first commit
This commit is contained in:
178
tts.py
Normal file
178
tts.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Text-to-Speech module using Silero TTS.
|
||||
Generates natural Russian speech with Xenia voice.
|
||||
Supports interruption via wake word detection using threading.
|
||||
"""
|
||||
import torch
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from config import TTS_SPEAKER, TTS_SAMPLE_RATE
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
"""Text-to-Speech using Silero TTS with wake word interruption support."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.sample_rate = TTS_SAMPLE_RATE
|
||||
self.speaker = TTS_SPEAKER
|
||||
self._interrupted = False
|
||||
self._stop_flag = threading.Event()
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize Silero TTS model."""
|
||||
print("📦 Загрузка модели Silero TTS...")
|
||||
|
||||
# Load Silero TTS model
|
||||
self.model, _ = torch.hub.load(
|
||||
repo_or_dir='snakers4/silero-models',
|
||||
model='silero_tts',
|
||||
language='ru',
|
||||
speaker='v4_ru'
|
||||
)
|
||||
|
||||
print(f"✅ Модель TTS загружена (голос: {self.speaker})")
|
||||
|
||||
def speak(self, text: str, check_interrupt=None) -> bool:
|
||||
"""
|
||||
Convert text to speech and play it.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize and speak
|
||||
check_interrupt: Optional callback function that returns True if playback should stop
|
||||
|
||||
Returns:
|
||||
True if playback completed normally, False if interrupted
|
||||
"""
|
||||
if not text.strip():
|
||||
return True
|
||||
|
||||
if not self.model:
|
||||
self.initialize()
|
||||
|
||||
print(f"🔊 Озвучивание: {text[:50]}...")
|
||||
|
||||
self._interrupted = False
|
||||
self._stop_flag.clear()
|
||||
|
||||
try:
|
||||
# Generate audio
|
||||
audio = self.model.apply_tts(
|
||||
text=text,
|
||||
speaker=self.speaker,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
# Convert to numpy array
|
||||
audio_np = audio.numpy()
|
||||
|
||||
if check_interrupt:
|
||||
# Play with interrupt checking in parallel thread
|
||||
return self._play_with_interrupt(audio_np, check_interrupt)
|
||||
else:
|
||||
# Standard playback
|
||||
sd.play(audio_np, self.sample_rate)
|
||||
sd.wait()
|
||||
print("✅ Воспроизведение завершено")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка TTS: {e}")
|
||||
return False
|
||||
|
||||
def _check_interrupt_worker(self, check_interrupt):
|
||||
"""
|
||||
Worker thread that continuously checks for interrupt signal.
|
||||
"""
|
||||
while not self._stop_flag.is_set():
|
||||
try:
|
||||
if check_interrupt():
|
||||
self._interrupted = True
|
||||
sd.stop()
|
||||
print("⏹️ Воспроизведение прервано!")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _play_with_interrupt(self, audio_np: np.ndarray, check_interrupt) -> bool:
|
||||
"""
|
||||
Play audio with interrupt checking in parallel thread.
|
||||
|
||||
Args:
|
||||
audio_np: Audio data as numpy array
|
||||
check_interrupt: Callback that returns True if should interrupt
|
||||
|
||||
Returns:
|
||||
True if completed normally, False if interrupted
|
||||
"""
|
||||
# Start interrupt checker thread
|
||||
checker_thread = threading.Thread(
|
||||
target=self._check_interrupt_worker,
|
||||
args=(check_interrupt,),
|
||||
daemon=True
|
||||
)
|
||||
checker_thread.start()
|
||||
|
||||
try:
|
||||
# Play audio (non-blocking start)
|
||||
sd.play(audio_np, self.sample_rate)
|
||||
|
||||
# Wait for playback to finish or interrupt
|
||||
while sd.get_stream().active:
|
||||
if self._interrupted:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
|
||||
finally:
|
||||
# Signal checker thread to stop
|
||||
self._stop_flag.set()
|
||||
checker_thread.join(timeout=0.5)
|
||||
|
||||
if self._interrupted:
|
||||
return False
|
||||
|
||||
print("✅ Воспроизведение завершено")
|
||||
return True
|
||||
|
||||
@property
|
||||
def was_interrupted(self) -> bool:
|
||||
"""Check if the last playback was interrupted."""
|
||||
return self._interrupted
|
||||
|
||||
|
||||
# Global instance
|
||||
_tts = None
|
||||
|
||||
|
||||
def get_tts() -> TextToSpeech:
|
||||
"""Get or create TTS instance."""
|
||||
global _tts
|
||||
if _tts is None:
|
||||
_tts = TextToSpeech()
|
||||
return _tts
|
||||
|
||||
|
||||
def speak(text: str, check_interrupt=None) -> bool:
|
||||
"""
|
||||
Synthesize and speak the given text.
|
||||
|
||||
Args:
|
||||
text: Text to speak
|
||||
check_interrupt: Optional callback for interrupt checking
|
||||
|
||||
Returns:
|
||||
True if completed normally, False if interrupted
|
||||
"""
|
||||
return get_tts().speak(text, check_interrupt)
|
||||
|
||||
|
||||
def was_interrupted() -> bool:
|
||||
"""Check if the last speak() call was interrupted."""
|
||||
return get_tts().was_interrupted
|
||||
|
||||
|
||||
def initialize():
|
||||
"""Pre-initialize TTS model."""
|
||||
get_tts().initialize()
|
||||
Reference in New Issue
Block a user