Commit
•
d29dd9f
1
Parent(s):
19c1b44
Fixing return structure
Browse files- handler.py +11 -8
handler.py
CHANGED
@@ -27,11 +27,13 @@ class EndpointHandler:
|
|
27 |
|
28 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
|
30 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
31 |
-
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
37 |
|
@@ -39,6 +41,7 @@ class EndpointHandler:
|
|
39 |
if 'prompt' in data.keys():
|
40 |
text = data['prompt']
|
41 |
else:
|
|
|
42 |
user_data = data.pop('query',data)
|
43 |
text = self.prompt_ar.format_map({'Question':user_data})
|
44 |
inputs = data.pop("inputs", data)
|
@@ -71,10 +74,10 @@ class EndpointHandler:
|
|
71 |
response = self.tokenizer.batch_decode(generate_ids,
|
72 |
skip_special_tokens=True,
|
73 |
clean_up_tokenization_spaces=True)[0]
|
74 |
-
final_response = response.split("### Response: [|AI|]")
|
75 |
-
turn = [f'[|Human|] {query}', f'[|AI|] {final_response[-1]}']
|
76 |
-
chat_history.extend(turn)
|
77 |
if 'prompt' in data.keys():
|
78 |
return response
|
79 |
else:
|
|
|
|
|
|
|
80 |
return {"response": final_response, "chat_history": chat_history}
|
|
|
27 |
|
28 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
|
30 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(path)
|
31 |
+
# self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",
|
32 |
+
# offload_folder='offload',
|
33 |
+
# trust_remote_code=True,
|
34 |
+
# load_in_8bit=True)
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.model = model
|
37 |
|
38 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
39 |
|
|
|
41 |
if 'prompt' in data.keys():
|
42 |
text = data['prompt']
|
43 |
else:
|
44 |
+
print(data.keys())
|
45 |
user_data = data.pop('query',data)
|
46 |
text = self.prompt_ar.format_map({'Question':user_data})
|
47 |
inputs = data.pop("inputs", data)
|
|
|
74 |
response = self.tokenizer.batch_decode(generate_ids,
|
75 |
skip_special_tokens=True,
|
76 |
clean_up_tokenization_spaces=True)[0]
|
|
|
|
|
|
|
77 |
if 'prompt' in data.keys():
|
78 |
return response
|
79 |
else:
|
80 |
+
final_response = response.split("### Response: [|AI|]")
|
81 |
+
turn = [f'[|Human|] {query}', f'[|AI|] {final_response[-1]}']
|
82 |
+
chat_history.extend(turn)
|
83 |
return {"response": final_response, "chat_history": chat_history}
|