专栏名称: GitHubStore
分享有意思的开源项目
目录
相关文章推荐
地刊速览  ·  EPSL:古太平洋的缺氧事件 ·  3 小时前  
地刊速览  ·  EPSL:古太平洋的缺氧事件 ·  3 小时前  
51好读  ›  专栏  ›  GitHubStore

ASR 速度革新!将 Whisper 推理生成速度提高 150%

GitHubStore  · 公众号  ·  · 2024-08-14 08:58

正文

项目简介

ASR 速度革新!将 Whisper 推理生成速度提高 150% ⚡️ 同时带来最小性能损耗的 Medusa Heads 加持的 whisper-medusa 开源 🔥

Medusa 是一个加速 LLM 推理速度的框架,可以与任意微调模型整合,提速 2.2~3.6x 的推理速度




Whisper 是一种高级的编码器-解码器模型,用于语音转录和翻译,通过编码和解码阶段处理音频。鉴于其庞大的规模和缓慢的推理速度,已经提出了诸如 Faster-Whisper 和推测性解码等优化策略来提高性能。我们的 Medusa 模型在 Whisper 的基础上通过每迭代预测多个令牌,显著提高了速度,同时在 WER 上略有下降。我们使用 LibriSpeech 数据集对模型进行训练和评估,证明了与原始 Whisper 模型相比,具有相同比例准确性的强性能速度改进。


Whisper Medusa 架构



培训和评估详情

Whisper Medusa基于带有 10 个美杜莎头的 Whisper 大型模型。它在 LibriSpeech 数据集上进行训练以执行音频翻译。美杜莎头针对英语进行了优化,因此为了获得最佳性能和速度提升,请仅使用英语音频。

平均而言,Whisper Medusa 的生成速度比 Whisper vanilla 快 1.5 倍,WER 相同(分别为 4.2%和 4%)。


Whisper Medusa speedup compared to Whisper vanilla.
Whisper Medusa 与 Whisper vanilla 的加速比。



安装

从创建虚拟环境并激活它开始:

conda create -n whisper-medusa python=3.11 -yconda activate whisper-medusapip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118

然后安装软件包:

git clone https://github.com/aiola-lab/whisper-medusa.gitcd whisper-medusapip install -e .


使用方法

推理可以通过以下代码完成:

import torchimport torchaudio
from whisper_medusa import WhisperMedusaModelfrom transformers import WhisperProcessor
model_name = "aiola/whisper-medusa-v1"model = WhisperMedusaModel.from_pretrained(model_name)processor = WhisperProcessor.from_pretrained(model_name)
path_to_audio = "path/to/audio.wav"SAMPLING_RATE = 16000language = "en"device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_speech, sr = torchaudio.load(path_to_audio)if input_speech.shape[0] > 1: # If stereo, average the channels input_speech = input_speech.mean(dim=0, keepdim=True)
if sr != SAMPLING_RATE: input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_featuresinput_features = input_features.to(device)
model = model.to(device)model_output = model.generate( input_features, language=language,)predict_ids = model_output[0]pred = processor.decode(predict_ids, skip_special_tokens=True)print(pred)

模型评估

为了评估模型,我们假设有一个 csv 文件,包含以下列:

  • audio : 音频文件的路径。

  • sentence : 对应的转录。

  • language : 音频文件的语言。


然后运行以下命令:

python whisper_medusa/eval_whisper_medusa.py \--model-name /path/to/model \--data-path /path/to/data \--out-file-path /path/to/output \--language en

参数描述:参数翻译直接给出,无需附加文本

  • model-name : 当地模型的路径 / Huggingface 存储库。

  • data-path : 数据的路径。

  • out-file-path : 输出文件的路径。







请到「今天看啥」查看全文