add local stt & tts, add anticipation logic, remove video context accumulation
This commit is contained in:
		
							parent
							
								
									bd6f530be7
								
							
						
					
					
						commit
						6110e70430
					
				| 
						 | 
				
			
			@ -1,10 +1,13 @@
 | 
			
		|||
from livekit.rtc import VideoStream
 | 
			
		||||
from livekit.rtc import VideoStream, VideoFrame, VideoBufferType
 | 
			
		||||
from livekit.agents import JobContext
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from livekit.rtc import VideoFrame
 | 
			
		||||
import asyncio
 | 
			
		||||
from typing import Callable, Coroutine, Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Interval settings
 | 
			
		||||
INTERVAL = 30  # seconds
 | 
			
		||||
 | 
			
		||||
# Define the path to the log file
 | 
			
		||||
LOG_FILE_PATH = 'video_processor.txt'
 | 
			
		||||
| 
						 | 
				
			
			@ -20,34 +23,71 @@ def log_message(message: str):
 | 
			
		|||
        log_file.write(f"{timestamp} - {message}\n")
 | 
			
		||||
 | 
			
		||||
class RemoteVideoProcessor:
 | 
			
		||||
    """Processes video frames from a remote participant's video stream."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, video_stream: VideoStream, job_ctx: JobContext):
 | 
			
		||||
        log_message("Initializing RemoteVideoProcessor")
 | 
			
		||||
        self.video_stream = video_stream
 | 
			
		||||
        self.job_ctx = job_ctx
 | 
			
		||||
        self.current_frame = None  # Store the latest VideoFrame
 | 
			
		||||
        self.current_frame = None
 | 
			
		||||
        self.lock = asyncio.Lock()
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        self.interval = INTERVAL
 | 
			
		||||
        self.video_context = False
 | 
			
		||||
        self.last_capture_time = 0
 | 
			
		||||
        
 | 
			
		||||
        # Add callback for safety checks
 | 
			
		||||
        self.on_instruction_check: Callable[[VideoFrame], Coroutine[Any, Any, None]] | None = None
 | 
			
		||||
 | 
			
		||||
    async def process_frames(self):
 | 
			
		||||
        log_message("Starting to process remote video frames.")
 | 
			
		||||
        """Process incoming video frames."""
 | 
			
		||||
        async for frame_event in self.video_stream:
 | 
			
		||||
            try:
 | 
			
		||||
                video_frame = frame_event.frame
 | 
			
		||||
                timestamp = frame_event.timestamp_us
 | 
			
		||||
                rotation = frame_event.rotation
 | 
			
		||||
                
 | 
			
		||||
                log_message(f"Processing frame at timestamp {timestamp/1000000:.3f}s")
 | 
			
		||||
                log_message(f"Frame details: size={video_frame.width}x{video_frame.height}, type={video_frame.type}")
 | 
			
		||||
 | 
			
		||||
                # Store the current frame safely
 | 
			
		||||
                log_message(f"Received frame: width={video_frame.width}, height={video_frame.height}, type={video_frame.type}")
 | 
			
		||||
                async with self.lock:
 | 
			
		||||
                    self.current_frame = video_frame
 | 
			
		||||
                    
 | 
			
		||||
                    if self.video_context and self._check_interrupt(timestamp):
 | 
			
		||||
                        self.last_capture_time = timestamp
 | 
			
		||||
                        # Trigger instruction check callback if registered
 | 
			
		||||
                        if self.on_instruction_check:
 | 
			
		||||
                            await self.on_instruction_check(video_frame)
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                log_message(f"Error processing frame: {e}")
 | 
			
		||||
                log_message(f"Error processing frame: {str(e)}")
 | 
			
		||||
                import traceback
 | 
			
		||||
                log_message(f"Traceback: {traceback.format_exc()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def register_safety_check_callback(self, callback: Callable[[VideoFrame], Coroutine[Any, Any, None]]):
 | 
			
		||||
        """Register a callback for safety checks"""
 | 
			
		||||
        self.on_instruction_check = callback
 | 
			
		||||
        log_message("Registered instruction check callback")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def get_current_frame(self) -> VideoFrame | None:
 | 
			
		||||
        """Retrieve the current VideoFrame."""
 | 
			
		||||
        log_message("called get current frame")
 | 
			
		||||
        """Get the most recent video frame."""
 | 
			
		||||
        log_message("Getting current frame")
 | 
			
		||||
        async with self.lock:
 | 
			
		||||
            log_message("retrieving current frame: " + str(self.current_frame))
 | 
			
		||||
            return self.current_frame
 | 
			
		||||
            if self.current_frame is None:
 | 
			
		||||
                log_message("No current frame available")
 | 
			
		||||
            return self.current_frame
 | 
			
		||||
        
 | 
			
		||||
    
 | 
			
		||||
    def set_video_context(self, context: bool):
 | 
			
		||||
        """Set the video context."""
 | 
			
		||||
        log_message(f"Setting video context to: {context}")
 | 
			
		||||
        self.video_context = context
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def get_video_context(self) -> bool:
 | 
			
		||||
        """Get the video context."""
 | 
			
		||||
        return self.video_context
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def _check_interrupt(self, timestamp: int) -> bool:
 | 
			
		||||
        """Determine if the video context should be interrupted."""
 | 
			
		||||
        return timestamp - self.last_capture_time > self.interval * 1000000
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,10 +2,11 @@ import asyncio
 | 
			
		|||
import numpy as np
 | 
			
		||||
import sys
 | 
			
		||||
import os
 | 
			
		||||
import threading
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import Literal, Awaitable
 | 
			
		||||
 | 
			
		||||
from livekit.agents import JobContext, WorkerOptions, cli
 | 
			
		||||
from livekit.agents import JobContext, WorkerOptions, cli, transcription
 | 
			
		||||
from livekit.agents.transcription import STTSegmentsForwarder
 | 
			
		||||
from livekit.agents.llm import ChatContext
 | 
			
		||||
from livekit import rtc
 | 
			
		||||
| 
						 | 
				
			
			@ -13,30 +14,23 @@ from livekit.agents.pipeline import VoicePipelineAgent
 | 
			
		|||
from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia
 | 
			
		||||
from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage
 | 
			
		||||
from livekit.agents.llm import LLMStream
 | 
			
		||||
from livekit.agents.stt import SpeechStream
 | 
			
		||||
 | 
			
		||||
from source.server.livekit.video_processor import RemoteVideoProcessor
 | 
			
		||||
 | 
			
		||||
from source.server.livekit.transcriptions import _forward_transcription
 | 
			
		||||
from source.server.livekit.anticipation import handle_instruction_check
 | 
			
		||||
from source.server.livekit.logger import log_message
 | 
			
		||||
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
# Define the path to the log file
 | 
			
		||||
LOG_FILE_PATH = 'worker.txt'
 | 
			
		||||
DEBUG = os.getenv('DEBUG', 'false').lower() == 'true'
 | 
			
		||||
 | 
			
		||||
def log_message(message: str):
 | 
			
		||||
    """Append a message to the log file with a timestamp."""
 | 
			
		||||
    if not DEBUG:
 | 
			
		||||
        return
 | 
			
		||||
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 | 
			
		||||
    with open(LOG_FILE_PATH, 'a') as log_file:
 | 
			
		||||
        log_file.write(f"{timestamp} - {message}\n")
 | 
			
		||||
 | 
			
		||||
start_message = """Hi! You can hold the white circle below to speak to me.
 | 
			
		||||
_room_lock = threading.Lock()
 | 
			
		||||
_connected_rooms = set()
 | 
			
		||||
 | 
			
		||||
Try asking what I can do."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
START_MESSAGE = "Hi! You can hold the white circle below to speak to me. Try asking what I can do."
 | 
			
		||||
 | 
			
		||||
# This function is the entrypoint for the agent.
 | 
			
		||||
async def entrypoint(ctx: JobContext):
 | 
			
		||||
| 
						 | 
				
			
			@ -96,7 +90,7 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
    base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
 | 
			
		||||
 | 
			
		||||
    # For debugging
 | 
			
		||||
    base_url = "http://127.0.0.1:8000/"
 | 
			
		||||
    base_url = "http://127.0.0.1:9000/"
 | 
			
		||||
 | 
			
		||||
    open_interpreter = openai.LLM(
 | 
			
		||||
        model="open-interpreter", base_url=base_url, api_key="x"
 | 
			
		||||
| 
						 | 
				
			
			@ -105,11 +99,18 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
    tts_provider = os.getenv('01_TTS', '').lower()
 | 
			
		||||
    stt_provider = os.getenv('01_STT', '').lower()
 | 
			
		||||
 | 
			
		||||
    tts_provider = "elevenlabs"
 | 
			
		||||
    stt_provider = "deepgram"
 | 
			
		||||
 | 
			
		||||
    # Add plugins here
 | 
			
		||||
    if tts_provider == 'openai':
 | 
			
		||||
        tts = openai.TTS()
 | 
			
		||||
    elif tts_provider == 'local':
 | 
			
		||||
        tts = openai.TTS(base_url="http://localhost:8000/v1")
 | 
			
		||||
        print("using local tts")
 | 
			
		||||
    elif tts_provider == 'elevenlabs':
 | 
			
		||||
        tts = elevenlabs.TTS()
 | 
			
		||||
        print("using elevenlabs tts")
 | 
			
		||||
    elif tts_provider == 'cartesia':
 | 
			
		||||
        tts = cartesia.TTS()
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -117,16 +118,20 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
 | 
			
		||||
    if stt_provider == 'deepgram':
 | 
			
		||||
        stt = deepgram.STT()
 | 
			
		||||
    elif stt_provider == 'local':
 | 
			
		||||
        stt = openai.STT(base_url="http://localhost:8001/v1")
 | 
			
		||||
        print("using local stt")
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
 | 
			
		||||
 | 
			
		||||
    ############################################################
 | 
			
		||||
    # initialize voice assistant states
 | 
			
		||||
    ############################################################
 | 
			
		||||
    push_to_talk = True
 | 
			
		||||
    push_to_talk = False
 | 
			
		||||
    current_message: ChatMessage = ChatMessage(role='user')
 | 
			
		||||
    submitted_message: ChatMessage = ChatMessage(role='user')
 | 
			
		||||
    video_muted = False
 | 
			
		||||
    video_context = False
 | 
			
		||||
 | 
			
		||||
    tasks = []
 | 
			
		||||
    ############################################################
 | 
			
		||||
| 
						 | 
				
			
			@ -175,6 +180,7 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
 | 
			
		||||
                if remote_video_processor and not video_muted:
 | 
			
		||||
                    video_frame = await remote_video_processor.get_current_frame()
 | 
			
		||||
 | 
			
		||||
                    if video_frame:
 | 
			
		||||
                        chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
 | 
			
		||||
                    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -202,7 +208,15 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
 | 
			
		||||
            # append image if available
 | 
			
		||||
            if remote_video_processor and not video_muted:
 | 
			
		||||
                video_frame = await remote_video_processor.get_current_frame()
 | 
			
		||||
                if remote_video_processor.get_video_context():
 | 
			
		||||
                    log_message("context is true")
 | 
			
		||||
                    log_message("retrieving timeline frame")
 | 
			
		||||
                    video_frame = await remote_video_processor.get_timeline_frame()
 | 
			
		||||
                else:   
 | 
			
		||||
                    log_message("context is false")
 | 
			
		||||
                    log_message("retrieving current frame")
 | 
			
		||||
                    video_frame = await remote_video_processor.get_current_frame()
 | 
			
		||||
 | 
			
		||||
                if video_frame:
 | 
			
		||||
                    chat_ctx.append(role="user", images=[ChatImage(image=video_frame)]) 
 | 
			
		||||
                    log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}")
 | 
			
		||||
| 
						 | 
				
			
			@ -263,6 +277,19 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
    ############################################################
 | 
			
		||||
    # transcribe participant track 
 | 
			
		||||
    ############################################################
 | 
			
		||||
    async def _forward_transcription(
 | 
			
		||||
        stt_stream: SpeechStream,
 | 
			
		||||
        stt_forwarder: transcription.STTSegmentsForwarder,
 | 
			
		||||
    ):
 | 
			
		||||
        """Forward the transcription and log the transcript in the console"""
 | 
			
		||||
        async for ev in stt_stream:
 | 
			
		||||
            stt_forwarder.update(ev)
 | 
			
		||||
            if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
 | 
			
		||||
                print(ev.alternatives[0].text, end="")
 | 
			
		||||
            elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
 | 
			
		||||
                print("\n")
 | 
			
		||||
                print(" -> ", ev.alternatives[0].text)
 | 
			
		||||
 | 
			
		||||
    async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track):
 | 
			
		||||
        audio_stream = rtc.AudioStream(track)
 | 
			
		||||
        stt_forwarder = STTSegmentsForwarder(
 | 
			
		||||
| 
						 | 
				
			
			@ -297,8 +324,18 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
            remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA)
 | 
			
		||||
            remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx)
 | 
			
		||||
            log_message("remote video processor." + str(remote_video_processor))
 | 
			
		||||
            
 | 
			
		||||
            # Register safety check callback
 | 
			
		||||
            remote_video_processor.register_safety_check_callback(
 | 
			
		||||
                lambda frame: handle_instruction_check(assistant, frame)
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
            remote_video_processor.set_video_context(video_context)
 | 
			
		||||
            log_message(f"set video context to {video_context} from queued video context")
 | 
			
		||||
            
 | 
			
		||||
            asyncio.create_task(remote_video_processor.process_frames())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ############################################################
 | 
			
		||||
    # on track muted callback
 | 
			
		||||
    ############################################################
 | 
			
		||||
| 
						 | 
				
			
			@ -329,11 +366,12 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
        local_participant = ctx.room.local_participant
 | 
			
		||||
        await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context")
 | 
			
		||||
        log_message("sent {CLEAR_CHAT} to chat_context for client to clear")
 | 
			
		||||
        await assistant.say(assistant.start_message)
 | 
			
		||||
        await assistant.say(START_MESSAGE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @ctx.room.on("data_received")
 | 
			
		||||
    def on_data_received(data: rtc.DataPacket):
 | 
			
		||||
        nonlocal video_context
 | 
			
		||||
        decoded_data = data.data.decode()
 | 
			
		||||
        log_message(f"received data from {data.topic}: {decoded_data}")
 | 
			
		||||
        if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}":
 | 
			
		||||
| 
						 | 
				
			
			@ -349,6 +387,22 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
 | 
			
		||||
            asyncio.create_task(_publish_clear_chat())
 | 
			
		||||
 | 
			
		||||
        if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_ON}":
 | 
			
		||||
            if remote_video_processor:
 | 
			
		||||
                remote_video_processor.set_video_context(True)
 | 
			
		||||
                log_message("set video context to True")
 | 
			
		||||
            else:
 | 
			
		||||
                video_context = True
 | 
			
		||||
                log_message("no remote video processor found, queued video context to True")
 | 
			
		||||
 | 
			
		||||
        if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_OFF}":
 | 
			
		||||
            if remote_video_processor:
 | 
			
		||||
                remote_video_processor.set_video_context(False)
 | 
			
		||||
                log_message("set video context to False")
 | 
			
		||||
            else:
 | 
			
		||||
                video_context = False
 | 
			
		||||
                log_message("no remote video processor found, queued video context to False")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ############################################################
 | 
			
		||||
    # Start the voice assistant with the LiveKit room
 | 
			
		||||
| 
						 | 
				
			
			@ -367,7 +421,7 @@ async def entrypoint(ctx: JobContext):
 | 
			
		|||
    await asyncio.sleep(1)
 | 
			
		||||
 | 
			
		||||
    # Greets the user with an initial message
 | 
			
		||||
    await assistant.say(start_message, allow_interruptions=True)
 | 
			
		||||
    await assistant.say(START_MESSAGE, allow_interruptions=True)
 | 
			
		||||
 | 
			
		||||
    ############################################################
 | 
			
		||||
    # wait for the voice assistant to finish
 | 
			
		||||
| 
						 | 
				
			
			@ -389,12 +443,21 @@ def main(livekit_url: str):
 | 
			
		|||
    # Workers have to be run as CLIs right now.
 | 
			
		||||
    # So we need to simualte running "[this file] dev"
 | 
			
		||||
 | 
			
		||||
    worker_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
 | 
			
		||||
    log_message(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
 | 
			
		||||
    print(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
 | 
			
		||||
 | 
			
		||||
    # Modify sys.argv to set the path to this file as the first argument
 | 
			
		||||
    # and 'dev' as the second argument
 | 
			
		||||
    sys.argv = [str(__file__), 'dev']
 | 
			
		||||
    sys.argv = [str(__file__), 'start']
 | 
			
		||||
 | 
			
		||||
    # livekit_url = "ws://localhost:7880"
 | 
			
		||||
    # Initialize the worker with the entrypoint
 | 
			
		||||
    cli.run_app(
 | 
			
		||||
        WorkerOptions(entrypoint_fnc=entrypoint, api_key="devkey", api_secret="secret", ws_url=livekit_url)
 | 
			
		||||
        WorkerOptions(
 | 
			
		||||
            entrypoint_fnc=entrypoint, 
 | 
			
		||||
            api_key="devkey", 
 | 
			
		||||
            api_secret="secret", 
 | 
			
		||||
            ws_url=livekit_url
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    )
 | 
			
		||||
		Loading…
	
		Reference in New Issue