другая структура проекта + beads + александр повтори + комментарии везде + readme
This commit is contained in:
265
app/audio/tts.py
Normal file
265
app/audio/tts.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Text-to-Speech module using Silero TTS.
|
||||
Generates natural Russian speech.
|
||||
Supports interruption via wake word detection using threading.
|
||||
"""
|
||||
|
||||
# Модуль синтеза речи (TTS - Text-to-Speech).
|
||||
# Использует нейросеть Silero TTS для качественной русской речи.
|
||||
# Также поддерживает прерывание речи, если пользователь скажет "Alexandr".
|
||||
|
||||
import torch
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
import re
|
||||
from ..core.config import TTS_SPEAKER, TTS_EN_SPEAKER, TTS_SAMPLE_RATE
|
||||
|
||||
# Подавляем предупреждения Silero о длинном тексте (мы сами его режем)
|
||||
warnings.filterwarnings("ignore", message="Text string is longer than 1000 symbols")
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
"""Класс синтеза речи с поддержкой прерывания."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_ru = None
|
||||
self.model_en = None
|
||||
self.sample_rate = TTS_SAMPLE_RATE
|
||||
self.speaker_ru = TTS_SPEAKER
|
||||
self.speaker_en = TTS_EN_SPEAKER
|
||||
self._interrupted = False
|
||||
self._stop_flag = threading.Event()
|
||||
|
||||
def _load_model(self, language: str):
|
||||
"""
|
||||
Загрузка и кэширование модели Silero TTS.
|
||||
Загружается один раз при первом обращении.
|
||||
"""
|
||||
device = torch.device("cpu") # Работаем на процессоре (достаточно быстро)
|
||||
|
||||
if language == "en":
|
||||
if self.model_en:
|
||||
return self.model_en
|
||||
print("📦 Загрузка модели Silero TTS (en)...")
|
||||
model, _ = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-models",
|
||||
model="silero_tts",
|
||||
language="en",
|
||||
speaker="v3_en",
|
||||
)
|
||||
model.to(device)
|
||||
self.model_en = model
|
||||
return model
|
||||
|
||||
# По умолчанию русский
|
||||
if self.model_ru:
|
||||
return self.model_ru
|
||||
print("📦 Загрузка модели Silero TTS (ru)...")
|
||||
model, _ = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-models",
|
||||
model="silero_tts",
|
||||
language="ru",
|
||||
speaker="v5_ru",
|
||||
)
|
||||
model.to(device)
|
||||
self.model_ru = model
|
||||
return model
|
||||
|
||||
def initialize(self):
|
||||
"""Предварительная инициализация (прогрев) русской модели."""
|
||||
self._load_model("ru")
|
||||
|
||||
def _split_text(self, text: str, max_length: int = 900) -> list[str]:
|
||||
"""
|
||||
Разбивает длинный текст на части (чанки), так как Silero не принимает >1000 символов.
|
||||
Старается разбивать по предложениям (.!?).
|
||||
"""
|
||||
if len(text) <= max_length:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
# Разбиваем по знакам препинания, сохраняя их
|
||||
parts = re.split(r"([.!?]+\s*)", text)
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for part in parts:
|
||||
# Если добавление части превысит лимит, сохраняем текущий кусок
|
||||
if len(current_chunk) + len(part) > max_length:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += part
|
||||
|
||||
# Если даже одна часть огромная (нет знаков препинания), режем жестко по пробелам
|
||||
while len(current_chunk) > max_length:
|
||||
split_idx = current_chunk.rfind(" ", 0, max_length)
|
||||
if split_idx == -1:
|
||||
split_idx = max_length # Если нет пробелов, режем посередине слова
|
||||
|
||||
chunks.append(current_chunk[:split_idx].strip())
|
||||
current_chunk = current_chunk[split_idx:].lstrip()
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return [c for c in chunks if c]
|
||||
|
||||
def speak(self, text: str, check_interrupt=None, language: str = "ru") -> bool:
|
||||
"""
|
||||
Основная функция: генерирует аудио и воспроизводит его.
|
||||
|
||||
Args:
|
||||
text: Текст для озвучки.
|
||||
check_interrupt: Функция, возвращающая True, если надо прерваться (например, check_wakeword_once).
|
||||
language: "ru" или "en".
|
||||
|
||||
Returns:
|
||||
True, если договорил до конца.
|
||||
False, если был прерван.
|
||||
"""
|
||||
if not text.strip():
|
||||
return True
|
||||
|
||||
# Выбор модели
|
||||
if language == "en":
|
||||
model = self._load_model("en")
|
||||
speaker = self.speaker_en
|
||||
else:
|
||||
model = self._load_model("ru")
|
||||
speaker = self.speaker_ru
|
||||
|
||||
# Проверка наличия спикера в модели (защита от ошибок конфига)
|
||||
if hasattr(model, "speakers") and speaker not in model.speakers:
|
||||
if model.speakers:
|
||||
speaker = model.speakers[0]
|
||||
|
||||
# Разбиваем текст на куски
|
||||
chunks = self._split_text(text)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
if total_chunks > 1:
|
||||
print(f"🔊 Озвучивание (частей: {total_chunks}): {text[:50]}...")
|
||||
else:
|
||||
print(f"🔊 Озвучивание: {text[:50]}...")
|
||||
|
||||
self._interrupted = False
|
||||
self._stop_flag.clear()
|
||||
|
||||
success = True
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
if self._interrupted:
|
||||
break
|
||||
|
||||
try:
|
||||
# Генерация аудио (тензор)
|
||||
audio = model.apply_tts(
|
||||
text=chunk, speaker=speaker, sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
# Конвертация в numpy массив для sounddevice
|
||||
audio_np = audio.numpy()
|
||||
|
||||
if check_interrupt:
|
||||
# Воспроизведение с проверкой прерывания (сложная логика)
|
||||
if not self._play_with_interrupt(audio_np, check_interrupt):
|
||||
success = False
|
||||
break
|
||||
else:
|
||||
# Обычное воспроизведение (блокирующее)
|
||||
sd.play(audio_np, self.sample_rate)
|
||||
sd.wait()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка TTS (часть {i + 1}/{total_chunks}): {e}")
|
||||
success = False
|
||||
|
||||
if success and not self._interrupted:
|
||||
print("✅ Воспроизведение завершено")
|
||||
return True
|
||||
elif self._interrupted:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def _check_interrupt_worker(self, check_interrupt):
|
||||
"""
|
||||
Фоновая функция для потока: постоянно опрашивает check_interrupt.
|
||||
Если вернуло True -> останавливаем звук.
|
||||
"""
|
||||
while not self._stop_flag.is_set():
|
||||
try:
|
||||
if check_interrupt():
|
||||
self._interrupted = True
|
||||
sd.stop() # Немедленная остановка звука
|
||||
print("⏹️ Воспроизведение прервано!")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _play_with_interrupt(self, audio_np: np.ndarray, check_interrupt) -> bool:
|
||||
"""
|
||||
Воспроизводит аудио, параллельно проверяя условие прерывания в отдельном потоке.
|
||||
"""
|
||||
# Запускаем поток-наблюдатель
|
||||
checker_thread = threading.Thread(
|
||||
target=self._check_interrupt_worker, args=(check_interrupt,), daemon=True
|
||||
)
|
||||
checker_thread.start()
|
||||
|
||||
try:
|
||||
# Запускаем воспроизведение (неблокирующее)
|
||||
sd.play(audio_np, self.sample_rate)
|
||||
|
||||
# Ждем окончания воспроизведения в цикле
|
||||
while sd.get_stream().active:
|
||||
if self._interrupted:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
|
||||
finally:
|
||||
# Сообщаем потоку-наблюдателю, что пора завершаться
|
||||
self._stop_flag.set()
|
||||
checker_thread.join(timeout=0.5)
|
||||
|
||||
if self._interrupted:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
def was_interrupted(self) -> bool:
|
||||
"""Был ли прерван последний вызов speak."""
|
||||
return self._interrupted
|
||||
|
||||
|
||||
# Глобальный экземпляр TTS
|
||||
_tts = None
|
||||
|
||||
|
||||
def get_tts() -> TextToSpeech:
|
||||
"""Получить или создать экземпляр TTS."""
|
||||
global _tts
|
||||
if _tts is None:
|
||||
_tts = TextToSpeech()
|
||||
return _tts
|
||||
|
||||
|
||||
def speak(text: str, check_interrupt=None, language: str = "ru") -> bool:
|
||||
"""Внешняя функция для озвучивания."""
|
||||
return get_tts().speak(text, check_interrupt, language)
|
||||
|
||||
|
||||
def was_interrupted() -> bool:
|
||||
"""Проверка флага прерывания."""
|
||||
return get_tts().was_interrupted
|
||||
|
||||
|
||||
def initialize():
|
||||
"""Предварительная загрузка моделей."""
|
||||
get_tts().initialize()
|
||||
Reference in New Issue
Block a user