shuttie commited on
Commit
a16d1d7
1 Parent(s): 7db4440

switch to cos loss with exp mapping

Browse files
Files changed (4) hide show
  1. .gitignore +3 -1
  2. finetune_cos.py +87 -0
  3. pytorch_model.bin +1 -1
  4. pytorch_model.onnx +1 -1
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .venv
2
- venv
 
 
 
1
  .venv
2
+ venv
3
+ cos-exp
4
+ cos-lin
finetune_cos.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample, CrossEncoder
2
+ from torch import nn
3
+ import csv
4
+ from torch.utils.data import DataLoader, Dataset
5
+ import torch
6
+ from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator, SimilarityFunction, RerankingEvaluator
7
+ from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
8
+ import logging
9
+ import json
10
+ import random
11
+ import gzip
12
+
13
+ model_name = 'sentence-transformers/all-MiniLM-L6-v2'
14
+
15
+ train_batch_size = 100
16
+ max_seq_length = 128
17
+ num_epochs = 1
18
+ warmup_steps = 1000
19
+ model_save_path = 'cos-exp'
20
+ lr = 2e-5
21
+
22
+ class ESCIDataset(Dataset):
23
+ def __init__(self, input):
24
+ self.queries = []
25
+ with gzip.open(input) as jsonfile:
26
+ for line in jsonfile.readlines():
27
+ query = json.loads(line)
28
+ for p in query['e']:
29
+ positive = p['title']
30
+ self.queries.append(InputExample(texts=[query['query'], positive], label=1.0))
31
+ for p in query['s']:
32
+ positive = p['title']
33
+ self.queries.append(InputExample(texts=[query['query'], positive], label=0.1))
34
+ for p in query['c']:
35
+ positive = p['title']
36
+ self.queries.append(InputExample(texts=[query['query'], positive], label=0.01))
37
+ for p in query['i']:
38
+ positive = p['title']
39
+ self.queries.append(InputExample(texts=[query['query'], positive], label=0.0))
40
+
41
+ def __getitem__(self, item):
42
+ return self.queries[item]
43
+
44
+ def __len__(self):
45
+ return len(self.queries)
46
+
47
+
48
+ model = SentenceTransformer(model_name, device='cpu')
49
+ model.max_seq_length = max_seq_length
50
+
51
+
52
+ train_dataset = ESCIDataset(input='train-small.json.gz')
53
+ eval_dataset = ESCIDataset(input='test-small.json.gz')
54
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
55
+ train_loss = losses.CosineSimilarityLoss(model=model)
56
+
57
+ # samples = {}
58
+ # for query in eval_dataset.queries:
59
+ # qstr = query.texts[0]
60
+ # sample = samples.get(qstr, {'query': qstr})
61
+ # positive = sample.get('positive', [])
62
+ # positive.append(query.texts[1])
63
+ # sample['positive'] = positive
64
+ # negative = sample.get('negative', [])
65
+ # negative.append(query.texts[2])
66
+ # sample['negative'] = negative
67
+ # samples[qstr] = sample
68
+
69
+ # evaluator = RerankingEvaluator(samples=samples,name='esci')
70
+
71
+ # Train the model
72
+
73
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
74
+ epochs=num_epochs,
75
+ warmup_steps=warmup_steps,
76
+ use_amp=True,
77
+ # checkpoint_path=model_save_path,
78
+ # checkpoint_save_steps=len(train_dataloader),
79
+ optimizer_params = {'lr': lr},
80
+ # evaluator=evaluator,
81
+ # evaluation_steps=1000,
82
+ output_path=model_save_path
83
+ )
84
+
85
+ # Save the model
86
+
87
+ model.save(model_save_path)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:32be7d4ded580fffe92bfff7e8dc865b17dc0cfbba3dd598865adc54dd89d0c3
3
  size 90891565
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c0cc55f150c9ab5404bc6b64cd1b9399b81bfe3cf083a1d30c89e9cd7d4235a
3
  size 90891565
pytorch_model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b1abb2047d142c643ed10c33b6a1517171dab1ffb3d3a57bc2043437d3f5bf77
3
  size 90984263
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1264e3429772fc7d27c6c2f6e9bbf04e4358821fb380dfcff18e5f96d14f8f32
3
  size 90984263