diff --git a/software/config.yaml b/software/config.yaml index 7e6c3c6..c647beb 100644 --- a/software/config.yaml +++ b/software/config.yaml @@ -1,7 +1,7 @@ client: enabled: false url: null - platform: auto + platform: '' llm: service: litellm @@ -13,7 +13,7 @@ llm: temperature: 0.8 local: - enabled: true + enabled: false tts_service: piper stt_service: local-whisper diff --git a/software/pyproject.toml b/software/pyproject.toml index e85bad0..8ff54ac 100644 --- a/software/pyproject.toml +++ b/software/pyproject.toml @@ -51,6 +51,7 @@ pytest = "^8.1.1" target-version = ['py311'] [tool.isort] -profile = "black" -multi_line_output = 3 include_trailing_comma = true +known_first_party = ["source"] +multi_line_output = 3 +profile = "black" diff --git a/software/source/__init__.py b/software/source/__init__.py index b6bac92..64ec270 100644 --- a/software/source/__init__.py +++ b/software/source/__init__.py @@ -1,6 +1,11 @@ +""" +Loads environment variables and creates a global configuration object. +""" + from dotenv import load_dotenv -from source.core.models import Config + +from source.core.config import Config, get_config load_dotenv() -config = Config() +config: Config = get_config() diff --git a/software/source/core/config.py b/software/source/core/config.py new file mode 100644 index 0000000..1444031 --- /dev/null +++ b/software/source/core/config.py @@ -0,0 +1,110 @@ +""" +Application configuration model. +""" + +import os +from functools import lru_cache +from typing import Any + +from pydantic_settings import ( + BaseSettings, + DotEnvSettingsSource, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) + +from source.core.models import LLM, STT, TTS, Client, Local, Server, Tunnel + +APP_PREFIX: str = os.getenv("01_PREFIX", "01_") + + +class Config(BaseSettings): + """ + Base configuration model + """ + + client: Client = Client() + llm: LLM = LLM() + local: Local = Local() + server: Server = Server() + stt: STT = STT() + tts: TTS = TTS() + tunnel: Tunnel = Tunnel() + + model_config = SettingsConfigDict(extra="allow") + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + """ + Modify the order of precedence for settings sources. + """ + return ( + DotEnvSettingsSource( + settings_cls, + env_prefix=APP_PREFIX, + env_file=".env", + env_file_encoding="utf-8", + env_nested_delimiter="_", + ), + YamlConfigSettingsSource( + settings_cls, + yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"), + ), + ) + + def apply_cli_args(self, args: dict) -> None: + """ + Apply CLI arguments to config. + """ + mapping: dict[str, str] = { + "server": "server.enabled", + "server_host": "server.host", + "server_port": "server.port", + "tunnel_service": "tunnel.service", + "expose": "tunnel.exposed", + "client": "client.enabled", + "server_url": "client.url", + "client_type": "client.platform", + "llm_service": "llm.service", + "model": "llm.model", + "llm_supports_vision": "llm.vision_enabled", + "llm_supports_functions": "llm.functions_enabled", + "context_window": "llm.context_window", + "max_tokens": "llm.max_tokens", + "temperature": "llm.temperature", + "tts_service": "tts.service", + "stt_service": "stt.service", + "local": "local.enabled", + } + + for key, path in mapping.items(): + if key in args and args[key] is not None: + self.set_field(path, args[key]) + + def set_field(self, field: str, value: Any) -> None: + """ + Set field value + """ + obj: Any = self + parts: list[str] = field.split(".") + + for part in parts[:-1]: + obj: Any = getattr(obj, part) + + setattr(obj, parts[-1], value) + + +@lru_cache() +def get_config() -> Config: + """ + Return the application configuration. + """ + return Config() diff --git a/software/source/core/models.py b/software/source/core/models.py index 10b871f..e715557 100644 --- a/software/source/core/models.py +++ b/software/source/core/models.py @@ -1,17 +1,8 @@ """ -Application configuration models. +Application models. """ -import os - from pydantic import BaseModel -from pydantic_settings import ( - BaseSettings, - PydanticBaseSettingsSource, - YamlConfigSettingsSource, -) - -APP_PREFIX: str = os.getenv("01_APP_PREFIX", "01_") class Client(BaseModel): @@ -20,8 +11,8 @@ class Client(BaseModel): """ enabled: bool = False - url: None | str = None - platform: str = "auto" + url: str | None = None + platform: str | None = None class LLM(BaseModel): @@ -44,8 +35,6 @@ class Local(BaseModel): """ enabled: bool = False - tts_service: str = "piper" - stt_service: str = "local-whisper" class Server(BaseModel): @@ -81,36 +70,3 @@ class Tunnel(BaseModel): service: str = "ngrok" exposed: bool = False - - -class Config(BaseSettings): - """ - Base configuration model - """ - - client: Client = Client() - llm: LLM = LLM() - local: Local = Local() - server: Server = Server() - stt: STT = STT() - tts: TTS = TTS() - tunnel: Tunnel = Tunnel() - - @classmethod - def settings_customise_sources( - cls, - settings_cls: type[BaseSettings], - init_settings: PydanticBaseSettingsSource, - env_settings: PydanticBaseSettingsSource, - dotenv_settings: PydanticBaseSettingsSource, - file_secret_settings: PydanticBaseSettingsSource, - ) -> tuple[PydanticBaseSettingsSource, ...]: - """ - Modify the order of precedence for settings sources. - """ - return ( - YamlConfigSettingsSource( - settings_cls, - yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"), - ), - ) diff --git a/software/source/core/utils.py b/software/source/core/utils.py deleted file mode 100644 index bc09916..0000000 --- a/software/source/core/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Core utilty functions for the server and client -""" - -import asyncio -import importlib -import os -import platform -from threading import Thread -from typing import NoReturn - -from source.server.server import main -from source.server.tunnel import create_tunnel - - -def get_client_platform(config) -> None: - """ - Returns the client platform based on the system type. - """ - if config.client.platform == "auto": - system_type: str = platform.system() - - # macOS - if system_type == "Darwin": - config.client.platform = "mac" - - # Linux - elif system_type == "Linux": - try: - with open("/proc/device-tree/model", "r", encoding="utf-8") as m: - if "raspberry pi" in m.read().lower(): - config.client.platform = "rpi" - else: - config.client.platform = "linux" - except FileNotFoundError: - config.client.platform = "linux" - - -def handle_exit(signum, frame) -> NoReturn: # pylint: disable=unused-argument - """ - Handle exit signal. - """ - os._exit(0) - - -def start_client(config) -> Thread: - """ - Start the client. - """ - module = importlib.import_module( - f".clients.{config.client.platform}.device", package="source" - ) - - client_thread = Thread(target=module.main, args=[config.client.url]) - client_thread.start() - return client_thread - - -def start_server(config) -> Thread: - """ - Start the server. - """ - loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - server_thread = Thread( - target=loop.run_until_complete, - args=( - main( - config.server.host, - config.server.port, - config.llm.service, - config.llm.model, - config.llm.vision_enabled, - config.llm.functions_enabled, - config.llm.context_window, - config.llm.max_tokens, - config.llm.temperature, - config.tts.service, - config.stt.service, - ), - ), - ) - - server_thread.start() - return server_thread - - -def start_tunnel(config) -> Thread: - """ - Start the tunnel. - """ - tunnel_thread = Thread( - target=create_tunnel, - args=[config.tunnel.service, config.server.host, config.server.port], - ) - - tunnel_thread.start() - return tunnel_thread diff --git a/software/source/server/i.py b/software/source/server/i.py index e14d6f9..3fe01fd 100644 --- a/software/source/server/i.py +++ b/software/source/server/i.py @@ -9,6 +9,8 @@ from pathlib import Path from interpreter import OpenInterpreter import shutil +from source import config +from source.core.config import APP_PREFIX system_message = r""" @@ -189,7 +191,7 @@ def configure_interpreter(interpreter: OpenInterpreter): interpreter.llm.supports_vision = True interpreter.shrink_images = True # Faster but less accurate - interpreter.llm.model = "gpt-4" + interpreter.llm.model = config.llm.model interpreter.llm.supports_functions = False interpreter.llm.context_window = 110000 diff --git a/software/source/utils/system.py b/software/source/utils/system.py new file mode 100644 index 0000000..9ae2b78 --- /dev/null +++ b/software/source/utils/system.py @@ -0,0 +1,12 @@ +""" +System utility functions +""" + +import os + + +def handle_exit(signum, frame) -> None: # pylint: disable=unused-argument + """ + Handle exit signal. + """ + os._exit(0) diff --git a/software/start.py b/software/start.py index 10782ec..01175bb 100644 --- a/software/start.py +++ b/software/start.py @@ -2,67 +2,164 @@ Application entry point """ +import asyncio +import importlib +import os +import platform import signal -from threading import Thread +import threading + +import typer from source import config -from source.core.utils import ( - get_client_platform, - handle_exit, - start_client, - start_server, - start_tunnel, -) from source.server.utils.local_mode import select_local_model +from source.utils.system import handle_exit + +app = typer.Typer() -def run() -> None: +@app.command() +def start( + ctx: typer.Context, + server: bool = typer.Option(False, "--server", help="Run server"), + server_host: str = typer.Option( + "0.0.0.0", + "--server-host", + help="Specify the server host where the server will deploy", + ), + server_port: int = typer.Option( + 10001, + "--server-port", + help="Specify the server port where the server will deploy", + ), + tunnel_service: str = typer.Option( + "ngrok", "--tunnel-service", help="Specify the tunnel service" + ), + expose: bool = typer.Option(False, "--expose", help="Expose server to internet"), + client: bool = typer.Option(False, "--client", help="Run client"), + server_url: str = typer.Option( + None, + "--server-url", + help="Specify the server URL that the client should expect. Defaults to server-host and server-port", + ), + client_type: str = typer.Option( + "auto", "--client-type", help="Specify the client type" + ), + llm_service: str = typer.Option( + "litellm", "--llm-service", help="Specify the LLM service" + ), + model: str = typer.Option(None, "--model", help="Specify the model"), + llm_supports_vision: bool = typer.Option( + False, + "--llm-supports-vision", + help="Specify if the LLM service supports vision", + ), + llm_supports_functions: bool = typer.Option( + False, + "--llm-supports-functions", + help="Specify if the LLM service supports functions", + ), + context_window: int = typer.Option( + 2048, "--context-window", help="Specify the context window size" + ), + max_tokens: int = typer.Option( + 4096, "--max-tokens", help="Specify the maximum number of tokens" + ), + temperature: float = typer.Option( + 0.8, "--temperature", help="Specify the temperature for generation" + ), + tts_service: str = typer.Option( + "openai", "--tts-service", help="Specify the TTS service" + ), + stt_service: str = typer.Option( + "openai", "--stt-service", help="Specify the STT service" + ), + local: bool = typer.Option( + False, "--local", help="Use recommended local services for LLM, STT, and TTS" + ), +) -> None: """ - Run the application. + Setup the application. """ - # Set up signal handler for SIGINT (keyboard interrupt) signal.signal(signal.SIGINT, handle_exit) + config.apply_cli_args(ctx.params) - # If platform is set to auto, determine user's platform automatically. - if config.client.platform == "auto": - get_client_platform(config) - - # If local mode is enabled, set up local services if config.local.enabled: - config.tts.service = config.local.tts_service - config.stt.service = config.local.stt_service - select_local_model() + config.tts.service = "piper" + config.stt.service = "local-whisper" - # If no client URL is provided, set one using server host and port. - config.client.url = ( - config.client.url or f"{config.server.host}:{config.server.port}" - ) + if not model: + select_local_model() + + if not server_url: + server_url = f"{config.server.host}:{config.server.port}" if not config.server.enabled and not config.client.enabled: - config.server.enabled = config.client.enabled = True + config.server.enabled = True + config.client.enabled = True - server_thread: Thread | None = ( - start_server(config) if config.server.enabled else None - ) + # Temporary fix pending refactor of `server` and `tunnel` modules. + # Prevents early execution of top-level code until config is fully initialized. + server_module = importlib.import_module("source.server.server") + tunnel_module = importlib.import_module("source.server.tunnel") - tunnel_thread: Thread | None = ( - start_tunnel(config) if config.tunnel.exposed else None - ) + if config.server.enabled: + loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + server_thread = threading.Thread( + target=loop.run_until_complete, + args=( + server_module.main( + config.server.host, + config.server.port, + config.llm.service, + config.llm.model, + config.llm.vision_enabled, + config.llm.functions_enabled, + config.llm.context_window, + config.llm.max_tokens, + config.llm.temperature, + config.tts.service, + config.stt.service, + ), + ), + ) + server_thread.start() - client_thread: Thread | None = ( - start_client(config) if config.client.enabled else None - ) + if config.tunnel.exposed: + tunnel_thread = threading.Thread( + target=tunnel_module.create_tunnel, + args=[config.tunnel.service, config.server.host, config.server.port], + ) + tunnel_thread.start() + + if config.client.enabled: + if config.client.platform == "auto": + system: str = platform.system() + if system == "Darwin": # macOS + config.client.platform = "mac" + elif system == "Linux": + try: + with open("/proc/device-tree/model", "r", encoding="utf-8") as m: + if "raspberry pi" in m.read().lower(): + config.client.platform = "rpi" + else: + config.client.platform = "linux" + except FileNotFoundError: + config.client.platform = "linux" + + module = importlib.import_module( + f".clients.{config.client.platform}.device", package="source" + ) + client_thread = threading.Thread(target=module.main, args=[server_url]) + client_thread.start() try: - if server_thread: + if config.server.enabled: server_thread.join() - if tunnel_thread: + if config.tunnel.exposed: tunnel_thread.join() - if client_thread and client_thread.is_alive(): + if config.client.enabled: client_thread.join() except KeyboardInterrupt: - handle_exit(signal.SIGINT, None) - - -if __name__ == "__main__": - run() + os.kill(os.getpid(), signal.SIGINT)