Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,257 Bytes
4f6613a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import queue
from dataclasses import dataclass
from typing import Annotated, Literal, Optional
import torch
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
from pydantic.functional_validators import SkipValidation
from fish_speech.conversation import Message, TextPart, VQPart
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
codes: SkipValidation[list[list[int]]]
class ServeTextPart(BaseModel):
type: Literal["text"] = "text"
text: str
class ServeAudioPart(BaseModel):
type: Literal["audio"] = "audio"
audio: bytes
@dataclass
class ASRPackRequest:
audio: torch.Tensor
result_queue: queue.Queue
language: str
class ServeASRRequest(BaseModel):
# The audio should be an uncompressed PCM float16 audio
audios: list[bytes]
sample_rate: int = 44100
language: Literal["zh", "en", "ja", "auto"] = "auto"
class ServeASRTranscription(BaseModel):
text: str
duration: float
huge_gap: bool
class ServeASRSegment(BaseModel):
text: str
start: float
end: float
class ServeTimedASRResponse(BaseModel):
text: str
segments: list[ServeASRSegment]
duration: float
class ServeASRResponse(BaseModel):
transcriptions: list[ServeASRTranscription]
class ServeMessage(BaseModel):
role: Literal["system", "assistant", "user"]
parts: list[ServeVQPart | ServeTextPart]
def to_conversation_message(self):
new_message = Message(role=self.role, parts=[])
for part in self.parts:
if isinstance(part, ServeTextPart):
new_message.parts.append(TextPart(text=part.text))
elif isinstance(part, ServeVQPart):
new_message.parts.append(
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
)
else:
raise ValueError(f"Unsupported part type: {part}")
return new_message
class ServeRequest(BaseModel):
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
max_new_tokens: int = 1024
top_p: float = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
streaming: bool = False
num_samples: int = 1
early_stop_threshold: float = 1.0
class ServeVQGANEncodeRequest(BaseModel):
# The audio here should be in wav, mp3, etc
audios: list[bytes]
class ServeVQGANEncodeResponse(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeRequest(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeResponse(BaseModel):
# The audio here should be in PCM float16 format
audios: list[bytes]
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeForwardMessage(BaseModel):
role: str
content: str
class ServeResponse(BaseModel):
messages: list[ServeMessage]
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] = {}
class ServeStreamDelta(BaseModel):
role: Literal["system", "assistant", "user"] | None = None
part: ServeVQPart | ServeTextPart | None = None
class ServeStreamResponse(BaseModel):
sample_id: int = 0
delta: ServeStreamDelta | None = None
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] | None = None
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeChatRequestV1(BaseModel):
model: str = "llama3-8b"
messages: list[ServeForwardMessage] = []
audio: bytes | None = None
temperature: float = 1.0
top_p: float = 1.0
max_tokens: int = 256
voice: str = "jessica"
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on-demand", "never"] = "never"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|