cadurosar commited on
Commit
70f429b
1 Parent(s): 47c0876
Files changed (2) hide show
  1. README.md +52 -0
  2. pytorch_model.bin +3 -0
README.md CHANGED
@@ -1,3 +1,55 @@
1
  ---
2
  license: cc-by-nc-sa-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-sa-4.0
3
  ---
4
+
5
+ Our best attempt at reproducing [RankT5 Enc-Softmax](https://arxiv.org/pdf/2210.10634.pdf), with a few important differences:
6
+
7
+ 1. We use a SPLADE first stage for the negatives vs GTR on the paper
8
+ 2. We train using Pytorch vs Flaxx on the paper
9
+ 3. We use the original t5-3b vs Flan T5-3b on the paper
10
+
11
+ This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well.
12
+
13
+ To use this model, first clone the huggingface repo
14
+
15
+ ```
16
+
17
+ ```
18
+
19
+
20
+ ```
21
+ import torch
22
+ from transformers import T5EncoderModel
23
+
24
+ class T5EncoderRerank(torch.nn.Module):
25
+ def __init__(self, model_type_or_dir,fp16=False, bf16=False):
26
+ """
27
+ model_type_or_dir is either the name of a pre-trained model (e.g. bert-base-uncased), or the path to
28
+ directory containing model weights, vocab etc.
29
+ """
30
+ super().__init__()
31
+ self.model = T5EncoderModel.from_pretrained(model_type_or_dir)
32
+ self.config = self.model.config
33
+ self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model)
34
+ self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12)
35
+ self.linear = torch.nn.Linear(self.config.d_model,1)
36
+
37
+ def forward(self, **kwargs):
38
+ result = self.model(**kwargs).last_hidden_state[:,0,:]
39
+ first_transformed = self.first_transform(result)
40
+ layer_normed = self.layer_norm(first_transformed)
41
+ logits = self.linear(layer_normed)
42
+ return SequenceClassifierOutput(
43
+ logits=logits
44
+ )
45
+
46
+
47
+ original_model="t5-3b"
48
+ path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin"
49
+
50
+ print("Loading")
51
+ model = T5EncoderRerank(original_model,bf16=True)
52
+ model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu")))
53
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
54
+ model.to(device)
55
+ ```
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e54e0aa04314ffc5aa2a07a23e890a932078254b88d2570797ffd15e0057e64e
3
+ size 4967947259