merge server from temp-branch
This commit is contained in:
		
							parent
							
								
									375ed1f575
								
							
						
					
					
						commit
						456ac51634
					
				| 
						 | 
					@ -1,12 +1,19 @@
 | 
				
			||||||
 | 
					# make this obvious
 | 
				
			||||||
 | 
					from .profiles.default import interpreter as base_interpreter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# from .profiles.fast import interpreter as base_interpreter
 | 
				
			||||||
 | 
					# from .profiles.local import interpreter as base_interpreter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: remove files i.py, llm.py, conftest?, services
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from fastapi import FastAPI, WebSocket
 | 
					from fastapi import FastAPI, WebSocket
 | 
				
			||||||
from fastapi.responses import PlainTextResponse
 | 
					from fastapi.responses import PlainTextResponse
 | 
				
			||||||
from uvicorn import Config, Server
 | 
					from uvicorn import Config, Server
 | 
				
			||||||
from .i import configure_interpreter
 | 
					
 | 
				
			||||||
from interpreter import interpreter as base_interpreter
 | 
					# from interpreter import interpreter as base_interpreter
 | 
				
			||||||
from starlette.websockets import WebSocketDisconnect
 | 
					 | 
				
			||||||
from .async_interpreter import AsyncInterpreter
 | 
					from .async_interpreter import AsyncInterpreter
 | 
				
			||||||
from fastapi.middleware.cors import CORSMiddleware
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
from typing import List, Dict, Any
 | 
					from typing import List, Dict, Any
 | 
				
			||||||
| 
						 | 
					@ -17,23 +24,9 @@ os.environ["STT_RUNNER"] = "server"
 | 
				
			||||||
os.environ["TTS_RUNNER"] = "server"
 | 
					os.environ["TTS_RUNNER"] = "server"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def main(server_host, server_port, tts_service, asynchronous):
 | 
					async def main(server_host, server_port, tts_service):
 | 
				
			||||||
    if asynchronous:
 | 
					    base_interpreter.tts = tts_service
 | 
				
			||||||
        base_interpreter.system_message = (
 | 
					    interpreter = AsyncInterpreter(base_interpreter)
 | 
				
			||||||
            "You are a helpful assistant that can answer questions and help with tasks."
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        base_interpreter.computer.import_computer_api = False
 | 
					 | 
				
			||||||
        base_interpreter.llm.model = "groq/llama3-8b-8192"
 | 
					 | 
				
			||||||
        base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
 | 
					 | 
				
			||||||
        base_interpreter.llm.supports_functions = False
 | 
					 | 
				
			||||||
        base_interpreter.auto_run = True
 | 
					 | 
				
			||||||
        base_interpreter.tts = tts_service
 | 
					 | 
				
			||||||
        interpreter = AsyncInterpreter(base_interpreter)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        configured_interpreter = configure_interpreter(base_interpreter)
 | 
					 | 
				
			||||||
        configured_interpreter.llm.supports_functions = True
 | 
					 | 
				
			||||||
        configured_interpreter.tts = tts_service
 | 
					 | 
				
			||||||
        interpreter = AsyncInterpreter(configured_interpreter)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    app = FastAPI()
 | 
					    app = FastAPI()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,79 +52,44 @@ async def main(server_host, server_port, tts_service, asynchronous):
 | 
				
			||||||
    @app.websocket("/")
 | 
					    @app.websocket("/")
 | 
				
			||||||
    async def websocket_endpoint(websocket: WebSocket):
 | 
					    async def websocket_endpoint(websocket: WebSocket):
 | 
				
			||||||
        await websocket.accept()
 | 
					        await websocket.accept()
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def send_output():
 | 
					            async def receive_input():
 | 
				
			||||||
            try:
 | 
					                while True:
 | 
				
			||||||
 | 
					                    if websocket.client_state == "DISCONNECTED":
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    data = await websocket.receive()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if isinstance(data, bytes):
 | 
				
			||||||
 | 
					                        await interpreter.input(data)
 | 
				
			||||||
 | 
					                    elif "bytes" in data:
 | 
				
			||||||
 | 
					                        await interpreter.input(data["bytes"])
 | 
				
			||||||
 | 
					                        # print("RECEIVED INPUT", data)
 | 
				
			||||||
 | 
					                    elif "text" in data:
 | 
				
			||||||
 | 
					                        # print("RECEIVED INPUT", data)
 | 
				
			||||||
 | 
					                        await interpreter.input(data["text"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            async def send_output():
 | 
				
			||||||
                while True:
 | 
					                while True:
 | 
				
			||||||
                    output = await interpreter.output()
 | 
					                    output = await interpreter.output()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if isinstance(output, bytes):
 | 
					                    if isinstance(output, bytes):
 | 
				
			||||||
                        try:
 | 
					                        # print(f"Sending {len(output)} bytes of audio data.")
 | 
				
			||||||
                            await websocket.send_bytes(output)
 | 
					                        await websocket.send_bytes(output)
 | 
				
			||||||
                        except Exception as e:
 | 
					                        # we dont send out bytes rn, no TTS
 | 
				
			||||||
                            print(f"Error: {e}")
 | 
					 | 
				
			||||||
                            traceback.print_exc()
 | 
					 | 
				
			||||||
                            return {"error": str(e)}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    elif isinstance(output, dict):
 | 
					                    elif isinstance(output, dict):
 | 
				
			||||||
                        try:
 | 
					                        # print("sending text")
 | 
				
			||||||
                            await websocket.send_text(json.dumps(output))
 | 
					                        await websocket.send_text(json.dumps(output))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        except Exception as e:
 | 
					            await asyncio.gather(send_output(), receive_input())
 | 
				
			||||||
                            print(f"Error: {e}")
 | 
					 | 
				
			||||||
                            traceback.print_exc()
 | 
					 | 
				
			||||||
                            return {"error": str(e)}
 | 
					 | 
				
			||||||
            except asyncio.CancelledError:
 | 
					 | 
				
			||||||
                print("WebSocket connection closed")
 | 
					 | 
				
			||||||
                traceback.print_exc()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async def receive_input():
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                while True:
 | 
					 | 
				
			||||||
                    # print("server awaiting input")
 | 
					 | 
				
			||||||
                    data = await websocket.receive()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if isinstance(data, bytes):
 | 
					 | 
				
			||||||
                        try:
 | 
					 | 
				
			||||||
                            await interpreter.input(data)
 | 
					 | 
				
			||||||
                        except Exception as e:
 | 
					 | 
				
			||||||
                            print(f"Error: {e}")
 | 
					 | 
				
			||||||
                            traceback.print_exc()
 | 
					 | 
				
			||||||
                            return {"error": str(e)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    elif "bytes" in data:
 | 
					 | 
				
			||||||
                        try:
 | 
					 | 
				
			||||||
                            await interpreter.input(data["bytes"])
 | 
					 | 
				
			||||||
                        except Exception as e:
 | 
					 | 
				
			||||||
                            print(f"Error: {e}")
 | 
					 | 
				
			||||||
                            traceback.print_exc()
 | 
					 | 
				
			||||||
                            return {"error": str(e)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    elif "text" in data:
 | 
					 | 
				
			||||||
                        try:
 | 
					 | 
				
			||||||
                            await interpreter.input(data["text"])
 | 
					 | 
				
			||||||
                        except Exception as e:
 | 
					 | 
				
			||||||
                            print(f"Error: {e}")
 | 
					 | 
				
			||||||
                            traceback.print_exc()
 | 
					 | 
				
			||||||
                            return {"error": str(e)}
 | 
					 | 
				
			||||||
            except asyncio.CancelledError:
 | 
					 | 
				
			||||||
                print("WebSocket connection closed")
 | 
					 | 
				
			||||||
                traceback.print_exc()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            send_task = asyncio.create_task(send_output())
 | 
					 | 
				
			||||||
            receive_task = asyncio.create_task(receive_input())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await asyncio.gather(send_task, receive_task)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        except WebSocketDisconnect:
 | 
					 | 
				
			||||||
            print("WebSocket disconnected")
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            print(f"WebSocket connection closed with exception: {e}")
 | 
					            print(f"WebSocket connection closed with exception: {e}")
 | 
				
			||||||
            traceback.print_exc()
 | 
					            traceback.print_exc()
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            print("server closing ws connection")
 | 
					            if not websocket.client_state == "DISCONNECTED":
 | 
				
			||||||
            await websocket.close()
 | 
					                await websocket.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print(f"Starting server on {server_host}:{server_port}")
 | 
					    print(f"Starting server on {server_host}:{server_port}")
 | 
				
			||||||
    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
					    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
				
			||||||
| 
						 | 
					@ -140,4 +98,4 @@ async def main(server_host, server_port, tts_service, asynchronous):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    asyncio.run(main("localhost", 8000))
 | 
					    asyncio.run(main())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -77,9 +77,6 @@ def run(
 | 
				
			||||||
    mobile: bool = typer.Option(
 | 
					    mobile: bool = typer.Option(
 | 
				
			||||||
        False, "--mobile", help="Toggle server to support mobile app"
 | 
					        False, "--mobile", help="Toggle server to support mobile app"
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
    asynchronous: bool = typer.Option(
 | 
					 | 
				
			||||||
        False, "--async", help="use interpreter optimized for latency"
 | 
					 | 
				
			||||||
    ),
 | 
					 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    _run(
 | 
					    _run(
 | 
				
			||||||
        server=server or mobile,
 | 
					        server=server or mobile,
 | 
				
			||||||
| 
						 | 
					@ -102,7 +99,6 @@ def run(
 | 
				
			||||||
        local=local,
 | 
					        local=local,
 | 
				
			||||||
        qr=qr or mobile,
 | 
					        qr=qr or mobile,
 | 
				
			||||||
        mobile=mobile,
 | 
					        mobile=mobile,
 | 
				
			||||||
        asynchronous=asynchronous,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -127,7 +123,6 @@ def _run(
 | 
				
			||||||
    local: bool = False,
 | 
					    local: bool = False,
 | 
				
			||||||
    qr: bool = False,
 | 
					    qr: bool = False,
 | 
				
			||||||
    mobile: bool = False,
 | 
					    mobile: bool = False,
 | 
				
			||||||
    asynchronous: bool = False,
 | 
					 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if local:
 | 
					    if local:
 | 
				
			||||||
        tts_service = "coqui"
 | 
					        tts_service = "coqui"
 | 
				
			||||||
| 
						 | 
					@ -162,7 +157,6 @@ def _run(
 | 
				
			||||||
                    server_host,
 | 
					                    server_host,
 | 
				
			||||||
                    server_port,
 | 
					                    server_port,
 | 
				
			||||||
                    tts_service,
 | 
					                    tts_service,
 | 
				
			||||||
                    asynchronous,
 | 
					 | 
				
			||||||
                    # llm_service,
 | 
					                    # llm_service,
 | 
				
			||||||
                    # model,
 | 
					                    # model,
 | 
				
			||||||
                    # llm_supports_vision,
 | 
					                    # llm_supports_vision,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue