nreimers commited on
Commit
2c3c6f5
1 Parent(s): cdce8e4
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ - transformers
8
+ ---
9
+
10
+ # msmarco-bert-base-dot-v4
11
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 768 dimensional dense vector space and was designed for **semantic search**. It has been trained on 500K (query, answer) pairs from the [MS MARCO dataset](https://github.com/microsoft/MSMARCO-Passage-Ranking/). For an introduction to semantic search, have a look at: [SBERT.net - Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html)
12
+
13
+
14
+ ## Usage (Sentence-Transformers)
15
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
16
+
17
+ ```
18
+ pip install -U sentence-transformers
19
+ ```
20
+
21
+ Then you can use the model like this:
22
+ ```python
23
+ from sentence_transformers import SentenceTransformer, util
24
+
25
+ query = "How many people live in London?"
26
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
27
+
28
+ #Load the model
29
+ model = SentenceTransformer('sentence-transformers/msmarco-bert-base-dot-v4')
30
+
31
+ #Encode query and documents
32
+ query_emb = model.encode(query)
33
+ doc_emb = model.encode(docs)
34
+
35
+ #Compute dot score between query and all document embeddings
36
+ scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
37
+
38
+ #Combine docs & scores
39
+ doc_score_pairs = list(zip(docs, scores))
40
+
41
+ #Sort by decreasing score
42
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
43
+
44
+ #Output passages & scores
45
+ print("Query:", query)
46
+ for doc, score in doc_score_pairs:
47
+ print(score, doc)
48
+ ```
49
+
50
+
51
+ ## Usage (HuggingFace Transformers)
52
+ Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the correct pooling-operation on-top of the contextualized word embeddings.
53
+
54
+ ```python
55
+ from transformers import AutoTokenizer, AutoModel
56
+ import torch
57
+
58
+ #Mean Pooling - Take attention mask into account for correct averaging
59
+ def mean_pooling(model_output, attention_mask):
60
+ token_embeddings = model_output.last_hidden_state
61
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
62
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
63
+
64
+
65
+ #Encode text
66
+ def encode(texts):
67
+ # Tokenize sentences
68
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
69
+
70
+ # Compute token embeddings
71
+ with torch.no_grad():
72
+ model_output = model(**encoded_input, return_dict=True)
73
+
74
+ # Perform pooling
75
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
76
+
77
+ return embeddings
78
+
79
+
80
+ # Sentences we want sentence embeddings for
81
+ query = "How many people live in London?"
82
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
83
+
84
+ # Load model from HuggingFace Hub
85
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v4")
86
+ model = AutoModel.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v4")
87
+
88
+ #Encode query and docs
89
+ query_emb = encode(query)
90
+ doc_emb = encode(docs)
91
+
92
+ #Compute dot score between query and all document embeddings
93
+ scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
94
+
95
+ #Combine docs & scores
96
+ doc_score_pairs = list(zip(docs, scores))
97
+
98
+ #Sort by decreasing score
99
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
100
+
101
+ #Output passages & scores
102
+ print("Query:", query)
103
+ for doc, score in doc_score_pairs:
104
+ print(score, doc)
105
+ ```
106
+
107
+ ## Technical Details
108
+
109
+ In the following some technical details how this model must be used:
110
+
111
+ | Setting | Value |
112
+ | --- | :---: |
113
+ | Dimensions | 768 |
114
+ | Max Sequence Length | 512 |
115
+ | Produces normalized embeddings | No |
116
+ | Pooling-Method | Mean pooling |
117
+ | Suitable score functions | dot-product (e.g. `util.dot_score`) |
118
+
119
+
120
+ ## Evaluation Results
121
+
122
+ <!--- Describe how your model was evaluated -->
123
+
124
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=msmarco-bert-base-base-dot-v4)
125
+
126
+
127
+ ## Training
128
+
129
+ See `train_script.py` in this repository for the used training script.
130
+
131
+
132
+
133
+ The model was trained with the parameters:
134
+
135
+ **DataLoader**:
136
+
137
+ `torch.utils.data.dataloader.DataLoader` of length 7858 with parameters:
138
+ ```
139
+ {'batch_size': 64, 'sampler': 'torch.utils.data.sampler.RandomSampler', 'batch_sampler': 'torch.utils.data.sampler.BatchSampler'}
140
+ ```
141
+
142
+ **Loss**:
143
+
144
+ `sentence_transformers.losses.MarginMSELoss.MarginMSELoss`
145
+
146
+ Parameters of the fit()-Method:
147
+ ```
148
+ {
149
+ "callback": null,
150
+ "epochs": 30,
151
+ "evaluation_steps": 0,
152
+ "evaluator": "NoneType",
153
+ "max_grad_norm": 1,
154
+ "optimizer_class": "<class 'transformers.optimization.AdamW'>",
155
+ "optimizer_params": {
156
+ "lr": 1e-05
157
+ },
158
+ "scheduler": "WarmupLinear",
159
+ "steps_per_epoch": null,
160
+ "warmup_steps": 10000,
161
+ "weight_decay": 0.01
162
+ }
163
+ ```
164
+
165
+
166
+ ## Full Model Architecture
167
+ ```
168
+ SentenceTransformer(
169
+ (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: bert-base-uncased
170
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
171
+ )
172
+ ```
173
+
174
+ ## Citing & Authors
175
+
176
+ This model was trained by [sentence-transformers](https://www.sbert.net/).
177
+
178
+ If you find this model helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
179
+ ```bibtex
180
+ @inproceedings{reimers-2019-sentence-bert,
181
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
182
+ author = "Reimers, Nils and Gurevych, Iryna",
183
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
184
+ month = "11",
185
+ year = "2019",
186
+ publisher = "Association for Computational Linguistics",
187
+ url = "http://arxiv.org/abs/1908.10084",
188
+ }
189
+ ```
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "final-models/co-condenser-margin_mse-sym_mnrl-mean-v1/",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "transformers_version": "4.6.1",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 30522
24
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbce66b05653369175bf318af64513fa2bf95f57782b850a6fa7f36c1723fd3c
3
+ size 438015479
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "do_basic_tokenize": true, "never_split": null, "model_max_length": 512, "name_or_path": "final-models/co-condenser-margin_mse-sym_mnrl-mean-v1/", "special_tokens_map_file": "/bos/tmp0/luyug/outputs/condenser/models/l2-s6-km-L128-e8-lr1e-4-b256/special_tokens_map.json"}
train_script.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import json
4
+ from torch.utils.data import DataLoader
5
+ from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
6
+ import logging
7
+ from datetime import datetime
8
+ import gzip
9
+ import os
10
+ import tarfile
11
+ from collections import defaultdict
12
+ from torch.utils.data import IterableDataset
13
+ import tqdm
14
+ from torch.utils.data import Dataset
15
+ import random
16
+ from shutil import copyfile
17
+
18
+ import argparse
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--train_batch_size", default=64, type=int)
22
+ parser.add_argument("--max_seq_length", default=300, type=int)
23
+ parser.add_argument("--model_name", required=True)
24
+ parser.add_argument("--max_passages", default=0, type=int)
25
+ parser.add_argument("--epochs", default=10, type=int)
26
+ parser.add_argument("--pooling", default="cls")
27
+ parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all")
28
+ parser.add_argument("--warmup_steps", default=1000, type=int)
29
+ parser.add_argument("--lr", default=2e-5, type=float)
30
+ parser.add_argument("--name", default='')
31
+ parser.add_argument("--num_negs_per_system", default=5, type=int)
32
+ parser.add_argument("--use_pre_trained_model", default=False, action="store_true")
33
+ parser.add_argument("--use_all_queries", default=False, action="store_true")
34
+ args = parser.parse_args()
35
+
36
+ print(args)
37
+
38
+ #### Just some code to print debug information to stdout
39
+ logging.basicConfig(format='%(asctime)s - %(message)s',
40
+ datefmt='%Y-%m-%d %H:%M:%S',
41
+ level=logging.INFO,
42
+ handlers=[LoggingHandler()])
43
+ #### /print debug information to stdout
44
+
45
+ # The model we want to fine-tune
46
+ train_batch_size = args.train_batch_size #Increasing the train batch size improves the model performance, but requires more GPU memory
47
+ model_name = args.model_name
48
+ max_passages = args.max_passages
49
+ max_seq_length = args.max_seq_length #Max length for passages. Increasing it, requires more GPU memory
50
+
51
+ num_negs_per_system = args.num_negs_per_system # We used different systems to mine hard negatives. Number of hard negatives to add from each system
52
+ num_epochs = args.epochs # Number of epochs we want to train
53
+
54
+ # We construct the SentenceTransformer bi-encoder from scratch
55
+ if args.use_pre_trained_model:
56
+ print("use pretrained SBERT model")
57
+ model = SentenceTransformer(model_name)
58
+ model.max_seq_length = max_seq_length
59
+ else:
60
+ print("Create new SBERT model")
61
+ word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
62
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
63
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
64
+
65
+ model_save_path = f'output/train_bi-encoder-margin_mse_en-{args.name}-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
66
+
67
+
68
+ # Write self to path
69
+ os.makedirs(model_save_path, exist_ok=True)
70
+
71
+ train_script_path = os.path.join(model_save_path, 'train_script.py')
72
+ copyfile(__file__, train_script_path)
73
+ with open(train_script_path, 'a') as fOut:
74
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
75
+
76
+
77
+ ### Now we read the MS Marco dataset
78
+ data_folder = 'msmarco-data'
79
+
80
+ #### Read the corpus files, that contain all the passages. Store them in the corpus dict
81
+ corpus = {} #dict in the format: passage_id -> passage. Stores all existent passages
82
+ collection_filepath = os.path.join(data_folder, 'collection.tsv')
83
+ if not os.path.exists(collection_filepath):
84
+ tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
85
+ if not os.path.exists(tar_filepath):
86
+ logging.info("Download collection.tar.gz")
87
+ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
88
+
89
+ with tarfile.open(tar_filepath, "r:gz") as tar:
90
+ tar.extractall(path=data_folder)
91
+
92
+ logging.info("Read corpus: collection.tsv")
93
+ with open(collection_filepath, 'r', encoding='utf8') as fIn:
94
+ for line in fIn:
95
+ pid, passage = line.strip().split("\t")
96
+ corpus[pid] = passage
97
+
98
+
99
+ ### Read the train queries, store in queries dict
100
+ queries = {} #dict in the format: query_id -> query. Stores all training queries
101
+ queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
102
+ if not os.path.exists(queries_filepath):
103
+ tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
104
+ if not os.path.exists(tar_filepath):
105
+ logging.info("Download queries.tar.gz")
106
+ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
107
+
108
+ with tarfile.open(tar_filepath, "r:gz") as tar:
109
+ tar.extractall(path=data_folder)
110
+
111
+
112
+ with open(queries_filepath, 'r', encoding='utf8') as fIn:
113
+ for line in fIn:
114
+ qid, query = line.strip().split("\t")
115
+ queries[qid] = query
116
+
117
+
118
+ # Read our training file: msmarco-hard-negatives.jsonl.gz contains all queries and hard-negatives that were mined with different systems
119
+ # For each positive and mined-hard negative passage, we have a Cross-Encoder score from the cross-encoder/ms-marco-MiniLM-L-6-v2 model
120
+ # This Cross-Encoder score allows to de-noise our hard-negatives by requiring that their CE-score is below a certain treshold
121
+ train_filepath = '/home/msmarco/data/hard-negatives/msmarco-hard-negatives-v6.jsonl.gz'
122
+
123
+ #### Create our training data
124
+ logging.info("Read train dataset")
125
+ train_queries = {}
126
+ ce_scores = {}
127
+ negs_to_use = None
128
+ with gzip.open(train_filepath, 'rt') as fIn:
129
+ for line in tqdm.tqdm(fIn):
130
+ if max_passages > 0 and len(train_queries) >= max_passages:
131
+ break
132
+
133
+ data = json.loads(line)
134
+
135
+ if data['qid'] not in ce_scores:
136
+ ce_scores[data['qid']] = {}
137
+
138
+ # Add pos ce_scores
139
+ for item in data['pos'] :
140
+ ce_scores[data['qid']][item['pid']] = item['ce-score']
141
+
142
+ #Get the positive passage ids
143
+ pos_pids = [item['pid'] for item in data['pos']]
144
+
145
+ #Get the hard negatives
146
+ neg_pids = set()
147
+ if negs_to_use is None:
148
+ if args.negs_to_use is not None: #Use specific system for negatives
149
+ negs_to_use = args.negs_to_use.split(",")
150
+ else: #Use all systems
151
+ negs_to_use = list(data['neg'].keys())
152
+ print("Using negatives from the following systems:", negs_to_use)
153
+
154
+ for system_name in negs_to_use:
155
+ if system_name not in data['neg']:
156
+ continue
157
+
158
+ system_negs = data['neg'][system_name]
159
+
160
+ negs_added = 0
161
+ for item in system_negs:
162
+ #Add neg ce_scores
163
+ ce_scores[data['qid']][item['pid']] = item['ce-score']
164
+
165
+ pid = item['pid']
166
+ if pid not in neg_pids:
167
+ neg_pids.add(pid)
168
+ negs_added += 1
169
+ if negs_added >= num_negs_per_system:
170
+ break
171
+
172
+ if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):
173
+ train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}
174
+
175
+ logging.info("Train queries: {}".format(len(train_queries)))
176
+
177
+ # We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
178
+ # on-the-fly based on the information from the mined-hard-negatives jsonl file.
179
+ class MSMARCODataset(Dataset):
180
+ def __init__(self, queries, corpus, ce_scores):
181
+ self.queries = queries
182
+ self.queries_ids = list(queries.keys())
183
+ self.corpus = corpus
184
+ self.ce_scores = ce_scores
185
+
186
+ for qid in self.queries:
187
+ self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
188
+ self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
189
+ random.shuffle(self.queries[qid]['neg'])
190
+
191
+ def __getitem__(self, item):
192
+ query = self.queries[self.queries_ids[item]]
193
+ query_text = query['query']
194
+ qid = query['qid']
195
+
196
+ if len(query['pos']) > 0:
197
+ pos_id = query['pos'].pop(0) #Pop positive and add at end
198
+ pos_text = self.corpus[pos_id]
199
+ query['pos'].append(pos_id)
200
+ else: #We only have negatives, use two negs
201
+ pos_id = query['neg'].pop(0) #Pop negative and add at end
202
+ pos_text = self.corpus[pos_id]
203
+ query['neg'].append(pos_id)
204
+
205
+ #Get a negative passage
206
+ neg_id = query['neg'].pop(0) #Pop negative and add at end
207
+ neg_text = self.corpus[neg_id]
208
+ query['neg'].append(neg_id)
209
+
210
+ pos_score = self.ce_scores[qid][pos_id]
211
+ neg_score = self.ce_scores[qid][neg_id]
212
+
213
+ return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score)
214
+
215
+ def __len__(self):
216
+ return len(self.queries)
217
+
218
+ # For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
219
+ train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores)
220
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True)
221
+ train_loss = losses.MarginMSELoss(model=model)
222
+
223
+ # Train the model
224
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
225
+ epochs=num_epochs,
226
+ warmup_steps=args.warmup_steps,
227
+ use_amp=True,
228
+ checkpoint_path=model_save_path,
229
+ checkpoint_save_steps=10000,
230
+ checkpoint_save_total_limit = 0,
231
+ optimizer_params = {'lr': args.lr},
232
+ )
233
+
234
+ # Train latest model
235
+ model.save(model_save_path)
236
+
237
+
238
+ # Script was called via:
239
+ #python train_bi-encoder-margin_mse-en.py --model final-models/co-condenser-margin_mse-sym_mnrl-mean-v1 --lr=1e-5 --warmup_steps=10000 --negs_to_use=co-condenser-margin_mse-sym_mnrl-mean-v1 --num_negs_per_system=10 --epochs=30 --name=cnt_with_mined_negs_mean --use_pre_trained_model --train_batch_size 32
vocab.txt ADDED
The diff for this file is too large to render. See raw diff