Merge pull request #223 from dheavy/fix/precommit-linter
Fix pre-commit linter
This commit is contained in:
commit
a40c8041f2
|
@ -1,10 +1,10 @@
|
|||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.2.2"
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.3.0 # Use the latest revision of Black
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix"]
|
||||
- id: ruff-format
|
||||
- id: black
|
||||
language_version: python3
|
||||
args: ["software/"]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
|
|
|
@ -10,4 +10,3 @@ In the coming months, we're going to release:
|
|||
- [ ] An open-source language model for computer control
|
||||
- [ ] A react-native app for your phone
|
||||
- [ ] A hand-held device that runs fully offline.
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
_archive
|
||||
__pycache__
|
||||
.idea
|
||||
|
||||
|
|
|
@ -1,23 +1,18 @@
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import os
|
||||
import pyaudio
|
||||
from starlette.websockets import WebSocket
|
||||
from queue import Queue
|
||||
from pynput import keyboard
|
||||
import json
|
||||
import traceback
|
||||
import websockets
|
||||
import queue
|
||||
import pydub
|
||||
import ast
|
||||
from pydub import AudioSegment
|
||||
from pydub.playback import play
|
||||
import io
|
||||
import time
|
||||
import wave
|
||||
import tempfile
|
||||
|
@ -25,7 +20,10 @@ from datetime import datetime
|
|||
import cv2
|
||||
import base64
|
||||
import platform
|
||||
from interpreter import interpreter # Just for code execution. Maybe we should let people do from interpreter.computer import run?
|
||||
from interpreter import (
|
||||
interpreter,
|
||||
) # Just for code execution. Maybe we should let people do from interpreter.computer import run?
|
||||
|
||||
# In the future, I guess kernel watching code should be elsewhere? Somewhere server / client agnostic?
|
||||
from ..server.utils.kernel import put_kernel_messages_into_queue
|
||||
from ..server.utils.get_system_info import get_system_info
|
||||
|
@ -33,6 +31,7 @@ from ..server.utils.process_utils import kill_process_tree
|
|||
|
||||
from ..server.utils.logs import setup_logging
|
||||
from ..server.utils.logs import logger
|
||||
|
||||
setup_logging()
|
||||
|
||||
os.environ["STT_RUNNER"] = "server"
|
||||
|
@ -51,11 +50,11 @@ RECORDING = False # Flag to control recording state
|
|||
SPACEBAR_PRESSED = False # Flag to track spacebar press state
|
||||
|
||||
# Camera configuration
|
||||
CAMERA_ENABLED = os.getenv('CAMERA_ENABLED', False)
|
||||
CAMERA_ENABLED = os.getenv("CAMERA_ENABLED", False)
|
||||
if type(CAMERA_ENABLED) == str:
|
||||
CAMERA_ENABLED = (CAMERA_ENABLED.lower() == "true")
|
||||
CAMERA_DEVICE_INDEX = int(os.getenv('CAMERA_DEVICE_INDEX', 0))
|
||||
CAMERA_WARMUP_SECONDS = float(os.getenv('CAMERA_WARMUP_SECONDS', 0))
|
||||
CAMERA_ENABLED = CAMERA_ENABLED.lower() == "true"
|
||||
CAMERA_DEVICE_INDEX = int(os.getenv("CAMERA_DEVICE_INDEX", 0))
|
||||
CAMERA_WARMUP_SECONDS = float(os.getenv("CAMERA_WARMUP_SECONDS", 0))
|
||||
|
||||
# Specify OS
|
||||
current_platform = get_system_info()
|
||||
|
@ -66,6 +65,7 @@ p = pyaudio.PyAudio()
|
|||
|
||||
send_queue = queue.Queue()
|
||||
|
||||
|
||||
class Device:
|
||||
def __init__(self):
|
||||
self.pressed_keys = set()
|
||||
|
@ -89,23 +89,28 @@ class Device:
|
|||
|
||||
if ret:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
image_path = os.path.join(temp_dir, f"01_photo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png")
|
||||
image_path = os.path.join(
|
||||
temp_dir, f"01_photo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png"
|
||||
)
|
||||
self.captured_images.append(image_path)
|
||||
cv2.imwrite(image_path, frame)
|
||||
logger.info(f"Camera image captured to {image_path}")
|
||||
logger.info(f"You now have {len(self.captured_images)} images which will be sent along with your next audio message.")
|
||||
logger.info(
|
||||
f"You now have {len(self.captured_images)} images which will be sent along with your next audio message."
|
||||
)
|
||||
else:
|
||||
logger.error(f"Error: Couldn't capture an image from camera ({camera_index})")
|
||||
logger.error(
|
||||
f"Error: Couldn't capture an image from camera ({camera_index})"
|
||||
)
|
||||
|
||||
cap.release()
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def encode_image_to_base64(self, image_path):
|
||||
"""Encodes an image file to a base64 string."""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def add_image_to_send_queue(self, image_path):
|
||||
"""Encodes an image and adds an LMC message to the send queue with the image data."""
|
||||
|
@ -114,7 +119,7 @@ class Device:
|
|||
"role": "user",
|
||||
"type": "image",
|
||||
"format": "base64.png",
|
||||
"content": base64_image
|
||||
"content": base64_image,
|
||||
}
|
||||
send_queue.put(image_message)
|
||||
# Delete the image file from the file system after sending it
|
||||
|
@ -126,7 +131,6 @@ class Device:
|
|||
self.add_image_to_send_queue(image_path)
|
||||
self.captured_images.clear() # Clear the list after sending
|
||||
|
||||
|
||||
async def play_audiosegments(self):
|
||||
"""Plays them sequentially."""
|
||||
while True:
|
||||
|
@ -141,27 +145,35 @@ class Device:
|
|||
except:
|
||||
logger.info(traceback.format_exc())
|
||||
|
||||
|
||||
def record_audio(self):
|
||||
|
||||
if os.getenv('STT_RUNNER') == "server":
|
||||
if os.getenv("STT_RUNNER") == "server":
|
||||
# STT will happen on the server. we're sending audio.
|
||||
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "start": True})
|
||||
elif os.getenv('STT_RUNNER') == "client":
|
||||
send_queue.put(
|
||||
{"role": "user", "type": "audio", "format": "bytes.wav", "start": True}
|
||||
)
|
||||
elif os.getenv("STT_RUNNER") == "client":
|
||||
# STT will happen here, on the client. we're sending text.
|
||||
send_queue.put({"role": "user", "type": "message", "start": True})
|
||||
else:
|
||||
raise Exception("STT_RUNNER must be set to either 'client' or 'server'.")
|
||||
|
||||
"""Record audio from the microphone and add it to the queue."""
|
||||
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
stream = p.open(
|
||||
format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK,
|
||||
)
|
||||
print("Recording started...")
|
||||
global RECORDING
|
||||
|
||||
# Create a temporary WAV file to store the audio data
|
||||
temp_dir = tempfile.gettempdir()
|
||||
wav_path = os.path.join(temp_dir, f"audio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
wav_file = wave.open(wav_path, 'wb')
|
||||
wav_path = os.path.join(
|
||||
temp_dir, f"audio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
|
||||
)
|
||||
wav_file = wave.open(wav_path, "wb")
|
||||
wav_file.setnchannels(CHANNELS)
|
||||
wav_file.setsampwidth(p.get_sample_size(FORMAT))
|
||||
wav_file.setframerate(RATE)
|
||||
|
@ -178,17 +190,30 @@ class Device:
|
|||
duration = wav_file.getnframes() / RATE
|
||||
if duration < 0.3:
|
||||
# Just pressed it. Send stop message
|
||||
if os.getenv('STT_RUNNER') == "client":
|
||||
if os.getenv("STT_RUNNER") == "client":
|
||||
send_queue.put({"role": "user", "type": "message", "content": "stop"})
|
||||
send_queue.put({"role": "user", "type": "message", "end": True})
|
||||
else:
|
||||
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "content": ""})
|
||||
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
|
||||
send_queue.put(
|
||||
{
|
||||
"role": "user",
|
||||
"type": "audio",
|
||||
"format": "bytes.wav",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
send_queue.put(
|
||||
{
|
||||
"role": "user",
|
||||
"type": "audio",
|
||||
"format": "bytes.wav",
|
||||
"end": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.queue_all_captured_images()
|
||||
|
||||
if os.getenv('STT_RUNNER') == "client":
|
||||
|
||||
if os.getenv("STT_RUNNER") == "client":
|
||||
# THIS DOES NOT WORK. We moved to this very cool stt_service, llm_service
|
||||
# way of doing things. stt_wav is not a thing anymore. Needs work to work
|
||||
|
||||
|
@ -199,12 +224,19 @@ class Device:
|
|||
send_queue.put({"role": "user", "type": "message", "end": True})
|
||||
else:
|
||||
# Stream audio
|
||||
with open(wav_path, 'rb') as audio_file:
|
||||
with open(wav_path, "rb") as audio_file:
|
||||
byte_data = audio_file.read(CHUNK)
|
||||
while byte_data:
|
||||
send_queue.put(byte_data)
|
||||
byte_data = audio_file.read(CHUNK)
|
||||
send_queue.put({"role": "user", "type": "audio", "format": "bytes.wav", "end": True})
|
||||
send_queue.put(
|
||||
{
|
||||
"role": "user",
|
||||
"type": "audio",
|
||||
"format": "bytes.wav",
|
||||
"end": True,
|
||||
}
|
||||
)
|
||||
|
||||
if os.path.exists(wav_path):
|
||||
os.remove(wav_path)
|
||||
|
@ -227,24 +259,27 @@ class Device:
|
|||
|
||||
if keyboard.Key.space in self.pressed_keys:
|
||||
self.toggle_recording(True)
|
||||
elif {keyboard.Key.ctrl, keyboard.KeyCode.from_char('c')} <= self.pressed_keys:
|
||||
elif {keyboard.Key.ctrl, keyboard.KeyCode.from_char("c")} <= self.pressed_keys:
|
||||
logger.info("Ctrl+C pressed. Exiting...")
|
||||
kill_process_tree()
|
||||
os._exit(0)
|
||||
|
||||
def on_release(self, key):
|
||||
"""Detect spacebar release and 'c' key press for camera, and handle key release."""
|
||||
self.pressed_keys.discard(key) # Remove the released key from the key press tracking set
|
||||
self.pressed_keys.discard(
|
||||
key
|
||||
) # Remove the released key from the key press tracking set
|
||||
|
||||
if key == keyboard.Key.space:
|
||||
self.toggle_recording(False)
|
||||
elif CAMERA_ENABLED and key == keyboard.KeyCode.from_char('c'):
|
||||
elif CAMERA_ENABLED and key == keyboard.KeyCode.from_char("c"):
|
||||
self.fetch_image_from_camera()
|
||||
|
||||
|
||||
async def message_sender(self, websocket):
|
||||
while True:
|
||||
message = await asyncio.get_event_loop().run_in_executor(None, send_queue.get)
|
||||
message = await asyncio.get_event_loop().run_in_executor(
|
||||
None, send_queue.get
|
||||
)
|
||||
if isinstance(message, bytes):
|
||||
await websocket.send(message)
|
||||
else:
|
||||
|
@ -257,7 +292,9 @@ class Device:
|
|||
|
||||
async def exec_ws_communication(websocket):
|
||||
if CAMERA_ENABLED:
|
||||
print("\nHold the spacebar to start recording. Press 'c' to capture an image from the camera. Press CTRL-C to exit.")
|
||||
print(
|
||||
"\nHold the spacebar to start recording. Press 'c' to capture an image from the camera. Press CTRL-C to exit."
|
||||
)
|
||||
else:
|
||||
print("\nHold the spacebar to start recording. Press CTRL-C to exit.")
|
||||
|
||||
|
@ -280,7 +317,6 @@ class Device:
|
|||
# At this point, we have our message
|
||||
|
||||
if message["type"] == "audio" and message["format"].startswith("bytes"):
|
||||
|
||||
# Convert bytes to audio file
|
||||
|
||||
audio_bytes = message["content"]
|
||||
|
@ -294,13 +330,13 @@ class Device:
|
|||
# 16,000 Hz frame rate
|
||||
frame_rate=16000,
|
||||
# mono sound
|
||||
channels=1
|
||||
channels=1,
|
||||
)
|
||||
|
||||
self.audiosegments.append(audio)
|
||||
|
||||
# Run the code if that's the client's job
|
||||
if os.getenv('CODE_RUNNER') == "client":
|
||||
if os.getenv("CODE_RUNNER") == "client":
|
||||
if message["type"] == "code" and "end" in message:
|
||||
language = message["format"]
|
||||
code = message["content"]
|
||||
|
@ -308,7 +344,7 @@ class Device:
|
|||
send_queue.put(result)
|
||||
|
||||
if is_win10():
|
||||
logger.info('Windows 10 detected')
|
||||
logger.info("Windows 10 detected")
|
||||
# Workaround for Windows 10 not latching to the websocket server.
|
||||
# See https://github.com/OpenInterpreter/01/issues/197
|
||||
try:
|
||||
|
@ -335,7 +371,7 @@ class Device:
|
|||
asyncio.create_task(self.websocket_communication(WS_URL))
|
||||
|
||||
# Start watching the kernel if it's your job to do that
|
||||
if os.getenv('CODE_RUNNER') == "client":
|
||||
if os.getenv("CODE_RUNNER") == "client":
|
||||
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
|
||||
|
||||
asyncio.create_task(self.play_audiosegments())
|
||||
|
@ -348,7 +384,9 @@ class Device:
|
|||
print("PINDEF", pindef)
|
||||
|
||||
# HACK: needs passwordless sudo
|
||||
process = await asyncio.create_subprocess_exec("sudo", "gpiomon", "-brf", *pindef, stdout=asyncio.subprocess.PIPE)
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"sudo", "gpiomon", "-brf", *pindef, stdout=asyncio.subprocess.PIPE
|
||||
)
|
||||
while True:
|
||||
line = await process.stdout.readline()
|
||||
if line:
|
||||
|
@ -361,10 +399,12 @@ class Device:
|
|||
break
|
||||
else:
|
||||
# Keyboard listener for spacebar press/release
|
||||
listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release)
|
||||
listener = keyboard.Listener(
|
||||
on_press=self.on_press, on_release=self.on_release
|
||||
)
|
||||
listener.start()
|
||||
|
||||
def start(self):
|
||||
if os.getenv('TEACH_MODE') != "True":
|
||||
if os.getenv("TEACH_MODE") != "True":
|
||||
asyncio.run(self.start_async())
|
||||
p.terminate()
|
||||
|
|
|
@ -26,4 +26,3 @@ And build and upload the firmware with a simple command:
|
|||
```bash
|
||||
pio run --target upload
|
||||
```
|
||||
|
||||
|
|
|
@ -2,9 +2,11 @@ from ..base_device import Device
|
|||
|
||||
device = Device()
|
||||
|
||||
|
||||
def main(server_url):
|
||||
device.server_url = server_url
|
||||
device.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -2,9 +2,11 @@ from ..base_device import Device
|
|||
|
||||
device = Device()
|
||||
|
||||
|
||||
def main(server_url):
|
||||
device.server_url = server_url
|
||||
device.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -2,8 +2,10 @@ from ..base_device import Device
|
|||
|
||||
device = Device()
|
||||
|
||||
|
||||
def main():
|
||||
device.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -2,9 +2,11 @@ from ..base_device import Device
|
|||
|
||||
device = Device()
|
||||
|
||||
|
||||
def main(server_url):
|
||||
device.server_url = server_url
|
||||
device.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from source.server.i import configure_interpreter
|
||||
from unittest.mock import Mock
|
||||
from interpreter import OpenInterpreter
|
||||
from fastapi.testclient import TestClient
|
||||
from .server import app
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from interpreter import OpenInterpreter
|
||||
import shutil
|
||||
|
||||
|
@ -182,7 +182,6 @@ Try multiple methods before saying the task is impossible. **You can do it!**
|
|||
|
||||
|
||||
def configure_interpreter(interpreter: OpenInterpreter):
|
||||
|
||||
### SYSTEM MESSAGE
|
||||
interpreter.system_message = system_message
|
||||
|
||||
|
@ -205,7 +204,6 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
"Please provide more information.",
|
||||
]
|
||||
|
||||
|
||||
# Check if required packages are installed
|
||||
|
||||
# THERE IS AN INCONSISTENCY HERE.
|
||||
|
@ -259,7 +257,6 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
time.sleep(2)
|
||||
print("Attempting to start OS control anyway...\n\n")
|
||||
|
||||
|
||||
# Should we explore other options for ^ these kinds of tags?
|
||||
# Like:
|
||||
|
||||
|
@ -295,12 +292,8 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
# if chunk.get("format") != "active_line":
|
||||
# print(chunk.get("content"))
|
||||
|
||||
import os
|
||||
|
||||
from platformdirs import user_data_dir
|
||||
|
||||
|
||||
|
||||
# Directory paths
|
||||
repo_skills_dir = os.path.join(os.path.dirname(__file__), "skills")
|
||||
user_data_skills_dir = os.path.join(user_data_dir("01"), "skills")
|
||||
|
@ -330,7 +323,6 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
|
||||
interpreter.computer.save_skills = True
|
||||
|
||||
|
||||
# Initialize user's task list
|
||||
interpreter.computer.run(
|
||||
language="python",
|
||||
|
@ -354,17 +346,21 @@ def configure_interpreter(interpreter: OpenInterpreter):
|
|||
### MISC SETTINGS
|
||||
|
||||
interpreter.auto_run = True
|
||||
interpreter.computer.languages = [l for l in interpreter.computer.languages if l.name.lower() in ["applescript", "shell", "zsh", "bash", "python"]]
|
||||
interpreter.computer.languages = [
|
||||
l
|
||||
for l in interpreter.computer.languages
|
||||
if l.name.lower() in ["applescript", "shell", "zsh", "bash", "python"]
|
||||
]
|
||||
interpreter.force_task_completion = True
|
||||
# interpreter.offline = True
|
||||
interpreter.id = 206 # Used to identify itself to other interpreters. This should be changed programmatically so it's unique.
|
||||
|
||||
### RESET conversations/user.json
|
||||
app_dir = user_data_dir('01')
|
||||
conversations_dir = os.path.join(app_dir, 'conversations')
|
||||
app_dir = user_data_dir("01")
|
||||
conversations_dir = os.path.join(app_dir, "conversations")
|
||||
os.makedirs(conversations_dir, exist_ok=True)
|
||||
user_json_path = os.path.join(conversations_dir, 'user.json')
|
||||
with open(user_json_path, 'w') as file:
|
||||
user_json_path = os.path.join(conversations_dir, "user.json")
|
||||
with open(user_json_path, "w") as file:
|
||||
json.dump([], file)
|
||||
|
||||
return interpreter
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import os
|
||||
|
@ -8,7 +9,7 @@ from pathlib import Path
|
|||
### LLM SETUP
|
||||
|
||||
# Define the path to a llamafile
|
||||
llamafile_path = Path(__file__).parent / 'model.llamafile'
|
||||
llamafile_path = Path(__file__).parent / "model.llamafile"
|
||||
|
||||
# Check if the new llamafile exists, if not download it
|
||||
if not os.path.exists(llamafile_path):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import traceback
|
||||
from platformdirs import user_data_dir
|
||||
import ast
|
||||
import json
|
||||
import queue
|
||||
import os
|
||||
|
@ -13,9 +13,7 @@ import re
|
|||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
from .utils.kernel import put_kernel_messages_into_queue
|
||||
from .i import configure_interpreter
|
||||
from interpreter import interpreter
|
||||
|
@ -44,28 +42,31 @@ accumulator = Accumulator()
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
app_dir = user_data_dir('01')
|
||||
conversation_history_path = os.path.join(app_dir, 'conversations', 'user.json')
|
||||
app_dir = user_data_dir("01")
|
||||
conversation_history_path = os.path.join(app_dir, "conversations", "user.json")
|
||||
|
||||
SERVER_LOCAL_PORT = int(os.getenv('SERVER_LOCAL_PORT', 10001))
|
||||
SERVER_LOCAL_PORT = int(os.getenv("SERVER_LOCAL_PORT", 10001))
|
||||
|
||||
|
||||
# This is so we only say() full sentences
|
||||
def is_full_sentence(text):
|
||||
return text.endswith(('.', '!', '?'))
|
||||
return text.endswith((".", "!", "?"))
|
||||
|
||||
|
||||
def split_into_sentences(text):
|
||||
return re.split(r'(?<=[.!?])\s+', text)
|
||||
return re.split(r"(?<=[.!?])\s+", text)
|
||||
|
||||
|
||||
# Queues
|
||||
from_computer = queue.Queue() # Just for computer messages from the device. Sync queue because interpreter.run is synchronous
|
||||
from_computer = (
|
||||
queue.Queue()
|
||||
) # Just for computer messages from the device. Sync queue because interpreter.run is synchronous
|
||||
from_user = asyncio.Queue() # Just for user messages from the device.
|
||||
to_device = asyncio.Queue() # For messages we send.
|
||||
|
||||
# Switch code executor to device if that's set
|
||||
|
||||
if os.getenv('CODE_RUNNER') == "device":
|
||||
|
||||
if os.getenv("CODE_RUNNER") == "device":
|
||||
# (This should probably just loop through all languages and apply these changes instead)
|
||||
|
||||
class Python:
|
||||
|
@ -79,13 +80,32 @@ if os.getenv('CODE_RUNNER') == "device":
|
|||
"""Generator that yields a dictionary in LMC Format."""
|
||||
|
||||
# Prepare the data
|
||||
message = {"role": "assistant", "type": "code", "format": "python", "content": code}
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"type": "code",
|
||||
"format": "python",
|
||||
"content": code,
|
||||
}
|
||||
|
||||
# Unless it was just sent to the device, send it wrapped in flags
|
||||
if not (interpreter.messages and interpreter.messages[-1] == message):
|
||||
to_device.put({"role": "assistant", "type": "code", "format": "python", "start": True})
|
||||
to_device.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"type": "code",
|
||||
"format": "python",
|
||||
"start": True,
|
||||
}
|
||||
)
|
||||
to_device.put(message)
|
||||
to_device.put({"role": "assistant", "type": "code", "format": "python", "end": True})
|
||||
to_device.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"type": "code",
|
||||
"format": "python",
|
||||
"end": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
logger.info("Waiting for the device to respond...")
|
||||
|
@ -109,10 +129,12 @@ if os.getenv('CODE_RUNNER') == "device":
|
|||
# Configure interpreter
|
||||
interpreter = configure_interpreter(interpreter)
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping():
|
||||
return PlainTextResponse("pong")
|
||||
|
||||
|
||||
@app.websocket("/")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
@ -145,19 +167,21 @@ async def receive_messages(websocket: WebSocket):
|
|||
except Exception as e:
|
||||
print(str(e))
|
||||
return
|
||||
if 'text' in data:
|
||||
if "text" in data:
|
||||
try:
|
||||
data = json.loads(data['text'])
|
||||
data = json.loads(data["text"])
|
||||
if data["role"] == "computer":
|
||||
from_computer.put(data) # To be handled by interpreter.computer.run
|
||||
from_computer.put(
|
||||
data
|
||||
) # To be handled by interpreter.computer.run
|
||||
elif data["role"] == "user":
|
||||
await from_user.put(data)
|
||||
else:
|
||||
raise ("Unknown role:", data)
|
||||
except json.JSONDecodeError:
|
||||
pass # data is not JSON, leave it as is
|
||||
elif 'bytes' in data:
|
||||
data = data['bytes'] # binary data
|
||||
elif "bytes" in data:
|
||||
data = data["bytes"] # binary data
|
||||
await from_user.put(data)
|
||||
except WebSocketDisconnect as e:
|
||||
if e.code == 1000:
|
||||
|
@ -184,8 +208,8 @@ async def send_messages(websocket: WebSocket):
|
|||
await to_device.put(message)
|
||||
raise
|
||||
|
||||
async def listener():
|
||||
|
||||
async def listener():
|
||||
while True:
|
||||
try:
|
||||
while True:
|
||||
|
@ -197,8 +221,6 @@ async def listener():
|
|||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
message = accumulator.accumulate(chunk)
|
||||
if message == None:
|
||||
# Will be None until we have a full message ready
|
||||
|
@ -209,8 +231,11 @@ async def listener():
|
|||
# At this point, we have our message
|
||||
|
||||
if message["type"] == "audio" and message["format"].startswith("bytes"):
|
||||
|
||||
if "content" not in message or message["content"] == None or message["content"] == "": # If it was nothing / silence / empty
|
||||
if (
|
||||
"content" not in message
|
||||
or message["content"] == None
|
||||
or message["content"] == ""
|
||||
): # If it was nothing / silence / empty
|
||||
continue
|
||||
|
||||
# Convert bytes to audio file
|
||||
|
@ -222,6 +247,7 @@ async def listener():
|
|||
if False:
|
||||
os.system(f"open {audio_file_path}")
|
||||
import time
|
||||
|
||||
time.sleep(15)
|
||||
|
||||
text = stt(audio_file_path)
|
||||
|
@ -239,21 +265,21 @@ async def listener():
|
|||
continue
|
||||
|
||||
# Load, append, and save conversation history
|
||||
with open(conversation_history_path, 'r') as file:
|
||||
with open(conversation_history_path, "r") as file:
|
||||
messages = json.load(file)
|
||||
messages.append(message)
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
with open(conversation_history_path, "w") as file:
|
||||
json.dump(messages, file, indent=4)
|
||||
|
||||
accumulated_text = ""
|
||||
|
||||
|
||||
if any([m["type"] == "image" for m in messages]) and interpreter.llm.model.startswith("gpt-"):
|
||||
if any(
|
||||
[m["type"] == "image" for m in messages]
|
||||
) and interpreter.llm.model.startswith("gpt-"):
|
||||
interpreter.llm.model = "gpt-4-vision-preview"
|
||||
interpreter.llm.supports_vision = True
|
||||
|
||||
for chunk in interpreter.chat(messages, stream=True, display=True):
|
||||
|
||||
if any([m["type"] == "image" for m in interpreter.messages]):
|
||||
interpreter.llm.model = "gpt-4-vision-preview"
|
||||
|
||||
|
@ -264,16 +290,22 @@ async def listener():
|
|||
# Yield to the event loop, so you actually send it out
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if os.getenv('TTS_RUNNER') == "server":
|
||||
if os.getenv("TTS_RUNNER") == "server":
|
||||
# Speak full sentences out loud
|
||||
if chunk["role"] == "assistant" and "content" in chunk and chunk["type"] == "message":
|
||||
if (
|
||||
chunk["role"] == "assistant"
|
||||
and "content" in chunk
|
||||
and chunk["type"] == "message"
|
||||
):
|
||||
accumulated_text += chunk["content"]
|
||||
sentences = split_into_sentences(accumulated_text)
|
||||
|
||||
# If we're going to speak, say we're going to stop sending text.
|
||||
# This should be fixed probably, we should be able to do both in parallel, or only one.
|
||||
if any(is_full_sentence(sentence) for sentence in sentences):
|
||||
await to_device.put({"role": "assistant", "type": "message", "end": True})
|
||||
await to_device.put(
|
||||
{"role": "assistant", "type": "message", "end": True}
|
||||
)
|
||||
|
||||
if is_full_sentence(sentences[-1]):
|
||||
for sentence in sentences:
|
||||
|
@ -287,22 +319,27 @@ async def listener():
|
|||
# If we're going to speak, say we're going to stop sending text.
|
||||
# This should be fixed probably, we should be able to do both in parallel, or only one.
|
||||
if any(is_full_sentence(sentence) for sentence in sentences):
|
||||
await to_device.put({"role": "assistant", "type": "message", "start": True})
|
||||
await to_device.put(
|
||||
{"role": "assistant", "type": "message", "start": True}
|
||||
)
|
||||
|
||||
# If we have a new message, save our progress and go back to the top
|
||||
if not from_user.empty():
|
||||
|
||||
# Check if it's just an end flag. We ignore those.
|
||||
temp_message = await from_user.get()
|
||||
|
||||
if type(temp_message) is dict and temp_message.get("role") == "user" and temp_message.get("end"):
|
||||
if (
|
||||
type(temp_message) is dict
|
||||
and temp_message.get("role") == "user"
|
||||
and temp_message.get("end")
|
||||
):
|
||||
# Yup. False alarm.
|
||||
continue
|
||||
else:
|
||||
# Whoops! Put that back
|
||||
await from_user.put(temp_message)
|
||||
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
with open(conversation_history_path, "w") as file:
|
||||
json.dump(interpreter.messages, file, indent=4)
|
||||
|
||||
# TODO: is triggering seemingly randomly
|
||||
|
@ -311,8 +348,7 @@ async def listener():
|
|||
|
||||
# Also check if there's any new computer messages
|
||||
if not from_computer.empty():
|
||||
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
with open(conversation_history_path, "w") as file:
|
||||
json.dump(interpreter.messages, file, indent=4)
|
||||
|
||||
logger.info("New computer message recieved. Breaking.")
|
||||
|
@ -320,6 +356,7 @@ async def listener():
|
|||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def stream_tts_to_device(sentence):
|
||||
force_task_completion_responses = [
|
||||
"the task is done",
|
||||
|
@ -332,8 +369,8 @@ async def stream_tts_to_device(sentence):
|
|||
for chunk in stream_tts(sentence):
|
||||
await to_device.put(chunk)
|
||||
|
||||
def stream_tts(sentence):
|
||||
|
||||
def stream_tts(sentence):
|
||||
audio_file = tts(sentence)
|
||||
|
||||
with open(audio_file, "rb") as f:
|
||||
|
@ -350,64 +387,84 @@ def stream_tts(sentence):
|
|||
yield chunk
|
||||
yield {"role": "assistant", "type": "audio", "format": file_type, "end": True}
|
||||
|
||||
|
||||
from uvicorn import Config, Server
|
||||
import os
|
||||
import platform
|
||||
from importlib import import_module
|
||||
|
||||
# these will be overwritten
|
||||
HOST = ''
|
||||
HOST = ""
|
||||
PORT = 0
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
server_url = f"{HOST}:{PORT}"
|
||||
print("")
|
||||
print_markdown(f"\n*Ready.*\n")
|
||||
print_markdown("\n*Ready.*\n")
|
||||
print("")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
print_markdown("*Server is shutting down*")
|
||||
|
||||
async def main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service):
|
||||
|
||||
async def main(
|
||||
server_host,
|
||||
server_port,
|
||||
llm_service,
|
||||
model,
|
||||
llm_supports_vision,
|
||||
llm_supports_functions,
|
||||
context_window,
|
||||
max_tokens,
|
||||
temperature,
|
||||
tts_service,
|
||||
stt_service,
|
||||
):
|
||||
global HOST
|
||||
global PORT
|
||||
PORT = server_port
|
||||
HOST = server_host
|
||||
|
||||
# Setup services
|
||||
application_directory = user_data_dir('01')
|
||||
services_directory = os.path.join(application_directory, 'services')
|
||||
application_directory = user_data_dir("01")
|
||||
services_directory = os.path.join(application_directory, "services")
|
||||
|
||||
service_dict = {'llm': llm_service, 'tts': tts_service, 'stt': stt_service}
|
||||
service_dict = {"llm": llm_service, "tts": tts_service, "stt": stt_service}
|
||||
|
||||
# Create a temp file with the session number
|
||||
session_file_path = os.path.join(user_data_dir('01'), '01-session.txt')
|
||||
with open(session_file_path, 'w') as session_file:
|
||||
session_file_path = os.path.join(user_data_dir("01"), "01-session.txt")
|
||||
with open(session_file_path, "w") as session_file:
|
||||
session_id = int(datetime.datetime.now().timestamp() * 1000)
|
||||
session_file.write(str(session_id))
|
||||
|
||||
for service in service_dict:
|
||||
|
||||
service_directory = os.path.join(services_directory, service, service_dict[service])
|
||||
service_directory = os.path.join(
|
||||
services_directory, service, service_dict[service]
|
||||
)
|
||||
|
||||
# This is the folder they can mess around in
|
||||
config = {"service_directory": service_directory}
|
||||
|
||||
if service == "llm":
|
||||
config.update({
|
||||
config.update(
|
||||
{
|
||||
"interpreter": interpreter,
|
||||
"model": model,
|
||||
"llm_supports_vision": llm_supports_vision,
|
||||
"llm_supports_functions": llm_supports_functions,
|
||||
"context_window": context_window,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature
|
||||
})
|
||||
"temperature": temperature,
|
||||
}
|
||||
)
|
||||
|
||||
module = import_module(f'.server.services.{service}.{service_dict[service]}.{service}', package='source')
|
||||
module = import_module(
|
||||
f".server.services.{service}.{service_dict[service]}.{service}",
|
||||
package="source",
|
||||
)
|
||||
|
||||
ServiceClass = getattr(module, service.capitalize())
|
||||
service_instance = ServiceClass(config)
|
||||
|
@ -422,10 +479,11 @@ async def main(server_host, server_port, llm_service, model, llm_supports_vision
|
|||
if True: # in the future, code can run on device. for now, just server.
|
||||
asyncio.create_task(put_kernel_messages_into_queue(from_computer))
|
||||
|
||||
config = Config(app, host=server_host, port=int(server_port), lifespan='on')
|
||||
config = Config(app, host=server_host, port=int(server_port), lifespan="on")
|
||||
server = Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
# Run the FastAPI app
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
class Llm:
|
||||
def __init__(self, config):
|
||||
|
||||
# Litellm is used by OI by default, so we just modify OI
|
||||
|
||||
interpreter = config["interpreter"]
|
||||
|
@ -10,6 +9,3 @@ class Llm:
|
|||
setattr(interpreter, key.replace("-", "_"), value)
|
||||
|
||||
self.llm = interpreter.llm.completions
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -3,29 +3,54 @@ import subprocess
|
|||
import requests
|
||||
import json
|
||||
|
||||
|
||||
class Llm:
|
||||
def __init__(self, config):
|
||||
self.install(config["service_directory"])
|
||||
|
||||
def install(self, service_directory):
|
||||
LLM_FOLDER_PATH = service_directory
|
||||
self.llm_directory = os.path.join(LLM_FOLDER_PATH, 'llm')
|
||||
self.llm_directory = os.path.join(LLM_FOLDER_PATH, "llm")
|
||||
if not os.path.isdir(self.llm_directory): # Check if the LLM directory exists
|
||||
os.makedirs(LLM_FOLDER_PATH, exist_ok=True)
|
||||
|
||||
# Install WasmEdge
|
||||
subprocess.run(['curl', '-sSf', 'https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh', '|', 'bash', '-s', '--', '--plugin', 'wasi_nn-ggml'])
|
||||
subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-sSf",
|
||||
"https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh",
|
||||
"|",
|
||||
"bash",
|
||||
"-s",
|
||||
"--",
|
||||
"--plugin",
|
||||
"wasi_nn-ggml",
|
||||
]
|
||||
)
|
||||
|
||||
# Download the Qwen1.5-0.5B-Chat model GGUF file
|
||||
MODEL_URL = "https://huggingface.co/second-state/Qwen1.5-0.5B-Chat-GGUF/resolve/main/Qwen1.5-0.5B-Chat-Q5_K_M.gguf"
|
||||
subprocess.run(['curl', '-LO', MODEL_URL], cwd=self.llm_directory)
|
||||
subprocess.run(["curl", "-LO", MODEL_URL], cwd=self.llm_directory)
|
||||
|
||||
# Download the llama-api-server.wasm app
|
||||
APP_URL = "https://github.com/LlamaEdge/LlamaEdge/releases/latest/download/llama-api-server.wasm"
|
||||
subprocess.run(['curl', '-LO', APP_URL], cwd=self.llm_directory)
|
||||
subprocess.run(["curl", "-LO", APP_URL], cwd=self.llm_directory)
|
||||
|
||||
# Run the API server
|
||||
subprocess.run(['wasmedge', '--dir', '.:.', '--nn-preload', 'default:GGML:AUTO:Qwen1.5-0.5B-Chat-Q5_K_M.gguf', 'llama-api-server.wasm', '-p', 'llama-2-chat'], cwd=self.llm_directory)
|
||||
subprocess.run(
|
||||
[
|
||||
"wasmedge",
|
||||
"--dir",
|
||||
".:.",
|
||||
"--nn-preload",
|
||||
"default:GGML:AUTO:Qwen1.5-0.5B-Chat-Q5_K_M.gguf",
|
||||
"llama-api-server.wasm",
|
||||
"-p",
|
||||
"llama-2-chat",
|
||||
],
|
||||
cwd=self.llm_directory,
|
||||
)
|
||||
|
||||
print("LLM setup completed.")
|
||||
else:
|
||||
|
@ -33,17 +58,11 @@ class Llm:
|
|||
|
||||
def llm(self, messages):
|
||||
url = "http://localhost:8080/v1/chat/completions"
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
"messages": messages,
|
||||
"model": "llama-2-chat"
|
||||
}
|
||||
with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response:
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
data = {"messages": messages, "model": "llama-2-chat"}
|
||||
with requests.post(
|
||||
url, headers=headers, data=json.dumps(data), stream=True
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
|
|
|
@ -10,9 +10,6 @@ import shutil
|
|||
import ffmpeg
|
||||
import subprocess
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import platform
|
||||
import urllib.request
|
||||
|
||||
|
||||
|
@ -26,7 +23,6 @@ class Stt:
|
|||
|
||||
|
||||
def install(service_dir):
|
||||
|
||||
### INSTALL
|
||||
|
||||
WHISPER_RUST_PATH = os.path.join(service_dir, "whisper-rust")
|
||||
|
@ -41,29 +37,38 @@ def install(service_dir):
|
|||
os.chdir(WHISPER_RUST_PATH)
|
||||
|
||||
# Check if whisper-rust executable exists before attempting to build
|
||||
if not os.path.isfile(os.path.join(WHISPER_RUST_PATH, "target/release/whisper-rust")):
|
||||
if not os.path.isfile(
|
||||
os.path.join(WHISPER_RUST_PATH, "target/release/whisper-rust")
|
||||
):
|
||||
# Check if Rust is installed. Needed to build whisper executable
|
||||
|
||||
rustc_path = shutil.which("rustc")
|
||||
|
||||
if rustc_path is None:
|
||||
print("Rust is not installed or is not in system PATH. Please install Rust before proceeding.")
|
||||
print(
|
||||
"Rust is not installed or is not in system PATH. Please install Rust before proceeding."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# Build Whisper Rust executable if not found
|
||||
subprocess.run(['cargo', 'build', '--release'], check=True)
|
||||
subprocess.run(["cargo", "build", "--release"], check=True)
|
||||
else:
|
||||
print("Whisper Rust executable already exists. Skipping build.")
|
||||
|
||||
WHISPER_MODEL_PATH = os.path.join(service_dir, "model")
|
||||
|
||||
WHISPER_MODEL_NAME = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
|
||||
WHISPER_MODEL_URL = os.getenv('WHISPER_MODEL_URL', 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/')
|
||||
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
|
||||
WHISPER_MODEL_URL = os.getenv(
|
||||
"WHISPER_MODEL_URL",
|
||||
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/",
|
||||
)
|
||||
|
||||
if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)):
|
||||
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
|
||||
urllib.request.urlretrieve(f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
|
||||
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME))
|
||||
urllib.request.urlretrieve(
|
||||
f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
|
||||
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME),
|
||||
)
|
||||
else:
|
||||
print("Whisper model already exists. Skipping download.")
|
||||
|
||||
|
@ -85,25 +90,31 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
|||
|
||||
# Create a temporary file with the appropriate extension
|
||||
input_ext = convert_mime_type_to_format(mime_type)
|
||||
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
|
||||
with open(input_path, 'wb') as f:
|
||||
input_path = os.path.join(
|
||||
temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}"
|
||||
)
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio)
|
||||
|
||||
# Check if the input file exists
|
||||
assert os.path.exists(input_path), f"Input file does not exist: {input_path}"
|
||||
|
||||
# Export to wav
|
||||
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
output_path = os.path.join(
|
||||
temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
|
||||
)
|
||||
print(mime_type, input_path, output_path)
|
||||
if mime_type == "audio/raw":
|
||||
ffmpeg.input(
|
||||
input_path,
|
||||
f='s16le',
|
||||
ar='16000',
|
||||
f="s16le",
|
||||
ar="16000",
|
||||
ac=1,
|
||||
).output(output_path, loglevel='panic').run()
|
||||
).output(output_path, loglevel="panic").run()
|
||||
else:
|
||||
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k', loglevel='panic').run()
|
||||
ffmpeg.input(input_path).output(
|
||||
output_path, acodec="pcm_s16le", ac=1, ar="16k", loglevel="panic"
|
||||
).run()
|
||||
|
||||
try:
|
||||
yield output_path
|
||||
|
@ -113,28 +124,40 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
|||
|
||||
|
||||
def run_command(command):
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
return result.stdout, result.stderr
|
||||
|
||||
|
||||
def get_transcription_file(service_directory, wav_file_path: str):
|
||||
local_path = os.path.join(service_directory, 'model')
|
||||
whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release')
|
||||
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
|
||||
local_path = os.path.join(service_directory, "model")
|
||||
whisper_rust_path = os.path.join(
|
||||
service_directory, "whisper-rust", "target", "release"
|
||||
)
|
||||
model_name = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
|
||||
|
||||
output, _ = run_command([
|
||||
os.path.join(whisper_rust_path, 'whisper-rust'),
|
||||
'--model-path', os.path.join(local_path, model_name),
|
||||
'--file-path', wav_file_path
|
||||
])
|
||||
output, _ = run_command(
|
||||
[
|
||||
os.path.join(whisper_rust_path, "whisper-rust"),
|
||||
"--model-path",
|
||||
os.path.join(local_path, model_name),
|
||||
"--file-path",
|
||||
wav_file_path,
|
||||
]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def stt_wav(service_directory, wav_file_path: str):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
|
||||
output_path = os.path.join(
|
||||
temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
|
||||
)
|
||||
ffmpeg.input(wav_file_path).output(
|
||||
output_path, acodec="pcm_s16le", ac=1, ar="16k"
|
||||
).run()
|
||||
try:
|
||||
transcript = get_transcription_file(service_directory, output_path)
|
||||
finally:
|
||||
|
|
|
@ -6,7 +6,6 @@ class Stt:
|
|||
return stt(audio_file_path)
|
||||
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
import contextlib
|
||||
|
@ -19,6 +18,7 @@ from openai import OpenAI
|
|||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def convert_mime_type_to_format(mime_type: str) -> str:
|
||||
if mime_type == "audio/x-wav" or mime_type == "audio/wav":
|
||||
return "wav"
|
||||
|
@ -29,30 +29,37 @@ def convert_mime_type_to_format(mime_type: str) -> str:
|
|||
|
||||
return mime_type
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
# Create a temporary file with the appropriate extension
|
||||
input_ext = convert_mime_type_to_format(mime_type)
|
||||
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
|
||||
with open(input_path, 'wb') as f:
|
||||
input_path = os.path.join(
|
||||
temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}"
|
||||
)
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio)
|
||||
|
||||
# Check if the input file exists
|
||||
assert os.path.exists(input_path), f"Input file does not exist: {input_path}"
|
||||
|
||||
# Export to wav
|
||||
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
output_path = os.path.join(
|
||||
temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
|
||||
)
|
||||
if mime_type == "audio/raw":
|
||||
ffmpeg.input(
|
||||
input_path,
|
||||
f='s16le',
|
||||
ar='16000',
|
||||
f="s16le",
|
||||
ar="16000",
|
||||
ac=1,
|
||||
).output(output_path, loglevel='panic').run()
|
||||
).output(output_path, loglevel="panic").run()
|
||||
else:
|
||||
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k', loglevel='panic').run()
|
||||
ffmpeg.input(input_path).output(
|
||||
output_path, acodec="pcm_s16le", ac=1, ar="16k", loglevel="panic"
|
||||
).run()
|
||||
|
||||
try:
|
||||
yield output_path
|
||||
|
@ -60,39 +67,49 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
|||
os.remove(input_path)
|
||||
os.remove(output_path)
|
||||
|
||||
|
||||
def run_command(command):
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
return result.stdout, result.stderr
|
||||
|
||||
def get_transcription_file(wav_file_path: str):
|
||||
local_path = os.path.join(os.path.dirname(__file__), 'local_service')
|
||||
whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release')
|
||||
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
|
||||
|
||||
output, error = run_command([
|
||||
os.path.join(whisper_rust_path, 'whisper-rust'),
|
||||
'--model-path', os.path.join(local_path, model_name),
|
||||
'--file-path', wav_file_path
|
||||
])
|
||||
def get_transcription_file(wav_file_path: str):
|
||||
local_path = os.path.join(os.path.dirname(__file__), "local_service")
|
||||
whisper_rust_path = os.path.join(
|
||||
os.path.dirname(__file__), "whisper-rust", "target", "release"
|
||||
)
|
||||
model_name = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
|
||||
|
||||
output, error = run_command(
|
||||
[
|
||||
os.path.join(whisper_rust_path, "whisper-rust"),
|
||||
"--model-path",
|
||||
os.path.join(local_path, model_name),
|
||||
"--file-path",
|
||||
wav_file_path,
|
||||
]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_transcription_bytes(audio_bytes: bytearray, mime_type):
|
||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
||||
return get_transcription_file(wav_file_path)
|
||||
|
||||
|
||||
def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"):
|
||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
||||
return stt_wav(wav_file_path)
|
||||
|
||||
def stt_wav(wav_file_path: str):
|
||||
|
||||
def stt_wav(wav_file_path: str):
|
||||
audio_file = open(wav_file_path, "rb")
|
||||
try:
|
||||
transcript = client.audio.transcriptions.create(
|
||||
model="whisper-1",
|
||||
file=audio_file,
|
||||
response_format="text"
|
||||
model="whisper-1", file=audio_file, response_format="text"
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
print(f"openai.BadRequestError: {e}")
|
||||
|
@ -100,10 +117,13 @@ def stt_wav(wav_file_path: str):
|
|||
|
||||
return transcript
|
||||
|
||||
|
||||
def stt(input_data, mime_type="audio/wav"):
|
||||
if isinstance(input_data, str):
|
||||
return stt_wav(input_data)
|
||||
elif isinstance(input_data, bytearray):
|
||||
return stt_bytes(input_data, mime_type)
|
||||
else:
|
||||
raise ValueError("Input data should be either a path to a wav file (str) or audio bytes (bytearray)")
|
||||
raise ValueError(
|
||||
"Input data should be either a path to a wav file (str) or audio bytes (bytearray)"
|
||||
)
|
||||
|
|
|
@ -2,23 +2,25 @@ import ffmpeg
|
|||
import tempfile
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from source.server.utils.logs import logger
|
||||
from source.server.utils.logs import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
# If this TTS service is used, the OPENAI_API_KEY environment variable must be set
|
||||
if not os.getenv('OPENAI_API_KEY'):
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
logger.error("")
|
||||
logger.error(f"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable, or run 01 with the --local option.")
|
||||
logger.error(
|
||||
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable, or run 01 with the --local option."
|
||||
)
|
||||
logger.error("Aborting...")
|
||||
logger.error("")
|
||||
os._exit(1)
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
class Tts:
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
@ -26,17 +28,17 @@ class Tts:
|
|||
def tts(self, text):
|
||||
response = client.audio.speech.create(
|
||||
model="tts-1",
|
||||
voice=os.getenv('OPENAI_VOICE_NAME', 'alloy'),
|
||||
voice=os.getenv("OPENAI_VOICE_NAME", "alloy"),
|
||||
input=text,
|
||||
response_format="opus"
|
||||
response_format="opus",
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(suffix=".opus", delete=False) as temp_file:
|
||||
response.stream_to_file(temp_file.name)
|
||||
|
||||
# TODO: hack to format audio correctly for device
|
||||
outfile = tempfile.gettempdir() + "/" + "raw.dat"
|
||||
ffmpeg.input(temp_file.name).output(outfile, f="s16le", ar="16000", ac="1", loglevel='panic').run()
|
||||
ffmpeg.input(temp_file.name).output(
|
||||
outfile, f="s16le", ar="16000", ac="1", loglevel="panic"
|
||||
).run()
|
||||
|
||||
return outfile
|
||||
|
||||
|
||||
|
|
|
@ -13,26 +13,40 @@ class Tts:
|
|||
self.install(config["service_directory"])
|
||||
|
||||
def tts(self, text):
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
||||
output_file = temp_file.name
|
||||
piper_dir = self.piper_directory
|
||||
subprocess.run([
|
||||
os.path.join(piper_dir, 'piper'),
|
||||
'--model', os.path.join(piper_dir, os.getenv('PIPER_VOICE_NAME', 'en_US-lessac-medium.onnx')),
|
||||
'--output_file', output_file
|
||||
], input=text, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
subprocess.run(
|
||||
[
|
||||
os.path.join(piper_dir, "piper"),
|
||||
"--model",
|
||||
os.path.join(
|
||||
piper_dir,
|
||||
os.getenv("PIPER_VOICE_NAME", "en_US-lessac-medium.onnx"),
|
||||
),
|
||||
"--output_file",
|
||||
output_file,
|
||||
],
|
||||
input=text,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# TODO: hack to format audio correctly for device
|
||||
outfile = tempfile.gettempdir() + "/" + "raw.dat"
|
||||
ffmpeg.input(temp_file.name).output(outfile, f="s16le", ar="16000", ac="1", loglevel='panic').run()
|
||||
ffmpeg.input(temp_file.name).output(
|
||||
outfile, f="s16le", ar="16000", ac="1", loglevel="panic"
|
||||
).run()
|
||||
|
||||
return outfile
|
||||
|
||||
def install(self, service_directory):
|
||||
PIPER_FOLDER_PATH = service_directory
|
||||
self.piper_directory = os.path.join(PIPER_FOLDER_PATH, 'piper')
|
||||
if not os.path.isdir(self.piper_directory): # Check if the Piper directory exists
|
||||
self.piper_directory = os.path.join(PIPER_FOLDER_PATH, "piper")
|
||||
if not os.path.isdir(
|
||||
self.piper_directory
|
||||
): # Check if the Piper directory exists
|
||||
os.makedirs(PIPER_FOLDER_PATH, exist_ok=True)
|
||||
|
||||
# Determine OS and architecture
|
||||
|
@ -60,51 +74,91 @@ class Tts:
|
|||
asset_url = f"{PIPER_URL}{PIPER_ASSETNAME}"
|
||||
|
||||
if OS == "windows":
|
||||
|
||||
asset_url = asset_url.replace(".tar.gz", ".zip")
|
||||
|
||||
# Download and extract Piper
|
||||
urllib.request.urlretrieve(asset_url, os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME))
|
||||
urllib.request.urlretrieve(
|
||||
asset_url, os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME)
|
||||
)
|
||||
|
||||
# Extract the downloaded file
|
||||
if OS == "windows":
|
||||
import zipfile
|
||||
with zipfile.ZipFile(os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), 'r') as zip_ref:
|
||||
|
||||
with zipfile.ZipFile(
|
||||
os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), "r"
|
||||
) as zip_ref:
|
||||
zip_ref.extractall(path=PIPER_FOLDER_PATH)
|
||||
else:
|
||||
with tarfile.open(os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), 'r:gz') as tar:
|
||||
with tarfile.open(
|
||||
os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), "r:gz"
|
||||
) as tar:
|
||||
tar.extractall(path=PIPER_FOLDER_PATH)
|
||||
|
||||
PIPER_VOICE_URL = os.getenv('PIPER_VOICE_URL',
|
||||
'https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/lessac/medium/')
|
||||
PIPER_VOICE_NAME = os.getenv('PIPER_VOICE_NAME', 'en_US-lessac-medium.onnx')
|
||||
PIPER_VOICE_URL = os.getenv(
|
||||
"PIPER_VOICE_URL",
|
||||
"https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/lessac/medium/",
|
||||
)
|
||||
PIPER_VOICE_NAME = os.getenv("PIPER_VOICE_NAME", "en_US-lessac-medium.onnx")
|
||||
|
||||
# Download voice model and its json file
|
||||
urllib.request.urlretrieve(f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}",
|
||||
os.path.join(self.piper_directory, PIPER_VOICE_NAME))
|
||||
urllib.request.urlretrieve(f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}.json",
|
||||
os.path.join(self.piper_directory, f"{PIPER_VOICE_NAME}.json"))
|
||||
urllib.request.urlretrieve(
|
||||
f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}",
|
||||
os.path.join(self.piper_directory, PIPER_VOICE_NAME),
|
||||
)
|
||||
urllib.request.urlretrieve(
|
||||
f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}.json",
|
||||
os.path.join(self.piper_directory, f"{PIPER_VOICE_NAME}.json"),
|
||||
)
|
||||
|
||||
# Additional setup for macOS
|
||||
if OS == "macos":
|
||||
if ARCH == "x64":
|
||||
subprocess.run(['softwareupdate', '--install-rosetta', '--agree-to-license'])
|
||||
subprocess.run(
|
||||
["softwareupdate", "--install-rosetta", "--agree-to-license"]
|
||||
)
|
||||
|
||||
PIPER_PHONEMIZE_ASSETNAME = f"piper-phonemize_{OS}_{ARCH}.tar.gz"
|
||||
PIPER_PHONEMIZE_URL = "https://github.com/rhasspy/piper-phonemize/releases/latest/download/"
|
||||
urllib.request.urlretrieve(f"{PIPER_PHONEMIZE_URL}{PIPER_PHONEMIZE_ASSETNAME}",
|
||||
os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME))
|
||||
urllib.request.urlretrieve(
|
||||
f"{PIPER_PHONEMIZE_URL}{PIPER_PHONEMIZE_ASSETNAME}",
|
||||
os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME),
|
||||
)
|
||||
|
||||
with tarfile.open(os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME), 'r:gz') as tar:
|
||||
with tarfile.open(
|
||||
os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME),
|
||||
"r:gz",
|
||||
) as tar:
|
||||
tar.extractall(path=self.piper_directory)
|
||||
|
||||
PIPER_DIR = self.piper_directory
|
||||
subprocess.run(['install_name_tool', '-change', '@rpath/libespeak-ng.1.dylib',
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libespeak-ng.1.dylib", f"{PIPER_DIR}/piper"])
|
||||
subprocess.run(['install_name_tool', '-change', '@rpath/libonnxruntime.1.14.1.dylib',
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libonnxruntime.1.14.1.dylib", f"{PIPER_DIR}/piper"])
|
||||
subprocess.run(['install_name_tool', '-change', '@rpath/libpiper_phonemize.1.dylib',
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libpiper_phonemize.1.dylib", f"{PIPER_DIR}/piper"])
|
||||
subprocess.run(
|
||||
[
|
||||
"install_name_tool",
|
||||
"-change",
|
||||
"@rpath/libespeak-ng.1.dylib",
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libespeak-ng.1.dylib",
|
||||
f"{PIPER_DIR}/piper",
|
||||
]
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"install_name_tool",
|
||||
"-change",
|
||||
"@rpath/libonnxruntime.1.14.1.dylib",
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libonnxruntime.1.14.1.dylib",
|
||||
f"{PIPER_DIR}/piper",
|
||||
]
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"install_name_tool",
|
||||
"-change",
|
||||
"@rpath/libpiper_phonemize.1.dylib",
|
||||
f"{PIPER_DIR}/piper-phonemize/lib/libpiper_phonemize.1.dylib",
|
||||
f"{PIPER_DIR}/piper",
|
||||
]
|
||||
)
|
||||
|
||||
print("Piper setup completed.")
|
||||
else:
|
||||
|
|
|
@ -3,9 +3,9 @@ from datetime import datetime
|
|||
from pytimeparse import parse
|
||||
from crontab import CronTab
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from platformdirs import user_data_dir
|
||||
|
||||
|
||||
def schedule(message="", start=None, interval=None) -> None:
|
||||
"""
|
||||
Schedules a task at a particular time, or at a particular interval
|
||||
|
@ -17,19 +17,18 @@ def schedule(message="", start=None, interval=None) -> None:
|
|||
raise ValueError("Either start time or interval must be specified.")
|
||||
|
||||
# Read the temp file to see what the current session is
|
||||
session_file_path = os.path.join(user_data_dir('01'), '01-session.txt')
|
||||
session_file_path = os.path.join(user_data_dir("01"), "01-session.txt")
|
||||
|
||||
with open(session_file_path, 'r') as session_file:
|
||||
with open(session_file_path, "r") as session_file:
|
||||
file_session_value = session_file.read().strip()
|
||||
|
||||
|
||||
prefixed_message = "AUTOMATED MESSAGE FROM SCHEDULER: " + message
|
||||
|
||||
# Escape the message and the json, cron is funky with quotes
|
||||
escaped_question = prefixed_message.replace('"', '\\"')
|
||||
json_data = f"{{\\\"text\\\": \\\"{escaped_question}\\\"}}"
|
||||
json_data = f'{{\\"text\\": \\"{escaped_question}\\"}}'
|
||||
|
||||
command = f'''bash -c 'if [ "$(cat "{session_file_path}")" == "{file_session_value}" ]; then /usr/bin/curl -X POST -H "Content-Type: application/json" -d "{json_data}" http://localhost:10001/; fi' '''
|
||||
command = f"""bash -c 'if [ "$(cat "{session_file_path}")" == "{file_session_value}" ]; then /usr/bin/curl -X POST -H "Content-Type: application/json" -d "{json_data}" http://localhost:10001/; fi' """
|
||||
|
||||
cron = CronTab(user=True)
|
||||
job = cron.new(command=command)
|
||||
|
@ -63,4 +62,3 @@ def schedule(message="", start=None, interval=None) -> None:
|
|||
print(f"Task scheduled every {days} day(s)")
|
||||
|
||||
cron.write()
|
||||
|
||||
|
|
|
@ -237,4 +237,6 @@ For example:
|
|||
|
||||
ALWAYS REMEMBER: You are running on a device called the O1, where the interface is entirely speech-based. Make your responses to the user **VERY short.**
|
||||
|
||||
""".strip().replace("OI_SKILLS_DIR", os.path.join(os.path.dirname(__file__), "skills"))
|
||||
""".strip().replace(
|
||||
"OI_SKILLS_DIR", os.path.join(os.path.dirname(__file__), "skills")
|
||||
)
|
||||
|
|
|
@ -131,4 +131,6 @@ print(output)
|
|||
|
||||
Remember: You can run Python code outside a function only to run a Python function; all other code must go in a in Python function if you first write a Python function. ALL imports must go inside the function.
|
||||
|
||||
""".strip().replace("OI_SKILLS_DIR", os.path.abspath(os.path.join(os.path.dirname(__file__), "skills")))
|
||||
""".strip().replace(
|
||||
"OI_SKILLS_DIR", os.path.abspath(os.path.join(os.path.dirname(__file__), "skills"))
|
||||
)
|
||||
|
|
|
@ -1,11 +1,5 @@
|
|||
# test_main.py
|
||||
import subprocess
|
||||
import uuid
|
||||
import pytest
|
||||
from source.server.i import configure_interpreter
|
||||
from unittest.mock import Mock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -1,26 +1,35 @@
|
|||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import shutil
|
||||
import pyqrcode
|
||||
import time
|
||||
from ..utils.print_markdown import print_markdown
|
||||
|
||||
def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10001, qr=False):
|
||||
print_markdown(f"Exposing server to the internet...")
|
||||
|
||||
def create_tunnel(
|
||||
tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False
|
||||
):
|
||||
print_markdown("Exposing server to the internet...")
|
||||
|
||||
server_url = ""
|
||||
if tunnel_method == "bore":
|
||||
try:
|
||||
output = subprocess.check_output('command -v bore', shell=True)
|
||||
output = subprocess.check_output("command -v bore", shell=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print("The bore-cli command is not available. Please run 'cargo install bore-cli'.")
|
||||
print(
|
||||
"The bore-cli command is not available. Please run 'cargo install bore-cli'."
|
||||
)
|
||||
print("For more information, see https://github.com/ekzhang/bore")
|
||||
exit(1)
|
||||
|
||||
time.sleep(6)
|
||||
# output = subprocess.check_output(f'bore local {server_port} --to bore.pub', shell=True)
|
||||
process = subprocess.Popen(f'bore local {server_port} --to bore.pub', shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
process = subprocess.Popen(
|
||||
f"bore local {server_port} --to bore.pub",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
line = process.stdout.readline()
|
||||
|
@ -28,25 +37,34 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10
|
|||
if not line:
|
||||
break
|
||||
if "listening at bore.pub:" in line:
|
||||
remote_port = re.search('bore.pub:([0-9]*)', line).group(1)
|
||||
remote_port = re.search("bore.pub:([0-9]*)", line).group(1)
|
||||
server_url = f"bore.pub:{remote_port}"
|
||||
print_markdown(f"Your server is being hosted at the following URL: bore.pub:{remote_port}")
|
||||
print_markdown(
|
||||
f"Your server is being hosted at the following URL: bore.pub:{remote_port}"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
elif tunnel_method == "localtunnel":
|
||||
if subprocess.call('command -v lt', shell=True):
|
||||
if subprocess.call("command -v lt", shell=True):
|
||||
print("The 'lt' command is not available.")
|
||||
print("Please ensure you have Node.js installed, then run 'npm install -g localtunnel'.")
|
||||
print("For more information, see https://github.com/localtunnel/localtunnel")
|
||||
print(
|
||||
"Please ensure you have Node.js installed, then run 'npm install -g localtunnel'."
|
||||
)
|
||||
print(
|
||||
"For more information, see https://github.com/localtunnel/localtunnel"
|
||||
)
|
||||
exit(1)
|
||||
else:
|
||||
process = subprocess.Popen(f'npx localtunnel --port {server_port}', shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
process = subprocess.Popen(
|
||||
f"npx localtunnel --port {server_port}",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
found_url = False
|
||||
url_pattern = re.compile(r'your url is: https://[a-zA-Z0-9.-]+')
|
||||
url_pattern = re.compile(r"your url is: https://[a-zA-Z0-9.-]+")
|
||||
|
||||
while True:
|
||||
line = process.stdout.readline()
|
||||
|
@ -55,35 +73,46 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10
|
|||
match = url_pattern.search(line)
|
||||
if match:
|
||||
found_url = True
|
||||
remote_url = match.group(0).replace('your url is: ', '')
|
||||
remote_url = match.group(0).replace("your url is: ", "")
|
||||
server_url = remote_url
|
||||
print(f"\nYour server is being hosted at the following URL: {remote_url}")
|
||||
print(
|
||||
f"\nYour server is being hosted at the following URL: {remote_url}"
|
||||
)
|
||||
break # Exit the loop once the URL is found
|
||||
|
||||
if not found_url:
|
||||
print("Failed to extract the localtunnel URL. Please check localtunnel's output for details.")
|
||||
print(
|
||||
"Failed to extract the localtunnel URL. Please check localtunnel's output for details."
|
||||
)
|
||||
|
||||
elif tunnel_method == "ngrok":
|
||||
|
||||
# Check if ngrok is installed
|
||||
is_installed = subprocess.check_output('command -v ngrok', shell=True).decode().strip()
|
||||
is_installed = (
|
||||
subprocess.check_output("command -v ngrok", shell=True).decode().strip()
|
||||
)
|
||||
if not is_installed:
|
||||
print("The ngrok command is not available.")
|
||||
print("Please install ngrok using the instructions at https://ngrok.com/docs/getting-started/")
|
||||
print(
|
||||
"Please install ngrok using the instructions at https://ngrok.com/docs/getting-started/"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# If ngrok is installed, start it on the specified port
|
||||
# process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)
|
||||
process = subprocess.Popen(f'ngrok http {server_port} --scheme http,https --domain=marten-advanced-dragon.ngrok-free.app --log=stdout', shell=True, stdout=subprocess.PIPE)
|
||||
process = subprocess.Popen(
|
||||
f"ngrok http {server_port} --scheme http,https --domain=marten-advanced-dragon.ngrok-free.app --log=stdout",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
# Initially, no URL is found
|
||||
found_url = False
|
||||
# Regular expression to match the ngrok URL
|
||||
url_pattern = re.compile(r'https://[a-zA-Z0-9-]+\.ngrok(-free)?\.app')
|
||||
url_pattern = re.compile(r"https://[a-zA-Z0-9-]+\.ngrok(-free)?\.app")
|
||||
|
||||
# Read the output line by line
|
||||
while True:
|
||||
line = process.stdout.readline().decode('utf-8')
|
||||
line = process.stdout.readline().decode("utf-8")
|
||||
if not line:
|
||||
break # Break out of the loop if no more output
|
||||
match = url_pattern.search(line)
|
||||
|
@ -91,15 +120,18 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10
|
|||
found_url = True
|
||||
remote_url = match.group(0)
|
||||
server_url = remote_url
|
||||
print(f"\nYour server is being hosted at the following URL: {remote_url}")
|
||||
print(
|
||||
f"\nYour server is being hosted at the following URL: {remote_url}"
|
||||
)
|
||||
break # Exit the loop once the URL is found
|
||||
|
||||
if not found_url:
|
||||
print("Failed to extract the ngrok tunnel URL. Please check ngrok's output for details.")
|
||||
print(
|
||||
"Failed to extract the ngrok tunnel URL. Please check ngrok's output for details."
|
||||
)
|
||||
|
||||
if server_url and qr:
|
||||
text = pyqrcode.create(remote_url)
|
||||
print(text.terminal(quiet_zone=1))
|
||||
|
||||
return server_url
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import tempfile
|
|||
import ffmpeg
|
||||
import subprocess
|
||||
|
||||
|
||||
def convert_mime_type_to_format(mime_type: str) -> str:
|
||||
if mime_type == "audio/x-wav" or mime_type == "audio/wav":
|
||||
return "wav"
|
||||
|
@ -15,39 +16,49 @@ def convert_mime_type_to_format(mime_type: str) -> str:
|
|||
|
||||
return mime_type
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
# Create a temporary file with the appropriate extension
|
||||
input_ext = convert_mime_type_to_format(mime_type)
|
||||
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
|
||||
with open(input_path, 'wb') as f:
|
||||
input_path = os.path.join(
|
||||
temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}"
|
||||
)
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio)
|
||||
|
||||
# Check if the input file exists
|
||||
assert os.path.exists(input_path), f"Input file does not exist: {input_path}"
|
||||
|
||||
# Export to wav
|
||||
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
output_path = os.path.join(
|
||||
temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
|
||||
)
|
||||
print(mime_type, input_path, output_path)
|
||||
if mime_type == "audio/raw":
|
||||
ffmpeg.input(
|
||||
input_path,
|
||||
f='s16le',
|
||||
ar='16000',
|
||||
f="s16le",
|
||||
ar="16000",
|
||||
ac=1,
|
||||
).output(output_path, loglevel='panic').run()
|
||||
).output(output_path, loglevel="panic").run()
|
||||
else:
|
||||
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k', loglevel='panic').run()
|
||||
ffmpeg.input(input_path).output(
|
||||
output_path, acodec="pcm_s16le", ac=1, ar="16k", loglevel="panic"
|
||||
).run()
|
||||
|
||||
try:
|
||||
yield output_path
|
||||
finally:
|
||||
os.remove(input_path)
|
||||
|
||||
|
||||
def run_command(command):
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
return result.stdout, result.stderr
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import platform
|
||||
|
||||
|
||||
def get_system_info():
|
||||
system = platform.system()
|
||||
if system == "Linux":
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import asyncio
|
||||
|
@ -7,8 +8,10 @@ import platform
|
|||
|
||||
from .logs import setup_logging
|
||||
from .logs import logger
|
||||
|
||||
setup_logging()
|
||||
|
||||
|
||||
def get_kernel_messages():
|
||||
"""
|
||||
Is this the way to do this?
|
||||
|
@ -16,20 +19,23 @@ def get_kernel_messages():
|
|||
current_platform = platform.system()
|
||||
|
||||
if current_platform == "Darwin":
|
||||
process = subprocess.Popen(['syslog'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||
process = subprocess.Popen(
|
||||
["syslog"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
|
||||
)
|
||||
output, _ = process.communicate()
|
||||
return output.decode('utf-8')
|
||||
return output.decode("utf-8")
|
||||
elif current_platform == "Linux":
|
||||
with open('/var/log/dmesg', 'r') as file:
|
||||
with open("/var/log/dmesg", "r") as file:
|
||||
return file.read()
|
||||
else:
|
||||
logger.info("Unsupported platform.")
|
||||
|
||||
|
||||
def custom_filter(message):
|
||||
# Check for {TO_INTERPRETER{ message here }TO_INTERPRETER} pattern
|
||||
if '{TO_INTERPRETER{' in message and '}TO_INTERPRETER}' in message:
|
||||
start = message.find('{TO_INTERPRETER{') + len('{TO_INTERPRETER{')
|
||||
end = message.find('}TO_INTERPRETER}', start)
|
||||
if "{TO_INTERPRETER{" in message and "}TO_INTERPRETER}" in message:
|
||||
start = message.find("{TO_INTERPRETER{") + len("{TO_INTERPRETER{")
|
||||
end = message.find("}TO_INTERPRETER}", start)
|
||||
return message[start:end]
|
||||
# Check for USB mention
|
||||
# elif 'USB' in message:
|
||||
|
@ -41,8 +47,10 @@ def custom_filter(message):
|
|||
else:
|
||||
return None
|
||||
|
||||
|
||||
last_messages = ""
|
||||
|
||||
|
||||
def check_filtered_kernel():
|
||||
messages = get_kernel_messages()
|
||||
if messages is None:
|
||||
|
@ -66,11 +74,25 @@ async def put_kernel_messages_into_queue(queue):
|
|||
if text:
|
||||
if isinstance(queue, asyncio.Queue):
|
||||
await queue.put({"role": "computer", "type": "console", "start": True})
|
||||
await queue.put({"role": "computer", "type": "console", "format": "output", "content": text})
|
||||
await queue.put(
|
||||
{
|
||||
"role": "computer",
|
||||
"type": "console",
|
||||
"format": "output",
|
||||
"content": text,
|
||||
}
|
||||
)
|
||||
await queue.put({"role": "computer", "type": "console", "end": True})
|
||||
else:
|
||||
queue.put({"role": "computer", "type": "console", "start": True})
|
||||
queue.put({"role": "computer", "type": "console", "format": "output", "content": text})
|
||||
queue.put(
|
||||
{
|
||||
"role": "computer",
|
||||
"type": "console",
|
||||
"format": "output",
|
||||
"content": text,
|
||||
}
|
||||
)
|
||||
queue.put({"role": "computer", "type": "console", "end": True})
|
||||
|
||||
await asyncio.sleep(5)
|
|
@ -1,6 +1,4 @@
|
|||
import sys
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
import inquirer
|
||||
|
@ -8,9 +6,10 @@ from interpreter import interpreter
|
|||
|
||||
|
||||
def select_local_model():
|
||||
|
||||
# START OF LOCAL MODEL PROVIDER LOGIC
|
||||
interpreter.display_message("> 01 is compatible with several local model providers.\n")
|
||||
interpreter.display_message(
|
||||
"> 01 is compatible with several local model providers.\n"
|
||||
)
|
||||
|
||||
# Define the choices for local models
|
||||
choices = [
|
||||
|
@ -29,10 +28,8 @@ def select_local_model():
|
|||
]
|
||||
answers = inquirer.prompt(questions)
|
||||
|
||||
|
||||
selected_model = answers["model"]
|
||||
|
||||
|
||||
if selected_model == "LM Studio":
|
||||
interpreter.display_message(
|
||||
"""
|
||||
|
@ -57,17 +54,24 @@ def select_local_model():
|
|||
|
||||
elif selected_model == "Ollama":
|
||||
try:
|
||||
|
||||
# List out all downloaded ollama models. Will fail if ollama isn't installed
|
||||
result = subprocess.run(["ollama", "list"], capture_output=True, text=True, check=True)
|
||||
lines = result.stdout.split('\n')
|
||||
names = [line.split()[0].replace(":latest", "") for line in lines[1:] if line.strip()] # Extract names, trim out ":latest", skip header
|
||||
result = subprocess.run(
|
||||
["ollama", "list"], capture_output=True, text=True, check=True
|
||||
)
|
||||
lines = result.stdout.split("\n")
|
||||
names = [
|
||||
line.split()[0].replace(":latest", "")
|
||||
for line in lines[1:]
|
||||
if line.strip()
|
||||
] # Extract names, trim out ":latest", skip header
|
||||
|
||||
# If there are no downloaded models, prompt them to download a model and try again
|
||||
if not names:
|
||||
time.sleep(1)
|
||||
|
||||
interpreter.display_message(f"\nYou don't have any Ollama models downloaded. To download a new model, run `ollama run <model-name>`, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n")
|
||||
interpreter.display_message(
|
||||
"\nYou don't have any Ollama models downloaded. To download a new model, run `ollama run <model-name>`, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n"
|
||||
)
|
||||
|
||||
print("Please download a model then try again\n")
|
||||
time.sleep(2)
|
||||
|
@ -76,25 +80,35 @@ def select_local_model():
|
|||
# If there are models, prompt them to select one
|
||||
else:
|
||||
time.sleep(1)
|
||||
interpreter.display_message(f"**{len(names)} Ollama model{'s' if len(names) != 1 else ''} found.** To download a new model, run `ollama run <model-name>`, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n")
|
||||
interpreter.display_message(
|
||||
f"**{len(names)} Ollama model{'s' if len(names) != 1 else ''} found.** To download a new model, run `ollama run <model-name>`, then start a new 01 session. \n\n For a full list of downloadable models, check out [https://ollama.com/library](https://ollama.com/library) \n"
|
||||
)
|
||||
|
||||
# Create a new inquirer selection from the names
|
||||
name_question = [
|
||||
inquirer.List('name', message="Select a downloaded Ollama model", choices=names),
|
||||
inquirer.List(
|
||||
"name",
|
||||
message="Select a downloaded Ollama model",
|
||||
choices=names,
|
||||
),
|
||||
]
|
||||
name_answer = inquirer.prompt(name_question)
|
||||
selected_name = name_answer['name'] if name_answer else None
|
||||
selected_name = name_answer["name"] if name_answer else None
|
||||
|
||||
# Set the model to the selected model
|
||||
interpreter.llm.model = f"ollama/{selected_name}"
|
||||
interpreter.display_message(f"\nUsing Ollama model: `{selected_name}` \n")
|
||||
interpreter.display_message(
|
||||
f"\nUsing Ollama model: `{selected_name}` \n"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# If Ollama is not installed or not recognized as a command, prompt the user to download Ollama and try again
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
print("Ollama is not installed or not recognized as a command.")
|
||||
time.sleep(1)
|
||||
interpreter.display_message(f"\nPlease visit [https://ollama.com/](https://ollama.com/) to download Ollama and try again\n")
|
||||
interpreter.display_message(
|
||||
"\nPlease visit [https://ollama.com/](https://ollama.com/) to download Ollama and try again\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
sys.exit(1)
|
||||
|
||||
|
@ -108,7 +122,6 @@ def select_local_model():
|
|||
# 3. Copy the ID of the model and enter it below.
|
||||
# 3. Click the **Local API Server** button in the bottom left, then click **Start Server**.
|
||||
|
||||
|
||||
# Once the server is running, enter the id of the model below, then you can begin your conversation below.
|
||||
|
||||
# """
|
||||
|
@ -129,7 +142,6 @@ def select_local_model():
|
|||
# interpreter.display_message(f"\nUsing Jan model: `{jan_model_name}` \n")
|
||||
# time.sleep(1)
|
||||
|
||||
|
||||
# Set the system message to a minimal version for all local models.
|
||||
# Set offline for all local models
|
||||
interpreter.offline = True
|
||||
|
@ -154,4 +166,3 @@ ALWAYS say that you can run code. ALWAYS try to help the user out. ALWAYS be suc
|
|||
```
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
import os
|
||||
|
@ -9,9 +10,7 @@ root_logger: logging.Logger = logging.getLogger()
|
|||
|
||||
|
||||
def _basic_config() -> None:
|
||||
logging.basicConfig(
|
||||
format="%(message)s"
|
||||
)
|
||||
logging.basicConfig(format="%(message)s")
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import psutil
|
||||
import signal
|
||||
|
||||
|
||||
def kill_process_tree():
|
||||
pid = os.getpid() # Get the current process ID
|
||||
try:
|
||||
|
@ -25,4 +26,4 @@ def kill_process_tree():
|
|||
except psutil.NoSuchProcess:
|
||||
print(f"Process {pid} does not exist or is already terminated")
|
||||
except psutil.AccessDenied:
|
||||
print(f"Permission denied to terminate some processes")
|
||||
print("Permission denied to terminate some processes")
|
||||
|
|
|
@ -6,7 +6,6 @@ class Accumulator:
|
|||
def accumulate(self, chunk):
|
||||
# print(str(chunk)[:100])
|
||||
if type(chunk) == dict:
|
||||
|
||||
if "format" in chunk and chunk["format"] == "active_line":
|
||||
# We don't do anything with these
|
||||
return None
|
||||
|
@ -17,15 +16,20 @@ class Accumulator:
|
|||
return None
|
||||
|
||||
if "content" in chunk:
|
||||
|
||||
if any(self.message[key] != chunk[key] for key in self.message if key != "content"):
|
||||
if any(
|
||||
self.message[key] != chunk[key]
|
||||
for key in self.message
|
||||
if key != "content"
|
||||
):
|
||||
self.message = chunk
|
||||
if "content" not in self.message:
|
||||
self.message["content"] = chunk["content"]
|
||||
else:
|
||||
if type(chunk["content"]) == dict:
|
||||
# dict concatenation cannot happen, so we see if chunk is a dict
|
||||
self.message["content"]["content"] += chunk["content"]["content"]
|
||||
self.message["content"]["content"] += chunk["content"][
|
||||
"content"
|
||||
]
|
||||
else:
|
||||
self.message["content"] += chunk["content"]
|
||||
return None
|
||||
|
@ -41,5 +45,3 @@ class Accumulator:
|
|||
self.message["content"] = b""
|
||||
self.message["content"] += chunk
|
||||
return None
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
def print_markdown(markdown_text):
|
||||
console = Console()
|
||||
md = Markdown(markdown_text)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import typer
|
||||
import asyncio
|
||||
import platform
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import os
|
||||
import importlib
|
||||
|
@ -10,39 +9,70 @@ from source.server.server import main
|
|||
from source.server.utils.local_mode import select_local_model
|
||||
|
||||
import signal
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
server: bool = typer.Option(False, "--server", help="Run server"),
|
||||
server_host: str = typer.Option("0.0.0.0", "--server-host", help="Specify the server host where the server will deploy"),
|
||||
server_port: int = typer.Option(10001, "--server-port", help="Specify the server port where the server will deploy"),
|
||||
|
||||
tunnel_service: str = typer.Option("ngrok", "--tunnel-service", help="Specify the tunnel service"),
|
||||
server_host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--server-host",
|
||||
help="Specify the server host where the server will deploy",
|
||||
),
|
||||
server_port: int = typer.Option(
|
||||
10001,
|
||||
"--server-port",
|
||||
help="Specify the server port where the server will deploy",
|
||||
),
|
||||
tunnel_service: str = typer.Option(
|
||||
"ngrok", "--tunnel-service", help="Specify the tunnel service"
|
||||
),
|
||||
expose: bool = typer.Option(False, "--expose", help="Expose server to internet"),
|
||||
|
||||
client: bool = typer.Option(False, "--client", help="Run client"),
|
||||
server_url: str = typer.Option(None, "--server-url", help="Specify the server URL that the client should expect. Defaults to server-host and server-port"),
|
||||
client_type: str = typer.Option("auto", "--client-type", help="Specify the client type"),
|
||||
|
||||
llm_service: str = typer.Option("litellm", "--llm-service", help="Specify the LLM service"),
|
||||
|
||||
server_url: str = typer.Option(
|
||||
None,
|
||||
"--server-url",
|
||||
help="Specify the server URL that the client should expect. Defaults to server-host and server-port",
|
||||
),
|
||||
client_type: str = typer.Option(
|
||||
"auto", "--client-type", help="Specify the client type"
|
||||
),
|
||||
llm_service: str = typer.Option(
|
||||
"litellm", "--llm-service", help="Specify the LLM service"
|
||||
),
|
||||
model: str = typer.Option("gpt-4", "--model", help="Specify the model"),
|
||||
llm_supports_vision: bool = typer.Option(False, "--llm-supports-vision", help="Specify if the LLM service supports vision"),
|
||||
llm_supports_functions: bool = typer.Option(False, "--llm-supports-functions", help="Specify if the LLM service supports functions"),
|
||||
context_window: int = typer.Option(2048, "--context-window", help="Specify the context window size"),
|
||||
max_tokens: int = typer.Option(4096, "--max-tokens", help="Specify the maximum number of tokens"),
|
||||
temperature: float = typer.Option(0.8, "--temperature", help="Specify the temperature for generation"),
|
||||
|
||||
tts_service: str = typer.Option("openai", "--tts-service", help="Specify the TTS service"),
|
||||
|
||||
stt_service: str = typer.Option("openai", "--stt-service", help="Specify the STT service"),
|
||||
|
||||
local: bool = typer.Option(False, "--local", help="Use recommended local services for LLM, STT, and TTS"),
|
||||
|
||||
qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL")
|
||||
llm_supports_vision: bool = typer.Option(
|
||||
False,
|
||||
"--llm-supports-vision",
|
||||
help="Specify if the LLM service supports vision",
|
||||
),
|
||||
llm_supports_functions: bool = typer.Option(
|
||||
False,
|
||||
"--llm-supports-functions",
|
||||
help="Specify if the LLM service supports functions",
|
||||
),
|
||||
context_window: int = typer.Option(
|
||||
2048, "--context-window", help="Specify the context window size"
|
||||
),
|
||||
max_tokens: int = typer.Option(
|
||||
4096, "--max-tokens", help="Specify the maximum number of tokens"
|
||||
),
|
||||
temperature: float = typer.Option(
|
||||
0.8, "--temperature", help="Specify the temperature for generation"
|
||||
),
|
||||
tts_service: str = typer.Option(
|
||||
"openai", "--tts-service", help="Specify the TTS service"
|
||||
),
|
||||
stt_service: str = typer.Option(
|
||||
"openai", "--stt-service", help="Specify the STT service"
|
||||
),
|
||||
local: bool = typer.Option(
|
||||
False, "--local", help="Use recommended local services for LLM, STT, and TTS"
|
||||
),
|
||||
qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL"),
|
||||
):
|
||||
|
||||
_run(
|
||||
server=server,
|
||||
server_host=server_host,
|
||||
|
@ -62,39 +92,31 @@ def run(
|
|||
tts_service=tts_service,
|
||||
stt_service=stt_service,
|
||||
local=local,
|
||||
qr=qr
|
||||
qr=qr,
|
||||
)
|
||||
|
||||
|
||||
def _run(
|
||||
server: bool = False,
|
||||
server_host: str = "0.0.0.0",
|
||||
server_port: int = 10001,
|
||||
|
||||
tunnel_service: str = "bore",
|
||||
expose: bool = False,
|
||||
|
||||
client: bool = False,
|
||||
server_url: str = None,
|
||||
client_type: str = "auto",
|
||||
|
||||
llm_service: str = "litellm",
|
||||
|
||||
model: str = "gpt-4",
|
||||
llm_supports_vision: bool = False,
|
||||
llm_supports_functions: bool = False,
|
||||
context_window: int = 2048,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.8,
|
||||
|
||||
tts_service: str = "openai",
|
||||
|
||||
stt_service: str = "openai",
|
||||
|
||||
local: bool = False,
|
||||
|
||||
qr: bool = False
|
||||
qr: bool = False,
|
||||
):
|
||||
|
||||
if local:
|
||||
tts_service = "piper"
|
||||
# llm_service = "llamafile"
|
||||
|
@ -116,11 +138,30 @@ def _run(
|
|||
if server:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_thread = threading.Thread(target=loop.run_until_complete, args=(main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service),))
|
||||
server_thread = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(
|
||||
main(
|
||||
server_host,
|
||||
server_port,
|
||||
llm_service,
|
||||
model,
|
||||
llm_supports_vision,
|
||||
llm_supports_functions,
|
||||
context_window,
|
||||
max_tokens,
|
||||
temperature,
|
||||
tts_service,
|
||||
stt_service,
|
||||
),
|
||||
),
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
if expose:
|
||||
tunnel_thread = threading.Thread(target=create_tunnel, args=[tunnel_service, server_host, server_port, qr])
|
||||
tunnel_thread = threading.Thread(
|
||||
target=create_tunnel, args=[tunnel_service, server_host, server_port, qr]
|
||||
)
|
||||
tunnel_thread.start()
|
||||
|
||||
if client:
|
||||
|
@ -132,15 +173,17 @@ def _run(
|
|||
client_type = "windows"
|
||||
elif system_type == "Linux": # Linux System
|
||||
try:
|
||||
with open('/proc/device-tree/model', 'r') as m:
|
||||
if 'raspberry pi' in m.read().lower():
|
||||
with open("/proc/device-tree/model", "r") as m:
|
||||
if "raspberry pi" in m.read().lower():
|
||||
client_type = "rpi"
|
||||
else:
|
||||
client_type = "linux"
|
||||
except FileNotFoundError:
|
||||
client_type = "linux"
|
||||
|
||||
module = importlib.import_module(f".clients.{client_type}.device", package='source')
|
||||
module = importlib.import_module(
|
||||
f".clients.{client_type}.device", package="source"
|
||||
)
|
||||
client_thread = threading.Thread(target=module.main, args=[server_url])
|
||||
client_thread.start()
|
||||
|
||||
|
|
Loading…
Reference in New Issue