Merge branch 'main' into feature/debug-logging
# Conflicts: # OS/01/conversations/user.json # OS/01/device.py
This commit is contained in:
commit
4ea0d0f841
|
@ -1,3 +1,5 @@
|
|||
ggml-*.bin
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
|
|
@ -144,19 +144,23 @@ async def websocket_communication(WS_URL):
|
|||
logging.info("Press the spacebar to start/stop recording. Press ESC to exit.")
|
||||
asyncio.create_task(message_sender(websocket))
|
||||
|
||||
message_so_far = {"role": None, "type": None, "format": None, "content": None}
|
||||
initial_message = {"role": None, "type": None, "format": None, "content": None}
|
||||
message_so_far = initial_message
|
||||
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
|
||||
logging.debug(f"Got this message from the server: {type(message)} {message}")
|
||||
|
||||
logging.info(f"Got this message from the server: {type(message)} {message}")
|
||||
|
||||
if type(message) == str:
|
||||
message = json.loads(message)
|
||||
|
||||
if message.get("end"):
|
||||
logging.info(f"Complete message from the server: {message_so_far}")
|
||||
message_so_far = initial_message
|
||||
|
||||
if "content" in message:
|
||||
print(message['content'], end="", flush=True)
|
||||
if any(message_so_far[key] != message[key] for key in message_so_far):
|
||||
if any(message_so_far[key] != message[key] for key in message_so_far if key != "content"):
|
||||
message_so_far = message
|
||||
else:
|
||||
message_so_far["content"] += message["content"]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "whisper-rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.79"
|
||||
clap = { version = "4.4.18", features = ["derive"] }
|
||||
cpal = "0.15.2"
|
||||
hound = "3.5.1"
|
||||
whisper-rs = "0.10.0"
|
||||
whisper-rs-sys = "0.8.0"
|
|
@ -0,0 +1,7 @@
|
|||
# Setup
|
||||
|
||||
To rebuild the `whisper-rust` executable, do the following:
|
||||
|
||||
1. Install [Rust](https://www.rust-lang.org/tools/install), cmake, and Python dependencies `pip install -r requirements.txt`.
|
||||
2. Go to **core/stt** and run `cargo build --release`.
|
||||
3. Move the `whisper-rust` executable from target/release to this directory.
|
|
@ -0,0 +1,34 @@
|
|||
mod transcribe;
|
||||
|
||||
use clap::Parser;
|
||||
use std::path::PathBuf;
|
||||
use transcribe::transcribe;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// This is the model for Whisper STT
|
||||
#[arg(short, long, value_parser, required = true)]
|
||||
model_path: PathBuf,
|
||||
|
||||
/// This is the wav audio file that will be converted from speech to text
|
||||
#[arg(short, long, value_parser, required = true)]
|
||||
file_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let file_path = match args.file_path {
|
||||
Some(fp) => fp,
|
||||
None => panic!("No file path provided")
|
||||
};
|
||||
|
||||
let result = transcribe(&args.model_path, &file_path);
|
||||
|
||||
match result {
|
||||
Ok(transcription) => print!("{}", transcription),
|
||||
Err(e) => panic!("Error: {}", e),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||
use std::path::PathBuf;
|
||||
|
||||
|
||||
/// Transcribes the given audio file using the whisper-rs library.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_path` - Path to Whisper model file
|
||||
/// * `file_path` - A string slice that holds the path to the audio file to be transcribed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A Result containing a String with the transcription if successful, or an error message if not.
|
||||
pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result<String, String> {
|
||||
|
||||
let model_path_str = model_path.to_str().expect("Not valid model path");
|
||||
// Load a context and model
|
||||
let ctx = WhisperContext::new_with_params(
|
||||
model_path_str, // Replace with the actual path to the model
|
||||
WhisperContextParameters::default(),
|
||||
)
|
||||
.map_err(|_| "failed to load model")?;
|
||||
|
||||
// Create a state
|
||||
let mut state = ctx.create_state().map_err(|_| "failed to create state")?;
|
||||
|
||||
// Create a params object
|
||||
// Note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
|
||||
// Edit parameters as needed
|
||||
params.set_n_threads(1); // Set the number of threads to use
|
||||
params.set_translate(true); // Enable translation
|
||||
params.set_language(Some("en")); // Set the language to translate to English
|
||||
// Disable printing to stdout
|
||||
params.set_print_special(false);
|
||||
params.set_print_progress(false);
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
|
||||
// Load the audio file
|
||||
let audio_data = std::fs::read(file_path)
|
||||
.map_err(|e| format!("failed to read audio file: {}", e))?
|
||||
.chunks_exact(2)
|
||||
.map(|chunk| i16::from_ne_bytes([chunk[0], chunk[1]]))
|
||||
.collect::<Vec<i16>>();
|
||||
|
||||
// Convert the audio data to the required format (16KHz mono i16 samples)
|
||||
let audio_data = whisper_rs::convert_integer_to_float_audio(&audio_data);
|
||||
|
||||
// Run the model
|
||||
state.full(params, &audio_data[..]).map_err(|_| "failed to run model")?;
|
||||
|
||||
// Fetch the results
|
||||
let num_segments = state.full_n_segments().map_err(|_| "failed to get number of segments")?;
|
||||
let mut transcription = String::new();
|
||||
for i in 0..num_segments {
|
||||
let segment = state.full_get_segment_text(i).map_err(|_| "failed to get segment")?;
|
||||
transcription.push_str(&segment);
|
||||
transcription.push('\n');
|
||||
}
|
||||
|
||||
Ok(transcription)
|
||||
}
|
Binary file not shown.
|
@ -157,7 +157,7 @@ async def listener():
|
|||
messages = json.load(file)
|
||||
messages.append(message)
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
json.dump(messages, file)
|
||||
json.dump(messages, file, indent=4)
|
||||
|
||||
accumulated_text = ""
|
||||
|
||||
|
@ -197,7 +197,7 @@ async def listener():
|
|||
await from_user.put(temp_message)
|
||||
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
json.dump(interpreter.messages, file)
|
||||
json.dump(interpreter.messages, file, indent=4)
|
||||
|
||||
logging.info("New user message recieved. Breaking.")
|
||||
break
|
||||
|
@ -206,11 +206,14 @@ async def listener():
|
|||
if not from_computer.empty():
|
||||
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
json.dump(interpreter.messages, file)
|
||||
json.dump(interpreter.messages, file, indent=4)
|
||||
|
||||
logging.info("New computer message recieved. Breaking.")
|
||||
break
|
||||
|
||||
else:
|
||||
with open(conversation_history_path, 'w') as file:
|
||||
json.dump(interpreter.messages, file, indent=4)
|
||||
|
||||
|
||||
async def stream_or_play_tts(sentence):
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
### SETTINGS
|
||||
|
||||
# If ALL_LOCAL is False, we'll use OpenAI's services
|
||||
# If setting ALL_LOCAL to true, set the path to the WHISPER local model
|
||||
export ALL_LOCAL=False
|
||||
# export WHISPER_MODEL_PATH=...
|
||||
# export OPENAI_API_KEY=sk-...
|
||||
|
||||
# If SERVER_START, this is where we'll serve the server.
|
||||
|
@ -33,18 +35,51 @@ fi
|
|||
|
||||
### START
|
||||
|
||||
# DEVICE
|
||||
|
||||
if [[ "$DEVICE_START" == "True" ]]; then
|
||||
start_device() {
|
||||
echo "Starting device..."
|
||||
python device.py &
|
||||
DEVICE_PID=$!
|
||||
echo "Device started as process $DEVICE_PID"
|
||||
}
|
||||
|
||||
# Function to start server
|
||||
start_server() {
|
||||
echo "Starting server..."
|
||||
python server.py &
|
||||
SERVER_PID=$!
|
||||
echo "Server started as process $SERVER_PID"
|
||||
}
|
||||
|
||||
stop_processes() {
|
||||
if [[ -n $DEVICE_PID ]]; then
|
||||
echo "Stopping device..."
|
||||
kill $DEVICE_PID
|
||||
fi
|
||||
if [[ -n $SERVER_PID ]]; then
|
||||
echo "Stopping server..."
|
||||
kill $SERVER_PID
|
||||
fi
|
||||
}
|
||||
|
||||
# Trap SIGINT and SIGTERM to stop processes when the script is terminated
|
||||
trap stop_processes SIGINT SIGTERM
|
||||
|
||||
# DEVICE
|
||||
# Start device if DEVICE_START is True
|
||||
if [[ "$DEVICE_START" == "True" ]]; then
|
||||
start_device
|
||||
fi
|
||||
|
||||
# SERVER
|
||||
|
||||
# Start server if SERVER_START is True
|
||||
if [[ "$SERVER_START" == "True" ]]; then
|
||||
python server.py &
|
||||
start_server
|
||||
fi
|
||||
|
||||
# Wait for device and server processes to exit
|
||||
wait $DEVICE_PID
|
||||
wait $SERVER_PID
|
||||
|
||||
# TTS, STT
|
||||
|
||||
# (todo)
|
||||
|
|
33
OS/01/stt.py
33
OS/01/stt.py
|
@ -48,6 +48,28 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
|
|||
os.remove(input_path)
|
||||
os.remove(output_path)
|
||||
|
||||
def run_command(command):
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
return result.stdout, result.stderr
|
||||
|
||||
def get_transcription_file(wav_file_path: str):
|
||||
model_path = os.getenv("WHISPER_MODEL_PATH")
|
||||
if not model_path:
|
||||
raise EnvironmentError("WHISPER_MODEL_PATH environment variable is not set.")
|
||||
|
||||
output, error = run_command([
|
||||
os.path.join(os.path.dirname(__file__), 'local_stt', 'whisper-rust', 'whisper-rust'),
|
||||
'--model-path', model_path,
|
||||
'--file-path', wav_file_path
|
||||
])
|
||||
|
||||
print("Exciting transcription result:", output)
|
||||
return output
|
||||
|
||||
def get_transcription_bytes(audio_bytes: bytearray, mime_type):
|
||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
||||
return get_transcription_file(wav_file_path)
|
||||
|
||||
def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"):
|
||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
||||
return stt_wav(wav_file_path)
|
||||
|
@ -69,8 +91,15 @@ def stt_wav(wav_file_path: str):
|
|||
logging.info(f"Transcription result: {transcript}")
|
||||
return transcript
|
||||
else:
|
||||
# Local whisper here, given `wav_file_path`
|
||||
pass
|
||||
temp_dir = tempfile.gettempdir()
|
||||
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
|
||||
try:
|
||||
transcript = get_transcription_file(output_path)
|
||||
print("Transcription result:", transcript)
|
||||
finally:
|
||||
os.remove(output_path)
|
||||
return transcript
|
||||
|
||||
def stt(input_data, mime_type="audio/wav"):
|
||||
if isinstance(input_data, str):
|
||||
|
|
|
@ -22,11 +22,12 @@ sudo apt-get install portaudio19-dev libav-tools
|
|||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
```bash
|
||||
cd OS/01
|
||||
```
|
||||
If you want to run local speech-to-text from whisper, download the GGML Whisper model from [Huggingface](https://huggingface.co/ggerganov/whisper.cpp). Then in `OS/01/start.sh`, set `ALL_LOCAL=TRUE` and set `WHISPER_MODEL_PATH` to the path of the model.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
cd OS/01
|
||||
bash start.sh
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue