Spaces:
Runtime error
Runtime error
wenmengzhou
commited on
Commit
•
74ffd9d
1
Parent(s):
7ea67e7
initiaize on cpu for frontend model
Browse files
cosyvoice/cli/frontend.py
CHANGED
@@ -37,7 +37,8 @@ class CosyVoiceFrontEnd:
|
|
37 |
allowed_special: str = 'all'):
|
38 |
self.tokenizer = get_tokenizer()
|
39 |
self.feat_extractor = feat_extractor
|
40 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
41 |
option = onnxruntime.SessionOptions()
|
42 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
43 |
option.intra_op_num_threads = 1
|
@@ -109,12 +110,14 @@ class CosyVoiceFrontEnd:
|
|
109 |
return texts
|
110 |
|
111 |
def frontend_sft(self, tts_text, spk_id):
|
|
|
112 |
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
113 |
embedding = self.spk2info[spk_id]['embedding']
|
114 |
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
115 |
return model_input
|
116 |
|
117 |
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
|
|
118 |
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
119 |
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
120 |
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
|
|
37 |
allowed_special: str = 'all'):
|
38 |
self.tokenizer = get_tokenizer()
|
39 |
self.feat_extractor = feat_extractor
|
40 |
+
#self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
41 |
+
self.device = 'cpu'
|
42 |
option = onnxruntime.SessionOptions()
|
43 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
44 |
option.intra_op_num_threads = 1
|
|
|
110 |
return texts
|
111 |
|
112 |
def frontend_sft(self, tts_text, spk_id):
|
113 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
114 |
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
115 |
embedding = self.spk2info[spk_id]['embedding']
|
116 |
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
117 |
return model_input
|
118 |
|
119 |
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
120 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
121 |
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
122 |
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
123 |
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|