add local stt & tts, add anticipation logic, remove video context accumulation
This commit is contained in:
parent
bd6f530be7
commit
6110e70430
|
@ -1,10 +1,13 @@
|
|||
from livekit.rtc import VideoStream
|
||||
from livekit.rtc import VideoStream, VideoFrame, VideoBufferType
|
||||
from livekit.agents import JobContext
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
from livekit.rtc import VideoFrame
|
||||
import asyncio
|
||||
from typing import Callable, Coroutine, Any
|
||||
|
||||
|
||||
# Interval settings
|
||||
INTERVAL = 30 # seconds
|
||||
|
||||
# Define the path to the log file
|
||||
LOG_FILE_PATH = 'video_processor.txt'
|
||||
|
@ -20,34 +23,71 @@ def log_message(message: str):
|
|||
log_file.write(f"{timestamp} - {message}\n")
|
||||
|
||||
class RemoteVideoProcessor:
|
||||
"""Processes video frames from a remote participant's video stream."""
|
||||
|
||||
def __init__(self, video_stream: VideoStream, job_ctx: JobContext):
|
||||
log_message("Initializing RemoteVideoProcessor")
|
||||
self.video_stream = video_stream
|
||||
self.job_ctx = job_ctx
|
||||
self.current_frame = None # Store the latest VideoFrame
|
||||
self.current_frame = None
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
|
||||
self.interval = INTERVAL
|
||||
self.video_context = False
|
||||
self.last_capture_time = 0
|
||||
|
||||
# Add callback for safety checks
|
||||
self.on_instruction_check: Callable[[VideoFrame], Coroutine[Any, Any, None]] | None = None
|
||||
|
||||
async def process_frames(self):
|
||||
log_message("Starting to process remote video frames.")
|
||||
"""Process incoming video frames."""
|
||||
async for frame_event in self.video_stream:
|
||||
try:
|
||||
video_frame = frame_event.frame
|
||||
timestamp = frame_event.timestamp_us
|
||||
rotation = frame_event.rotation
|
||||
|
||||
log_message(f"Processing frame at timestamp {timestamp/1000000:.3f}s")
|
||||
log_message(f"Frame details: size={video_frame.width}x{video_frame.height}, type={video_frame.type}")
|
||||
|
||||
# Store the current frame safely
|
||||
log_message(f"Received frame: width={video_frame.width}, height={video_frame.height}, type={video_frame.type}")
|
||||
async with self.lock:
|
||||
self.current_frame = video_frame
|
||||
|
||||
if self.video_context and self._check_interrupt(timestamp):
|
||||
self.last_capture_time = timestamp
|
||||
# Trigger instruction check callback if registered
|
||||
if self.on_instruction_check:
|
||||
await self.on_instruction_check(video_frame)
|
||||
|
||||
except Exception as e:
|
||||
log_message(f"Error processing frame: {e}")
|
||||
log_message(f"Error processing frame: {str(e)}")
|
||||
import traceback
|
||||
log_message(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def register_safety_check_callback(self, callback: Callable[[VideoFrame], Coroutine[Any, Any, None]]):
|
||||
"""Register a callback for safety checks"""
|
||||
self.on_instruction_check = callback
|
||||
log_message("Registered instruction check callback")
|
||||
|
||||
|
||||
async def get_current_frame(self) -> VideoFrame | None:
|
||||
"""Retrieve the current VideoFrame."""
|
||||
log_message("called get current frame")
|
||||
"""Get the most recent video frame."""
|
||||
log_message("Getting current frame")
|
||||
async with self.lock:
|
||||
log_message("retrieving current frame: " + str(self.current_frame))
|
||||
return self.current_frame
|
||||
if self.current_frame is None:
|
||||
log_message("No current frame available")
|
||||
return self.current_frame
|
||||
|
||||
|
||||
def set_video_context(self, context: bool):
|
||||
"""Set the video context."""
|
||||
log_message(f"Setting video context to: {context}")
|
||||
self.video_context = context
|
||||
|
||||
|
||||
def get_video_context(self) -> bool:
|
||||
"""Get the video context."""
|
||||
return self.video_context
|
||||
|
||||
|
||||
def _check_interrupt(self, timestamp: int) -> bool:
|
||||
"""Determine if the video context should be interrupted."""
|
||||
return timestamp - self.last_capture_time > self.interval * 1000000
|
||||
|
|
|
@ -2,10 +2,11 @@ import asyncio
|
|||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Literal, Awaitable
|
||||
|
||||
from livekit.agents import JobContext, WorkerOptions, cli
|
||||
from livekit.agents import JobContext, WorkerOptions, cli, transcription
|
||||
from livekit.agents.transcription import STTSegmentsForwarder
|
||||
from livekit.agents.llm import ChatContext
|
||||
from livekit import rtc
|
||||
|
@ -13,30 +14,23 @@ from livekit.agents.pipeline import VoicePipelineAgent
|
|||
from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia
|
||||
from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage
|
||||
from livekit.agents.llm import LLMStream
|
||||
from livekit.agents.stt import SpeechStream
|
||||
|
||||
from source.server.livekit.video_processor import RemoteVideoProcessor
|
||||
|
||||
from source.server.livekit.transcriptions import _forward_transcription
|
||||
from source.server.livekit.anticipation import handle_instruction_check
|
||||
from source.server.livekit.logger import log_message
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Define the path to the log file
|
||||
LOG_FILE_PATH = 'worker.txt'
|
||||
DEBUG = os.getenv('DEBUG', 'false').lower() == 'true'
|
||||
|
||||
def log_message(message: str):
|
||||
"""Append a message to the log file with a timestamp."""
|
||||
if not DEBUG:
|
||||
return
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
with open(LOG_FILE_PATH, 'a') as log_file:
|
||||
log_file.write(f"{timestamp} - {message}\n")
|
||||
|
||||
start_message = """Hi! You can hold the white circle below to speak to me.
|
||||
_room_lock = threading.Lock()
|
||||
_connected_rooms = set()
|
||||
|
||||
Try asking what I can do."""
|
||||
|
||||
|
||||
START_MESSAGE = "Hi! You can hold the white circle below to speak to me. Try asking what I can do."
|
||||
|
||||
# This function is the entrypoint for the agent.
|
||||
async def entrypoint(ctx: JobContext):
|
||||
|
@ -96,7 +90,7 @@ async def entrypoint(ctx: JobContext):
|
|||
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
|
||||
|
||||
# For debugging
|
||||
base_url = "http://127.0.0.1:8000/"
|
||||
base_url = "http://127.0.0.1:9000/"
|
||||
|
||||
open_interpreter = openai.LLM(
|
||||
model="open-interpreter", base_url=base_url, api_key="x"
|
||||
|
@ -105,11 +99,18 @@ async def entrypoint(ctx: JobContext):
|
|||
tts_provider = os.getenv('01_TTS', '').lower()
|
||||
stt_provider = os.getenv('01_STT', '').lower()
|
||||
|
||||
tts_provider = "elevenlabs"
|
||||
stt_provider = "deepgram"
|
||||
|
||||
# Add plugins here
|
||||
if tts_provider == 'openai':
|
||||
tts = openai.TTS()
|
||||
elif tts_provider == 'local':
|
||||
tts = openai.TTS(base_url="http://localhost:8000/v1")
|
||||
print("using local tts")
|
||||
elif tts_provider == 'elevenlabs':
|
||||
tts = elevenlabs.TTS()
|
||||
print("using elevenlabs tts")
|
||||
elif tts_provider == 'cartesia':
|
||||
tts = cartesia.TTS()
|
||||
else:
|
||||
|
@ -117,16 +118,20 @@ async def entrypoint(ctx: JobContext):
|
|||
|
||||
if stt_provider == 'deepgram':
|
||||
stt = deepgram.STT()
|
||||
elif stt_provider == 'local':
|
||||
stt = openai.STT(base_url="http://localhost:8001/v1")
|
||||
print("using local stt")
|
||||
else:
|
||||
raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
|
||||
|
||||
############################################################
|
||||
# initialize voice assistant states
|
||||
############################################################
|
||||
push_to_talk = True
|
||||
push_to_talk = False
|
||||
current_message: ChatMessage = ChatMessage(role='user')
|
||||
submitted_message: ChatMessage = ChatMessage(role='user')
|
||||
video_muted = False
|
||||
video_context = False
|
||||
|
||||
tasks = []
|
||||
############################################################
|
||||
|
@ -175,6 +180,7 @@ async def entrypoint(ctx: JobContext):
|
|||
|
||||
if remote_video_processor and not video_muted:
|
||||
video_frame = await remote_video_processor.get_current_frame()
|
||||
|
||||
if video_frame:
|
||||
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
|
||||
else:
|
||||
|
@ -202,7 +208,15 @@ async def entrypoint(ctx: JobContext):
|
|||
|
||||
# append image if available
|
||||
if remote_video_processor and not video_muted:
|
||||
video_frame = await remote_video_processor.get_current_frame()
|
||||
if remote_video_processor.get_video_context():
|
||||
log_message("context is true")
|
||||
log_message("retrieving timeline frame")
|
||||
video_frame = await remote_video_processor.get_timeline_frame()
|
||||
else:
|
||||
log_message("context is false")
|
||||
log_message("retrieving current frame")
|
||||
video_frame = await remote_video_processor.get_current_frame()
|
||||
|
||||
if video_frame:
|
||||
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
|
||||
log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}")
|
||||
|
@ -263,6 +277,19 @@ async def entrypoint(ctx: JobContext):
|
|||
############################################################
|
||||
# transcribe participant track
|
||||
############################################################
|
||||
async def _forward_transcription(
|
||||
stt_stream: SpeechStream,
|
||||
stt_forwarder: transcription.STTSegmentsForwarder,
|
||||
):
|
||||
"""Forward the transcription and log the transcript in the console"""
|
||||
async for ev in stt_stream:
|
||||
stt_forwarder.update(ev)
|
||||
if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
|
||||
print(ev.alternatives[0].text, end="")
|
||||
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
|
||||
print("\n")
|
||||
print(" -> ", ev.alternatives[0].text)
|
||||
|
||||
async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track):
|
||||
audio_stream = rtc.AudioStream(track)
|
||||
stt_forwarder = STTSegmentsForwarder(
|
||||
|
@ -297,8 +324,18 @@ async def entrypoint(ctx: JobContext):
|
|||
remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA)
|
||||
remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx)
|
||||
log_message("remote video processor." + str(remote_video_processor))
|
||||
|
||||
# Register safety check callback
|
||||
remote_video_processor.register_safety_check_callback(
|
||||
lambda frame: handle_instruction_check(assistant, frame)
|
||||
)
|
||||
|
||||
remote_video_processor.set_video_context(video_context)
|
||||
log_message(f"set video context to {video_context} from queued video context")
|
||||
|
||||
asyncio.create_task(remote_video_processor.process_frames())
|
||||
|
||||
|
||||
############################################################
|
||||
# on track muted callback
|
||||
############################################################
|
||||
|
@ -329,11 +366,12 @@ async def entrypoint(ctx: JobContext):
|
|||
local_participant = ctx.room.local_participant
|
||||
await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context")
|
||||
log_message("sent {CLEAR_CHAT} to chat_context for client to clear")
|
||||
await assistant.say(assistant.start_message)
|
||||
await assistant.say(START_MESSAGE)
|
||||
|
||||
|
||||
@ctx.room.on("data_received")
|
||||
def on_data_received(data: rtc.DataPacket):
|
||||
nonlocal video_context
|
||||
decoded_data = data.data.decode()
|
||||
log_message(f"received data from {data.topic}: {decoded_data}")
|
||||
if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}":
|
||||
|
@ -349,6 +387,22 @@ async def entrypoint(ctx: JobContext):
|
|||
|
||||
asyncio.create_task(_publish_clear_chat())
|
||||
|
||||
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_ON}":
|
||||
if remote_video_processor:
|
||||
remote_video_processor.set_video_context(True)
|
||||
log_message("set video context to True")
|
||||
else:
|
||||
video_context = True
|
||||
log_message("no remote video processor found, queued video context to True")
|
||||
|
||||
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_OFF}":
|
||||
if remote_video_processor:
|
||||
remote_video_processor.set_video_context(False)
|
||||
log_message("set video context to False")
|
||||
else:
|
||||
video_context = False
|
||||
log_message("no remote video processor found, queued video context to False")
|
||||
|
||||
|
||||
############################################################
|
||||
# Start the voice assistant with the LiveKit room
|
||||
|
@ -367,7 +421,7 @@ async def entrypoint(ctx: JobContext):
|
|||
await asyncio.sleep(1)
|
||||
|
||||
# Greets the user with an initial message
|
||||
await assistant.say(start_message, allow_interruptions=True)
|
||||
await assistant.say(START_MESSAGE, allow_interruptions=True)
|
||||
|
||||
############################################################
|
||||
# wait for the voice assistant to finish
|
||||
|
@ -389,12 +443,21 @@ def main(livekit_url: str):
|
|||
# Workers have to be run as CLIs right now.
|
||||
# So we need to simualte running "[this file] dev"
|
||||
|
||||
worker_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
||||
log_message(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
|
||||
print(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
|
||||
|
||||
# Modify sys.argv to set the path to this file as the first argument
|
||||
# and 'dev' as the second argument
|
||||
sys.argv = [str(__file__), 'dev']
|
||||
sys.argv = [str(__file__), 'start']
|
||||
|
||||
# livekit_url = "ws://localhost:7880"
|
||||
# Initialize the worker with the entrypoint
|
||||
cli.run_app(
|
||||
WorkerOptions(entrypoint_fnc=entrypoint, api_key="devkey", api_secret="secret", ws_url=livekit_url)
|
||||
WorkerOptions(
|
||||
entrypoint_fnc=entrypoint,
|
||||
api_key="devkey",
|
||||
api_secret="secret",
|
||||
ws_url=livekit_url
|
||||
)
|
||||
|
||||
)
|
Loading…
Reference in New Issue