add async interpreter with coqui, openai, elevenlabs tts
This commit is contained in:
		
							parent
							
								
									2627fba481
								
							
						
					
					
						commit
						eee00ac026
					
				|  | @ -90,6 +90,7 @@ class Device: | ||||||
|         self.audiosegments = asyncio.Queue() |         self.audiosegments = asyncio.Queue() | ||||||
|         self.server_url = "" |         self.server_url = "" | ||||||
|         self.ctrl_pressed = False |         self.ctrl_pressed = False | ||||||
|  |         self.tts_service = "" | ||||||
|         self.playback_latency = None |         self.playback_latency = None | ||||||
| 
 | 
 | ||||||
|     def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): |     def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX): | ||||||
|  | @ -164,30 +165,18 @@ class Device: | ||||||
|         while True: |         while True: | ||||||
|             try: |             try: | ||||||
|                 audio = await self.audiosegments.get() |                 audio = await self.audiosegments.get() | ||||||
|                 # print("got audio segment!!!!") |                 if self.playback_latency and isinstance(audio, bytes): | ||||||
|                 if self.playback_latency: |  | ||||||
|                     elapsed_time = time.time() - self.playback_latency |                     elapsed_time = time.time() - self.playback_latency | ||||||
|                     print(f"Time from request to playback: {elapsed_time} seconds") |                     # print(f"Time from request to playback: {elapsed_time} seconds") | ||||||
|                     self.playback_latency = None |                     self.playback_latency = None | ||||||
| 
 | 
 | ||||||
|                 if audio is not None: |                 if self.tts_service == "elevenlabs": | ||||||
|                     mpv_process.stdin.write(audio)  # type: ignore |                     mpv_process.stdin.write(audio)  # type: ignore | ||||||
|                     mpv_process.stdin.flush()  # type: ignore |                     mpv_process.stdin.flush()  # type: ignore | ||||||
|                 """ |                 else: | ||||||
|                 args = ["ffplay", "-autoexit", "-", "-nodisp"] |                     play(audio) | ||||||
|                 proc = subprocess.Popen( |  | ||||||
|                     args=args, |  | ||||||
|                     stdout=subprocess.PIPE, |  | ||||||
|                     stdin=subprocess.PIPE, |  | ||||||
|                     stderr=subprocess.PIPE, |  | ||||||
|                 ) |  | ||||||
|                 out, err = proc.communicate(input=audio) |  | ||||||
|                 proc.poll() |  | ||||||
| 
 | 
 | ||||||
|                 play(audio) |                 await asyncio.sleep(0.1) | ||||||
|                 """ |  | ||||||
|                 # self.audiosegments.remove(audio) |  | ||||||
|                 # await asyncio.sleep(0.1) |  | ||||||
|             except asyncio.exceptions.CancelledError: |             except asyncio.exceptions.CancelledError: | ||||||
|                 # This happens once at the start? |                 # This happens once at the start? | ||||||
|                 pass |                 pass | ||||||
|  | @ -342,24 +331,17 @@ class Device: | ||||||
| 
 | 
 | ||||||
|     async def message_sender(self, websocket): |     async def message_sender(self, websocket): | ||||||
|         while True: |         while True: | ||||||
|             try: |             message = await asyncio.get_event_loop().run_in_executor( | ||||||
|                 message = await asyncio.get_event_loop().run_in_executor( |                 None, send_queue.get | ||||||
|                     None, send_queue.get |             ) | ||||||
|                 ) |             if isinstance(message, bytes): | ||||||
| 
 |                 await websocket.send(message) | ||||||
|                 if isinstance(message, bytes): |             else: | ||||||
|                     await websocket.send(message) |                 await websocket.send(json.dumps(message)) | ||||||
| 
 |             send_queue.task_done() | ||||||
|                 else: |             await asyncio.sleep(0.01) | ||||||
|                     await websocket.send(json.dumps(message)) |  | ||||||
| 
 |  | ||||||
|                 send_queue.task_done() |  | ||||||
|                 await asyncio.sleep(0.01) |  | ||||||
|             except: |  | ||||||
|                 traceback.print_exc() |  | ||||||
| 
 | 
 | ||||||
|     async def websocket_communication(self, WS_URL): |     async def websocket_communication(self, WS_URL): | ||||||
|         print("websocket communication was called!!!!") |  | ||||||
|         show_connection_log = True |         show_connection_log = True | ||||||
| 
 | 
 | ||||||
|         async def exec_ws_communication(websocket): |         async def exec_ws_communication(websocket): | ||||||
|  | @ -373,48 +355,48 @@ class Device: | ||||||
|             asyncio.create_task(self.message_sender(websocket)) |             asyncio.create_task(self.message_sender(websocket)) | ||||||
| 
 | 
 | ||||||
|             while True: |             while True: | ||||||
|                 await asyncio.sleep(0.0001) |                 await asyncio.sleep(0.01) | ||||||
|                 chunk = await websocket.recv() |                 chunk = await websocket.recv() | ||||||
| 
 | 
 | ||||||
|                 logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") |                 logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") | ||||||
|                 # print((f"Got this message from the server: {type(chunk)} {chunk}")) |                 # print("received chunk from server") | ||||||
|  | 
 | ||||||
|                 if type(chunk) == str: |                 if type(chunk) == str: | ||||||
|                     chunk = json.loads(chunk) |                     chunk = json.loads(chunk) | ||||||
| 
 | 
 | ||||||
|                 # message = accumulator.accumulate(chunk) |                 if self.tts_service == "elevenlabs": | ||||||
|                 message = chunk |                     message = chunk | ||||||
|  |                 else: | ||||||
|  |                     message = accumulator.accumulate(chunk) | ||||||
|  | 
 | ||||||
|                 if message == None: |                 if message == None: | ||||||
|                     # Will be None until we have a full message ready |                     # Will be None until we have a full message ready | ||||||
|                     continue |                     continue | ||||||
| 
 | 
 | ||||||
|                 # At this point, we have our message |                 # At this point, we have our message | ||||||
|                 # print("checkpoint reached!", message) |                 if isinstance(message, bytes) or ( | ||||||
|                 if isinstance(message, bytes): |                     message["type"] == "audio" and message["format"].startswith("bytes") | ||||||
| 
 |                 ): | ||||||
|                     # if message["type"] == "audio" and message["format"].startswith("bytes"): |  | ||||||
|                     # Convert bytes to audio file |                     # Convert bytes to audio file | ||||||
|  |                     if self.tts_service == "elevenlabs": | ||||||
|  |                         audio_bytes = message | ||||||
|  |                         audio = audio_bytes | ||||||
|  |                     else: | ||||||
|  |                         audio_bytes = message["content"] | ||||||
| 
 | 
 | ||||||
|                     # audio_bytes = message["content"] |                         # Create an AudioSegment instance with the raw data | ||||||
|                     audio_bytes = message |                         audio = AudioSegment( | ||||||
|  |                             # raw audio data (bytes) | ||||||
|  |                             data=audio_bytes, | ||||||
|  |                             # signed 16-bit little-endian format | ||||||
|  |                             sample_width=2, | ||||||
|  |                             # 16,000 Hz frame rate | ||||||
|  |                             frame_rate=22050, | ||||||
|  |                             # mono sound | ||||||
|  |                             channels=1, | ||||||
|  |                         ) | ||||||
| 
 | 
 | ||||||
|                     # Create an AudioSegment instance with the raw data |                     await self.audiosegments.put(audio) | ||||||
|                     """ |  | ||||||
|                     audio = AudioSegment( |  | ||||||
|                         # raw audio data (bytes) |  | ||||||
|                         data=audio_bytes, |  | ||||||
|                         # signed 16-bit little-endian format |  | ||||||
|                         sample_width=2, |  | ||||||
|                         # 24,000 Hz frame rate |  | ||||||
|                         frame_rate=16000, |  | ||||||
|                         # mono sound |  | ||||||
|                         channels=1, |  | ||||||
|                     ) |  | ||||||
|                     """ |  | ||||||
| 
 |  | ||||||
|                     # print("audio segment was created") |  | ||||||
|                     await self.audiosegments.put(audio_bytes) |  | ||||||
| 
 |  | ||||||
|                     # await self.audiosegments.put(audio) |  | ||||||
| 
 | 
 | ||||||
|                 # Run the code if that's the client's job |                 # Run the code if that's the client's job | ||||||
|                 if os.getenv("CODE_RUNNER") == "client": |                 if os.getenv("CODE_RUNNER") == "client": | ||||||
|  | @ -434,29 +416,26 @@ class Device: | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 logger.error(f"Error while attempting to connect: {e}") |                 logger.error(f"Error while attempting to connect: {e}") | ||||||
|         else: |         else: | ||||||
|             print("websocket url is", WS_URL) |  | ||||||
|             while True: |             while True: | ||||||
|                 try: |                 try: | ||||||
|                     async with websockets.connect(WS_URL) as websocket: |                     async with websockets.connect(WS_URL) as websocket: | ||||||
|                         print("awaiting exec_ws_communication") |  | ||||||
|                         await exec_ws_communication(websocket) |                         await exec_ws_communication(websocket) | ||||||
|                 except: |                 except: | ||||||
|                     logger.info(traceback.format_exc()) |                     logger.debug(traceback.format_exc()) | ||||||
|                     if show_connection_log: |                     if show_connection_log: | ||||||
|                         logger.info(f"Connecting to `{WS_URL}`...") |                         logger.info(f"Connecting to `{WS_URL}`...") | ||||||
|                         show_connection_log = False |                         show_connection_log = False | ||||||
|                         await asyncio.sleep(2) |                         await asyncio.sleep(2) | ||||||
| 
 | 
 | ||||||
|     async def start_async(self): |     async def start_async(self): | ||||||
|         print("start async was called!!!!!") |  | ||||||
|         # Configuration for WebSocket |         # Configuration for WebSocket | ||||||
|         WS_URL = f"ws://{self.server_url}" |         WS_URL = f"ws://{self.server_url}" | ||||||
| 
 |  | ||||||
|         # Start the WebSocket communication |         # Start the WebSocket communication | ||||||
|         asyncio.create_task(self.websocket_communication(WS_URL)) |         asyncio.create_task(self.websocket_communication(WS_URL)) | ||||||
| 
 | 
 | ||||||
|         # Start watching the kernel if it's your job to do that |         # Start watching the kernel if it's your job to do that | ||||||
|         if os.getenv("CODE_RUNNER") == "client": |         if os.getenv("CODE_RUNNER") == "client": | ||||||
|  |             # client is not running code! | ||||||
|             asyncio.create_task(put_kernel_messages_into_queue(send_queue)) |             asyncio.create_task(put_kernel_messages_into_queue(send_queue)) | ||||||
| 
 | 
 | ||||||
|         asyncio.create_task(self.play_audiosegments()) |         asyncio.create_task(self.play_audiosegments()) | ||||||
|  | @ -488,10 +467,8 @@ class Device: | ||||||
|                 on_press=self.on_press, on_release=self.on_release |                 on_press=self.on_press, on_release=self.on_release | ||||||
|             ) |             ) | ||||||
|             listener.start() |             listener.start() | ||||||
|             print("listener for keyboard started!!!!!") |  | ||||||
| 
 | 
 | ||||||
|     def start(self): |     def start(self): | ||||||
|         print("device was started!!!!!!") |  | ||||||
|         if os.getenv("TEACH_MODE") != "True": |         if os.getenv("TEACH_MODE") != "True": | ||||||
|             asyncio.run(self.start_async()) |             asyncio.run(self.start_async()) | ||||||
|             p.terminate() |             p.terminate() | ||||||
|  |  | ||||||
|  | @ -3,8 +3,9 @@ from ..base_device import Device | ||||||
| device = Device() | device = Device() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(server_url): | def main(server_url, tts_service): | ||||||
|     device.server_url = server_url |     device.server_url = server_url | ||||||
|  |     device.tts_service = tts_service | ||||||
|     device.start() |     device.start() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -3,8 +3,9 @@ from ..base_device import Device | ||||||
| device = Device() | device = Device() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(server_url): | def main(server_url, tts_service): | ||||||
|     device.server_url = server_url |     device.server_url = server_url | ||||||
|  |     device.tts_service = tts_service | ||||||
|     device.start() |     device.start() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -3,8 +3,9 @@ from ..base_device import Device | ||||||
| device = Device() | device = Device() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(server_url): | def main(server_url, tts_service): | ||||||
|     device.server_url = server_url |     device.server_url = server_url | ||||||
|  |     device.tts_service = tts_service | ||||||
|     device.start() |     device.start() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,16 +10,9 @@ | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| ### | ### | ||||||
| 
 |  | ||||||
| from pynput import keyboard | from pynput import keyboard | ||||||
| from RealtimeTTS import ( | 
 | ||||||
|     TextToAudioStream, | from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine | ||||||
|     OpenAIEngine, |  | ||||||
|     CoquiEngine, |  | ||||||
|     ElevenlabsEngine, |  | ||||||
|     SystemEngine, |  | ||||||
|     GTTSEngine, |  | ||||||
| ) |  | ||||||
| from RealtimeSTT import AudioToTextRecorder | from RealtimeSTT import AudioToTextRecorder | ||||||
| import time | import time | ||||||
| import asyncio | import asyncio | ||||||
|  | @ -29,9 +22,9 @@ import os | ||||||
| 
 | 
 | ||||||
| class AsyncInterpreter: | class AsyncInterpreter: | ||||||
|     def __init__(self, interpreter): |     def __init__(self, interpreter): | ||||||
|         self.stt_latency = None |         # self.stt_latency = None | ||||||
|         self.tts_latency = None |         # self.tts_latency = None | ||||||
|         self.interpreter_latency = None |         # self.interpreter_latency = None | ||||||
|         self.interpreter = interpreter |         self.interpreter = interpreter | ||||||
| 
 | 
 | ||||||
|         # STT |         # STT | ||||||
|  | @ -45,12 +38,9 @@ class AsyncInterpreter: | ||||||
|             engine = CoquiEngine() |             engine = CoquiEngine() | ||||||
|         elif self.interpreter.tts == "openai": |         elif self.interpreter.tts == "openai": | ||||||
|             engine = OpenAIEngine() |             engine = OpenAIEngine() | ||||||
|         elif self.interpreter.tts == "gtts": |  | ||||||
|             engine = GTTSEngine() |  | ||||||
|         elif self.interpreter.tts == "elevenlabs": |         elif self.interpreter.tts == "elevenlabs": | ||||||
|             engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"]) |             engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"]) | ||||||
|         elif self.interpreter.tts == "system": |             engine.set_voice("Michael") | ||||||
|             engine = SystemEngine() |  | ||||||
|         else: |         else: | ||||||
|             raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}") |             raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}") | ||||||
|         self.tts = TextToAudioStream(engine) |         self.tts = TextToAudioStream(engine) | ||||||
|  | @ -112,111 +102,96 @@ class AsyncInterpreter: | ||||||
|         # print("ADDING TO QUEUE:", chunk) |         # print("ADDING TO QUEUE:", chunk) | ||||||
|         asyncio.create_task(self._add_to_queue(self._output_queue, chunk)) |         asyncio.create_task(self._add_to_queue(self._output_queue, chunk)) | ||||||
| 
 | 
 | ||||||
|  |     def generate(self, message, start_interpreter): | ||||||
|  |         last_lmc_start_flag = self._last_lmc_start_flag | ||||||
|  |         self.interpreter.messages = self.active_chat_messages | ||||||
|  | 
 | ||||||
|  |         # print("message is", message) | ||||||
|  | 
 | ||||||
|  |         for chunk in self.interpreter.chat(message, display=True, stream=True): | ||||||
|  | 
 | ||||||
|  |             if self._last_lmc_start_flag != last_lmc_start_flag: | ||||||
|  |                 # self.beeper.stop() | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |             # self.add_to_output_queue_sync(chunk) # To send text, not just audio | ||||||
|  | 
 | ||||||
|  |             content = chunk.get("content") | ||||||
|  | 
 | ||||||
|  |             # Handle message blocks | ||||||
|  |             if chunk.get("type") == "message": | ||||||
|  |                 if content: | ||||||
|  |                     # self.beeper.stop() | ||||||
|  | 
 | ||||||
|  |                     # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer | ||||||
|  |                     # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") | ||||||
|  |                     # print("yielding ", content) | ||||||
|  |                     yield content | ||||||
|  | 
 | ||||||
|  |             # Handle code blocks | ||||||
|  |             elif chunk.get("type") == "code": | ||||||
|  |                 if "start" in chunk: | ||||||
|  |                     # self.beeper.start() | ||||||
|  |                     pass | ||||||
|  | 
 | ||||||
|  |                 # Experimental: If the AI wants to type, we should type immediatly | ||||||
|  |                 if ( | ||||||
|  |                     self.interpreter.messages[-1] | ||||||
|  |                     .get("content", "") | ||||||
|  |                     .startswith("computer.keyboard.write(") | ||||||
|  |                 ): | ||||||
|  |                     keyboard.controller.type(content) | ||||||
|  |                     self._in_keyboard_write_block = True | ||||||
|  |                 if "end" in chunk and self._in_keyboard_write_block: | ||||||
|  |                     self._in_keyboard_write_block = False | ||||||
|  |                     # (This will make it so it doesn't type twice when the block executes) | ||||||
|  |                     if self.interpreter.messages[-1]["content"].startswith( | ||||||
|  |                         "computer.keyboard.write(" | ||||||
|  |                     ): | ||||||
|  |                         self.interpreter.messages[-1]["content"] = ( | ||||||
|  |                             "dummy_variable = (" | ||||||
|  |                             + self.interpreter.messages[-1]["content"][ | ||||||
|  |                                 len("computer.keyboard.write(") : | ||||||
|  |                             ] | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |         # Send a completion signal | ||||||
|  |         # end_interpreter = time.time() | ||||||
|  |         # self.interpreter_latency = end_interpreter - start_interpreter | ||||||
|  |         # print("INTERPRETER LATENCY", self.interpreter_latency) | ||||||
|  |         # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) | ||||||
|  | 
 | ||||||
|     async def run(self): |     async def run(self): | ||||||
|         """ |         """ | ||||||
|         Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue. |         Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue. | ||||||
|         """ |         """ | ||||||
|         self.interpreter.messages = self.active_chat_messages |         self.interpreter.messages = self.active_chat_messages | ||||||
| 
 | 
 | ||||||
|         # self.beeper.start() |  | ||||||
| 
 |  | ||||||
|         self.stt.stop() |         self.stt.stop() | ||||||
|         # message = self.stt.text() |  | ||||||
|         # print("THE MESSAGE:", message) |  | ||||||
| 
 | 
 | ||||||
|         # accumulates the input queue message |  | ||||||
|         input_queue = [] |         input_queue = [] | ||||||
|         while not self._input_queue.empty(): |         while not self._input_queue.empty(): | ||||||
|             input_queue.append(self._input_queue.get()) |             input_queue.append(self._input_queue.get()) | ||||||
| 
 | 
 | ||||||
|         # print("INPUT QUEUE:", input_queue) |         # start_stt = time.time() | ||||||
|         # message = [i for i in input_queue if i["type"] == "message"][0]["content"] |  | ||||||
|         start_stt = time.time() |  | ||||||
|         message = self.stt.text() |         message = self.stt.text() | ||||||
|         end_stt = time.time() |         # end_stt = time.time() | ||||||
|         self.stt_latency = end_stt - start_stt |         # self.stt_latency = end_stt - start_stt | ||||||
|         print("STT LATENCY", self.stt_latency) |         # print("STT LATENCY", self.stt_latency) | ||||||
| 
 | 
 | ||||||
|         # print(message) |         # print(message) | ||||||
|         end_interpreter = 0 |  | ||||||
| 
 |  | ||||||
|         # print(message) |  | ||||||
|         def generate(message): |  | ||||||
|             last_lmc_start_flag = self._last_lmc_start_flag |  | ||||||
|             self.interpreter.messages = self.active_chat_messages |  | ||||||
|             # print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages) |  | ||||||
|             # print("🍀   🍀   🍀   🍀 active_chat_messages: ", self.active_chat_messages) |  | ||||||
|             print("message is", message) |  | ||||||
| 
 |  | ||||||
|             for chunk in self.interpreter.chat(message, display=True, stream=True): |  | ||||||
| 
 |  | ||||||
|                 if self._last_lmc_start_flag != last_lmc_start_flag: |  | ||||||
|                     # self.beeper.stop() |  | ||||||
|                     break |  | ||||||
| 
 |  | ||||||
|                 # self.add_to_output_queue_sync(chunk) # To send text, not just audio |  | ||||||
| 
 |  | ||||||
|                 content = chunk.get("content") |  | ||||||
| 
 |  | ||||||
|                 # Handle message blocks |  | ||||||
|                 if chunk.get("type") == "message": |  | ||||||
|                     if content: |  | ||||||
|                         # self.beeper.stop() |  | ||||||
| 
 |  | ||||||
|                         # Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer |  | ||||||
|                         # content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ") |  | ||||||
|                         # print("yielding this", content) |  | ||||||
|                         yield content |  | ||||||
| 
 |  | ||||||
|                 # Handle code blocks |  | ||||||
|                 elif chunk.get("type") == "code": |  | ||||||
|                     if "start" in chunk: |  | ||||||
|                         # self.beeper.start() |  | ||||||
|                         pass |  | ||||||
| 
 |  | ||||||
|                     # Experimental: If the AI wants to type, we should type immediatly |  | ||||||
|                     if ( |  | ||||||
|                         self.interpreter.messages[-1] |  | ||||||
|                         .get("content", "") |  | ||||||
|                         .startswith("computer.keyboard.write(") |  | ||||||
|                     ): |  | ||||||
|                         keyboard.controller.type(content) |  | ||||||
|                         self._in_keyboard_write_block = True |  | ||||||
|                     if "end" in chunk and self._in_keyboard_write_block: |  | ||||||
|                         self._in_keyboard_write_block = False |  | ||||||
|                         # (This will make it so it doesn't type twice when the block executes) |  | ||||||
|                         if self.interpreter.messages[-1]["content"].startswith( |  | ||||||
|                             "computer.keyboard.write(" |  | ||||||
|                         ): |  | ||||||
|                             self.interpreter.messages[-1]["content"] = ( |  | ||||||
|                                 "dummy_variable = (" |  | ||||||
|                                 + self.interpreter.messages[-1]["content"][ |  | ||||||
|                                     len("computer.keyboard.write(") : |  | ||||||
|                                 ] |  | ||||||
|                             ) |  | ||||||
| 
 |  | ||||||
|             # Send a completion signal |  | ||||||
|             end_interpreter = time.time() |  | ||||||
|             self.interpreter_latency = end_interpreter - start_interpreter |  | ||||||
|             print("INTERPRETER LATENCY", self.interpreter_latency) |  | ||||||
|             # self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"}) |  | ||||||
| 
 | 
 | ||||||
|         # Feed generate to RealtimeTTS |         # Feed generate to RealtimeTTS | ||||||
|         self.add_to_output_queue_sync( |         self.add_to_output_queue_sync( | ||||||
|             {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} |             {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True} | ||||||
|         ) |         ) | ||||||
|         start_interpreter = time.time() |         start_interpreter = time.time() | ||||||
|         text_iterator = generate(message) |         text_iterator = self.generate(message, start_interpreter) | ||||||
| 
 | 
 | ||||||
|         self.tts.feed(text_iterator) |         self.tts.feed(text_iterator) | ||||||
|  | 
 | ||||||
|         self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) |         self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) | ||||||
| 
 | 
 | ||||||
|         while True: |  | ||||||
|             if self.tts.is_playing(): |  | ||||||
|                 start_tts = time.time() |  | ||||||
| 
 |  | ||||||
|                 break |  | ||||||
|             await asyncio.sleep(0.1) |  | ||||||
|         while True: |         while True: | ||||||
|             await asyncio.sleep(0.1) |             await asyncio.sleep(0.1) | ||||||
|             # print("is_playing", self.tts.is_playing()) |             # print("is_playing", self.tts.is_playing()) | ||||||
|  | @ -229,14 +204,14 @@ class AsyncInterpreter: | ||||||
|                         "end": True, |                         "end": True, | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
|                 end_tts = time.time() |                 # end_tts = time.time() | ||||||
|                 self.tts_latency = end_tts - start_tts |                 # self.tts_latency = end_tts - self.tts.stream_start_time | ||||||
|                 print("TTS LATENCY", self.tts_latency) |                 # print("TTS LATENCY", self.tts_latency) | ||||||
|                 self.tts.stop() |                 self.tts.stop() | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|     async def _on_tts_chunk_async(self, chunk): |     async def _on_tts_chunk_async(self, chunk): | ||||||
|         # print("SENDING TTS CHUNK") |         # print("adding chunk to queue") | ||||||
|         await self._add_to_queue(self._output_queue, chunk) |         await self._add_to_queue(self._output_queue, chunk) | ||||||
| 
 | 
 | ||||||
|     def on_tts_chunk(self, chunk): |     def on_tts_chunk(self, chunk): | ||||||
|  | @ -244,4 +219,5 @@ class AsyncInterpreter: | ||||||
|         asyncio.run(self._on_tts_chunk_async(chunk)) |         asyncio.run(self._on_tts_chunk_async(chunk)) | ||||||
| 
 | 
 | ||||||
|     async def output(self): |     async def output(self): | ||||||
|  |         # print("outputting chunks") | ||||||
|         return await self._output_queue.get() |         return await self._output_queue.get() | ||||||
|  |  | ||||||
|  | @ -1,43 +1,38 @@ | ||||||
| import asyncio | import asyncio | ||||||
| import traceback | import traceback | ||||||
| import json | import json | ||||||
| from fastapi import FastAPI, WebSocket, Header | from fastapi import FastAPI, WebSocket | ||||||
| from fastapi.responses import PlainTextResponse | from fastapi.responses import PlainTextResponse | ||||||
| from uvicorn import Config, Server | from uvicorn import Config, Server | ||||||
|  | from .i import configure_interpreter | ||||||
| from interpreter import interpreter as base_interpreter | from interpreter import interpreter as base_interpreter | ||||||
| from .async_interpreter import AsyncInterpreter | from .async_interpreter import AsyncInterpreter | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from typing import List, Dict, Any | from typing import List, Dict, Any | ||||||
| from openai import OpenAI |  | ||||||
| from pydantic import BaseModel |  | ||||||
| import argparse |  | ||||||
| import os | import os | ||||||
| 
 | 
 | ||||||
| # import sentry_sdk |  | ||||||
| 
 |  | ||||||
| base_interpreter.system_message = ( |  | ||||||
|     "You are a helpful assistant that can answer questions and help with tasks." |  | ||||||
| ) |  | ||||||
| base_interpreter.computer.import_computer_api = False |  | ||||||
| base_interpreter.llm.model = "groq/llama3-8b-8192" |  | ||||||
| base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] |  | ||||||
| base_interpreter.llm.supports_functions = False |  | ||||||
| base_interpreter.auto_run = True |  | ||||||
| 
 | 
 | ||||||
| os.environ["STT_RUNNER"] = "server" | os.environ["STT_RUNNER"] = "server" | ||||||
| os.environ["TTS_RUNNER"] = "server" | os.environ["TTS_RUNNER"] = "server" | ||||||
| 
 | 
 | ||||||
| # Parse command line arguments for port number |  | ||||||
| """ |  | ||||||
| parser = argparse.ArgumentParser(description="FastAPI server.") |  | ||||||
| parser.add_argument("--port", type=int, default=8000, help="Port to run on.") |  | ||||||
| args = parser.parse_args() |  | ||||||
| """ |  | ||||||
| base_interpreter.tts = "coqui" |  | ||||||
| 
 | 
 | ||||||
| 
 | async def main(server_host, server_port, tts_service, asynchronous): | ||||||
| async def main(server_host, server_port): |     if asynchronous: | ||||||
|     interpreter = AsyncInterpreter(base_interpreter) |         base_interpreter.system_message = ( | ||||||
|  |             "You are a helpful assistant that can answer questions and help with tasks." | ||||||
|  |         ) | ||||||
|  |         base_interpreter.computer.import_computer_api = False | ||||||
|  |         base_interpreter.llm.model = "groq/llama3-8b-8192" | ||||||
|  |         base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] | ||||||
|  |         base_interpreter.llm.supports_functions = False | ||||||
|  |         base_interpreter.auto_run = True | ||||||
|  |         base_interpreter.tts = tts_service | ||||||
|  |         interpreter = AsyncInterpreter(base_interpreter) | ||||||
|  |     else: | ||||||
|  |         configured_interpreter = configure_interpreter(base_interpreter) | ||||||
|  |         configured_interpreter.llm.supports_functions = True | ||||||
|  |         configured_interpreter.tts = tts_service | ||||||
|  |         interpreter = AsyncInterpreter(configured_interpreter) | ||||||
| 
 | 
 | ||||||
|     app = FastAPI() |     app = FastAPI() | ||||||
| 
 | 
 | ||||||
|  | @ -107,37 +102,6 @@ async def main(server_host, server_port): | ||||||
|     server = Server(config) |     server = Server(config) | ||||||
|     await server.serve() |     await server.serve() | ||||||
| 
 | 
 | ||||||
|     class Rename(BaseModel): |  | ||||||
|         input: str |  | ||||||
| 
 |  | ||||||
|     @app.post("/rename-chat") |  | ||||||
|     async def rename_chat(body_content: Rename, x_api_key: str = Header(None)): |  | ||||||
|         print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙") |  | ||||||
|         input_value = body_content.input |  | ||||||
|         client = OpenAI( |  | ||||||
|             # defaults to os.environ.get("OPENAI_API_KEY") |  | ||||||
|             api_key=x_api_key, |  | ||||||
|         ) |  | ||||||
|         try: |  | ||||||
|             response = client.chat.completions.create( |  | ||||||
|                 model="gpt-3.5-turbo", |  | ||||||
|                 messages=[ |  | ||||||
|                     { |  | ||||||
|                         "role": "user", |  | ||||||
|                         "content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}", |  | ||||||
|                     } |  | ||||||
|                 ], |  | ||||||
|                 temperature=0.3, |  | ||||||
|                 stream=False, |  | ||||||
|             ) |  | ||||||
|             print(response) |  | ||||||
|             completion = response["choices"][0]["message"]["content"] |  | ||||||
|             return {"data": {"content": completion}} |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"Error: {e}") |  | ||||||
|             traceback.print_exc() |  | ||||||
|             return {"error": str(e)} |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     asyncio.run(main()) |     asyncio.run(main()) | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ import os | ||||||
| import importlib | import importlib | ||||||
| from source.server.tunnel import create_tunnel | from source.server.tunnel import create_tunnel | ||||||
| from source.server.async_server import main | from source.server.async_server import main | ||||||
|  | 
 | ||||||
|  | # from source.server.server import main | ||||||
| from source.server.utils.local_mode import select_local_model | from source.server.utils.local_mode import select_local_model | ||||||
| 
 | 
 | ||||||
| import signal | import signal | ||||||
|  | @ -63,7 +65,7 @@ def run( | ||||||
|         0.8, "--temperature", help="Specify the temperature for generation" |         0.8, "--temperature", help="Specify the temperature for generation" | ||||||
|     ), |     ), | ||||||
|     tts_service: str = typer.Option( |     tts_service: str = typer.Option( | ||||||
|         "openai", "--tts-service", help="Specify the TTS service" |         "elevenlabs", "--tts-service", help="Specify the TTS service" | ||||||
|     ), |     ), | ||||||
|     stt_service: str = typer.Option( |     stt_service: str = typer.Option( | ||||||
|         "openai", "--stt-service", help="Specify the STT service" |         "openai", "--stt-service", help="Specify the STT service" | ||||||
|  | @ -75,6 +77,9 @@ def run( | ||||||
|     mobile: bool = typer.Option( |     mobile: bool = typer.Option( | ||||||
|         False, "--mobile", help="Toggle server to support mobile app" |         False, "--mobile", help="Toggle server to support mobile app" | ||||||
|     ), |     ), | ||||||
|  |     asynchronous: bool = typer.Option( | ||||||
|  |         False, "--async", help="use interpreter optimized for latency" | ||||||
|  |     ), | ||||||
| ): | ): | ||||||
|     _run( |     _run( | ||||||
|         server=server or mobile, |         server=server or mobile, | ||||||
|  | @ -97,6 +102,7 @@ def run( | ||||||
|         local=local, |         local=local, | ||||||
|         qr=qr or mobile, |         qr=qr or mobile, | ||||||
|         mobile=mobile, |         mobile=mobile, | ||||||
|  |         asynchronous=asynchronous, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -116,14 +122,15 @@ def _run( | ||||||
|     context_window: int = 2048, |     context_window: int = 2048, | ||||||
|     max_tokens: int = 4096, |     max_tokens: int = 4096, | ||||||
|     temperature: float = 0.8, |     temperature: float = 0.8, | ||||||
|     tts_service: str = "openai", |     tts_service: str = "elevenlabs", | ||||||
|     stt_service: str = "openai", |     stt_service: str = "openai", | ||||||
|     local: bool = False, |     local: bool = False, | ||||||
|     qr: bool = False, |     qr: bool = False, | ||||||
|     mobile: bool = False, |     mobile: bool = False, | ||||||
|  |     asynchronous: bool = False, | ||||||
| ): | ): | ||||||
|     if local: |     if local: | ||||||
|         tts_service = "piper" |         tts_service = "coqui" | ||||||
|         # llm_service = "llamafile" |         # llm_service = "llamafile" | ||||||
|         stt_service = "local-whisper" |         stt_service = "local-whisper" | ||||||
|         select_local_model() |         select_local_model() | ||||||
|  | @ -154,6 +161,8 @@ def _run( | ||||||
|                 main( |                 main( | ||||||
|                     server_host, |                     server_host, | ||||||
|                     server_port, |                     server_port, | ||||||
|  |                     tts_service, | ||||||
|  |                     asynchronous, | ||||||
|                     # llm_service, |                     # llm_service, | ||||||
|                     # model, |                     # model, | ||||||
|                     # llm_supports_vision, |                     # llm_supports_vision, | ||||||
|  | @ -161,7 +170,6 @@ def _run( | ||||||
|                     # context_window, |                     # context_window, | ||||||
|                     # max_tokens, |                     # max_tokens, | ||||||
|                     # temperature, |                     # temperature, | ||||||
|                     # tts_service, |  | ||||||
|                     # stt_service, |                     # stt_service, | ||||||
|                     # mobile, |                     # mobile, | ||||||
|                 ), |                 ), | ||||||
|  | @ -180,7 +188,6 @@ def _run( | ||||||
|             system_type = platform.system() |             system_type = platform.system() | ||||||
|             if system_type == "Darwin":  # Mac OS |             if system_type == "Darwin":  # Mac OS | ||||||
|                 client_type = "mac" |                 client_type = "mac" | ||||||
|                 print("initiating mac device with base device!!!") |  | ||||||
|             elif system_type == "Windows":  # Windows System |             elif system_type == "Windows":  # Windows System | ||||||
|                 client_type = "windows" |                 client_type = "windows" | ||||||
|             elif system_type == "Linux":  # Linux System |             elif system_type == "Linux":  # Linux System | ||||||
|  | @ -196,9 +203,10 @@ def _run( | ||||||
|         module = importlib.import_module( |         module = importlib.import_module( | ||||||
|             f".clients.{client_type}.device", package="source" |             f".clients.{client_type}.device", package="source" | ||||||
|         ) |         ) | ||||||
|         # server_url = "0.0.0.0:8000" | 
 | ||||||
|         client_thread = threading.Thread(target=module.main, args=[server_url]) |         client_thread = threading.Thread( | ||||||
|         print("client thread started") |             target=module.main, args=[server_url, tts_service] | ||||||
|  |         ) | ||||||
|         client_thread.start() |         client_thread.start() | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Ben Xu
						Ben Xu