Implemented `profiles`
This commit is contained in:
		
							parent
							
								
									632af7f7ba
								
							
						
					
					
						commit
						fda23e95b2
					
				| 
						 | 
					@ -1,14 +1,3 @@
 | 
				
			||||||
# import from the profiles directory the interpreter to be served
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# add other profiles to the directory to define other interpreter instances and import them here
 | 
					 | 
				
			||||||
# {.profiles.fast: optimizes for STT/TTS latency with the fastest models }
 | 
					 | 
				
			||||||
# {.profiles.local: uses local models and local STT/TTS }
 | 
					 | 
				
			||||||
# {.profiles.default: uses default interpreter settings with optimized TTS latency }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# from .profiles.fast import interpreter as base_interpreter
 | 
					 | 
				
			||||||
# from .profiles.local import interpreter as base_interpreter
 | 
					 | 
				
			||||||
from .profiles.default import interpreter as base_interpreter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
| 
						 | 
					@ -19,6 +8,7 @@ from .async_interpreter import AsyncInterpreter
 | 
				
			||||||
from fastapi.middleware.cors import CORSMiddleware
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
from typing import List, Dict, Any
 | 
					from typing import List, Dict, Any
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					import importlib.util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
os.environ["STT_RUNNER"] = "server"
 | 
					os.environ["STT_RUNNER"] = "server"
 | 
				
			||||||
os.environ["TTS_RUNNER"] = "server"
 | 
					os.environ["TTS_RUNNER"] = "server"
 | 
				
			||||||
| 
						 | 
					@ -50,14 +40,6 @@ async def websocket_endpoint(
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    await websocket.accept()
 | 
					    await websocket.accept()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # interpreter.tts set in the profiles directory!!!!
 | 
					 | 
				
			||||||
    interpreter = AsyncInterpreter(base_interpreter, debug)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Send the tts_service value to the client
 | 
					 | 
				
			||||||
    await websocket.send_text(
 | 
					 | 
				
			||||||
        json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts})
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def receive_input():
 | 
					        async def receive_input():
 | 
				
			||||||
| 
						 | 
					@ -98,9 +80,21 @@ async def websocket_endpoint(
 | 
				
			||||||
            await websocket.close()
 | 
					            await websocket.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def main(server_host, server_port, debug):
 | 
					async def main(server_host, server_port, profile, debug):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    app.state.debug = debug
 | 
					    app.state.debug = debug
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load the profile module from the provided path
 | 
				
			||||||
 | 
					    spec = importlib.util.spec_from_file_location("profile", profile)
 | 
				
			||||||
 | 
					    profile_module = importlib.util.module_from_spec(spec)
 | 
				
			||||||
 | 
					    spec.loader.exec_module(profile_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get the interpreter from the profile
 | 
				
			||||||
 | 
					    interpreter = profile_module.interpreter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Make it async
 | 
				
			||||||
 | 
					    interpreter = AsyncInterpreter(interpreter, debug)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print(f"Starting server on {server_host}:{server_port}")
 | 
					    print(f"Starting server on {server_host}:{server_port}")
 | 
				
			||||||
    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
					    config = Config(app, host=server_host, port=server_port, lifespan="on")
 | 
				
			||||||
    server = Server(config)
 | 
					    server = Server(config)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,7 +6,7 @@ from ..utils.print_markdown import print_markdown
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_tunnel(
 | 
					def create_tunnel(
 | 
				
			||||||
    tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False
 | 
					    tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False, domain=None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    print_markdown("Exposing server to the internet...")
 | 
					    print_markdown("Exposing server to the internet...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,8 +99,13 @@ def create_tunnel(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # If ngrok is installed, start it on the specified port
 | 
					        # If ngrok is installed, start it on the specified port
 | 
				
			||||||
        # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)
 | 
					        # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if domain:
 | 
				
			||||||
 | 
					            domain = f"--domain={domain}"
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            domain = ""
 | 
				
			||||||
        process = subprocess.Popen(
 | 
					        process = subprocess.Popen(
 | 
				
			||||||
            f"ngrok http {server_port} --scheme http,https  --log=stdout",
 | 
					            f"ngrok http {server_port} --scheme http,https {domain} --log=stdout",
 | 
				
			||||||
            shell=True,
 | 
					            shell=True,
 | 
				
			||||||
            stdout=subprocess.PIPE,
 | 
					            stdout=subprocess.PIPE,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,7 @@ import os
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
from source.server.tunnel import create_tunnel
 | 
					from source.server.tunnel import create_tunnel
 | 
				
			||||||
from source.server.async_server import main
 | 
					from source.server.async_server import main
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import signal
 | 
					import signal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,11 +42,25 @@ def run(
 | 
				
			||||||
    qr: bool = typer.Option(
 | 
					    qr: bool = typer.Option(
 | 
				
			||||||
        False, "--qr", help="Display QR code to scan to connect to the server"
 | 
					        False, "--qr", help="Display QR code to scan to connect to the server"
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					    domain: str = typer.Option(
 | 
				
			||||||
 | 
					        None, "--domain", help="Connect ngrok to a custom domain"
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    profiles: bool = typer.Option(
 | 
				
			||||||
 | 
					        False,
 | 
				
			||||||
 | 
					        "--profiles",
 | 
				
			||||||
 | 
					        help="Opens the folder where this script is contained",
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    profile: str = typer.Option(
 | 
				
			||||||
 | 
					        "default.py", # default
 | 
				
			||||||
 | 
					        "--profile",
 | 
				
			||||||
 | 
					        help="Specify the path to the profile, or the name of the file if it's in the `profiles` directory (run `--profiles` to open the profiles directory)",
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
    debug: bool = typer.Option(
 | 
					    debug: bool = typer.Option(
 | 
				
			||||||
        False,
 | 
					        False,
 | 
				
			||||||
        "--debug",
 | 
					        "--debug",
 | 
				
			||||||
        help="Print latency measurements and save microphone recordings locally for manual playback.",
 | 
					        help="Print latency measurements and save microphone recordings locally for manual playback.",
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    _run(
 | 
					    _run(
 | 
				
			||||||
        server=server,
 | 
					        server=server,
 | 
				
			||||||
| 
						 | 
					@ -58,6 +73,9 @@ def run(
 | 
				
			||||||
        client_type=client_type,
 | 
					        client_type=client_type,
 | 
				
			||||||
        qr=qr,
 | 
					        qr=qr,
 | 
				
			||||||
        debug=debug,
 | 
					        debug=debug,
 | 
				
			||||||
 | 
					        domain=domain,
 | 
				
			||||||
 | 
					        profiles=profiles,
 | 
				
			||||||
 | 
					        profile=profile,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,8 +90,33 @@ def _run(
 | 
				
			||||||
    client_type: str = "auto",
 | 
					    client_type: str = "auto",
 | 
				
			||||||
    qr: bool = False,
 | 
					    qr: bool = False,
 | 
				
			||||||
    debug: bool = False,
 | 
					    debug: bool = False,
 | 
				
			||||||
 | 
					    domain = None,
 | 
				
			||||||
 | 
					    profiles = None,
 | 
				
			||||||
 | 
					    profile = None,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    profiles_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "source", "server", "profiles")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if profiles:
 | 
				
			||||||
 | 
					        if platform.system() == "Windows":
 | 
				
			||||||
 | 
					            subprocess.Popen(['explorer', profiles_dir])
 | 
				
			||||||
 | 
					        elif platform.system() == "Darwin":
 | 
				
			||||||
 | 
					            subprocess.Popen(['open', profiles_dir])
 | 
				
			||||||
 | 
					        elif platform.system() == "Linux":
 | 
				
			||||||
 | 
					            subprocess.Popen(['xdg-open', profiles_dir])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            subprocess.Popen(['open', profiles_dir])
 | 
				
			||||||
 | 
					        exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if profile:
 | 
				
			||||||
 | 
					        if not os.path.isfile(profile):
 | 
				
			||||||
 | 
					            profile = os.path.join(profiles_dir, profile)
 | 
				
			||||||
 | 
					            if not os.path.isfile(profile):
 | 
				
			||||||
 | 
					                profile += ".py"
 | 
				
			||||||
 | 
					                if not os.path.isfile(profile):
 | 
				
			||||||
 | 
					                    print(f"Invalid profile path: {profile}")
 | 
				
			||||||
 | 
					                    exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    system_type = platform.system()
 | 
					    system_type = platform.system()
 | 
				
			||||||
    if system_type == "Windows":
 | 
					    if system_type == "Windows":
 | 
				
			||||||
        server_host = "localhost"
 | 
					        server_host = "localhost"
 | 
				
			||||||
| 
						 | 
					@ -91,7 +134,6 @@ def _run(
 | 
				
			||||||
    signal.signal(signal.SIGINT, handle_exit)
 | 
					    signal.signal(signal.SIGINT, handle_exit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if server:
 | 
					    if server:
 | 
				
			||||||
        # print(f"Starting server with mobile = {mobile}")
 | 
					 | 
				
			||||||
        loop = asyncio.new_event_loop()
 | 
					        loop = asyncio.new_event_loop()
 | 
				
			||||||
        asyncio.set_event_loop(loop)
 | 
					        asyncio.set_event_loop(loop)
 | 
				
			||||||
        server_thread = threading.Thread(
 | 
					        server_thread = threading.Thread(
 | 
				
			||||||
| 
						 | 
					@ -100,6 +142,7 @@ def _run(
 | 
				
			||||||
                main(
 | 
					                main(
 | 
				
			||||||
                    server_host,
 | 
					                    server_host,
 | 
				
			||||||
                    server_port,
 | 
					                    server_port,
 | 
				
			||||||
 | 
					                    profile,
 | 
				
			||||||
                    debug,
 | 
					                    debug,
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
| 
						 | 
					@ -108,7 +151,7 @@ def _run(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if expose:
 | 
					    if expose:
 | 
				
			||||||
        tunnel_thread = threading.Thread(
 | 
					        tunnel_thread = threading.Thread(
 | 
				
			||||||
            target=create_tunnel, args=[tunnel_service, server_host, server_port, qr]
 | 
					            target=create_tunnel, args=[tunnel_service, server_host, server_port, qr, domain]
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        tunnel_thread.start()
 | 
					        tunnel_thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue