File size: 2,080 Bytes
2e90b71
3251ecc
d81f76b
0af61f8
bf7289c
36f9031
 
2e90b71
 
bf7289c
 
 
 
2e90b71
 
 
 
0af61f8
d81f76b
 
 
 
36f9031
2e90b71
 
 
 
 
 
d81f76b
36f9031
2e90b71
 
36f9031
 
 
 
 
 
 
 
 
2e90b71
 
 
36f9031
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from unsloth import FastLanguageModel
import torch
import os


class EndpointHandler:
    def __init__(self, path=""):
        # ระบุชื่อโมเดลที่คุณต้องการใช้งาน
        model_name = "defog/llama-3-sqlcoder-8b"
        
        torch.cuda.empty_cache() 
        os.environ[ "PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
        
        # Configuration settings
        max_seq_length = 2048
        dtype = None  # Keep as None, you can change later if needed
        load_in_4bit = True

        # กำหนดไดเรกทอรีสำหรับการ offload โมเดล (สร้างขึ้นถ้ายังไม่มี)
        offload_dir = "./offload"
        os.makedirs(offload_dir, exist_ok=True)

        # โหลดโมเดลและ tokenizer
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_name,
            max_seq_length=max_seq_length,
            dtype=dtype,
            load_in_4bit=load_in_4bit,
            offload_folder=offload_dir  # ระบุโฟลเดอร์สำหรับการ offload
        )

        # เตรียมโมเดลสำหรับการประมวลผลข้อความ
        FastLanguageModel.for_inference(self.model)

    def __call__(self, data):
        # รับข้อความ input จากผู้ใช้
        input_text = data.get("inputs", "")
        if not input_text:
            return {"error": "No input text provided."}

        # สร้างข้อความโดยใช้โมเดล
        try:
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            outputs = self.model.generate(**inputs, max_new_tokens=150)
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return {"generated_text": generated_text}
        except Exception as e:
            return {"error": str(e)}