mrfakename commited on
Commit
fe296ca
1 Parent(s): 9c54d62

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. model/dataset.py +17 -3
model/dataset.py CHANGED
@@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
  from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
 
11
 
12
  from model.modules import MelSpec
 
13
 
14
 
15
  class HFDataset(Dataset):
@@ -77,15 +79,22 @@ class CustomDataset(Dataset):
77
  hop_length=256,
78
  n_mel_channels=100,
79
  preprocessed_mel=False,
 
80
  ):
81
  self.data = custom_dataset
82
  self.durations = durations
83
  self.target_sample_rate = target_sample_rate
84
  self.hop_length = hop_length
85
  self.preprocessed_mel = preprocessed_mel
 
86
  if not preprocessed_mel:
87
- self.mel_spectrogram = MelSpec(
88
- target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
 
 
 
 
 
89
  )
90
 
91
  def get_frame_len(self, index):
@@ -201,6 +210,7 @@ def load_dataset(
201
  tokenizer: str = "pinyin",
202
  dataset_type: str = "CustomDataset",
203
  audio_type: str = "raw",
 
204
  mel_spec_kwargs: dict = dict(),
205
  ) -> CustomDataset | HFDataset:
206
  """
@@ -224,7 +234,11 @@ def load_dataset(
224
  data_dict = json.load(f)
225
  durations = data_dict["duration"]
226
  train_dataset = CustomDataset(
227
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
 
 
 
 
228
  )
229
 
230
  elif dataset_type == "CustomDatasetPath":
 
8
  import torchaudio
9
  from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
+ from torch import nn
12
 
13
  from model.modules import MelSpec
14
+ from model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
 
79
  hop_length=256,
80
  n_mel_channels=100,
81
  preprocessed_mel=False,
82
+ mel_spec_module: nn.Module | None = None,
83
  ):
84
  self.data = custom_dataset
85
  self.durations = durations
86
  self.target_sample_rate = target_sample_rate
87
  self.hop_length = hop_length
88
  self.preprocessed_mel = preprocessed_mel
89
+
90
  if not preprocessed_mel:
91
+ self.mel_spectrogram = default(
92
+ mel_spec_module,
93
+ MelSpec(
94
+ target_sample_rate=target_sample_rate,
95
+ hop_length=hop_length,
96
+ n_mel_channels=n_mel_channels,
97
+ ),
98
  )
99
 
100
  def get_frame_len(self, index):
 
210
  tokenizer: str = "pinyin",
211
  dataset_type: str = "CustomDataset",
212
  audio_type: str = "raw",
213
+ mel_spec_module: nn.Module | None = None,
214
  mel_spec_kwargs: dict = dict(),
215
  ) -> CustomDataset | HFDataset:
216
  """
 
234
  data_dict = json.load(f)
235
  durations = data_dict["duration"]
236
  train_dataset = CustomDataset(
237
+ train_dataset,
238
+ durations=durations,
239
+ preprocessed_mel=preprocessed_mel,
240
+ mel_spec_module=mel_spec_module,
241
+ **mel_spec_kwargs,
242
  )
243
 
244
  elif dataset_type == "CustomDatasetPath":