add async interpreter with coqui, openai, elevenlabs tts
This commit is contained in:
parent
2627fba481
commit
eee00ac026
|
@ -90,6 +90,7 @@ class Device:
|
||||||
self.audiosegments = asyncio.Queue()
|
self.audiosegments = asyncio.Queue()
|
||||||
self.server_url = ""
|
self.server_url = ""
|
||||||
self.ctrl_pressed = False
|
self.ctrl_pressed = False
|
||||||
|
self.tts_service = ""
|
||||||
self.playback_latency = None
|
self.playback_latency = None
|
||||||
|
|
||||||
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
|
def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
|
||||||
|
@ -164,30 +165,18 @@ class Device:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
audio = await self.audiosegments.get()
|
audio = await self.audiosegments.get()
|
||||||
# print("got audio segment!!!!")
|
if self.playback_latency and isinstance(audio, bytes):
|
||||||
if self.playback_latency:
|
|
||||||
elapsed_time = time.time() - self.playback_latency
|
elapsed_time = time.time() - self.playback_latency
|
||||||
print(f"Time from request to playback: {elapsed_time} seconds")
|
# print(f"Time from request to playback: {elapsed_time} seconds")
|
||||||
self.playback_latency = None
|
self.playback_latency = None
|
||||||
|
|
||||||
if audio is not None:
|
if self.tts_service == "elevenlabs":
|
||||||
mpv_process.stdin.write(audio) # type: ignore
|
mpv_process.stdin.write(audio) # type: ignore
|
||||||
mpv_process.stdin.flush() # type: ignore
|
mpv_process.stdin.flush() # type: ignore
|
||||||
"""
|
else:
|
||||||
args = ["ffplay", "-autoexit", "-", "-nodisp"]
|
play(audio)
|
||||||
proc = subprocess.Popen(
|
|
||||||
args=args,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
out, err = proc.communicate(input=audio)
|
|
||||||
proc.poll()
|
|
||||||
|
|
||||||
play(audio)
|
await asyncio.sleep(0.1)
|
||||||
"""
|
|
||||||
# self.audiosegments.remove(audio)
|
|
||||||
# await asyncio.sleep(0.1)
|
|
||||||
except asyncio.exceptions.CancelledError:
|
except asyncio.exceptions.CancelledError:
|
||||||
# This happens once at the start?
|
# This happens once at the start?
|
||||||
pass
|
pass
|
||||||
|
@ -342,24 +331,17 @@ class Device:
|
||||||
|
|
||||||
async def message_sender(self, websocket):
|
async def message_sender(self, websocket):
|
||||||
while True:
|
while True:
|
||||||
try:
|
message = await asyncio.get_event_loop().run_in_executor(
|
||||||
message = await asyncio.get_event_loop().run_in_executor(
|
None, send_queue.get
|
||||||
None, send_queue.get
|
)
|
||||||
)
|
if isinstance(message, bytes):
|
||||||
|
await websocket.send(message)
|
||||||
if isinstance(message, bytes):
|
else:
|
||||||
await websocket.send(message)
|
await websocket.send(json.dumps(message))
|
||||||
|
send_queue.task_done()
|
||||||
else:
|
await asyncio.sleep(0.01)
|
||||||
await websocket.send(json.dumps(message))
|
|
||||||
|
|
||||||
send_queue.task_done()
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
except:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def websocket_communication(self, WS_URL):
|
async def websocket_communication(self, WS_URL):
|
||||||
print("websocket communication was called!!!!")
|
|
||||||
show_connection_log = True
|
show_connection_log = True
|
||||||
|
|
||||||
async def exec_ws_communication(websocket):
|
async def exec_ws_communication(websocket):
|
||||||
|
@ -373,48 +355,48 @@ class Device:
|
||||||
asyncio.create_task(self.message_sender(websocket))
|
asyncio.create_task(self.message_sender(websocket))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0.0001)
|
await asyncio.sleep(0.01)
|
||||||
chunk = await websocket.recv()
|
chunk = await websocket.recv()
|
||||||
|
|
||||||
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
|
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
|
||||||
# print((f"Got this message from the server: {type(chunk)} {chunk}"))
|
# print("received chunk from server")
|
||||||
|
|
||||||
if type(chunk) == str:
|
if type(chunk) == str:
|
||||||
chunk = json.loads(chunk)
|
chunk = json.loads(chunk)
|
||||||
|
|
||||||
# message = accumulator.accumulate(chunk)
|
if self.tts_service == "elevenlabs":
|
||||||
message = chunk
|
message = chunk
|
||||||
|
else:
|
||||||
|
message = accumulator.accumulate(chunk)
|
||||||
|
|
||||||
if message == None:
|
if message == None:
|
||||||
# Will be None until we have a full message ready
|
# Will be None until we have a full message ready
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# At this point, we have our message
|
# At this point, we have our message
|
||||||
# print("checkpoint reached!", message)
|
if isinstance(message, bytes) or (
|
||||||
if isinstance(message, bytes):
|
message["type"] == "audio" and message["format"].startswith("bytes")
|
||||||
|
):
|
||||||
# if message["type"] == "audio" and message["format"].startswith("bytes"):
|
|
||||||
# Convert bytes to audio file
|
# Convert bytes to audio file
|
||||||
|
if self.tts_service == "elevenlabs":
|
||||||
|
audio_bytes = message
|
||||||
|
audio = audio_bytes
|
||||||
|
else:
|
||||||
|
audio_bytes = message["content"]
|
||||||
|
|
||||||
# audio_bytes = message["content"]
|
# Create an AudioSegment instance with the raw data
|
||||||
audio_bytes = message
|
audio = AudioSegment(
|
||||||
|
# raw audio data (bytes)
|
||||||
|
data=audio_bytes,
|
||||||
|
# signed 16-bit little-endian format
|
||||||
|
sample_width=2,
|
||||||
|
# 16,000 Hz frame rate
|
||||||
|
frame_rate=22050,
|
||||||
|
# mono sound
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Create an AudioSegment instance with the raw data
|
await self.audiosegments.put(audio)
|
||||||
"""
|
|
||||||
audio = AudioSegment(
|
|
||||||
# raw audio data (bytes)
|
|
||||||
data=audio_bytes,
|
|
||||||
# signed 16-bit little-endian format
|
|
||||||
sample_width=2,
|
|
||||||
# 24,000 Hz frame rate
|
|
||||||
frame_rate=16000,
|
|
||||||
# mono sound
|
|
||||||
channels=1,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# print("audio segment was created")
|
|
||||||
await self.audiosegments.put(audio_bytes)
|
|
||||||
|
|
||||||
# await self.audiosegments.put(audio)
|
|
||||||
|
|
||||||
# Run the code if that's the client's job
|
# Run the code if that's the client's job
|
||||||
if os.getenv("CODE_RUNNER") == "client":
|
if os.getenv("CODE_RUNNER") == "client":
|
||||||
|
@ -434,29 +416,26 @@ class Device:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error while attempting to connect: {e}")
|
logger.error(f"Error while attempting to connect: {e}")
|
||||||
else:
|
else:
|
||||||
print("websocket url is", WS_URL)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(WS_URL) as websocket:
|
async with websockets.connect(WS_URL) as websocket:
|
||||||
print("awaiting exec_ws_communication")
|
|
||||||
await exec_ws_communication(websocket)
|
await exec_ws_communication(websocket)
|
||||||
except:
|
except:
|
||||||
logger.info(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
if show_connection_log:
|
if show_connection_log:
|
||||||
logger.info(f"Connecting to `{WS_URL}`...")
|
logger.info(f"Connecting to `{WS_URL}`...")
|
||||||
show_connection_log = False
|
show_connection_log = False
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
async def start_async(self):
|
async def start_async(self):
|
||||||
print("start async was called!!!!!")
|
|
||||||
# Configuration for WebSocket
|
# Configuration for WebSocket
|
||||||
WS_URL = f"ws://{self.server_url}"
|
WS_URL = f"ws://{self.server_url}"
|
||||||
|
|
||||||
# Start the WebSocket communication
|
# Start the WebSocket communication
|
||||||
asyncio.create_task(self.websocket_communication(WS_URL))
|
asyncio.create_task(self.websocket_communication(WS_URL))
|
||||||
|
|
||||||
# Start watching the kernel if it's your job to do that
|
# Start watching the kernel if it's your job to do that
|
||||||
if os.getenv("CODE_RUNNER") == "client":
|
if os.getenv("CODE_RUNNER") == "client":
|
||||||
|
# client is not running code!
|
||||||
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
|
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
|
||||||
|
|
||||||
asyncio.create_task(self.play_audiosegments())
|
asyncio.create_task(self.play_audiosegments())
|
||||||
|
@ -488,10 +467,8 @@ class Device:
|
||||||
on_press=self.on_press, on_release=self.on_release
|
on_press=self.on_press, on_release=self.on_release
|
||||||
)
|
)
|
||||||
listener.start()
|
listener.start()
|
||||||
print("listener for keyboard started!!!!!")
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
print("device was started!!!!!!")
|
|
||||||
if os.getenv("TEACH_MODE") != "True":
|
if os.getenv("TEACH_MODE") != "True":
|
||||||
asyncio.run(self.start_async())
|
asyncio.run(self.start_async())
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
|
@ -3,8 +3,9 @@ from ..base_device import Device
|
||||||
device = Device()
|
device = Device()
|
||||||
|
|
||||||
|
|
||||||
def main(server_url):
|
def main(server_url, tts_service):
|
||||||
device.server_url = server_url
|
device.server_url = server_url
|
||||||
|
device.tts_service = tts_service
|
||||||
device.start()
|
device.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,9 @@ from ..base_device import Device
|
||||||
device = Device()
|
device = Device()
|
||||||
|
|
||||||
|
|
||||||
def main(server_url):
|
def main(server_url, tts_service):
|
||||||
device.server_url = server_url
|
device.server_url = server_url
|
||||||
|
device.tts_service = tts_service
|
||||||
device.start()
|
device.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,9 @@ from ..base_device import Device
|
||||||
device = Device()
|
device = Device()
|
||||||
|
|
||||||
|
|
||||||
def main(server_url):
|
def main(server_url, tts_service):
|
||||||
device.server_url = server_url
|
device.server_url = server_url
|
||||||
|
device.tts_service = tts_service
|
||||||
device.start()
|
device.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,16 +10,9 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
###
|
###
|
||||||
|
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
from RealtimeTTS import (
|
|
||||||
TextToAudioStream,
|
from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine
|
||||||
OpenAIEngine,
|
|
||||||
CoquiEngine,
|
|
||||||
ElevenlabsEngine,
|
|
||||||
SystemEngine,
|
|
||||||
GTTSEngine,
|
|
||||||
)
|
|
||||||
from RealtimeSTT import AudioToTextRecorder
|
from RealtimeSTT import AudioToTextRecorder
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -29,9 +22,9 @@ import os
|
||||||
|
|
||||||
class AsyncInterpreter:
|
class AsyncInterpreter:
|
||||||
def __init__(self, interpreter):
|
def __init__(self, interpreter):
|
||||||
self.stt_latency = None
|
# self.stt_latency = None
|
||||||
self.tts_latency = None
|
# self.tts_latency = None
|
||||||
self.interpreter_latency = None
|
# self.interpreter_latency = None
|
||||||
self.interpreter = interpreter
|
self.interpreter = interpreter
|
||||||
|
|
||||||
# STT
|
# STT
|
||||||
|
@ -45,12 +38,9 @@ class AsyncInterpreter:
|
||||||
engine = CoquiEngine()
|
engine = CoquiEngine()
|
||||||
elif self.interpreter.tts == "openai":
|
elif self.interpreter.tts == "openai":
|
||||||
engine = OpenAIEngine()
|
engine = OpenAIEngine()
|
||||||
elif self.interpreter.tts == "gtts":
|
|
||||||
engine = GTTSEngine()
|
|
||||||
elif self.interpreter.tts == "elevenlabs":
|
elif self.interpreter.tts == "elevenlabs":
|
||||||
engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"])
|
engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"])
|
||||||
elif self.interpreter.tts == "system":
|
engine.set_voice("Michael")
|
||||||
engine = SystemEngine()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
|
raise ValueError(f"Unsupported TTS engine: {self.interpreter.tts}")
|
||||||
self.tts = TextToAudioStream(engine)
|
self.tts = TextToAudioStream(engine)
|
||||||
|
@ -112,111 +102,96 @@ class AsyncInterpreter:
|
||||||
# print("ADDING TO QUEUE:", chunk)
|
# print("ADDING TO QUEUE:", chunk)
|
||||||
asyncio.create_task(self._add_to_queue(self._output_queue, chunk))
|
asyncio.create_task(self._add_to_queue(self._output_queue, chunk))
|
||||||
|
|
||||||
|
def generate(self, message, start_interpreter):
|
||||||
|
last_lmc_start_flag = self._last_lmc_start_flag
|
||||||
|
self.interpreter.messages = self.active_chat_messages
|
||||||
|
|
||||||
|
# print("message is", message)
|
||||||
|
|
||||||
|
for chunk in self.interpreter.chat(message, display=True, stream=True):
|
||||||
|
|
||||||
|
if self._last_lmc_start_flag != last_lmc_start_flag:
|
||||||
|
# self.beeper.stop()
|
||||||
|
break
|
||||||
|
|
||||||
|
# self.add_to_output_queue_sync(chunk) # To send text, not just audio
|
||||||
|
|
||||||
|
content = chunk.get("content")
|
||||||
|
|
||||||
|
# Handle message blocks
|
||||||
|
if chunk.get("type") == "message":
|
||||||
|
if content:
|
||||||
|
# self.beeper.stop()
|
||||||
|
|
||||||
|
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
|
||||||
|
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
|
||||||
|
# print("yielding ", content)
|
||||||
|
yield content
|
||||||
|
|
||||||
|
# Handle code blocks
|
||||||
|
elif chunk.get("type") == "code":
|
||||||
|
if "start" in chunk:
|
||||||
|
# self.beeper.start()
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Experimental: If the AI wants to type, we should type immediatly
|
||||||
|
if (
|
||||||
|
self.interpreter.messages[-1]
|
||||||
|
.get("content", "")
|
||||||
|
.startswith("computer.keyboard.write(")
|
||||||
|
):
|
||||||
|
keyboard.controller.type(content)
|
||||||
|
self._in_keyboard_write_block = True
|
||||||
|
if "end" in chunk and self._in_keyboard_write_block:
|
||||||
|
self._in_keyboard_write_block = False
|
||||||
|
# (This will make it so it doesn't type twice when the block executes)
|
||||||
|
if self.interpreter.messages[-1]["content"].startswith(
|
||||||
|
"computer.keyboard.write("
|
||||||
|
):
|
||||||
|
self.interpreter.messages[-1]["content"] = (
|
||||||
|
"dummy_variable = ("
|
||||||
|
+ self.interpreter.messages[-1]["content"][
|
||||||
|
len("computer.keyboard.write(") :
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send a completion signal
|
||||||
|
# end_interpreter = time.time()
|
||||||
|
# self.interpreter_latency = end_interpreter - start_interpreter
|
||||||
|
# print("INTERPRETER LATENCY", self.interpreter_latency)
|
||||||
|
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
|
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
|
||||||
"""
|
"""
|
||||||
self.interpreter.messages = self.active_chat_messages
|
self.interpreter.messages = self.active_chat_messages
|
||||||
|
|
||||||
# self.beeper.start()
|
|
||||||
|
|
||||||
self.stt.stop()
|
self.stt.stop()
|
||||||
# message = self.stt.text()
|
|
||||||
# print("THE MESSAGE:", message)
|
|
||||||
|
|
||||||
# accumulates the input queue message
|
|
||||||
input_queue = []
|
input_queue = []
|
||||||
while not self._input_queue.empty():
|
while not self._input_queue.empty():
|
||||||
input_queue.append(self._input_queue.get())
|
input_queue.append(self._input_queue.get())
|
||||||
|
|
||||||
# print("INPUT QUEUE:", input_queue)
|
# start_stt = time.time()
|
||||||
# message = [i for i in input_queue if i["type"] == "message"][0]["content"]
|
|
||||||
start_stt = time.time()
|
|
||||||
message = self.stt.text()
|
message = self.stt.text()
|
||||||
end_stt = time.time()
|
# end_stt = time.time()
|
||||||
self.stt_latency = end_stt - start_stt
|
# self.stt_latency = end_stt - start_stt
|
||||||
print("STT LATENCY", self.stt_latency)
|
# print("STT LATENCY", self.stt_latency)
|
||||||
|
|
||||||
# print(message)
|
# print(message)
|
||||||
end_interpreter = 0
|
|
||||||
|
|
||||||
# print(message)
|
|
||||||
def generate(message):
|
|
||||||
last_lmc_start_flag = self._last_lmc_start_flag
|
|
||||||
self.interpreter.messages = self.active_chat_messages
|
|
||||||
# print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
|
|
||||||
# print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages)
|
|
||||||
print("message is", message)
|
|
||||||
|
|
||||||
for chunk in self.interpreter.chat(message, display=True, stream=True):
|
|
||||||
|
|
||||||
if self._last_lmc_start_flag != last_lmc_start_flag:
|
|
||||||
# self.beeper.stop()
|
|
||||||
break
|
|
||||||
|
|
||||||
# self.add_to_output_queue_sync(chunk) # To send text, not just audio
|
|
||||||
|
|
||||||
content = chunk.get("content")
|
|
||||||
|
|
||||||
# Handle message blocks
|
|
||||||
if chunk.get("type") == "message":
|
|
||||||
if content:
|
|
||||||
# self.beeper.stop()
|
|
||||||
|
|
||||||
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
|
|
||||||
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
|
|
||||||
# print("yielding this", content)
|
|
||||||
yield content
|
|
||||||
|
|
||||||
# Handle code blocks
|
|
||||||
elif chunk.get("type") == "code":
|
|
||||||
if "start" in chunk:
|
|
||||||
# self.beeper.start()
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Experimental: If the AI wants to type, we should type immediatly
|
|
||||||
if (
|
|
||||||
self.interpreter.messages[-1]
|
|
||||||
.get("content", "")
|
|
||||||
.startswith("computer.keyboard.write(")
|
|
||||||
):
|
|
||||||
keyboard.controller.type(content)
|
|
||||||
self._in_keyboard_write_block = True
|
|
||||||
if "end" in chunk and self._in_keyboard_write_block:
|
|
||||||
self._in_keyboard_write_block = False
|
|
||||||
# (This will make it so it doesn't type twice when the block executes)
|
|
||||||
if self.interpreter.messages[-1]["content"].startswith(
|
|
||||||
"computer.keyboard.write("
|
|
||||||
):
|
|
||||||
self.interpreter.messages[-1]["content"] = (
|
|
||||||
"dummy_variable = ("
|
|
||||||
+ self.interpreter.messages[-1]["content"][
|
|
||||||
len("computer.keyboard.write(") :
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send a completion signal
|
|
||||||
end_interpreter = time.time()
|
|
||||||
self.interpreter_latency = end_interpreter - start_interpreter
|
|
||||||
print("INTERPRETER LATENCY", self.interpreter_latency)
|
|
||||||
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})
|
|
||||||
|
|
||||||
# Feed generate to RealtimeTTS
|
# Feed generate to RealtimeTTS
|
||||||
self.add_to_output_queue_sync(
|
self.add_to_output_queue_sync(
|
||||||
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
|
{"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
|
||||||
)
|
)
|
||||||
start_interpreter = time.time()
|
start_interpreter = time.time()
|
||||||
text_iterator = generate(message)
|
text_iterator = self.generate(message, start_interpreter)
|
||||||
|
|
||||||
self.tts.feed(text_iterator)
|
self.tts.feed(text_iterator)
|
||||||
|
|
||||||
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
|
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
|
||||||
|
|
||||||
while True:
|
|
||||||
if self.tts.is_playing():
|
|
||||||
start_tts = time.time()
|
|
||||||
|
|
||||||
break
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
# print("is_playing", self.tts.is_playing())
|
# print("is_playing", self.tts.is_playing())
|
||||||
|
@ -229,14 +204,14 @@ class AsyncInterpreter:
|
||||||
"end": True,
|
"end": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
end_tts = time.time()
|
# end_tts = time.time()
|
||||||
self.tts_latency = end_tts - start_tts
|
# self.tts_latency = end_tts - self.tts.stream_start_time
|
||||||
print("TTS LATENCY", self.tts_latency)
|
# print("TTS LATENCY", self.tts_latency)
|
||||||
self.tts.stop()
|
self.tts.stop()
|
||||||
break
|
break
|
||||||
|
|
||||||
async def _on_tts_chunk_async(self, chunk):
|
async def _on_tts_chunk_async(self, chunk):
|
||||||
# print("SENDING TTS CHUNK")
|
# print("adding chunk to queue")
|
||||||
await self._add_to_queue(self._output_queue, chunk)
|
await self._add_to_queue(self._output_queue, chunk)
|
||||||
|
|
||||||
def on_tts_chunk(self, chunk):
|
def on_tts_chunk(self, chunk):
|
||||||
|
@ -244,4 +219,5 @@ class AsyncInterpreter:
|
||||||
asyncio.run(self._on_tts_chunk_async(chunk))
|
asyncio.run(self._on_tts_chunk_async(chunk))
|
||||||
|
|
||||||
async def output(self):
|
async def output(self):
|
||||||
|
# print("outputting chunks")
|
||||||
return await self._output_queue.get()
|
return await self._output_queue.get()
|
||||||
|
|
|
@ -1,43 +1,38 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
from fastapi import FastAPI, WebSocket, Header
|
from fastapi import FastAPI, WebSocket
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
|
from .i import configure_interpreter
|
||||||
from interpreter import interpreter as base_interpreter
|
from interpreter import interpreter as base_interpreter
|
||||||
from .async_interpreter import AsyncInterpreter
|
from .async_interpreter import AsyncInterpreter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from openai import OpenAI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# import sentry_sdk
|
|
||||||
|
|
||||||
base_interpreter.system_message = (
|
|
||||||
"You are a helpful assistant that can answer questions and help with tasks."
|
|
||||||
)
|
|
||||||
base_interpreter.computer.import_computer_api = False
|
|
||||||
base_interpreter.llm.model = "groq/llama3-8b-8192"
|
|
||||||
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
|
|
||||||
base_interpreter.llm.supports_functions = False
|
|
||||||
base_interpreter.auto_run = True
|
|
||||||
|
|
||||||
os.environ["STT_RUNNER"] = "server"
|
os.environ["STT_RUNNER"] = "server"
|
||||||
os.environ["TTS_RUNNER"] = "server"
|
os.environ["TTS_RUNNER"] = "server"
|
||||||
|
|
||||||
# Parse command line arguments for port number
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="FastAPI server.")
|
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
"""
|
|
||||||
base_interpreter.tts = "coqui"
|
|
||||||
|
|
||||||
|
async def main(server_host, server_port, tts_service, asynchronous):
|
||||||
async def main(server_host, server_port):
|
if asynchronous:
|
||||||
interpreter = AsyncInterpreter(base_interpreter)
|
base_interpreter.system_message = (
|
||||||
|
"You are a helpful assistant that can answer questions and help with tasks."
|
||||||
|
)
|
||||||
|
base_interpreter.computer.import_computer_api = False
|
||||||
|
base_interpreter.llm.model = "groq/llama3-8b-8192"
|
||||||
|
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
|
||||||
|
base_interpreter.llm.supports_functions = False
|
||||||
|
base_interpreter.auto_run = True
|
||||||
|
base_interpreter.tts = tts_service
|
||||||
|
interpreter = AsyncInterpreter(base_interpreter)
|
||||||
|
else:
|
||||||
|
configured_interpreter = configure_interpreter(base_interpreter)
|
||||||
|
configured_interpreter.llm.supports_functions = True
|
||||||
|
configured_interpreter.tts = tts_service
|
||||||
|
interpreter = AsyncInterpreter(configured_interpreter)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -107,37 +102,6 @@ async def main(server_host, server_port):
|
||||||
server = Server(config)
|
server = Server(config)
|
||||||
await server.serve()
|
await server.serve()
|
||||||
|
|
||||||
class Rename(BaseModel):
|
|
||||||
input: str
|
|
||||||
|
|
||||||
@app.post("/rename-chat")
|
|
||||||
async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
|
|
||||||
print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙")
|
|
||||||
input_value = body_content.input
|
|
||||||
client = OpenAI(
|
|
||||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
|
||||||
api_key=x_api_key,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=0.3,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
completion = response["choices"][0]["message"]["content"]
|
|
||||||
return {"data": {"content": completion}}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
@ -6,6 +6,8 @@ import os
|
||||||
import importlib
|
import importlib
|
||||||
from source.server.tunnel import create_tunnel
|
from source.server.tunnel import create_tunnel
|
||||||
from source.server.async_server import main
|
from source.server.async_server import main
|
||||||
|
|
||||||
|
# from source.server.server import main
|
||||||
from source.server.utils.local_mode import select_local_model
|
from source.server.utils.local_mode import select_local_model
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
@ -63,7 +65,7 @@ def run(
|
||||||
0.8, "--temperature", help="Specify the temperature for generation"
|
0.8, "--temperature", help="Specify the temperature for generation"
|
||||||
),
|
),
|
||||||
tts_service: str = typer.Option(
|
tts_service: str = typer.Option(
|
||||||
"openai", "--tts-service", help="Specify the TTS service"
|
"elevenlabs", "--tts-service", help="Specify the TTS service"
|
||||||
),
|
),
|
||||||
stt_service: str = typer.Option(
|
stt_service: str = typer.Option(
|
||||||
"openai", "--stt-service", help="Specify the STT service"
|
"openai", "--stt-service", help="Specify the STT service"
|
||||||
|
@ -75,6 +77,9 @@ def run(
|
||||||
mobile: bool = typer.Option(
|
mobile: bool = typer.Option(
|
||||||
False, "--mobile", help="Toggle server to support mobile app"
|
False, "--mobile", help="Toggle server to support mobile app"
|
||||||
),
|
),
|
||||||
|
asynchronous: bool = typer.Option(
|
||||||
|
False, "--async", help="use interpreter optimized for latency"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
_run(
|
_run(
|
||||||
server=server or mobile,
|
server=server or mobile,
|
||||||
|
@ -97,6 +102,7 @@ def run(
|
||||||
local=local,
|
local=local,
|
||||||
qr=qr or mobile,
|
qr=qr or mobile,
|
||||||
mobile=mobile,
|
mobile=mobile,
|
||||||
|
asynchronous=asynchronous,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,14 +122,15 @@ def _run(
|
||||||
context_window: int = 2048,
|
context_window: int = 2048,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
tts_service: str = "openai",
|
tts_service: str = "elevenlabs",
|
||||||
stt_service: str = "openai",
|
stt_service: str = "openai",
|
||||||
local: bool = False,
|
local: bool = False,
|
||||||
qr: bool = False,
|
qr: bool = False,
|
||||||
mobile: bool = False,
|
mobile: bool = False,
|
||||||
|
asynchronous: bool = False,
|
||||||
):
|
):
|
||||||
if local:
|
if local:
|
||||||
tts_service = "piper"
|
tts_service = "coqui"
|
||||||
# llm_service = "llamafile"
|
# llm_service = "llamafile"
|
||||||
stt_service = "local-whisper"
|
stt_service = "local-whisper"
|
||||||
select_local_model()
|
select_local_model()
|
||||||
|
@ -154,6 +161,8 @@ def _run(
|
||||||
main(
|
main(
|
||||||
server_host,
|
server_host,
|
||||||
server_port,
|
server_port,
|
||||||
|
tts_service,
|
||||||
|
asynchronous,
|
||||||
# llm_service,
|
# llm_service,
|
||||||
# model,
|
# model,
|
||||||
# llm_supports_vision,
|
# llm_supports_vision,
|
||||||
|
@ -161,7 +170,6 @@ def _run(
|
||||||
# context_window,
|
# context_window,
|
||||||
# max_tokens,
|
# max_tokens,
|
||||||
# temperature,
|
# temperature,
|
||||||
# tts_service,
|
|
||||||
# stt_service,
|
# stt_service,
|
||||||
# mobile,
|
# mobile,
|
||||||
),
|
),
|
||||||
|
@ -180,7 +188,6 @@ def _run(
|
||||||
system_type = platform.system()
|
system_type = platform.system()
|
||||||
if system_type == "Darwin": # Mac OS
|
if system_type == "Darwin": # Mac OS
|
||||||
client_type = "mac"
|
client_type = "mac"
|
||||||
print("initiating mac device with base device!!!")
|
|
||||||
elif system_type == "Windows": # Windows System
|
elif system_type == "Windows": # Windows System
|
||||||
client_type = "windows"
|
client_type = "windows"
|
||||||
elif system_type == "Linux": # Linux System
|
elif system_type == "Linux": # Linux System
|
||||||
|
@ -196,9 +203,10 @@ def _run(
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(
|
||||||
f".clients.{client_type}.device", package="source"
|
f".clients.{client_type}.device", package="source"
|
||||||
)
|
)
|
||||||
# server_url = "0.0.0.0:8000"
|
|
||||||
client_thread = threading.Thread(target=module.main, args=[server_url])
|
client_thread = threading.Thread(
|
||||||
print("client thread started")
|
target=module.main, args=[server_url, tts_service]
|
||||||
|
)
|
||||||
client_thread.start()
|
client_thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue