add realtime tts streaming
This commit is contained in:
		
							parent
							
								
									9e04e2c5de
								
							
						
					
					
						commit
						72f7d140d4
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -33,19 +33,20 @@ python-crontab = "^3.0.0" | |||
| inquirer = "^3.2.4" | ||||
| pyqrcode = "^1.2.1" | ||||
| realtimestt = "^0.1.12" | ||||
| realtimetts = "^0.3.44" | ||||
| realtimetts = "^0.4.1" | ||||
| keyboard = "^0.13.5" | ||||
| pyautogui = "^0.9.54" | ||||
| ctranslate2 = "4.1.0" | ||||
| py3-tts = "^3.5" | ||||
| elevenlabs = "0.2.27" | ||||
| elevenlabs = "1.2.2" | ||||
| groq = "^0.5.0" | ||||
| open-interpreter = "^0.2.5" | ||||
| open-interpreter = "^0.2.6" | ||||
| litellm = "1.35.35" | ||||
| openai = "1.13.3" | ||||
| openai = "1.30.5" | ||||
| pywebview = "*" | ||||
| pyobjc = "*" | ||||
| 
 | ||||
| sentry-sdk = "^2.4.0" | ||||
| [build-system] | ||||
| requires = ["poetry-core"] | ||||
| build-backend = "poetry.core.masonry.api" | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ from dotenv import load_dotenv | |||
| 
 | ||||
| load_dotenv()  # take environment variables from .env. | ||||
| 
 | ||||
| import subprocess | ||||
| import os | ||||
| import sys | ||||
| import asyncio | ||||
|  | @ -46,7 +47,7 @@ accumulator = Accumulator() | |||
| CHUNK = 1024  # Record in chunks of 1024 samples | ||||
| FORMAT = pyaudio.paInt16  # 16 bits per sample | ||||
| CHANNELS = 1  # Mono | ||||
| RATE = 44100  # Sample rate | ||||
| RATE = 16000  # Sample rate | ||||
| RECORDING = False  # Flag to control recording state | ||||
| SPACEBAR_PRESSED = False  # Flag to track spacebar press state | ||||
| 
 | ||||
|  | @ -86,10 +87,10 @@ class Device: | |||
|     def __init__(self): | ||||
|         self.pressed_keys = set() | ||||
|         self.captured_images = [] | ||||
|         self.audiosegments = [] | ||||
|         self.audiosegments = asyncio.Queue() | ||||
|         self.server_url = "" | ||||
|         self.ctrl_pressed = False | ||||
|         # self.latency = None | ||||
|         self.playback_latency = None | ||||
| 
 | ||||
|     def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): | ||||
|         """Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list.""" | ||||
|  | @ -153,14 +154,26 @@ class Device: | |||
|         """Plays them sequentially.""" | ||||
|         while True: | ||||
|             try: | ||||
|                 for audio in self.audiosegments: | ||||
|                     # if self.latency: | ||||
|                     #    elapsed_time = time.time() - self.latency | ||||
|                     #    print(f"Time from request to playback: {elapsed_time} seconds") | ||||
|                     #    self.latency = None | ||||
|                     play(audio) | ||||
|                     self.audiosegments.remove(audio) | ||||
|                 await asyncio.sleep(0.1) | ||||
|                 audio = await self.audiosegments.get() | ||||
|                 # print("got audio segment!!!!") | ||||
|                 if self.playback_latency: | ||||
|                     elapsed_time = time.time() - self.playback_latency | ||||
|                     print(f"Time from request to playback: {elapsed_time} seconds") | ||||
|                     self.playback_latency = None | ||||
| 
 | ||||
|                 args = ["ffplay", "-autoexit", "-", "-nodisp"] | ||||
|                 proc = subprocess.Popen( | ||||
|                     args=args, | ||||
|                     stdout=subprocess.PIPE, | ||||
|                     stdin=subprocess.PIPE, | ||||
|                     stderr=subprocess.PIPE, | ||||
|                 ) | ||||
|                 out, err = proc.communicate(input=audio) | ||||
|                 proc.poll() | ||||
| 
 | ||||
|                 # play(audio) | ||||
|                 # self.audiosegments.remove(audio) | ||||
|                 # await asyncio.sleep(0.1) | ||||
|             except asyncio.exceptions.CancelledError: | ||||
|                 # This happens once at the start? | ||||
|                 pass | ||||
|  | @ -208,7 +221,7 @@ class Device: | |||
|         stream.stop_stream() | ||||
|         stream.close() | ||||
|         print("Recording stopped.") | ||||
|         # self.latency = time.time() | ||||
|         self.playback_latency = time.time() | ||||
| 
 | ||||
|         duration = wav_file.getnframes() / RATE | ||||
|         if duration < 0.3: | ||||
|  | @ -315,6 +328,7 @@ class Device: | |||
| 
 | ||||
|     async def message_sender(self, websocket): | ||||
|         while True: | ||||
|             try: | ||||
|                 message = await asyncio.get_event_loop().run_in_executor( | ||||
|                     None, send_queue.get | ||||
|                 ) | ||||
|  | @ -327,6 +341,8 @@ class Device: | |||
| 
 | ||||
|                 send_queue.task_done() | ||||
|                 await asyncio.sleep(0.01) | ||||
|             except: | ||||
|                 traceback.print_exc() | ||||
| 
 | ||||
|     async def websocket_communication(self, WS_URL): | ||||
|         print("websocket communication was called!!!!") | ||||
|  | @ -343,7 +359,7 @@ class Device: | |||
|             asyncio.create_task(self.message_sender(websocket)) | ||||
| 
 | ||||
|             while True: | ||||
|                 await asyncio.sleep(0.01) | ||||
|                 await asyncio.sleep(0.0001) | ||||
|                 chunk = await websocket.recv() | ||||
| 
 | ||||
|                 logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") | ||||
|  | @ -351,31 +367,38 @@ class Device: | |||
|                 if type(chunk) == str: | ||||
|                     chunk = json.loads(chunk) | ||||
| 
 | ||||
|                 message = accumulator.accumulate(chunk) | ||||
|                 # message = accumulator.accumulate(chunk) | ||||
|                 message = chunk | ||||
|                 if message == None: | ||||
|                     # Will be None until we have a full message ready | ||||
|                     continue | ||||
| 
 | ||||
|                 # At this point, we have our message | ||||
|                 # print("checkpoint reached!", message) | ||||
|                 if isinstance(message, bytes): | ||||
| 
 | ||||
|                 if message["type"] == "audio" and message["format"].startswith("bytes"): | ||||
|                     # if message["type"] == "audio" and message["format"].startswith("bytes"): | ||||
|                     # Convert bytes to audio file | ||||
| 
 | ||||
|                     audio_bytes = message["content"] | ||||
|                     # audio_bytes = message["content"] | ||||
|                     audio_bytes = message | ||||
| 
 | ||||
|                     # Create an AudioSegment instance with the raw data | ||||
|                     """ | ||||
|                     audio = AudioSegment( | ||||
|                         # raw audio data (bytes) | ||||
|                         data=audio_bytes, | ||||
|                         # signed 16-bit little-endian format | ||||
|                         sample_width=2, | ||||
|                         # 16,000 Hz frame rate | ||||
|                         frame_rate=16000, | ||||
|                         # 24,000 Hz frame rate | ||||
|                         frame_rate=24000, | ||||
|                         # mono sound | ||||
|                         channels=1, | ||||
|                     ) | ||||
|                     """ | ||||
| 
 | ||||
|                     self.audiosegments.append(audio) | ||||
|                     # print("audio segment was created") | ||||
|                     await self.audiosegments.put(audio_bytes) | ||||
| 
 | ||||
|                 # Run the code if that's the client's job | ||||
|                 if os.getenv("CODE_RUNNER") == "client": | ||||
|  | @ -399,6 +422,7 @@ class Device: | |||
|             while True: | ||||
|                 try: | ||||
|                     async with websockets.connect(WS_URL) as websocket: | ||||
|                         print("awaiting exec_ws_communication") | ||||
|                         await exec_ws_communication(websocket) | ||||
|                 except: | ||||
|                     logger.info(traceback.format_exc()) | ||||
|  | @ -410,7 +434,7 @@ class Device: | |||
|     async def start_async(self): | ||||
|         print("start async was called!!!!!") | ||||
|         # Configuration for WebSocket | ||||
|         WS_URL = f"ws://{self.server_url}" | ||||
|         WS_URL = f"ws://{self.server_url}/ws" | ||||
|         # Start the WebSocket communication | ||||
|         asyncio.create_task(self.websocket_communication(WS_URL)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,7 +12,14 @@ | |||
| ### | ||||
| 
 | ||||
| from pynput import keyboard | ||||
| from RealtimeTTS import TextToAudioStream, OpenAIEngine, CoquiEngine | ||||
| from RealtimeTTS import ( | ||||
|     TextToAudioStream, | ||||
|     OpenAIEngine, | ||||
|     CoquiEngine, | ||||
|     ElevenlabsEngine, | ||||
|     SystemEngine, | ||||
|     GTTSEngine, | ||||
| ) | ||||
| from RealtimeSTT import AudioToTextRecorder | ||||
| import time | ||||
| import asyncio | ||||
|  | @ -21,11 +28,14 @@ import json | |||
| 
 | ||||
| class AsyncInterpreter: | ||||
|     def __init__(self, interpreter): | ||||
|         self.stt_latency = None | ||||
|         self.tts_latency = None | ||||
|         self.interpreter_latency = None | ||||
|         self.interpreter = interpreter | ||||
| 
 | ||||
|         # STT | ||||
|         self.stt = AudioToTextRecorder( | ||||
|             model="tiny", spinner=False, use_microphone=False | ||||
|             model="tiny.en", spinner=False, use_microphone=False | ||||
|         ) | ||||
|         self.stt.stop()  # It needs this for some reason | ||||
| 
 | ||||
|  | @ -34,6 +44,16 @@ class AsyncInterpreter: | |||
|             engine = CoquiEngine() | ||||
|         elif self.interpreter.tts == "openai": | ||||
|             engine = OpenAIEngine() | ||||
|         elif self.interpreter.tts == "gtts": | ||||
|             engine = GTTSEngine() | ||||
|         elif self.interpreter.tts == "elevenlabs": | ||||
|             engine = ElevenlabsEngine( | ||||
|                 api_key="sk_077cb1cabdf67e62b85f8782e66e5d8e11f78b450c7ce171" | ||||
|             ) | ||||
|         elif self.interpreter.tts == "system": | ||||
|             engine = SystemEngine() | ||||
|         else: | ||||
|             raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}") | ||||
|         self.tts = TextToAudioStream(engine) | ||||
| 
 | ||||
|         self.active_chat_messages = [] | ||||
|  | @ -112,7 +132,11 @@ class AsyncInterpreter: | |||
| 
 | ||||
|         # print("INPUT QUEUE:", input_queue) | ||||
|         # message = [i for i in input_queue if i["type"] == "message"][0]["content"] | ||||
|         start_stt = time.time() | ||||
|         message = self.stt.text() | ||||
|         end_stt = time.time() | ||||
|         self.stt_latency = end_stt - start_stt | ||||
|         print("STT LATENCY", self.stt_latency) | ||||
| 
 | ||||
|         # print(message) | ||||
| 
 | ||||
|  | @ -141,7 +165,7 @@ class AsyncInterpreter: | |||
| 
 | ||||
|                         # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer | ||||
|                         # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") | ||||
| 
 | ||||
|                         print("yielding this", content) | ||||
|                         yield content | ||||
| 
 | ||||
|                 # Handle code blocks | ||||
|  | @ -172,17 +196,24 @@ class AsyncInterpreter: | |||
|                             ) | ||||
| 
 | ||||
|             # Send a completion signal | ||||
| 
 | ||||
|             end_interpreter = time.time() | ||||
|             self.interpreter_latency = end_interpreter - start_interpreter | ||||
|             print("INTERPRETER LATENCY", self.interpreter_latency) | ||||
|             # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) | ||||
| 
 | ||||
|         # Feed generate to RealtimeTTS | ||||
|         self.add_to_output_queue_sync( | ||||
|             {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} | ||||
|         ) | ||||
|         self.tts.feed(generate(message)) | ||||
|         start_interpreter = time.time() | ||||
|         text_iterator = generate(message) | ||||
| 
 | ||||
|         self.tts.feed(text_iterator) | ||||
|         self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) | ||||
| 
 | ||||
|         while True: | ||||
|             if self.tts.is_playing(): | ||||
|                 start_tts = time.time() | ||||
|                 break | ||||
|             await asyncio.sleep(0.1) | ||||
|         while True: | ||||
|  | @ -197,6 +228,9 @@ class AsyncInterpreter: | |||
|                         "end": True, | ||||
|                     } | ||||
|                 ) | ||||
|                 end_tts = time.time() | ||||
|                 self.tts_latency = end_tts - start_tts | ||||
|                 print("TTS LATENCY", self.tts_latency) | ||||
|                 break | ||||
| 
 | ||||
|     async def _on_tts_chunk_async(self, chunk): | ||||
|  | @ -204,6 +238,7 @@ class AsyncInterpreter: | |||
|         await self._add_to_queue(self._output_queue, chunk) | ||||
| 
 | ||||
|     def on_tts_chunk(self, chunk): | ||||
|         # print("ye") | ||||
|         asyncio.run(self._on_tts_chunk_async(chunk)) | ||||
| 
 | ||||
|     async def output(self): | ||||
|  |  | |||
|  | @ -12,6 +12,18 @@ from pydantic import BaseModel | |||
| import argparse | ||||
| import os | ||||
| 
 | ||||
| # import sentry_sdk | ||||
| 
 | ||||
| base_interpreter.system_message = ( | ||||
|     "You are a helpful assistant that can answer questions and help with tasks." | ||||
| ) | ||||
| base_interpreter.computer.import_computer_api = False | ||||
| base_interpreter.llm.model = "groq/mixtral-8x7b-32768" | ||||
| base_interpreter.llm.api_key = ( | ||||
|     "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" | ||||
| ) | ||||
| base_interpreter.llm.supports_functions = False | ||||
| 
 | ||||
| os.environ["STT_RUNNER"] = "server" | ||||
| os.environ["TTS_RUNNER"] = "server" | ||||
| 
 | ||||
|  | @ -20,11 +32,24 @@ parser = argparse.ArgumentParser(description="FastAPI server.") | |||
| parser.add_argument("--port", type=int, default=8000, help="Port to run on.") | ||||
| args = parser.parse_args() | ||||
| 
 | ||||
| base_interpreter.tts = "openai" | ||||
| base_interpreter.llm.model = "gpt-4-turbo" | ||||
| base_interpreter.tts = "elevenlabs" | ||||
| 
 | ||||
| 
 | ||||
| async def main(): | ||||
|     """ | ||||
|     sentry_sdk.init( | ||||
|         dsn="https://a1465f62a31c7dfb23e1616da86341e9@o4506046614667264.ingest.us.sentry.io/4507374662385664", | ||||
|         enable_tracing=True, | ||||
|         # Set traces_sample_rate to 1.0 to capture 100% | ||||
|         # of transactions for performance monitoring. | ||||
|         traces_sample_rate=1.0, | ||||
|         # Set profiles_sample_rate to 1.0 to profile 100% | ||||
|         # of sampled transactions. | ||||
|         # We recommend adjusting this value in production. | ||||
|         profiles_sample_rate=1.0, | ||||
|     ) | ||||
|     """ | ||||
| 
 | ||||
|     interpreter = AsyncInterpreter(base_interpreter) | ||||
| 
 | ||||
|     app = FastAPI() | ||||
|  | @ -51,6 +76,9 @@ async def main(): | |||
| 
 | ||||
|             async def receive_input(): | ||||
|                 while True: | ||||
|                     if websocket.client_state == "DISCONNECTED": | ||||
|                         break | ||||
| 
 | ||||
|                     data = await websocket.receive() | ||||
| 
 | ||||
|                     if isinstance(data, bytes): | ||||
|  | @ -65,18 +93,22 @@ async def main(): | |||
|             async def send_output(): | ||||
|                 while True: | ||||
|                     output = await interpreter.output() | ||||
| 
 | ||||
|                     if isinstance(output, bytes): | ||||
|                         # print(f"Sending {len(output)} bytes of audio data.") | ||||
|                         await websocket.send_bytes(output) | ||||
|                         # we dont send out bytes rn, no TTS | ||||
|                         pass | ||||
| 
 | ||||
|                     elif isinstance(output, dict): | ||||
|                         # print("sending text") | ||||
|                         await websocket.send_text(json.dumps(output)) | ||||
| 
 | ||||
|             await asyncio.gather(receive_input(), send_output()) | ||||
|             await asyncio.gather(send_output(), receive_input()) | ||||
|         except Exception as e: | ||||
|             print(f"WebSocket connection closed with exception: {e}") | ||||
|             traceback.print_exc() | ||||
|         finally: | ||||
|             if not websocket.client_state == "DISCONNECTED": | ||||
|                 await websocket.close() | ||||
| 
 | ||||
|     config = Config(app, host="0.0.0.0", port=8000, lifespan="on") | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ from .utils.logs import logger | |||
| import base64 | ||||
| import shutil | ||||
| from ..utils.print_markdown import print_markdown | ||||
| import time | ||||
| 
 | ||||
| os.environ["STT_RUNNER"] = "server" | ||||
| os.environ["TTS_RUNNER"] = "server" | ||||
|  | @ -383,6 +384,7 @@ async def stream_tts_to_device(sentence, mobile: bool): | |||
| 
 | ||||
| 
 | ||||
| def stream_tts(sentence, mobile: bool): | ||||
| 
 | ||||
|     audio_file = tts(sentence, mobile) | ||||
| 
 | ||||
|     # Read the entire WAV file | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ import threading | |||
| import os | ||||
| import importlib | ||||
| from source.server.tunnel import create_tunnel | ||||
| from source.server.server import main | ||||
| from source.server.async_server import main | ||||
| from source.server.utils.local_mode import select_local_model | ||||
| 
 | ||||
| import signal | ||||
|  | @ -152,18 +152,18 @@ def _run( | |||
|             target=loop.run_until_complete, | ||||
|             args=( | ||||
|                 main( | ||||
|                     server_host, | ||||
|                     server_port, | ||||
|                     llm_service, | ||||
|                     model, | ||||
|                     llm_supports_vision, | ||||
|                     llm_supports_functions, | ||||
|                     context_window, | ||||
|                     max_tokens, | ||||
|                     temperature, | ||||
|                     tts_service, | ||||
|                     stt_service, | ||||
|                     mobile, | ||||
|                     # server_host, | ||||
|                     # server_port, | ||||
|                     # llm_service, | ||||
|                     # model, | ||||
|                     # llm_supports_vision, | ||||
|                     # llm_supports_functions, | ||||
|                     # context_window, | ||||
|                     # max_tokens, | ||||
|                     # temperature, | ||||
|                     # tts_service, | ||||
|                     # stt_service, | ||||
|                     # mobile, | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
|  | @ -196,7 +196,7 @@ def _run( | |||
|         module = importlib.import_module( | ||||
|             f".clients.{client_type}.device", package="source" | ||||
|         ) | ||||
| 
 | ||||
|         server_url = "0.0.0.0:8000" | ||||
|         client_thread = threading.Thread(target=module.main, args=[server_url]) | ||||
|         print("client thread started") | ||||
|         client_thread.start() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Ben Xu
						Ben Xu