Kunyi commited on
Commit
6fbc593
1 Parent(s): be3eee7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -22
README.md CHANGED
@@ -174,36 +174,35 @@ pip install -r requirements.txt
174
  ```
175
 
176
  ## Inference Code
177
- ```bash
178
- export PYTHONPATH=/yourpath/QA-CLIP-main
179
- ```
180
  Inference code example:
181
  ```python
182
- import torch
183
  from PIL import Image
 
 
184
 
185
- import clip as clip
186
- from clip import load_from_name, available_models
187
- print("Available models:", available_models())
188
- # Available models: ['ViT-B-16', 'ViT-L-14', 'RN50']
189
 
190
- device = "cuda" if torch.cuda.is_available() else "cpu"
191
- model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
192
- model.eval()
193
- image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
194
- text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
195
 
196
- with torch.no_grad():
197
- image_features = model.encode_image(image)
198
- text_features = model.encode_text(text)
199
- # Normalize the features. Please use the normalized features for downstream tasks.
200
- image_features /= image_features.norm(dim=-1, keepdim=True)
201
- text_features /= text_features.norm(dim=-1, keepdim=True)
202
 
203
- logits_per_image, logits_per_text = model.get_similarity(image, text)
204
- probs = logits_per_image.softmax(dim=-1).cpu().numpy()
 
 
205
 
206
- print("Label probs:", probs)
 
 
 
 
207
  ```
208
  <br><br>
209
 
 
174
  ```
175
 
176
  ## Inference Code
 
 
 
177
  Inference code example:
178
  ```python
 
179
  from PIL import Image
180
+ import requests
181
+ from transformers import ChineseCLIPProcessor, ChineseCLIPModel
182
 
183
+ model = ChineseCLIPModel.from_pretrained("TencentARC/QA-CLIP-ViT-B-16")
184
+ processor = ChineseCLIPProcessor.from_pretrained("TencentARC/QA-CLIP-ViT-B-16")
 
 
185
 
186
+ url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
187
+ image = Image.open(requests.get(url, stream=True).raw)
188
+ # Squirtle, Bulbasaur, Charmander, Pikachu in English
189
+ texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]
 
190
 
191
+ # compute image feature
192
+ inputs = processor(images=image, return_tensors="pt")
193
+ image_features = model.get_image_features(**inputs)
194
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) # normalize
 
 
195
 
196
+ # compute text features
197
+ inputs = processor(text=texts, padding=True, return_tensors="pt")
198
+ text_features = model.get_text_features(**inputs)
199
+ text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) # normalize
200
 
201
+ # compute image-text similarity scores
202
+ inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
203
+ outputs = model(**inputs)
204
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
205
+ probs = logits_per_image.softmax(dim=1)
206
  ```
207
  <br><br>
208