项目简介
ASR 速度革新!将 Whisper 推理生成速度提高 150% ⚡️ 同时带来最小性能损耗的 Medusa Heads 加持的 whisper-medusa 开源 🔥
Medusa 是一个加速 LLM 推理速度的框架,可以与任意微调模型整合,提速 2.2~3.6x 的推理速度
Whisper 是一种高级的编码器-解码器模型,用于语音转录和翻译,通过编码和解码阶段处理音频。鉴于其庞大的规模和缓慢的推理速度,已经提出了诸如 Faster-Whisper 和推测性解码等优化策略来提高性能。我们的 Medusa 模型在 Whisper 的基础上通过每迭代预测多个令牌,显著提高了速度,同时在 WER 上略有下降。我们使用 LibriSpeech 数据集对模型进行训练和评估,证明了与原始 Whisper 模型相比,具有相同比例准确性的强性能速度改进。
Whisper Medusa
架构
培训和评估详情
Whisper Medusa基于带有 10 个美杜莎头的 Whisper 大型模型。它在 LibriSpeech 数据集上进行训练以执行音频翻译。美杜莎头针对英语进行了优化,因此为了获得最佳性能和速度提升,请仅使用英语音频。
平均而言,Whisper Medusa 的生成速度比 Whisper vanilla 快 1.5 倍,WER 相同(分别为 4.2%和 4%)。
Whisper Medusa speedup compared to Whisper vanilla.
Whisper Medusa 与 Whisper vanilla 的加速比。
安装
从创建虚拟环境并激活它开始:
conda create -n whisper-medusa python=3.11 -y
conda activate whisper-medusa
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2
然后安装软件包:
git clone https://github.com/aiola-lab/whisper-medusa.git
cd whisper-medusa
pip install -e .
使用方法
推理可以通过以下代码完成:
import torch
import torchaudio
from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor
model_name = "aiola/whisper-medusa-v1"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
path_to_audio = "path/to/audio.wav"
SAMPLING_RATE = 16000
language = "en"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_speech, sr = torchaudio.load(path_to_audio)
if input_speech.shape[0] > 1:
input_speech = input_speech.mean(dim=0, keepdim=True)
if sr != SAMPLING_RATE:
input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
input_features = input_features.to(device)
model = model.to(device)
model_output = model.generate(
input_features,
language=language,
)
predict_ids = model_output[0]
pred = processor.decode(predict_ids, skip_special_tokens=True)
print(pred)
模型评估
为了评估模型,我们假设有一个 csv 文件,包含以下列:
-
audio
: 音频文件的路径。
-
sentence
: 对应的转录。
-
language
: 音频文件的语言。
然后运行以下命令:
python whisper_medusa/eval_whisper_medusa.py \
--model-name /path/to/model \
--data-path /path/to/data \
--out-file-path /path/to/output \
--language en
参数描述:参数翻译直接给出,无需附加文本