专栏名称: 机器学习研究组订阅
连接人工智能技术人才和产业人才的交流平台
目录
相关文章推荐
爱可可-爱生活  ·  【[40星]CORAL:一个大规模的对话式检 ... ·  22 小时前  
爱可可-爱生活  ·  《爱可可微博热门分享(2.5)》 ... ·  20 小时前  
爱可可-爱生活  ·  【The End of Search, ... ·  昨天  
宝玉xp  ·  OpenAI 今天下午在 Reddit ... ·  5 天前  
51好读  ›  专栏  ›  机器学习研究组订阅

利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调

机器学习研究组订阅  · 公众号  · AI  · 2024-12-05 19:28

正文

本文介绍如何使用 PyTorch 和三元组边缘损失 (Triplet Margin Loss) 微调嵌入模型,并重点阐述实现细节和代码示例。三元组损失是一种对比损失函数,通过缩小锚点与正例间的距离,同时扩大锚点与负例间的距离来优化模型。


数据集准备与处理


一般的嵌入模型都会使用Sentence Transformer ,其中的 encode() 方法可以直接处理文本输入。但是为了进行微调,我们需要采用 Transformer 库,所以就要将文本转换为模型可接受的 token IDs 和 attention masks。Token IDs 代表模型词汇表中的词或字符,attention masks 用于防止模型关注填充 tokens。

本文使用 thenlper/gte-base 模型,需要对应的 tokenizer 对文本进行预处理。该模型基于 BertModel 架构:

 BertModel(  (embeddings): BertEmbeddings(    (word_embeddings): Embedding(30522, 768, padding_idx=0)    (position_embeddings): Embedding(512, 768)    (token_type_embeddings): Embedding(2, 768)    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)    (dropout): Dropout(p=0.1, inplace=False)  )  (encoder): BertEncoder(    (layer): ModuleList(      (0-11): 12 x BertLayer(        (attention): BertAttention(          (self): BertSdpaSelfAttention(            (query): Linear(in_features=768, out_features=768, bias=True)            (key): Linear(in_features=768, out_features=768, bias=True)            (value): Linear(in_features=768, out_features=768, bias=True)            (dropout): Dropout(p=0.1, inplace=False)          )          (output): BertSelfOutput(            (dense): Linear(in_features=768, out_features=768, bias=True)            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)            (dropout): Dropout(p=0.1, inplace=False)          )        )        (intermediate): BertIntermediate(          (dense): Linear(in_features=768, out_features=3072, bias=True)          (intermediate_act_fn): GELUActivation()        )        (output): BertOutput(          (dense): Linear(in_features=3072, out_features=768, bias=True)          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)          (dropout): Dropout(p=0.1, inplace=False)        )      )    )  )  (pooler): BertPooler(    (dense): Linear(in_features=768, out_features=768, bias=True)    (activation): Tanh()  ) )

利用 Transformers 库的 AutoTokenizer 和 AutoModel 可以简化模型加载过程,无需手动处理底层架构和配置细节。

 from transformers import AutoTokenizer, AutoModel   from tqdm import tqdm   tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")  
# 获取文本并进行标记 train_texts = [df_train.loc[i]['content'] for i in range(df_train.shape[0])] dev_texts = [df_dev.loc[i]['content'] for i in range(df_dev.shape[0])] test_texts = [df_test.loc[i]['content'] for i in range(df_test.shape[0])]
train_tokens = [] train_attention_masks = [] dev_tokens = [] dev_attention_masks = [] test_tokens = [] test_attention_masks = []
for sent in tqdm(train_texts): encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') train_tokens.append(encoding['input_ids'].squeeze(0)) train_attention_masks.append(encoding['attention_mask'].squeeze(0))
for sent in tqdm(dev_texts): encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') dev_tokens.append(encoding['input_ids'].squeeze(0)) dev_attention_masks.append(encoding['attention_mask'].squeeze(0))
for sent in tqdm(test_texts): encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') test_tokens.append(encoding['input_ids'].squeeze(0)) test_attention_masks.append(encoding['attention_mask'].squeeze(0))

获取 token IDs 和 attention masks 后,需要将其存储并创建一个自定义的 PyTorch 数据集。

 import random  from collections import defaultdict  import torch  from torch.utils.data import Dataset, DataLoader, Sampler, SequentialSampler  
class CustomTripletDataset(Dataset): def __init__(self, tokens, attention_masks, labels): self.tokens = tokens self.attention_masks = attention_masks self.labels = torch.Tensor(labels) self.label_dict = defaultdict(list)
for i in range(len(tokens)): self.label_dict[int(self.labels[i])].append(i) self.unique_classes = list(self.label_dict.keys())
def __len__(self): return len(self.tokens)
def __getitem__(self, index): ids = self.tokens[index].to(device) ams = self.attention_masks[index].to(device) y = self.labels[index].to(device) return ids, ams, y

由于采用三元组损失,需要从数据集中采样正例和负例。label_dict 字典用于存储每个类别及其对应的数据索引,方便随机采样。DataLoader 用于加载数据集:


 train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler)


其中 train_batch_sampler 是自定义的批次采样器:


 class CustomBatchSampler(SequentialSampler):       def __init__(self, dataset, batch_size):           self.dataset = dataset           self.batch_size = batch_size           self.unique_classes = sorted(dataset.unique_classes)           self.label_dict = dataset.label_dict           self.num_batches = len(self.dataset) // self.batch_size           self.class_size = self.batch_size // 4  
def __iter__(self): total_samples_used = 0 weights = np.repeat(1, len(self.unique_classes))
while total_samples_used < len(self.dataset): batch = [] classes = [] for _ in range(4): next_selected_class = self._select_class(weights) while next_selected_class in classes: next_selected_class = self._select_class(weights) weights[next_selected_class] += 1 classes.append(next_selected_class) new_choices = self.label_dict[next_selected_class] remaining_samples = list(np.random.choice(new_choices, min(self.class_size, len(new_choices)), replace=False)) batch.extend(remaining_samples)
total_samples_used += len(batch)
yield batch
def _select_class(self, weights): dist = 1/weights dist = dist/np.sum(dist) selected = int(np.random.choice(self.unique_classes, p=dist)) return selected
def __len__(self): return self.num_batches

自定义批次采样器控制训练批次的构成,本文的实现确保每个批次包含 4 个类别,每个类别包含 8 个数据点。验证采样器则确保验证集批次在不同 epoch 间保持一致。

模型构建


嵌入模型通常基于 Transformer 架构,输出每个 token 的嵌入。为了获得句子嵌入,需要对 token 嵌入进行汇总。常用的方法包括 CLS 池化和平均池化。本文使用的 gte-base 模型采用平均池化,需要从模型输出中提取 token 嵌入并计算平均值。

 import torch.nn.functional as F   import torch.nn as nn  
class EmbeddingModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model
def average_pool(self, last_hidden_states, attention_mask): # 平均 token 嵌入 last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def forward(self, input_ids, attention_mask): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = outputs.last_hidden_state pooled_output = self.average_pool(last_hidden_state, attention_mask) normalized_output = F.normalize(pooled_output, p=2, dim=1) return normalized_output
base_model = AutoModel.from_pretrained("thenlper/gte-base") model = EmbeddingModel(base_model)

EmbeddingModel 类封装了 Hugging Face 模型,并实现了平均池化和嵌入归一化。

模型训练


训练循环中需要动态计算每个锚点的最难正例和最难负例。

 import numpy as np  
def train(model, train_loader, criterion, optimizer, scheduler):





请到「今天看啥」查看全文