zhzluke96 commited on
Commit
f83b1b7
1 Parent(s): 32b2aaa
models/ChatTTS/config/decoder.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  dim: 384
2
 
3
  decoder_config:
 
1
+
2
+
3
  dim: 384
4
 
5
  decoder_config:
models/ChatTTS/config/dvae.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  dim: 512
2
  decoder_config:
3
  idim: ${dim}
@@ -7,6 +9,6 @@ decoder_config:
7
 
8
  vq_config:
9
  dim: 1024
10
- levels: [5, 5, 5, 5]
11
  G: 2
12
  R: 2
 
1
+
2
+
3
  dim: 512
4
  decoder_config:
5
  idim: ${dim}
 
9
 
10
  vq_config:
11
  dim: 1024
12
+ levels: [5,5,5,5]
13
  G: 2
14
  R: 2
models/ChatTTS/config/gpt.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  num_audio_tokens: 626
2
  num_text_tokens: 21178
3
 
@@ -15,3 +17,4 @@ gpt_config:
15
  num_audio_tokens: 626
16
  num_text_tokens: null
17
  num_vq: 4
 
 
1
+
2
+
3
  num_audio_tokens: 626
4
  num_text_tokens: 21178
5
 
 
17
  num_audio_tokens: 626
18
  num_text_tokens: null
19
  num_vq: 4
20
+
models/ChatTTS/config/path.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  vocos_config_path: config/vocos.yaml
2
  vocos_ckpt_path: asset/Vocos.pt
3
  dvae_config_path: config/dvae.yaml
 
1
+
2
+
3
  vocos_config_path: config/vocos.yaml
4
  vocos_ckpt_path: asset/Vocos.pt
5
  dvae_config_path: config/dvae.yaml
models/ChatTTS/config/vocos.yaml CHANGED
@@ -21,4 +21,4 @@ head:
21
  dim: 512
22
  n_fft: 1024
23
  hop_length: 256
24
- padding: center
 
21
  dim: 512
22
  n_fft: 1024
23
  hop_length: 256
24
+ padding: center
models/Denoise/.gitkeep ADDED
File without changes
models/Denoise/audio-denoiser-512-32-v1/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"scaler": {"mean": -3.027921438217163, "std": 1.9317387342453003}, "in_channels": 257, "n_fft": 512, "num_frames": 32, "exp_id": "115233"}
models/Denoise/audio-denoiser-512-32-v1/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5079784e228d2b36496f2c72f8d06015c8fb1827a81f757ec8540ca708ada7a9
3
+ size 153639572
models/put_model_here ADDED
File without changes
models/resemble-enhance/hparams.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fg_dir: !!python/object/apply:pathlib.PosixPath
2
+ - data
3
+ - fg
4
+ bg_dir: !!python/object/apply:pathlib.PosixPath
5
+ - data
6
+ - bg
7
+ rir_dir: !!python/object/apply:pathlib.PosixPath
8
+ - data
9
+ - rir
10
+ load_fg_only: false
11
+ wav_rate: 44100
12
+ n_fft: 2048
13
+ win_size: 2048
14
+ hop_size: 420
15
+ num_mels: 128
16
+ stft_magnitude_min: 0.0001
17
+ preemphasis: 0.97
18
+ mix_alpha_range:
19
+ - 0.2
20
+ - 0.8
21
+ nj: 64
22
+ training_seconds: 3.0
23
+ batch_size_per_gpu: 32
24
+ min_lr: 1.0e-05
25
+ max_lr: 0.0001
26
+ warmup_steps: 1000
27
+ max_steps: 1000000
28
+ gradient_clipping: 1.0
29
+ cfm_solver_method: midpoint
30
+ cfm_solver_nfe: 64
31
+ cfm_time_mapping_divisor: 4
32
+ univnet_nc: 96
33
+ lcfm_latent_dim: 64
34
+ lcfm_training_mode: cfm
35
+ lcfm_z_scale: 6
36
+ vocoder_extra_dim: 32
37
+ gan_training_start_step: null
38
+ praat_augment_prob: 0.2
models/resemble-enhance/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9d035f318de3e6d919bc70cf7ad7d32b4fe92ec5cbe0b30029a27f5db07d9d6
3
+ size 713176232
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- from typing import List
 
3
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
4
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
5
  from modules.repos_static.resemble_enhance.inference import inference
@@ -25,14 +26,11 @@ def load_enhancer(device: torch.device):
25
 
26
 
27
  class ResembleEnhance:
28
- hparams: HParams
29
- enhancer: Enhancer
30
-
31
  def __init__(self, device: torch.device):
32
  self.device = device
33
 
34
- self.enhancer = None
35
- self.hparams = None
36
 
37
  def load_model(self):
38
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
@@ -42,9 +40,7 @@ class ResembleEnhance:
42
  map_location="cpu",
43
  )["module"]
44
  enhancer.load_state_dict(state_dict)
45
- enhancer.eval()
46
- enhancer.to(self.device)
47
- enhancer.denoiser.to(self.device)
48
 
49
  self.hparams = hparams
50
  self.enhancer = enhancer
@@ -63,7 +59,7 @@ class ResembleEnhance:
63
  sr,
64
  device,
65
  nfe=32,
66
- solver="midpoint",
67
  lambd=0.5,
68
  tau=0.5,
69
  ) -> tuple[torch.Tensor, int]:
@@ -83,34 +79,51 @@ class ResembleEnhance:
83
 
84
  if __name__ == "__main__":
85
  import torchaudio
86
- from modules.models import load_chat_tts
87
-
88
- load_chat_tts()
89
 
90
  device = torch.device("cuda")
91
- ench = ResembleEnhance(device)
92
- ench.load_model()
93
 
94
- wav, sr = torchaudio.load("test.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- print(wav.shape, type(wav), sr, type(sr))
97
- exit()
98
 
99
- wav = wav.squeeze(0).cuda()
100
 
101
- print(wav.device)
102
 
103
- denoised, d_sr = ench.denoise(wav.cpu(), sr, device)
104
- denoised = denoised.unsqueeze(0)
105
- print(denoised.shape)
106
- torchaudio.save("denoised.wav", denoised, d_sr)
107
 
108
- for solver in ("midpoint", "rk4", "euler"):
109
- for lambd in (0.1, 0.5, 0.9):
110
- for tau in (0.1, 0.5, 0.9):
111
- enhanced, e_sr = ench.enhance(
112
- wav.cpu(), sr, device, solver=solver, lambd=lambd, tau=tau, nfe=128
113
- )
114
- enhanced = enhanced.unsqueeze(0)
115
- print(enhanced.shape)
116
- torchaudio.save(f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced, e_sr)
 
 
 
1
  import os
2
+ from typing import List, Literal
3
+ from modules.devices import devices
4
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
5
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
6
  from modules.repos_static.resemble_enhance.inference import inference
 
26
 
27
 
28
  class ResembleEnhance:
 
 
 
29
  def __init__(self, device: torch.device):
30
  self.device = device
31
 
32
+ self.enhancer: HParams = None
33
+ self.hparams: Enhancer = None
34
 
35
  def load_model(self):
36
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
 
40
  map_location="cpu",
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
43
+ enhancer.to(self.device).eval()
 
 
44
 
45
  self.hparams = hparams
46
  self.enhancer = enhancer
 
59
  sr,
60
  device,
61
  nfe=32,
62
+ solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
63
  lambd=0.5,
64
  tau=0.5,
65
  ) -> tuple[torch.Tensor, int]:
 
79
 
80
  if __name__ == "__main__":
81
  import torchaudio
82
+ import gradio as gr
 
 
83
 
84
  device = torch.device("cuda")
 
 
85
 
86
+ # def enhance(file):
87
+ # print(file)
88
+ # ench = load_enhancer(device)
89
+ # dwav, sr = torchaudio.load(file)
90
+ # dwav = dwav.mean(dim=0).to(device)
91
+ # enhanced, e_sr = ench.enhance(dwav, sr)
92
+ # return e_sr, enhanced.cpu().numpy()
93
+
94
+ # # 随便一个示例
95
+ # gr.Interface(
96
+ # fn=enhance, inputs=[gr.Audio(type="filepath")], outputs=[gr.Audio()]
97
+ # ).launch()
98
+
99
+ # load_chat_tts()
100
+
101
+ # ench = load_enhancer(device)
102
+
103
+ # devices.torch_gc()
104
+
105
+ # wav, sr = torchaudio.load("test.wav")
106
 
107
+ # print(wav.shape, type(wav), sr, type(sr))
108
+ # # exit()
109
 
110
+ # wav = wav.squeeze(0).cuda()
111
 
112
+ # print(wav.device)
113
 
114
+ # denoised, d_sr = ench.denoise(wav, sr)
115
+ # denoised = denoised.unsqueeze(0)
116
+ # print(denoised.shape)
117
+ # torchaudio.save("denoised.wav", denoised.cpu(), d_sr)
118
 
119
+ # for solver in ("midpoint", "rk4", "euler"):
120
+ # for lambd in (0.1, 0.5, 0.9):
121
+ # for tau in (0.1, 0.5, 0.9):
122
+ # enhanced, e_sr = ench.enhance(
123
+ # wav, sr, solver=solver, lambd=lambd, tau=tau, nfe=128
124
+ # )
125
+ # enhanced = enhanced.unsqueeze(0)
126
+ # print(enhanced.shape)
127
+ # torchaudio.save(
128
+ # f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced.cpu(), e_sr
129
+ # )
modules/generate_audio.py CHANGED
@@ -72,7 +72,7 @@ def generate_audio_batch(
72
  }
73
 
74
  if isinstance(spk, int):
75
- with SeedContext(spk):
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
  logger.info(("spk", spk))
78
  elif isinstance(spk, Speaker):
@@ -94,7 +94,7 @@ def generate_audio_batch(
94
  }
95
  )
96
 
97
- with SeedContext(infer_seed):
98
  wavs = chat_tts.generate_audio(
99
  texts, params_infer_code, use_decoder=use_decoder
100
  )
 
72
  }
73
 
74
  if isinstance(spk, int):
75
+ with SeedContext(spk, True):
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
  logger.info(("spk", spk))
78
  elif isinstance(spk, Speaker):
 
94
  }
95
  )
96
 
97
+ with SeedContext(infer_seed, True):
98
  wavs = chat_tts.generate_audio(
99
  texts, params_infer_code, use_decoder=use_decoder
100
  )
modules/repos_static/resemble_enhance/enhancer/enhancer.py CHANGED
@@ -73,8 +73,8 @@ class Enhancer(nn.Module):
73
  )
74
  self._load_pretrained(pretrained_path)
75
 
76
- logger.info(f"{self.__class__.__name__} summary")
77
- logger.info(f"{self.summarize()}")
78
 
79
  def _load_pretrained(self, path):
80
  # Clone is necessary as otherwise it holds a reference to the original model
 
73
  )
74
  self._load_pretrained(pretrained_path)
75
 
76
+ # logger.info(f"{self.__class__.__name__} summary")
77
+ # logger.info(f"{self.summarize()}")
78
 
79
  def _load_pretrained(self, path):
80
  # Clone is necessary as otherwise it holds a reference to the original model
modules/speaker.py CHANGED
@@ -11,7 +11,7 @@ import uuid
11
 
12
  def create_speaker_from_seed(seed):
13
  chat_tts = models.load_chat_tts()
14
- with SeedContext(seed):
15
  emb = chat_tts.sample_random_speaker()
16
  return emb
17
 
 
11
 
12
  def create_speaker_from_seed(seed):
13
  chat_tts = models.load_chat_tts()
14
+ with SeedContext(seed, True):
15
  emb = chat_tts.sample_random_speaker()
16
  return emb
17
 
modules/utils/SeedContext.py CHANGED
@@ -7,15 +7,17 @@ import logging
7
  logger = logging.getLogger(__name__)
8
 
9
 
10
- def deterministic(seed=0):
11
  random.seed(seed)
12
  np.random.seed(seed)
13
  torch_rn = rng.convert_np_to_torch(seed)
14
  torch.manual_seed(torch_rn)
15
  if torch.cuda.is_available():
16
  torch.cuda.manual_seed_all(torch_rn)
17
- torch.backends.cudnn.deterministic = True
18
- torch.backends.cudnn.benchmark = False
 
 
19
 
20
 
21
  def is_numeric(obj):
@@ -36,7 +38,7 @@ def is_numeric(obj):
36
 
37
 
38
  class SeedContext:
39
- def __init__(self, seed):
40
  assert is_numeric(seed), "Seed must be an number."
41
 
42
  try:
@@ -45,6 +47,7 @@ class SeedContext:
45
  raise ValueError(f"Seed must be an integer, but: {type(seed)}")
46
 
47
  self.seed = seed
 
48
  self.state = None
49
 
50
  if isinstance(seed, str) and seed.isdigit():
@@ -57,10 +60,16 @@ class SeedContext:
57
  self.seed = random.randint(0, 2**32 - 1)
58
 
59
  def __enter__(self):
60
- self.state = (torch.get_rng_state(), random.getstate(), np.random.get_state())
 
 
 
 
 
 
61
 
62
  try:
63
- deterministic(self.seed)
64
  except Exception as e:
65
  # raise ValueError(
66
  # f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
@@ -73,6 +82,8 @@ class SeedContext:
73
  torch.set_rng_state(self.state[0])
74
  random.setstate(self.state[1])
75
  np.random.set_state(self.state[2])
 
 
76
 
77
 
78
  if __name__ == "__main__":
 
7
  logger = logging.getLogger(__name__)
8
 
9
 
10
+ def deterministic(seed=0, cudnn_deterministic=False):
11
  random.seed(seed)
12
  np.random.seed(seed)
13
  torch_rn = rng.convert_np_to_torch(seed)
14
  torch.manual_seed(torch_rn)
15
  if torch.cuda.is_available():
16
  torch.cuda.manual_seed_all(torch_rn)
17
+
18
+ if cudnn_deterministic:
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
 
22
 
23
  def is_numeric(obj):
 
38
 
39
 
40
  class SeedContext:
41
+ def __init__(self, seed, cudnn_deterministic=False):
42
  assert is_numeric(seed), "Seed must be an number."
43
 
44
  try:
 
47
  raise ValueError(f"Seed must be an integer, but: {type(seed)}")
48
 
49
  self.seed = seed
50
+ self.cudnn_deterministic = cudnn_deterministic
51
  self.state = None
52
 
53
  if isinstance(seed, str) and seed.isdigit():
 
60
  self.seed = random.randint(0, 2**32 - 1)
61
 
62
  def __enter__(self):
63
+ self.state = (
64
+ torch.get_rng_state(),
65
+ random.getstate(),
66
+ np.random.get_state(),
67
+ torch.backends.cudnn.deterministic,
68
+ torch.backends.cudnn.benchmark,
69
+ )
70
 
71
  try:
72
+ deterministic(self.seed, cudnn_deterministic=self.cudnn_deterministic)
73
  except Exception as e:
74
  # raise ValueError(
75
  # f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
 
82
  torch.set_rng_state(self.state[0])
83
  random.setstate(self.state[1])
84
  np.random.set_state(self.state[2])
85
+ torch.backends.cudnn.deterministic = self.state[3]
86
+ torch.backends.cudnn.benchmark = self.state[4]
87
 
88
 
89
  if __name__ == "__main__":
modules/webui/app.py CHANGED
@@ -16,11 +16,6 @@ from modules.webui.readme_tab import create_readme_tab
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
- logging.basicConfig(
20
- level=os.getenv("LOG_LEVEL", "INFO"),
21
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
22
- )
23
-
24
 
25
  def webui_init():
26
  # fix: If the system proxy is enabled in the Windows system, you need to skip these
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
19
 
20
  def webui_init():
21
  # fix: If the system proxy is enabled in the Windows system, you need to skip these
modules/webui/speaker/speaker_creator.py CHANGED
@@ -61,7 +61,7 @@ def create_spk_from_seed(
61
  desc: str,
62
  ):
63
  chat_tts = load_chat_tts()
64
- with SeedContext(seed):
65
  emb = chat_tts.sample_random_speaker()
66
  spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
67
  spk.emb = emb
@@ -118,7 +118,7 @@ def speaker_creator_ui():
118
  with gr.Row():
119
  current_seed = gr.Label(label="Current Seed", value=-1)
120
  with gr.Column(scale=4):
121
- output_audio = gr.Audio(label="Output Audio")
122
 
123
  test_voice_btn.click(
124
  fn=test_spk_voice,
 
61
  desc: str,
62
  ):
63
  chat_tts = load_chat_tts()
64
+ with SeedContext(seed, True):
65
  emb = chat_tts.sample_random_speaker()
66
  spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
67
  spk.emb = emb
 
118
  with gr.Row():
119
  current_seed = gr.Label(label="Current Seed", value=-1)
120
  with gr.Column(scale=4):
121
+ output_audio = gr.Audio(label="Output Audio", format="mp3")
122
 
123
  test_voice_btn.click(
124
  fn=test_spk_voice,
modules/webui/speaker/speaker_merger.py CHANGED
@@ -204,7 +204,9 @@ def create_speaker_merger():
204
  value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
205
  )
206
 
207
- output_audio = gr.Audio(label="Output Audio")
 
 
208
 
209
  with gr.Column(scale=1):
210
  with gr.Group():
 
204
  value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
205
  )
206
 
207
+ output_audio = gr.Audio(
208
+ label="Output Audio", format="mp3"
209
+ )
210
 
211
  with gr.Column(scale=1):
212
  with gr.Group():
modules/webui/ssml_tab.py CHANGED
@@ -44,7 +44,7 @@ def create_ssml_interface():
44
  inputs=[ssml_input],
45
  )
46
 
47
- ssml_output = gr.Audio(label="Generated Audio")
48
 
49
  ssml_button.click(
50
  synthesize_ssml,
 
44
  inputs=[ssml_input],
45
  )
46
 
47
+ ssml_output = gr.Audio(label="Generated Audio", format="mp3")
48
 
49
  ssml_button.click(
50
  synthesize_ssml,
modules/webui/tts_tab.py CHANGED
@@ -204,7 +204,7 @@ def create_tts_interface():
204
 
205
  with gr.Group():
206
  gr.Markdown("🎨Output")
207
- tts_output = gr.Audio(label="Generated Audio")
208
  with gr.Column(scale=1):
209
  with gr.Group():
210
  gr.Markdown("🎶Refiner")
@@ -220,10 +220,9 @@ def create_tts_interface():
220
  value=False, label="Disable Normalize"
221
  )
222
 
223
- # FIXME: 不知道为啥,就是非常慢,单独调脚本是很快的
224
- with gr.Group(visible=webui_config.experimental):
225
  gr.Markdown("💪🏼Enhance")
226
- enable_enhance = gr.Checkbox(value=False, label="Enable Enhance")
227
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
228
  tts_button = gr.Button(
229
  "🔊Generate Audio",
 
204
 
205
  with gr.Group():
206
  gr.Markdown("🎨Output")
207
+ tts_output = gr.Audio(label="Generated Audio", format="mp3")
208
  with gr.Column(scale=1):
209
  with gr.Group():
210
  gr.Markdown("🎶Refiner")
 
220
  value=False, label="Disable Normalize"
221
  )
222
 
223
+ with gr.Group():
 
224
  gr.Markdown("💪🏼Enhance")
225
+ enable_enhance = gr.Checkbox(value=True, label="Enable Enhance")
226
  enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
227
  tts_button = gr.Button(
228
  "🔊Generate Audio",
modules/webui/webui_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Union
2
  import numpy as np
3
 
@@ -23,6 +24,9 @@ from modules import refiner
23
  from modules.utils import audio
24
  from modules.SentenceSplitter import SentenceSplitter
25
 
 
 
 
26
 
27
  def get_speakers():
28
  return speaker_mgr.list_speakers()
@@ -67,22 +71,23 @@ def segments_length_limit(
67
  @torch.inference_mode()
68
  @spaces.GPU
69
  def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
70
- audio_data = torch.from_numpy(audio_data).float().squeeze().cpu()
71
- if enable_denoise or enable_enhance:
72
- enhancer = load_enhancer(devices.device)
73
- if enable_denoise:
74
- audio_data, sr = enhancer.denoise(audio_data, sr, devices.device)
75
- if enable_enhance:
76
- audio_data, sr = enhancer.enhance(
77
- audio_data,
78
- sr,
79
- devices.device,
80
- tau=0.9,
81
- nfe=64,
82
- solver="euler",
83
- lambd=0.5,
84
- )
85
- audio_data = audio_data.cpu().numpy()
 
86
  return audio_data, int(sr)
87
 
88
 
@@ -111,10 +116,12 @@ def synthesize_ssml(ssml: str, batch_size=4):
111
  audio_segments = synthesize.synthesize_segments(segments)
112
  combined_audio = combine_audio_segments(audio_segments)
113
 
114
- return audio.pydub_to_np(combined_audio)
115
 
 
116
 
117
- @torch.inference_mode()
 
118
  @spaces.GPU
119
  def tts_generate(
120
  text,
@@ -186,7 +193,6 @@ def tts_generate(
186
  audio_data, sample_rate = apply_audio_enhance(
187
  audio_data, sample_rate, enable_denoise, enable_enhance
188
  )
189
-
190
  audio_data = audio.audio_to_int16(audio_data)
191
  return sample_rate, audio_data
192
 
 
1
+ import io
2
  from typing import Union
3
  import numpy as np
4
 
 
24
  from modules.utils import audio
25
  from modules.SentenceSplitter import SentenceSplitter
26
 
27
+ from pydub import AudioSegment
28
+ import torch.profiler
29
+
30
 
31
  def get_speakers():
32
  return speaker_mgr.list_speakers()
 
71
  @torch.inference_mode()
72
  @spaces.GPU
73
  def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
74
+ if not enable_denoise and not enable_enhance:
75
+ return audio_data, sr
76
+
77
+ device = devices.device
78
+ # NOTE: 这里很奇怪按道理得放到 device 上,但是 enhancer chunk 的时候会报错...所以得 cpu()
79
+ tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
80
+ enhancer = load_enhancer(device)
81
+
82
+ if enable_enhance:
83
+ lambd = 0.9 if enable_denoise else 0.1
84
+ tensor, sr = enhancer.enhance(
85
+ tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd, device=device
86
+ )
87
+ elif enable_denoise:
88
+ tensor, sr = enhancer.denoise(tensor, sr)
89
+
90
+ audio_data = tensor.cpu().numpy()
91
  return audio_data, int(sr)
92
 
93
 
 
116
  audio_segments = synthesize.synthesize_segments(segments)
117
  combined_audio = combine_audio_segments(audio_segments)
118
 
119
+ sr, audio_data = audio.pydub_to_np(combined_audio)
120
 
121
+ return sr, audio_data
122
 
123
+
124
+ # @torch.inference_mode()
125
  @spaces.GPU
126
  def tts_generate(
127
  text,
 
193
  audio_data, sample_rate = apply_audio_enhance(
194
  audio_data, sample_rate, enable_denoise, enable_enhance
195
  )
 
196
  audio_data = audio.audio_to_int16(audio_data)
197
  return sample_rate, audio_data
198
 
webui.py CHANGED
@@ -1,4 +1,11 @@
1
  import os
 
 
 
 
 
 
 
2
  from modules.devices import devices
3
  from modules.utils import env
4
  from modules.webui import webui_config
 
1
  import os
2
+ import logging
3
+
4
+ # logging.basicConfig(
5
+ # level=os.getenv("LOG_LEVEL", "INFO"),
6
+ # format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
+ # )
8
+
9
  from modules.devices import devices
10
  from modules.utils import env
11
  from modules.webui import webui_config