merge server from temp-branch

This commit is contained in:
Ben Xu 2024-06-18 12:35:13 -07:00
parent 375ed1f575
commit 456ac51634
2 changed files with 40 additions and 88 deletions

View File

@ -1,12 +1,19 @@
# make this obvious
from .profiles.default import interpreter as base_interpreter
# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
# TODO: remove files i.py, llm.py, conftest?, services
import asyncio import asyncio
import traceback import traceback
import json import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from .i import configure_interpreter
from interpreter import interpreter as base_interpreter # from interpreter import interpreter as base_interpreter
from starlette.websockets import WebSocketDisconnect
from .async_interpreter import AsyncInterpreter from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any from typing import List, Dict, Any
@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
async def main(server_host, server_port, tts_service, asynchronous): async def main(server_host, server_port, tts_service):
if asynchronous: base_interpreter.tts = tts_service
base_interpreter.system_message = ( interpreter = AsyncInterpreter(base_interpreter)
"You are a helpful assistant that can answer questions and help with tasks."
)
base_interpreter.computer.import_computer_api = False
base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True
base_interpreter.tts = tts_service
interpreter = AsyncInterpreter(base_interpreter)
else:
configured_interpreter = configure_interpreter(base_interpreter)
configured_interpreter.llm.supports_functions = True
configured_interpreter.tts = tts_service
interpreter = AsyncInterpreter(configured_interpreter)
app = FastAPI() app = FastAPI()
@ -59,79 +52,44 @@ async def main(server_host, server_port, tts_service, asynchronous):
@app.websocket("/") @app.websocket("/")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
try:
async def send_output(): async def receive_input():
try: while True:
if websocket.client_state == "DISCONNECTED":
break
data = await websocket.receive()
if isinstance(data, bytes):
await interpreter.input(data)
elif "bytes" in data:
await interpreter.input(data["bytes"])
# print("RECEIVED INPUT", data)
elif "text" in data:
# print("RECEIVED INPUT", data)
await interpreter.input(data["text"])
async def send_output():
while True: while True:
output = await interpreter.output() output = await interpreter.output()
if isinstance(output, bytes): if isinstance(output, bytes):
try: # print(f"Sending {len(output)} bytes of audio data.")
await websocket.send_bytes(output) await websocket.send_bytes(output)
except Exception as e: # we dont send out bytes rn, no TTS
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif isinstance(output, dict): elif isinstance(output, dict):
try: # print("sending text")
await websocket.send_text(json.dumps(output)) await websocket.send_text(json.dumps(output))
except Exception as e: await asyncio.gather(send_output(), receive_input())
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
async def receive_input():
try:
while True:
# print("server awaiting input")
data = await websocket.receive()
if isinstance(data, bytes):
try:
await interpreter.input(data)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "bytes" in data:
try:
await interpreter.input(data["bytes"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "text" in data:
try:
await interpreter.input(data["text"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
try:
send_task = asyncio.create_task(send_output())
receive_task = asyncio.create_task(receive_input())
await asyncio.gather(send_task, receive_task)
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e: except Exception as e:
print(f"WebSocket connection closed with exception: {e}") print(f"WebSocket connection closed with exception: {e}")
traceback.print_exc() traceback.print_exc()
finally: finally:
print("server closing ws connection") if not websocket.client_state == "DISCONNECTED":
await websocket.close() await websocket.close()
print(f"Starting server on {server_host}:{server_port}") print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on") config = Config(app, host=server_host, port=server_port, lifespan="on")
@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main("localhost", 8000)) asyncio.run(main())

View File

@ -77,9 +77,6 @@ def run(
mobile: bool = typer.Option( mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app" False, "--mobile", help="Toggle server to support mobile app"
), ),
asynchronous: bool = typer.Option(
False, "--async", help="use interpreter optimized for latency"
),
): ):
_run( _run(
server=server or mobile, server=server or mobile,
@ -102,7 +99,6 @@ def run(
local=local, local=local,
qr=qr or mobile, qr=qr or mobile,
mobile=mobile, mobile=mobile,
asynchronous=asynchronous,
) )
@ -127,7 +123,6 @@ def _run(
local: bool = False, local: bool = False,
qr: bool = False, qr: bool = False,
mobile: bool = False, mobile: bool = False,
asynchronous: bool = False,
): ):
if local: if local:
tts_service = "coqui" tts_service = "coqui"
@ -162,7 +157,6 @@ def _run(
server_host, server_host,
server_port, server_port,
tts_service, tts_service,
asynchronous,
# llm_service, # llm_service,
# model, # model,
# llm_supports_vision, # llm_supports_vision,