mrfakename commited on
Commit
9eac142
1 Parent(s): 8474faf

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

app.py CHANGED
@@ -4,7 +4,6 @@ import torchaudio
4
  import gradio as gr
5
  import numpy as np
6
  import tempfile
7
- from einops import rearrange
8
  from vocos import Vocos
9
  from pydub import AudioSegment, silence
10
  from model import CFM, UNetT, DiT, MMDiT
@@ -175,7 +174,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
175
 
176
  generated = generated.to(torch.float32)
177
  generated = generated[:, ref_audio_len:, :]
178
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
179
  generated_wave = vocos.decode(generated_mel_spec.cpu())
180
  if rms < target_rms:
181
  generated_wave = generated_wave * rms / target_rms
 
4
  import gradio as gr
5
  import numpy as np
6
  import tempfile
 
7
  from vocos import Vocos
8
  from pydub import AudioSegment, silence
9
  from model import CFM, UNetT, DiT, MMDiT
 
174
 
175
  generated = generated.to(torch.float32)
176
  generated = generated[:, ref_audio_len:, :]
177
+ generated_mel_spec = generated.permute(0, 2, 1)
178
  generated_wave = vocos.decode(generated_mel_spec.cpu())
179
  if rms < target_rms:
180
  generated_wave = generated_wave * rms / target_rms
inference-cli.py CHANGED
@@ -11,7 +11,6 @@ import torch
11
  import torchaudio
12
  import tqdm
13
  from cached_path import cached_path
14
- from einops import rearrange
15
  from pydub import AudioSegment, silence
16
  from transformers import pipeline
17
  from vocos import Vocos
@@ -274,7 +273,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cr
274
 
275
  generated = generated.to(torch.float32)
276
  generated = generated[:, ref_audio_len:, :]
277
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
278
  generated_wave = vocos.decode(generated_mel_spec.cpu())
279
  if rms < target_rms:
280
  generated_wave = generated_wave * rms / target_rms
@@ -427,4 +426,4 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
427
  print(f.name)
428
 
429
 
430
- process(ref_audio, ref_text, gen_text, model, remove_silence)
 
11
  import torchaudio
12
  import tqdm
13
  from cached_path import cached_path
 
14
  from pydub import AudioSegment, silence
15
  from transformers import pipeline
16
  from vocos import Vocos
 
273
 
274
  generated = generated.to(torch.float32)
275
  generated = generated[:, ref_audio_len:, :]
276
+ generated_mel_spec = generated.permute(0, 2, 1)
277
  generated_wave = vocos.decode(generated_mel_spec.cpu())
278
  if rms < target_rms:
279
  generated_wave = generated_wave * rms / target_rms
 
426
  print(f.name)
427
 
428
 
429
+ process(ref_audio, ref_text, gen_text, model, remove_silence)
model/backbones/dit.py CHANGED
@@ -13,8 +13,6 @@ import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
16
- from einops import repeat
17
-
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from model.modules import (
@@ -134,7 +132,7 @@ class DiT(nn.Module):
134
  ):
135
  batch, seq_len = x.shape[0], x.shape[1]
136
  if time.ndim == 0:
137
- time = repeat(time, ' -> b', b = batch)
138
 
139
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
  t = self.time_embed(time)
 
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
 
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from model.modules import (
 
132
  ):
133
  batch, seq_len = x.shape[0], x.shape[1]
134
  if time.ndim == 0:
135
+ time = time.repeat(batch)
136
 
137
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
138
  t = self.time_embed(time)
model/backbones/mmdit.py CHANGED
@@ -12,8 +12,6 @@ from __future__ import annotations
12
  import torch
13
  from torch import nn
14
 
15
- from einops import repeat
16
-
17
  from x_transformers.x_transformers import RotaryEmbedding
18
 
19
  from model.modules import (
@@ -115,7 +113,7 @@ class MMDiT(nn.Module):
115
  ):
116
  batch = x.shape[0]
117
  if time.ndim == 0:
118
- time = repeat(time, ' -> b', b = batch)
119
 
120
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
  t = self.time_embed(time)
 
12
  import torch
13
  from torch import nn
14
 
 
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from model.modules import (
 
113
  ):
114
  batch = x.shape[0]
115
  if time.ndim == 0:
116
+ time = time.repeat(batch)
117
 
118
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
119
  t = self.time_embed(time)
model/backbones/unett.py CHANGED
@@ -14,8 +14,6 @@ import torch
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
17
- from einops import repeat, pack, unpack
18
-
19
  from x_transformers import RMSNorm
20
  from x_transformers.x_transformers import RotaryEmbedding
21
 
@@ -155,7 +153,7 @@ class UNetT(nn.Module):
155
  ):
156
  batch, seq_len = x.shape[0], x.shape[1]
157
  if time.ndim == 0:
158
- time = repeat(time, ' -> b', b = batch)
159
 
160
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
  t = self.time_embed(time)
@@ -163,7 +161,7 @@ class UNetT(nn.Module):
163
  x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
 
165
  # postfix time t to input x, [b n d] -> [b n+1 d]
166
- x, ps = pack((t, x), 'b * d')
167
  if mask is not None:
168
  mask = F.pad(mask, (1, 0), value=1)
169
 
@@ -196,6 +194,6 @@ class UNetT(nn.Module):
196
 
197
  assert len(skips) == 0
198
 
199
- _, x = unpack(self.norm_out(x), ps, 'b * d')
200
 
201
  return self.proj_out(x)
 
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
 
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
 
153
  ):
154
  batch, seq_len = x.shape[0], x.shape[1]
155
  if time.ndim == 0:
156
+ time = time.repeat(batch)
157
 
158
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
159
  t = self.time_embed(time)
 
161
  x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
162
 
163
  # postfix time t to input x, [b n d] -> [b n+1 d]
164
+ x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
165
  if mask is not None:
166
  mask = F.pad(mask, (1, 0), value=1)
167
 
 
194
 
195
  assert len(skips) == 0
196
 
197
+ x = self.norm_out(x)[:, 1:, :] # unpack t from x
198
 
199
  return self.proj_out(x)
model/cfm.py CHANGED
@@ -18,10 +18,7 @@ from torch.nn.utils.rnn import pad_sequence
18
 
19
  from torchdiffeq import odeint
20
 
21
- from einops import rearrange
22
-
23
  from model.modules import MelSpec
24
-
25
  from model.utils import (
26
  default, exists,
27
  list_str_to_idx, list_str_to_tensor,
@@ -105,7 +102,7 @@ class CFM(nn.Module):
105
 
106
  if cond.ndim == 2:
107
  cond = self.mel_spec(cond)
108
- cond = rearrange(cond, 'b d n -> b n d')
109
  assert cond.shape[-1] == self.num_channels
110
 
111
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
@@ -144,7 +141,7 @@ class CFM(nn.Module):
144
 
145
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
146
  cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
147
- cond_mask = rearrange(cond_mask, '... -> ... 1')
148
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
149
 
150
  if batch > 1:
@@ -199,7 +196,7 @@ class CFM(nn.Module):
199
  out = torch.where(cond_mask, cond, out)
200
 
201
  if exists(vocoder):
202
- out = rearrange(out, 'b n d -> b d n')
203
  out = vocoder(out)
204
 
205
  return out, trajectory
@@ -215,7 +212,7 @@ class CFM(nn.Module):
215
  # handle raw wave
216
  if inp.ndim == 2:
217
  inp = self.mel_spec(inp)
218
- inp = rearrange(inp, 'b d n -> b n d')
219
  assert inp.shape[-1] == self.num_channels
220
 
221
  batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
@@ -252,7 +249,7 @@ class CFM(nn.Module):
252
  # TODO. noise_scheduler
253
 
254
  # sample xt (φ_t(x) in the paper)
255
- t = rearrange(time, 'b -> b 1 1')
256
  φ = (1 - t) * x0 + t * x1
257
  flow = x1 - x0
258
 
 
18
 
19
  from torchdiffeq import odeint
20
 
 
 
21
  from model.modules import MelSpec
 
22
  from model.utils import (
23
  default, exists,
24
  list_str_to_idx, list_str_to_tensor,
 
102
 
103
  if cond.ndim == 2:
104
  cond = self.mel_spec(cond)
105
+ cond = cond.permute(0, 2, 1)
106
  assert cond.shape[-1] == self.num_channels
107
 
108
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
 
141
 
142
  cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
143
  cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
144
+ cond_mask = cond_mask.unsqueeze(-1)
145
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
146
 
147
  if batch > 1:
 
196
  out = torch.where(cond_mask, cond, out)
197
 
198
  if exists(vocoder):
199
+ out = out.permute(0, 2, 1)
200
  out = vocoder(out)
201
 
202
  return out, trajectory
 
212
  # handle raw wave
213
  if inp.ndim == 2:
214
  inp = self.mel_spec(inp)
215
+ inp = inp.permute(0, 2, 1)
216
  assert inp.shape[-1] == self.num_channels
217
 
218
  batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
 
249
  # TODO. noise_scheduler
250
 
251
  # sample xt (φ_t(x) in the paper)
252
+ t = time.unsqueeze(-1).unsqueeze(-1)
253
  φ = (1 - t) * x0 + t * x1
254
  flow = x1 - x0
255
 
model/dataset.py CHANGED
@@ -9,8 +9,6 @@ import torchaudio
9
  from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
 
12
- from einops import rearrange
13
-
14
  from model.modules import MelSpec
15
 
16
 
@@ -54,11 +52,11 @@ class HFDataset(Dataset):
54
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
  audio_tensor = resampler(audio_tensor)
56
 
57
- audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
 
59
  mel_spec = self.mel_spectrogram(audio_tensor)
60
 
61
- mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
 
63
  text = row['text']
64
 
@@ -114,7 +112,7 @@ class CustomDataset(Dataset):
114
  audio = resampler(audio)
115
 
116
  mel_spec = self.mel_spectrogram(audio)
117
- mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
 
119
  return dict(
120
  mel_spec = mel_spec,
 
9
  from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
 
 
 
12
  from model.modules import MelSpec
13
 
14
 
 
52
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
53
  audio_tensor = resampler(audio_tensor)
54
 
55
+ audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
56
 
57
  mel_spec = self.mel_spectrogram(audio_tensor)
58
 
59
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
60
 
61
  text = row['text']
62
 
 
112
  audio = resampler(audio)
113
 
114
  mel_spec = self.mel_spectrogram(audio)
115
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
116
 
117
  return dict(
118
  mel_spec = mel_spec,
model/modules.py CHANGED
@@ -16,7 +16,6 @@ from torch import nn
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
19
- from einops import rearrange
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
 
@@ -54,7 +53,7 @@ class MelSpec(nn.Module):
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
- inp = rearrange(inp, 'b 1 nw -> b nw')
58
 
59
  assert len(inp.shape) == 2
60
 
@@ -101,9 +100,9 @@ class ConvPositionEmbedding(nn.Module):
101
  mask = mask[..., None]
102
  x = x.masked_fill(~mask, 0.)
103
 
104
- x = rearrange(x, 'b n d -> b d n')
105
  x = self.conv1d(x)
106
- out = rearrange(x, 'b d n -> b n d')
107
 
108
  if mask is not None:
109
  out = out.masked_fill(~mask, 0.)
@@ -345,7 +344,7 @@ class AttnProcessor:
345
  # mask. e.g. inference got a batch with different target durations, mask out the padding
346
  if mask is not None:
347
  attn_mask = mask
348
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
  else:
351
  attn_mask = None
@@ -360,7 +359,7 @@ class AttnProcessor:
360
  x = attn.to_out[1](x)
361
 
362
  if mask is not None:
363
- mask = rearrange(mask, 'b n -> b n 1')
364
  x = x.masked_fill(~mask, 0.)
365
 
366
  return x
@@ -422,7 +421,7 @@ class JointAttnProcessor:
422
  # mask. e.g. inference got a batch with different target durations, mask out the padding
423
  if mask is not None:
424
  attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
  else:
428
  attn_mask = None
@@ -445,7 +444,7 @@ class JointAttnProcessor:
445
  c = attn.to_out_c(c)
446
 
447
  if mask is not None:
448
- mask = rearrange(mask, 'b n -> b n 1')
449
  x = x.masked_fill(~mask, 0.)
450
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
 
 
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
 
19
  from x_transformers.x_transformers import apply_rotary_pos_emb
20
 
21
 
 
53
 
54
  def forward(self, inp):
55
  if len(inp.shape) == 3:
56
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
57
 
58
  assert len(inp.shape) == 2
59
 
 
100
  mask = mask[..., None]
101
  x = x.masked_fill(~mask, 0.)
102
 
103
+ x = x.permute(0, 2, 1)
104
  x = self.conv1d(x)
105
+ out = x.permute(0, 2, 1)
106
 
107
  if mask is not None:
108
  out = out.masked_fill(~mask, 0.)
 
344
  # mask. e.g. inference got a batch with different target durations, mask out the padding
345
  if mask is not None:
346
  attn_mask = mask
347
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
348
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
349
  else:
350
  attn_mask = None
 
359
  x = attn.to_out[1](x)
360
 
361
  if mask is not None:
362
+ mask = mask.unsqueeze(-1)
363
  x = x.masked_fill(~mask, 0.)
364
 
365
  return x
 
421
  # mask. e.g. inference got a batch with different target durations, mask out the padding
422
  if mask is not None:
423
  attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
424
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
425
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
426
  else:
427
  attn_mask = None
 
444
  c = attn.to_out_c(c)
445
 
446
  if mask is not None:
447
+ mask = mask.unsqueeze(-1)
448
  x = x.masked_fill(~mask, 0.)
449
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
450
 
model/trainer.py CHANGED
@@ -10,8 +10,6 @@ from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
13
- from einops import rearrange
14
-
15
  from accelerate import Accelerator
16
  from accelerate.utils import DistributedDataParallelKwargs
17
 
@@ -222,7 +220,7 @@ class Trainer:
222
  for batch in progress_bar:
223
  with self.accelerator.accumulate(self.model):
224
  text_inputs = batch['text']
225
- mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
226
  mel_lengths = batch["mel_lengths"]
227
 
228
  # TODO. add duration predictor training
 
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
 
 
13
  from accelerate import Accelerator
14
  from accelerate.utils import DistributedDataParallelKwargs
15
 
 
220
  for batch in progress_bar:
221
  with self.accelerator.accumulate(self.model):
222
  text_inputs = batch['text']
223
+ mel_spec = batch['mel'].permute(0, 2, 1)
224
  mel_lengths = batch["mel_lengths"]
225
 
226
  # TODO. add duration predictor training
model/utils.py CHANGED
@@ -1,7 +1,6 @@
1
  from __future__ import annotations
2
 
3
  import os
4
- import re
5
  import math
6
  import random
7
  import string
@@ -17,9 +16,6 @@ import torch.nn.functional as F
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
20
- import einx
21
- from einops import rearrange, reduce
22
-
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
 
@@ -57,7 +53,7 @@ def lens_to_mask(
57
  length = t.amax()
58
 
59
  seq = torch.arange(length, device = t.device)
60
- return einx.less('n, b -> b n', seq, t)
61
 
62
  def mask_from_start_end_indices(
63
  seq_len: int['b'],
@@ -66,7 +62,9 @@ def mask_from_start_end_indices(
66
  ):
67
  max_seq_len = seq_len.max().item()
68
  seq = torch.arange(max_seq_len, device = start.device).long()
69
- return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
 
 
70
 
71
  def mask_from_frac_lengths(
72
  seq_len: int['b'],
@@ -89,11 +87,11 @@ def maybe_masked_mean(
89
  if not exists(mask):
90
  return t.mean(dim = 1)
91
 
92
- t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
- num = reduce(t, 'b n d -> b d', 'sum')
94
- den = reduce(mask.float(), 'b n -> b', 'sum')
95
 
96
- return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
 
98
 
99
  # simple utf-8 tokenizer, since paper went character based
@@ -239,7 +237,7 @@ def padded_mel_batch(ref_mels):
239
  padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
  padded_ref_mels.append(padded_ref_mel)
241
  padded_ref_mels = torch.stack(padded_ref_mels)
242
- padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
  return padded_ref_mels
244
 
245
 
@@ -302,7 +300,7 @@ def get_inference_prompt(
302
 
303
  # to mel spectrogram
304
  ref_mel = mel_spectrogram(ref_audio)
305
- ref_mel = rearrange(ref_mel, '1 d n -> d n')
306
 
307
  # deal with batch
308
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
 
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import math
5
  import random
6
  import string
 
16
  from torch.nn.utils.rnn import pad_sequence
17
  import torchaudio
18
 
 
 
 
19
  import jieba
20
  from pypinyin import lazy_pinyin, Style
21
 
 
53
  length = t.amax()
54
 
55
  seq = torch.arange(length, device = t.device)
56
+ return seq[None, :] < t[:, None]
57
 
58
  def mask_from_start_end_indices(
59
  seq_len: int['b'],
 
62
  ):
63
  max_seq_len = seq_len.max().item()
64
  seq = torch.arange(max_seq_len, device = start.device).long()
65
+ start_mask = seq[None, :] >= start[:, None]
66
+ end_mask = seq[None, :] < end[:, None]
67
+ return start_mask & end_mask
68
 
69
  def mask_from_frac_lengths(
70
  seq_len: int['b'],
 
87
  if not exists(mask):
88
  return t.mean(dim = 1)
89
 
90
+ t = torch.where(mask[:, :, None], t, torch.tensor(0., device=t.device))
91
+ num = t.sum(dim=1)
92
+ den = mask.float().sum(dim=1)
93
 
94
+ return num / den.clamp(min=1.)
95
 
96
 
97
  # simple utf-8 tokenizer, since paper went character based
 
237
  padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
238
  padded_ref_mels.append(padded_ref_mel)
239
  padded_ref_mels = torch.stack(padded_ref_mels)
240
+ padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
241
  return padded_ref_mels
242
 
243
 
 
300
 
301
  # to mel spectrogram
302
  ref_mel = mel_spectrogram(ref_audio)
303
+ ref_mel = ref_mel.squeeze(0)
304
 
305
  # deal with batch
306
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
requirements.txt CHANGED
@@ -3,8 +3,6 @@ bitsandbytes>0.37.0
3
  cached_path
4
  click
5
  datasets
6
- einops>=0.8.0
7
- einx>=0.3.0
8
  ema_pytorch>=0.5.2
9
  gradio
10
  jieba
 
3
  cached_path
4
  click
5
  datasets
 
 
6
  ema_pytorch>=0.5.2
7
  gradio
8
  jieba
scripts/eval_infer_batch.py CHANGED
@@ -9,7 +9,6 @@ import argparse
9
  import torch
10
  import torchaudio
11
  from accelerate import Accelerator
12
- from einops import rearrange
13
  from vocos import Vocos
14
 
15
  from model import CFM, UNetT, DiT
@@ -187,7 +186,7 @@ with accelerator.split_between_processes(prompts_all) as prompts:
187
  # Final result
188
  for i, gen in enumerate(generated):
189
  gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
- gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
  generated_wave = vocos.decode(gen_mel_spec.cpu())
192
  if ref_rms_list[i] < target_rms:
193
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
 
9
  import torch
10
  import torchaudio
11
  from accelerate import Accelerator
 
12
  from vocos import Vocos
13
 
14
  from model import CFM, UNetT, DiT
 
186
  # Final result
187
  for i, gen in enumerate(generated):
188
  gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
189
+ gen_mel_spec = gen.permute(0, 2, 1)
190
  generated_wave = vocos.decode(gen_mel_spec.cpu())
191
  if ref_rms_list[i] < target_rms:
192
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
speech_edit.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import torch
4
  import torch.nn.functional as F
5
  import torchaudio
6
- from einops import rearrange
7
  from vocos import Vocos
8
 
9
  from model import CFM, UNetT, DiT, MMDiT
@@ -174,7 +173,7 @@ print(f"Generated mel: {generated.shape}")
174
  # Final result
175
  generated = generated.to(torch.float32)
176
  generated = generated[:, ref_audio_len:, :]
177
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
178
  generated_wave = vocos.decode(generated_mel_spec.cpu())
179
  if rms < target_rms:
180
  generated_wave = generated_wave * rms / target_rms
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torchaudio
 
6
  from vocos import Vocos
7
 
8
  from model import CFM, UNetT, DiT, MMDiT
 
173
  # Final result
174
  generated = generated.to(torch.float32)
175
  generated = generated[:, ref_audio_len:, :]
176
+ generated_mel_spec = generated.permute(0, 2, 1)
177
  generated_wave = vocos.decode(generated_mel_spec.cpu())
178
  if rms < target_rms:
179
  generated_wave = generated_wave * rms / target_rms
train.py CHANGED
@@ -56,7 +56,7 @@ def main():
56
  hop_length = hop_length,
57
  )
58
 
59
- e2tts = CFM(
60
  transformer = model_cls(
61
  **model_cfg,
62
  text_num_embeds = vocab_size,
@@ -67,7 +67,7 @@ def main():
67
  )
68
 
69
  trainer = Trainer(
70
- e2tts,
71
  epochs,
72
  learning_rate,
73
  num_warmup_updates = num_warmup_updates,
 
56
  hop_length = hop_length,
57
  )
58
 
59
+ model = CFM(
60
  transformer = model_cls(
61
  **model_cfg,
62
  text_num_embeds = vocab_size,
 
67
  )
68
 
69
  trainer = Trainer(
70
+ model,
71
  epochs,
72
  learning_rate,
73
  num_warmup_updates = num_warmup_updates,