lm3 / handler.py
ahmetmete's picture
Create handler.py
da3cbbe verified
raw
history blame
1.25 kB
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
class EndpointHandler():
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto")
self.model.generation_config = GenerationConfig.from_pretrained(path)
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop('inputs', data)
messages = [{"role": "user", "content": inputs}]
# Mesajları modelin anlayacağı formata dönüştürme
input_texts = [message["content"] for message in messages]
input_text = self.tokenizer.eos_token.join(input_texts)
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
# Modelden yanıt üretme
outputs = self.model.generate(input_ids.to(self.model.device), max_new_tokens=100)
# Üretilen yanıtı çözme
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"result": result}]