From c0ed82c46543734af0cc2fd6d48640c58d42e9e6 Mon Sep 17 00:00:00 2001 From: "James C. Palmer" Date: Sat, 23 Mar 2024 15:00:03 -0700 Subject: [PATCH] Add app configuration management - Integrate Pydantic and Pydantic Settings for config models and validation - Add `config.yaml` for structured and readable configuration - Centralize environment variable loading and configuration instance --- software/config.yaml | 33 ++++++ software/source/__init__.py | 6 + software/source/core/__init__.py | 0 software/source/core/models.py | 117 ++++++++++++++++++++ software/source/core/utils.py | 99 +++++++++++++++++ software/start.py | 184 +++++++++---------------------- 6 files changed, 307 insertions(+), 132 deletions(-) create mode 100644 software/config.yaml create mode 100644 software/source/core/__init__.py create mode 100644 software/source/core/models.py create mode 100644 software/source/core/utils.py diff --git a/software/config.yaml b/software/config.yaml new file mode 100644 index 0000000..7e6c3c6 --- /dev/null +++ b/software/config.yaml @@ -0,0 +1,33 @@ +client: + enabled: false + url: null + platform: auto + +llm: + service: litellm + model: gpt-4 + vision_enabled: false + functions_enabled: false + context_window: 2048 + max_tokens: 4096 + temperature: 0.8 + +local: + enabled: true + tts_service: piper + stt_service: local-whisper + +server: + enabled: false + host: 0.0.0.0 + port: 10001 + +stt: + service: openai + +tts: + service: openai + +tunnel: + service: ngrok + exposed: false diff --git a/software/source/__init__.py b/software/source/__init__.py index e69de29..b6bac92 100644 --- a/software/source/__init__.py +++ b/software/source/__init__.py @@ -0,0 +1,6 @@ +from dotenv import load_dotenv +from source.core.models import Config + +load_dotenv() + +config = Config() diff --git a/software/source/core/__init__.py b/software/source/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/software/source/core/models.py b/software/source/core/models.py new file mode 100644 index 0000000..419d8d1 --- /dev/null +++ b/software/source/core/models.py @@ -0,0 +1,117 @@ +""" +Application configuration models. +""" + +from pydantic import BaseModel +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) + +APP_PREFIX = "01_" + + +class Client(BaseModel): + """ + Client configuration model + """ + + enabled: bool = False + url: None | str = None + platform: str = "auto" + + +class LLM(BaseModel): + """ + LLM configuration model + """ + + service: str = "litellm" + model: str = "gpt-4" + vision_enabled: bool = False + functions_enabled: bool = False + context_window: int = 2048 + max_tokens: int = 4096 + temperature: float = 0.8 + + +class Local(BaseModel): + """ + Local configuration model + """ + + enabled: bool = False + tts_service: str = "piper" + stt_service: str = "local-whisper" + + +class Server(BaseModel): + """ + Server configuration model + """ + + enabled: bool = False + host: str = "0.0.0.0" + port: int = 10001 + + +class STT(BaseModel): + """ + Speech-to-text configuration model + """ + + service: str = "openai" + + +class TTS(BaseModel): + """ + Text-to-speech configuration model + """ + + service: str = "openai" + + +class Tunnel(BaseModel): + """ + Tunnel configuration model + """ + + service: str = "ngrok" + exposed: bool = False + + +class Config(BaseSettings): + """ + Base configuration model + """ + + client: Client = Client() + llm: LLM = LLM() + local: Local = Local() + server: Server = Server() + stt: STT = STT() + tts: TTS = TTS() + tunnel: Tunnel = Tunnel() + + model_config = SettingsConfigDict( + env_prefix=APP_PREFIX, + env_file=".env", + env_file_encoding="utf-8", + yaml_file="config.yaml", + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + """ + Modify the order of precedence for settings sources. + """ + return (YamlConfigSettingsSource(settings_cls),) diff --git a/software/source/core/utils.py b/software/source/core/utils.py new file mode 100644 index 0000000..bc09916 --- /dev/null +++ b/software/source/core/utils.py @@ -0,0 +1,99 @@ +""" +Core utilty functions for the server and client +""" + +import asyncio +import importlib +import os +import platform +from threading import Thread +from typing import NoReturn + +from source.server.server import main +from source.server.tunnel import create_tunnel + + +def get_client_platform(config) -> None: + """ + Returns the client platform based on the system type. + """ + if config.client.platform == "auto": + system_type: str = platform.system() + + # macOS + if system_type == "Darwin": + config.client.platform = "mac" + + # Linux + elif system_type == "Linux": + try: + with open("/proc/device-tree/model", "r", encoding="utf-8") as m: + if "raspberry pi" in m.read().lower(): + config.client.platform = "rpi" + else: + config.client.platform = "linux" + except FileNotFoundError: + config.client.platform = "linux" + + +def handle_exit(signum, frame) -> NoReturn: # pylint: disable=unused-argument + """ + Handle exit signal. + """ + os._exit(0) + + +def start_client(config) -> Thread: + """ + Start the client. + """ + module = importlib.import_module( + f".clients.{config.client.platform}.device", package="source" + ) + + client_thread = Thread(target=module.main, args=[config.client.url]) + client_thread.start() + return client_thread + + +def start_server(config) -> Thread: + """ + Start the server. + """ + loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + server_thread = Thread( + target=loop.run_until_complete, + args=( + main( + config.server.host, + config.server.port, + config.llm.service, + config.llm.model, + config.llm.vision_enabled, + config.llm.functions_enabled, + config.llm.context_window, + config.llm.max_tokens, + config.llm.temperature, + config.tts.service, + config.stt.service, + ), + ), + ) + + server_thread.start() + return server_thread + + +def start_tunnel(config) -> Thread: + """ + Start the tunnel. + """ + tunnel_thread = Thread( + target=create_tunnel, + args=[config.tunnel.service, config.server.host, config.server.port], + ) + + tunnel_thread.start() + return tunnel_thread diff --git a/software/start.py b/software/start.py index 70088e4..10782ec 100644 --- a/software/start.py +++ b/software/start.py @@ -1,148 +1,68 @@ -import typer -import asyncio -import platform -import concurrent.futures -import threading -import os -import importlib -from source.server.tunnel import create_tunnel -from source.server.server import main -from source.server.utils.local_mode import select_local_model +""" +Application entry point +""" import signal -app = typer.Typer() +from threading import Thread -@app.command() -def run( - server: bool = typer.Option(False, "--server", help="Run server"), - server_host: str = typer.Option("0.0.0.0", "--server-host", help="Specify the server host where the server will deploy"), - server_port: int = typer.Option(10001, "--server-port", help="Specify the server port where the server will deploy"), - - tunnel_service: str = typer.Option("ngrok", "--tunnel-service", help="Specify the tunnel service"), - expose: bool = typer.Option(False, "--expose", help="Expose server to internet"), - - client: bool = typer.Option(False, "--client", help="Run client"), - server_url: str = typer.Option(None, "--server-url", help="Specify the server URL that the client should expect. Defaults to server-host and server-port"), - client_type: str = typer.Option("auto", "--client-type", help="Specify the client type"), - - llm_service: str = typer.Option("litellm", "--llm-service", help="Specify the LLM service"), - - model: str = typer.Option("gpt-4", "--model", help="Specify the model"), - llm_supports_vision: bool = typer.Option(False, "--llm-supports-vision", help="Specify if the LLM service supports vision"), - llm_supports_functions: bool = typer.Option(False, "--llm-supports-functions", help="Specify if the LLM service supports functions"), - context_window: int = typer.Option(2048, "--context-window", help="Specify the context window size"), - max_tokens: int = typer.Option(4096, "--max-tokens", help="Specify the maximum number of tokens"), - temperature: float = typer.Option(0.8, "--temperature", help="Specify the temperature for generation"), - - tts_service: str = typer.Option("openai", "--tts-service", help="Specify the TTS service"), - - stt_service: str = typer.Option("openai", "--stt-service", help="Specify the STT service"), +from source import config +from source.core.utils import ( + get_client_platform, + handle_exit, + start_client, + start_server, + start_tunnel, +) +from source.server.utils.local_mode import select_local_model - local: bool = typer.Option(False, "--local", help="Use recommended local services for LLM, STT, and TTS"), - ): - - _run( - server=server, - server_host=server_host, - server_port=server_port, - tunnel_service=tunnel_service, - expose=expose, - client=client, - server_url=server_url, - client_type=client_type, - llm_service=llm_service, - model=model, - llm_supports_vision=llm_supports_vision, - llm_supports_functions=llm_supports_functions, - context_window=context_window, - max_tokens=max_tokens, - temperature=temperature, - tts_service=tts_service, - stt_service=stt_service, - local=local - ) - -def _run( - server: bool = False, - server_host: str = "0.0.0.0", - server_port: int = 10001, - - tunnel_service: str = "bore", - expose: bool = False, - - client: bool = False, - server_url: str = None, - client_type: str = "auto", - - llm_service: str = "litellm", - - model: str = "gpt-4", - llm_supports_vision: bool = False, - llm_supports_functions: bool = False, - context_window: int = 2048, - max_tokens: int = 4096, - temperature: float = 0.8, - - tts_service: str = "openai", - - stt_service: str = "openai", - - local: bool = False - ): - - if local: - tts_service = "piper" - # llm_service = "llamafile" - stt_service = "local-whisper" - select_local_model() - - if not server_url: - server_url = f"{server_host}:{server_port}" - - if not server and not client: - server = True - client = True - - def handle_exit(signum, frame): - os._exit(0) +def run() -> None: + """ + Run the application. + """ + # Set up signal handler for SIGINT (keyboard interrupt) signal.signal(signal.SIGINT, handle_exit) - if server: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - server_thread = threading.Thread(target=loop.run_until_complete, args=(main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service),)) - server_thread.start() + # If platform is set to auto, determine user's platform automatically. + if config.client.platform == "auto": + get_client_platform(config) - if expose: - tunnel_thread = threading.Thread(target=create_tunnel, args=[tunnel_service, server_host, server_port]) - tunnel_thread.start() + # If local mode is enabled, set up local services + if config.local.enabled: + config.tts.service = config.local.tts_service + config.stt.service = config.local.stt_service + select_local_model() - if client: - if client_type == "auto": - system_type = platform.system() - if system_type == "Darwin": # Mac OS - client_type = "mac" - elif system_type == "Linux": # Linux System - try: - with open('/proc/device-tree/model', 'r') as m: - if 'raspberry pi' in m.read().lower(): - client_type = "rpi" - else: - client_type = "linux" - except FileNotFoundError: - client_type = "linux" + # If no client URL is provided, set one using server host and port. + config.client.url = ( + config.client.url or f"{config.server.host}:{config.server.port}" + ) - module = importlib.import_module(f".clients.{client_type}.device", package='source') - client_thread = threading.Thread(target=module.main, args=[server_url]) - client_thread.start() + if not config.server.enabled and not config.client.enabled: + config.server.enabled = config.client.enabled = True + + server_thread: Thread | None = ( + start_server(config) if config.server.enabled else None + ) + + tunnel_thread: Thread | None = ( + start_tunnel(config) if config.tunnel.exposed else None + ) + + client_thread: Thread | None = ( + start_client(config) if config.client.enabled else None + ) try: - if server: + if server_thread: server_thread.join() - if expose: + if tunnel_thread: tunnel_thread.join() - if client: + if client_thread and client_thread.is_alive(): client_thread.join() except KeyboardInterrupt: - os.kill(os.getpid(), signal.SIGINT) \ No newline at end of file + handle_exit(signal.SIGINT, None) + + +if __name__ == "__main__": + run()