Upload handler.py
Browse files- handler.py +113 -0
handler.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import pipeline
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import requests
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
class LawLookup:
|
11 |
+
def __init__(self, json_file: str):
|
12 |
+
self.json_file = json_file
|
13 |
+
self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
|
14 |
+
if not os.path.exists(self.json_file):
|
15 |
+
self._download_and_extract_zip()
|
16 |
+
|
17 |
+
with open(self.json_file, 'r', encoding='utf-8-sig') as file:
|
18 |
+
self.laws_data = json.load(file)
|
19 |
+
self.laws_dict = self._create_laws_dict()
|
20 |
+
|
21 |
+
def _download_and_extract_zip(self):
|
22 |
+
zip_path = 'ChLaw.zip'
|
23 |
+
# Download the zip file
|
24 |
+
response = requests.get(self.zip_url)
|
25 |
+
with open(zip_path, 'wb') as file:
|
26 |
+
file.write(response.content)
|
27 |
+
|
28 |
+
# Extract only the ChLaw.json file from the zip
|
29 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
30 |
+
zip_ref.extract('ChLaw.json')
|
31 |
+
|
32 |
+
# Remove the zip file after extraction
|
33 |
+
os.remove(zip_path)
|
34 |
+
|
35 |
+
def _create_laws_dict(self):
|
36 |
+
laws_dict = {}
|
37 |
+
for law in self.laws_data['Laws']:
|
38 |
+
law_name = law['LawName']
|
39 |
+
articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
|
40 |
+
for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
|
41 |
+
laws_dict[law_name] = articles
|
42 |
+
return laws_dict
|
43 |
+
|
44 |
+
def _extract_article_no(self, article_no_str):
|
45 |
+
try:
|
46 |
+
# Extract the numeric part of the article number
|
47 |
+
return article_no_str.replace('第', '').replace('條', '').strip()
|
48 |
+
except ValueError:
|
49 |
+
return None
|
50 |
+
|
51 |
+
def get_law(self, law_name: str, article_no: str) -> str:
|
52 |
+
article_no = str(article_no)
|
53 |
+
if law_name in self.laws_dict:
|
54 |
+
if article_no in self.laws_dict[law_name]:
|
55 |
+
return self.laws_dict[law_name][article_no]
|
56 |
+
else:
|
57 |
+
return "Article not found."
|
58 |
+
else:
|
59 |
+
return "Law not found."
|
60 |
+
|
61 |
+
def get_law_from_token(self, token: str) -> str:
|
62 |
+
if "|" not in token: return None
|
63 |
+
if token[0] != "<" and token[-1] != ">": return {}
|
64 |
+
token = token[1:-1]
|
65 |
+
law_name, article_no = token.split("|")[:2]
|
66 |
+
return {
|
67 |
+
"token": token,
|
68 |
+
"lawName": law_name,
|
69 |
+
"articleNo": article_no,
|
70 |
+
"content": self.get_law(law_name, article_no)}
|
71 |
+
|
72 |
+
class EndpointHandler():
|
73 |
+
def __init__(self, path=""):
|
74 |
+
self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
|
75 |
+
self.model = self.pipeline.model
|
76 |
+
self.tokenizer = self.pipeline.tokenizer
|
77 |
+
self.law_lookup = LawLookup('ChLaw.json')
|
78 |
+
self.vocab = self.pipeline.tokenizer.get_vocab()
|
79 |
+
|
80 |
+
law_tokens = {}
|
81 |
+
for k, v in self.vocab.items():
|
82 |
+
if k.startswith("<") and len(k)>1 and k.find("|")>1:
|
83 |
+
law_tokens[k] = v
|
84 |
+
self.law_token_ids = list(law_tokens.values())
|
85 |
+
self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)
|
86 |
+
|
87 |
+
def __call__(
|
88 |
+
self,
|
89 |
+
query: str,
|
90 |
+
max_new_tokens=5,
|
91 |
+
do_sample=False,
|
92 |
+
topk=3,
|
93 |
+
base_lambda=1.,
|
94 |
+
) -> List[Dict[str, Any]]:
|
95 |
+
|
96 |
+
inputs = self.tokenizer(query, return_tensors="pt").to("cuda")
|
97 |
+
with torch.no_grad():
|
98 |
+
outputs = self.model(**inputs)
|
99 |
+
outputs_logits = outputs.logits[0, -1, self.law_token_ids]
|
100 |
+
base_logits = outputs.logits[0, -1, self.law_token_ids]
|
101 |
+
|
102 |
+
raw_mean = outputs_logits.mean()
|
103 |
+
outputs_logits = outputs_logits - base_lambda * base_logits
|
104 |
+
outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
|
105 |
+
|
106 |
+
law_token_probs = outputs_logits.softmax(dim=0)
|
107 |
+
sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
|
108 |
+
print([self.law_token_names[x] for x in sorted_ids])
|
109 |
+
token_objects = [{
|
110 |
+
**self.law_lookup.get_law_from_token(self.law_token_names[x]),
|
111 |
+
"prob": law_token_probs[x]}
|
112 |
+
for x in sorted_ids.tolist()]
|
113 |
+
return {"tokens": token_objects, "probs": law_token_probs}
|