2023-11-15 19:57:17 +01:00
|
|
|
import os
|
|
|
|
import pyaudio
|
|
|
|
from TTS.api import TTS
|
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig
|
|
|
|
from TTS.tts.models.xtts import Xtts
|
|
|
|
from TTS.utils.generic_utils import get_user_data_dir
|
|
|
|
import threading
|
|
|
|
import time
|
2023-12-01 21:48:56 +01:00
|
|
|
import re
|
|
|
|
|
2023-11-15 19:57:17 +01:00
|
|
|
|
|
|
|
|
|
|
|
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
|
|
|
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
class TTSStream:
|
|
|
|
def __init__(self, speaker_wav=None, device=None):
|
2023-11-15 19:57:17 +01:00
|
|
|
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
if device is None:
|
|
|
|
import torch
|
|
|
|
|
|
|
|
# Check if CUDA is available
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
print("Using CUDA")
|
|
|
|
device = "cuda"
|
|
|
|
else:
|
|
|
|
print("Using CPU")
|
|
|
|
device = "cpu"
|
|
|
|
|
2023-11-15 19:57:17 +01:00
|
|
|
#print(model_path)
|
2023-12-01 21:48:56 +01:00
|
|
|
|
|
|
|
print("Loading TTS model... ", end="")
|
2023-11-15 19:57:17 +01:00
|
|
|
#
|
|
|
|
# download model if it doesn't exist
|
|
|
|
if not os.path.exists(os.path.join(model_path, "config.json")):
|
2023-12-01 21:48:56 +01:00
|
|
|
print("Downloading model... ", end="")
|
2023-11-15 19:57:17 +01:00
|
|
|
tts = TTS()
|
|
|
|
tts.download_model_by_name(model_name=model_name)
|
|
|
|
|
|
|
|
config = XttsConfig()
|
|
|
|
config.load_json(os.path.join(model_path, "config.json"))
|
|
|
|
self.model = Xtts.init_from_config(config)
|
|
|
|
self.model.load_checkpoint(
|
|
|
|
config,
|
|
|
|
checkpoint_path=os.path.join(model_path, "model.pth"),
|
|
|
|
vocab_path=os.path.join(model_path, "vocab.json"),
|
|
|
|
eval=True,
|
|
|
|
use_deepspeed=False
|
|
|
|
)
|
|
|
|
self.model.to(device)
|
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
print("Done!")
|
2023-11-15 19:57:17 +01:00
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
if speaker_wav is not None:
|
|
|
|
#self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(audio_path=speaker_wav)
|
|
|
|
self.change_speaker(speaker_wav)
|
2023-11-15 19:57:17 +01:00
|
|
|
|
|
|
|
def change_speaker(self, speaker_wav):
|
2023-12-01 21:48:56 +01:00
|
|
|
print("Loading speaker... ", end="")
|
2023-11-15 19:57:17 +01:00
|
|
|
self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(audio_path=speaker_wav)
|
2023-12-01 21:48:56 +01:00
|
|
|
print("Done!")
|
2023-11-15 19:57:17 +01:00
|
|
|
|
|
|
|
def _write_stream(self):
|
|
|
|
# play first play_buffer_size samples and remove them from the buffer
|
|
|
|
while True:
|
|
|
|
if len(self.chunks_bin) > 0:
|
|
|
|
self.chunk = self.chunks_bin[:self.play_buffer_size]
|
|
|
|
self.chunks_bin = self.chunks_bin[self.play_buffer_size:]
|
|
|
|
self.stream.write(self.chunk)
|
|
|
|
else:
|
|
|
|
if self.all_done:
|
2023-12-01 21:48:56 +01:00
|
|
|
#self.thread_ended = True
|
2023-11-15 19:57:17 +01:00
|
|
|
break
|
|
|
|
time.sleep(0.01)
|
|
|
|
|
|
|
|
|
|
|
|
def tts_speak(self, text):
|
|
|
|
self.play_buffer_size = 512
|
|
|
|
|
|
|
|
|
|
|
|
# open pyaudio stream
|
|
|
|
p = pyaudio.PyAudio()
|
|
|
|
self.stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True)
|
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
# for each sentence ending with . or ! or ?
|
|
|
|
for text in re.split(r"(?<=[.!?])", text):
|
|
|
|
text = text.strip()
|
2023-11-15 19:57:17 +01:00
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
if len(text) == 0:
|
|
|
|
continue
|
2023-11-15 19:57:17 +01:00
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
chunks = self.model.inference_stream(
|
|
|
|
text,
|
|
|
|
"pl",
|
|
|
|
self.gpt_cond_latent,
|
|
|
|
self.speaker_embedding,
|
|
|
|
stream_chunk_size=20,
|
|
|
|
)
|
2023-11-15 19:57:17 +01:00
|
|
|
|
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
self.chunks_bin = b""
|
|
|
|
self.all_done = False
|
|
|
|
|
|
|
|
# run write_stream as thread
|
|
|
|
#self.thread_ended = False
|
|
|
|
thread = threading.Thread(target=self._write_stream)
|
|
|
|
thread.start()
|
2023-11-15 19:57:17 +01:00
|
|
|
|
2023-12-01 21:48:56 +01:00
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# read chunks from chunks generator as they are generated
|
|
|
|
for self.chunk in chunks:
|
|
|
|
self.chunks_bin += self.chunk.cpu().numpy().astype("float32").tobytes()
|
|
|
|
break
|
|
|
|
# some weird error caused by coqui-tts
|
|
|
|
except:
|
|
|
|
print("Error occured when generating audio stream. Retrying...")
|
|
|
|
continue
|
|
|
|
|
|
|
|
self.all_done = True
|
|
|
|
|
|
|
|
# wait for thread to finish
|
|
|
|
thread.join()
|
|
|
|
|
|
|
|
# wait for thread ended
|
|
|
|
#while not self.thread_ended:
|
|
|
|
# time.sleep(0.01)
|
|
|
|
|
|
|
|
#while True:
|
|
|
|
# if self.thread_ended:
|
|
|
|
# break
|
|
|
|
# print("Waiting for thread to end...")
|
|
|
|
# time.sleep(0.01)
|
2023-11-15 19:57:17 +01:00
|
|
|
|
|
|
|
self.stream.close()
|
|
|
|
p.terminate()
|