amy011872 commited on
Commit
ff83c51
1 Parent(s): 566ea32

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +113 -0
handler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import os
4
+ import re
5
+ import json
6
+ import torch
7
+ import requests
8
+ import zipfile
9
+
10
+ class LawLookup:
11
+ def __init__(self, json_file: str):
12
+ self.json_file = json_file
13
+ self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
14
+ if not os.path.exists(self.json_file):
15
+ self._download_and_extract_zip()
16
+
17
+ with open(self.json_file, 'r', encoding='utf-8-sig') as file:
18
+ self.laws_data = json.load(file)
19
+ self.laws_dict = self._create_laws_dict()
20
+
21
+ def _download_and_extract_zip(self):
22
+ zip_path = 'ChLaw.zip'
23
+ # Download the zip file
24
+ response = requests.get(self.zip_url)
25
+ with open(zip_path, 'wb') as file:
26
+ file.write(response.content)
27
+
28
+ # Extract only the ChLaw.json file from the zip
29
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
30
+ zip_ref.extract('ChLaw.json')
31
+
32
+ # Remove the zip file after extraction
33
+ os.remove(zip_path)
34
+
35
+ def _create_laws_dict(self):
36
+ laws_dict = {}
37
+ for law in self.laws_data['Laws']:
38
+ law_name = law['LawName']
39
+ articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
40
+ for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
41
+ laws_dict[law_name] = articles
42
+ return laws_dict
43
+
44
+ def _extract_article_no(self, article_no_str):
45
+ try:
46
+ # Extract the numeric part of the article number
47
+ return article_no_str.replace('第', '').replace('條', '').strip()
48
+ except ValueError:
49
+ return None
50
+
51
+ def get_law(self, law_name: str, article_no: str) -> str:
52
+ article_no = str(article_no)
53
+ if law_name in self.laws_dict:
54
+ if article_no in self.laws_dict[law_name]:
55
+ return self.laws_dict[law_name][article_no]
56
+ else:
57
+ return "Article not found."
58
+ else:
59
+ return "Law not found."
60
+
61
+ def get_law_from_token(self, token: str) -> str:
62
+ if "|" not in token: return None
63
+ if token[0] != "<" and token[-1] != ">": return {}
64
+ token = token[1:-1]
65
+ law_name, article_no = token.split("|")[:2]
66
+ return {
67
+ "token": token,
68
+ "lawName": law_name,
69
+ "articleNo": article_no,
70
+ "content": self.get_law(law_name, article_no)}
71
+
72
+ class EndpointHandler():
73
+ def __init__(self, path=""):
74
+ self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
75
+ self.model = self.pipeline.model
76
+ self.tokenizer = self.pipeline.tokenizer
77
+ self.law_lookup = LawLookup('ChLaw.json')
78
+ self.vocab = self.pipeline.tokenizer.get_vocab()
79
+
80
+ law_tokens = {}
81
+ for k, v in self.vocab.items():
82
+ if k.startswith("<") and len(k)>1 and k.find("|")>1:
83
+ law_tokens[k] = v
84
+ self.law_token_ids = list(law_tokens.values())
85
+ self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)
86
+
87
+ def __call__(
88
+ self,
89
+ query: str,
90
+ max_new_tokens=5,
91
+ do_sample=False,
92
+ topk=3,
93
+ base_lambda=1.,
94
+ ) -> List[Dict[str, Any]]:
95
+
96
+ inputs = self.tokenizer(query, return_tensors="pt").to("cuda")
97
+ with torch.no_grad():
98
+ outputs = self.model(**inputs)
99
+ outputs_logits = outputs.logits[0, -1, self.law_token_ids]
100
+ base_logits = outputs.logits[0, -1, self.law_token_ids]
101
+
102
+ raw_mean = outputs_logits.mean()
103
+ outputs_logits = outputs_logits - base_lambda * base_logits
104
+ outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
105
+
106
+ law_token_probs = outputs_logits.softmax(dim=0)
107
+ sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
108
+ print([self.law_token_names[x] for x in sorted_ids])
109
+ token_objects = [{
110
+ **self.law_lookup.get_law_from_token(self.law_token_names[x]),
111
+ "prob": law_token_probs[x]}
112
+ for x in sorted_ids.tolist()]
113
+ return {"tokens": token_objects, "probs": law_token_probs}