add realtime tts streaming
This commit is contained in:
parent
9e04e2c5de
commit
72f7d140d4
File diff suppressed because it is too large
Load Diff
|
@ -33,19 +33,20 @@ python-crontab = "^3.0.0"
|
|||
inquirer = "^3.2.4"
|
||||
pyqrcode = "^1.2.1"
|
||||
realtimestt = "^0.1.12"
|
||||
realtimetts = "^0.3.44"
|
||||
realtimetts = "^0.4.1"
|
||||
keyboard = "^0.13.5"
|
||||
pyautogui = "^0.9.54"
|
||||
ctranslate2 = "4.1.0"
|
||||
py3-tts = "^3.5"
|
||||
elevenlabs = "0.2.27"
|
||||
elevenlabs = "1.2.2"
|
||||
groq = "^0.5.0"
|
||||
open-interpreter = "^0.2.5"
|
||||
open-interpreter = "^0.2.6"
|
||||
litellm = "1.35.35"
|
||||
openai = "1.13.3"
|
||||
openai = "1.30.5"
|
||||
pywebview = "*"
|
||||
pyobjc = "*"
|
||||
|
||||
sentry-sdk = "^2.4.0"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
@ -2,6 +2,7 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
|
@ -46,7 +47,7 @@ accumulator = Accumulator()
|
|||
CHUNK = 1024 # Record in chunks of 1024 samples
|
||||
FORMAT = pyaudio.paInt16 # 16 bits per sample
|
||||
CHANNELS = 1 # Mono
|
||||
RATE = 44100 # Sample rate
|
||||
RATE = 16000 # Sample rate
|
||||
RECORDING = False # Flag to control recording state
|
||||
SPACEBAR_PRESSED = False # Flag to track spacebar press state
|
||||
|
||||
|
@ -86,10 +87,10 @@ class Device:
|
|||
def __init__(self):
|
||||
self.pressed_keys = set()
|
||||
self.captured_images = []
|
||||
self.audiosegments = []
|
||||
self.audiosegments = asyncio.Queue()
|
||||
self.server_url = ""
|
||||
self.ctrl_pressed = False
|
||||
# self.latency = None
|
||||
self.playback_latency = None
|
||||
|
||||
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
|
||||
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
|
||||
|
@ -153,14 +154,26 @@ class Device:
|
|||
"""Plays them sequentially."""
|
||||
while True:
|
||||
try:
|
||||
for audio in self.audiosegments:
|
||||
# if self.latency:
|
||||
# elapsed_time = time.time() - self.latency
|
||||
# print(f"Time from request to playback: {elapsed_time} seconds")
|
||||
# self.latency = None
|
||||
play(audio)
|
||||
self.audiosegments.remove(audio)
|
||||
await asyncio.sleep(0.1)
|
||||
audio = await self.audiosegments.get()
|
||||
# print("got audio segment!!!!")
|
||||
if self.playback_latency:
|
||||
elapsed_time = time.time() - self.playback_latency
|
||||
print(f"Time from request to playback: {elapsed_time} seconds")
|
||||
self.playback_latency = None
|
||||
|
||||
args = ["ffplay", "-autoexit", "-", "-nodisp"]
|
||||
proc = subprocess.Popen(
|
||||
args=args,
|
||||
stdout=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, err = proc.communicate(input=audio)
|
||||
proc.poll()
|
||||
|
||||
# play(audio)
|
||||
# self.audiosegments.remove(audio)
|
||||
# await asyncio.sleep(0.1)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
# This happens once at the start?
|
||||
pass
|
||||
|
@ -208,7 +221,7 @@ class Device:
|
|||
stream.stop_stream()
|
||||
stream.close()
|
||||
print("Recording stopped.")
|
||||
# self.latency = time.time()
|
||||
self.playback_latency = time.time()
|
||||
|
||||
duration = wav_file.getnframes() / RATE
|
||||
if duration < 0.3:
|
||||
|
@ -315,18 +328,21 @@ class Device:
|
|||
|
||||
async def message_sender(self, websocket):
|
||||
while True:
|
||||
message = await asyncio.get_event_loop().run_in_executor(
|
||||
None, send_queue.get
|
||||
)
|
||||
try:
|
||||
message = await asyncio.get_event_loop().run_in_executor(
|
||||
None, send_queue.get
|
||||
)
|
||||
|
||||
if isinstance(message, bytes):
|
||||
await websocket.send(message)
|
||||
if isinstance(message, bytes):
|
||||
await websocket.send(message)
|
||||
|
||||
else:
|
||||
await websocket.send(json.dumps(message))
|
||||
else:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
send_queue.task_done()
|
||||
await asyncio.sleep(0.01)
|
||||
send_queue.task_done()
|
||||
await asyncio.sleep(0.01)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
async def websocket_communication(self, WS_URL):
|
||||
print("websocket communication was called!!!!")
|
||||
|
@ -343,7 +359,7 @@ class Device:
|
|||
asyncio.create_task(self.message_sender(websocket))
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.sleep(0.0001)
|
||||
chunk = await websocket.recv()
|
||||
|
||||
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
|
||||
|
@ -351,31 +367,38 @@ class Device:
|
|||
if type(chunk) == str:
|
||||
chunk = json.loads(chunk)
|
||||
|
||||
message = accumulator.accumulate(chunk)
|
||||
# message = accumulator.accumulate(chunk)
|
||||
message = chunk
|
||||
if message == None:
|
||||
# Will be None until we have a full message ready
|
||||
continue
|
||||
|
||||
# At this point, we have our message
|
||||
# print("checkpoint reached!", message)
|
||||
if isinstance(message, bytes):
|
||||
|
||||
if message["type"] == "audio" and message["format"].startswith("bytes"):
|
||||
# if message["type"] == "audio" and message["format"].startswith("bytes"):
|
||||
# Convert bytes to audio file
|
||||
|
||||
audio_bytes = message["content"]
|
||||
# audio_bytes = message["content"]
|
||||
audio_bytes = message
|
||||
|
||||
# Create an AudioSegment instance with the raw data
|
||||
"""
|
||||
audio = AudioSegment(
|
||||
# raw audio data (bytes)
|
||||
data=audio_bytes,
|
||||
# signed 16-bit little-endian format
|
||||
sample_width=2,
|
||||
# 16,000 Hz frame rate
|
||||
frame_rate=16000,
|
||||
# 24,000 Hz frame rate
|
||||
frame_rate=24000,
|
||||
# mono sound
|
||||
channels=1,
|
||||
)
|
||||
"""
|
||||
|
||||
self.audiosegments.append(audio)
|
||||
# print("audio segment was created")
|
||||
await self.audiosegments.put(audio_bytes)
|
||||
|
||||
# Run the code if that's the client's job
|
||||
if os.getenv("CODE_RUNNER") == "client":
|
||||
|
@ -399,6 +422,7 @@ class Device:
|
|||
while True:
|
||||
try:
|
||||
async with websockets.connect(WS_URL) as websocket:
|
||||
print("awaiting exec_ws_communication")
|
||||
await exec_ws_communication(websocket)
|
||||
except:
|
||||
logger.info(traceback.format_exc())
|
||||
|
@ -410,7 +434,7 @@ class Device:
|
|||
async def start_async(self):
|
||||
print("start async was called!!!!!")
|
||||
# Configuration for WebSocket
|
||||
WS_URL = f"ws://{self.server_url}"
|
||||
WS_URL = f"ws://{self.server_url}/ws"
|
||||
# Start the WebSocket communication
|
||||
asyncio.create_task(self.websocket_communication(WS_URL))
|
||||
|
||||
|
|
|
@ -12,7 +12,14 @@
|
|||
###
|
||||
|
||||
from pynput import keyboard
|
||||
from RealtimeTTS import TextToAudioStream, OpenAIEngine, CoquiEngine
|
||||
from RealtimeTTS import (
|
||||
TextToAudioStream,
|
||||
OpenAIEngine,
|
||||
CoquiEngine,
|
||||
ElevenlabsEngine,
|
||||
SystemEngine,
|
||||
GTTSEngine,
|
||||
)
|
||||
from RealtimeSTT import AudioToTextRecorder
|
||||
import time
|
||||
import asyncio
|
||||
|
@ -21,11 +28,14 @@ import json
|
|||
|
||||
class AsyncInterpreter:
|
||||
def __init__(self, interpreter):
|
||||
self.stt_latency = None
|
||||
self.tts_latency = None
|
||||
self.interpreter_latency = None
|
||||
self.interpreter = interpreter
|
||||
|
||||
# STT
|
||||
self.stt = AudioToTextRecorder(
|
||||
model="tiny", spinner=False, use_microphone=False
|
||||
model="tiny.en", spinner=False, use_microphone=False
|
||||
)
|
||||
self.stt.stop() # It needs this for some reason
|
||||
|
||||
|
@ -34,6 +44,16 @@ class AsyncInterpreter:
|
|||
engine = CoquiEngine()
|
||||
elif self.interpreter.tts == "openai":
|
||||
engine = OpenAIEngine()
|
||||
elif self.interpreter.tts == "gtts":
|
||||
engine = GTTSEngine()
|
||||
elif self.interpreter.tts == "elevenlabs":
|
||||
engine = ElevenlabsEngine(
|
||||
api_key="sk_077cb1cabdf67e62b85f8782e66e5d8e11f78b450c7ce171"
|
||||
)
|
||||
elif self.interpreter.tts == "system":
|
||||
engine = SystemEngine()
|
||||
else:
|
||||
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
|
||||
self.tts = TextToAudioStream(engine)
|
||||
|
||||
self.active_chat_messages = []
|
||||
|
@ -112,7 +132,11 @@ class AsyncInterpreter:
|
|||
|
||||
# print("INPUT QUEUE:", input_queue)
|
||||
# message = [i for i in input_queue if i["type"] == "message"][0]["content"]
|
||||
start_stt = time.time()
|
||||
message = self.stt.text()
|
||||
end_stt = time.time()
|
||||
self.stt_latency = end_stt - start_stt
|
||||
print("STT LATENCY", self.stt_latency)
|
||||
|
||||
# print(message)
|
||||
|
||||
|
@ -141,7 +165,7 @@ class AsyncInterpreter:
|
|||
|
||||
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
|
||||
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
|
||||
|
||||
print("yielding this", content)
|
||||
yield content
|
||||
|
||||
# Handle code blocks
|
||||
|
@ -172,17 +196,24 @@ class AsyncInterpreter:
|
|||
)
|
||||
|
||||
# Send a completion signal
|
||||
|
||||
end_interpreter = time.time()
|
||||
self.interpreter_latency = end_interpreter - start_interpreter
|
||||
print("INTERPRETER LATENCY", self.interpreter_latency)
|
||||
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
|
||||
|
||||
# Feed generate to RealtimeTTS
|
||||
self.add_to_output_queue_sync(
|
||||
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
|
||||
)
|
||||
self.tts.feed(generate(message))
|
||||
start_interpreter = time.time()
|
||||
text_iterator = generate(message)
|
||||
|
||||
self.tts.feed(text_iterator)
|
||||
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
|
||||
|
||||
while True:
|
||||
if self.tts.is_playing():
|
||||
start_tts = time.time()
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
while True:
|
||||
|
@ -197,6 +228,9 @@ class AsyncInterpreter:
|
|||
"end": True,
|
||||
}
|
||||
)
|
||||
end_tts = time.time()
|
||||
self.tts_latency = end_tts - start_tts
|
||||
print("TTS LATENCY", self.tts_latency)
|
||||
break
|
||||
|
||||
async def _on_tts_chunk_async(self, chunk):
|
||||
|
@ -204,6 +238,7 @@ class AsyncInterpreter:
|
|||
await self._add_to_queue(self._output_queue, chunk)
|
||||
|
||||
def on_tts_chunk(self, chunk):
|
||||
# print("ye")
|
||||
asyncio.run(self._on_tts_chunk_async(chunk))
|
||||
|
||||
async def output(self):
|
||||
|
|
|
@ -12,6 +12,18 @@ from pydantic import BaseModel
|
|||
import argparse
|
||||
import os
|
||||
|
||||
# import sentry_sdk
|
||||
|
||||
base_interpreter.system_message = (
|
||||
"You are a helpful assistant that can answer questions and help with tasks."
|
||||
)
|
||||
base_interpreter.computer.import_computer_api = False
|
||||
base_interpreter.llm.model = "groq/mixtral-8x7b-32768"
|
||||
base_interpreter.llm.api_key = (
|
||||
"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"
|
||||
)
|
||||
base_interpreter.llm.supports_functions = False
|
||||
|
||||
os.environ["STT_RUNNER"] = "server"
|
||||
os.environ["TTS_RUNNER"] = "server"
|
||||
|
||||
|
@ -20,11 +32,24 @@ parser = argparse.ArgumentParser(description="FastAPI server.")
|
|||
parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_interpreter.tts = "openai"
|
||||
base_interpreter.llm.model = "gpt-4-turbo"
|
||||
base_interpreter.tts = "elevenlabs"
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
sentry_sdk.init(
|
||||
dsn="https://a1465f62a31c7dfb23e1616da86341e9@o4506046614667264.ingest.us.sentry.io/4507374662385664",
|
||||
enable_tracing=True,
|
||||
# Set traces_sample_rate to 1.0 to capture 100%
|
||||
# of transactions for performance monitoring.
|
||||
traces_sample_rate=1.0,
|
||||
# Set profiles_sample_rate to 1.0 to profile 100%
|
||||
# of sampled transactions.
|
||||
# We recommend adjusting this value in production.
|
||||
profiles_sample_rate=1.0,
|
||||
)
|
||||
"""
|
||||
|
||||
interpreter = AsyncInterpreter(base_interpreter)
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -51,6 +76,9 @@ async def main():
|
|||
|
||||
async def receive_input():
|
||||
while True:
|
||||
if websocket.client_state == "DISCONNECTED":
|
||||
break
|
||||
|
||||
data = await websocket.receive()
|
||||
|
||||
if isinstance(data, bytes):
|
||||
|
@ -65,19 +93,23 @@ async def main():
|
|||
async def send_output():
|
||||
while True:
|
||||
output = await interpreter.output()
|
||||
|
||||
if isinstance(output, bytes):
|
||||
# print(f"Sending {len(output)} bytes of audio data.")
|
||||
await websocket.send_bytes(output)
|
||||
# we dont send out bytes rn, no TTS
|
||||
pass
|
||||
|
||||
elif isinstance(output, dict):
|
||||
# print("sending text")
|
||||
await websocket.send_text(json.dumps(output))
|
||||
|
||||
await asyncio.gather(receive_input(), send_output())
|
||||
await asyncio.gather(send_output(), receive_input())
|
||||
except Exception as e:
|
||||
print(f"WebSocket connection closed with exception: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
await websocket.close()
|
||||
if not websocket.client_state == "DISCONNECTED":
|
||||
await websocket.close()
|
||||
|
||||
config = Config(app, host="0.0.0.0", port=8000, lifespan="on")
|
||||
server = Server(config)
|
||||
|
|
|
@ -23,6 +23,7 @@ from .utils.logs import logger
|
|||
import base64
|
||||
import shutil
|
||||
from ..utils.print_markdown import print_markdown
|
||||
import time
|
||||
|
||||
os.environ["STT_RUNNER"] = "server"
|
||||
os.environ["TTS_RUNNER"] = "server"
|
||||
|
@ -383,6 +384,7 @@ async def stream_tts_to_device(sentence, mobile: bool):
|
|||
|
||||
|
||||
def stream_tts(sentence, mobile: bool):
|
||||
|
||||
audio_file = tts(sentence, mobile)
|
||||
|
||||
# Read the entire WAV file
|
||||
|
|
|
@ -5,7 +5,7 @@ import threading
|
|||
import os
|
||||
import importlib
|
||||
from source.server.tunnel import create_tunnel
|
||||
from source.server.server import main
|
||||
from source.server.async_server import main
|
||||
from source.server.utils.local_mode import select_local_model
|
||||
|
||||
import signal
|
||||
|
@ -152,18 +152,18 @@ def _run(
|
|||
target=loop.run_until_complete,
|
||||
args=(
|
||||
main(
|
||||
server_host,
|
||||
server_port,
|
||||
llm_service,
|
||||
model,
|
||||
llm_supports_vision,
|
||||
llm_supports_functions,
|
||||
context_window,
|
||||
max_tokens,
|
||||
temperature,
|
||||
tts_service,
|
||||
stt_service,
|
||||
mobile,
|
||||
# server_host,
|
||||
# server_port,
|
||||
# llm_service,
|
||||
# model,
|
||||
# llm_supports_vision,
|
||||
# llm_supports_functions,
|
||||
# context_window,
|
||||
# max_tokens,
|
||||
# temperature,
|
||||
# tts_service,
|
||||
# stt_service,
|
||||
# mobile,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -196,7 +196,7 @@ def _run(
|
|||
module = importlib.import_module(
|
||||
f".clients.{client_type}.device", package="source"
|
||||
)
|
||||
|
||||
server_url = "0.0.0.0:8000"
|
||||
client_thread = threading.Thread(target=module.main, args=[server_url])
|
||||
print("client thread started")
|
||||
client_thread.start()
|
||||
|
|
Loading…
Reference in New Issue