153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
import torch
|
|
import pyaudio
|
|
import numpy as np
|
|
import time
|
|
import onnxruntime as ort
|
|
import threading
|
|
|
|
ort.set_default_logger_severity(3)
|
|
|
|
|
|
SAMPLERATE = 16000
|
|
|
|
class VADRecorder:
|
|
#def __init__(self, target_device_name, window_size_sec = 0.2, use_onnx = True):
|
|
def __init__(self, use_onnx = True):
|
|
|
|
print("Loading Silero VAD model... ", end="")
|
|
|
|
self.vad_model, utils = torch.hub.load(
|
|
repo_or_dir="snakers4/silero-vad",
|
|
model="silero_vad",
|
|
force_reload=False,
|
|
onnx=use_onnx
|
|
)
|
|
|
|
(
|
|
_, # get_speech_timestamps
|
|
_, # save_audio
|
|
_, # read_audio
|
|
self.VADIterator,
|
|
_ # collect_chunks
|
|
) = utils
|
|
|
|
print("Done!")
|
|
|
|
self.vad_iterator = None
|
|
|
|
|
|
|
|
def _vad_recorder(self):
|
|
print("Listening...")
|
|
|
|
speech_win = 0
|
|
detected_audio = []
|
|
|
|
last_chunk = np.zeros(self.window_size, dtype=np.float32)
|
|
|
|
# Vad iterator needs to be reloaded because after running for a while, it freaks out and hallucinates speech.
|
|
vad_iter_reload_delay = 60 * 2
|
|
vad_iter_load_time = time.time()
|
|
|
|
|
|
self.vad_iterator = self.VADIterator(
|
|
self.vad_model,
|
|
threshold = self.vad_threshold,
|
|
sampling_rate = SAMPLERATE,
|
|
min_silence_duration_ms = self.min_silence_duration_ms,
|
|
speech_pad_ms = self.speech_pad_ms
|
|
)
|
|
|
|
|
|
while self.rec_flag:
|
|
chunk = np.frombuffer(self.stream_in.read(self.window_size), dtype=np.float32)
|
|
|
|
speech_dict = self.vad_iterator(chunk)
|
|
|
|
# check if speech_dict is {"start": x} ir {"end": x}
|
|
if speech_dict is not None:
|
|
self.speech = "start" in speech_dict
|
|
|
|
if self.speech:
|
|
#print("Speech detected!")
|
|
if speech_win == 0:
|
|
detected_audio = last_chunk.tolist()
|
|
speech_win += 1
|
|
detected_audio += chunk.tolist()
|
|
|
|
else:
|
|
if time.time() - vad_iter_load_time > vad_iter_reload_delay:
|
|
self.vad_iterator.reset_states()
|
|
|
|
vad_iter_load_time = time.time()
|
|
|
|
self.vad_iterator = self.VADIterator(
|
|
self.vad_model,
|
|
threshold = self.vad_threshold,
|
|
sampling_rate = SAMPLERATE,
|
|
min_silence_duration_ms = self.min_silence_duration_ms,
|
|
speech_pad_ms = self.speech_pad_ms
|
|
)
|
|
|
|
print("Reloaded VADIterator!")
|
|
|
|
if speech_win > 0:
|
|
speech_win = 0
|
|
|
|
self.audios_for_whisper.append(detected_audio)
|
|
|
|
last_chunk = chunk.copy()
|
|
|
|
|
|
|
|
|
|
#def start_vad_recorder(self, target_device_name, window_size_sec = 0.1, vad_threshold = 0.6, min_silence_duration_ms = 150, speech_pad_ms = 0):
|
|
def start_vad_recorder(self, window_size_sec = 0.1, vad_threshold = 0.6, min_silence_duration_ms = 150, speech_pad_ms = 0):
|
|
|
|
self.window_size = int(window_size_sec * SAMPLERATE)
|
|
|
|
self.vad_threshold = vad_threshold
|
|
self.min_silence_duration_ms = min_silence_duration_ms
|
|
self.speech_pad_ms = speech_pad_ms
|
|
|
|
|
|
self.p = pyaudio.PyAudio()
|
|
|
|
#target_device_index = None
|
|
#for i in range(self.p.get_device_count()):
|
|
# device_info = self.p.get_device_info_by_index(i)
|
|
# if device_info['maxInputChannels'] > 0 and target_device_name in device_info['name']:
|
|
# target_device_index = i
|
|
# break
|
|
#
|
|
#if target_device_index is None:
|
|
# print(f"No target device found with \"{target_device_name}\" in its name.")
|
|
# exit()
|
|
#
|
|
#try:
|
|
# self.stream_in = self.p.open(format=pyaudio.paFloat32, channels=1, rate=SAMPLERATE, input=True, frames_per_buffer=self.window_size, input_device_index=target_device_index)
|
|
#except OSError:
|
|
# print(f"An unexpected error occured when trying to open device stream with \"{target_device_name}\" in its name. That could be caused by the device being disabled or unplugged.")
|
|
# exit()
|
|
|
|
self.stream_in = self.p.open(format=pyaudio.paFloat32, channels=1, rate=SAMPLERATE, input=True, frames_per_buffer=self.window_size)
|
|
|
|
self.speech = False
|
|
self.audios_for_whisper = []
|
|
|
|
|
|
if self.vad_iterator is not None:
|
|
self.vad_iterator.reset_states()
|
|
|
|
self.rec_flag = True
|
|
self.vad_rec_thread = threading.Thread(target=self._vad_recorder, daemon=True)
|
|
self.vad_rec_thread.start()
|
|
|
|
def stop_vad_recorder(self):
|
|
self.rec_flag = False
|
|
self.vad_rec_thread.join()
|
|
|
|
self.stream_in.stop_stream()
|
|
self.stream_in.close()
|
|
self.p.terminate()
|