专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
简七读财  ·  3个有效方法,新年好运连连 ·  19 小时前  
哔哩哔哩  ·  几款最强AI玩狼人杀,谁能封神? ·  昨天  
格上财富  ·  上海为什么要新建这么多高铁站 ·  2 天前  
格上财富  ·  芒格:把一天中最好的时间卖给自己 ·  3 天前  
哔哩哔哩  ·  猫和老鼠来B站,自己鬼畜自己 ·  3 天前  
51好读  ›  专栏  ›  极市平台

定制适合自己的 Diffusers 扩散模型训练脚本

极市平台  · 公众号  ·  · 2024-07-25 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨天才程序员周弈帆
来源丨天才程序员周弈帆
编辑丨极市平台

极市导读

本文介绍了如何定制Diffusers库的训练脚本,通过重构使其更兼容多种模型训练,分享了作者的GitHub代码示例。作者提出了一套设计原则,通过创建数据类和接口类来解耦训练逻辑,简化添加新任务的过程,并讨论了重构代码的体会。 >> 加入极市CV技术交流群,走在计算机视觉的最前沿

Diffusers 库为社区用户提供了多种扩散模型任务的训练脚本。每个脚本都平铺直叙,没有多余的封装,把训练的绝大多数细节都写在了一个脚本里。这种设计既能让入门用户在不阅读源码的前提下直接用脚本训练,又方便高级用户直接修改脚本。

可是,这种设计就是最好的吗?关于训练脚本的最佳设计风格,社区用户们往往各执一词。有人更喜欢更贴近 PyTorch 官方示例的写法,而有人会喜欢用 PyTorch Lightning 等封装度高、重复代码少的库。而在我看来,选择哪种风格的训练脚本,确实是个人喜好问题。但是,在开始使用训练脚本之前,我们要从细节入手,理解训练脚本到底要做哪些事。学懂了之后,不管是用别人的训练库,还是定制适合自己的训练脚本,都是很轻松的。不管怎么说,Diffusers 的这种训练脚本是一份很好的学习素材。

当然,我在用 Diffusers 的训练脚本时,发现一旦涉及多类任务的训练,比如既要能训练 Stable Diffusion,又要能训练 VAE,那么这份脚本就会用起来比较困难,而写两份训练脚本又会有很大的冗余。Diffusers 的训练脚本依然有改进的空间。

在这篇文章中,我会主要面向想系统性学习扩散模型训练框架的读者,先详细介绍 Diffusers 官方训练脚本,再分享我重构训练脚本的过程,使得脚本能够更好地兼容多类模型的训练。文章的末尾,我会展示几个简单的扩散模型训练实例。

在阅读本文时,建议大家用电脑端,一边看源代码一边读文章。「官方训练脚本细读」一节细节较多,初次阅读时可以快速浏览,看完「训练脚本内容总结」中的流程图,再回头仔细看一遍。

准备源代码

我们将以最简单的 DDPM 官方训练脚本 examples/unconditional_image_generation/train_unconditional.py 为例,学习训练脚本的通用写法。 examples 文件夹在位于 Diffusers 官方 GitHub 仓库中,用 pip 安装的 Diffusers 可能没有这个文件夹,最好是手动 clone 官方仓库,再在本地查看这个文件夹。使用 Diffusers 训练时,可能还要安装其他库。官方在不同的训练教程里给了不同的安装指令,建议大家都安装上。

cd examples/text_to_image
pip install -r requirements.txt
pip install diffusers[training]

我为本教程准备的脚本在仓库 https://github.com/SingleZombie/DiffusersExample 中。请 clone 这个仓库,再切换到 TrainingScript 目录下。 train_official.py 是原官方训练脚本 train_unconditional.py train_0.py 是第一次修改后的训练脚本 , train_1.py 是第二次修改后的训练脚本。

官方训练脚本细读

先拉到文件的最底部,我们能在这找到程序的入口。在 parse_args 函数中,脚本会用 argparse 库解析命令行参数,并将所有参数保存在 args 里。 args 会传进 main 函数里。稍后我们看到所有 args. 打头的变量调用,都表明该变量来自于命令行参数。

if __name__ == "__main__":
    args = parse_args()
    main(args)

接着,我们正式开始学习训练主函数。一开始,函数会配置 accelerate 库及日志记录器。

logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=args.output_dir, logging_dir=logging_dir)

# a big number for high resolution or big dataset
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
accelerator = Accelerator(...)

if args.logger == "tensorboard":
    if not is_tensorboard_available():
        ...

elif args.logger == "wandb":
    if not is_wandb_available():
        ...
    import wandb

在配置日志的中途,函数插入了一段修改模型存取逻辑的代码。为了让我们阅读代码的顺序与实际运行顺序一致,我们等待会用到了这段代码时再回头来读。

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
    def save_model_hook(models, weights, output_dir):
        ...
    def load_model_hook(models, input_dir):
        ...

跳过上面的代码,还是日志配置。

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

之后其他版本的训练脚本会有一段设置随机种子的代码,我们给这份脚本补上。

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

接着,函数会创建输出文件夹。如果我们想把模型推送到在线仓库上,函数还会创建一个仓库。这段代码还出现了一行比较重要的判断语句: if accelerator.is_main_process: 。在多卡训练时,只有主进程会执行这个条件语句块里的内容。该判断在并行编程中十分重要。很多时候,比如在输出、存取模型时,我们只需要让一个进程执行操作就行了。这个时候就要用到这行判断语句。

# Handle the repository creation
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    if args.push_to_hub:
        repo_id = create_repo(...).repo_id

准备完辅助工具后,函数开始准备模型。输入参数里的 model_config_name_or_path 表示预定义的模型配置文件。如果该配置文件不存在,则函数会用默认的配置创建一个 DDPM 的 U-Net 模型。在写我们自己的训练脚本时,我们需要在这个地方初始化我们需要的所有模型。比如训练 Stable Diffusion 时,除了 U-Net,需要在此处准备 VAE、CLIP 文本编码器。

# Initialize the model
if args.model_config_name_or_path is None:
    model = UNet2DModel(...)
else:
    config = UNet2DModel.load_config(args.model_config_name_or_path)
    model = UNet2DModel.from_config(config)

这份脚本还帮我们写好了维护 EMA(指数移动平均)模型的功能。EMA 模型用于存储模型可学习的参数的局部平均值。有时 EMA 模型的效果会比原模型要好。

# Create EMA for the model.
if args.use_ema:
    ema_model = EMAModel(
        model.parameters(),
        model_cls=UNet2DModel,
        model_config=model.config,
        ...)

此处函数还会根据 accelerate 配置自动设置模型的精度。

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
    args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
    args.mixed_precision = accelerator.mixed_precision

函数还会尝试启用 xformers 来提升 Attention 的效率。PyTorch 在 2.0 版本也加入了类似的 Attention 优化技术。如果你的显卡性能有限,且 PyTorch 版本小于 2.0,可以考虑使用 xformers

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        ...

准备了 U-Net 后,函数会准备噪声调度器,即定义扩散模型的细节。

注意,扩散模型不是一个神经网络,而是一套定义了加噪、去噪公式的模型。扩散模型中需要一个去噪模型来去噪,去噪模型一般是一个神经网络。

# Initialize the scheduler
accepts_prediction_type = "prediction_type" in set(
    inspect.signature(DDPMScheduler.__init__).parameters.keys())
if  accepts_prediction_type:
    noise_scheduler = DDPMScheduler(...)
else:
    noise_scheduler = DDPMScheduler(...)

准备完所有扩散模型组件后,函数开始准备其他和训练相关的模块。其他版本的训练脚本会在这个地方加一段缓存梯度和自动放缩学习率的代码,我们给这份脚本补上。

if args.gradient_checkpointing:
    unet.enable_gradient_checkpointing()

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

函数先准备的训练模块是优化器。这里默认使用的优化器是 AdamW

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

函数随后会准备训练集。这个脚本用 HuggingFace 的 datasets 库来管理数据集。我们既可以读取在线数据集,也可以读取本地的图片文件夹数据集。自定义数据集的方法可以参考 https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 。

if args.dataset_name is not None:
    dataset = load_dataset(
        args.dataset_name,
        args.dataset_config_name,
        cache_dir=args.cache_dir,
        split="train",
    )
else:
    dataset = load_dataset(
        "imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
    # See more about loading custom images at
    # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

有了数据集后,函数会继续准备 PyTorch 的 DataLoader。在这一步中,除了定义 DataLoader 外,我们还要编写数据预处理的方法。下面这段代码的编写顺序和执行顺序不同,我们按执行顺序来整理一遍下面的代码:

  1. 将预定义的预处理函数传给数据集对象 `dataset.set_transform(transform_images)`。在使用数据集里的数据时,才会调用这个函数预处理图像。
  2. 使用 PyTorch API 定义 DataLoader。`train_dataloader = ...`
  3. 每次用 DataLoader 获取数据时,一个数据词典 `examples` 会被传入预处理函数 `transform_images`。`examples` 里既包含了图像数据,也包含了数据的各种标签。而对于无约束图像生成任务,我们只需要图像数据,因此可以直接通过词典的 `"image"` 键得到 PIL 格式的图像数据。用 `convert("RGB")` 把图像转成三通道后,该 PIL 图像会被传入预处理流水线。
  4. 图像预处理流水线 `augmentations` 是用 Torchvision 里的 `transform` API 定义的。默认的流水线包括短边缩放至指定分辨率、按分辨率裁剪、随机反转、归一化。
  5. 处理过的数据会被存到词典的 `"input"` 键里。
# Preprocessing the datasets and DataLoaders creation.
augmentations = transforms.Compose(
    [
        transforms.Resize(
            args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(
            args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
        transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def transform_images(examples):
    images = [augmentations(image.convert("RGB"))
                for image in examples["image"]]
    return {"input": images}

logger.info(f"Dataset size: {len(dataset)}")

dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)

在准备工作的最后,函数会准备学习率调度器。

# Initialize the learning rate scheduler
lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=(len(train_dataloader) * args.num_epochs),
)

准备完了所有模块,函数会调用 accelerate 库来把所有模块变成适合并行训练的模块。

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)

if args.use_ema:
    ema_model.to(accelerator.device)

之后函数还会用 accelerate 库配置训练日志。默认情况下日志名 run 由当前脚本名决定。如果不想让之前的日志被覆盖的话,可以让日志名 run 由当前的时间决定。

if accelerator.is_main_process:
    run = os.path.split(__file__)[-1].split(".")[0]
    accelerator.init_trackers(run)

马上就要开始训练了。在此之前,函数会准备全局变量并记录日志。注意,这里函数会算一次总的 batch 数,它由输入 batch 数、进程数(显卡数)、梯度累计步数共同决定。梯度累计是一种用较少的显存实现大 batch 训练的技术。使用这项技术时,训练梯度不会每步优化,而是累计了若干步后再优化。

total_batch_size = args.train_batch_size * \
    accelerator.num_processes * args.gradient_accumulation_steps
num_update_steps_per_epoch = math.ceil(
    len(train_dataloader) / args.gradient_accumulation_steps)
max_train_steps = args.num_epochs * num_update_steps_per_epoch

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(dataset)}")
logger.info(f"  Num Epochs = {args.num_epochs}")
logger.info(
    f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
    f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
    f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")

global_step = 0
first_epoch = 0

在开始训练前,如果设置了 args.resume_from_checkpoint ,则函数会读取之前训练过的权重。负责读取训练权重的函数是 load_state

if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = ..
    else:
        # Get the most recent checkpoint
        ...

    if path is None:
        ...
    else:
        accelerator.load_state(os.path.join(args.output_dir, path))
        accelerator.print(f"Resuming from checkpoint {path}")
        ...

在每个 epoch 中,函数会重置进度条。接着,函数会进入每一个 batch 的训练迭代。


# Train!
for epoch in range(first_epoch, args.num_epochs):
    model.train()
    progress_bar = tqdm(total=num_update_steps_per_epoch,
                        disable=not accelerator.is_local_main_process)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in enumerate(train_dataloader):

如果是继续训练的话,训练开始之前会更新当前的步数 step

# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step     if step % args.gradient_accumulation_steps == 0:
        progress_bar.update(1)
    continue

训练的一开始,函数会从数据的 "input" 键里取出图像数据。此处的键名是我们之前在数据预处理函数 transform_images 里写的。

clean_images = batch["input"].to(weight_dtype)

之后函数会设置扩散模型训练中的其他变量,包含随机噪声、时刻。由于本文的重点并不是介绍扩散模型的原理,这段代码我们就快速略过。

noise = torch.randn(...)
timesteps =...
noisy_images = noise_scheduler.add_noise(
    clean_images, noise, timesteps)

接下来,函数会用去噪网络做前向传播。为了让模型能正确累计梯度,我们要用 with accelerator.accumulate(model): 把模型调用与反向传播的逻辑包起来。在这段代码中,我们会先得到模型的输出 model_output ,再根据扩散模型得到损失函数 loss ,最后用 accelerate 库的 API accelerator 代替原来 PyTorch API 来完成反向传播、梯度裁剪,并完成参数更新、学习率调度器更新、优化器更新。

with accelerator.accumulate(model):
    # Predict the noise residual
    model_output = model(noisy_images, timesteps).sample

    loss = ...

    accelerator.backward(loss)

    if accelerator.sync_gradients:
        accelerator.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()

确保一步训练结束后,函数会更新和步数相关的变量。

if accelerator.sync_gradients:
    if args.use_ema:
        ema_model.step(model.parameters())
    progress_bar.update(1)
    global_step += 1

在这个地方,函数还会尝试保存模型。默认情况下,每 args.checkpointing_steps 步保存一次中间结果。确认要保存后,函数会算出当前的保存点名称,并根据最大保存点数 checkpoints_total_limit 决定是否要删除以前的保存点。做完准备后,函数会调用 save_state 保存当前训练时的所有中间变量。

f accelerator.is_main_process:
    if global_step % args.checkpointing_steps == 0:
        if args.checkpoints_total_limit is not None:
            checkpoints = os.listdir(args.output_dir)
            checkpoints = [
                d for d in checkpoints if d.startswith("checkpoint")]
            checkpoints = sorted(
                checkpoints, key=lambda x: int(x.split("-")[1]))

            if len(checkpoints) >= args.checkpoints_total_limit:
                ...

            save_path = os.path.join(
            args.output_dir, f"checkpoint-{global_step}")
            accelerator.save_state(save_path)
            logger.info(f"Saved state to {save_path}")

在这个地方,主函数开头设置的存取模型回调函数终于派上用场了。在调用 save_state 时,会自动触发下面的回调函数来保存模型。如果不加下面的代码,所有模型默认会以 .safetensor 的形式存下来。而用了下面的代码后,模型能够被 save_pretrained 存进一个文件夹里,就像其他标准 Diffusers 模型一样。

这里的输入参数 models 来自于之前的 accelerator.prepare ,感兴趣可以去阅读文档或源码。

def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        if args.use_ema:
            ema_model.save_pretrained(
                os.path.join(output_dir, "unet_ema"))

        for i, model in enumerate(models):
            model.save_pretrained(os.path.join(output_dir, "unet"))

            # make sure to pop weight so that corresponding model is not saved again
            weights.pop()

与上面的这段代码对应,脚本还提供了读取文件的回调函数。它会在继续中断的训练后调用 load_state 时被调用。

def load_model_hook(models, input_dir):
    if args.use_ema:
        load_model = EMAModel.from_pretrained(
            os.path.join(input_dir, "unet_ema"), UNet2DModel)
        ema_model.load_state_dict(load_model.state_dict())
        ema_model.to(accelerator.device)
        del load_model

    for i in range(len(models)):
        # pop models so that they are not loaded again
        model = models.pop()

        # load diffusers style into model
        load_model = UNet2DModel.from_pretrained(
            input_dir, subfolder="unet")
        model.register_to_config(**load_model.config)

        model.load_state_dict(load_model.state_dict())
        del load_model

两个回调函数需要用下面的代码来设置。

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

回到最新的代码处。训练迭代的末尾,脚本会记录当前步的日志。

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
    logs["ema_decay"] = ema_model.cur_decay_value
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

执行完了一个 epoch 后,脚本调用 accelerate API 保证所有进程均训练完毕。

progress_bar.close()
accelerator.wait_for_everyone()

此处脚本可能会在主进程中验证模型或保存模型。如果当前是最后一个 epoch,或者达到了配置指定的验证/保存时刻,脚本就会执行验证/保存。


if accelerator.is_main_process:
    if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
        ...

    if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
        ...

脚本默认的验证方法是随机生成图片,并用日志库保存图片。生成图片的方法是使用标准 Diffusers 采样流水线 DDPMPipeline 。由于此时模型 model 可能被包裹成了一个用于多卡训练的 PyTorch 模块,需要用相关 API 把 model 解包成普通 PyTorch 模块 unet 。如果使用了 EMA 模型,为了避免对 EMA 模型的干扰,此处需要先保存 EMA 模型参数,采样结束再还原参数。

if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
    unet = accelerator.unwrap_model(model)
    if args.use_ema:
        ema_model.store(unet.parameters())
        ema_model.copy_to(unet.parameters())

    pipeline = DDPMPipeline(
        unet=unet,
        scheduler=noise_scheduler,
    )

    generator = torch.Generator(device=pipeline.device).manual_seed(0)
    # run pipeline in inference (sample random noise and denoise)
    images = pipeline(...).images

    if args.use_ema:
        ema_model.restore(unet.parameters())

    # denormalize the images and save to tensorboard
    images_processed = (images * 255).round().astype("uint8")

    if args.logger == "tensorboard":
        ...
    elif args.logger == "wandb":
        ...

在保存模型时,脚本同样会先用去噪模型 model 构建一个流水线,再调用流水线的保存方法 save_pretrained 将扩散模型的所有组件(去噪模型、噪声调度器)保存下来。

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
    # save the model
    unet = accelerator.unwrap_model(model)

    if args.use_ema:
        ema_model.store(unet.parameters())
        ema_model.copy_to(unet.parameters())

    pipeline = DDPMPipeline(
        unet=unet,
        scheduler=noise_scheduler,
    )

    pipeline.save_pretrained(args.output_dir)

    if args.use_ema:
        ema_model.restore(unet.parameters())

    if args.push_to_hub:
        upload_folder(...)

一个 epoch 训练的代码就到此结束了。所有 epoch 的训练结束后,脚本调用 API 结束训练。这个 API 会自动关闭所有的日志库。训练代码到这里也就结束了。

accelerator.end_training()

训练脚本内容总结

大概熟悉了一遍这份训练脚本后,我们可以用下面的流程图概括训练脚本的执行顺序和主要内容。

去掉命令行参数

我不喜欢用命令行参数传训练参数,而喜欢把训练参数写进配置文件里,理由有:

  • 我一般会直接在命令行里手敲命令。如果命令行参数过多,我则会把要运行的命令及其参数保存在某文件里。这样还不如把参数写在另外的文件里。
  • 将大量参数藏在一个词典 args 里,而不是把所有需用的参数在某处定义好,是一种很差的编程方式。各个参数将难以追踪。

在正式重构脚本之前,我做的第一步是去掉脚本中原来的命令行参数,将所有参数先塞进一个数据类里面。脚本将只留一个命令行参数,表示参数配置文件的路径。具体做法如下:

先编写一个存命令行参数的数据类。这个类是一个 Python 的 dataclass 。Python 中 dataclass 是一种专门用来放数据的类。定义数据类时,我们只需要定义类中所有数据的类型及默认值,不需要编写任何方法。初始化数据类时,我们只需要传一个词典或列表。一个示例如下(示例来源 https://www.geeksforgeeks.org/understanding-python-dataclasses/):

from dataclasses import dataclass
 
# A class for holding an employees content
@dataclass
class employee:
 
    # Attributes Declaration
    # using Type Hints
    name: str
    emp_id: str
    age: int
    city: str
 
 
emp1 = employee("Satyam""ksatyam858", 21, 'Patna')
emp2 = employee("Anurag""au23", 28, 'Delhi')
emp3 = employee({"name""Satyam"
   "emp_id""ksatyam858"
   "age": 21, 
   "city"'Patna'})
 
print("employee object are :")
print(emp1)
print(emp2)
print(emp3)

我们可以用 dataclass 编写一个存储所有命令行参数的数据类,该类开头内容如下:

from dataclasses import dataclass

@dataclass
class BaseTrainingConfig:
    # Dir
    logging_dir: str
    output_dir: str

    # Logger and checkpoint
    logger: str = 'tensorboard'
    checkpointing_steps: int = 500
    checkpoints_total_limit: int = 20
    valid_epochs: int = 100
    valid_batch_size: int = 1
    save_model_epochs: int = 100
    resume_from_checkpoint: str = None

之后在训练脚本里,我们可以把旧的命令行参数全删了,再加一个命令行参数 cfg ,表示训练配置文件的路径。我们可以用 omegaconf 打开这个配置文件,得到一个词典 data_dict ,再用这个词典构建配置文件 cfg 。接下来,只需要把原来代码里所有 args. 改成 cfg. 就行了。

from omegaconf import OmegaConf
from training_cfg_0 import BaseTrainingConfig

parser = argparse.ArgumentParser()
parser.add_argument('cfg'type=str)
args = parser.parse_args()

data_dict = OmegaConf.load(args.cfg)
cfg = BaseTrainingConfig(**data_dict)

第一次修改过的训练脚本为 train_0.py ,配置文件类在 training_cfg_0.py 里,示例配置文件为 cfg_0.json ,一个简单 DDPM 模型配置写在 unet_cfg 目录里。可以直接运行下面的命令测试此训练脚本。

python train_0.py cfg_0.json

在配置文件里,我们只需要改少量的训练参数就行了。如果想知道还有哪些参数可以改,可以去查看 training_cfg_0.py 文件。

{
    "logging_dir""logs",
    "output_dir""models/ddpm_0",

    "model_config""unet_cfg",
    "num_epochs": 10,
    "train_batch_size": 64,
    "checkpointing_steps": 5000,
    "valid_epochs": 1,
    "valid_batch_size": 4,
    "dataset_name""ylecun/mnist",
    "resolution": 32,
    "learning_rate": 1e-4
}

读者感兴趣的话也可以尝试这样改一遍代码。这样做会强迫自己读一遍训练脚本,让自己更熟悉这份代码。

适配多种任务的训练脚本

如果只是训练一种任务,Diffusers 的这种训练脚本还算好用。但如果我们想用完全相同的训练流程训练多种任务,这种脚本的弊端就暴露出来了:

  • 各任务的官方示例脚本本身就不完全统一。比如有的训练脚本支持设置随机种子,有的不支持。
  • 一旦想修改训练过程,就得同时修改所有任务的脚本。这不符合编程中「代码复用」的思想。

为此,我想重构一下官方训练脚本,将训练流程和每种任务的具体训练过程解耦开,让一份训练脚本能够被多种任务使用。于是,我又从头过了一遍训练脚本,将代码分成两类:所有任务都会用到的代码、仅 DDPM 训练会用到的代码。如下图所示,我用红字表示了训练脚本中应该由具体任务决定的部分。

根据这个划分规则,我将仅和 DDPM 相关的代码剥离出来,并用一个描述某具体任务的训练器接口类的方法调用代替原有代码。这样,每次换一个训练任务,只需要重新实现一个训练器类就行了。如下图所示,原流程图中所有红字的内容都可以由接口类的方法代替。对于不同任务,我们需要实现不同的训练器类。

具体在代码中,我写了一个接口类 Trainer

class Trainer(metaclass=ABCMeta):
    def __init__(self, weight_dtype, accelerator, logger, cfg):
        self.weight_dtype = weight_dtype
        self.accelerator = accelerator
        self.logger = logger
        self.cfg = cfg

    @abstractmethod
    def init_modules(self,
                     enable_xformer: bool = False,
                     gradient_checkpointing: bool = False):
        pass

    @abstractmethod
    def init_optimizers(self, train_batch_size):
        pass

    @abstractmethod
    def init_lr_schedulers(self, gradient_accumulation_steps, num_epochs):
        pass

    def set_dataset(self, dataset, train_dataloader):
        self.dataset = dataset
        self.train_dataloader = train_dataloader

    @abstractmethod
    def prepare_modules(self):
        pass

    @abstractmethod
    def models_to_train(self):
        pass

    @abstractmethod
    def training_step(self, global_step, batch) -> dict:
        pass

    @abstractmethod
    def validate(self, epoch, global_step):
        pass

    @abstractmethod
    def save_pipeline(self):
        pass

    @abstractmethod
    def save_model_hook(self, models, weights, output_dir):
        pass

    @abstractmethod
    def load_model_hook(self, models, input_dir):
        pass

根据类型名和初始化参数可以创建具体的训练器。

def create_trainer(type, weight_dtype, accelerator, logger, cfg_dict) -> Trainer:
    from ddpm_trainer import DDPMTrainer
    from sd_lora_trainer import LoraTrainer

    __TYPE_CLS_DICT = {
        'ddpm': DDPMTrainer,
        'lora': LoraTrainer
    }

    return __TYPE_CLS_DICT[type](weight_dtype, accelerator, logger, cfg_dict)

原来训练脚本里的具体训练逻辑被接口类方法调用代替。比如:

# old
if cfg.model_config is None:
    model = UNet2DModel(...)
else:
    config = UNet2DModel.load_config(cfg.model_config)
    model = UNet2DModel.from_config(config)

# Create EMA for the model.
if cfg.use_ema:
    ema_model = EMAModel(...)
...

# new
trainer.init_modules(enable_xformers, cfg.gradient_checkpointing)

原来仅和 DDPM 训练相关的代码全被我搬到了 DDPMTrainer 类中。与之对应,除了代码需要搬走外,原配置文件里的数据也需要搬走。我在 DDPMTrainer 类里加了一个 DDPMTrainingConfig 数据类,用来存对应的配置数据。

@dataclass
class DDPMTrainingConfig:
    # Diffuion Models
    model_config: str
    ddpm_num_steps: int = 1000
    ddpm_beta_schedule: str = 'linear'
    prediction_type: str = 'epsilon'
    ddpm_num_inference_steps: int = 100
    ...

因此,我们需要用稍微复杂一点的方式来创建配置文件。现在全局训练配置和任务配置放在两组配置里。配置文件最外层除 "base" 外的那个键表明了训练器的类型。

{
    "base": {
        "logging_dir""logs",
        "output_dir""models/ddpm_1",
        "checkpointing_steps": 5000,
        "valid_epochs": 1,
        "dataset_name""ylecun/mnist",
        "resolution": 32,
        "train_batch_size": 64,
        "num_epochs": 10
    },
    "ddpm": {
        "model_config""unet_cfg",
        "learning_rate": 1e-4,
        "valid_batch_size": 4
    }
}
__TYPE_CLS_DICT = {
    'base': BaseTrainingConfig,
    'ddpm': DDPMTrainingConfig,
    'lora': LoraTrainingConfig
}


def load_training_config(config_path: str) -> Dict[str, BaseTrainingConfig]:
    data_dict = OmegaConf.load(config_path)

    # The config must have a "base" key
    base_cfg_dict = data_dict.pop('base')

    # The config must have one another model config
    assert len(data_dict) == 1
    model_key = next(iter(data_dict))
    model_cfg_dict = data_dict[model_key]
    model_cfg_cls = __TYPE_CLS_DICT[model_key]

    return {'base': BaseTrainingConfig(**base_cfg_dict),
            model_key: model_cfg_cls(**model_cfg_dict)}

这样改完过后,训练脚本开头也需要稍作更改,其他地方保持不变。

from training_cfg_1 import BaseTrainingConfig, load_training_config
from trainer import Trainer, create_trainer

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg'type=str)
    args = parser.parse_args()

    cfgs = load_training_config(args.cfg)
    cfg: BaseTrainingConfig = cfgs.pop('base')
    trainer_type = next(iter(cfgs))
    trainer_cfg_dict = cfgs[trainer_type]

    ...

    trainer: Trainer = create_trainer(
        trainer_type, weight_dtype, accelerator, cfg.logger, trainer_cfg_dict)

这次修改过的训练脚本为 train_1.py ,配置文件类在 training_cfg_1.py 里,DDPM 训练器在 TrainingScript/ddpm_trainer.py 里,示例配置文件为 cfg_1.json 。可以直接运行下面的命令测试此训练脚本。







请到「今天看啥」查看全文