273 lines
8.6 KiB
Python
273 lines
8.6 KiB
Python
"""
|
|
Text-to-Speech module using Silero TTS.
|
|
Generates natural Russian speech.
|
|
Supports interruption via wake word detection using threading.
|
|
"""
|
|
|
|
import torch
|
|
import sounddevice as sd
|
|
import numpy as np
|
|
import threading
|
|
import time
|
|
import warnings
|
|
import re
|
|
from config import TTS_SPEAKER, TTS_EN_SPEAKER, TTS_SAMPLE_RATE
|
|
|
|
# Suppress Silero TTS warning about text length
|
|
warnings.filterwarnings("ignore", message="Text string is longer than 1000 symbols")
|
|
|
|
|
|
class TextToSpeech:
|
|
"""Text-to-Speech using Silero TTS with wake word interruption support."""
|
|
|
|
def __init__(self):
|
|
self.models = {}
|
|
self.sample_rate = TTS_SAMPLE_RATE
|
|
self.speakers = {
|
|
"ru": TTS_SPEAKER,
|
|
"en": TTS_EN_SPEAKER,
|
|
}
|
|
self._interrupted = False
|
|
self._stop_flag = threading.Event()
|
|
|
|
def _load_model(self, language: str):
|
|
"""Load and cache Silero TTS model for the given language."""
|
|
if language in self.models:
|
|
return self.models[language]
|
|
|
|
model_config = {
|
|
"ru": {"language": "ru", "model_id": "v5_ru"},
|
|
"en": {"language": "en", "model_id": "v3_en"},
|
|
}
|
|
|
|
if language not in model_config:
|
|
raise ValueError(f"Unsupported TTS language: {language}")
|
|
|
|
config = model_config[language]
|
|
print(f"📦 Загрузка модели Silero TTS ({language})...")
|
|
|
|
device = torch.device("cpu")
|
|
model, _ = torch.hub.load(
|
|
repo_or_dir="snakers4/silero-models",
|
|
model="silero_tts",
|
|
language=config["language"],
|
|
speaker=config["model_id"],
|
|
)
|
|
model.to(device)
|
|
|
|
self.models[language] = model
|
|
return model
|
|
|
|
def _get_speaker(self, language: str, model) -> str:
|
|
"""Return a valid speaker for the loaded model."""
|
|
speaker = self.speakers.get(language)
|
|
if hasattr(model, "speakers") and speaker not in model.speakers:
|
|
fallback = model.speakers[0] if model.speakers else speaker
|
|
print(f"⚠️ Голос '{speaker}' недоступен, использую '{fallback}'")
|
|
return fallback
|
|
return speaker
|
|
|
|
def initialize(self):
|
|
"""Initialize default (Russian) TTS model."""
|
|
self._load_model("ru")
|
|
|
|
def _split_text(self, text: str, max_length: int = 900) -> list[str]:
|
|
"""Split text into chunks smaller than max_length."""
|
|
if len(text) <= max_length:
|
|
return [text]
|
|
|
|
chunks = []
|
|
# Split by sentence endings, keeping the punctuation
|
|
# pattern matches [.!?] followed by optional newlines
|
|
parts = re.split(r"([.!?]+\s*)", text)
|
|
|
|
current_chunk = ""
|
|
# Reconstruct sentences. re.split with groups returns [text, delimiter, text, delimiter...]
|
|
# We iterate through parts. If part is a delimiter (matches pattern), we append to previous text.
|
|
|
|
for part in parts:
|
|
# If the part combined with current_chunk exceeds max_length, save current_chunk
|
|
if len(current_chunk) + len(part) > max_length:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = ""
|
|
|
|
current_chunk += part
|
|
|
|
# If even a single part is too big (very long sentence without punctuation), force split
|
|
while len(current_chunk) > max_length:
|
|
# Try to split by space
|
|
split_idx = current_chunk.rfind(" ", 0, max_length)
|
|
if split_idx == -1:
|
|
# No space found, hard cut
|
|
split_idx = max_length
|
|
|
|
chunks.append(current_chunk[:split_idx].strip())
|
|
current_chunk = current_chunk[split_idx:].lstrip()
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
# Filter empty chunks
|
|
return [c for c in chunks if c]
|
|
|
|
def speak(self, text: str, check_interrupt=None, language: str = "ru") -> bool:
|
|
"""
|
|
Convert text to speech and play it.
|
|
|
|
Args:
|
|
text: Text to synthesize and speak
|
|
check_interrupt: Optional callback function that returns True if playback should stop
|
|
language: Language code for voice selection ("ru" or "en")
|
|
|
|
Returns:
|
|
True if playback completed normally, False if interrupted
|
|
"""
|
|
if not text.strip():
|
|
return True
|
|
|
|
model = self._load_model(language)
|
|
speaker = self._get_speaker(language, model)
|
|
|
|
# Split text into manageable chunks
|
|
chunks = self._split_text(text)
|
|
total_chunks = len(chunks)
|
|
|
|
if total_chunks > 1:
|
|
print(f"🔊 Озвучивание (частей: {total_chunks}): {text[:50]}...")
|
|
else:
|
|
print(f"🔊 Озвучивание: {text[:50]}...")
|
|
|
|
self._interrupted = False
|
|
self._stop_flag.clear()
|
|
|
|
success = True
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
if self._interrupted:
|
|
break
|
|
|
|
try:
|
|
# Generate audio for chunk
|
|
audio = model.apply_tts(
|
|
text=chunk, speaker=speaker, sample_rate=self.sample_rate
|
|
)
|
|
|
|
# Convert to numpy array
|
|
audio_np = audio.numpy()
|
|
|
|
if check_interrupt:
|
|
# Play with interrupt checking in parallel thread
|
|
if not self._play_with_interrupt(audio_np, check_interrupt):
|
|
success = False
|
|
break
|
|
else:
|
|
# Standard playback
|
|
sd.play(audio_np, self.sample_rate)
|
|
sd.wait()
|
|
|
|
except Exception as e:
|
|
print(f"❌ Ошибка TTS (часть {i + 1}/{total_chunks}): {e}")
|
|
success = False
|
|
|
|
if success and not self._interrupted:
|
|
print("✅ Воспроизведение завершено")
|
|
return True
|
|
elif self._interrupted:
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
def _check_interrupt_worker(self, check_interrupt):
|
|
"""
|
|
Worker thread that continuously checks for interrupt signal.
|
|
"""
|
|
while not self._stop_flag.is_set():
|
|
try:
|
|
if check_interrupt():
|
|
self._interrupted = True
|
|
sd.stop()
|
|
print("⏹️ Воспроизведение прервано!")
|
|
return
|
|
except Exception:
|
|
pass
|
|
|
|
def _play_with_interrupt(self, audio_np: np.ndarray, check_interrupt) -> bool:
|
|
"""
|
|
Play audio with interrupt checking in parallel thread.
|
|
|
|
Args:
|
|
audio_np: Audio data as numpy array
|
|
check_interrupt: Callback that returns True if should interrupt
|
|
|
|
Returns:
|
|
True if completed normally, False if interrupted
|
|
"""
|
|
# Start interrupt checker thread
|
|
checker_thread = threading.Thread(
|
|
target=self._check_interrupt_worker, args=(check_interrupt,), daemon=True
|
|
)
|
|
checker_thread.start()
|
|
|
|
try:
|
|
# Play audio (non-blocking start)
|
|
sd.play(audio_np, self.sample_rate)
|
|
|
|
# Wait for playback to finish or interrupt
|
|
while sd.get_stream().active:
|
|
if self._interrupted:
|
|
break
|
|
time.sleep(0.05)
|
|
|
|
finally:
|
|
# Signal checker thread to stop
|
|
self._stop_flag.set()
|
|
checker_thread.join(timeout=0.5)
|
|
|
|
if self._interrupted:
|
|
return False
|
|
|
|
return True
|
|
|
|
@property
|
|
def was_interrupted(self) -> bool:
|
|
"""Check if the last playback was interrupted."""
|
|
return self._interrupted
|
|
|
|
|
|
# Global instance
|
|
_tts = None
|
|
|
|
|
|
def get_tts() -> TextToSpeech:
|
|
"""Get or create TTS instance."""
|
|
global _tts
|
|
if _tts is None:
|
|
_tts = TextToSpeech()
|
|
return _tts
|
|
|
|
|
|
def speak(text: str, check_interrupt=None, language: str = "ru") -> bool:
|
|
"""
|
|
Synthesize and speak the given text.
|
|
|
|
Args:
|
|
text: Text to speak
|
|
check_interrupt: Optional callback for interrupt checking
|
|
language: Language code for voice selection ("ru" or "en")
|
|
|
|
Returns:
|
|
True if completed normally, False if interrupted
|
|
"""
|
|
return get_tts().speak(text, check_interrupt, language)
|
|
|
|
|
|
def was_interrupted() -> bool:
|
|
"""Check if the last speak() call was interrupted."""
|
|
return get_tts().was_interrupted
|
|
|
|
|
|
def initialize():
|
|
"""Pre-initialize TTS model."""
|
|
get_tts().initialize()
|