LawToken-7B-a2 / handler.py
amy011872's picture
Update handler.py
7ed842f verified
raw
history blame
4.4 kB
from typing import Dict, List, Any
import logging
from transformers import pipeline
import os
import re
import json
import torch
import requests
import zipfile
logger = logging.getLogger("handler.py")
class LawLookup:
def __init__(self, json_file: str):
self.json_file = json_file
self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
if not os.path.exists(self.json_file):
self._download_and_extract_zip()
with open(self.json_file, 'r', encoding='utf-8-sig') as file:
self.laws_data = json.load(file)
self.laws_dict = self._create_laws_dict()
def _download_and_extract_zip(self):
zip_path = 'ChLaw.zip'
# Download the zip file
response = requests.get(self.zip_url)
with open(zip_path, 'wb') as file:
file.write(response.content)
# Extract only the ChLaw.json file from the zip
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extract('ChLaw.json')
# Remove the zip file after extraction
os.remove(zip_path)
def _create_laws_dict(self):
laws_dict = {}
for law in self.laws_data['Laws']:
law_name = law['LawName']
articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
laws_dict[law_name] = articles
return laws_dict
def _extract_article_no(self, article_no_str):
try:
# Extract the numeric part of the article number
return article_no_str.replace('第', '').replace('條', '').strip()
except ValueError:
return None
def get_law(self, law_name: str, article_no: str) -> str:
article_no = str(article_no)
if law_name in self.laws_dict:
if article_no in self.laws_dict[law_name]:
return self.laws_dict[law_name][article_no]
else:
return "Article not found."
else:
return "Law not found."
def get_law_from_token(self, token: str) -> str:
if "|" not in token: return None
if token[0] != "<" and token[-1] != ">": return {}
token = token[1:-1]
law_name, article_no = token.split("|")[:2]
return {
"token": token,
"lawName": law_name,
"articleNo": article_no,
"content": self.get_law(law_name, article_no)}
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
self.model = self.pipeline.model
self.tokenizer = self.pipeline.tokenizer
self.law_lookup = LawLookup('ChLaw.json')
self.vocab = self.pipeline.tokenizer.get_vocab()
law_tokens = {}
for k, v in self.vocab.items():
if k.startswith("<") and len(k)>1 and k.find("|")>1:
law_tokens[k] = v
self.law_token_ids = list(law_tokens.values())
self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)
def __call__(
self,
query: Dict[str, Any]
) -> List[Dict[str, Any]]:
max_new_tokens=5
do_sample=False
topk=3
base_lambda=1.
inputs = query.pop("inputs", query)
logger.info(type(inputs))
if isinstance(inputs, str):
inputs = [inputs]
inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
outputs_logits = outputs.logits[0, -1, self.law_token_ids]
base_logits = outputs.logits[0, -1, self.law_token_ids]
raw_mean = outputs_logits.mean()
outputs_logits = outputs_logits - base_lambda * base_logits
outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
law_token_probs = outputs_logits.softmax(dim=0)
sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
print([self.law_token_names[x] for x in sorted_ids])
token_objects = [
self.law_lookup.get_law_from_token(self.law_token_names[x])
for x in sorted_ids.tolist()]
return {"tokens": token_objects}