专栏名称: 雷克世界
赛迪研究院(CCID)主办的新媒体平台,每天跟你聊聊机器人、人工智能、智能制造领域的那些你想知道的事……
目录
相关文章推荐
51好读  ›  专栏  ›  雷克世界

Facebook发布的LR-GAN如何生成图像?这里有一篇Pytorch教程

雷克世界  · 公众号  · 机器人  · 2017-08-04 16:39

正文

原文来源GitHub

「机器人圈」编译:嗯~阿童木呀


本文主要是介绍2017年国际学习表征会议(ICLR 2017)的论文《LR-GAN:分层递归生成对抗网络进行图像生成》在pytorch上的实现。该论文由弗吉尼亚理工大学、FAIR(脸书AI实验室)、佐治亚理工学院合著。该文的第一作者Jianwei Yang近日在GitHub上发文,说明该论文如何用Pytorch实现。




在论文中,我们提出,鉴于图像本身带有结构和内容,可采用LR-GAN(分层递归生成对抗网络)以递归的方式逐层生成图像。如下图所示,LR-GAN首先生成背景图像,然后生成具有外观、姿态和形状的前景。之后,LR-GAN将以一种与情境相关的方式将前景置于相应的背景之上,从而生成一个完整的自然图像。整个模型都是无监督的,并且采用梯度下降的方法,以端到端的方式进行训练。



通过这种方式,LR-GAN便可以明显减少背景和前景之间的混合。而无论是定性或定量的比较,都表明,相较于基准DCGAN模型,LR-GAN可以产生更好、更清晰的图像。


免责声明


这是基于Pytorch的LR-GAN代码。它是以Pytorch DCGAN为基础进行开发的。我们的原始代码是在第一作者实习期间基于Torch实现的。本文所呈现的所有结果都是基于Torch代码获得的,由于版权限制不能将其发布。此项目旨在重现论文中所得到的结果。


引文


如果你发现此代码对你有所帮助,想要了解更多,请引用以下文章:


@article{yang2017lr,
    title={LR-GAN: Layered recursive generative adversarial networks for image generation},
    author={Yang, Jianwei and Kannan, Anitha and Batra, Dhruv and Parikh, Devi},
    journal={ICLR},
    year={2017}
}

实验基础


1.PyTorch,请使用正确的命令安装PyTorch,同时确保你也安装了torchvision。


2. STNM(Spatial transformer network with mask,具有掩码的空间变换神经网络)。我们已在此项目中提供了这个模块。但是如果你想对此做些改变,请参考此项目


训练LR-GAN


准备


将此项目置于你自己的机器上,并确保Pytorch已成功安装。然后,你需要创建一个保存训练集的数据集文件夹;一个保存生成结果的图像文件夹,以及一个保存模型(生成器和鉴别器)的模型文件夹:


$ mkdir datasets
$ mkdir images
$ mkdir models


接下来,你就可以尝试在数据集上训练LR-GAN模型:1)MNIST-ONE; 2)MNIST-TWO; 3)CUB-200; 4)CIFAR-10。样本图像如下所示:

 





在数据集文件夹中,分别为所有这些数据集创建子文件夹:


$ mkdir datasets/mnist-one
$ mkdir datasets/mnist-two
$ mkdir datasets/cub200
$ mkdir datasets/cifar10

训练


我们的模型是分别在MNIST-ONE、MNIST-TWO、CIFAR-10和CUB-200四个数据集上进行训练的。接下来我们将分别介绍在不同数据集上进行训练的操作和结果。


1.MNIST-ONE,我们首先在MNIST-ONE上进行实验,可以从这里下载。将其解压缩到datasets / mnist-one文件夹中,然后运行以下命令:


$ python train.py \
      --dataset mnist-one \
      --dataroot datasets/mnist-one \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 50 \
      --session 1

其中,ntimestep指定递归层的数量,例如,2表示一个背景层和一个前景层;imageSize是指训练图像的比例尺寸;maxobjscale是指最大目标(前景)比例,值越大,目标尺寸越小;session指定了训练会话;niter指定训练时期的数量。以下是在训练周期为50,使用训练模型的随机生成结果:


生成背景图像


前景图像


前景掩码


最终图像


2.CUB200,我们在64x64的CUB200上运行。这是处理过的数据集。同样,将其下载并解压缩到datasets / cub200。然后,运行以下命令:


$ python train.py \
      --dataset cub200 \
      --dataroot datasets/cub200 \
      --ntimestep 2 \
      --imageSize 64 \
      --ndf 128 \
      --ngf 128 \
      --maxobjscale 1.2 \
      --niter 200 \
      --session 1

基于上述指令,我们获得了与论文中所提及的相同的模型。以下是随机生成的图像: 






由此我们可以看到,此中布局类似于MNIST-ONE。正如我们所看到的那样,生成器产生鸟形掩码,从而使最终的图像更为清晰。


3.CIFAR-10,CIFAR-10可以使用pytorch dataloader自动下载。我们生成只需要使用两个时间步长。为了训练模型,需要运行:


$ python train.py \
      --dataset cifar10 \
      --dataroot datasets/cifar10 \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 100 \
      --session 1

以下是一些随机采样的生成结果:


生成背景图像


前景图像


前景掩码


 最终的图像


从上面几幅图片中,我们可以清楚地发现一些生成的马形,鸟形和船形掩码图像,同时,最终生成的图像也更为清晰。


4.MNIST-TWO,图像为64x64,且包含两位数字。我们使用以下命令训练模型:


$ python train.py \
      --dataset mnist-two \
      --dataroot datasets/mnist-two \
      --ntimestep 3 \
      --imageSize 64 \
      --maxobjscale 2 \
      --niter 50 \
      --session 1










可以看到,整体布局与我们论文中的布局是一样的。


5.LFW,我们训练的是64x64的图像,图像资源可以从这里下载。同样,我们需要将其解压缩到文件夹datasets / lfw。我们使用以下命令训练模型:


$ python train.py \
      --dataset lfw \
      --dataroot datasets/lfw \
      --ntimestep 2 \
      --imageSize 64 \
      --maxobjscale 1.3 \
      --niter 100 \
      --session 1

以下是生成结果:

 

真实的图像


生成背景图像


生成前景图像


掩码图像


最终图像


 测试LR-GAN


训练结束后,检查点将保存到models 中。你可以再附加两个选项(netGevaluate)到用于训练模型的命令中。以cifar10为例,其代码将如下所示:


$ python train.py \
      --dataset cifar10 \
      --dataroot datasets/cifar10 \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 100 \
      --session 1
      --netG models/cifar10_netG_s_1_epoch_100.pth
      --evaluate True

然后,你将在images文件夹中获得模型的会话为1、周期为100的生成结果。


有关此实验的更多理论信息,可点击链接查看完整论文

原文链接:https://github.com/jwyang/lr-gan.pytorch/blob/master/README.md


回复转载获得授权,微信搜索ROBO_AI关注公众号


欢迎加入


中国人工智能产业创新联盟在京成立 近200家成员单位共推AI发展



关注“机器人圈”后不要忘记置顶

我们还在搜狐新闻、机器人圈官网、腾讯新闻、网易新闻、一点资讯、天天快报、今日头条、QQ公众号…

↓↓↓点击阅读原文查看中国人工智能产业创新联盟手册