stefanpp's picture
added test inference script
d5d7d09
raw
history blame contribute delete
No virus
2.65 kB
from transformers import Wav2Vec2Processor, AutoConfig
import onnxruntime as rt
import torch
import torch.nn.functional as F
import numpy as np
import os
import torchaudio
import soundfile as sf
class EndOfSpeechDetection:
processor: Wav2Vec2Processor
config: AutoConfig
session: rt.InferenceSession
def load_model(self, path, use_gpu=False):
processor = Wav2Vec2Processor.from_pretrained(path)
config = AutoConfig.from_pretrained(path)
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
session = rt.InferenceSession(
os.path.join(path, "model.onnx"), sess_options, providers=providers
)
return processor, config, session
def predict(self, segment, file_type="pcm"):
if file_type == "pcm":
# pcm files
speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
np.float32
)
else:
# wave files
speech_array, _ = torchaudio.load(segment)
speech_array = speech_array[0].numpy()
features = self.processor(
speech_array, sampling_rate=16000, return_tensors="pt", padding=True
)
input_values = features.input_values
outputs = self.session.run(
[self.session.get_outputs()[-1].name],
{self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
)[0]
softmax_output = F.softmax(torch.tensor(outputs), dim=1)
both_classes_with_prob = {
self.config.id2label[i]: softmax_output[0][i].item()
for i in range(len(softmax_output[0]))
}
return both_classes_with_prob
if __name__ == "__main__":
eos = EndOfSpeechDetection()
eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
audio_file = "5sec_audio.wav"
audio, sr = torchaudio.load(audio_file)
audio = audio[0].numpy()
audio_len = len(audio)
segment_len = 700 * sr // 1000
segments = []
for i in range(0, audio_len, segment_len):
if i + segment_len < audio_len:
segment = audio[i : i + segment_len]
else:
segment = audio[i:]
segments.append(segment)
if not os.path.exists("segments"):
os.makedirs("segments")
for i, segment in enumerate(segments):
sf.write(f"segments/segment_{i}.wav", segment, sr)
print(eos.predict(f"segments/segment_{i}.wav", file_type="wav"))