-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
cnclip_model.py
51 lines (40 loc) · 1.4 KB
/
cnclip_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Originally from https:/OFA-Sys/Chinese-CLIP. MIT License.
import torch
from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
from cn_clip.clip import load_from_name
_CNCLIP_MODEL_MAPS = {
'CN-CLIP/ViT-B-16': 'ViT-B-16',
'CN-CLIP/ViT-L-14': 'ViT-L-14',
'CN-CLIP/ViT-L-14-336': 'ViT-L-14-336',
'CN-CLIP/ViT-H-14': 'ViT-H-14',
'CN-CLIP/RN50': 'RN50',
}
class CNClipModel(CLIPModel):
def __init__(
self,
name: str,
device: str = 'cpu',
jit: bool = False,
dtype: str = None,
**kwargs
):
super().__init__(name, **kwargs)
self._name = _CNCLIP_MODEL_MAPS[name]
self._model, self._preprocess = load_from_name(
_CNCLIP_MODEL_MAPS[name], device=device
)
self._model.eval()
@staticmethod
def get_model_name(name: str):
return _CNCLIP_MODEL_MAPS[name]
def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids).detach()
def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
return self._model.encode_image(pixel_values).detach()
@property
def model_name(self):
return self.__class__.get_model_name(self._name)
@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self._name, None)