intfloat commited on
Commit
d084283
1 Parent(s): 0b5edad

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -0
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SimLM: Pre-training with Representation Bottleneck for Dense Passage Retrieval
2
+
3
+ paper available at [https://arxiv.org/pdf/2207.02578](https://arxiv.org/pdf/2207.02578)
4
+
5
+ code available at [https://github.com/microsoft/unilm/tree/master/simlm](https://github.com/microsoft/unilm/tree/master/simlm)
6
+
7
+ ## Paper abstract
8
+
9
+ In this paper, we propose SimLM (Similarity matching with Language Model pre-training), a simple yet effective pre-training method for dense passage retrieval.
10
+ It employs a simple bottleneck architecture that learns to compress the passage information into a dense vector through self-supervised pre-training.
11
+ We use a replaced language modeling objective, which is inspired by ELECTRA,
12
+ to improve the sample efficiency and reduce the mismatch of the input distribution between pre-training and fine-tuning.
13
+ SimLM only requires access to unlabeled corpus, and is more broadly applicable when there are no labeled data or queries.
14
+ We conduct experiments on several large-scale passage retrieval datasets, and show substantial improvements over strong baselines under various settings.
15
+ Remarkably, SimLM even outperforms multi-vector approaches such as ColBERTv2 which incurs significantly more storage cost.
16
+
17
+ ## Results on MS-MARCO passage ranking task
18
+
19
+ | Model | dev MRR@10 | dev R@50 | dev R@1k | TREC DL 2019 nDCG@10 | TREC DL 2020 nDCG@10 |
20
+ |--|---|---|---|---|---|
21
+ | RocketQAv2 | 38.8 | 86.2 | 98.1 | - | - |
22
+ | coCondenser | 38.2 | 86.5 | 98.4 | 71.7 | 68.4 |
23
+ | ColBERTv2 | 39.7 | 86.8 | 98.4 | - | - |
24
+ | **SimLM (this model)** | 41.1 | 87.8 | 98.7 | 71.4 | 69.7 |
25
+
26
+ ## Usage
27
+
28
+ Get embeddings from our fine-tuned model:
29
+
30
+ ```python
31
+ import torch
32
+ from transformers import AutoModel, AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast
33
+ from transformers.modeling_outputs import BaseModelOutput
34
+
35
+ def l2_normalize(x: torch.Tensor):
36
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
37
+
38
+ def encode_query(tokenizer: PreTrainedTokenizerFast, query: str) -> BatchEncoding:
39
+ return tokenizer(query,
40
+ max_length=32,
41
+ padding=True,
42
+ truncation=True,
43
+ return_tensors='pt')
44
+
45
+ def encode_passage(tokenizer: PreTrainedTokenizerFast, passage: str, title: str = '-') -> BatchEncoding:
46
+ return tokenizer(title,
47
+ text_pair=passage,
48
+ max_length=144,
49
+ padding=True,
50
+ truncation=True,
51
+ return_tensors='pt')
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained('intfloat/simlm-base-msmarco-finetuned')
54
+ model = AutoModel.from_pretrained('intfloat/simlm-base-msmarco-finetuned')
55
+ model.eval()
56
+
57
+ with torch.no_grad():
58
+ query_batch_dict = encode_query(tokenizer, 'what is qa')
59
+ outputs: BaseModelOutput = model(**query_batch_dict, return_dict=True)
60
+ query_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])
61
+
62
+ psg1 = 'Quality assurance (QA) is a process-centered approach to ensuring that a company or organization is providing the best possible products or services. It is related to quality control, which focuses on the end result, such as testing a sample of items from a batch after production.'
63
+ psg1_batch_dict = encode_passage(tokenizer, psg1)
64
+ outputs: BaseModelOutput = model(**psg1_batch_dict, return_dict=True)
65
+ psg1_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])
66
+
67
+ psg2 = 'The Super Bowl is typically four hours long. The game itself takes about three and a half hours, with a 30 minute halftime show built in.'
68
+ psg2_batch_dict = encode_passage(tokenizer, psg2)
69
+ outputs: BaseModelOutput = model(**psg2_batch_dict, return_dict=True)
70
+ psg2_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])
71
+
72
+ # Higher cosine similarity means they are more relevant
73
+ print(query_embedding.dot(psg1_embedding), query_embedding.dot(psg2_embedding))
74
+ ```
75
+
76
+ ## Citation
77
+
78
+ ```bibtex
79
+ @article{Wang2022SimLMPW,
80
+ title={SimLM: Pre-training with Representation Bottleneck for Dense Passage Retrieval},
81
+ author={Liang Wang and Nan Yang and Xiaolong Huang and Binxing Jiao and Linjun Yang and Daxin Jiang and Rangan Majumder and Furu Wei},
82
+ journal={ArXiv},
83
+ year={2022},
84
+ volume={abs/2207.02578}
85
+ }
86
+ ```