Spaces:
Running
Running
yjwtheonly
commited on
Commit
•
6ebf426
1
Parent(s):
fce1f4b
server
Browse files- server/__pycache__/server_utils.cpython-38.pyc +0 -0
- server/server.py +677 -0
- server/server_utils.py +89 -0
server/__pycache__/server_utils.cpython-38.pyc
ADDED
Binary file (1.97 kB). View file
|
|
server/server.py
ADDED
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
import networkx as nx
|
11 |
+
import spacy
|
12 |
+
import pickle as pkl
|
13 |
+
#%%
|
14 |
+
|
15 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
16 |
+
from transformers import AutoTokenizer
|
17 |
+
from transformers import BioGptForCausalLM, BartForConditionalGeneration
|
18 |
+
|
19 |
+
import server_utils
|
20 |
+
|
21 |
+
sys.path.append("..")
|
22 |
+
import Parameters
|
23 |
+
from Openai.chat import generate_abstract
|
24 |
+
sys.path.append("../DiseaseSpecific")
|
25 |
+
import utils, attack
|
26 |
+
from attack import calculate_edge_bound, get_model_loss_without_softmax
|
27 |
+
|
28 |
+
|
29 |
+
specific_model = None
|
30 |
+
|
31 |
+
def capitalize_the_first_letter(s):
|
32 |
+
return s[0].upper() + s[1:]
|
33 |
+
|
34 |
+
parser = utils.get_argument_parser()
|
35 |
+
parser = utils.add_attack_parameters(parser)
|
36 |
+
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
|
37 |
+
args = parser.parse_args()
|
38 |
+
args = utils.set_hyperparams(args)
|
39 |
+
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
# device = torch.device("cpu")
|
42 |
+
args.device = device
|
43 |
+
args.device1 = device
|
44 |
+
if torch.cuda.device_count() >= 2:
|
45 |
+
args.device = "cuda:0"
|
46 |
+
args.device1 = "cuda:1"
|
47 |
+
|
48 |
+
utils.seed_all(args.seed)
|
49 |
+
np.set_printoptions(precision=5)
|
50 |
+
cudnn.benchmark = False
|
51 |
+
|
52 |
+
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
|
53 |
+
model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name)
|
54 |
+
data_path = os.path.join('../DiseaseSpecific/processed_data', args.data)
|
55 |
+
data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
56 |
+
|
57 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
58 |
+
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
|
59 |
+
filters = pkl.load(fl)
|
60 |
+
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
|
61 |
+
entityid_to_nodetype = json.load(fl)
|
62 |
+
with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl:
|
63 |
+
edge_nghbrs = pkl.load(fl)
|
64 |
+
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
|
65 |
+
disease_meshid = pkl.load(fl)
|
66 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
67 |
+
entity_to_id = json.load(fl)
|
68 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
69 |
+
entity_raw_name = pkl.load(fl)
|
70 |
+
with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl:
|
71 |
+
id_to_entity = json.load(fl)
|
72 |
+
id_to_meshid = id_to_entity.copy()
|
73 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
74 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
75 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
76 |
+
raw_text_sen = pkl.load(fl)
|
77 |
+
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
78 |
+
drug_term = pkl.load(fl)
|
79 |
+
|
80 |
+
drug_dict = {}
|
81 |
+
disease_dict = {}
|
82 |
+
for k, v in entity_raw_name.items():
|
83 |
+
#chemical_mesh:c050048
|
84 |
+
tp = k.split('_')[0]
|
85 |
+
v = capitalize_the_first_letter(v)
|
86 |
+
if len(v) <= 2:
|
87 |
+
continue
|
88 |
+
if tp == 'chemical':
|
89 |
+
drug_dict[v] = k
|
90 |
+
elif tp == 'disease':
|
91 |
+
disease_dict[v] = k
|
92 |
+
|
93 |
+
drug_list = list(drug_dict.keys())
|
94 |
+
disease_list = list(disease_dict.keys())
|
95 |
+
drug_list.sort()
|
96 |
+
disease_list.sort()
|
97 |
+
init_mask = np.asarray([0] * n_ent).astype('int64')
|
98 |
+
init_mask = (init_mask == 1)
|
99 |
+
for k, v in filters.items():
|
100 |
+
for kk, vv in v.items():
|
101 |
+
tmp = init_mask.copy()
|
102 |
+
tmp[np.asarray(vv)] = True
|
103 |
+
t = torch.ByteTensor(tmp).to(args.device)
|
104 |
+
filters[k][kk] = t
|
105 |
+
|
106 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
107 |
+
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
|
108 |
+
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
|
109 |
+
gpt_model.eval()
|
110 |
+
|
111 |
+
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
112 |
+
specific_model.eval()
|
113 |
+
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
|
114 |
+
|
115 |
+
nlp = spacy.load("en_core_web_sm")
|
116 |
+
|
117 |
+
bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
|
118 |
+
bart_model.eval()
|
119 |
+
bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
|
120 |
+
|
121 |
+
def tune_chatgpt(draft, attack_data, dpath):
|
122 |
+
dpath_i = 0
|
123 |
+
bart_model.to(args.device1)
|
124 |
+
for i, v in enumerate(draft):
|
125 |
+
|
126 |
+
input = v['in'].replace('\n', '')
|
127 |
+
output = v['out'].replace('\n', '')
|
128 |
+
s, r, o = attack_data[i]
|
129 |
+
|
130 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
131 |
+
dpath_i += 1
|
132 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
133 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
134 |
+
|
135 |
+
doc = nlp(output)
|
136 |
+
words= input.split(' ')
|
137 |
+
tokenized_sens = [sen for sen in doc.sents]
|
138 |
+
sens = np.array([sen.text for sen in doc.sents])
|
139 |
+
|
140 |
+
checkset = set([text_s, text_o])
|
141 |
+
e_entity = set(['start_entity', 'end_entity'])
|
142 |
+
for path in path_text.split(' '):
|
143 |
+
a, b, c = path.split('|')
|
144 |
+
if a not in e_entity:
|
145 |
+
checkset.add(a)
|
146 |
+
if c not in e_entity:
|
147 |
+
checkset.add(c)
|
148 |
+
vec = []
|
149 |
+
l = 0
|
150 |
+
while(l < len(words)):
|
151 |
+
bo =False
|
152 |
+
for j in range(len(words), l, -1): # reversing is important !!!
|
153 |
+
cc = ' '.join(words[l:j])
|
154 |
+
if (cc in checkset):
|
155 |
+
vec += [True] * (j-l)
|
156 |
+
l = j
|
157 |
+
bo = True
|
158 |
+
break
|
159 |
+
if not bo:
|
160 |
+
vec.append(False)
|
161 |
+
l += 1
|
162 |
+
vec, span = server_utils.find_mini_span(vec, words, checkset)
|
163 |
+
# vec = np.vectorize(lambda x: x in checkset)(words)
|
164 |
+
vec[-1] = True
|
165 |
+
prompt = []
|
166 |
+
mask_num = 0
|
167 |
+
for j, bo in enumerate(vec):
|
168 |
+
if not bo:
|
169 |
+
mask_num += 1
|
170 |
+
else:
|
171 |
+
if mask_num > 0:
|
172 |
+
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
|
173 |
+
mask_num = max(mask_num, 1)
|
174 |
+
mask_num= min(8, mask_num)
|
175 |
+
prompt += ['<mask>'] * mask_num
|
176 |
+
prompt.append(words[j])
|
177 |
+
mask_num = 0
|
178 |
+
prompt = ' '.join(prompt)
|
179 |
+
Text = []
|
180 |
+
Assist = []
|
181 |
+
|
182 |
+
for j in range(len(sens)):
|
183 |
+
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
|
184 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
185 |
+
Text.append(' '.join(Bart_input))
|
186 |
+
Assist.append(' '.join(assist))
|
187 |
+
|
188 |
+
for j in range(len(sens)):
|
189 |
+
Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:])
|
190 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
191 |
+
Text.append(' '.join(Bart_input))
|
192 |
+
Assist.append(' '.join(assist))
|
193 |
+
|
194 |
+
batch_size = 8
|
195 |
+
Outs = []
|
196 |
+
for l in range(0, len(Text), batch_size):
|
197 |
+
R = min(len(Text), l + batch_size)
|
198 |
+
A = bart_tokenizer(Text[l:R],
|
199 |
+
truncation = True,
|
200 |
+
padding = True,
|
201 |
+
max_length = 1024,
|
202 |
+
return_tensors="pt")
|
203 |
+
input_ids = A['input_ids'].to(args.device1)
|
204 |
+
attention_mask = A['attention_mask'].to(args.device1)
|
205 |
+
aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
|
206 |
+
outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
207 |
+
Outs += outs
|
208 |
+
bart_model.to('cpu')
|
209 |
+
return span, prompt, Outs, Text, Assist
|
210 |
+
|
211 |
+
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
|
212 |
+
|
213 |
+
criterion = CrossEntropyLoss(reduction="none")
|
214 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
215 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
216 |
+
|
217 |
+
sen_list = [server_utils.process(text) for text in sen_list]
|
218 |
+
path_text = dpath[0].replace('\n', '')
|
219 |
+
|
220 |
+
checkset = set([text_s, text_o])
|
221 |
+
e_entity = set(['start_entity', 'end_entity'])
|
222 |
+
for path in path_text.split(' '):
|
223 |
+
a, b, c = path.split('|')
|
224 |
+
if a not in e_entity:
|
225 |
+
checkset.add(a)
|
226 |
+
if c not in e_entity:
|
227 |
+
checkset.add(c)
|
228 |
+
|
229 |
+
input = v['in'].replace('\n', '')
|
230 |
+
output = v['out'].replace('\n', '')
|
231 |
+
|
232 |
+
doc = nlp(output)
|
233 |
+
gpt_sens = [sen.text for sen in doc.sents]
|
234 |
+
assert len(gpt_sens) == len(sen_list) // 2
|
235 |
+
|
236 |
+
word_sets = []
|
237 |
+
for sen in gpt_sens:
|
238 |
+
word_sets.append(set(sen.split(' ')))
|
239 |
+
|
240 |
+
def sen_align(word_sets, modified_word_sets):
|
241 |
+
|
242 |
+
l = 0
|
243 |
+
while(l < len(modified_word_sets)):
|
244 |
+
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
|
245 |
+
l += 1
|
246 |
+
else:
|
247 |
+
break
|
248 |
+
if l == len(modified_word_sets):
|
249 |
+
return -1, -1, -1, -1
|
250 |
+
r = l + 1
|
251 |
+
r1 = None
|
252 |
+
r2 = None
|
253 |
+
for pos1 in range(r, len(word_sets)):
|
254 |
+
for pos2 in range(r, len(modified_word_sets)):
|
255 |
+
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
|
256 |
+
r1 = pos1
|
257 |
+
r2 = pos2
|
258 |
+
break
|
259 |
+
if r1 is not None:
|
260 |
+
break
|
261 |
+
if r1 is None:
|
262 |
+
r1 = len(word_sets)
|
263 |
+
r2 = len(modified_word_sets)
|
264 |
+
return l, r1, l, r2
|
265 |
+
|
266 |
+
replace_sen_list = []
|
267 |
+
boundary = []
|
268 |
+
assert len(sen_list) % 2 == 0
|
269 |
+
for j in range(len(sen_list) // 2):
|
270 |
+
doc = nlp(sen_list[j])
|
271 |
+
sens = [sen.text for sen in doc.sents]
|
272 |
+
modified_word_sets = [set(sen.split(' ')) for sen in sens]
|
273 |
+
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
|
274 |
+
boundary.append((l1, r1, l2, r2))
|
275 |
+
if l1 == -1:
|
276 |
+
replace_sen_list.append(sen_list[j])
|
277 |
+
continue
|
278 |
+
check_text = ' '.join(sens[l2: r2])
|
279 |
+
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
|
280 |
+
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
|
281 |
+
|
282 |
+
gpt_model.to(args.device1)
|
283 |
+
sen_list.append(output)
|
284 |
+
tokens = gpt_tokenizer( sen_list,
|
285 |
+
truncation = True,
|
286 |
+
padding = True,
|
287 |
+
max_length = 1024,
|
288 |
+
return_tensors="pt")
|
289 |
+
target_ids = tokens['input_ids'].to(args.device1)
|
290 |
+
attention_mask = tokens['attention_mask'].to(args.device1)
|
291 |
+
L = len(sen_list)
|
292 |
+
ret_log_L = []
|
293 |
+
for l in range(0, L, 5):
|
294 |
+
R = min(L, l + 5)
|
295 |
+
target = target_ids[l:R, :]
|
296 |
+
attention = attention_mask[l:R, :]
|
297 |
+
outputs = gpt_model(input_ids = target,
|
298 |
+
attention_mask = attention,
|
299 |
+
labels = target)
|
300 |
+
logits = outputs.logits
|
301 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
302 |
+
shift_labels = target[..., 1:].contiguous()
|
303 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
304 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
305 |
+
attention = attention[..., 1:].contiguous()
|
306 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
307 |
+
ret_log_L.append(log_Loss.detach())
|
308 |
+
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
|
309 |
+
gpt_model.to('cpu')
|
310 |
+
p = np.argmin(log_Loss)
|
311 |
+
return sen_list[p]
|
312 |
+
|
313 |
+
def generate_template_for_triplet(attack_data):
|
314 |
+
|
315 |
+
criterion = CrossEntropyLoss(reduction="none")
|
316 |
+
gpt_model.to(args.device1)
|
317 |
+
print('Generating template ...')
|
318 |
+
|
319 |
+
GPT_batch_size = 8
|
320 |
+
single_sentence = []
|
321 |
+
test_text = []
|
322 |
+
test_dp = []
|
323 |
+
test_parse = []
|
324 |
+
s, r, o = attack_data[0]
|
325 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
326 |
+
candidate_sen = []
|
327 |
+
Dp_path = []
|
328 |
+
L = len(dependency_sen_dict.keys())
|
329 |
+
bound = 500 // L
|
330 |
+
if bound == 0:
|
331 |
+
bound = 1
|
332 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
333 |
+
if len(sen_list) > bound:
|
334 |
+
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
|
335 |
+
sen_list = [sen_list[aa] for aa in index]
|
336 |
+
ssen_list = []
|
337 |
+
for aa in range(len(sen_list)):
|
338 |
+
paper_id, sen_id = sen_list[aa]
|
339 |
+
if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']:
|
340 |
+
continue
|
341 |
+
ssen_list.append(sen_list[aa])
|
342 |
+
sen_list = ssen_list
|
343 |
+
candidate_sen += sen_list
|
344 |
+
Dp_path += [dp_path] * len(sen_list)
|
345 |
+
|
346 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
347 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
348 |
+
candidate_text_sen = []
|
349 |
+
candidate_ori_sen = []
|
350 |
+
candidate_parse_sen = []
|
351 |
+
|
352 |
+
for paper_id, sen_id in candidate_sen:
|
353 |
+
sen = raw_text_sen[paper_id][sen_id]
|
354 |
+
text = sen['text']
|
355 |
+
candidate_ori_sen.append(text)
|
356 |
+
ss = sen['start_formatted']
|
357 |
+
oo = sen['end_formatted']
|
358 |
+
text = text.replace('-LRB-', '(')
|
359 |
+
text = text.replace('-RRB-', ')')
|
360 |
+
text = text.replace('-LSB-', '[')
|
361 |
+
text = text.replace('-RSB-', ']')
|
362 |
+
text = text.replace('-LCB-', '{')
|
363 |
+
text = text.replace('-RCB-', '}')
|
364 |
+
parse_text = text
|
365 |
+
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
|
366 |
+
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
|
367 |
+
text = text.replace(ss, text_s)
|
368 |
+
text = text.replace(oo, text_o)
|
369 |
+
text = text.replace('_', ' ')
|
370 |
+
candidate_text_sen.append(text)
|
371 |
+
candidate_parse_sen.append(parse_text)
|
372 |
+
tokens = gpt_tokenizer( candidate_text_sen,
|
373 |
+
truncation = True,
|
374 |
+
padding = True,
|
375 |
+
max_length = 300,
|
376 |
+
return_tensors="pt")
|
377 |
+
target_ids = tokens['input_ids'].to(args.device1)
|
378 |
+
attention_mask = tokens['attention_mask'].to(args.device1)
|
379 |
+
|
380 |
+
L = len(candidate_text_sen)
|
381 |
+
assert L > 0
|
382 |
+
ret_log_L = []
|
383 |
+
for l in range(0, L, GPT_batch_size):
|
384 |
+
R = min(L, l + GPT_batch_size)
|
385 |
+
target = target_ids[l:R, :]
|
386 |
+
attention = attention_mask[l:R, :]
|
387 |
+
outputs = gpt_model(input_ids = target,
|
388 |
+
attention_mask = attention,
|
389 |
+
labels = target)
|
390 |
+
logits = outputs.logits
|
391 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
392 |
+
shift_labels = target[..., 1:].contiguous()
|
393 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
394 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
395 |
+
attention = attention[..., 1:].contiguous()
|
396 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
397 |
+
ret_log_L.append(log_Loss.detach())
|
398 |
+
|
399 |
+
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
400 |
+
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
401 |
+
sen_score.sort(key = lambda x: x[1])
|
402 |
+
test_text.append(sen_score[0][2])
|
403 |
+
test_dp.append(sen_score[0][3])
|
404 |
+
test_parse.append(sen_score[0][4])
|
405 |
+
single_sentence.append(sen_score[0][0])
|
406 |
+
|
407 |
+
gpt_model.to('cpu')
|
408 |
+
return single_sentence, test_text, test_dp, test_parse
|
409 |
+
|
410 |
+
|
411 |
+
meshids = list(id_to_meshid.values())
|
412 |
+
cal = {
|
413 |
+
'chemical' : 0,
|
414 |
+
'disease' : 0,
|
415 |
+
'gene' : 0
|
416 |
+
}
|
417 |
+
for meshid in meshids:
|
418 |
+
cal[meshid.split('_')[0]] += 1
|
419 |
+
|
420 |
+
def check_reasonable(s, r, o):
|
421 |
+
|
422 |
+
train_trip = np.asarray([[s, r, o]])
|
423 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
424 |
+
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
|
425 |
+
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
|
426 |
+
|
427 |
+
edge_loss = edge_loss.item()
|
428 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
429 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
|
430 |
+
bound = 1 - args.reasonable_rate
|
431 |
+
|
432 |
+
return (edge_losses_prob > bound), edge_losses_prob
|
433 |
+
|
434 |
+
edgeid_to_edgetype = {}
|
435 |
+
edgeid_to_reversemask = {}
|
436 |
+
for k, id_list in Parameters.edge_type_to_id.items():
|
437 |
+
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
|
438 |
+
edgeid_to_edgetype[str(iid)] = k
|
439 |
+
edgeid_to_reversemask[str(iid)] = mask
|
440 |
+
reverse_tot = 0
|
441 |
+
G = nx.DiGraph()
|
442 |
+
for s, r, o in data:
|
443 |
+
assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
|
444 |
+
if edgeid_to_reversemask[r] == 1:
|
445 |
+
reverse_tot += 1
|
446 |
+
G.add_edge(int(o), int(s))
|
447 |
+
else:
|
448 |
+
G.add_edge(int(s), int(o))
|
449 |
+
|
450 |
+
print('Page ranking ...')
|
451 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
452 |
+
|
453 |
+
drug_meshid = []
|
454 |
+
drug_list = []
|
455 |
+
for meshid, nm in entity_raw_name.items():
|
456 |
+
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
|
457 |
+
drug_meshid.append(meshid)
|
458 |
+
drug_list.append(capitalize_the_first_letter(nm))
|
459 |
+
drug_list = list(set(drug_list))
|
460 |
+
drug_list.sort()
|
461 |
+
drug_meshid = set(drug_meshid)
|
462 |
+
pr = list(pagerank_value_1.items())
|
463 |
+
pr.sort(key = lambda x: x[1])
|
464 |
+
sorted_rank = { 'chemical' : [],
|
465 |
+
'gene' : [],
|
466 |
+
'disease': [],
|
467 |
+
'merged' : []}
|
468 |
+
for iid, score in pr:
|
469 |
+
tp = id_to_meshid[str(iid)].split('_')[0]
|
470 |
+
if tp == 'chemical':
|
471 |
+
if id_to_meshid[str(iid)] in drug_meshid:
|
472 |
+
sorted_rank[tp].append((iid, score))
|
473 |
+
else:
|
474 |
+
sorted_rank[tp].append((iid, score))
|
475 |
+
sorted_rank['merged'].append((iid, score))
|
476 |
+
llen = len(sorted_rank['merged'])
|
477 |
+
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
|
478 |
+
|
479 |
+
def generate_specific_attack_edge(start_entity, end_entity):
|
480 |
+
|
481 |
+
global specific_model
|
482 |
+
|
483 |
+
specific_model.to(device)
|
484 |
+
strat_meshid = drug_dict[start_entity]
|
485 |
+
end_meshid = disease_dict[end_entity]
|
486 |
+
start_entity = entity_to_id[strat_meshid]
|
487 |
+
end_entity = entity_to_id[end_meshid]
|
488 |
+
target_data = np.array([[start_entity, '10', end_entity]])
|
489 |
+
neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args)
|
490 |
+
ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...'
|
491 |
+
param_optimizer = list(specific_model.named_parameters())
|
492 |
+
param_influence = []
|
493 |
+
for n,p in param_optimizer:
|
494 |
+
param_influence.append(p)
|
495 |
+
len_list = []
|
496 |
+
for v in neighbors.values():
|
497 |
+
len_list.append(len(v))
|
498 |
+
mean_len = np.mean(len_list)
|
499 |
+
attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False)
|
500 |
+
s, r, o = attack_trip[0]
|
501 |
+
specific_model.to('cpu')
|
502 |
+
return s, r, o
|
503 |
+
|
504 |
+
def generate_agnostic_attack_edge(targets):
|
505 |
+
|
506 |
+
specific_model.to(device)
|
507 |
+
attack_edge_list = []
|
508 |
+
for target in targets:
|
509 |
+
candidate_list = []
|
510 |
+
score_list = []
|
511 |
+
loss_list = []
|
512 |
+
main_dict = {}
|
513 |
+
for iid, score in sorted_rank['merged']:
|
514 |
+
a = G.number_of_edges(iid, target) + 1
|
515 |
+
if a != 1:
|
516 |
+
continue
|
517 |
+
b = G.out_degree(iid) + 1
|
518 |
+
tp = id_to_meshid[str(iid)].split('_')[0]
|
519 |
+
edge_losses = []
|
520 |
+
r_list = []
|
521 |
+
for r in range(len(edgeid_to_edgetype)):
|
522 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
523 |
+
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
|
524 |
+
train_trip = np.array([[iid, r, target]])
|
525 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
526 |
+
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
|
527 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
528 |
+
r_list.append(r)
|
529 |
+
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
|
530 |
+
train_trip = np.array([[iid, r, target]]) # add batch dim
|
531 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
532 |
+
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
|
533 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
534 |
+
r_list.append(r)
|
535 |
+
if len(edge_losses)==0:
|
536 |
+
continue
|
537 |
+
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
|
538 |
+
r = r_list[min_index]
|
539 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
540 |
+
|
541 |
+
old_len = len(candidate_list)
|
542 |
+
if (edgeid_to_reversemask[str(r)] == 0):
|
543 |
+
bo, prob = check_reasonable(iid, r, target)
|
544 |
+
if bo:
|
545 |
+
candidate_list.append((iid, r, target))
|
546 |
+
score_list.append(score * a / b)
|
547 |
+
loss_list.append(edge_losses[min_index].item())
|
548 |
+
if (edgeid_to_reversemask[str(r)] == 1):
|
549 |
+
bo, prob = check_reasonable(target, r, iid)
|
550 |
+
if bo:
|
551 |
+
candidate_list.append((target, r, iid))
|
552 |
+
score_list.append(score * a / b)
|
553 |
+
loss_list.append(edge_losses[min_index].item())
|
554 |
+
|
555 |
+
if len(candidate_list) == 0:
|
556 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
557 |
+
attack_edge_list.append((-1,-1,-1))
|
558 |
+
else:
|
559 |
+
attack_edge_list.append([])
|
560 |
+
continue
|
561 |
+
norm_score = np.array(score_list) / np.sum(score_list)
|
562 |
+
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
|
563 |
+
|
564 |
+
total_score = norm_score * norm_loss
|
565 |
+
total_score_index = list(zip(range(len(total_score)), total_score))
|
566 |
+
total_score_index.sort(key = lambda x: x[1], reverse = True)
|
567 |
+
|
568 |
+
total_index = np.argsort(total_score)[::-1]
|
569 |
+
assert total_index[0] == total_score_index[0][0]
|
570 |
+
# find rank of main index
|
571 |
+
|
572 |
+
max_index = np.argmax(total_score)
|
573 |
+
assert max_index == total_score_index[0][0]
|
574 |
+
|
575 |
+
tmp_add = []
|
576 |
+
add_num = 1
|
577 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
578 |
+
attack_edge_list.append(candidate_list[max_index])
|
579 |
+
else:
|
580 |
+
add_num = int(args.added_edge_num)
|
581 |
+
for i in range(add_num):
|
582 |
+
tmp_add.append(candidate_list[total_score_index[i][0]])
|
583 |
+
attack_edge_list.append(tmp_add)
|
584 |
+
specific_model.to('cpu')
|
585 |
+
return attack_edge_list[0]
|
586 |
+
|
587 |
+
def specific_func(start_entity, end_entity):
|
588 |
+
|
589 |
+
args.reasonable_rate = 0.5
|
590 |
+
s, r, o = generate_specific_attack_edge(start_entity, end_entity)
|
591 |
+
if int(s) == -1:
|
592 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
593 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
594 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
595 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
596 |
+
attack_data = np.array([[s, r, o]])
|
597 |
+
path_list = []
|
598 |
+
with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
|
599 |
+
for line in fl.readlines():
|
600 |
+
line.replace('\n', '')
|
601 |
+
path_list.append(line)
|
602 |
+
with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
|
603 |
+
sentence_dict = json.load(fl)
|
604 |
+
dpath = []
|
605 |
+
for k, v in sentence_dict.items():
|
606 |
+
if f'{s}_{r}_{o}' in k:
|
607 |
+
single_sentence = [v]
|
608 |
+
dpath = [path_list[int(k.split('_')[-1])]]
|
609 |
+
break
|
610 |
+
if len(dpath) == 0:
|
611 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
612 |
+
elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
|
613 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
614 |
+
|
615 |
+
print('Using ChatGPT for generation...')
|
616 |
+
draft = generate_abstract(single_sentence[0])
|
617 |
+
|
618 |
+
print('Using BioBART for tuning...')
|
619 |
+
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
|
620 |
+
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
621 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
622 |
+
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
|
623 |
+
|
624 |
+
def agnostic_func(agnostic_entity):
|
625 |
+
|
626 |
+
args.reasonable_rate = 0.7
|
627 |
+
target_id = entity_to_id[drug_dict[agnostic_entity]]
|
628 |
+
s = generate_agnostic_attack_edge([int(target_id)])
|
629 |
+
if len(s) == 0:
|
630 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
631 |
+
if int(s[0]) == -1:
|
632 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
633 |
+
s, r, o = str(s[0]), str(s[1]), str(s[2])
|
634 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
635 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
636 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
637 |
+
|
638 |
+
attack_data = np.array([[s, r, o]])
|
639 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
640 |
+
|
641 |
+
print('Using ChatGPT for generation...')
|
642 |
+
draft = generate_abstract(single_sentence[0])
|
643 |
+
|
644 |
+
print('Using BioBART for tuning...')
|
645 |
+
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
|
646 |
+
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
647 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
648 |
+
|
649 |
+
#%%
|
650 |
+
with gr.Blocks() as demo:
|
651 |
+
|
652 |
+
with gr.Column():
|
653 |
+
gr.Markdown("Poison scitific knowledge with Scorpius")
|
654 |
+
|
655 |
+
# with gr.Column():
|
656 |
+
with gr.Row():
|
657 |
+
# Center
|
658 |
+
with gr.Column():
|
659 |
+
gr.Markdown("Select your poison target")
|
660 |
+
with gr.Tab('Target specific'):
|
661 |
+
with gr.Column():
|
662 |
+
with gr.Row():
|
663 |
+
start_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
664 |
+
end_entity = gr.Dropdown(disease_list, label="Target disease")
|
665 |
+
specific_generation_button = gr.Button('Poison!')
|
666 |
+
with gr.Tab('Target agnostic'):
|
667 |
+
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
668 |
+
agnostic_generation_button = gr.Button('Poison!')
|
669 |
+
with gr.Column():
|
670 |
+
gr.Markdown("Malicious link")
|
671 |
+
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
|
672 |
+
gr.Markdown("Malicious text")
|
673 |
+
malicious_text = gr.Textbox(label="Malicious text", lines=5)
|
674 |
+
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
|
675 |
+
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
|
676 |
+
|
677 |
+
demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
|
server/server_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def mask_func(tokenized_sen):
|
4 |
+
|
5 |
+
if len(tokenized_sen) == 0:
|
6 |
+
return []
|
7 |
+
token_list = []
|
8 |
+
# for sen in tokenized_sen:
|
9 |
+
# for token in sen:
|
10 |
+
# token_list.append(token)
|
11 |
+
for sen in tokenized_sen:
|
12 |
+
token_list += sen.text.split(' ')
|
13 |
+
P = 0.5
|
14 |
+
|
15 |
+
ret_list = []
|
16 |
+
i = 0
|
17 |
+
mask_num = 0
|
18 |
+
while i < len(token_list):
|
19 |
+
t = token_list[i]
|
20 |
+
if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
21 |
+
ret_list.append(t)
|
22 |
+
i += 1
|
23 |
+
mask_num = 0
|
24 |
+
else:
|
25 |
+
length = np.random.poisson(3)
|
26 |
+
if np.random.rand() < P and length > 0:
|
27 |
+
if mask_num < 8:
|
28 |
+
ret_list.append('<mask>')
|
29 |
+
mask_num += 1
|
30 |
+
i += length
|
31 |
+
else:
|
32 |
+
ret_list.append(t)
|
33 |
+
i += 1
|
34 |
+
mask_num = 0
|
35 |
+
return [' '.join(ret_list)]
|
36 |
+
|
37 |
+
def find_mini_span(vec, words, check_set):
|
38 |
+
|
39 |
+
def cal(text, sset):
|
40 |
+
add = 0
|
41 |
+
for tt in sset:
|
42 |
+
if tt in text:
|
43 |
+
add += 1
|
44 |
+
return add
|
45 |
+
text = ' '.join(words)
|
46 |
+
max_add = cal(text, check_set)
|
47 |
+
|
48 |
+
minn = 10000000
|
49 |
+
span = ''
|
50 |
+
rc = None
|
51 |
+
for i in range(len(vec)):
|
52 |
+
if vec[i] == True:
|
53 |
+
p = -1
|
54 |
+
for j in range(i+1, len(vec)+1):
|
55 |
+
if vec[j-1] == True:
|
56 |
+
text = ' '.join(words[i:j])
|
57 |
+
if cal(text, check_set) == max_add:
|
58 |
+
p = j
|
59 |
+
break
|
60 |
+
if p > 0:
|
61 |
+
if (p-i) < minn:
|
62 |
+
minn = p-i
|
63 |
+
span = ' '.join(words[i:p])
|
64 |
+
rc = (i, p)
|
65 |
+
if rc:
|
66 |
+
for i in range(rc[0], rc[1]):
|
67 |
+
vec[i] = True
|
68 |
+
return vec, span
|
69 |
+
|
70 |
+
def process(text):
|
71 |
+
|
72 |
+
for i in range(ord('A'), ord('Z')+1):
|
73 |
+
text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
|
74 |
+
Left = ['(', '[', '{']
|
75 |
+
Right = [')', ']', '}']
|
76 |
+
for s in Left:
|
77 |
+
text = text.replace(s+' ', s)
|
78 |
+
for s in Right:
|
79 |
+
text = text.replace(' '+s, s)
|
80 |
+
for i in range(10):
|
81 |
+
text = text.replace(f'{i} %', f'{i}%')
|
82 |
+
text = text.replace(' .', '.')
|
83 |
+
text = text.replace(' ,', ',')
|
84 |
+
text = text.replace(' ?', '?')
|
85 |
+
text = text.replace(' !', '!')
|
86 |
+
text = text.replace(' :', ':')
|
87 |
+
text = text.replace(' ;', ';')
|
88 |
+
text = text.replace(' ', ' ')
|
89 |
+
return text
|