Added service for transcription
This commit is contained in:
		
							parent
							
								
									2593ef83c3
								
							
						
					
					
						commit
						3e10fb80fa
					
				|  | @ -1 +1,2 @@ | |||
| .env | ||||
| .DS_Store | ||||
|  | @ -0,0 +1,9 @@ | |||
| 
 | ||||
| # Setup | ||||
| 
 | ||||
| 1. Install Rust and Python dependencies | ||||
| 2. Go to core/stt and run `cargo build --release` | ||||
| 3. Download GGML Whisper model from https://huggingface.co/ggerganov/whisper.cpp | ||||
| 4. Copy .env.example to .env and put the path to model | ||||
| 5. Run `python core/i_endpoint.py` to start the server | ||||
| 6. Run `python core/test_cli.py PATH_TO_FILE` to test sending audio to service and getting transcription back over websocket | ||||
|  | @ -0,0 +1 @@ | |||
| WHISPER_MODEL_PATH=/path/to/ggml-tiny.en.bin | ||||
|  | @ -1,6 +1,13 @@ | |||
| from fastapi import FastAPI, Request, WebSocket | ||||
| import uvicorn | ||||
| import redis | ||||
| import json | ||||
| from dotenv import load_dotenv | ||||
| from stt import get_transcription | ||||
| import tempfile | ||||
| 
 | ||||
| # Load environment variables | ||||
| load_dotenv() | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
|  | @ -26,5 +33,47 @@ async def i(request: Request): | |||
|         "content": message | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| @app.websocket("/a") | ||||
| async def a(ws: WebSocket): | ||||
|     await ws.accept() | ||||
|     audio_file = bytearray() | ||||
|     mime_type = None | ||||
| 
 | ||||
|     try: | ||||
|         while True: | ||||
|             message = await ws.receive() | ||||
| 
 | ||||
|             if message['type'] == 'websocket.disconnect': | ||||
|                 break | ||||
| 
 | ||||
|             if message['type'] == 'websocket.receive': | ||||
|                 if 'text' in message: | ||||
|                     control_message = json.loads(message['text']) | ||||
|                     if control_message.get('action') == 'command' and control_message.get('state') == 'start' and 'mimeType' in control_message: | ||||
|                         # This indicates the start of a new audio file | ||||
|                         mime_type = control_message.get('mimeType') | ||||
|                     elif control_message.get('action') == 'command' and control_message.get('state') == 'end': | ||||
|                         # This indicates the end of the audio file | ||||
|                         # Process the complete audio file here | ||||
|                         transcription = get_transcription(audio_file, mime_type) | ||||
|                         await ws.send_json({"transcript": transcription}) | ||||
|                          | ||||
|                         print("SENT TRANSCRIPTION!") | ||||
| 
 | ||||
|                         # Reset the bytearray for the next audio file | ||||
|                         audio_file = bytearray() | ||||
|                         mime_type = None | ||||
|                 elif 'bytes' in message: | ||||
|                     # If it's not a control message, it's part of the audio file | ||||
|                     audio_file.extend(message['bytes']) | ||||
|     except Exception as e: | ||||
|         print(f"WebSocket connection closed with exception: {e}") | ||||
|     finally: | ||||
|         await ws.close() | ||||
|         print("WebSocket connection closed") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     with tempfile.TemporaryDirectory(): | ||||
|         uvicorn.run(app, host="0.0.0.0", port=8000) | ||||
|  | @ -0,0 +1,55 @@ | |||
| from datetime import datetime | ||||
| import os | ||||
| import contextlib | ||||
| import tempfile | ||||
| import ffmpeg | ||||
| import subprocess | ||||
| 
 | ||||
| def convert_mime_type_to_format(mime_type: str) -> str: | ||||
|     if mime_type == "audio/x-wav": | ||||
|         return "wav" | ||||
|     if mime_type == "audio/webm": | ||||
|         return "webm" | ||||
| 
 | ||||
|     return mime_type | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: | ||||
|     temp_dir = tempfile.gettempdir() | ||||
| 
 | ||||
|     # Create a temporary file with the appropriate extension | ||||
|     input_ext = convert_mime_type_to_format(mime_type) | ||||
|     input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}") | ||||
|     with open(input_path, 'wb') as f: | ||||
|         f.write(audio) | ||||
| 
 | ||||
|     # Export to wav | ||||
|     output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") | ||||
|     ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run() | ||||
| 
 | ||||
|     print(f"Temporary file path: {output_path}") | ||||
| 
 | ||||
|     try: | ||||
|         yield output_path | ||||
|     finally: | ||||
|         os.remove(input_path) | ||||
|         #os.remove(output_path) | ||||
| 
 | ||||
| def run_command(command): | ||||
|     result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | ||||
|     return result.stdout, result.stderr | ||||
| 
 | ||||
| def get_transcription(audio_bytes: bytearray, mime_type): | ||||
|     with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: | ||||
|         model_path = os.getenv("WHISPER_MODEL_PATH") | ||||
|         if not model_path: | ||||
|             raise EnvironmentError("WHISPER_MODEL_PATH environment variable is not set.") | ||||
| 
 | ||||
|         output, error = run_command([ | ||||
|             os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release', 'whisper-rust'), | ||||
|             '--model-path', model_path, | ||||
|             '--file-path', wav_file_path | ||||
|         ]) | ||||
| 
 | ||||
|         print("Exciting transcription result:", output) | ||||
|         return output | ||||
										
											Binary file not shown.
										
									
								
							|  | @ -0,0 +1 @@ | |||
| target | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,14 @@ | |||
| [package] | ||||
| name = "whisper-rust" | ||||
| version = "0.1.0" | ||||
| edition = "2021" | ||||
| 
 | ||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||
| 
 | ||||
| [dependencies] | ||||
| anyhow = "1.0.79" | ||||
| clap = { version = "4.4.18", features = ["derive"] } | ||||
| cpal = "0.15.2" | ||||
| hound = "3.5.1" | ||||
| whisper-rs = "0.10.0" | ||||
| whisper-rs-sys = "0.8.0" | ||||
|  | @ -0,0 +1,34 @@ | |||
| mod transcribe; | ||||
| 
 | ||||
| use clap::Parser; | ||||
| use std::path::PathBuf; | ||||
| use transcribe::transcribe; | ||||
| 
 | ||||
| #[derive(Parser, Debug)] | ||||
| #[command(author, version, about, long_about = None)] | ||||
| struct Args { | ||||
|     /// This is the model for Whisper STT
 | ||||
|     #[arg(short, long, value_parser, required = true)] | ||||
|     model_path: PathBuf, | ||||
|     
 | ||||
|     /// This is the wav audio file that will be converted from speech to text
 | ||||
|     #[arg(short, long, value_parser, required = true)] | ||||
|     file_path: Option<PathBuf>, | ||||
| } | ||||
| 
 | ||||
| fn main() { | ||||
| 
 | ||||
|     let args = Args::parse(); | ||||
| 
 | ||||
|     let file_path = match args.file_path { | ||||
|         Some(fp) => fp, | ||||
|         None => panic!("No file path provided") | ||||
|     }; | ||||
| 
 | ||||
|     let result = transcribe(&args.model_path, &file_path); | ||||
| 
 | ||||
|     match result { | ||||
|         Ok(transcription) => print!("{}", transcription), | ||||
|         Err(e) => panic!("Error: {}", e), | ||||
|     } | ||||
| } | ||||
|  | @ -0,0 +1,64 @@ | |||
| use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; | ||||
| use std::path::PathBuf; | ||||
| 
 | ||||
| 
 | ||||
| /// Transcribes the given audio file using the whisper-rs library.
 | ||||
| ///
 | ||||
| /// # Arguments
 | ||||
| /// * `model_path` - Path to Whisper model file
 | ||||
| /// * `file_path` - A string slice that holds the path to the audio file to be transcribed.
 | ||||
| ///
 | ||||
| /// # Returns
 | ||||
| ///
 | ||||
| /// A Result containing a String with the transcription if successful, or an error message if not.
 | ||||
| pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result<String, String> { | ||||
| 
 | ||||
|     let model_path_str = model_path.to_str().expect("Not valid model path"); | ||||
|     // Load a context and model
 | ||||
|     let ctx = WhisperContext::new_with_params( | ||||
|         model_path_str, // Replace with the actual path to the model
 | ||||
|         WhisperContextParameters::default(), | ||||
|     ) | ||||
|     .map_err(|_| "failed to load model")?; | ||||
| 
 | ||||
|     // Create a state
 | ||||
|     let mut state = ctx.create_state().map_err(|_| "failed to create state")?; | ||||
| 
 | ||||
|     // Create a params object
 | ||||
|     // Note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
 | ||||
|     let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); | ||||
| 
 | ||||
|     // Edit parameters as needed
 | ||||
|     params.set_n_threads(1); // Set the number of threads to use
 | ||||
|     params.set_translate(true); // Enable translation
 | ||||
|     params.set_language(Some("en")); // Set the language to translate to English
 | ||||
|     // Disable printing to stdout
 | ||||
|     params.set_print_special(false); | ||||
|     params.set_print_progress(false); | ||||
|     params.set_print_realtime(false); | ||||
|     params.set_print_timestamps(false); | ||||
| 
 | ||||
|     // Load the audio file
 | ||||
|     let audio_data = std::fs::read(file_path) | ||||
|         .map_err(|e| format!("failed to read audio file: {}", e))? | ||||
|         .chunks_exact(2) | ||||
|         .map(|chunk| i16::from_ne_bytes([chunk[0], chunk[1]])) | ||||
|         .collect::<Vec<i16>>(); | ||||
| 
 | ||||
|     // Convert the audio data to the required format (16KHz mono i16 samples)
 | ||||
|     let audio_data = whisper_rs::convert_integer_to_float_audio(&audio_data); | ||||
| 
 | ||||
|     // Run the model
 | ||||
|     state.full(params, &audio_data[..]).map_err(|_| "failed to run model")?; | ||||
| 
 | ||||
|     // Fetch the results
 | ||||
|     let num_segments = state.full_n_segments().map_err(|_| "failed to get number of segments")?; | ||||
|     let mut transcription = String::new(); | ||||
|     for i in 0..num_segments { | ||||
|         let segment = state.full_get_segment_text(i).map_err(|_| "failed to get segment")?; | ||||
|         transcription.push_str(&segment); | ||||
|         transcription.push('\n'); | ||||
|     } | ||||
| 
 | ||||
|     Ok(transcription) | ||||
| } | ||||
|  | @ -0,0 +1,40 @@ | |||
| import argparse | ||||
| import asyncio | ||||
| import websockets | ||||
| import os | ||||
| import json | ||||
| 
 | ||||
| # Define the function to send audio file in chunks | ||||
| async def send_audio_in_chunks(file_path, chunk_size=4096): | ||||
|     async with websockets.connect("ws://localhost:8000/a") as websocket: | ||||
|         # Send the start command with mime type | ||||
|         await websocket.send(json.dumps({"action": "command", "state": "start", "mimeType": "audio/webm"})) | ||||
| 
 | ||||
|         # Open the file in binary mode and send in chunks | ||||
|         with open(file_path, 'rb') as audio_file: | ||||
|             chunk = audio_file.read(chunk_size) | ||||
|             while chunk: | ||||
|                 await websocket.send(chunk) | ||||
|                 chunk = audio_file.read(chunk_size) | ||||
| 
 | ||||
|         # Send the end command | ||||
|         await websocket.send(json.dumps({"action": "command", "state": "end"})) | ||||
| 
 | ||||
|         # Receive a json message and then close the connection | ||||
|         message = await websocket.recv() | ||||
|         print("Received message:", json.loads(message)) | ||||
|         await websocket.close() | ||||
| 
 | ||||
| # Parse command line arguments | ||||
| parser = argparse.ArgumentParser(description="Send a webm audio file to the /a websocket endpoint and print the responses.") | ||||
| parser.add_argument("file_path", help="The path to the webm audio file to send.") | ||||
| args = parser.parse_args() | ||||
| 
 | ||||
| # Check if the file exists | ||||
| if not os.path.isfile(args.file_path): | ||||
|     print(args.file_path) | ||||
|     print("Error: The file does not exist.") | ||||
|     exit(1) | ||||
| 
 | ||||
| # Run the asyncio event loop | ||||
| asyncio.get_event_loop().run_until_complete(send_audio_in_chunks(args.file_path)) | ||||
|  | @ -3,3 +3,6 @@ redis==5.0.1 | |||
| fastapi==0.109.0 | ||||
| uvicorn==0.27.0.post1 | ||||
| websockets==12.0 | ||||
| pydub==0.25.1 | ||||
| numpy==1.26.3 | ||||
| python-dotenv==1.0.1 | ||||
		Loading…
	
		Reference in New Issue
	
	 Zohaib Rauf
						Zohaib Rauf