anahita-b commited on
Commit
b1aa9eb
1 Parent(s): e54c068

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -5
README.md CHANGED
@@ -39,7 +39,6 @@ import requests
39
  from PIL import Image
40
  import torch
41
 
42
- device = torch.device('cuda')
43
  image_urls = [
44
  "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
45
     "http://images.cocodataset.org/val2017/000000039769.jpg"]
@@ -49,13 +48,13 @@ texts = [
49
  images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
50
 
51
  processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm")
52
- model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")model.to(device)
53
 
54
- inputs  = processor(images, texts, padding=True, return_tensors="pt").to(device)
55
  outputs = model(**inputs, labels=torch.ones(2,device=device))
56
 
57
- inputs  = processor(images, texts[::-1], padding=True, return_tensors="pt").to(device)
58
- outputs_swapped = model(**inputs, labels=torch.ones(2,device=device))
59
 
60
  print('Loss', outputs.loss.item())
61
  print('Loss with swapped images', outputs_swapped.loss.item())
 
39
  from PIL import Image
40
  import torch
41
 
 
42
  image_urls = [
43
  "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
44
     "http://images.cocodataset.org/val2017/000000039769.jpg"]
 
48
  images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
49
 
50
  processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm")
51
+ model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
52
 
53
+ inputs  = processor(images, texts, padding=True, return_tensors="pt")
54
  outputs = model(**inputs, labels=torch.ones(2,device=device))
55
 
56
+ inputs  = processor(images, texts[::-1], padding=True, return_tensors="pt")
57
+ outputs_swapped = model(**inputs, labels=torch.ones(2))
58
 
59
  print('Loss', outputs.loss.item())
60
  print('Loss with swapped images', outputs_swapped.loss.item())