Add global `config`, support for `.env` and apply cli args to `config` at start.

This commit is contained in:
James C. Palmer 2024-03-23 17:42:39 -07:00
parent 94edb8e001
commit fe2a343c41
9 changed files with 278 additions and 194 deletions

View File

@ -1,7 +1,7 @@
client:
enabled: false
url: null
platform: auto
platform: ''
llm:
service: litellm
@ -13,7 +13,7 @@ llm:
temperature: 0.8
local:
enabled: true
enabled: false
tts_service: piper
stt_service: local-whisper

View File

@ -51,6 +51,7 @@ pytest = "^8.1.1"
target-version = ['py311']
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
known_first_party = ["source"]
multi_line_output = 3
profile = "black"

View File

@ -1,6 +1,11 @@
"""
Loads environment variables and creates a global configuration object.
"""
from dotenv import load_dotenv
from source.core.models import Config
from source.core.config import Config, get_config
load_dotenv()
config = Config()
config: Config = get_config()

View File

@ -0,0 +1,110 @@
"""
Application configuration model.
"""
import os
from functools import lru_cache
from typing import Any
from pydantic_settings import (
BaseSettings,
DotEnvSettingsSource,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from source.core.models import LLM, STT, TTS, Client, Local, Server, Tunnel
APP_PREFIX: str = os.getenv("01_PREFIX", "01_")
class Config(BaseSettings):
"""
Base configuration model
"""
client: Client = Client()
llm: LLM = LLM()
local: Local = Local()
server: Server = Server()
stt: STT = STT()
tts: TTS = TTS()
tunnel: Tunnel = Tunnel()
model_config = SettingsConfigDict(extra="allow")
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
"""
Modify the order of precedence for settings sources.
"""
return (
DotEnvSettingsSource(
settings_cls,
env_prefix=APP_PREFIX,
env_file=".env",
env_file_encoding="utf-8",
env_nested_delimiter="_",
),
YamlConfigSettingsSource(
settings_cls,
yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"),
),
)
def apply_cli_args(self, args: dict) -> None:
"""
Apply CLI arguments to config.
"""
mapping: dict[str, str] = {
"server": "server.enabled",
"server_host": "server.host",
"server_port": "server.port",
"tunnel_service": "tunnel.service",
"expose": "tunnel.exposed",
"client": "client.enabled",
"server_url": "client.url",
"client_type": "client.platform",
"llm_service": "llm.service",
"model": "llm.model",
"llm_supports_vision": "llm.vision_enabled",
"llm_supports_functions": "llm.functions_enabled",
"context_window": "llm.context_window",
"max_tokens": "llm.max_tokens",
"temperature": "llm.temperature",
"tts_service": "tts.service",
"stt_service": "stt.service",
"local": "local.enabled",
}
for key, path in mapping.items():
if key in args and args[key] is not None:
self.set_field(path, args[key])
def set_field(self, field: str, value: Any) -> None:
"""
Set field value
"""
obj: Any = self
parts: list[str] = field.split(".")
for part in parts[:-1]:
obj: Any = getattr(obj, part)
setattr(obj, parts[-1], value)
@lru_cache()
def get_config() -> Config:
"""
Return the application configuration.
"""
return Config()

View File

@ -1,17 +1,8 @@
"""
Application configuration models.
Application models.
"""
import os
from pydantic import BaseModel
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
APP_PREFIX: str = os.getenv("01_APP_PREFIX", "01_")
class Client(BaseModel):
@ -20,8 +11,8 @@ class Client(BaseModel):
"""
enabled: bool = False
url: None | str = None
platform: str = "auto"
url: str | None = None
platform: str | None = None
class LLM(BaseModel):
@ -44,8 +35,6 @@ class Local(BaseModel):
"""
enabled: bool = False
tts_service: str = "piper"
stt_service: str = "local-whisper"
class Server(BaseModel):
@ -81,36 +70,3 @@ class Tunnel(BaseModel):
service: str = "ngrok"
exposed: bool = False
class Config(BaseSettings):
"""
Base configuration model
"""
client: Client = Client()
llm: LLM = LLM()
local: Local = Local()
server: Server = Server()
stt: STT = STT()
tts: TTS = TTS()
tunnel: Tunnel = Tunnel()
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
"""
Modify the order of precedence for settings sources.
"""
return (
YamlConfigSettingsSource(
settings_cls,
yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"),
),
)

View File

@ -1,99 +0,0 @@
"""
Core utilty functions for the server and client
"""
import asyncio
import importlib
import os
import platform
from threading import Thread
from typing import NoReturn
from source.server.server import main
from source.server.tunnel import create_tunnel
def get_client_platform(config) -> None:
"""
Returns the client platform based on the system type.
"""
if config.client.platform == "auto":
system_type: str = platform.system()
# macOS
if system_type == "Darwin":
config.client.platform = "mac"
# Linux
elif system_type == "Linux":
try:
with open("/proc/device-tree/model", "r", encoding="utf-8") as m:
if "raspberry pi" in m.read().lower():
config.client.platform = "rpi"
else:
config.client.platform = "linux"
except FileNotFoundError:
config.client.platform = "linux"
def handle_exit(signum, frame) -> NoReturn: # pylint: disable=unused-argument
"""
Handle exit signal.
"""
os._exit(0)
def start_client(config) -> Thread:
"""
Start the client.
"""
module = importlib.import_module(
f".clients.{config.client.platform}.device", package="source"
)
client_thread = Thread(target=module.main, args=[config.client.url])
client_thread.start()
return client_thread
def start_server(config) -> Thread:
"""
Start the server.
"""
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_thread = Thread(
target=loop.run_until_complete,
args=(
main(
config.server.host,
config.server.port,
config.llm.service,
config.llm.model,
config.llm.vision_enabled,
config.llm.functions_enabled,
config.llm.context_window,
config.llm.max_tokens,
config.llm.temperature,
config.tts.service,
config.stt.service,
),
),
)
server_thread.start()
return server_thread
def start_tunnel(config) -> Thread:
"""
Start the tunnel.
"""
tunnel_thread = Thread(
target=create_tunnel,
args=[config.tunnel.service, config.server.host, config.server.port],
)
tunnel_thread.start()
return tunnel_thread

View File

@ -9,6 +9,8 @@ from pathlib import Path
from interpreter import OpenInterpreter
import shutil
from source import config
from source.core.config import APP_PREFIX
system_message = r"""
@ -189,7 +191,7 @@ def configure_interpreter(interpreter: OpenInterpreter):
interpreter.llm.supports_vision = True
interpreter.shrink_images = True # Faster but less accurate
interpreter.llm.model = "gpt-4"
interpreter.llm.model = config.llm.model
interpreter.llm.supports_functions = False
interpreter.llm.context_window = 110000

View File

@ -0,0 +1,12 @@
"""
System utility functions
"""
import os
def handle_exit(signum, frame) -> None: # pylint: disable=unused-argument
"""
Handle exit signal.
"""
os._exit(0)

View File

@ -2,67 +2,164 @@
Application entry point
"""
import asyncio
import importlib
import os
import platform
import signal
from threading import Thread
import threading
import typer
from source import config
from source.core.utils import (
get_client_platform,
handle_exit,
start_client,
start_server,
start_tunnel,
)
from source.server.utils.local_mode import select_local_model
from source.utils.system import handle_exit
app = typer.Typer()
def run() -> None:
@app.command()
def start(
ctx: typer.Context,
server: bool = typer.Option(False, "--server", help="Run server"),
server_host: str = typer.Option(
"0.0.0.0",
"--server-host",
help="Specify the server host where the server will deploy",
),
server_port: int = typer.Option(
10001,
"--server-port",
help="Specify the server port where the server will deploy",
),
tunnel_service: str = typer.Option(
"ngrok", "--tunnel-service", help="Specify the tunnel service"
),
expose: bool = typer.Option(False, "--expose", help="Expose server to internet"),
client: bool = typer.Option(False, "--client", help="Run client"),
server_url: str = typer.Option(
None,
"--server-url",
help="Specify the server URL that the client should expect. Defaults to server-host and server-port",
),
client_type: str = typer.Option(
"auto", "--client-type", help="Specify the client type"
),
llm_service: str = typer.Option(
"litellm", "--llm-service", help="Specify the LLM service"
),
model: str = typer.Option(None, "--model", help="Specify the model"),
llm_supports_vision: bool = typer.Option(
False,
"--llm-supports-vision",
help="Specify if the LLM service supports vision",
),
llm_supports_functions: bool = typer.Option(
False,
"--llm-supports-functions",
help="Specify if the LLM service supports functions",
),
context_window: int = typer.Option(
2048, "--context-window", help="Specify the context window size"
),
max_tokens: int = typer.Option(
4096, "--max-tokens", help="Specify the maximum number of tokens"
),
temperature: float = typer.Option(
0.8, "--temperature", help="Specify the temperature for generation"
),
tts_service: str = typer.Option(
"openai", "--tts-service", help="Specify the TTS service"
),
stt_service: str = typer.Option(
"openai", "--stt-service", help="Specify the STT service"
),
local: bool = typer.Option(
False, "--local", help="Use recommended local services for LLM, STT, and TTS"
),
) -> None:
"""
Run the application.
Setup the application.
"""
# Set up signal handler for SIGINT (keyboard interrupt)
signal.signal(signal.SIGINT, handle_exit)
config.apply_cli_args(ctx.params)
# If platform is set to auto, determine user's platform automatically.
if config.client.platform == "auto":
get_client_platform(config)
# If local mode is enabled, set up local services
if config.local.enabled:
config.tts.service = config.local.tts_service
config.stt.service = config.local.stt_service
select_local_model()
config.tts.service = "piper"
config.stt.service = "local-whisper"
# If no client URL is provided, set one using server host and port.
config.client.url = (
config.client.url or f"{config.server.host}:{config.server.port}"
)
if not model:
select_local_model()
if not server_url:
server_url = f"{config.server.host}:{config.server.port}"
if not config.server.enabled and not config.client.enabled:
config.server.enabled = config.client.enabled = True
config.server.enabled = True
config.client.enabled = True
server_thread: Thread | None = (
start_server(config) if config.server.enabled else None
)
# Temporary fix pending refactor of `server` and `tunnel` modules.
# Prevents early execution of top-level code until config is fully initialized.
server_module = importlib.import_module("source.server.server")
tunnel_module = importlib.import_module("source.server.tunnel")
tunnel_thread: Thread | None = (
start_tunnel(config) if config.tunnel.exposed else None
)
if config.server.enabled:
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_thread = threading.Thread(
target=loop.run_until_complete,
args=(
server_module.main(
config.server.host,
config.server.port,
config.llm.service,
config.llm.model,
config.llm.vision_enabled,
config.llm.functions_enabled,
config.llm.context_window,
config.llm.max_tokens,
config.llm.temperature,
config.tts.service,
config.stt.service,
),
),
)
server_thread.start()
client_thread: Thread | None = (
start_client(config) if config.client.enabled else None
)
if config.tunnel.exposed:
tunnel_thread = threading.Thread(
target=tunnel_module.create_tunnel,
args=[config.tunnel.service, config.server.host, config.server.port],
)
tunnel_thread.start()
if config.client.enabled:
if config.client.platform == "auto":
system: str = platform.system()
if system == "Darwin": # macOS
config.client.platform = "mac"
elif system == "Linux":
try:
with open("/proc/device-tree/model", "r", encoding="utf-8") as m:
if "raspberry pi" in m.read().lower():
config.client.platform = "rpi"
else:
config.client.platform = "linux"
except FileNotFoundError:
config.client.platform = "linux"
module = importlib.import_module(
f".clients.{config.client.platform}.device", package="source"
)
client_thread = threading.Thread(target=module.main, args=[server_url])
client_thread.start()
try:
if server_thread:
if config.server.enabled:
server_thread.join()
if tunnel_thread:
if config.tunnel.exposed:
tunnel_thread.join()
if client_thread and client_thread.is_alive():
if config.client.enabled:
client_thread.join()
except KeyboardInterrupt:
handle_exit(signal.SIGINT, None)
if __name__ == "__main__":
run()
os.kill(os.getpid(), signal.SIGINT)