|
from unsloth import FastLanguageModel |
|
import torch |
|
import os |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
model_name = "defog/llama-3-sqlcoder-8b" |
|
|
|
|
|
max_seq_length = 2048 |
|
dtype = None |
|
load_in_4bit = True |
|
|
|
|
|
offload_dir = "./offload" |
|
os.makedirs(offload_dir, exist_ok=True) |
|
|
|
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=model_name, |
|
max_seq_length=max_seq_length, |
|
dtype=dtype, |
|
load_in_4bit=load_in_4bit, |
|
offload_folder=offload_dir |
|
) |
|
|
|
|
|
FastLanguageModel.for_inference(self.model) |
|
|
|
def __call__(self, data): |
|
|
|
input_text = data.get("inputs", "") |
|
if not input_text: |
|
return {"error": "No input text provided."} |
|
|
|
|
|
try: |
|
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) |
|
outputs = self.model.generate(**inputs, max_new_tokens=150) |
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return {"generated_text": generated_text} |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|