suneeln-duke commited on
Commit
a3afef0
1 Parent(s): d94ffb7

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +87 -0
handler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import torch
4
+
5
+ import transformers
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, path=""):
13
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code = True)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ path,
16
+ return_dict = True,
17
+ device_map = "auto",
18
+ load_in_8bit = True,
19
+ torch_dtype = dtype,
20
+ trust_remote_code = True,
21
+ )
22
+
23
+ gen_config = model.generation_config
24
+ gen_config.max_new_tokens = 100
25
+ gen_config.temperature = 0
26
+ gen_config.num_return_sequences = 1
27
+ gen_config.pad_token_id = tokenizer.eos_token_id
28
+ gen_config.eos_token_id = tokenizer.eos_token_id
29
+
30
+ self.generation_config = gen_config
31
+
32
+ self.pipeline = transformers.pipeline(
33
+ 'text-generation', model=model, tokenizer=tokenizer
34
+ )
35
+
36
+
37
+
38
+ def __call__(self, data: Dict[dict, Any]) -> Dict[str, Any]:
39
+ question = data.pop("question", data)
40
+
41
+ context = data.pop("context", None)
42
+
43
+ temp = data.pop("temp", None)
44
+
45
+ max_tokens = data.pop("max_tokens", None)
46
+
47
+ bos_token = "<s>"
48
+
49
+ original_system_message = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
50
+
51
+ system_message = "Use the provided context followed by a question to answer it."
52
+
53
+ full_prompt = f"""<s>### Instruction:
54
+ {system_message}
55
+
56
+ ### Context:
57
+ {context}
58
+
59
+
60
+ ### Question:
61
+
62
+ {question}
63
+
64
+
65
+ ### Answer:
66
+ """
67
+
68
+ full_prompt = " ".join(full_prompt.split())
69
+
70
+ self.generation_config.max_new_tokens = max_tokens
71
+ self.generation_config.temperature = temp
72
+
73
+ result = self.pipeline(full_prompt, generation_config = self.generation_config)[0]['generated_text']
74
+
75
+ match = re.search(r'### Answer:(.*?)###', result, re.DOTALL)
76
+
77
+ if match:
78
+ result = match.group(1).strip()
79
+
80
+ pattern = r"### Answer:(.*)"
81
+
82
+ match = re.search(pattern, result)
83
+
84
+ if match:
85
+ result = match.group(1).strip()
86
+
87
+ return result