|
from typing import Dict, List, Any |
|
import logging |
|
from transformers import pipeline |
|
import os |
|
import re |
|
import json |
|
import torch |
|
import requests |
|
import zipfile |
|
|
|
logger = logging.getLogger("handler.py") |
|
|
|
class LawLookup: |
|
def __init__(self, json_file: str): |
|
self.json_file = json_file |
|
self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip' |
|
if not os.path.exists(self.json_file): |
|
self._download_and_extract_zip() |
|
|
|
with open(self.json_file, 'r', encoding='utf-8-sig') as file: |
|
self.laws_data = json.load(file) |
|
self.laws_dict = self._create_laws_dict() |
|
|
|
def _download_and_extract_zip(self): |
|
zip_path = 'ChLaw.zip' |
|
|
|
response = requests.get(self.zip_url) |
|
with open(zip_path, 'wb') as file: |
|
file.write(response.content) |
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
zip_ref.extract('ChLaw.json') |
|
|
|
|
|
os.remove(zip_path) |
|
|
|
def _create_laws_dict(self): |
|
laws_dict = {} |
|
for law in self.laws_data['Laws']: |
|
law_name = law['LawName'] |
|
articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent'] |
|
for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None} |
|
laws_dict[law_name] = articles |
|
return laws_dict |
|
|
|
def _extract_article_no(self, article_no_str): |
|
try: |
|
|
|
return article_no_str.replace('第', '').replace('條', '').strip() |
|
except ValueError: |
|
return None |
|
|
|
def get_law(self, law_name: str, article_no: str) -> str: |
|
article_no = str(article_no) |
|
if law_name in self.laws_dict: |
|
if article_no in self.laws_dict[law_name]: |
|
return self.laws_dict[law_name][article_no] |
|
else: |
|
return "Article not found." |
|
else: |
|
return "Law not found." |
|
|
|
def get_law_from_token(self, token: str) -> str: |
|
if "|" not in token: return None |
|
if token[0] != "<" and token[-1] != ">": return {} |
|
token = token[1:-1] |
|
law_name, article_no = token.split("|")[:2] |
|
return { |
|
"token": token, |
|
"lawName": law_name, |
|
"articleNo": article_no, |
|
"content": self.get_law(law_name, article_no)} |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16) |
|
self.model = self.pipeline.model |
|
self.tokenizer = self.pipeline.tokenizer |
|
self.law_lookup = LawLookup('ChLaw.json') |
|
self.vocab = self.pipeline.tokenizer.get_vocab() |
|
|
|
law_tokens = {} |
|
for k, v in self.vocab.items(): |
|
if k.startswith("<") and len(k)>1 and k.find("|")>1: |
|
law_tokens[k] = v |
|
self.law_token_ids = list(law_tokens.values()) |
|
self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids) |
|
|
|
def __call__( |
|
self, |
|
query: Dict[str, Any] |
|
) -> List[Dict[str, Any]]: |
|
|
|
max_new_tokens=5 |
|
do_sample=False |
|
topk=10 |
|
base_lambda=1. |
|
|
|
inputs = query.pop("inputs", query) |
|
if not inputs.endswith("<cite>"): |
|
inputs += "<cite>" |
|
logger.info(inputs) |
|
inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda") |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
outputs_logits = outputs.logits[0, -1, self.law_token_ids] |
|
|
|
base_input = self.tokenizer("<cite>", return_tensors="pt").to("cuda") |
|
with torch.no_grad(): |
|
base_output = self.model(**base_input) |
|
|
|
base_logits = base_output.logits[0, -1, self.law_token_ids] |
|
raw_mean = outputs_logits.mean() |
|
outputs_logits = outputs_logits - base_lambda * base_logits |
|
outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean()) |
|
|
|
law_token_probs = outputs_logits.softmax(dim=0) |
|
sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk] |
|
logger.info([self.law_token_names[x] for x in sorted_ids]) |
|
token_objects = [ |
|
self.law_lookup.get_law_from_token(self.law_token_names[x]) |
|
for x in sorted_ids.tolist()] |
|
|
|
return {"tokens": token_objects} |
|
|