File size: 4,395 Bytes
ff83c51
7fd7210
ff83c51
 
 
 
 
 
 
 
7fd7210
 
ff83c51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed842f
ff83c51
 
 
 
 
 
 
 
 
 
 
 
 
 
2a051f4
ff83c51
2a051f4
 
 
 
 
319721c
7fd7210
319721c
 
7fd7210
ff83c51
 
 
 
 
 
 
 
 
 
 
 
7fd7210
dec94d5
ff83c51
7fd7210
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Dict, List, Any
import logging
from transformers import pipeline
import os
import re
import json
import torch
import requests
import zipfile

logger = logging.getLogger("handler.py")

class LawLookup:
    def __init__(self, json_file: str):
        self.json_file = json_file
        self.zip_url = 'https://law.moj.gov.tw/api/data/chlaw.json.zip'
        if not os.path.exists(self.json_file):
            self._download_and_extract_zip()

        with open(self.json_file, 'r', encoding='utf-8-sig') as file:
            self.laws_data = json.load(file)
        self.laws_dict = self._create_laws_dict()

    def _download_and_extract_zip(self):
        zip_path = 'ChLaw.zip'
        # Download the zip file
        response = requests.get(self.zip_url)
        with open(zip_path, 'wb') as file:
            file.write(response.content)

        # Extract only the ChLaw.json file from the zip
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extract('ChLaw.json')

        # Remove the zip file after extraction
        os.remove(zip_path)

    def _create_laws_dict(self):
        laws_dict = {}
        for law in self.laws_data['Laws']:
            law_name = law['LawName']
            articles = {self._extract_article_no(article['ArticleNo']): article['ArticleContent']
                        for article in law['LawArticles'] if self._extract_article_no(article['ArticleNo']) is not None}
            laws_dict[law_name] = articles
        return laws_dict

    def _extract_article_no(self, article_no_str):
        try:
            # Extract the numeric part of the article number
            return article_no_str.replace('第', '').replace('條', '').strip()
        except ValueError:
            return None

    def get_law(self, law_name: str, article_no: str) -> str:
        article_no = str(article_no)
        if law_name in self.laws_dict:
            if article_no in self.laws_dict[law_name]:
                return self.laws_dict[law_name][article_no]
            else:
                return "Article not found."
        else:
            return "Law not found."

    def get_law_from_token(self, token: str) -> str:
        if "|" not in token: return None
        if token[0] != "<" and token[-1] != ">": return {}
        token = token[1:-1]
        law_name, article_no = token.split("|")[:2]
        return {
            "token": token,
            "lawName": law_name,
            "articleNo": article_no,
            "content": self.get_law(law_name, article_no)}

class EndpointHandler():
    def __init__(self, path=""):
        self.pipeline = pipeline(model="amy011872/LawToken-7B-a2", device=0, torch_dtype=torch.float16)
        self.model = self.pipeline.model
        self.tokenizer = self.pipeline.tokenizer
        self.law_lookup = LawLookup('ChLaw.json')
        self.vocab = self.pipeline.tokenizer.get_vocab()

        law_tokens = {}
        for k, v in self.vocab.items():
            if k.startswith("<") and len(k)>1 and k.find("|")>1:
                law_tokens[k] = v
        self.law_token_ids = list(law_tokens.values())
        self.law_token_names = self.tokenizer.convert_ids_to_tokens(self.law_token_ids)

    def __call__(
            self, 
            query: Dict[str, Any]
            ) -> List[Dict[str, Any]]:

        max_new_tokens=5
        do_sample=False
        topk=3
        base_lambda=1.
        inputs = query.pop("inputs", query)
        logger.info(type(inputs))
        if isinstance(inputs, str):
            inputs = [inputs]
        inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = self.model(**inputs)
        outputs_logits = outputs.logits[0, -1, self.law_token_ids]
        base_logits = outputs.logits[0, -1, self.law_token_ids]

        raw_mean = outputs_logits.mean()
        outputs_logits = outputs_logits - base_lambda * base_logits
        outputs_logits = outputs_logits + (raw_mean - outputs_logits.mean())

        law_token_probs = outputs_logits.softmax(dim=0)
        sorted_ids = torch.argsort(law_token_probs, descending=True)[:topk]
        print([self.law_token_names[x] for x in sorted_ids])
        token_objects = [
            self.law_lookup.get_law_from_token(self.law_token_names[x])
            for x in sorted_ids.tolist()]
        return {"tokens": token_objects}