Files
smart-speaker/app/audio/tts.py

350 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Text-to-Speech module using Silero TTS.
Generates natural Russian speech.
Supports interruption via wake word detection using threading.
"""
# Модуль синтеза речи (TTS - Text-to-Speech).
# Использует нейросеть Silero TTS для качественной русской речи.
# Также поддерживает прерывание речи, если пользователь скажет "Alexandr".
import torch
import sounddevice as sd
import numpy as np
import threading
import time
import warnings
import re
from ..core.config import TTS_SPEAKER, TTS_EN_SPEAKER, TTS_SAMPLE_RATE
# Подавляем предупреждения Silero о длинном тексте (мы сами его режем)
warnings.filterwarnings("ignore", message="Text string is longer than 1000 symbols")
_EN_WORD_RE = re.compile(r"[A-Za-z][A-Za-z0-9'-]*")
class TextToSpeech:
"""Класс синтеза речи с поддержкой прерывания."""
def __init__(self):
self.model_ru = None
self.model_en = None
self.sample_rate = TTS_SAMPLE_RATE
self.speaker_ru = TTS_SPEAKER
self.speaker_en = TTS_EN_SPEAKER
self._interrupted = False
self._stop_flag = threading.Event()
def _load_model(self, language: str):
"""
Загрузка и кэширование модели Silero TTS.
Загружается один раз при первом обращении.
"""
device = torch.device("cpu") # Работаем на процессоре (достаточно быстро)
if language == "en":
if self.model_en:
return self.model_en
print("📦 Загрузка модели Silero TTS (en)...")
model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="en",
speaker="v3_en",
)
model.to(device)
self.model_en = model
return model
# По умолчанию русский
if self.model_ru:
return self.model_ru
print("📦 Загрузка модели Silero TTS (ru)...")
model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="ru",
speaker="v5_ru",
)
model.to(device)
self.model_ru = model
return model
def initialize(self):
"""Предварительная инициализация (прогрев) русской модели."""
self._load_model("ru")
def _split_text(self, text: str, max_length: int = 900) -> list[str]:
"""
Разбивает длинный текст на части (чанки), так как Silero не принимает >1000 символов.
Старается разбивать по предложениям (.!?).
"""
if len(text) <= max_length:
return [text]
chunks = []
# Разбиваем по знакам препинания, сохраняя их
parts = re.split(r"([.!?]+\s*)", text)
current_chunk = ""
for part in parts:
# Если добавление части превысит лимит, сохраняем текущий кусок
if len(current_chunk) + len(part) > max_length:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = ""
current_chunk += part
# Если даже одна часть огромная (нет знаков препинания), режем жестко по пробелам
while len(current_chunk) > max_length:
split_idx = current_chunk.rfind(" ", 0, max_length)
if split_idx == -1:
split_idx = max_length # Если нет пробелов, режем посередине слова
chunks.append(current_chunk[:split_idx].strip())
current_chunk = current_chunk[split_idx:].lstrip()
if current_chunk:
chunks.append(current_chunk.strip())
return [c for c in chunks if c]
def _split_mixed_language(self, text: str) -> list[tuple[str, str]]:
"""
Разбивает текст на сегменты русского и английского текста.
Английские слова (латиница) будут озвучены английской моделью.
"""
matches = list(_EN_WORD_RE.finditer(text))
if not matches:
return [(text, "ru")]
segments = []
idx = 0
for match in matches:
if match.start() > idx:
segments.append((text[idx : match.start()], "ru"))
segments.append((match.group(0), "en"))
idx = match.end()
if idx < len(text):
segments.append((text[idx:], "ru"))
# Склеиваем соседние сегменты и прикрепляем чистую пунктуацию к предыдущему.
merged = []
for segment, lang in segments:
if not segment:
continue
if not any(ch.isalnum() for ch in segment):
if merged:
merged[-1] = (merged[-1][0] + segment, merged[-1][1])
else:
merged.append((segment, lang))
continue
if merged and merged[-1][1] == lang:
merged[-1] = (merged[-1][0] + segment, lang)
else:
merged.append((segment, lang))
if merged and not any(ch.isalnum() for ch in merged[0][0]) and len(merged) > 1:
merged[1] = (merged[0][0] + merged[1][0], merged[1][1])
merged = merged[1:]
return merged
def _speak_single_language(
self, text: str, check_interrupt=None, language: str = "ru"
) -> bool:
"""Озвучивание текста одной моделью языка."""
if not text.strip():
return True
# Выбор модели
if language == "en":
model = self._load_model("en")
speaker = self.speaker_en
else:
model = self._load_model("ru")
speaker = self.speaker_ru
# Проверка наличия спикера в модели (защита от ошибок конфига).
# Для русского языка сохраняем мужской голос по умолчанию.
if hasattr(model, "speakers") and model.speakers:
if language == "ru":
male_speakers = ("eugene", "aidar")
if speaker not in model.speakers or speaker not in male_speakers:
for candidate in male_speakers:
if candidate in model.speakers:
speaker = candidate
break
else:
speaker = model.speakers[0]
elif speaker not in model.speakers:
speaker = model.speakers[0]
# Разбиваем текст на куски
chunks = self._split_text(text)
total_chunks = len(chunks)
if total_chunks > 1:
print(f"🔊 Озвучивание (частей: {total_chunks}): {text[:50]}...")
else:
print(f"🔊 Озвучивание: {text[:50]}...")
self._interrupted = False
self._stop_flag.clear()
success = True
for i, chunk in enumerate(chunks):
if self._interrupted:
break
try:
# Генерация аудио (тензор)
audio = model.apply_tts(
text=chunk, speaker=speaker, sample_rate=self.sample_rate
)
# Конвертация в numpy массив для sounddevice
audio_np = audio.numpy()
if check_interrupt:
# Воспроизведение с проверкой прерывания (сложная логика)
if not self._play_with_interrupt(audio_np, check_interrupt):
success = False
break
else:
# Обычное воспроизведение (блокирующее)
sd.play(audio_np, self.sample_rate)
sd.wait()
except Exception as e:
print(f"❌ Ошибка TTS (часть {i + 1}/{total_chunks}): {e}")
success = False
if success and not self._interrupted:
print("✅ Воспроизведение завершено")
return True
elif self._interrupted:
return False
else:
return False
def _speak_mixed(
self, segments: list[tuple[str, str]], check_interrupt=None
) -> bool:
"""Озвучивание текста с переключением RU/EN по сегментам."""
for segment, lang in segments:
if not segment.strip():
continue
completed = self._speak_single_language(
segment, check_interrupt=check_interrupt, language=lang
)
if not completed:
return False
return True
def speak(self, text: str, check_interrupt=None, language: str = "ru") -> bool:
"""
Основная функция: генерирует аудио и воспроизводит его.
Args:
text: Текст для озвучки.
check_interrupt: Функция, возвращающая True, если надо прерваться (например, check_wakeword_once).
language: "ru" или "en".
Returns:
True, если договорил до конца.
False, если был прерван.
"""
if not text.strip():
return True
if language == "ru":
segments = self._split_mixed_language(text)
if any(lang == "en" for _, lang in segments):
return self._speak_mixed(segments, check_interrupt=check_interrupt)
return self._speak_single_language(
text, check_interrupt=check_interrupt, language=language
)
def _check_interrupt_worker(self, check_interrupt):
"""
Фоновая функция для потока: постоянно опрашивает check_interrupt.
Если вернуло True -> останавливаем звук.
"""
while not self._stop_flag.is_set():
try:
if check_interrupt():
self._interrupted = True
sd.stop() # Немедленная остановка звука
print("⏹️ Воспроизведение прервано!")
return
except Exception:
pass
def _play_with_interrupt(self, audio_np: np.ndarray, check_interrupt) -> bool:
"""
Воспроизводит аудио, параллельно проверяя условие прерывания в отдельном потоке.
"""
# Запускаем поток-наблюдатель
checker_thread = threading.Thread(
target=self._check_interrupt_worker, args=(check_interrupt,), daemon=True
)
checker_thread.start()
try:
# Запускаем воспроизведение (неблокирующее)
sd.play(audio_np, self.sample_rate)
# Ждем окончания воспроизведения в цикле
while sd.get_stream().active:
if self._interrupted:
break
time.sleep(0.05)
finally:
# Сообщаем потоку-наблюдателю, что пора завершаться
self._stop_flag.set()
checker_thread.join(timeout=0.5)
if self._interrupted:
return False
return True
@property
def was_interrupted(self) -> bool:
"""Был ли прерван последний вызов speak."""
return self._interrupted
# Глобальный экземпляр TTS
_tts = None
def get_tts() -> TextToSpeech:
"""Получить или создать экземпляр TTS."""
global _tts
if _tts is None:
_tts = TextToSpeech()
return _tts
def speak(text: str, check_interrupt=None, language: str = "ru") -> bool:
"""Внешняя функция для озвучивания."""
return get_tts().speak(text, check_interrupt, language)
def was_interrupted() -> bool:
"""Проверка флага прерывания."""
return get_tts().was_interrupted
def initialize():
"""Предварительная загрузка моделей."""
get_tts().initialize()