在自然语言处理领域,
重排序模型(Reranker Models)
扮演着重要角色。无论是检索增强生成、语义搜索,还是文本相似性评估,重排序模型都能提供关键支持。
而今天,我们要聊聊如何使用 Sentence Transformers v4 来训练和微调这些强大的模型。
Sentence Transformers 是一个强大的 Python 库,广泛应用于各种自然语言处理任务,包括但不限于检索增强生成、语义搜索、语义文本相似性、释义挖掘等。
最近,它迎来了 v4.0 版本更新,为重排序模型(也称为交叉编码器模型)带来了全新的训练方法,这与 v3.0 版本为嵌入模型引入的训练方法类似。
https://huggingface.co/blog/train-reranker
在今天的内容中,我们将深入探讨如何使用 Sentence Transformers v4 微调重排序模型,使其在你的数据上超越所有现有的选项。不仅如此,这种方法还可以从零开始训练出极其强大的新重排序模型。
重排序模型是什么?
重排序模型通常是基于
交叉编码器(Cross Encoder)
架构实现的。它的核心思想是把两段文本(比如一个问题和一个文档,或者两个句子)放在一起处理,通过一个共享的神经网络,最终给出一个分数,表示这两段文本的相关性。这和另一种常见的模型——Sentence Transformers(也叫双编码器或嵌入模型)很不一样。
Sentence Transformers 是把每段文本单独转换成向量,然后通过计算向量之间的距离来判断相似性。而重排序模型则是让两段文本“相互关注”,一起通过网络处理。这种设计让重排序模型在判断文本相关性时更加精准,但代价是计算速度会慢一些。
重排序模型适合“精挑细选”?
重排序模型虽然计算速度慢,但它在某些场景下非常有用。比如,当你已经有了一个初步的搜索结果列表(比如通过 Sentence Transformers 搜索出来的),但这个列表可能包含很多相似的内容,这时候就需要重排序模型来“精挑细选”,找出最相关的结果。
微调重排序模型:让模型更懂你
虽然重排序模型本身已经很强大,但它的“通用性”也带来了一个问题:它可能在很多领域都表现得不错,但在你最关心的领域却未必能达到最佳效果。比如,一个通用的重排序模型可能对各种话题都能处理,但如果你的业务是医学研究,它可能就不够精准。
训练重排序模型
训练一个强大的重排序模型就像搭建一座桥梁,需要多个关键组件的协同合作。
Hugging Face Datasets Hub 是一个强大的数据集资源库,提供了大量的公开数据集,这些数据集可以直接用于训练和评估重排序模型。许多数据集已经被标注为与 Sentence Transformers 兼容,这意味着它们可以直接用于训练重排序模型。
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
features: ['query', 'answer'],
num_rows: 100231
})
"""
当然也可以通过本地文件加载:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
# or
dataset = load_dataset("json", data_files="my_file.json")
数据集格式必须与你选择的损失函数(以及模型)相匹配,否则训练过程可能会出错,或者模型无法正确学习。
损失函数通常需要特定数量的输入列。在数据集中,所有不命名为
“label”
或
“scores”
的列都被视为输入列。
Hard Negatives Mining
在训练重排序模型时,负样本的质量往往决定了模型的性能。负样本是指那些与查询不相关的文本片段,模型需要学会将这些片段的得分降低。负样本可以分为两类:软负样本和硬负样本。
-
软负样本(Soft Negatives)
:与查询完全无关的文本片段,也称为“容易的负样本”。例如,查询是“苹果公司在哪里成立的?”软负样本可能是“阿肯色州的 Cache River Bridge 是一座跨越 Cache River 的桥梁。”
-
硬负样本(Hard Negatives)
:看起来可能与查询相关,但实际上并不相关的文本片段。例如,查询是“苹果公司在哪里成立的?”硬负样本可能是“富士苹果是一种在 20 世纪 30 年代末开发的苹果品种,于 1962 年推向市场。”
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives
# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
print(train_dataset)
# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5, # How many negatives per question-answer pair
range_min=10, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
max_score=0.8, # Only consider samples with a similarity score of at most x
margin=0.1, # Similarity between query and negative samples should be x lower than query-positive similarity
sampling_strategy="top", # Randomly sample negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True, # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
损失函数是训练过程中的“指南针”,它告诉模型哪些方向是正确的,哪些是错误的。具体来说,损失函数的作用包括:
-
衡量模型性能
:损失函数通过计算模型的输出和真实标签之间的差异,来衡量模型的性能。差异越小,说明模型的性能越好。
-
指导优化过程
:损失函数的值会通过优化算法(如梯度下降)来调整模型的参数,使损失值最小化。不同的损失函数可能导致模型学习到不同的特征。
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss
# Load a model to train/finetune
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 is for rerankers
# Initialize the CachedMultipleNegativesRankingLoss, which requires pairs of
# related texts or triplets
loss = CachedMultipleNegativesRankingLoss(model)
# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")
Training Arguments:训练关键参数
训练参数就像是训练过程中的“调音器”,它们可以帮助你控制训练的速度、稳定性和效果。通过调整这些参数,你可以:
-
加快训练速度
:通过合理设置学习率、批次大小等参数,可以显著提高训练效率。
-
优化模型性能
:通过调整训练轮数、学习率调度等参数,可以让模型更好地学习数据中的特征。
-
监控训练过程
:通过设置日志记录、评估间隔等参数,可以实时了解训练进度和模型性能。
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir="models/reranker-MiniLM-msmarco-v1",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True,
# Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # losses that use "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="reranker-MiniLM-msmarco-v1", # Will be used in W&B if `wandb` is installed
)
多数据集训练(Multi-Dataset Training)
多数据集训练是提升通用模型性能的有效方法。通过在多个数据集上同时训练,模型可以学习到更广泛的数据特征,从而提高泛化能力。
CrossEncoderTrainer
支持多数据集训练,并且允许对每个数据集使用不同的损失函数。
-
准备数据集
:使用一个字典来存储多个数据集,字典的键是数据集的名称,值是
datasets.Dataset
实例。
-
选择损失函数
:如果需要对每个数据集使用不同的损失函数,可以准备一个字典,将数据集名称映射到对应的损失函数。
-
设置采样策略
:通过
CrossEncoderTrainingArguments
的
multi_dataset_batch_sampler
参数选择采样策略。