stash server changes
This commit is contained in:
		
							parent
							
								
									2627fba481
								
							
						
					
					
						commit
						4b25239d0f
					
				|  | @ -2,6 +2,7 @@ from dotenv import load_dotenv | ||||||
| 
 | 
 | ||||||
| load_dotenv()  # take environment variables from .env. | load_dotenv()  # take environment variables from .env. | ||||||
| 
 | 
 | ||||||
|  | import requests | ||||||
| import subprocess | import subprocess | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
|  | @ -12,6 +13,7 @@ from pynput import keyboard | ||||||
| import json | import json | ||||||
| import traceback | import traceback | ||||||
| import websockets | import websockets | ||||||
|  | import websockets.sync.client | ||||||
| import queue | import queue | ||||||
| from pydub import AudioSegment | from pydub import AudioSegment | ||||||
| from pydub.playback import play | from pydub.playback import play | ||||||
|  | @ -169,11 +171,11 @@ class Device: | ||||||
|                     elapsed_time = time.time() - self.playback_latency |                     elapsed_time = time.time() - self.playback_latency | ||||||
|                     print(f"Time from request to playback: {elapsed_time} seconds") |                     print(f"Time from request to playback: {elapsed_time} seconds") | ||||||
|                     self.playback_latency = None |                     self.playback_latency = None | ||||||
| 
 |                 """ | ||||||
|                 if audio is not None: |                 if audio is not None: | ||||||
|                     mpv_process.stdin.write(audio)  # type: ignore |                     mpv_process.stdin.write(audio)  # type: ignore | ||||||
|                     mpv_process.stdin.flush()  # type: ignore |                     mpv_process.stdin.flush()  # type: ignore | ||||||
|                 """ | 
 | ||||||
|                 args = ["ffplay", "-autoexit", "-", "-nodisp"] |                 args = ["ffplay", "-autoexit", "-", "-nodisp"] | ||||||
|                 proc = subprocess.Popen( |                 proc = subprocess.Popen( | ||||||
|                     args=args, |                     args=args, | ||||||
|  | @ -183,9 +185,8 @@ class Device: | ||||||
|                 ) |                 ) | ||||||
|                 out, err = proc.communicate(input=audio) |                 out, err = proc.communicate(input=audio) | ||||||
|                 proc.poll() |                 proc.poll() | ||||||
| 
 |  | ||||||
|                 play(audio) |  | ||||||
|                 """ |                 """ | ||||||
|  |                 play(audio) | ||||||
|                 # self.audiosegments.remove(audio) |                 # self.audiosegments.remove(audio) | ||||||
|                 # await asyncio.sleep(0.1) |                 # await asyncio.sleep(0.1) | ||||||
|             except asyncio.exceptions.CancelledError: |             except asyncio.exceptions.CancelledError: | ||||||
|  | @ -361,7 +362,7 @@ class Device: | ||||||
|     async def websocket_communication(self, WS_URL): |     async def websocket_communication(self, WS_URL): | ||||||
|         print("websocket communication was called!!!!") |         print("websocket communication was called!!!!") | ||||||
|         show_connection_log = True |         show_connection_log = True | ||||||
| 
 |         """ | ||||||
|         async def exec_ws_communication(websocket): |         async def exec_ws_communication(websocket): | ||||||
|             if CAMERA_ENABLED: |             if CAMERA_ENABLED: | ||||||
|                 print( |                 print( | ||||||
|  | @ -373,11 +374,11 @@ class Device: | ||||||
|             asyncio.create_task(self.message_sender(websocket)) |             asyncio.create_task(self.message_sender(websocket)) | ||||||
| 
 | 
 | ||||||
|             while True: |             while True: | ||||||
|                 await asyncio.sleep(0.0001) |                 await asyncio.sleep(0) | ||||||
|                 chunk = await websocket.recv() |                 chunk = await websocket.recv() | ||||||
| 
 | 
 | ||||||
|                 logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") |                 #logger.debug(f"Got this message from the server: {type(chunk)} {chunk}") | ||||||
|                 # print((f"Got this message from the server: {type(chunk)} {chunk}")) |                 print((f"Got this message from the server: {type(chunk)}")) | ||||||
|                 if type(chunk) == str: |                 if type(chunk) == str: | ||||||
|                     chunk = json.loads(chunk) |                     chunk = json.loads(chunk) | ||||||
| 
 | 
 | ||||||
|  | @ -388,7 +389,7 @@ class Device: | ||||||
|                     continue |                     continue | ||||||
| 
 | 
 | ||||||
|                 # At this point, we have our message |                 # At this point, we have our message | ||||||
|                 # print("checkpoint reached!", message) |                 print("checkpoint reached!") | ||||||
|                 if isinstance(message, bytes): |                 if isinstance(message, bytes): | ||||||
| 
 | 
 | ||||||
|                     # if message["type"] == "audio" and message["format"].startswith("bytes"): |                     # if message["type"] == "audio" and message["format"].startswith("bytes"): | ||||||
|  | @ -398,23 +399,23 @@ class Device: | ||||||
|                     audio_bytes = message |                     audio_bytes = message | ||||||
| 
 | 
 | ||||||
|                     # Create an AudioSegment instance with the raw data |                     # Create an AudioSegment instance with the raw data | ||||||
|                     """ | 
 | ||||||
|                     audio = AudioSegment( |                     audio = AudioSegment( | ||||||
|                         # raw audio data (bytes) |                         # raw audio data (bytes) | ||||||
|                         data=audio_bytes, |                         data=audio_bytes, | ||||||
|                         # signed 16-bit little-endian format |                         # signed 16-bit little-endian format | ||||||
|                         sample_width=2, |                         sample_width=2, | ||||||
|                         # 24,000 Hz frame rate |                         # 24,000 Hz frame rate | ||||||
|                         frame_rate=16000, |                         frame_rate=24000, | ||||||
|                         # mono sound |                         # mono sound | ||||||
|                         channels=1, |                         channels=1, | ||||||
|                     ) |                     ) | ||||||
|                     """ |  | ||||||
| 
 | 
 | ||||||
|                     # print("audio segment was created") |  | ||||||
|                     await self.audiosegments.put(audio_bytes) |  | ||||||
| 
 | 
 | ||||||
|                     # await self.audiosegments.put(audio) |                     print("audio segment was created") | ||||||
|  |                     #await self.audiosegments.put(audio_bytes) | ||||||
|  | 
 | ||||||
|  |                     await self.audiosegments.put(audio) | ||||||
| 
 | 
 | ||||||
|                 # Run the code if that's the client's job |                 # Run the code if that's the client's job | ||||||
|                 if os.getenv("CODE_RUNNER") == "client": |                 if os.getenv("CODE_RUNNER") == "client": | ||||||
|  | @ -424,42 +425,65 @@ class Device: | ||||||
|                         result = interpreter.computer.run(language, code) |                         result = interpreter.computer.run(language, code) | ||||||
|                         send_queue.put(result) |                         send_queue.put(result) | ||||||
| 
 | 
 | ||||||
|  |             """ | ||||||
|         if is_win10(): |         if is_win10(): | ||||||
|             logger.info("Windows 10 detected") |             logger.info("Windows 10 detected") | ||||||
|             # Workaround for Windows 10 not latching to the websocket server. |             # Workaround for Windows 10 not latching to the websocket server. | ||||||
|             # See https://github.com/OpenInterpreter/01/issues/197 |             # See https://github.com/OpenInterpreter/01/issues/197 | ||||||
|             try: |             try: | ||||||
|                 ws = websockets.connect(WS_URL) |                 ws = websockets.connect(WS_URL) | ||||||
|                 await exec_ws_communication(ws) |                 # await exec_ws_communication(ws) | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 logger.error(f"Error while attempting to connect: {e}") |                 logger.error(f"Error while attempting to connect: {e}") | ||||||
|         else: |         else: | ||||||
|             print("websocket url is", WS_URL) |             print("websocket url is", WS_URL) | ||||||
|             while True: |             i = 0 | ||||||
|  |             # while True: | ||||||
|  |             #     try: | ||||||
|  |             #         i += 1 | ||||||
|  |             #         print("i is", i) | ||||||
|  | 
 | ||||||
|  |             #         # Hit the /ping endpoint | ||||||
|  |             #         ping_url = f"http://{self.server_url}/ping" | ||||||
|  |             #         response = requests.get(ping_url) | ||||||
|  |             #         print(response.text) | ||||||
|  |             #         # async with aiohttp.ClientSession() as session: | ||||||
|  |             #         #     async with session.get(ping_url) as response: | ||||||
|  |             #         #         print(f"Ping response: {await response.text()}") | ||||||
|  | 
 | ||||||
|  |             for i in range(3): | ||||||
|  |                 print(i) | ||||||
|                 try: |                 try: | ||||||
|                     async with websockets.connect(WS_URL) as websocket: |                     async with websockets.connect(WS_URL) as websocket: | ||||||
|                         print("awaiting exec_ws_communication") |                         print("happi happi happi :DDDDDDDDDDDDD") | ||||||
|                         await exec_ws_communication(websocket) |                         # await exec_ws_communication(websocket) | ||||||
|  |                         # print("exiting exec_ws_communication") | ||||||
|                 except: |                 except: | ||||||
|                     logger.info(traceback.format_exc()) |                     print("exception in websocket communication!!!!!!!!!!!!!!!!!") | ||||||
|                     if show_connection_log: |                     traceback.print_exc() | ||||||
|                         logger.info(f"Connecting to `{WS_URL}`...") | 
 | ||||||
|                         show_connection_log = False |                 # except: | ||||||
|                         await asyncio.sleep(2) |                 #     print("exception in websocket communication!!!!!!!!!!!!!!!!!") | ||||||
|  |                 #     traceback.print_exc() | ||||||
|  |                 #     if show_connection_log: | ||||||
|  |                 #         logger.info(f"Connecting to `{WS_URL}`...") | ||||||
|  |                 #         show_connection_log = False | ||||||
|  |                 #         await asyncio.sleep(2) | ||||||
| 
 | 
 | ||||||
|     async def start_async(self): |     async def start_async(self): | ||||||
|         print("start async was called!!!!!") |         print("start async was called!!!!!") | ||||||
|         # Configuration for WebSocket |         # Configuration for WebSocket | ||||||
|         WS_URL = f"ws://{self.server_url}" |         WS_URL = f"ws://{self.server_url}/" | ||||||
| 
 | 
 | ||||||
|         # Start the WebSocket communication |         # Start the WebSocket communication | ||||||
|         asyncio.create_task(self.websocket_communication(WS_URL)) |         await self.websocket_communication(WS_URL) | ||||||
| 
 | 
 | ||||||
|  |         """ | ||||||
|         # Start watching the kernel if it's your job to do that |         # Start watching the kernel if it's your job to do that | ||||||
|         if os.getenv("CODE_RUNNER") == "client": |         if os.getenv("CODE_RUNNER") == "client": | ||||||
|             asyncio.create_task(put_kernel_messages_into_queue(send_queue)) |             asyncio.create_task(put_kernel_messages_into_queue(send_queue)) | ||||||
| 
 | 
 | ||||||
|         asyncio.create_task(self.play_audiosegments()) |         #asyncio.create_task(self.play_audiosegments()) | ||||||
| 
 | 
 | ||||||
|         # If Raspberry Pi, add the button listener, otherwise use the spacebar |         # If Raspberry Pi, add the button listener, otherwise use the spacebar | ||||||
|         if current_platform.startswith("raspberry-pi"): |         if current_platform.startswith("raspberry-pi"): | ||||||
|  | @ -483,12 +507,11 @@ class Device: | ||||||
|                 else: |                 else: | ||||||
|                     break |                     break | ||||||
|         else: |         else: | ||||||
|             # Keyboard listener for spacebar press/release |         """ | ||||||
|             listener = keyboard.Listener( |         # Keyboard listener for spacebar press/release | ||||||
|                 on_press=self.on_press, on_release=self.on_release |         # listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) | ||||||
|             ) |         # listener.start() | ||||||
|             listener.start() |         # print("listener for keyboard started!!!!!") | ||||||
|             print("listener for keyboard started!!!!!") |  | ||||||
| 
 | 
 | ||||||
|     def start(self): |     def start(self): | ||||||
|         print("device was started!!!!!!") |         print("device was started!!!!!!") | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ class AsyncInterpreter: | ||||||
|         self.stt = AudioToTextRecorder( |         self.stt = AudioToTextRecorder( | ||||||
|             model="tiny.en", spinner=False, use_microphone=False |             model="tiny.en", spinner=False, use_microphone=False | ||||||
|         ) |         ) | ||||||
|         self.stt.stop()  # It needs this for some reason |         self.stt.stop() | ||||||
| 
 | 
 | ||||||
|         # TTS |         # TTS | ||||||
|         if self.interpreter.tts == "coqui": |         if self.interpreter.tts == "coqui": | ||||||
|  | @ -118,8 +118,6 @@ class AsyncInterpreter: | ||||||
|         """ |         """ | ||||||
|         self.interpreter.messages = self.active_chat_messages |         self.interpreter.messages = self.active_chat_messages | ||||||
| 
 | 
 | ||||||
|         # self.beeper.start() |  | ||||||
| 
 |  | ||||||
|         self.stt.stop() |         self.stt.stop() | ||||||
|         # message = self.stt.text() |         # message = self.stt.text() | ||||||
|         # print("THE MESSAGE:", message) |         # print("THE MESSAGE:", message) | ||||||
|  | @ -137,15 +135,9 @@ class AsyncInterpreter: | ||||||
|         self.stt_latency = end_stt - start_stt |         self.stt_latency = end_stt - start_stt | ||||||
|         print("STT LATENCY", self.stt_latency) |         print("STT LATENCY", self.stt_latency) | ||||||
| 
 | 
 | ||||||
|         # print(message) |  | ||||||
|         end_interpreter = 0 |  | ||||||
| 
 |  | ||||||
|         # print(message) |  | ||||||
|         def generate(message): |         def generate(message): | ||||||
|             last_lmc_start_flag = self._last_lmc_start_flag |             last_lmc_start_flag = self._last_lmc_start_flag | ||||||
|             self.interpreter.messages = self.active_chat_messages |             self.interpreter.messages = self.active_chat_messages | ||||||
|             # print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages) |  | ||||||
|             # print("🍀   🍀   🍀   🍀 active_chat_messages: ", self.active_chat_messages) |  | ||||||
|             print("message is", message) |             print("message is", message) | ||||||
| 
 | 
 | ||||||
|             for chunk in self.interpreter.chat(message, display=True, stream=True): |             for chunk in self.interpreter.chat(message, display=True, stream=True): | ||||||
|  | @ -209,7 +201,7 @@ class AsyncInterpreter: | ||||||
|         text_iterator = generate(message) |         text_iterator = generate(message) | ||||||
| 
 | 
 | ||||||
|         self.tts.feed(text_iterator) |         self.tts.feed(text_iterator) | ||||||
|         self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True) |         self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=False) | ||||||
| 
 | 
 | ||||||
|         while True: |         while True: | ||||||
|             if self.tts.is_playing(): |             if self.tts.is_playing(): | ||||||
|  | @ -236,7 +228,7 @@ class AsyncInterpreter: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|     async def _on_tts_chunk_async(self, chunk): |     async def _on_tts_chunk_async(self, chunk): | ||||||
|         # print("SENDING TTS CHUNK") |         print(f"Adding chunk to output queue") | ||||||
|         await self._add_to_queue(self._output_queue, chunk) |         await self._add_to_queue(self._output_queue, chunk) | ||||||
| 
 | 
 | ||||||
|     def on_tts_chunk(self, chunk): |     def on_tts_chunk(self, chunk): | ||||||
|  | @ -244,4 +236,7 @@ class AsyncInterpreter: | ||||||
|         asyncio.run(self._on_tts_chunk_async(chunk)) |         asyncio.run(self._on_tts_chunk_async(chunk)) | ||||||
| 
 | 
 | ||||||
|     async def output(self): |     async def output(self): | ||||||
|         return await self._output_queue.get() |         print("entering output method") | ||||||
|  |         value = await self._output_queue.get() | ||||||
|  |         print("output method returning") | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ from fastapi import FastAPI, WebSocket, Header | ||||||
| from fastapi.responses import PlainTextResponse | from fastapi.responses import PlainTextResponse | ||||||
| from uvicorn import Config, Server | from uvicorn import Config, Server | ||||||
| from interpreter import interpreter as base_interpreter | from interpreter import interpreter as base_interpreter | ||||||
|  | from starlette.websockets import WebSocketDisconnect | ||||||
| from .async_interpreter import AsyncInterpreter | from .async_interpreter import AsyncInterpreter | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from typing import List, Dict, Any | from typing import List, Dict, Any | ||||||
|  | @ -23,18 +24,11 @@ base_interpreter.llm.model = "groq/llama3-8b-8192" | ||||||
| base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] | base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"] | ||||||
| base_interpreter.llm.supports_functions = False | base_interpreter.llm.supports_functions = False | ||||||
| base_interpreter.auto_run = True | base_interpreter.auto_run = True | ||||||
|  | base_interpreter.tts = "elevenlabs" | ||||||
| 
 | 
 | ||||||
| os.environ["STT_RUNNER"] = "server" | os.environ["STT_RUNNER"] = "server" | ||||||
| os.environ["TTS_RUNNER"] = "server" | os.environ["TTS_RUNNER"] = "server" | ||||||
| 
 | 
 | ||||||
| # Parse command line arguments for port number |  | ||||||
| """ |  | ||||||
| parser = argparse.ArgumentParser(description="FastAPI server.") |  | ||||||
| parser.add_argument("--port", type=int, default=8000, help="Port to run on.") |  | ||||||
| args = parser.parse_args() |  | ||||||
| """ |  | ||||||
| base_interpreter.tts = "coqui" |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| async def main(server_host, server_port): | async def main(server_host, server_port): | ||||||
|     interpreter = AsyncInterpreter(base_interpreter) |     interpreter = AsyncInterpreter(base_interpreter) | ||||||
|  | @ -60,84 +54,111 @@ async def main(server_host, server_port): | ||||||
|         print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.active_chat_messages) |         print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.active_chat_messages) | ||||||
|         return {"status": "success"} |         return {"status": "success"} | ||||||
| 
 | 
 | ||||||
|  |     print("About to set up the websocker endpoint!!!!!!!!!!!!!!!!!!!!!!!!!") | ||||||
|  | 
 | ||||||
|     @app.websocket("/") |     @app.websocket("/") | ||||||
|     async def websocket_endpoint(websocket: WebSocket): |     async def websocket_endpoint(websocket: WebSocket): | ||||||
|  |         print("websocket hit") | ||||||
|         await websocket.accept() |         await websocket.accept() | ||||||
|         try: |         print("websocket accepted") | ||||||
| 
 | 
 | ||||||
|             async def receive_input(): |         async def send_output(): | ||||||
|                 while True: |             try: | ||||||
|                     if websocket.client_state == "DISCONNECTED": |  | ||||||
|                         break |  | ||||||
| 
 |  | ||||||
|                     data = await websocket.receive() |  | ||||||
| 
 |  | ||||||
|                     if isinstance(data, bytes): |  | ||||||
|                         await interpreter.input(data) |  | ||||||
|                     elif "bytes" in data: |  | ||||||
|                         await interpreter.input(data["bytes"]) |  | ||||||
|                         # print("RECEIVED INPUT", data) |  | ||||||
|                     elif "text" in data: |  | ||||||
|                         # print("RECEIVED INPUT", data) |  | ||||||
|                         await interpreter.input(data["text"]) |  | ||||||
| 
 |  | ||||||
|             async def send_output(): |  | ||||||
|                 while True: |                 while True: | ||||||
|                     output = await interpreter.output() |                     output = await interpreter.output() | ||||||
| 
 | 
 | ||||||
|                     if isinstance(output, bytes): |                     if isinstance(output, bytes): | ||||||
|                         # print(f"Sending {len(output)} bytes of audio data.") |                         print("server sending bytes output") | ||||||
|                         await websocket.send_bytes(output) |                         try: | ||||||
|                         # we dont send out bytes rn, no TTS |                             await websocket.send_bytes(output) | ||||||
|  |                             print("server successfully sent bytes output") | ||||||
|  |                         except Exception as e: | ||||||
|  |                             print(f"Error: {e}") | ||||||
|  |                             traceback.print_exc() | ||||||
|  |                             return {"error": str(e)} | ||||||
| 
 | 
 | ||||||
|                     elif isinstance(output, dict): |                     elif isinstance(output, dict): | ||||||
|                         # print("sending text") |                         print("server sending text output") | ||||||
|                         await websocket.send_text(json.dumps(output)) |                         try: | ||||||
|  |                             await websocket.send_text(json.dumps(output)) | ||||||
|  |                             print("server successfully sent text output") | ||||||
|  |                         except Exception as e: | ||||||
|  |                             print(f"Error: {e}") | ||||||
|  |                             traceback.print_exc() | ||||||
|  |                             return {"error": str(e)} | ||||||
|  |             except asyncio.CancelledError: | ||||||
|  |                 print("WebSocket connection closed") | ||||||
|  |                 traceback.print_exc() | ||||||
| 
 | 
 | ||||||
|             await asyncio.gather(send_output(), receive_input()) |         async def receive_input(): | ||||||
|  |             try: | ||||||
|  |                 while True: | ||||||
|  |                     print("server awaiting input") | ||||||
|  |                     data = await websocket.receive() | ||||||
|  | 
 | ||||||
|  |                     if isinstance(data, bytes): | ||||||
|  |                         try: | ||||||
|  |                             await interpreter.input(data) | ||||||
|  |                         except Exception as e: | ||||||
|  |                             print(f"Error: {e}") | ||||||
|  |                             traceback.print_exc() | ||||||
|  |                             return {"error": str(e)} | ||||||
|  | 
 | ||||||
|  |                     elif "bytes" in data: | ||||||
|  |                         try: | ||||||
|  |                             await interpreter.input(data["bytes"]) | ||||||
|  |                         except Exception as e: | ||||||
|  |                             print(f"Error: {e}") | ||||||
|  |                             traceback.print_exc() | ||||||
|  |                             return {"error": str(e)} | ||||||
|  | 
 | ||||||
|  |                     elif "text" in data: | ||||||
|  |                         try: | ||||||
|  |                             await interpreter.input(data["text"]) | ||||||
|  |                         except Exception as e: | ||||||
|  |                             print(f"Error: {e}") | ||||||
|  |                             traceback.print_exc() | ||||||
|  |                             return {"error": str(e)} | ||||||
|  |             except asyncio.CancelledError: | ||||||
|  |                 print("WebSocket connection closed") | ||||||
|  |                 traceback.print_exc() | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             send_task = asyncio.create_task(send_output()) | ||||||
|  |             receive_task = asyncio.create_task(receive_input()) | ||||||
|  | 
 | ||||||
|  |             print("server starting to handle ws connection") | ||||||
|  |             """ | ||||||
|  |             done, pending = await asyncio.wait( | ||||||
|  |                 [send_task, receive_task], | ||||||
|  |                 return_when=asyncio.FIRST_COMPLETED, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             for task in pending: | ||||||
|  |                 task.cancel() | ||||||
|  | 
 | ||||||
|  |             for task in done: | ||||||
|  |                 if task.exception() is not None: | ||||||
|  |                     raise | ||||||
|  |             """ | ||||||
|  |             await asyncio.gather(send_task, receive_task) | ||||||
|  | 
 | ||||||
|  |             print("server finished handling ws connection") | ||||||
|  | 
 | ||||||
|  |         except WebSocketDisconnect: | ||||||
|  |             print("WebSocket disconnected") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"WebSocket connection closed with exception: {e}") |             print(f"WebSocket connection closed with exception: {e}") | ||||||
|             traceback.print_exc() |             traceback.print_exc() | ||||||
|         finally: |         finally: | ||||||
|             if not websocket.client_state == "DISCONNECTED": |             print("server closing ws connection") | ||||||
|                 await websocket.close() |             await websocket.close() | ||||||
| 
 | 
 | ||||||
|     print(f"Starting server on {server_host}:{server_port}") |     print(f"Starting server on {server_host}:{server_port}") | ||||||
|     config = Config(app, host=server_host, port=server_port, lifespan="on") |     config = Config(app, host=server_host, port=server_port, lifespan="on") | ||||||
|     server = Server(config) |     server = Server(config) | ||||||
|     await server.serve() |     await server.serve() | ||||||
| 
 | 
 | ||||||
|     class Rename(BaseModel): |  | ||||||
|         input: str |  | ||||||
| 
 |  | ||||||
|     @app.post("/rename-chat") |  | ||||||
|     async def rename_chat(body_content: Rename, x_api_key: str = Header(None)): |  | ||||||
|         print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙") |  | ||||||
|         input_value = body_content.input |  | ||||||
|         client = OpenAI( |  | ||||||
|             # defaults to os.environ.get("OPENAI_API_KEY") |  | ||||||
|             api_key=x_api_key, |  | ||||||
|         ) |  | ||||||
|         try: |  | ||||||
|             response = client.chat.completions.create( |  | ||||||
|                 model="gpt-3.5-turbo", |  | ||||||
|                 messages=[ |  | ||||||
|                     { |  | ||||||
|                         "role": "user", |  | ||||||
|                         "content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}", |  | ||||||
|                     } |  | ||||||
|                 ], |  | ||||||
|                 temperature=0.3, |  | ||||||
|                 stream=False, |  | ||||||
|             ) |  | ||||||
|             print(response) |  | ||||||
|             completion = response["choices"][0]["message"]["content"] |  | ||||||
|             return {"data": {"content": completion}} |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"Error: {e}") |  | ||||||
|             traceback.print_exc() |  | ||||||
|             return {"error": str(e)} |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     asyncio.run(main()) |     asyncio.run(main("localhost", 8000)) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Ben Xu
						Ben Xu