"""A simple web interactive chat demo based on gradio.""" import os import time import gradio as gr import numpy as np import spaces import torch import os import lightning as L import torch import time import spaces from snac import SNAC from litgpt import Tokenizer from litgpt.utils import ( num_parameters, ) from litgpt.generate.base import ( generate_AA, generate_ASR, generate_TA, generate_TT, generate_AT, generate_TA_BATCH, ) from typing import Any, Literal, Optional import soundfile as sf from litgpt.model import GPT, Config from lightning.fabric.utilities.load import _lazy_load as lazy_load from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str from utils.snac_utils import get_snac import whisper from tqdm import tqdm from huggingface_hub import snapshot_download from litgpt.generate.base import sample device = "cuda" if torch.cuda.is_available() else "cpu" ckpt_dir = "./checkpoint" OUT_CHUNK = 4096 OUT_RATE = 24000 OUT_CHANNELS = 1 # TODO text_vocabsize = 151936 text_specialtokens = 64 audio_vocabsize = 4096 audio_specialtokens = 64 padded_text_vocabsize = text_vocabsize + text_specialtokens padded_audio_vocabsize = audio_vocabsize + audio_specialtokens _eot = text_vocabsize _pad_t = text_vocabsize + 1 _input_t = text_vocabsize + 2 _answer_t = text_vocabsize + 3 _asr = text_vocabsize + 4 _eoa = audio_vocabsize _pad_a = audio_vocabsize + 1 _input_a = audio_vocabsize + 2 _answer_a = audio_vocabsize + 3 _split = audio_vocabsize + 4 def download_model(ckpt_dir): repo_id = "gpt-omni/mini-omni" snapshot_download(repo_id, local_dir=ckpt_dir, revision="main") if not os.path.exists(ckpt_dir): print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface") download_model(ckpt_dir) snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device) whispermodel = whisper.load_model("small").to(device) whispermodel.eval() text_tokenizer = Tokenizer(ckpt_dir) # fabric = L.Fabric(devices=1, strategy="auto") config = Config.from_file(ckpt_dir + "/model_config.yaml") config.post_adapter = False model = GPT(config, device=device) state_dict = lazy_load(ckpt_dir + "/lit_model.pth") model.load_state_dict(state_dict, strict=True) model = model.to(device) model.eval() def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device): # with torch.no_grad(): mel = mel.unsqueeze(0).to(device) # audio_feature = whisper.decode(whispermodel,mel, options).audio_features audio_feature = whispermodel.embed_audio(mel)[0][:leng] T = audio_feature.size(0) input_ids_AA = [] for i in range(7): input_ids_item = [] input_ids_item.append(layershift(_input_a, i)) input_ids_item += [layershift(_pad_a, i)] * T input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)] input_ids_AA.append(torch.tensor(input_ids_item)) input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t]) input_ids_AA.append(input_id_T) input_ids_AT = [] for i in range(7): input_ids_item = [] input_ids_item.append(layershift(_input_a, i)) input_ids_item += [layershift(_pad_a, i)] * T input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)] input_ids_AT.append(torch.tensor(input_ids_item)) input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t]) input_ids_AT.append(input_id_T) input_ids = [input_ids_AA, input_ids_AT] stacked_inputids = [[] for _ in range(8)] for i in range(2): for j in range(8): stacked_inputids[j].append(input_ids[i][j]) stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids] return torch.stack([audio_feature, audio_feature]), stacked_inputids def next_token_batch( model: GPT, audio_features: torch.tensor, input_ids: list, whisper_lens: int, task: list, input_pos: torch.Tensor, **kwargs: Any, ) -> torch.Tensor: input_pos = input_pos.to(model.device) input_ids = [input_id.to(model.device) for input_id in input_ids] logits_a, logit_t = model( audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task ) for i in range(7): logits_a[i] = logits_a[i][0].unsqueeze(0) logit_t = logit_t[1].unsqueeze(0) next_audio_tokens = [] for logit_a in logits_a: next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype) next_audio_tokens.append(next_a) next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype) return next_audio_tokens, next_t def load_audio(path): audio = whisper.load_audio(path) duration_ms = (len(audio) / 16000) * 1000 audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio) return mel, int(duration_ms / 20) + 1 def generate_audio_data(snac_tokens, snacmodel, device=None): audio = reconstruct_tensors(snac_tokens, device) with torch.inference_mode(): audio_hat = snacmodel.decode(audio) audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0 audio_data = audio_data.astype(np.int16) audio_data = audio_data.tobytes() return audio_data @spaces.GPU @torch.inference_mode() def run_AT_batch_stream( audio_path, stream_stride=4, max_returned_tokens=2048, temperature=0.9, top_k=1, top_p=1.0, eos_id_a=_eoa, eos_id_t=_eot, ): assert os.path.exists(audio_path), f"audio file {audio_path} not found" model.set_kv_cache(batch_size=2, device=device) mel, leng = load_audio(audio_path) audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device) T = input_ids[0].size(1) # device = input_ids[0].device assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}" if model.max_seq_length < max_returned_tokens - 1: raise NotImplementedError( f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" ) input_pos = torch.tensor([T], device=device) list_output = [[] for i in range(8)] tokens_A, token_T = next_token_batch( model, audio_feature.to(torch.float32).to(model.device), input_ids, [T - 3, T - 3], ["A1T2", "A1T2"], input_pos=torch.arange(0, T, device=device), temperature=temperature, top_k=top_k, top_p=top_p, ) for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) list_output[7].append(token_T.tolist()[0]) model_input_ids = [[] for i in range(8)] for i in range(7): tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device)) model_input_ids[i] = torch.stack(model_input_ids[i]) model_input_ids[-1].append(token_T.clone().to(torch.int32)) model_input_ids[-1].append(token_T.clone().to(torch.int32)) model_input_ids[-1] = torch.stack(model_input_ids[-1]) text_end = False index = 1 nums_generate = stream_stride begin_generate = False current_index = 0 for _ in tqdm(range(2, max_returned_tokens - T + 1)): tokens_A, token_T = next_token_batch( model, None, model_input_ids, None, None, input_pos=input_pos, temperature=temperature, top_k=top_k, top_p=top_p, ) if text_end: token_T = torch.tensor([_pad_t], device=device) if tokens_A[-1] == eos_id_a: break if token_T == eos_id_t: text_end = True for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) list_output[7].append(token_T.tolist()[0]) model_input_ids = [[] for i in range(8)] for i in range(7): tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32)) model_input_ids[i].append( torch.tensor([layershift(4097, i)], device=device) ) model_input_ids[i] = torch.stack(model_input_ids[i]) model_input_ids[-1].append(token_T.clone().to(torch.int32)) model_input_ids[-1].append(token_T.clone().to(torch.int32)) model_input_ids[-1] = torch.stack(model_input_ids[-1]) if index == 7: begin_generate = True if begin_generate: current_index += 1 if current_index == nums_generate: current_index = 0 snac = get_snac(list_output, index, nums_generate) audio_stream = generate_audio_data(snac, snacmodel, device) yield audio_stream input_pos = input_pos.add_(1) index += 1 text = text_tokenizer.decode(torch.tensor(list_output[-1])) print(f"text output: {text}") model.clear_kv_cache() return list_output for chunk in run_AT_batch_stream('./data/samples/output1.wav'): pass def process_audio(audio): filepath = audio print(f"filepath: {filepath}") if filepath is None: return cnt = 0 tik = time.time() for chunk in run_AT_batch_stream(filepath): # Convert chunk to numpy array if cnt == 0: print(f"first chunk time cost: {time.time() - tik:.3f}") cnt += 1 audio_data = np.frombuffer(chunk, dtype=np.int16) audio_data = audio_data.reshape(-1, OUT_CHANNELS) yield OUT_RATE, audio_data.astype(np.int16) demo = gr.Interface( process_audio, inputs=gr.Audio(type="filepath", label="Microphone"), outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)], title="Chat Mini-Omni Demo", # live=True, ) demo.queue() demo.launch()