Added service for transcription
This commit is contained in:
parent
2593ef83c3
commit
3e10fb80fa
|
@ -1 +1,2 @@
|
||||||
.env
|
.env
|
||||||
|
.DS_Store
|
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
|
||||||
|
1. Install Rust and Python dependencies
|
||||||
|
2. Go to core/stt and run `cargo build --release`
|
||||||
|
3. Download GGML Whisper model from https://huggingface.co/ggerganov/whisper.cpp
|
||||||
|
4. Copy .env.example to .env and put the path to model
|
||||||
|
5. Run `python core/i_endpoint.py` to start the server
|
||||||
|
6. Run `python core/test_cli.py PATH_TO_FILE` to test sending audio to service and getting transcription back over websocket
|
|
@ -0,0 +1 @@
|
||||||
|
WHISPER_MODEL_PATH=/path/to/ggml-tiny.en.bin
|
|
@ -1,6 +1,13 @@
|
||||||
from fastapi import FastAPI, Request, WebSocket
|
from fastapi import FastAPI, Request, WebSocket
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import redis
|
import redis
|
||||||
|
import json
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from stt import get_transcription
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -10,7 +17,7 @@ r = redis.Redis(host='localhost', port=6379, db=0)
|
||||||
@app.post("/i/")
|
@app.post("/i/")
|
||||||
async def i(request: Request):
|
async def i(request: Request):
|
||||||
message = await request.json()
|
message = await request.json()
|
||||||
|
|
||||||
client_host = request.client.host # Get the client's IP address
|
client_host = request.client.host # Get the client's IP address
|
||||||
|
|
||||||
message = f"""
|
message = f"""
|
||||||
|
@ -26,5 +33,47 @@ async def i(request: Request):
|
||||||
"content": message
|
"content": message
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/a")
|
||||||
|
async def a(ws: WebSocket):
|
||||||
|
await ws.accept()
|
||||||
|
audio_file = bytearray()
|
||||||
|
mime_type = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await ws.receive()
|
||||||
|
|
||||||
|
if message['type'] == 'websocket.disconnect':
|
||||||
|
break
|
||||||
|
|
||||||
|
if message['type'] == 'websocket.receive':
|
||||||
|
if 'text' in message:
|
||||||
|
control_message = json.loads(message['text'])
|
||||||
|
if control_message.get('action') == 'command' and control_message.get('state') == 'start' and 'mimeType' in control_message:
|
||||||
|
# This indicates the start of a new audio file
|
||||||
|
mime_type = control_message.get('mimeType')
|
||||||
|
elif control_message.get('action') == 'command' and control_message.get('state') == 'end':
|
||||||
|
# This indicates the end of the audio file
|
||||||
|
# Process the complete audio file here
|
||||||
|
transcription = get_transcription(audio_file, mime_type)
|
||||||
|
await ws.send_json({"transcript": transcription})
|
||||||
|
|
||||||
|
print("SENT TRANSCRIPTION!")
|
||||||
|
|
||||||
|
# Reset the bytearray for the next audio file
|
||||||
|
audio_file = bytearray()
|
||||||
|
mime_type = None
|
||||||
|
elif 'bytes' in message:
|
||||||
|
# If it's not a control message, it's part of the audio file
|
||||||
|
audio_file.extend(message['bytes'])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket connection closed with exception: {e}")
|
||||||
|
finally:
|
||||||
|
await ws.close()
|
||||||
|
print("WebSocket connection closed")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
with tempfile.TemporaryDirectory():
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
|
import tempfile
|
||||||
|
import ffmpeg
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def convert_mime_type_to_format(mime_type: str) -> str:
|
||||||
|
if mime_type == "audio/x-wav":
|
||||||
|
return "wav"
|
||||||
|
if mime_type == "audio/webm":
|
||||||
|
return "webm"
|
||||||
|
|
||||||
|
return mime_type
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
|
||||||
|
# Create a temporary file with the appropriate extension
|
||||||
|
input_ext = convert_mime_type_to_format(mime_type)
|
||||||
|
input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}")
|
||||||
|
with open(input_path, 'wb') as f:
|
||||||
|
f.write(audio)
|
||||||
|
|
||||||
|
# Export to wav
|
||||||
|
output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||||
|
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
|
||||||
|
|
||||||
|
print(f"Temporary file path: {output_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield output_path
|
||||||
|
finally:
|
||||||
|
os.remove(input_path)
|
||||||
|
#os.remove(output_path)
|
||||||
|
|
||||||
|
def run_command(command):
|
||||||
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
return result.stdout, result.stderr
|
||||||
|
|
||||||
|
def get_transcription(audio_bytes: bytearray, mime_type):
|
||||||
|
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
||||||
|
model_path = os.getenv("WHISPER_MODEL_PATH")
|
||||||
|
if not model_path:
|
||||||
|
raise EnvironmentError("WHISPER_MODEL_PATH environment variable is not set.")
|
||||||
|
|
||||||
|
output, error = run_command([
|
||||||
|
os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release', 'whisper-rust'),
|
||||||
|
'--model-path', model_path,
|
||||||
|
'--file-path', wav_file_path
|
||||||
|
])
|
||||||
|
|
||||||
|
print("Exciting transcription result:", output)
|
||||||
|
return output
|
Binary file not shown.
|
@ -0,0 +1 @@
|
||||||
|
target
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "whisper-rust"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.79"
|
||||||
|
clap = { version = "4.4.18", features = ["derive"] }
|
||||||
|
cpal = "0.15.2"
|
||||||
|
hound = "3.5.1"
|
||||||
|
whisper-rs = "0.10.0"
|
||||||
|
whisper-rs-sys = "0.8.0"
|
|
@ -0,0 +1,34 @@
|
||||||
|
mod transcribe;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use transcribe::transcribe;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// This is the model for Whisper STT
|
||||||
|
#[arg(short, long, value_parser, required = true)]
|
||||||
|
model_path: PathBuf,
|
||||||
|
|
||||||
|
/// This is the wav audio file that will be converted from speech to text
|
||||||
|
#[arg(short, long, value_parser, required = true)]
|
||||||
|
file_path: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let file_path = match args.file_path {
|
||||||
|
Some(fp) => fp,
|
||||||
|
None => panic!("No file path provided")
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = transcribe(&args.model_path, &file_path);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(transcription) => print!("{}", transcription),
|
||||||
|
Err(e) => panic!("Error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
|
||||||
|
/// Transcribes the given audio file using the whisper-rs library.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_path` - Path to Whisper model file
|
||||||
|
/// * `file_path` - A string slice that holds the path to the audio file to be transcribed.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A Result containing a String with the transcription if successful, or an error message if not.
|
||||||
|
pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result<String, String> {
|
||||||
|
|
||||||
|
let model_path_str = model_path.to_str().expect("Not valid model path");
|
||||||
|
// Load a context and model
|
||||||
|
let ctx = WhisperContext::new_with_params(
|
||||||
|
model_path_str, // Replace with the actual path to the model
|
||||||
|
WhisperContextParameters::default(),
|
||||||
|
)
|
||||||
|
.map_err(|_| "failed to load model")?;
|
||||||
|
|
||||||
|
// Create a state
|
||||||
|
let mut state = ctx.create_state().map_err(|_| "failed to create state")?;
|
||||||
|
|
||||||
|
// Create a params object
|
||||||
|
// Note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
|
||||||
|
// Edit parameters as needed
|
||||||
|
params.set_n_threads(1); // Set the number of threads to use
|
||||||
|
params.set_translate(true); // Enable translation
|
||||||
|
params.set_language(Some("en")); // Set the language to translate to English
|
||||||
|
// Disable printing to stdout
|
||||||
|
params.set_print_special(false);
|
||||||
|
params.set_print_progress(false);
|
||||||
|
params.set_print_realtime(false);
|
||||||
|
params.set_print_timestamps(false);
|
||||||
|
|
||||||
|
// Load the audio file
|
||||||
|
let audio_data = std::fs::read(file_path)
|
||||||
|
.map_err(|e| format!("failed to read audio file: {}", e))?
|
||||||
|
.chunks_exact(2)
|
||||||
|
.map(|chunk| i16::from_ne_bytes([chunk[0], chunk[1]]))
|
||||||
|
.collect::<Vec<i16>>();
|
||||||
|
|
||||||
|
// Convert the audio data to the required format (16KHz mono i16 samples)
|
||||||
|
let audio_data = whisper_rs::convert_integer_to_float_audio(&audio_data);
|
||||||
|
|
||||||
|
// Run the model
|
||||||
|
state.full(params, &audio_data[..]).map_err(|_| "failed to run model")?;
|
||||||
|
|
||||||
|
// Fetch the results
|
||||||
|
let num_segments = state.full_n_segments().map_err(|_| "failed to get number of segments")?;
|
||||||
|
let mut transcription = String::new();
|
||||||
|
for i in 0..num_segments {
|
||||||
|
let segment = state.full_get_segment_text(i).map_err(|_| "failed to get segment")?;
|
||||||
|
transcription.push_str(&segment);
|
||||||
|
transcription.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(transcription)
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Define the function to send audio file in chunks
|
||||||
|
async def send_audio_in_chunks(file_path, chunk_size=4096):
|
||||||
|
async with websockets.connect("ws://localhost:8000/a") as websocket:
|
||||||
|
# Send the start command with mime type
|
||||||
|
await websocket.send(json.dumps({"action": "command", "state": "start", "mimeType": "audio/webm"}))
|
||||||
|
|
||||||
|
# Open the file in binary mode and send in chunks
|
||||||
|
with open(file_path, 'rb') as audio_file:
|
||||||
|
chunk = audio_file.read(chunk_size)
|
||||||
|
while chunk:
|
||||||
|
await websocket.send(chunk)
|
||||||
|
chunk = audio_file.read(chunk_size)
|
||||||
|
|
||||||
|
# Send the end command
|
||||||
|
await websocket.send(json.dumps({"action": "command", "state": "end"}))
|
||||||
|
|
||||||
|
# Receive a json message and then close the connection
|
||||||
|
message = await websocket.recv()
|
||||||
|
print("Received message:", json.loads(message))
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Send a webm audio file to the /a websocket endpoint and print the responses.")
|
||||||
|
parser.add_argument("file_path", help="The path to the webm audio file to send.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Check if the file exists
|
||||||
|
if not os.path.isfile(args.file_path):
|
||||||
|
print(args.file_path)
|
||||||
|
print("Error: The file does not exist.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Run the asyncio event loop
|
||||||
|
asyncio.get_event_loop().run_until_complete(send_audio_in_chunks(args.file_path))
|
|
@ -2,4 +2,7 @@ git+https://github.com/KillianLucas/open-interpreter.git
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
fastapi==0.109.0
|
fastapi==0.109.0
|
||||||
uvicorn==0.27.0.post1
|
uvicorn==0.27.0.post1
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
|
pydub==0.25.1
|
||||||
|
numpy==1.26.3
|
||||||
|
python-dotenv==1.0.1
|
Loading…
Reference in New Issue