merge server from temp-branch
This commit is contained in:
parent
375ed1f575
commit
456ac51634
|
@ -1,12 +1,19 @@
|
||||||
|
# make this obvious
|
||||||
|
from .profiles.default import interpreter as base_interpreter
|
||||||
|
|
||||||
|
# from .profiles.fast import interpreter as base_interpreter
|
||||||
|
# from .profiles.local import interpreter as base_interpreter
|
||||||
|
|
||||||
|
# TODO: remove files i.py, llm.py, conftest?, services
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
from fastapi import FastAPI, WebSocket
|
from fastapi import FastAPI, WebSocket
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
from .i import configure_interpreter
|
|
||||||
from interpreter import interpreter as base_interpreter
|
# from interpreter import interpreter as base_interpreter
|
||||||
from starlette.websockets import WebSocketDisconnect
|
|
||||||
from .async_interpreter import AsyncInterpreter
|
from .async_interpreter import AsyncInterpreter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
|
||||||
os.environ["TTS_RUNNER"] = "server"
|
os.environ["TTS_RUNNER"] = "server"
|
||||||
|
|
||||||
|
|
||||||
async def main(server_host, server_port, tts_service, asynchronous):
|
async def main(server_host, server_port, tts_service):
|
||||||
if asynchronous:
|
base_interpreter.tts = tts_service
|
||||||
base_interpreter.system_message = (
|
interpreter = AsyncInterpreter(base_interpreter)
|
||||||
"You are a helpful assistant that can answer questions and help with tasks."
|
|
||||||
)
|
|
||||||
base_interpreter.computer.import_computer_api = False
|
|
||||||
base_interpreter.llm.model = "groq/llama3-8b-8192"
|
|
||||||
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
|
|
||||||
base_interpreter.llm.supports_functions = False
|
|
||||||
base_interpreter.auto_run = True
|
|
||||||
base_interpreter.tts = tts_service
|
|
||||||
interpreter = AsyncInterpreter(base_interpreter)
|
|
||||||
else:
|
|
||||||
configured_interpreter = configure_interpreter(base_interpreter)
|
|
||||||
configured_interpreter.llm.supports_functions = True
|
|
||||||
configured_interpreter.tts = tts_service
|
|
||||||
interpreter = AsyncInterpreter(configured_interpreter)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -59,79 +52,44 @@ async def main(server_host, server_port, tts_service, asynchronous):
|
||||||
@app.websocket("/")
|
@app.websocket("/")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
|
||||||
async def send_output():
|
async def receive_input():
|
||||||
try:
|
while True:
|
||||||
|
if websocket.client_state == "DISCONNECTED":
|
||||||
|
break
|
||||||
|
|
||||||
|
data = await websocket.receive()
|
||||||
|
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
await interpreter.input(data)
|
||||||
|
elif "bytes" in data:
|
||||||
|
await interpreter.input(data["bytes"])
|
||||||
|
# print("RECEIVED INPUT", data)
|
||||||
|
elif "text" in data:
|
||||||
|
# print("RECEIVED INPUT", data)
|
||||||
|
await interpreter.input(data["text"])
|
||||||
|
|
||||||
|
async def send_output():
|
||||||
while True:
|
while True:
|
||||||
output = await interpreter.output()
|
output = await interpreter.output()
|
||||||
|
|
||||||
if isinstance(output, bytes):
|
if isinstance(output, bytes):
|
||||||
try:
|
# print(f"Sending {len(output)} bytes of audio data.")
|
||||||
await websocket.send_bytes(output)
|
await websocket.send_bytes(output)
|
||||||
except Exception as e:
|
# we dont send out bytes rn, no TTS
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
elif isinstance(output, dict):
|
elif isinstance(output, dict):
|
||||||
try:
|
# print("sending text")
|
||||||
await websocket.send_text(json.dumps(output))
|
await websocket.send_text(json.dumps(output))
|
||||||
|
|
||||||
except Exception as e:
|
await asyncio.gather(send_output(), receive_input())
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
print("WebSocket connection closed")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def receive_input():
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# print("server awaiting input")
|
|
||||||
data = await websocket.receive()
|
|
||||||
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
try:
|
|
||||||
await interpreter.input(data)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
elif "bytes" in data:
|
|
||||||
try:
|
|
||||||
await interpreter.input(data["bytes"])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
elif "text" in data:
|
|
||||||
try:
|
|
||||||
await interpreter.input(data["text"])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
print("WebSocket connection closed")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
try:
|
|
||||||
send_task = asyncio.create_task(send_output())
|
|
||||||
receive_task = asyncio.create_task(receive_input())
|
|
||||||
|
|
||||||
await asyncio.gather(send_task, receive_task)
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
print("WebSocket disconnected")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WebSocket connection closed with exception: {e}")
|
print(f"WebSocket connection closed with exception: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
print("server closing ws connection")
|
if not websocket.client_state == "DISCONNECTED":
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
print(f"Starting server on {server_host}:{server_port}")
|
print(f"Starting server on {server_host}:{server_port}")
|
||||||
config = Config(app, host=server_host, port=server_port, lifespan="on")
|
config = Config(app, host=server_host, port=server_port, lifespan="on")
|
||||||
|
@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main("localhost", 8000))
|
asyncio.run(main())
|
||||||
|
|
|
@ -77,9 +77,6 @@ def run(
|
||||||
mobile: bool = typer.Option(
|
mobile: bool = typer.Option(
|
||||||
False, "--mobile", help="Toggle server to support mobile app"
|
False, "--mobile", help="Toggle server to support mobile app"
|
||||||
),
|
),
|
||||||
asynchronous: bool = typer.Option(
|
|
||||||
False, "--async", help="use interpreter optimized for latency"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
_run(
|
_run(
|
||||||
server=server or mobile,
|
server=server or mobile,
|
||||||
|
@ -102,7 +99,6 @@ def run(
|
||||||
local=local,
|
local=local,
|
||||||
qr=qr or mobile,
|
qr=qr or mobile,
|
||||||
mobile=mobile,
|
mobile=mobile,
|
||||||
asynchronous=asynchronous,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +123,6 @@ def _run(
|
||||||
local: bool = False,
|
local: bool = False,
|
||||||
qr: bool = False,
|
qr: bool = False,
|
||||||
mobile: bool = False,
|
mobile: bool = False,
|
||||||
asynchronous: bool = False,
|
|
||||||
):
|
):
|
||||||
if local:
|
if local:
|
||||||
tts_service = "coqui"
|
tts_service = "coqui"
|
||||||
|
@ -162,7 +157,6 @@ def _run(
|
||||||
server_host,
|
server_host,
|
||||||
server_port,
|
server_port,
|
||||||
tts_service,
|
tts_service,
|
||||||
asynchronous,
|
|
||||||
# llm_service,
|
# llm_service,
|
||||||
# model,
|
# model,
|
||||||
# llm_supports_vision,
|
# llm_supports_vision,
|
||||||
|
|
Loading…
Reference in New Issue