wenmengzhou commited on
Commit
74ffd9d
1 Parent(s): 7ea67e7

initiaize on cpu for frontend model

Browse files
Files changed (1) hide show
  1. cosyvoice/cli/frontend.py +4 -1
cosyvoice/cli/frontend.py CHANGED
@@ -37,7 +37,8 @@ class CosyVoiceFrontEnd:
37
  allowed_special: str = 'all'):
38
  self.tokenizer = get_tokenizer()
39
  self.feat_extractor = feat_extractor
40
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
41
  option = onnxruntime.SessionOptions()
42
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
43
  option.intra_op_num_threads = 1
@@ -109,12 +110,14 @@ class CosyVoiceFrontEnd:
109
  return texts
110
 
111
  def frontend_sft(self, tts_text, spk_id):
 
112
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
113
  embedding = self.spk2info[spk_id]['embedding']
114
  model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
115
  return model_input
116
 
117
  def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
 
118
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
119
  prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
120
  prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
 
37
  allowed_special: str = 'all'):
38
  self.tokenizer = get_tokenizer()
39
  self.feat_extractor = feat_extractor
40
+ #self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ self.device = 'cpu'
42
  option = onnxruntime.SessionOptions()
43
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
44
  option.intra_op_num_threads = 1
 
110
  return texts
111
 
112
  def frontend_sft(self, tts_text, spk_id):
113
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
115
  embedding = self.spk2info[spk_id]['embedding']
116
  model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
117
  return model_input
118
 
119
  def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
120
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
121
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
122
  prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
123
  prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)