import pyrootutils import torch import torch.nn.functional as F from matplotlib import pyplot as plt from transformers import AutoTokenizer # register eval resolver and root pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from torch.utils.data import DataLoader from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator from tools.llama.generate import load_model def smooth( scalars: list[float], weight: float ) -> list[float]: # Weight between 0 and 1 last = scalars[0] # First value in the plot (first timestep) smoothed = list() for point in scalars: smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value smoothed.append(smoothed_val) # Save it last = smoothed_val # Anchor the last smoothed value return smoothed @torch.inference_mode() def analyze_one_model(loader, config, weight, max_length): device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model( config, weight, device, torch.bfloat16, max_length, compile=False, )[0] current_step = 0 model.eval() semantic_loss_sum = torch.zeros( max_length, dtype=torch.float32, device=device, ) counter = torch.zeros( max_length, dtype=torch.long, device=device, ) for batch in loader: batch = {k: v.to(device) for k, v in batch.items()} labels = batch["labels"] outputs = model( inp=batch["inputs"], key_padding_mask=batch["attention_masks"], ) token_logits = outputs.token_logits codebook_logits = outputs.codebook_logits # Generate labels base_loss = F.cross_entropy( token_logits.reshape(-1, token_logits.size(-1)), labels[:, 0].reshape(-1), ignore_index=-100, reduction="none", ) codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT semantic_loss = F.cross_entropy( codebook_logits.reshape(-1, codebook_logits.size(-1)), codebook_labels.reshape(-1), ignore_index=-100, reduction="none", ) base_loss = base_loss.reshape(labels[:, 0].shape) semantic_loss = semantic_loss.reshape(codebook_labels.shape) semantic_loss_frame = semantic_loss.mean(-1) pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks for loss_sample, pad in zip(semantic_loss_frame, pad_pos): semantic_loss_sum[~pad] += loss_sample[~pad] counter[~pad] += 1 current_step += 1 if current_step == 10: break semantic_loss = semantic_loss.cpu() counter = counter.cpu() xs, ys = [], [] for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)): if count > 0: xs.append(i) ys.append((loss / count).item()) # for better loss visualization smoothed_ys = smooth(ys, 0.95) # Unload model del model torch.cuda.empty_cache() return xs, ys, smoothed_ys def main(): tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") max_length = 4096 ds = AutoAugTextDataset( ["data/protos/sft/云天河"], tokenizer=tokenizer, use_speaker=False, interactive_prob=1.0, max_length=max_length, ) loader = DataLoader( ds, batch_size=8, collate_fn=TextDataCollator(tokenizer, max_length=max_length), num_workers=0, shuffle=False, ) plt.figure(figsize=(10, 5), dpi=200) plt.xlabel("Frame") plt.ylabel("Loss") plt.yscale("log") plt.title("Semantic Loss") plt.grid(which="both", axis="both") plt.xlim(0, max_length) tests = [ ( "pertrain-medium", "dual_ar_2_codebook_medium", "checkpoints/text2semantic-pretrain-medium-2k-v1.pth", ), ( "sft-medium", "dual_ar_2_codebook_medium", "checkpoints/text2semantic-sft-medium-v1.1-4k.pth", ), ( "sft-large", "dual_ar_2_codebook_large", "checkpoints/text2semantic-sft-large-v1.1-4k.pth", ), ] for name, config, weight in tests: xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length) plt.plot(xs, smoothed_ys, label=name) plt.legend() plt.savefig("semantic_loss.png") if __name__ == "__main__": main()