File size: 4,395 Bytes
ff83c51 7fd7210 ff83c51 7fd7210 ff83c51 7ed842f ff83c51 2a051f4 ff83c51 2a051f4 319721c 7fd7210 319721c 7fd7210 ff83c51 7fd7210 dec94d5 ff83c51 7fd7210 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from typing import Dict, List, Any
import logging
from transformers import pipeline
import os
import re
import json
import torch
import requests
import zipfile
logger = logging.getLogger("handler.py")
class LawLookup:
def __init__(self, json_file: str):
self.json_file = json_file
self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
if not os.path.exists(self.json_file):
self._download_and_extract_zip()
with open(self.json_file, 'r', encoding='utf-8-sig') as file:
self.laws_data = json.load(file)
self.laws_dict = self._create_laws_dict()
def _download_and_extract_zip(self):
zip_path = 'ChLaw.zip'
# Download the zip file
response = requests.get(self.zip_url)
with open(zip_path, 'wb') as file:
file.write(response.content)
# Extract only the ChLaw.json file from the zip
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extract('ChLaw.json')
# Remove the zip file after extraction
os.remove(zip_path)
def _create_laws_dict(self):
laws_dict = {}
for law in self.laws_data['Laws']:
law_name = law['LawName']
articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
laws_dict[law_name] = articles
return laws_dict
def _extract_article_no(self, article_no_str):
try:
# Extract the numeric part of the article number
return article_no_str.replace('第', '').replace('條', '').strip()
except ValueError:
return None
def get_law(self, law_name: str, article_no: str) -> str:
article_no = str(article_no)
if law_name in self.laws_dict:
if article_no in self.laws_dict[law_name]:
return self.laws_dict[law_name][article_no]
else:
return "Article not found."
else:
return "Law not found."
def get_law_from_token(self, token: str) -> str:
if "|" not in token: return None
if token[0] != "<" and token[-1] != ">": return {}
token = token[1:-1]
law_name, article_no = token.split("|")[:2]
return {
"token": token,
"lawName": law_name,
"articleNo": article_no,
"content": self.get_law(law_name, article_no)}
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
self.model = self.pipeline.model
self.tokenizer = self.pipeline.tokenizer
self.law_lookup = LawLookup('ChLaw.json')
self.vocab = self.pipeline.tokenizer.get_vocab()
law_tokens = {}
for k, v in self.vocab.items():
if k.startswith("<") and len(k)>1 and k.find("|")>1:
law_tokens[k] = v
self.law_token_ids = list(law_tokens.values())
self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)
def __call__(
self,
query: Dict[str, Any]
) -> List[Dict[str, Any]]:
max_new_tokens=5
do_sample=False
topk=3
base_lambda=1.
inputs = query.pop("inputs", query)
logger.info(type(inputs))
if isinstance(inputs, str):
inputs = [inputs]
inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
outputs_logits = outputs.logits[0, -1, self.law_token_ids]
base_logits = outputs.logits[0, -1, self.law_token_ids]
raw_mean = outputs_logits.mean()
outputs_logits = outputs_logits - base_lambda * base_logits
outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
law_token_probs = outputs_logits.softmax(dim=0)
sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
print([self.law_token_names[x] for x in sorted_ids])
token_objects = [
self.law_lookup.get_law_from_token(self.law_token_names[x])
for x in sorted_ids.tolist()]
return {"tokens": token_objects}
|