|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import PreTrainedModel, PreTrainedTokenizer |
|
from vllm import AsyncLLMEngine |
|
|
|
from ..data import Template |
|
from ..data.mm_plugin import ImageInput, VideoInput |
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments |
|
|
|
|
|
@dataclass |
|
class Response: |
|
response_text: str |
|
response_length: int |
|
prompt_length: int |
|
finish_reason: Literal["stop", "length"] |
|
|
|
|
|
class BaseEngine(ABC): |
|
r""" |
|
Base class for inference engine of chat models. |
|
|
|
Must implements async methods: chat(), stream_chat() and get_scores(). |
|
""" |
|
|
|
model: Union["PreTrainedModel", "AsyncLLMEngine"] |
|
tokenizer: "PreTrainedTokenizer" |
|
can_generate: bool |
|
template: "Template" |
|
generating_args: Dict[str, Any] |
|
|
|
@abstractmethod |
|
def __init__( |
|
self, |
|
model_args: "ModelArguments", |
|
data_args: "DataArguments", |
|
finetuning_args: "FinetuningArguments", |
|
generating_args: "GeneratingArguments", |
|
) -> None: |
|
r""" |
|
Initializes an inference engine. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
async def chat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> List["Response"]: |
|
r""" |
|
Gets a list of responses of the chat model. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
async def stream_chat( |
|
self, |
|
messages: Sequence[Dict[str, str]], |
|
system: Optional[str] = None, |
|
tools: Optional[str] = None, |
|
image: Optional["ImageInput"] = None, |
|
video: Optional["VideoInput"] = None, |
|
**input_kwargs, |
|
) -> AsyncGenerator[str, None]: |
|
r""" |
|
Gets the response token-by-token of the chat model. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
async def get_scores( |
|
self, |
|
batch_input: List[str], |
|
**input_kwargs, |
|
) -> List[float]: |
|
r""" |
|
Gets a list of scores of the reward model. |
|
""" |
|
... |
|
|