import os import queue from dataclasses import dataclass from typing import Annotated, Literal, Optional import torch from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist from pydantic.functional_validators import SkipValidation from fish_speech.conversation import Message, TextPart, VQPart GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1)) class ServeVQPart(BaseModel): type: Literal["vq"] = "vq" codes: SkipValidation[list[list[int]]] class ServeTextPart(BaseModel): type: Literal["text"] = "text" text: str class ServeAudioPart(BaseModel): type: Literal["audio"] = "audio" audio: bytes @dataclass class ASRPackRequest: audio: torch.Tensor result_queue: queue.Queue language: str class ServeASRRequest(BaseModel): # The audio should be an uncompressed PCM float16 audio audios: list[bytes] sample_rate: int = 44100 language: Literal["zh", "en", "ja", "auto"] = "auto" class ServeASRTranscription(BaseModel): text: str duration: float huge_gap: bool class ServeASRSegment(BaseModel): text: str start: float end: float class ServeTimedASRResponse(BaseModel): text: str segments: list[ServeASRSegment] duration: float class ServeASRResponse(BaseModel): transcriptions: list[ServeASRTranscription] class ServeMessage(BaseModel): role: Literal["system", "assistant", "user"] parts: list[ServeVQPart | ServeTextPart] def to_conversation_message(self): new_message = Message(role=self.role, parts=[]) for part in self.parts: if isinstance(part, ServeTextPart): new_message.parts.append(TextPart(text=part.text)) elif isinstance(part, ServeVQPart): new_message.parts.append( VQPart(codes=torch.tensor(part.codes, dtype=torch.int)) ) else: raise ValueError(f"Unsupported part type: {part}") return new_message class ServeRequest(BaseModel): messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)] max_new_tokens: int = 1024 top_p: float = 0.7 repetition_penalty: float = 1.2 temperature: float = 0.7 streaming: bool = False num_samples: int = 1 early_stop_threshold: float = 1.0 class ServeVQGANEncodeRequest(BaseModel): # The audio here should be in wav, mp3, etc audios: list[bytes] class ServeVQGANEncodeResponse(BaseModel): tokens: SkipValidation[list[list[list[int]]]] class ServeVQGANDecodeRequest(BaseModel): tokens: SkipValidation[list[list[list[int]]]] class ServeVQGANDecodeResponse(BaseModel): # The audio here should be in PCM float16 format audios: list[bytes] class ServeReferenceAudio(BaseModel): audio: bytes text: str class ServeForwardMessage(BaseModel): role: str content: str class ServeResponse(BaseModel): messages: list[ServeMessage] finish_reason: Literal["stop", "error"] | None = None stats: dict[str, int | float | str] = {} class ServeStreamDelta(BaseModel): role: Literal["system", "assistant", "user"] | None = None part: ServeVQPart | ServeTextPart | None = None class ServeStreamResponse(BaseModel): sample_id: int = 0 delta: ServeStreamDelta | None = None finish_reason: Literal["stop", "error"] | None = None stats: dict[str, int | float | str] | None = None class ServeReferenceAudio(BaseModel): audio: bytes text: str def __repr__(self) -> str: return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})" class ServeChatRequestV1(BaseModel): model: str = "llama3-8b" messages: list[ServeForwardMessage] = [] audio: bytes | None = None temperature: float = 1.0 top_p: float = 1.0 max_tokens: int = 256 voice: str = "jessica" tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3" tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128 class ServeTTSRequest(BaseModel): text: str chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 # Audio format format: Literal["wav", "pcm", "mp3"] = "wav" mp3_bitrate: Literal[64, 128, 192] = 128 # References audios for in-context learning references: list[ServeReferenceAudio] = [] # Reference id # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 reference_id: str | None = None seed: int | None = None use_memory_cache: Literal["on-demand", "never"] = "never" # Normalize text for en & zh, this increase stability for numbers normalize: bool = True mp3_bitrate: Optional[int] = 64 opus_bitrate: Optional[int] = -1000 # Balance mode will reduce latency to 300ms, but may decrease stability latency: Literal["normal", "balanced"] = "normal" # not usually used below streaming: bool = False max_new_tokens: int = 1024 top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7