Add global `config`, support for `.env` and apply cli args to `config` at start.
This commit is contained in:
parent
94edb8e001
commit
fe2a343c41
|
@ -1,7 +1,7 @@
|
|||
client:
|
||||
enabled: false
|
||||
url: null
|
||||
platform: auto
|
||||
platform: ''
|
||||
|
||||
llm:
|
||||
service: litellm
|
||||
|
@ -13,7 +13,7 @@ llm:
|
|||
temperature: 0.8
|
||||
|
||||
local:
|
||||
enabled: true
|
||||
enabled: false
|
||||
tts_service: piper
|
||||
stt_service: local-whisper
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ pytest = "^8.1.1"
|
|||
target-version = ['py311']
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
known_first_party = ["source"]
|
||||
multi_line_output = 3
|
||||
profile = "black"
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
"""
|
||||
Loads environment variables and creates a global configuration object.
|
||||
"""
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from source.core.models import Config
|
||||
|
||||
from source.core.config import Config, get_config
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = Config()
|
||||
config: Config = get_config()
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Application configuration model.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
DotEnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
from source.core.models import LLM, STT, TTS, Client, Local, Server, Tunnel
|
||||
|
||||
APP_PREFIX: str = os.getenv("01_PREFIX", "01_")
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
Base configuration model
|
||||
"""
|
||||
|
||||
client: Client = Client()
|
||||
llm: LLM = LLM()
|
||||
local: Local = Local()
|
||||
server: Server = Server()
|
||||
stt: STT = STT()
|
||||
tts: TTS = TTS()
|
||||
tunnel: Tunnel = Tunnel()
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""
|
||||
Modify the order of precedence for settings sources.
|
||||
"""
|
||||
return (
|
||||
DotEnvSettingsSource(
|
||||
settings_cls,
|
||||
env_prefix=APP_PREFIX,
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
env_nested_delimiter="_",
|
||||
),
|
||||
YamlConfigSettingsSource(
|
||||
settings_cls,
|
||||
yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"),
|
||||
),
|
||||
)
|
||||
|
||||
def apply_cli_args(self, args: dict) -> None:
|
||||
"""
|
||||
Apply CLI arguments to config.
|
||||
"""
|
||||
mapping: dict[str, str] = {
|
||||
"server": "server.enabled",
|
||||
"server_host": "server.host",
|
||||
"server_port": "server.port",
|
||||
"tunnel_service": "tunnel.service",
|
||||
"expose": "tunnel.exposed",
|
||||
"client": "client.enabled",
|
||||
"server_url": "client.url",
|
||||
"client_type": "client.platform",
|
||||
"llm_service": "llm.service",
|
||||
"model": "llm.model",
|
||||
"llm_supports_vision": "llm.vision_enabled",
|
||||
"llm_supports_functions": "llm.functions_enabled",
|
||||
"context_window": "llm.context_window",
|
||||
"max_tokens": "llm.max_tokens",
|
||||
"temperature": "llm.temperature",
|
||||
"tts_service": "tts.service",
|
||||
"stt_service": "stt.service",
|
||||
"local": "local.enabled",
|
||||
}
|
||||
|
||||
for key, path in mapping.items():
|
||||
if key in args and args[key] is not None:
|
||||
self.set_field(path, args[key])
|
||||
|
||||
def set_field(self, field: str, value: Any) -> None:
|
||||
"""
|
||||
Set field value
|
||||
"""
|
||||
obj: Any = self
|
||||
parts: list[str] = field.split(".")
|
||||
|
||||
for part in parts[:-1]:
|
||||
obj: Any = getattr(obj, part)
|
||||
|
||||
setattr(obj, parts[-1], value)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_config() -> Config:
|
||||
"""
|
||||
Return the application configuration.
|
||||
"""
|
||||
return Config()
|
|
@ -1,17 +1,8 @@
|
|||
"""
|
||||
Application configuration models.
|
||||
Application models.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
APP_PREFIX: str = os.getenv("01_APP_PREFIX", "01_")
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
|
@ -20,8 +11,8 @@ class Client(BaseModel):
|
|||
"""
|
||||
|
||||
enabled: bool = False
|
||||
url: None | str = None
|
||||
platform: str = "auto"
|
||||
url: str | None = None
|
||||
platform: str | None = None
|
||||
|
||||
|
||||
class LLM(BaseModel):
|
||||
|
@ -44,8 +35,6 @@ class Local(BaseModel):
|
|||
"""
|
||||
|
||||
enabled: bool = False
|
||||
tts_service: str = "piper"
|
||||
stt_service: str = "local-whisper"
|
||||
|
||||
|
||||
class Server(BaseModel):
|
||||
|
@ -81,36 +70,3 @@ class Tunnel(BaseModel):
|
|||
|
||||
service: str = "ngrok"
|
||||
exposed: bool = False
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
Base configuration model
|
||||
"""
|
||||
|
||||
client: Client = Client()
|
||||
llm: LLM = LLM()
|
||||
local: Local = Local()
|
||||
server: Server = Server()
|
||||
stt: STT = STT()
|
||||
tts: TTS = TTS()
|
||||
tunnel: Tunnel = Tunnel()
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""
|
||||
Modify the order of precedence for settings sources.
|
||||
"""
|
||||
return (
|
||||
YamlConfigSettingsSource(
|
||||
settings_cls,
|
||||
yaml_file=os.getenv(f"{APP_PREFIX}CONFIG_FILE", "config.yaml"),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
"""
|
||||
Core utilty functions for the server and client
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import platform
|
||||
from threading import Thread
|
||||
from typing import NoReturn
|
||||
|
||||
from source.server.server import main
|
||||
from source.server.tunnel import create_tunnel
|
||||
|
||||
|
||||
def get_client_platform(config) -> None:
|
||||
"""
|
||||
Returns the client platform based on the system type.
|
||||
"""
|
||||
if config.client.platform == "auto":
|
||||
system_type: str = platform.system()
|
||||
|
||||
# macOS
|
||||
if system_type == "Darwin":
|
||||
config.client.platform = "mac"
|
||||
|
||||
# Linux
|
||||
elif system_type == "Linux":
|
||||
try:
|
||||
with open("/proc/device-tree/model", "r", encoding="utf-8") as m:
|
||||
if "raspberry pi" in m.read().lower():
|
||||
config.client.platform = "rpi"
|
||||
else:
|
||||
config.client.platform = "linux"
|
||||
except FileNotFoundError:
|
||||
config.client.platform = "linux"
|
||||
|
||||
|
||||
def handle_exit(signum, frame) -> NoReturn: # pylint: disable=unused-argument
|
||||
"""
|
||||
Handle exit signal.
|
||||
"""
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def start_client(config) -> Thread:
|
||||
"""
|
||||
Start the client.
|
||||
"""
|
||||
module = importlib.import_module(
|
||||
f".clients.{config.client.platform}.device", package="source"
|
||||
)
|
||||
|
||||
client_thread = Thread(target=module.main, args=[config.client.url])
|
||||
client_thread.start()
|
||||
return client_thread
|
||||
|
||||
|
||||
def start_server(config) -> Thread:
|
||||
"""
|
||||
Start the server.
|
||||
"""
|
||||
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
server_thread = Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(
|
||||
main(
|
||||
config.server.host,
|
||||
config.server.port,
|
||||
config.llm.service,
|
||||
config.llm.model,
|
||||
config.llm.vision_enabled,
|
||||
config.llm.functions_enabled,
|
||||
config.llm.context_window,
|
||||
config.llm.max_tokens,
|
||||
config.llm.temperature,
|
||||
config.tts.service,
|
||||
config.stt.service,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
return server_thread
|
||||
|
||||
|
||||
def start_tunnel(config) -> Thread:
|
||||
"""
|
||||
Start the tunnel.
|
||||
"""
|
||||
tunnel_thread = Thread(
|
||||
target=create_tunnel,
|
||||
args=[config.tunnel.service, config.server.host, config.server.port],
|
||||
)
|
||||
|
||||
tunnel_thread.start()
|
||||
return tunnel_thread
|
|
@ -9,6 +9,8 @@ from pathlib import Path
|
|||
from interpreter import OpenInterpreter
|
||||
import shutil
|
||||
|
||||
from source import config
|
||||
from source.core.config import APP_PREFIX
|
||||
|
||||
system_message = r"""
|
||||
|
||||
|
@ -189,7 +191,7 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
interpreter.llm.supports_vision = True
|
||||
interpreter.shrink_images = True # Faster but less accurate
|
||||
|
||||
interpreter.llm.model = "gpt-4"
|
||||
interpreter.llm.model = config.llm.model
|
||||
|
||||
interpreter.llm.supports_functions = False
|
||||
interpreter.llm.context_window = 110000
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
"""
|
||||
System utility functions
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def handle_exit(signum, frame) -> None: # pylint: disable=unused-argument
|
||||
"""
|
||||
Handle exit signal.
|
||||
"""
|
||||
os._exit(0)
|
|
@ -2,67 +2,164 @@
|
|||
Application entry point
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
from threading import Thread
|
||||
import threading
|
||||
|
||||
import typer
|
||||
|
||||
from source import config
|
||||
from source.core.utils import (
|
||||
get_client_platform,
|
||||
handle_exit,
|
||||
start_client,
|
||||
start_server,
|
||||
start_tunnel,
|
||||
)
|
||||
from source.server.utils.local_mode import select_local_model
|
||||
from source.utils.system import handle_exit
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def run() -> None:
|
||||
@app.command()
|
||||
def start(
|
||||
ctx: typer.Context,
|
||||
server: bool = typer.Option(False, "--server", help="Run server"),
|
||||
server_host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--server-host",
|
||||
help="Specify the server host where the server will deploy",
|
||||
),
|
||||
server_port: int = typer.Option(
|
||||
10001,
|
||||
"--server-port",
|
||||
help="Specify the server port where the server will deploy",
|
||||
),
|
||||
tunnel_service: str = typer.Option(
|
||||
"ngrok", "--tunnel-service", help="Specify the tunnel service"
|
||||
),
|
||||
expose: bool = typer.Option(False, "--expose", help="Expose server to internet"),
|
||||
client: bool = typer.Option(False, "--client", help="Run client"),
|
||||
server_url: str = typer.Option(
|
||||
None,
|
||||
"--server-url",
|
||||
help="Specify the server URL that the client should expect. Defaults to server-host and server-port",
|
||||
),
|
||||
client_type: str = typer.Option(
|
||||
"auto", "--client-type", help="Specify the client type"
|
||||
),
|
||||
llm_service: str = typer.Option(
|
||||
"litellm", "--llm-service", help="Specify the LLM service"
|
||||
),
|
||||
model: str = typer.Option(None, "--model", help="Specify the model"),
|
||||
llm_supports_vision: bool = typer.Option(
|
||||
False,
|
||||
"--llm-supports-vision",
|
||||
help="Specify if the LLM service supports vision",
|
||||
),
|
||||
llm_supports_functions: bool = typer.Option(
|
||||
False,
|
||||
"--llm-supports-functions",
|
||||
help="Specify if the LLM service supports functions",
|
||||
),
|
||||
context_window: int = typer.Option(
|
||||
2048, "--context-window", help="Specify the context window size"
|
||||
),
|
||||
max_tokens: int = typer.Option(
|
||||
4096, "--max-tokens", help="Specify the maximum number of tokens"
|
||||
),
|
||||
temperature: float = typer.Option(
|
||||
0.8, "--temperature", help="Specify the temperature for generation"
|
||||
),
|
||||
tts_service: str = typer.Option(
|
||||
"openai", "--tts-service", help="Specify the TTS service"
|
||||
),
|
||||
stt_service: str = typer.Option(
|
||||
"openai", "--stt-service", help="Specify the STT service"
|
||||
),
|
||||
local: bool = typer.Option(
|
||||
False, "--local", help="Use recommended local services for LLM, STT, and TTS"
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Run the application.
|
||||
Setup the application.
|
||||
"""
|
||||
# Set up signal handler for SIGINT (keyboard interrupt)
|
||||
signal.signal(signal.SIGINT, handle_exit)
|
||||
config.apply_cli_args(ctx.params)
|
||||
|
||||
# If platform is set to auto, determine user's platform automatically.
|
||||
if config.client.platform == "auto":
|
||||
get_client_platform(config)
|
||||
|
||||
# If local mode is enabled, set up local services
|
||||
if config.local.enabled:
|
||||
config.tts.service = config.local.tts_service
|
||||
config.stt.service = config.local.stt_service
|
||||
select_local_model()
|
||||
config.tts.service = "piper"
|
||||
config.stt.service = "local-whisper"
|
||||
|
||||
# If no client URL is provided, set one using server host and port.
|
||||
config.client.url = (
|
||||
config.client.url or f"{config.server.host}:{config.server.port}"
|
||||
)
|
||||
if not model:
|
||||
select_local_model()
|
||||
|
||||
if not server_url:
|
||||
server_url = f"{config.server.host}:{config.server.port}"
|
||||
|
||||
if not config.server.enabled and not config.client.enabled:
|
||||
config.server.enabled = config.client.enabled = True
|
||||
config.server.enabled = True
|
||||
config.client.enabled = True
|
||||
|
||||
server_thread: Thread | None = (
|
||||
start_server(config) if config.server.enabled else None
|
||||
)
|
||||
# Temporary fix pending refactor of `server` and `tunnel` modules.
|
||||
# Prevents early execution of top-level code until config is fully initialized.
|
||||
server_module = importlib.import_module("source.server.server")
|
||||
tunnel_module = importlib.import_module("source.server.tunnel")
|
||||
|
||||
tunnel_thread: Thread | None = (
|
||||
start_tunnel(config) if config.tunnel.exposed else None
|
||||
)
|
||||
if config.server.enabled:
|
||||
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_thread = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(
|
||||
server_module.main(
|
||||
config.server.host,
|
||||
config.server.port,
|
||||
config.llm.service,
|
||||
config.llm.model,
|
||||
config.llm.vision_enabled,
|
||||
config.llm.functions_enabled,
|
||||
config.llm.context_window,
|
||||
config.llm.max_tokens,
|
||||
config.llm.temperature,
|
||||
config.tts.service,
|
||||
config.stt.service,
|
||||
),
|
||||
),
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
client_thread: Thread | None = (
|
||||
start_client(config) if config.client.enabled else None
|
||||
)
|
||||
if config.tunnel.exposed:
|
||||
tunnel_thread = threading.Thread(
|
||||
target=tunnel_module.create_tunnel,
|
||||
args=[config.tunnel.service, config.server.host, config.server.port],
|
||||
)
|
||||
tunnel_thread.start()
|
||||
|
||||
if config.client.enabled:
|
||||
if config.client.platform == "auto":
|
||||
system: str = platform.system()
|
||||
if system == "Darwin": # macOS
|
||||
config.client.platform = "mac"
|
||||
elif system == "Linux":
|
||||
try:
|
||||
with open("/proc/device-tree/model", "r", encoding="utf-8") as m:
|
||||
if "raspberry pi" in m.read().lower():
|
||||
config.client.platform = "rpi"
|
||||
else:
|
||||
config.client.platform = "linux"
|
||||
except FileNotFoundError:
|
||||
config.client.platform = "linux"
|
||||
|
||||
module = importlib.import_module(
|
||||
f".clients.{config.client.platform}.device", package="source"
|
||||
)
|
||||
client_thread = threading.Thread(target=module.main, args=[server_url])
|
||||
client_thread.start()
|
||||
|
||||
try:
|
||||
if server_thread:
|
||||
if config.server.enabled:
|
||||
server_thread.join()
|
||||
if tunnel_thread:
|
||||
if config.tunnel.exposed:
|
||||
tunnel_thread.join()
|
||||
if client_thread and client_thread.is_alive():
|
||||
if config.client.enabled:
|
||||
client_thread.join()
|
||||
except KeyboardInterrupt:
|
||||
handle_exit(signal.SIGINT, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
|
Loading…
Reference in New Issue