neofung commited on
Commit
cb971c7
1 Parent(s): c8a5fc3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +133 -1
README.md CHANGED
@@ -62,4 +62,136 @@ model-index:
62
  value: 68.83696915006163
63
  - type: mrr
64
  value: 79.77644651857584
65
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  value: 68.83696915006163
63
  - type: mrr
64
  value: 79.77644651857584
65
+ ---
66
+
67
+ ## Introduction
68
+
69
+ This model is a downstream task of [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) .
70
+ We leverage the work of [FlagEmbedding reranker](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/reranker) ,
71
+ and implement with Qwen2-1.5B as pretrained model.
72
+
73
+ ## Usage
74
+
75
+ ```python
76
+ from typing import cast, List, Union, Tuple, Dict, Optional
77
+ import numpy as np
78
+ import torch
79
+ from tqdm import tqdm
80
+ import transformers
81
+ from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, DataCollatorWithPadding
82
+ from transformers.models.qwen2 import Qwen2Config, Qwen2ForSequenceClassification
83
+ from transformers.trainer_pt_utils import LabelSmoother
84
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
85
+
86
+ def preprocess(
87
+ sources,
88
+ tokenizer: transformers.PreTrainedTokenizer,
89
+ max_len: int = 1024,
90
+ ) -> Dict:
91
+
92
+ # Apply prompt templates
93
+ input_ids, attention_masks = [], []
94
+ for i, source in enumerate(sources):
95
+ ## system_message
96
+ messages = [
97
+ {"role": "user",
98
+ "content": "\n\n".join(source)}
99
+ ]
100
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
+ model_inputs = tokenizer([text])
102
+ input_id = model_inputs['input_ids'][0]
103
+ attention_mask = model_inputs['attention_mask'][0]
104
+ if len(input_id) > max_len:
105
+ diff = len(input_id) - max_len
106
+ input_id = input_id[:-5-diff] + input_id[-5:]
107
+ attention_mask = attention_mask[:-5-diff] + attention_mask[-5:]
108
+ assert len(input_id) == max_len
109
+ input_ids.append(input_id)
110
+ attention_masks.append(attention_mask)
111
+
112
+ return dict(
113
+ input_ids=input_ids,
114
+ attention_mask=attention_masks
115
+ )
116
+
117
+ class FlagRerankerCustom:
118
+ def __init__(
119
+ self,
120
+ model: PreTrainedModel,
121
+ tokenizer: PreTrainedTokenizer,
122
+ use_fp16: bool = False
123
+ ) -> None:
124
+ self.tokenizer = tokenizer
125
+ self.model = model
126
+ self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
127
+
128
+ if torch.cuda.is_available():
129
+ self.device = torch.device('cuda')
130
+ elif torch.backends.mps.is_available():
131
+ self.device = torch.device('mps')
132
+ else:
133
+ self.device = torch.device('cpu')
134
+ use_fp16 = False
135
+ if use_fp16:
136
+ self.model.half()
137
+
138
+ self.model = self.model.to(self.device)
139
+
140
+ self.model.eval()
141
+
142
+ self.num_gpus = torch.cuda.device_count()
143
+ if self.num_gpus > 1:
144
+ print(f"----------using {self.num_gpus}*GPUs----------")
145
+ self.model = torch.nn.DataParallel(self.model)
146
+
147
+ @torch.no_grad()
148
+ def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 64,
149
+ max_length: int = 1024) -> List[float]:
150
+
151
+ if self.num_gpus > 0:
152
+ batch_size = batch_size * self.num_gpus
153
+
154
+ assert isinstance(sentence_pairs, list)
155
+ if isinstance(sentence_pairs[0], str):
156
+ sentence_pairs = [sentence_pairs]
157
+
158
+ all_scores = []
159
+ for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Scores",
160
+ disable=True):
161
+ sentences_batch = sentence_pairs[start_index:start_index + batch_size]
162
+ inputs = preprocess(sources=sentences_batch, tokenizer=self.tokenizer, max_len=max_length)
163
+ inputs = [dict(zip(inputs, t)) for t in zip(*inputs.values())]
164
+ inputs = self.data_collator(inputs).to(self.device)
165
+ scores = self.model(**inputs, return_dict=True).logits
166
+ scores = scores.squeeze()
167
+ all_scores.extend(scores.detach().to(torch.float).cpu().numpy().tolist())
168
+
169
+ if len(all_scores) == 1:
170
+ return all_scores[0]
171
+ return all_scores
172
+
173
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
174
+ "neofung/LdIR-Qwen2-reranker-1.5B-large",
175
+ padding_side="right",
176
+ )
177
+
178
+ config = Qwen2Config.from_pretrained(
179
+ "neofung/LdIR-Qwen2-reranker-1.5B-large",
180
+ trust_remote_code=True,
181
+ bf16=True,
182
+ )
183
+
184
+ model = Qwen2ForSequenceClassification.from_pretrained(
185
+ "neofung/LdIR-Qwen2-reranker-1.5B-large",
186
+ config = config,
187
+ trust_remote_code = True,
188
+ )
189
+
190
+ model = FlagRerankerCustom(model=model, tokenizer=tokenizer, use_fp16=False)
191
+
192
+ pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
193
+
194
+ model.compute_score(pairs)
195
+
196
+ # [-2.655318021774292, 11.7670316696167]
197
+ ```