llm-swp / handler.py
niruemon's picture
Update handler.py
bf7289c verified
raw
history blame contribute delete
No virus
2.08 kB
from unsloth import FastLanguageModel
import torch
import os
class EndpointHandler:
def __init__(self, path=""):
# ระบุชื่อโมเดลที่คุณต้องการใช้งาน
model_name = "defog/llama-3-sqlcoder-8b"
torch.cuda.empty_cache()
os.environ[ "PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
# Configuration settings
max_seq_length = 2048
dtype = None # Keep as None, you can change later if needed
load_in_4bit = True
# กำหนดไดเรกทอรีสำหรับการ offload โมเดล (สร้างขึ้นถ้ายังไม่มี)
offload_dir = "./offload"
os.makedirs(offload_dir, exist_ok=True)
# โหลดโมเดลและ tokenizer
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
offload_folder=offload_dir # ระบุโฟลเดอร์สำหรับการ offload
)
# เตรียมโมเดลสำหรับการประมวลผลข้อความ
FastLanguageModel.for_inference(self.model)
def __call__(self, data):
# รับข้อความ input จากผู้ใช้
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input text provided."}
# สร้างข้อความโดยใช้โมเดล
try:
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=150)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated_text}
except Exception as e:
return {"error": str(e)}