File size: 5,911 Bytes
c8a5fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb971c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cbcfab
cb971c7
 
 
 
4cbcfab
cb971c7
 
 
 
 
4cbcfab
cb971c7
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
---
license: apache-2.0
language:
- en
- zh
library_name: transformers
tags:
- mteb
- RAG-reranking
model-index:
- name: LdIR-reranker-large
  results:
  - task:
      type: Reranking
    dataset:
      type: C-MTEB/CMedQAv1-reranking
      name: MTEB CMedQAv1
      config: default
      split: test
      revision: None
    metrics:
    - type: map
      value: 86.50438688414654
    - type: mrr
      value: 88.91170634920635
  - task:
      type: Reranking
    dataset:
      type: C-MTEB/CMedQAv2-reranking
      name: MTEB CMedQAv2
      config: default
      split: test
      revision: None
    metrics:
    - type: map
      value: 87.10592353383732
    - type: mrr
      value: 89.10178571428571
  - task:
      type: Reranking
    dataset:
      type: C-MTEB/Mmarco-reranking
      name: MTEB MMarcoReranking
      config: default
      split: dev
      revision: None
    metrics:
    - type: map
      value: 39.354813242907133
    - type: mrr
      value: 39.075793650793655
  - task:
      type: Reranking
    dataset:
      type: C-MTEB/T2Reranking
      name: MTEB T2Reranking
      config: default
      split: dev
      revision: None
    metrics:
    - type: map
      value: 68.83696915006163
    - type: mrr
      value: 79.77644651857584
---

## Introduction

This model is a downstream task of [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) . 
We leverage the work of [FlagEmbedding reranker](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/reranker) , 
and implement with Qwen2-1.5B as pretrained model.

## Usage

```python
from typing import cast, List, Union, Tuple, Dict, Optional
import numpy as np
import torch
from tqdm import tqdm
import transformers
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, DataCollatorWithPadding
from transformers.models.qwen2 import Qwen2Config, Qwen2ForSequenceClassification
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index

def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int = 1024,
) -> Dict:

    # Apply prompt templates
    input_ids, attention_masks = [], []
    for i, source in enumerate(sources):
        ## system_message
        messages = [
            {"role": "user",
            "content": "\n\n".join(source)}
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([text])
        input_id = model_inputs['input_ids'][0]
        attention_mask = model_inputs['attention_mask'][0]
        if len(input_id) > max_len:
            diff = len(input_id) - max_len
            input_id = input_id[:-5-diff] + input_id[-5:]
            attention_mask = attention_mask[:-5-diff] + attention_mask[-5:]
            assert len(input_id) == max_len
        input_ids.append(input_id)
        attention_masks.append(attention_mask)

    return dict(
        input_ids=input_ids,
        attention_mask=attention_masks
    )

class FlagRerankerCustom:
    def __init__(
            self,
            model: PreTrainedModel,
            tokenizer: PreTrainedTokenizer,
            use_fp16: bool = False
    ) -> None:
        self.tokenizer = tokenizer
        self.model = model
        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        elif torch.backends.mps.is_available():
            self.device = torch.device('mps')
        else:
            self.device = torch.device('cpu')
            use_fp16 = False
        if use_fp16:
            self.model.half()

        self.model = self.model.to(self.device)

        self.model.eval()

        self.num_gpus = torch.cuda.device_count()
        if self.num_gpus > 1:
            print(f"----------using {self.num_gpus}*GPUs----------")
            self.model = torch.nn.DataParallel(self.model)

    @torch.no_grad()
    def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 64,
                      max_length: int = 1024) -> List[float]:
        
        if self.num_gpus > 0:
            batch_size = batch_size * self.num_gpus

        assert isinstance(sentence_pairs, list)
        if isinstance(sentence_pairs[0], str):
            sentence_pairs = [sentence_pairs]

        all_scores = []
        for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Scores",
                                disable=True):
            sentences_batch = sentence_pairs[start_index:start_index + batch_size]
            inputs = preprocess(sources=sentences_batch, tokenizer=self.tokenizer, max_len=max_length)
            inputs = [dict(zip(inputs, t)) for t in zip(*inputs.values())]
            inputs = self.data_collator(inputs).to(self.device)
            scores = self.model(**inputs, return_dict=True).logits
            scores = scores.squeeze()
            all_scores.extend(scores.detach().to(torch.float).cpu().numpy().tolist())

        if len(all_scores) == 1:
            return all_scores[0]
        return all_scores

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "neofung/LdIR-Qwen2-reranker-1.5B",
    padding_side="right",
)

config = Qwen2Config.from_pretrained(
    "neofung/LdIR-Qwen2-reranker-1.5B",
    trust_remote_code=True,
    bf16=True,
)

model = Qwen2ForSequenceClassification.from_pretrained(
    "neofung/LdIR-Qwen2-reranker-1.5B",
    config = config,
    trust_remote_code = True,
)

model = FlagRerankerCustom(model=model, tokenizer=tokenizer, use_fp16=False)

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]

model.compute_score(pairs)

# [-2.655318021774292, 11.7670316696167]
```