Added service for transcription
This commit is contained in:
		
							parent
							
								
									2593ef83c3
								
							
						
					
					
						commit
						3e10fb80fa
					
				|  | @ -1 +1,2 @@ | ||||||
| .env | .env | ||||||
|  | .DS_Store | ||||||
|  | @ -0,0 +1,9 @@ | ||||||
|  | 
 | ||||||
|  | # Setup | ||||||
|  | 
 | ||||||
|  | 1. Install Rust and Python dependencies | ||||||
|  | 2. Go to core/stt and run `cargo build --release` | ||||||
|  | 3. Download GGML Whisper model from https://huggingface.co/ggerganov/whisper.cpp | ||||||
|  | 4. Copy .env.example to .env and put the path to model | ||||||
|  | 5. Run `python core/i_endpoint.py` to start the server | ||||||
|  | 6. Run `python core/test_cli.py PATH_TO_FILE` to test sending audio to service and getting transcription back over websocket | ||||||
|  | @ -0,0 +1 @@ | ||||||
|  | WHISPER_MODEL_PATH=/path/to/ggml-tiny.en.bin | ||||||
|  | @ -1,6 +1,13 @@ | ||||||
| from fastapi import FastAPI, Request, WebSocket | from fastapi import FastAPI, Request, WebSocket | ||||||
| import uvicorn | import uvicorn | ||||||
| import redis | import redis | ||||||
|  | import json | ||||||
|  | from dotenv import load_dotenv | ||||||
|  | from stt import get_transcription | ||||||
|  | import tempfile | ||||||
|  | 
 | ||||||
|  | # Load environment variables | ||||||
|  | load_dotenv() | ||||||
| 
 | 
 | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| 
 | 
 | ||||||
|  | @ -10,7 +17,7 @@ r = redis.Redis(host='localhost', port=6379, db=0) | ||||||
| @app.post("/i/") | @app.post("/i/") | ||||||
| async def i(request: Request): | async def i(request: Request): | ||||||
|     message = await request.json() |     message = await request.json() | ||||||
|      | 
 | ||||||
|     client_host = request.client.host  # Get the client's IP address |     client_host = request.client.host  # Get the client's IP address | ||||||
| 
 | 
 | ||||||
|     message = f""" |     message = f""" | ||||||
|  | @ -26,5 +33,47 @@ async def i(request: Request): | ||||||
|         "content": message |         "content": message | ||||||
|     }) |     }) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | @app.websocket("/a") | ||||||
|  | async def a(ws: WebSocket): | ||||||
|  |     await ws.accept() | ||||||
|  |     audio_file = bytearray() | ||||||
|  |     mime_type = None | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         while True: | ||||||
|  |             message = await ws.receive() | ||||||
|  | 
 | ||||||
|  |             if message['type'] == 'websocket.disconnect': | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |             if message['type'] == 'websocket.receive': | ||||||
|  |                 if 'text' in message: | ||||||
|  |                     control_message = json.loads(message['text']) | ||||||
|  |                     if control_message.get('action') == 'command' and control_message.get('state') == 'start' and 'mimeType' in control_message: | ||||||
|  |                         # This indicates the start of a new audio file | ||||||
|  |                         mime_type = control_message.get('mimeType') | ||||||
|  |                     elif control_message.get('action') == 'command' and control_message.get('state') == 'end': | ||||||
|  |                         # This indicates the end of the audio file | ||||||
|  |                         # Process the complete audio file here | ||||||
|  |                         transcription = get_transcription(audio_file, mime_type) | ||||||
|  |                         await ws.send_json({"transcript": transcription}) | ||||||
|  |                          | ||||||
|  |                         print("SENT TRANSCRIPTION!") | ||||||
|  | 
 | ||||||
|  |                         # Reset the bytearray for the next audio file | ||||||
|  |                         audio_file = bytearray() | ||||||
|  |                         mime_type = None | ||||||
|  |                 elif 'bytes' in message: | ||||||
|  |                     # If it's not a control message, it's part of the audio file | ||||||
|  |                     audio_file.extend(message['bytes']) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"WebSocket connection closed with exception: {e}") | ||||||
|  |     finally: | ||||||
|  |         await ws.close() | ||||||
|  |         print("WebSocket connection closed") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     uvicorn.run(app, host="0.0.0.0", port=8000) |     with tempfile.TemporaryDirectory(): | ||||||
|  |         uvicorn.run(app, host="0.0.0.0", port=8000) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,55 @@ | ||||||
|  | from datetime import datetime | ||||||
|  | import os | ||||||
|  | import contextlib | ||||||
|  | import tempfile | ||||||
|  | import ffmpeg | ||||||
|  | import subprocess | ||||||
|  | 
 | ||||||
|  | def convert_mime_type_to_format(mime_type: str) -> str: | ||||||
|  |     if mime_type == "audio/x-wav": | ||||||
|  |         return "wav" | ||||||
|  |     if mime_type == "audio/webm": | ||||||
|  |         return "webm" | ||||||
|  | 
 | ||||||
|  |     return mime_type | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: | ||||||
|  |     temp_dir = tempfile.gettempdir() | ||||||
|  | 
 | ||||||
|  |     # Create a temporary file with the appropriate extension | ||||||
|  |     input_ext = convert_mime_type_to_format(mime_type) | ||||||
|  |     input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}") | ||||||
|  |     with open(input_path, 'wb') as f: | ||||||
|  |         f.write(audio) | ||||||
|  | 
 | ||||||
|  |     # Export to wav | ||||||
|  |     output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") | ||||||
|  |     ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run() | ||||||
|  | 
 | ||||||
|  |     print(f"Temporary file path: {output_path}") | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         yield output_path | ||||||
|  |     finally: | ||||||
|  |         os.remove(input_path) | ||||||
|  |         #os.remove(output_path) | ||||||
|  | 
 | ||||||
|  | def run_command(command): | ||||||
|  |     result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | ||||||
|  |     return result.stdout, result.stderr | ||||||
|  | 
 | ||||||
|  | def get_transcription(audio_bytes: bytearray, mime_type): | ||||||
|  |     with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: | ||||||
|  |         model_path = os.getenv("WHISPER_MODEL_PATH") | ||||||
|  |         if not model_path: | ||||||
|  |             raise EnvironmentError("WHISPER_MODEL_PATH environment variable is not set.") | ||||||
|  | 
 | ||||||
|  |         output, error = run_command([ | ||||||
|  |             os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release', 'whisper-rust'), | ||||||
|  |             '--model-path', model_path, | ||||||
|  |             '--file-path', wav_file_path | ||||||
|  |         ]) | ||||||
|  | 
 | ||||||
|  |         print("Exciting transcription result:", output) | ||||||
|  |         return output | ||||||
										
											Binary file not shown.
										
									
								
							|  | @ -0,0 +1 @@ | ||||||
|  | target | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,14 @@ | ||||||
|  | [package] | ||||||
|  | name = "whisper-rust" | ||||||
|  | version = "0.1.0" | ||||||
|  | edition = "2021" | ||||||
|  | 
 | ||||||
|  | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
|  | 
 | ||||||
|  | [dependencies] | ||||||
|  | anyhow = "1.0.79" | ||||||
|  | clap = { version = "4.4.18", features = ["derive"] } | ||||||
|  | cpal = "0.15.2" | ||||||
|  | hound = "3.5.1" | ||||||
|  | whisper-rs = "0.10.0" | ||||||
|  | whisper-rs-sys = "0.8.0" | ||||||
|  | @ -0,0 +1,34 @@ | ||||||
|  | mod transcribe; | ||||||
|  | 
 | ||||||
|  | use clap::Parser; | ||||||
|  | use std::path::PathBuf; | ||||||
|  | use transcribe::transcribe; | ||||||
|  | 
 | ||||||
|  | #[derive(Parser, Debug)] | ||||||
|  | #[command(author, version, about, long_about = None)] | ||||||
|  | struct Args { | ||||||
|  |     /// This is the model for Whisper STT
 | ||||||
|  |     #[arg(short, long, value_parser, required = true)] | ||||||
|  |     model_path: PathBuf, | ||||||
|  |     
 | ||||||
|  |     /// This is the wav audio file that will be converted from speech to text
 | ||||||
|  |     #[arg(short, long, value_parser, required = true)] | ||||||
|  |     file_path: Option<PathBuf>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn main() { | ||||||
|  | 
 | ||||||
|  |     let args = Args::parse(); | ||||||
|  | 
 | ||||||
|  |     let file_path = match args.file_path { | ||||||
|  |         Some(fp) => fp, | ||||||
|  |         None => panic!("No file path provided") | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let result = transcribe(&args.model_path, &file_path); | ||||||
|  | 
 | ||||||
|  |     match result { | ||||||
|  |         Ok(transcription) => print!("{}", transcription), | ||||||
|  |         Err(e) => panic!("Error: {}", e), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,64 @@ | ||||||
|  | use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; | ||||||
|  | use std::path::PathBuf; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | /// Transcribes the given audio file using the whisper-rs library.
 | ||||||
|  | ///
 | ||||||
|  | /// # Arguments
 | ||||||
|  | /// * `model_path` - Path to Whisper model file
 | ||||||
|  | /// * `file_path` - A string slice that holds the path to the audio file to be transcribed.
 | ||||||
|  | ///
 | ||||||
|  | /// # Returns
 | ||||||
|  | ///
 | ||||||
|  | /// A Result containing a String with the transcription if successful, or an error message if not.
 | ||||||
|  | pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result<String, String> { | ||||||
|  | 
 | ||||||
|  |     let model_path_str = model_path.to_str().expect("Not valid model path"); | ||||||
|  |     // Load a context and model
 | ||||||
|  |     let ctx = WhisperContext::new_with_params( | ||||||
|  |         model_path_str, // Replace with the actual path to the model
 | ||||||
|  |         WhisperContextParameters::default(), | ||||||
|  |     ) | ||||||
|  |     .map_err(|_| "failed to load model")?; | ||||||
|  | 
 | ||||||
|  |     // Create a state
 | ||||||
|  |     let mut state = ctx.create_state().map_err(|_| "failed to create state")?; | ||||||
|  | 
 | ||||||
|  |     // Create a params object
 | ||||||
|  |     // Note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
 | ||||||
|  |     let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); | ||||||
|  | 
 | ||||||
|  |     // Edit parameters as needed
 | ||||||
|  |     params.set_n_threads(1); // Set the number of threads to use
 | ||||||
|  |     params.set_translate(true); // Enable translation
 | ||||||
|  |     params.set_language(Some("en")); // Set the language to translate to English
 | ||||||
|  |     // Disable printing to stdout
 | ||||||
|  |     params.set_print_special(false); | ||||||
|  |     params.set_print_progress(false); | ||||||
|  |     params.set_print_realtime(false); | ||||||
|  |     params.set_print_timestamps(false); | ||||||
|  | 
 | ||||||
|  |     // Load the audio file
 | ||||||
|  |     let audio_data = std::fs::read(file_path) | ||||||
|  |         .map_err(|e| format!("failed to read audio file: {}", e))? | ||||||
|  |         .chunks_exact(2) | ||||||
|  |         .map(|chunk| i16::from_ne_bytes([chunk[0], chunk[1]])) | ||||||
|  |         .collect::<Vec<i16>>(); | ||||||
|  | 
 | ||||||
|  |     // Convert the audio data to the required format (16KHz mono i16 samples)
 | ||||||
|  |     let audio_data = whisper_rs::convert_integer_to_float_audio(&audio_data); | ||||||
|  | 
 | ||||||
|  |     // Run the model
 | ||||||
|  |     state.full(params, &audio_data[..]).map_err(|_| "failed to run model")?; | ||||||
|  | 
 | ||||||
|  |     // Fetch the results
 | ||||||
|  |     let num_segments = state.full_n_segments().map_err(|_| "failed to get number of segments")?; | ||||||
|  |     let mut transcription = String::new(); | ||||||
|  |     for i in 0..num_segments { | ||||||
|  |         let segment = state.full_get_segment_text(i).map_err(|_| "failed to get segment")?; | ||||||
|  |         transcription.push_str(&segment); | ||||||
|  |         transcription.push('\n'); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(transcription) | ||||||
|  | } | ||||||
|  | @ -0,0 +1,40 @@ | ||||||
|  | import argparse | ||||||
|  | import asyncio | ||||||
|  | import websockets | ||||||
|  | import os | ||||||
|  | import json | ||||||
|  | 
 | ||||||
|  | # Define the function to send audio file in chunks | ||||||
|  | async def send_audio_in_chunks(file_path, chunk_size=4096): | ||||||
|  |     async with websockets.connect("ws://localhost:8000/a") as websocket: | ||||||
|  |         # Send the start command with mime type | ||||||
|  |         await websocket.send(json.dumps({"action": "command", "state": "start", "mimeType": "audio/webm"})) | ||||||
|  | 
 | ||||||
|  |         # Open the file in binary mode and send in chunks | ||||||
|  |         with open(file_path, 'rb') as audio_file: | ||||||
|  |             chunk = audio_file.read(chunk_size) | ||||||
|  |             while chunk: | ||||||
|  |                 await websocket.send(chunk) | ||||||
|  |                 chunk = audio_file.read(chunk_size) | ||||||
|  | 
 | ||||||
|  |         # Send the end command | ||||||
|  |         await websocket.send(json.dumps({"action": "command", "state": "end"})) | ||||||
|  | 
 | ||||||
|  |         # Receive a json message and then close the connection | ||||||
|  |         message = await websocket.recv() | ||||||
|  |         print("Received message:", json.loads(message)) | ||||||
|  |         await websocket.close() | ||||||
|  | 
 | ||||||
|  | # Parse command line arguments | ||||||
|  | parser = argparse.ArgumentParser(description="Send a webm audio file to the /a websocket endpoint and print the responses.") | ||||||
|  | parser.add_argument("file_path", help="The path to the webm audio file to send.") | ||||||
|  | args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  | # Check if the file exists | ||||||
|  | if not os.path.isfile(args.file_path): | ||||||
|  |     print(args.file_path) | ||||||
|  |     print("Error: The file does not exist.") | ||||||
|  |     exit(1) | ||||||
|  | 
 | ||||||
|  | # Run the asyncio event loop | ||||||
|  | asyncio.get_event_loop().run_until_complete(send_audio_in_chunks(args.file_path)) | ||||||
|  | @ -2,4 +2,7 @@ git+https://github.com/KillianLucas/open-interpreter.git | ||||||
| redis==5.0.1 | redis==5.0.1 | ||||||
| fastapi==0.109.0 | fastapi==0.109.0 | ||||||
| uvicorn==0.27.0.post1 | uvicorn==0.27.0.post1 | ||||||
| websockets==12.0 | websockets==12.0 | ||||||
|  | pydub==0.25.1 | ||||||
|  | numpy==1.26.3 | ||||||
|  | python-dotenv==1.0.1 | ||||||
		Loading…
	
		Reference in New Issue
	
	 Zohaib Rauf
						Zohaib Rauf