Fixed local whisper
This commit is contained in:
parent
166706d203
commit
4c82587db8
|
@ -16,11 +16,11 @@ import subprocess
|
||||||
|
|
||||||
class Stt:
|
class Stt:
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
service_directory = config["service_directory"]
|
self.service_directory = config["service_directory"]
|
||||||
install(service_directory)
|
install(self.service_directory)
|
||||||
|
|
||||||
def stt(self, audio_file_path):
|
def stt(self, audio_file_path):
|
||||||
return stt(audio_file_path)
|
return stt(self.service_directory, audio_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,14 +109,12 @@ def run_command(command):
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
return result.stdout, result.stderr
|
return result.stdout, result.stderr
|
||||||
|
|
||||||
def get_transcription_file(wav_file_path: str):
|
def get_transcription_file(service_directory, wav_file_path: str):
|
||||||
local_path = os.path.join(os.path.dirname(__file__), 'model')
|
local_path = os.path.join(service_directory, 'model')
|
||||||
whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release')
|
whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release')
|
||||||
model_name = os.getenv('WHISPER_MODEL_NAME')
|
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
|
||||||
if not model_name:
|
|
||||||
raise EnvironmentError("WHISPER_MODEL_NAME environment variable is not set.")
|
|
||||||
|
|
||||||
output, error = run_command([
|
output, _ = run_command([
|
||||||
os.path.join(whisper_rust_path, 'whisper-rust'),
|
os.path.join(whisper_rust_path, 'whisper-rust'),
|
||||||
'--model-path', os.path.join(local_path, model_name),
|
'--model-path', os.path.join(local_path, model_name),
|
||||||
'--file-path', wav_file_path
|
'--file-path', wav_file_path
|
||||||
|
@ -124,28 +122,16 @@ def get_transcription_file(wav_file_path: str):
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_transcription_bytes(audio_bytes: bytearray, mime_type):
|
|
||||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
|
||||||
return get_transcription_file(wav_file_path)
|
|
||||||
|
|
||||||
def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"):
|
def stt_wav(service_directory, wav_file_path: str):
|
||||||
with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path:
|
|
||||||
return stt_wav(wav_file_path)
|
|
||||||
|
|
||||||
def stt_wav(wav_file_path: str):
|
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
|
||||||
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
|
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
|
||||||
try:
|
try:
|
||||||
transcript = get_transcription_file(output_path)
|
transcript = get_transcription_file(service_directory, output_path)
|
||||||
finally:
|
finally:
|
||||||
os.remove(output_path)
|
os.remove(output_path)
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
def stt(input_data, mime_type="audio/wav"):
|
def stt(service_directory, input_data):
|
||||||
if isinstance(input_data, str):
|
return stt_wav(service_directory, input_data)
|
||||||
return stt_wav(input_data)
|
|
||||||
elif isinstance(input_data, bytearray):
|
|
||||||
return stt_bytes(input_data, mime_type)
|
|
||||||
else:
|
|
||||||
raise ValueError("Input data should be either a path to a wav file (str) or audio bytes (bytearray)")
|
|
|
@ -68,9 +68,7 @@ def run_command(command):
|
||||||
def get_transcription_file(wav_file_path: str):
|
def get_transcription_file(wav_file_path: str):
|
||||||
local_path = os.path.join(os.path.dirname(__file__), 'local_service')
|
local_path = os.path.join(os.path.dirname(__file__), 'local_service')
|
||||||
whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release')
|
whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release')
|
||||||
model_name = os.getenv('WHISPER_MODEL_NAME')
|
model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
|
||||||
if not model_name:
|
|
||||||
raise EnvironmentError("WHISPER_MODEL_NAME environment variable is not set.")
|
|
||||||
|
|
||||||
output, error = run_command([
|
output, error = run_command([
|
||||||
os.path.join(whisper_rust_path, 'whisper-rust'),
|
os.path.join(whisper_rust_path, 'whisper-rust'),
|
||||||
|
|
|
@ -40,7 +40,7 @@ def run(
|
||||||
|
|
||||||
if local:
|
if local:
|
||||||
tts_service = "piper"
|
tts_service = "piper"
|
||||||
llm_service = "llamafile"
|
# llm_service = "llamafile"
|
||||||
stt_service = "local-whisper"
|
stt_service = "local-whisper"
|
||||||
|
|
||||||
if not server_url:
|
if not server_url:
|
||||||
|
|
Loading…
Reference in New Issue