merge server from temp-branch
This commit is contained in:
		
							parent
							
								
									375ed1f575
								
							
						
					
					
						commit
						456ac51634
					
				| 
						 | 
				
			
			@ -1,12 +1,19 @@
 | 
			
		|||
# make this obvious
 | 
			
		||||
from .profiles.default import interpreter as base_interpreter
 | 
			
		||||
 | 
			
		||||
# from .profiles.fast import interpreter as base_interpreter
 | 
			
		||||
# from .profiles.local import interpreter as base_interpreter
 | 
			
		||||
 | 
			
		||||
# TODO: remove files i.py, llm.py, conftest?, services
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import traceback
 | 
			
		||||
import json
 | 
			
		||||
from fastapi import FastAPI, WebSocket
 | 
			
		||||
from fastapi.responses import PlainTextResponse
 | 
			
		||||
from uvicorn import Config, Server
 | 
			
		||||
from .i import configure_interpreter
 | 
			
		||||
from interpreter import interpreter as base_interpreter
 | 
			
		||||
from starlette.websockets import WebSocketDisconnect
 | 
			
		||||
 | 
			
		||||
# from interpreter import interpreter as base_interpreter
 | 
			
		||||
from .async_interpreter import AsyncInterpreter
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from typing import List, Dict, Any
 | 
			
		||||
| 
						 | 
				
			
			@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
 | 
			
		|||
os.environ["TTS_RUNNER"] = "server"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def main(server_host, server_port, tts_service, asynchronous):
 | 
			
		||||
    if asynchronous:
 | 
			
		||||
        base_interpreter.system_message = (
 | 
			
		||||
            "You are a helpful assistant that can answer questions and help with tasks."
 | 
			
		||||
        )
 | 
			
		||||
        base_interpreter.computer.import_computer_api = False
 | 
			
		||||
        base_interpreter.llm.model = "groq/llama3-8b-8192"
 | 
			
		||||
        base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
 | 
			
		||||
        base_interpreter.llm.supports_functions = False
 | 
			
		||||
        base_interpreter.auto_run = True
 | 
			
		||||
        base_interpreter.tts = tts_service
 | 
			
		||||
        interpreter = AsyncInterpreter(base_interpreter)
 | 
			
		||||
    else:
 | 
			
		||||
        configured_interpreter = configure_interpreter(base_interpreter)
 | 
			
		||||
        configured_interpreter.llm.supports_functions = True
 | 
			
		||||
        configured_interpreter.tts = tts_service
 | 
			
		||||
        interpreter = AsyncInterpreter(configured_interpreter)
 | 
			
		||||
async def main(server_host, server_port, tts_service):
 | 
			
		||||
    base_interpreter.tts = tts_service
 | 
			
		||||
    interpreter = AsyncInterpreter(base_interpreter)
 | 
			
		||||
 | 
			
		||||
    app = FastAPI()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -59,79 +52,44 @@ async def main(server_host, server_port, tts_service, asynchronous):
 | 
			
		|||
    @app.websocket("/")
 | 
			
		||||
    async def websocket_endpoint(websocket: WebSocket):
 | 
			
		||||
        await websocket.accept()
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
        async def send_output():
 | 
			
		||||
            try:
 | 
			
		||||
            async def receive_input():
 | 
			
		||||
                while True:
 | 
			
		||||
                    if websocket.client_state == "DISCONNECTED":
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                    data = await websocket.receive()
 | 
			
		||||
 | 
			
		||||
                    if isinstance(data, bytes):
 | 
			
		||||
                        await interpreter.input(data)
 | 
			
		||||
                    elif "bytes" in data:
 | 
			
		||||
                        await interpreter.input(data["bytes"])
 | 
			
		||||
                        # print("RECEIVED INPUT", data)
 | 
			
		||||
                    elif "text" in data:
 | 
			
		||||
                        # print("RECEIVED INPUT", data)
 | 
			
		||||
                        await interpreter.input(data["text"])
 | 
			
		||||
 | 
			
		||||
            async def send_output():
 | 
			
		||||
                while True:
 | 
			
		||||
                    output = await interpreter.output()
 | 
			
		||||
 | 
			
		||||
                    if isinstance(output, bytes):
 | 
			
		||||
                        try:
 | 
			
		||||
                            await websocket.send_bytes(output)
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error: {e}")
 | 
			
		||||
                            traceback.print_exc()
 | 
			
		||||
                            return {"error": str(e)}
 | 
			
		||||
                        # print(f"Sending {len(output)} bytes of audio data.")
 | 
			
		||||
                        await websocket.send_bytes(output)
 | 
			
		||||
                        # we dont send out bytes rn, no TTS
 | 
			
		||||
 | 
			
		||||
                    elif isinstance(output, dict):
 | 
			
		||||
                        try:
 | 
			
		||||
                            await websocket.send_text(json.dumps(output))
 | 
			
		||||
                        # print("sending text")
 | 
			
		||||
                        await websocket.send_text(json.dumps(output))
 | 
			
		||||
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error: {e}")
 | 
			
		||||
                            traceback.print_exc()
 | 
			
		||||
                            return {"error": str(e)}
 | 
			
		||||
            except asyncio.CancelledError:
 | 
			
		||||
                print("WebSocket connection closed")
 | 
			
		||||
                traceback.print_exc()
 | 
			
		||||
 | 
			
		||||
        async def receive_input():
 | 
			
		||||
            try:
 | 
			
		||||
                while True:
 | 
			
		||||
                    # print("server awaiting input")
 | 
			
		||||
                    data = await websocket.receive()
 | 
			
		||||
 | 
			
		||||
                    if isinstance(data, bytes):
 | 
			
		||||
                        try:
 | 
			
		||||
                            await interpreter.input(data)
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error: {e}")
 | 
			
		||||
                            traceback.print_exc()
 | 
			
		||||
                            return {"error": str(e)}
 | 
			
		||||
 | 
			
		||||
                    elif "bytes" in data:
 | 
			
		||||
                        try:
 | 
			
		||||
                            await interpreter.input(data["bytes"])
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error: {e}")
 | 
			
		||||
                            traceback.print_exc()
 | 
			
		||||
                            return {"error": str(e)}
 | 
			
		||||
 | 
			
		||||
                    elif "text" in data:
 | 
			
		||||
                        try:
 | 
			
		||||
                            await interpreter.input(data["text"])
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(f"Error: {e}")
 | 
			
		||||
                            traceback.print_exc()
 | 
			
		||||
                            return {"error": str(e)}
 | 
			
		||||
            except asyncio.CancelledError:
 | 
			
		||||
                print("WebSocket connection closed")
 | 
			
		||||
                traceback.print_exc()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            send_task = asyncio.create_task(send_output())
 | 
			
		||||
            receive_task = asyncio.create_task(receive_input())
 | 
			
		||||
 | 
			
		||||
            await asyncio.gather(send_task, receive_task)
 | 
			
		||||
 | 
			
		||||
        except WebSocketDisconnect:
 | 
			
		||||
            print("WebSocket disconnected")
 | 
			
		||||
            await asyncio.gather(send_output(), receive_input())
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"WebSocket connection closed with exception: {e}")
 | 
			
		||||
            traceback.print_exc()
 | 
			
		||||
        finally:
 | 
			
		||||
            print("server closing ws connection")
 | 
			
		||||
            await websocket.close()
 | 
			
		||||
            if not websocket.client_state == "DISCONNECTED":
 | 
			
		||||
                await websocket.close()
 | 
			
		||||
 | 
			
		||||
    print(f"Starting server on {server_host}:{server_port}")
 | 
			
		||||
    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
			
		||||
| 
						 | 
				
			
			@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main("localhost", 8000))
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -77,9 +77,6 @@ def run(
 | 
			
		|||
    mobile: bool = typer.Option(
 | 
			
		||||
        False, "--mobile", help="Toggle server to support mobile app"
 | 
			
		||||
    ),
 | 
			
		||||
    asynchronous: bool = typer.Option(
 | 
			
		||||
        False, "--async", help="use interpreter optimized for latency"
 | 
			
		||||
    ),
 | 
			
		||||
):
 | 
			
		||||
    _run(
 | 
			
		||||
        server=server or mobile,
 | 
			
		||||
| 
						 | 
				
			
			@ -102,7 +99,6 @@ def run(
 | 
			
		|||
        local=local,
 | 
			
		||||
        qr=qr or mobile,
 | 
			
		||||
        mobile=mobile,
 | 
			
		||||
        asynchronous=asynchronous,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -127,7 +123,6 @@ def _run(
 | 
			
		|||
    local: bool = False,
 | 
			
		||||
    qr: bool = False,
 | 
			
		||||
    mobile: bool = False,
 | 
			
		||||
    asynchronous: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    if local:
 | 
			
		||||
        tts_service = "coqui"
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +157,6 @@ def _run(
 | 
			
		|||
                    server_host,
 | 
			
		||||
                    server_port,
 | 
			
		||||
                    tts_service,
 | 
			
		||||
                    asynchronous,
 | 
			
		||||
                    # llm_service,
 | 
			
		||||
                    # model,
 | 
			
		||||
                    # llm_supports_vision,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue