From a2f86afce133410a2d86d7a74d63e0fbe9d5289d Mon Sep 17 00:00:00 2001 From: Ben Xu Date: Wed, 1 Jan 2025 03:52:24 -0500 Subject: [PATCH] make request based on updated chat ctx in anticipation --- .../source/server/livekit/anticipation.py | 31 ++++++++++++++----- software/source/server/livekit/worker.py | 7 +++-- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/software/source/server/livekit/anticipation.py b/software/source/server/livekit/anticipation.py index 383e668..e41e008 100644 --- a/software/source/server/livekit/anticipation.py +++ b/software/source/server/livekit/anticipation.py @@ -3,6 +3,7 @@ import json import base64 import traceback import io +import os from PIL import Image as PIL_Image from openai import OpenAI @@ -11,7 +12,7 @@ from livekit import rtc from livekit.agents.pipeline import VoicePipelineAgent from livekit.agents.llm.chat_context import ChatContext from source.server.livekit.logger import log_message - +from livekit.agents.llm.chat_context import ChatImage # Add these constants after the existing ones @@ -52,20 +53,32 @@ async def handle_instruction_check( log_message(f"Violation detected with severity {result['severity_rating']}, triggering assistant response") # Append violation to chat context - violation_text = f"Safety violation detected: {result['violation_summary']}\nRecommendations: {result['recommendations']}" + violation_text = f"For the given instructions: {INSTRUCTIONS_PROMPT}\n. Instruction violation frame detected: {result['violation_summary']}\nRecommendations: {result['recommendations']}" assistant.chat_ctx.append( role="user", text=violation_text ) + + assistant.chat_ctx.append( + role="user", + images=[ + ChatImage(image=video_frame) + ] + ) log_message(f"Added violation to chat context: {violation_text}") + + + log_message(f"Current chat context: {assistant.chat_ctx}") # Trigger assistant response - response = f"I noticed that {result['violation_summary']}. {result['recommendations']}" - log_message(f"Triggering assistant response: {response}") + log_message(f"Triggering assistant response...") # TODO: instead of saying the predetermined response, we'll trigger an assistant response here # we can append the current video frame that triggered the violation to the chat context - stream = assistant.llm.chat() + stream = assistant.llm.chat( + chat_ctx=assistant.chat_ctx, + fnc_ctx=assistant.fnc_ctx, + ) await assistant.say(stream) else: @@ -84,7 +97,11 @@ async def check_instruction_violation( log_message("Creating new context for instruction check...") try: - client = OpenAI() + # pull this from env. + interpreter_server_host = os.getenv('INTERPRETER_SERVER_HOST', 'localhost') + interpreter_server_port = os.getenv('INTERPRETER_SERVER_PORT', '8000') + base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/" + client = OpenAI(base_url) try: # Get raw RGBA data @@ -114,7 +131,7 @@ async def check_instruction_violation( response = client.chat.completions.create( model="gpt-4o-mini", messages=[ - # append chat context to prompt without images -- we'll need to parse them out + # TODO: append chat context to prompt without images -- we'll need to parse them out { "role": "user", "content": [ diff --git a/software/source/server/livekit/worker.py b/software/source/server/livekit/worker.py index 275b4b0..d06bd2c 100644 --- a/software/source/server/livekit/worker.py +++ b/software/source/server/livekit/worker.py @@ -84,7 +84,7 @@ async def entrypoint(ctx: JobContext): base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/" # For debugging - base_url = "http://127.0.0.1:9000/" + base_url = "http://127.0.0.1:8000/" open_interpreter = openai.LLM( model="open-interpreter", base_url=base_url, api_key="x" @@ -93,6 +93,7 @@ async def entrypoint(ctx: JobContext): tts_provider = os.getenv('01_TTS', '').lower() stt_provider = os.getenv('01_STT', '').lower() + # todo: remove this tts_provider = "elevenlabs" stt_provider = "deepgram" @@ -100,7 +101,7 @@ async def entrypoint(ctx: JobContext): if tts_provider == 'openai': tts = openai.TTS() elif tts_provider == 'local': - tts = openai.TTS(base_url="http://localhost:8000/v1") + tts = openai.TTS(base_url="http://localhost:9001/v1") print("using local tts") elif tts_provider == 'elevenlabs': tts = elevenlabs.TTS() @@ -113,7 +114,7 @@ async def entrypoint(ctx: JobContext): if stt_provider == 'deepgram': stt = deepgram.STT() elif stt_provider == 'local': - stt = openai.STT(base_url="http://localhost:8001/v1") + stt = openai.STT(base_url="http://localhost:9002/v1") print("using local stt") else: raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")