switch to cos loss with exp mapping
Browse files- .gitignore +3 -1
- finetune_cos.py +87 -0
- pytorch_model.bin +1 -1
- pytorch_model.onnx +1 -1
.gitignore
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
.venv
|
2 |
-
venv
|
|
|
|
|
|
1 |
.venv
|
2 |
+
venv
|
3 |
+
cos-exp
|
4 |
+
cos-lin
|
finetune_cos.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
|
2 |
+
from torch import nn
|
3 |
+
import csv
|
4 |
+
from torch.utils.data import DataLoader, Dataset
|
5 |
+
import torch
|
6 |
+
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
|
7 |
+
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
import gzip
|
12 |
+
|
13 |
+
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
|
14 |
+
|
15 |
+
train_batch_size = 100
|
16 |
+
max_seq_length = 128
|
17 |
+
num_epochs = 1
|
18 |
+
warmup_steps = 1000
|
19 |
+
model_save_path = 'cos-exp'
|
20 |
+
lr = 2e-5
|
21 |
+
|
22 |
+
class ESCIDataset(Dataset):
|
23 |
+
def __init__(self, input):
|
24 |
+
self.queries = []
|
25 |
+
with gzip.open(input) as jsonfile:
|
26 |
+
for line in jsonfile.readlines():
|
27 |
+
query = json.loads(line)
|
28 |
+
for p in query['e']:
|
29 |
+
positive = p['title']
|
30 |
+
self.queries.append(InputExample(texts=[query['query'], positive], label=1.0))
|
31 |
+
for p in query['s']:
|
32 |
+
positive = p['title']
|
33 |
+
self.queries.append(InputExample(texts=[query['query'], positive], label=0.1))
|
34 |
+
for p in query['c']:
|
35 |
+
positive = p['title']
|
36 |
+
self.queries.append(InputExample(texts=[query['query'], positive], label=0.01))
|
37 |
+
for p in query['i']:
|
38 |
+
positive = p['title']
|
39 |
+
self.queries.append(InputExample(texts=[query['query'], positive], label=0.0))
|
40 |
+
|
41 |
+
def __getitem__(self, item):
|
42 |
+
return self.queries[item]
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.queries)
|
46 |
+
|
47 |
+
|
48 |
+
model = SentenceTransformer(model_name, device='cpu')
|
49 |
+
model.max_seq_length = max_seq_length
|
50 |
+
|
51 |
+
|
52 |
+
train_dataset = ESCIDataset(input='train-small.json.gz')
|
53 |
+
eval_dataset = ESCIDataset(input='test-small.json.gz')
|
54 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
|
55 |
+
train_loss = losses.CosineSimilarityLoss(model=model)
|
56 |
+
|
57 |
+
# samples = {}
|
58 |
+
# for query in eval_dataset.queries:
|
59 |
+
# qstr = query.texts[0]
|
60 |
+
# sample = samples.get(qstr, {'query': qstr})
|
61 |
+
# positive = sample.get('positive', [])
|
62 |
+
# positive.append(query.texts[1])
|
63 |
+
# sample['positive'] = positive
|
64 |
+
# negative = sample.get('negative', [])
|
65 |
+
# negative.append(query.texts[2])
|
66 |
+
# sample['negative'] = negative
|
67 |
+
# samples[qstr] = sample
|
68 |
+
|
69 |
+
# evaluator = RerankingEvaluator(samples=samples,name='esci')
|
70 |
+
|
71 |
+
# Train the model
|
72 |
+
|
73 |
+
model.fit(train_objectives=[(train_dataloader, train_loss)],
|
74 |
+
epochs=num_epochs,
|
75 |
+
warmup_steps=warmup_steps,
|
76 |
+
use_amp=True,
|
77 |
+
# checkpoint_path=model_save_path,
|
78 |
+
# checkpoint_save_steps=len(train_dataloader),
|
79 |
+
optimizer_params = {'lr': lr},
|
80 |
+
# evaluator=evaluator,
|
81 |
+
# evaluation_steps=1000,
|
82 |
+
output_path=model_save_path
|
83 |
+
)
|
84 |
+
|
85 |
+
# Save the model
|
86 |
+
|
87 |
+
model.save(model_save_path)
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 90891565
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c0cc55f150c9ab5404bc6b64cd1b9399b81bfe3cf083a1d30c89e9cd7d4235a
|
3 |
size 90891565
|
pytorch_model.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 90984263
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1264e3429772fc7d27c6c2f6e9bbf04e4358821fb380dfcff18e5f96d14f8f32
|
3 |
size 90984263
|