>>> from datasets import load_dataset >>> dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train[:1%]") >>> sample = dataset[1] >>> sample["image"].show() >>> sample["question"] 'how many families?' >>> sample["rejected"] 'The image does not provide any information about families.' >>> sample["chosen"] 'The image shows a Union Organization table setup with 18,000 families.'
defformat(example): # Prepare the input for the chat template prompt = [ { "role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}], }, ] chosen = [ { "role": "assistant", "content": [{"type": "text", "text": example["chosen"]}], }, ] rejected = [ { "role": "assistant", "content": [{"type": "text", "text": example["rejected"]}], }, ] # Apply the chat template prompt = processor.apply_chat_template(prompt, tokenize=False) chosen = processor.apply_chat_template(chosen, tokenize=False) rejected = processor.apply_chat_template(rejected, tokenize=False) # Resize the image to ensure it fits within the maximum allowable # size of the processor to prevent OOM errors. max_size = processor.image_processor.size["longest_edge"] example["image"].thumbnail((max_size, max_size)) return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}
# Apply the formatting function to the dataset, # remove columns to end up with only "images", "prompt", "chosen", "rejected" columns dataset = dataset.map(format, remove_columns=dataset.column_names)
# Make sure that the images are decoded, it prevents from storing bytes. # More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478 f = dataset.features f["images"] = features.Sequence(features.Image(decode=True)) # to avoid bytes dataset = dataset.cast(f)
完成了格式转换,我们来看看第一条数据:
>>> dataset[1] {'images': [980x812 at 0x154505570
>], 'prompt': 'User:how many families?\n', 'rejected': 'Assistant: The image does not provide any information about families.\n', 'chosen': 'Assistant: The image shows a Union Organization table setup with 18,000 families.\n'}
# dpo_idefics2-8b.py from datasets import features, load_dataset from transformers import AutoModelForVision2Seq, AutoProcessor import torch from trl import DPOConfig, DPOTrainer from peft import LoraConfig
defmain(): # Load the model and processor model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b", torch_dtype=torch.bfloat16) processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
# Load the dataset dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
defformat(example): # Prepare the input for the chat template prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}] chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}] rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}] # Apply the chat template prompt = processor.apply_chat_template(prompt, tokenize=False) chosen = processor.apply_chat_template(chosen, tokenize=False) rejected = processor.apply_chat_template(rejected, tokenize=False) # Resize the image to ensure it fits within the maximum allowable # size of the processor to prevent OOM errors. max_size = processor.image_processor.size["longest_edge"]// 2 example["image"].thumbnail((max_size, max_size)) return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}
# Apply the formatting function to the dataset dataset = dataset.map(format, remove_columns=dataset.column_names, num_proc=32)
# Make sure that the images are decoded, it prevents from storing bytes. # More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478 f = dataset.features f["images"] = features.Sequence(features.Image(decode=True)) dataset = dataset.cast(f)
# Train the model training_args = DPOConfig( output_dir="idefics2-8b-dpo", bf16=True, gradient_checkpointing=True, per_device_train_batch_size=2, gradient_accumulation_steps=32, num_train_epochs=1, dataset_num_proc=32, # tokenization will use 32 processes dataloader_num_workers=32, # data loading will use 32 workers logging_steps=10, ) trainer = DPOTrainer( model, ref_model=None, # not needed when using peft args=training_args, train_dataset=dataset, tokenizer=processor, peft_config=LoraConfig(target_modules="all-linear"), )
trainer.train()
if __name__ == "__main__": main()
启动脚本开始训练,接下来就等待结果吧 🚀
accelerate launch dpo_idefics2-8b.py
结果
训练需要几小时的时间。当训练完成后,我们可以看看训练相关指标的变化曲线:
Learning curves
In DPO, we focus on several metrics to assess the quality of the training: