Universal Image Restoration Pre-training via Degradation Classification
论文地址:
https://openreview.net/forum?id=PacBhLzeGO
代码地址:
https://github.com/MILab-PKU/dcpt
背景
图像复原是利用模型将低质量(LQ)图像改进为高质量(HQ)图像的任务,在深度学习时代,图像复原任务可以被进一步理解为:
以低质量图像为条件生成高质量图像
。
通用图像复原(Universal Image Restoration, UIR)任务是图像复原的一项重要的子任务。UIR 试图创造一种方法,使得模型能够自主的应对不同退化,并生成语义、细节纹理一致的高质量图像。可以简单地认为,一个合格的UIR模型应当包含以下两种能力:
-
退化判别:用于提升模型对输入低质量图像的退化的鉴别能力,使得模型能够“自如”地使用自身参数进行自适应复原(这种解释的正确性有待商榷,但已经有大量文献证明退化判别能力的引入有助于图像复原性能增长)
-
生成高质量图像:生成prior将有助于复原能力的提升,尤其在输入图像退化极其严重的情况下。在干净、高质量图像数据集下训练的生成模型,能够促进复原模型恢复出干净、高质量的图像。
这导向了两种不同的通用图像复原方法设计思路:(1)促进退化判别;(2)引入生成Prior。其中前者已经被得到广泛的研究。流行的方法使用输入图像的退化表征作为判别提示,如:梯度、频率、附加参数和经神经网络压缩的抽象特征等等。虽然这些方法通过使用精确有效的退化提示获得了很高的复原性能,但它们
未能利用复原模型本身所蕴含的潜在先验信息
。
DCPT的诞生来源于对复原模型自我退化判别能力的分析。
发现
我们对复原模型自身的退化判别能力进行了分析,并得到三个有趣的发现:
-
-
在一体化(All-in-one)复原任务中训练的模型表现出辨别未知退化的能力;
-
我们进行了一项简单的预实验来说明这三点:我们提取了复原训练过程中网络复原头之前的输出特征,训练过程中,模型仅见到雾霾、雨、高斯噪声三种退化。根据该特征, kNN 分类器将对五种退化类型(包括雾霾、雨天、高斯噪声、运动模糊和弱光)进行分类。
预实验结果如下:
|
|
|
|
|
Acc. on Random initialized (%)
|
|
|
|
|
Acc. on 3D all-in-one trained 200k iterations (%)
|
|
|
|
|
Acc. on 3D all-in-one trained 400k iterations (%)
|
|
|
|
|
Acc. on 3D all-in-one trained 600k iterations (%)
|
|
|
|
|
可以看到四种网络在网络初始化时就表现出52%~71%的分类准确率,且在复原训练过程早期(前200k次迭代)快速收敛到90%以上的分类准确率。
-
当退化数量进一步增多,…
遗憾的是,我们发现复原模型对未知退化的辨别能力会随着退化种类的增多而逐渐减弱。我们将在后续工作中对此进行更充分的讨论。
动机
由于图像复原的核心任务还是以低质量图像为条件生成高质量图像,我们不希望在复原训练过程中出现与该任务存在潜在冲突的其他训练子任务,例如退化分类。于是,我们选择将显式地将该训练阶段提前为“预训练”,并进一步创造了DCPT。
方法
Degradation Classification Pre-Training
(DCPT) 是一个简单且有效的方法,可见下图。
在单次迭代中,它包含两个阶段:退化分类阶段、生成阶段,这两个阶段交替进行。其中,
-
退化分类阶段:通过提取复原网络的深层特征,并将其输入一个轻量级分类器,以对输入图像的退化种类进行分类。
-
生成阶段:我们利用最原始的Autoencoder手段对复原模型的生成能力进行保留。
实现代码也非常简洁:
### train to generate the clean image
encoder.train()
decoder.eval()
optimizer_encoder.zero_grad()
pix_output = encoder(gt, hook=False)
l_total = 0
# pixel loss
if cri_pixel:
l_pix = cri_pixel(pix_output, gt)
l_total += l_pix
### train to classify the degradation
decoder.train()
optimizer_decoder.zero_grad()
hook_outputs = encoder(lq, hook=True)
cls_output = decoder(lq, hook_outputs[::-1])
# classification loss
if cri_cls:
l_cls = cri_cls(cls_output, dataset_idx)
l_total += l_cls
l_total.backward()
optimizer_encoder.step()
optimizer_decoder.step()
需要注意,在预训练结束后,仍需要进行复原任务上的fine-tune。
实验结果
5D All-in-one image restoration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DCPT-SwinIR
|
28.67
|
35.70
|
31.16
|
26.42
|
20.38
|
28.47
|
|
|
|
|
|
|
|