FID 是一种衡量图像生成模型质量的指标。对于这种常见的指标,一般都能找到好用的 PyTorch 计算接口。然而,当我用 PyTorch 的官方库 TorchEval 来算 FID 指标时,却发现它的结果和多数非官方库无法对齐。我花了不少时间,总算把 TorchEval 的 FID 计算接口修好了。在这篇文章中,我将分享有关 FID 计算的知识以及我调试 TorchEval 的经历,并总结用 pytorch-fid, torch-fidelity, TorchEval 算 FID 的方法。文章最后,我还会分享一个偶然发现的用于反映模型训练时的当前 FID 的方法。FID 指标简介
FID 的全称是 Fréchet Inception Distance,它用于衡量两个图像分布之间的差距。如果令一个图像分布是训练集,再用生成模型输出的图像构成另一个分布,那么 FID 指标就表示了生成出来的图片和训练集整体上的相似度,也就间接反映了模型对训练集的拟合程度。FID 名字中的 Fréchet Distance 是一种描述两个样本分布的距离的指标,其定位和 KL 散度一样,但某些情况下会比 KL 散度更加合适。FID 用来算 Fréchet Distance 的样本来自预训练 InceptionV3 模型,它名称中的 Inception 由此而来。
计算 FID 的过程如下:
- 准备两个图片文件夹。一般一个是训练集,另一个存储了生成模型随机生成的图片。
- 用预训练的 InceptionV3 模型把每个输入图片转换成一个 2048 维的向量。
- 把均值、协方差代入进下面这个算 Fréchet Distance 的公式,就得到了 FID。
实际上,在用 FID 的时候我们完全不用管它的原理,只要知道它的值越小就越好,并且会调用相关接口即可。需注意的是,由于 FID 是一种和集合相关的指标,算 FID 时一定要给足图片。在构建自己模型的输出集合时,至少得有 10000 张图片,推荐生成 50000 张。否则 FID 的结果会不准确。
用 PyTorch 计算 FID 的第三方库
由于 FID 的计算需要用到一个预训练的 InceptionV3 模型,只有在模型实现完全一致的情况下,FID 的输出结果才是可比的。因此,所有论文汇报的 FID 都基于提出 FID 的作者的官方实现。这份官方实现是用 TensorFlow 写的,后来也有完全等价的 PyTorch 实现。在这一节里,我们就来学习如何用这些基于 PyTorch 的库算 FID。
GitHub 上点赞最多的 PyTorch FID 库是 pytorch-fid
。这个库被 FID 官方仓库推荐,且 Stable Diffusion 论文也用了这个库,结果绝对可靠。使用该库的方法很简单,只需要先安装它。
pip install pytorch-fid
再准备好两个用于计算 FID 的文件夹,将文件夹路径传给脚本即可。
python -m pytorch_fid path/to/dataset1 path/to/dataset2
另一个较为常见的用 PyTorch 算指标的库叫做 torch-fidelity
。它用起来和 pytorch-fid
一样简单。一开始,需要用 pip 安装它。
pip install torch-fidelity
之后,同样是准备好两个图片文件夹,将文件夹路径传给脚本。
fidelity --gpu 0 --fid --input1 path/to/dataset1 --input2 path/to/dataset2
除了命令行脚本外,torch-fidelity
还提供了 Python API。我们可以在 Python 脚本里加入算 FID 的代码。
import torch_fidelity
metrics_dict = torch_fidelity.calculate_metrics(
input1='path1',
input2='path2',
fid=True
)
print(metrics_dict)
torch-fidelity
还提供了其他便捷的功能。比如直接以某个生成模型为 API 的输入 input1
,而不是先把图像生成到一个文件夹里,再把文件夹路径传给 input1
。同时,torch-fidelity
还支持计算其他指标,我们只需要在命令行脚本或者 API 里多加几个参数就行了。
修正 TorchEval 里的 FID 计算接口
尽管这些第三方库已经足够好用了,我还是想用 PyTorch 官方近年来推出的指标计算库 TorchEval 来算 FID 指标。原因有两点:
- 我的项目其他地方都是用 PyTorch 官方库实现的 (
torch
以及 torchvision
),算指标也用官方库会让整体代码风格更加统一。我已经用 TorchEval 算了 PSNR、SSIM,使用体验还可以。 - 目前,似乎只有 TorchEval 支持在线更新指标的值。也就是说,我可以先生成一部分图片,储存算 FID 需要的中间结果;再生成一部分图片,最终计算此前所有图片与训练集的 FID。这种计算方法的好处我会在文章后面介绍。
以前我都是用 pytorch-fid 来算 FID。而当我换成用 TorchEval 后,却发现结果对不齐。于是,漫长的调试之路开始了。
当你有两块时间不一样的手表时,应该怎样确认时间呢?答案是,再找到第三块表。如果三块表中能有两块表时间一样,那么它们的时间就是正确的。一开始,我并不能确定是哪个库写错了,所以我又测试了 torch-fidelity 的结果。实验发现,torch-fidelity 和 pytorch-fid 的结果是一致的。并且我去确认了 Stable Diffusion 的论文,其中用来计算 FID 的库也是 pytorch-fid。看来,是 TorchEval 结果不对。
像 FID 这么常见的指标,大家的中间计算过程肯定都没错,就是一些细微的预处理不太一样。抱着这样的想法,我随意地比对了一下二者的代码,很快就发现 TorchEval 把输入尺寸调成 [299, 299]
了,而 pytorch-fid 没做。可删掉这段代码,程序直接报错了。我深入阅读了 pytorch-fid 的代码,发现它的写法和 TorchEval 不一样,把调整尺寸为 [299, 299]
写到了另一个地方。且通过调查发现,InceptionV3 网络的输入尺寸必须是 [299, 299]
的,是我孤陋寡闻了。唉,看来这次的调试不能太随意啊。
我准备拿出我的真实实力来调 bug。我认真整理了一下算 FID 的步骤,将其主要过程总结为以下几步:
- 用 InceptionV3 算两个数据集输出的均值、协方差
最后那个算距离的过程不涉及任何神经网络,输出该是什么就是什么。这一块是最不容易出错,且最容易调试的。于是,我决定先排除第三步是否对齐。我把 TorchEval 得到的均值、协方差存下来,用 pytorch-fid 算距离。发现结果和原 TorchEval 的输出差不多。看来算距离这一步没有问题。
接下来,我很自然地想到是不是均值和协方差算错了。我存下了两个库得到的均值、协方差,算了两个库输出之间的误差。结果发现,均值的误差在 0.09 左右,协方差的误差在 0.0002 左右。图像的数据范围在 0~1 之间,0.09 算是一个很大的误差了。可见,第一步和第二步一定存在着无法对齐的部分。
模型输出不同,最容易想到的是模型权重不同。于是,我尝试交换使用二者的模型权重,再比较输出的 FID。两个库的模型定义不太一样,不能直接换模型文件名。我用强大的代码魔改实力强行让新权重分别都跑起来了。结果非常神奇,算上之前的两个 FID,我一共得到了 4 个不一样的 FID 结果。也就是说,A 库 A 模型、B 库 B 模型、A 库 B 模型,B 库 A 模型,结果均不一样。
我被这两个库气得不行,决定认真研究对比二者的模型定义。眼尖的我发现初始化 pytorch-fid 的 InceptionV3 时有一个参数叫 use_fid_inception
。作者对此的注释写道:「如果设置为 true,则用 TensorFlow 版 FID 实现;否则,用 torchvision 版 Inception 模型。TensorFlow 的 FID Inception 模型和 torchvision 的在权重和结构上有细微的差别。如果你要计算 FID,强烈推荐将此值设置为 true,以得到和其他论文可比的结果。」总结来说,TorchEval 用的是 torchvision 里的标准 PyTorch 版 InceptionV3,而 pytorch-fid 在标准 PyTorch 版 InceptionV3 外又封装了一层,改了一些模块的定义。为什么要改这些东西呢?这是因为原来的 FID Inception 模型是在 TensorFlow 里实现的,需要改一些结构来将 PyTorch 模型对齐过去。除了模型结构外,二者的权重也有一定差别。大家都是用 TensorFlow 版模型算 FID,一切都应该以 pytorch-fid 的为准。这个 TorchEval 太离谱了,我也懒得认真修改了,直接注释掉 TorchEval 里原 FIDInceptionV3
的定义,然后大笔一挥:
from pytorch_fid.inception import \
InceptionV3 as FIDInceptionV3
按理说,这下权重和模型结构都对齐了。FID 计算的第一、第二步绝对不会有错。而开始的结果表明,FID 计算的第三步也没有错。那么,两个库就应该对齐了。我激动地又测了 TorchEval 的结果,发现结果还是无法对齐!
这不应该啊?难道哪步测错了?人生就是在不断自我怀疑中度过的。而怀疑自我,首先会怀疑最久远的自我。所以,我感觉是最早测第三步的时候有问题。之前我是把 TorchEval 的均值、协方差放到 pytorch-fid 里,结果与 TorchEval 自己的输出一致。这次我反过来,把 pytorch-fid 的均值、协方差放到 TorchEval 的算距离函数里算。这次,我第一次见到 TorchEval 输出了正确的 FID。由此可见,第三步没错。难道是均值和协方差又没对齐了?
自我怀疑开始进一步推进,我开始怀疑第二步输出的均值、协方差还是没有对齐。我再次计算了 pytorch-fid 和 TorchEval 的输出之间的误差,发现误差这次仅有 1e-16,可以认为没有区别。我花了很多时间复习协方差的计算,想找出 TorchEval 里的 bug。可是越学习,越觉得 TorchEval 写得很对。这一回,我找不到错误了。
调试代码,不怕到处有错,而怕「没错却有错」。「没错」,指的是每一步中间步骤都找不到错误;「有错」,指的是最终结果还是错了。没有错误,就得创造错误。我开启了随机乱调模式,希望能触发一个错误。回忆一下,算 FID 要用到两个数据集,一般一个是训练集,一个是模型输出的集合。在 TorchEval 最后一步算距离时,我乱改代码,让一个集合的均值、协方差不变,即来自原 TorchEval 的 Inception 模型的输出;而让另一个的集合的均值、协方差来自 pytorch-fid。理论上说,如果两个库的均值、协方差是对齐的,那么这次输出的 FID 也应该是正确的。欸,这回代码报错了,运行不了。报错说数据精度不统一。原来,TorchEval 的输出精度是 float32,而 pytorch-fid 的输出精度是 float64。之前测试距离计算函数时,数据要么全来自 TorchEval,要么全来自 pytorch-fid,所以没报过这个错。可是这个错只是一个运行上的错误,稍微改改就好了。
我把 pytorch-fid 相关数据的精度统一成了 float32。这下代码跑起来了,可 FID 不对了。调试过程中,如果上一次成功,而这一次失败,则应该想办法把代码退回上一次的,再次测试。因此,我又修改了最后用 TorchEval 计算距离的数据来源,让所有数据都来自 pytorch-fid。可是,修改后,FID 输出没变,还是错的。
为什么两轮测试之前,我全用 pytorch-fid 的输出、TorchEval 的距离计算函数没有错,这次却错了?到底是哪里不同?当测试两份差不多的代码后,一份对了,一份错了,那么错误就可以定位到两份代码的差异处。仔细回顾一下我的调试经历,相信你可以推理出 bug 出自哪了。
没错!我仔细比对了当前代码和我记忆中两轮测试前的代码,仅发现了一处不同——我把 pytorch-fid 的输出数据的精度改成了 float32。把精度改回 float64 就对了。同样,如果把 TorchEval 的输出数据的精度改成 float64,再扔进 TorchEval 的距离计算函数里算,结果也是对的。问题出在 TorchEval 的距离计算函数的数据精度上。
定位到了 bug 的位置,再找出 bug 的原因就很简单了。对比 pytorch-fid 的距离计算函数和 TorchEval 的,可以发现二者描述的计算公式完全相同。然而,pytorch-fid 是用 NumPy 算的,而 TorchEval 是用 PyTorch 算的。算 FID 的距离时,会涉及矩阵特征值等较为复杂的运算,它们对数据精度要求较高。像 NumPy 这种久经考验的库应该会自动把数据变成高精度再计算,而 PyTorch 就没做这么多细腻的处理了。
汇总一下我调试的结论。TorchEval 在权重初始化、模型计算、距离计算这三步中均有错误。前两步没有让 InceptionV3 模型和普遍使用的 TensorFlow 版对齐,最后一步没有考虑输入精度,用了不够稳定的 PyTorch API 来做复杂矩阵运算。要用 TorchEval 算出正确的 FID,需要做以下修改:
- 安装 pytorch-fid 和 TorchEval
- 打开 torcheval/metrics/image/fid.py
- 注释掉
FIDInceptionV3
类,在文件开头加上 from pytorch_fid.inception import InceptionV3 as FIDInceptionV3
- 在
FrechetInceptionDistance
类的构造函数中,在定义所有浮点数据时加上 dtype=torch.float64
这里点名批评 TorchEval。开源的时候吹得天花乱坠,结果根本没人用,这么简单的有关 FID 的 bug 也发现不了。我发了一个修正此 bug 的相关 issue https://github.com/pytorch/torcheval/issues/192
,截至目前还是没有官方人员回复。这个库的开发水平实在太逆天了,希望他们能尽快维护好。
在线计算 FID
前文提到,我用 TorchEval 的原因是它支持在线计算 FID。具体来说,可以建立一个 FID 管理类,之后用 update
方法来不断往某个集合加入新图片,并随时使用 compute
方法算出当前所有图片的 FID。我之前写代码忘了清空旧图片的中间结果时发现了一个相关应用。经我使用下来,这种应用非常有用,我们可以用它高效估计训练时的当前 FID。
回顾一下,要得到准确的 FID 值,一般需要 50000 张图片。而训练图像生成模型时,如果每次验证都要生成这么多图片,则大部分时间都会消耗在验证上了。为了加快 FID 的验证,我发现可以用一种 「全局 FID」来近似表示当前的模型拟合情况。具体来说,我先用训练集的所有图片初始化 FID 的集合 1 的中间结果,再在模型训练中每次验证时随机生成 500 张图片,将其中间结果加到 FID 的集合 2 中,并输出一次当前 FID。这样,随着训练不断推进,算 FID 的图片的数量会逐渐满足 50000 张的要求,但是这些图片并不是来自同一个模型,而是来自不同训练程度的模型。这样得到的 FID 仅能大致反映当前的真实 FID 值,有时偏高、有时偏低。但经我测试发现,这种全局 FID 的相对关系很能反映最终的真实 FID 的相对关系。训练两个不同超参的模型时,如果一个全局 FID 较大,那它最终的 FID 一般也会较大。同时,如果训练一切正常,则全局 FID 会随验证轮数单调递减(因为图片数量变多,且拟合情况不会变差)。如果某一次验证时全局 FID 增加了,则模型也一定在这段时间里变差了。通过这种验证方式,我们能够大致评估模型在训练中的拟合情况。这应该是一种很容易想到的工程技巧,但由于分享自己训练生成模型的经验帖较少,且重要性不足以写进论文,我没有在任何地方看到有人介绍这种技巧。
总结
FID 是评估图像生成模型的重要指标。通过 pytorch-fid 等库,我们能轻松地用 PyTorch 计算两个图像分布间的 FID。而通过计算输出分布和训练分布之间的 FID,我们就能评估当前模型的拟合情况。
FID 的计算本身是很简单的。所以在介绍 FID 的计算方法之外,我分享了我调试 TorchEval 的漫长过程。这段经历很有意思,我学到了不少调 bug 的新知识。此前我从来没想到过数据精度竟然会大幅影响某个值的结果。这段经历启示我们,做一些复杂运算时,不要用 PyTorch 算,最好拿 NumPy 等更稳定的库来计算。如果你调 bug 的经验不足,这段经历也能给你许多参考。
文章最后我分享了一种算全局 FID 的方法。它可以高效反映生成模型在训练时的拟合情况。该功能很容易实现,感兴趣的话可以自己尝试一下。