make request based on updated chat ctx in anticipation
This commit is contained in:
parent
ab8055e0de
commit
a2f86afce1
|
@ -3,6 +3,7 @@ import json
|
||||||
import base64
|
import base64
|
||||||
import traceback
|
import traceback
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -11,7 +12,7 @@ from livekit import rtc
|
||||||
from livekit.agents.pipeline import VoicePipelineAgent
|
from livekit.agents.pipeline import VoicePipelineAgent
|
||||||
from livekit.agents.llm.chat_context import ChatContext
|
from livekit.agents.llm.chat_context import ChatContext
|
||||||
from source.server.livekit.logger import log_message
|
from source.server.livekit.logger import log_message
|
||||||
|
from livekit.agents.llm.chat_context import ChatImage
|
||||||
|
|
||||||
|
|
||||||
# Add these constants after the existing ones
|
# Add these constants after the existing ones
|
||||||
|
@ -52,20 +53,32 @@ async def handle_instruction_check(
|
||||||
log_message(f"Violation detected with severity {result['severity_rating']}, triggering assistant response")
|
log_message(f"Violation detected with severity {result['severity_rating']}, triggering assistant response")
|
||||||
|
|
||||||
# Append violation to chat context
|
# Append violation to chat context
|
||||||
violation_text = f"Safety violation detected: {result['violation_summary']}\nRecommendations: {result['recommendations']}"
|
violation_text = f"For the given instructions: {INSTRUCTIONS_PROMPT}\n. Instruction violation frame detected: {result['violation_summary']}\nRecommendations: {result['recommendations']}"
|
||||||
assistant.chat_ctx.append(
|
assistant.chat_ctx.append(
|
||||||
role="user",
|
role="user",
|
||||||
text=violation_text
|
text=violation_text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assistant.chat_ctx.append(
|
||||||
|
role="user",
|
||||||
|
images=[
|
||||||
|
ChatImage(image=video_frame)
|
||||||
|
]
|
||||||
|
)
|
||||||
log_message(f"Added violation to chat context: {violation_text}")
|
log_message(f"Added violation to chat context: {violation_text}")
|
||||||
|
|
||||||
|
|
||||||
|
log_message(f"Current chat context: {assistant.chat_ctx}")
|
||||||
|
|
||||||
# Trigger assistant response
|
# Trigger assistant response
|
||||||
response = f"I noticed that {result['violation_summary']}. {result['recommendations']}"
|
log_message(f"Triggering assistant response...")
|
||||||
log_message(f"Triggering assistant response: {response}")
|
|
||||||
|
|
||||||
# TODO: instead of saying the predetermined response, we'll trigger an assistant response here
|
# TODO: instead of saying the predetermined response, we'll trigger an assistant response here
|
||||||
# we can append the current video frame that triggered the violation to the chat context
|
# we can append the current video frame that triggered the violation to the chat context
|
||||||
stream = assistant.llm.chat()
|
stream = assistant.llm.chat(
|
||||||
|
chat_ctx=assistant.chat_ctx,
|
||||||
|
fnc_ctx=assistant.fnc_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
await assistant.say(stream)
|
await assistant.say(stream)
|
||||||
else:
|
else:
|
||||||
|
@ -84,7 +97,11 @@ async def check_instruction_violation(
|
||||||
log_message("Creating new context for instruction check...")
|
log_message("Creating new context for instruction check...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = OpenAI()
|
# pull this from env.
|
||||||
|
interpreter_server_host = os.getenv('INTERPRETER_SERVER_HOST', 'localhost')
|
||||||
|
interpreter_server_port = os.getenv('INTERPRETER_SERVER_PORT', '8000')
|
||||||
|
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
|
||||||
|
client = OpenAI(base_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get raw RGBA data
|
# Get raw RGBA data
|
||||||
|
@ -114,7 +131,7 @@ async def check_instruction_violation(
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
messages=[
|
messages=[
|
||||||
# append chat context to prompt without images -- we'll need to parse them out
|
# TODO: append chat context to prompt without images -- we'll need to parse them out
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
|
|
|
@ -84,7 +84,7 @@ async def entrypoint(ctx: JobContext):
|
||||||
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
|
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
|
||||||
|
|
||||||
# For debugging
|
# For debugging
|
||||||
base_url = "http://127.0.0.1:9000/"
|
base_url = "http://127.0.0.1:8000/"
|
||||||
|
|
||||||
open_interpreter = openai.LLM(
|
open_interpreter = openai.LLM(
|
||||||
model="open-interpreter", base_url=base_url, api_key="x"
|
model="open-interpreter", base_url=base_url, api_key="x"
|
||||||
|
@ -93,6 +93,7 @@ async def entrypoint(ctx: JobContext):
|
||||||
tts_provider = os.getenv('01_TTS', '').lower()
|
tts_provider = os.getenv('01_TTS', '').lower()
|
||||||
stt_provider = os.getenv('01_STT', '').lower()
|
stt_provider = os.getenv('01_STT', '').lower()
|
||||||
|
|
||||||
|
# todo: remove this
|
||||||
tts_provider = "elevenlabs"
|
tts_provider = "elevenlabs"
|
||||||
stt_provider = "deepgram"
|
stt_provider = "deepgram"
|
||||||
|
|
||||||
|
@ -100,7 +101,7 @@ async def entrypoint(ctx: JobContext):
|
||||||
if tts_provider == 'openai':
|
if tts_provider == 'openai':
|
||||||
tts = openai.TTS()
|
tts = openai.TTS()
|
||||||
elif tts_provider == 'local':
|
elif tts_provider == 'local':
|
||||||
tts = openai.TTS(base_url="http://localhost:8000/v1")
|
tts = openai.TTS(base_url="http://localhost:9001/v1")
|
||||||
print("using local tts")
|
print("using local tts")
|
||||||
elif tts_provider == 'elevenlabs':
|
elif tts_provider == 'elevenlabs':
|
||||||
tts = elevenlabs.TTS()
|
tts = elevenlabs.TTS()
|
||||||
|
@ -113,7 +114,7 @@ async def entrypoint(ctx: JobContext):
|
||||||
if stt_provider == 'deepgram':
|
if stt_provider == 'deepgram':
|
||||||
stt = deepgram.STT()
|
stt = deepgram.STT()
|
||||||
elif stt_provider == 'local':
|
elif stt_provider == 'local':
|
||||||
stt = openai.STT(base_url="http://localhost:8001/v1")
|
stt = openai.STT(base_url="http://localhost:9002/v1")
|
||||||
print("using local stt")
|
print("using local stt")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
|
raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
|
||||||
|
|
Loading…
Reference in New Issue