|
from unsloth import FastLanguageModel |
|
from typing import Dict, List, Any |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
max_seq_length = 2048 |
|
dtype = None |
|
load_in_4bit = True |
|
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=path, |
|
max_seq_length=max_seq_length, |
|
dtype=dtype, |
|
load_in_4bit=load_in_4bit, |
|
|
|
) |
|
|
|
self.alpaca_prompt = """ |
|
### Instruction: |
|
{} |
|
|
|
### Input: |
|
{} |
|
|
|
### Response: |
|
""" |
|
|
|
self.EOS_TOKEN = self.tokenizer.eos_token |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
date (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
data = data.pop("inputs", data) |
|
input_text = data.get("input_text", "") |
|
lex_diversity = data.get("lex_diversity", 80) |
|
order_diversity = data.get("order_diversity", 20) |
|
repetition_penalty = data.get("repetition_penalty", 1.0) |
|
use_cache = data.get("use_cache", False) |
|
max_length = data.get("max_length", 128) |
|
|
|
prediction = self.paraphrase( |
|
input_text, |
|
lex_diversity, |
|
order_diversity, |
|
repetition_penalty=repetition_penalty, |
|
use_cache=use_cache, |
|
max_length=max_length |
|
) |
|
|
|
prediction = {'prediction': prediction} |
|
return prediction |
|
|
|
def paraphrase(self, input_text, lex_diversity, order_diversity, repetition_penalty, use_cache, max_length, **kwargs): |
|
FastLanguageModel.for_inference(self.model) |
|
inputs = self.tokenizer( |
|
[ |
|
self.alpaca_prompt.format( |
|
"You are an AI assistant, capable of paraphrasing any text to a human-like version of the text. Human writing often exhibits bursts and lulls, with a mix of long and short sentences", |
|
f"lexical = {lex_diversity}, order = {order_diversity} {input_text}", |
|
"", |
|
) |
|
], return_tensors="pt").to("cuda") |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=max_length, use_cache=False, repetition_penalty=repetition_penalty) |
|
output_text = self.tokenizer.batch_decode(outputs) |
|
return output_text |