PizzAI/frontend/vad_recorder.py
2023-12-04 23:36:07 +01:00

153 lines
4.8 KiB
Python

import torch
import pyaudio
import numpy as np
import time
import onnxruntime as ort
import threading
ort.set_default_logger_severity(3)
SAMPLERATE = 16000
class VADRecorder:
#def __init__(self, target_device_name, window_size_sec = 0.2, use_onnx = True):
def __init__(self, use_onnx = True):
print("Loading Silero VAD model... ", end="")
self.vad_model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False,
onnx=use_onnx
)
(
_, # get_speech_timestamps
_, # save_audio
_, # read_audio
self.VADIterator,
_ # collect_chunks
) = utils
print("Done!")
self.vad_iterator = None
def _vad_recorder(self):
print("Listening...")
speech_win = 0
detected_audio = []
last_chunk = np.zeros(self.window_size, dtype=np.float32)
# Vad iterator needs to be reloaded because after running for a while, it freaks out and hallucinates speech.
vad_iter_reload_delay = 60 * 2
vad_iter_load_time = time.time()
self.vad_iterator = self.VADIterator(
self.vad_model,
threshold = self.vad_threshold,
sampling_rate = SAMPLERATE,
min_silence_duration_ms = self.min_silence_duration_ms,
speech_pad_ms = self.speech_pad_ms
)
while self.rec_flag:
chunk = np.frombuffer(self.stream_in.read(self.window_size), dtype=np.float32)
speech_dict = self.vad_iterator(chunk)
# check if speech_dict is {"start": x} ir {"end": x}
if speech_dict is not None:
self.speech = "start" in speech_dict
if self.speech:
#print("Speech detected!")
if speech_win == 0:
detected_audio = last_chunk.tolist()
speech_win += 1
detected_audio += chunk.tolist()
else:
if time.time() - vad_iter_load_time > vad_iter_reload_delay:
self.vad_iterator.reset_states()
vad_iter_load_time = time.time()
self.vad_iterator = self.VADIterator(
self.vad_model,
threshold = self.vad_threshold,
sampling_rate = SAMPLERATE,
min_silence_duration_ms = self.min_silence_duration_ms,
speech_pad_ms = self.speech_pad_ms
)
print("Reloaded VADIterator!")
if speech_win > 0:
speech_win = 0
self.audios_for_whisper.append(detected_audio)
last_chunk = chunk.copy()
#def start_vad_recorder(self, target_device_name, window_size_sec = 0.1, vad_threshold = 0.6, min_silence_duration_ms = 150, speech_pad_ms = 0):
def start_vad_recorder(self, window_size_sec = 0.1, vad_threshold = 0.6, min_silence_duration_ms = 150, speech_pad_ms = 0):
self.window_size = int(window_size_sec * SAMPLERATE)
self.vad_threshold = vad_threshold
self.min_silence_duration_ms = min_silence_duration_ms
self.speech_pad_ms = speech_pad_ms
self.p = pyaudio.PyAudio()
#target_device_index = None
#for i in range(self.p.get_device_count()):
# device_info = self.p.get_device_info_by_index(i)
# if device_info['maxInputChannels'] > 0 and target_device_name in device_info['name']:
# target_device_index = i
# break
#
#if target_device_index is None:
# print(f"No target device found with \"{target_device_name}\" in its name.")
# exit()
#
#try:
# self.stream_in = self.p.open(format=pyaudio.paFloat32, channels=1, rate=SAMPLERATE, input=True, frames_per_buffer=self.window_size, input_device_index=target_device_index)
#except OSError:
# print(f"An unexpected error occured when trying to open device stream with \"{target_device_name}\" in its name. That could be caused by the device being disabled or unplugged.")
# exit()
self.stream_in = self.p.open(format=pyaudio.paFloat32, channels=1, rate=SAMPLERATE, input=True, frames_per_buffer=self.window_size)
self.speech = False
self.audios_for_whisper = []
if self.vad_iterator is not None:
self.vad_iterator.reset_states()
self.rec_flag = True
self.vad_rec_thread = threading.Thread(target=self._vad_recorder, daemon=True)
self.vad_rec_thread.start()
def stop_vad_recorder(self):
self.rec_flag = False
self.vad_rec_thread.join()
self.stream_in.stop_stream()
self.stream_in.close()
self.p.terminate()