from typing import Any, Dict, List import torch, re import transformers from transformers import AutoModelForCausalLM, AutoTokenizer dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler: def __init__(self, path=""): tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code = True) model = AutoModelForCausalLM.from_pretrained( path, return_dict = True, device_map = "auto", load_in_8bit = True, torch_dtype = dtype, trust_remote_code = True, ) gen_config = model.generation_config gen_config.max_new_tokens = 100 gen_config.temperature = 0 gen_config.num_return_sequences = 1 gen_config.pad_token_id = tokenizer.eos_token_id gen_config.eos_token_id = tokenizer.eos_token_id self.generation_config = gen_config self.pipeline = transformers.pipeline( 'text-generation', model=model, tokenizer=tokenizer ) def __call__(self, data: Dict[dict, Any]) -> Dict[str, Any]: inputs = data.pop("inputs", data) question = data.pop("question", None) context = data.pop("context", None) temp = data.pop("temp", None) max_tokens = data.pop("max_tokens", None) bos_token = "" original_system_message = "Below is an instruction that describes a task. Write a response that appropriately completes the request." system_message = "Use the provided context followed by a question to answer it." full_prompt = f"""### Instruction: {system_message} ### Context: {context} ### Question: {question} ### Answer: """ full_prompt = " ".join(full_prompt.split()) self.generation_config.max_new_tokens = max_tokens self.generation_config.temperature = temp result = self.pipeline(full_prompt, generation_config = self.generation_config)[0]['generated_text'] match = re.search(r'### Answer:(.*?)###', result, re.DOTALL) if match: result = match.group(1).strip() pattern = r"### Answer:(.*)" match = re.search(pattern, result) if match: result = match.group(1).strip() return result