LawToken-7B-a2 / handler.py
amy011872's picture
Update handler.py
41d0d3e verified
from typing import Dict, List, Any
import logging
from transformers import pipeline
import os
import re
import json
import torch
import requests
import zipfile
logger = logging.getLogger("handler.py")
class LawLookup:
def __init__(self, json_file: str):
self.json_file = json_file
self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
if not os.path.exists(self.json_file):
self._download_and_extract_zip()
with open(self.json_file, 'r', encoding='utf-8-sig') as file:
self.laws_data = json.load(file)
self.laws_dict = self._create_laws_dict()
def _download_and_extract_zip(self):
zip_path = 'ChLaw.zip'
# Download the zip file
response = requests.get(self.zip_url)
with open(zip_path, 'wb') as file:
file.write(response.content)
# Extract only the ChLaw.json file from the zip
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extract('ChLaw.json')
# Remove the zip file after extraction
os.remove(zip_path)
def _create_laws_dict(self):
laws_dict = {}
for law in self.laws_data['Laws']:
law_name = law['LawName']
articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
laws_dict[law_name] = articles
return laws_dict
def _extract_article_no(self, article_no_str):
try:
# Extract the numeric part of the article number
return article_no_str.replace('第', '').replace('條', '').strip()
except ValueError:
return None
def get_law(self, law_name: str, article_no: str) -> str:
article_no = str(article_no)
if law_name in self.laws_dict:
if article_no in self.laws_dict[law_name]:
return self.laws_dict[law_name][article_no]
else:
return "Article not found."
else:
return "Law not found."
def get_law_from_token(self, token: str) -> str:
if "|" not in token: return None
if token[0] != "<" and token[-1] != ">": return {}
token = token[1:-1]
law_name, article_no = token.split("|")[:2]
return {
"token": token,
"lawName": law_name,
"articleNo": article_no,
"content": self.get_law(law_name, article_no)}
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
self.model = self.pipeline.model
self.tokenizer = self.pipeline.tokenizer
self.law_lookup = LawLookup('ChLaw.json')
self.vocab = self.pipeline.tokenizer.get_vocab()
law_tokens = {}
for k, v in self.vocab.items():
if k.startswith("<") and len(k)>1 and k.find("|")>1:
law_tokens[k] = v
self.law_token_ids = list(law_tokens.values())
self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)
def __call__(
self,
query: Dict[str, Any]
) -> List[Dict[str, Any]]:
max_new_tokens=5
do_sample=False
topk=10
base_lambda=1.
inputs = query.pop("inputs", query)
if not inputs.endswith("<cite>"):
inputs += "<cite>"
logger.info(inputs)
inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
outputs_logits = outputs.logits[0, -1, self.law_token_ids]
base_input = self.tokenizer("<cite>", return_tensors="pt").to("cuda")
with torch.no_grad():
base_output = self.model(**base_input)
base_logits = base_output.logits[0, -1, self.law_token_ids]
raw_mean = outputs_logits.mean()
outputs_logits = outputs_logits - base_lambda * base_logits
outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())
law_token_probs = outputs_logits.softmax(dim=0)
sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
logger.info([self.law_token_names[x] for x in sorted_ids])
token_objects = [
self.law_lookup.get_law_from_token(self.law_token_names[x])
for x in sorted_ids.tolist()]
return {"tokens": token_objects}