# `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): def save_model_hook(models, weights, output_dir): ... def load_model_hook(models, input_dir): ...
跳过上面的代码,还是日志配置。
# Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error()
之后其他版本的训练脚本会有一段设置随机种子的代码,我们给这份脚本补上。
# If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed)
接着,函数会创建输出文件夹。如果我们想把模型推送到在线仓库上,函数还会创建一个仓库。这段代码还出现了一行比较重要的判断语句:
if accelerator.is_main_process:
。在多卡训练时,只有主进程会执行这个条件语句块里的内容。该判断在并行编程中十分重要。很多时候,比如在输出、存取模型时,我们只需要让一个进程执行操作就行了。这个时候就要用到这行判断语句。
# Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub: repo_id = create_repo(...).repo_id
# Initialize the model if args.model_config_name_or_path is None: model = UNet2DModel(...) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config)
这份脚本还帮我们写好了维护 EMA(指数移动平均)模型的功能。EMA 模型用于存储模型可学习的参数的局部平均值。有时 EMA 模型的效果会比原模型要好。
# Create EMA for the model. if args.use_ema: ema_model = EMAModel( model.parameters(), model_cls=UNet2DModel, model_config=model.config, ...)
if args.dataset_name is not None: dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, split="train", ) else: dataset = load_dataset( "imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = .. else: # Get the most recent checkpoint ...
if path is None: ... else: accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.print(f"Resuming from checkpoint {path}") ...
在每个 epoch 中,函数会重置进度条。接着,函数会进入每一个 batch 的训练迭代。
# Train! for epoch in range(first_epoch, args.num_epochs): model.train() progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader):
如果是继续训练的话,训练开始之前会更新当前的步数
step
。
# Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step if step % args.gradient_accumulation_steps == 0: progress_bar.update(1) continue
接下来,函数会用去噪网络做前向传播。为了让模型能正确累计梯度,我们要用
with accelerator.accumulate(model):
把模型调用与反向传播的逻辑包起来。在这段代码中,我们会先得到模型的输出
model_output
,再根据扩散模型得到损失函数
loss
,最后用 accelerate 库的 API
accelerator
代替原来 PyTorch API 来完成反向传播、梯度裁剪,并完成参数更新、学习率调度器更新、优化器更新。
with accelerator.accumulate(model): # Predict the noise residual model_output = model(noisy_images, timesteps).sample
loss = ...
accelerator.backward(loss)
if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad()
确保一步训练结束后,函数会更新和步数相关的变量。
if accelerator.sync_gradients: if args.use_ema: ema_model.step(model.parameters()) progress_bar.update(1) global_step += 1
f accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [ d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted( checkpoints, key=lambda x: int(x.split("-")[1]))
if len(checkpoints) >= args.checkpoints_total_limit: ...
save_path = os.path.join( args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}")
if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: ...
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: ...
脚本默认的验证方法是随机生成图片,并用日志库保存图片。生成图片的方法是使用标准 Diffusers 采样流水线
DDPMPipeline
。由于此时模型
model
可能被包裹成了一个用于多卡训练的 PyTorch 模块,需要用相关 API 把
model
解包成普通 PyTorch 模块
unet
。如果使用了 EMA 模型,为了避免对 EMA 模型的干扰,此处需要先保存 EMA 模型参数,采样结束再还原参数。
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: unet = accelerator.unwrap_model(model) if args.use_ema: ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters())
generator = torch.Generator(device=pipeline.device).manual_seed(0) # run pipeline in inference (sample random noise and denoise) images = pipeline(...).images
if args.use_ema: ema_model.restore(unet.parameters())
# denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8")
if args.logger == "tensorboard": ... elif args.logger == "wandb": ...
在保存模型时,脚本同样会先用去噪模型
model
构建一个流水线,再调用流水线的保存方法
save_pretrained
将扩散模型的所有组件(去噪模型、噪声调度器)保存下来。
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model unet = accelerator.unwrap_model(model)
if args.use_ema: ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters())
# old if cfg.model_config is None: model = UNet2DModel(...) else: config = UNet2DModel.load_config(cfg.model_config) model = UNet2DModel.from_config(config)
# Create EMA for the model. if cfg.use_ema: ema_model = EMAModel(...) ...
# new trainer.init_modules(enable_xformers, cfg.gradient_checkpointing)
# The config must have a "base" key base_cfg_dict = data_dict.pop('base')
# The config must have one another model config assert len(data_dict) == 1 model_key = next(iter(data_dict)) model_cfg_dict = data_dict[model_key] model_cfg_cls = __TYPE_CLS_DICT[model_key]