20

前沿科技探索家 · 2021年03月23日

基于飞桨复现 GLCLC 模型,对残次图片实现图像补全

本次复现使用的数据集是CelebA人脸数据集,这是一个大规模的人脸属性数据集,是由香港中文大学汤晓鸥教授实验室公布的大型人脸识别数据集,拥有超过20万张名人图像,已下载放置在此项目的数据集中,人脸属性有40多种。

本文项目代码github地址:

https://github.com/Eric-Hjx/P...

模型摘要

在此篇论文中,作者们提出了Globally and Locally Consistent Image Completion方法,可以使得图像的缺失部分自动补全,局部和整图保持一致。作者通过全卷积网络,可以补全图片中任何形状的缺失,为了保持补全后的图像与原图的一致性,作者使用全局(整张图片)和局部(缺失补全部分)两种鉴别器来训练。全局鉴别器查看整个图像以评估它是否作为整体是连贯的,而局部鉴别器仅查看以完成区域为中心的小区域来确保所生成的补丁的局部一致性。

接着对图像补全网络训练以欺骗两个内容鉴别器网络,这要求它生成总体以及细节上与真实无法区分的图像。我们证明了我们的方法可以用来完成各种各样的场景。此外,与PatchMatch等基于补丁的方法相比,我们的方法可以生成图像中未出现的碎片,这使我们能够自然地完成具有熟悉且高度特定的结构(如面部)的对象的图像。

1.jpg

该论文的方法,完全以卷积网络作为基础,使用了GAN网络的思路,设计了两部分(三个网络),一部分用于生成图像,即补全网络,一部分用于鉴别生成图像是否与原图像一致,即全局鉴别器和局部鉴别器。网络结构图如下所示:

1.jpg

网络介绍:

补全网络:补全网络是完全卷积的,目的是用来修复图像。
全局鉴别器:以完整的图像作为输入,识别场景的全局一致性。
局部鉴别器:只关注完成区域周围的一个小区域,以判断更详细的外观质量。

基于飞桨实现GLCLC算法

下面我们基于飞桨开源深度学习框架动手实现 GLCLC 算法,介绍神经网络代码实现内容,主要使用了卷积、反卷积、空洞卷积、正则、激活函数等方法搭建了补全网络及鉴别网络。

1. 补全网络结构

补全网络部分,作者采用12层卷积网络对输入图像进行encoding,得到一张原图16分之一大小的网格。然后再对该网格采用4层卷积网络进行decoding。为了保证生成区域尽量不模糊,文中降低分辨率的操作是使用strided convolution 的方式进行的,而且只用了两次,将图片的size 变为原来的四分之一。同时在中间层还使用了空洞卷积来增大感受野,在尽量获取更大范围内的图像信息的同时不损失额外的信息,从而得到复原图像。下表为补全网络各层参数分布情况。

输入为RGB图像与二进制掩码(需要填充的区域以1填充)的组合图像;输出为RGB图像。

1.jpg

搭建补全网络

def generator(x):

# conv1
conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=1,padding='SAME',name='generator_conv1',data_format='NHWC')
conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
conv1 = fluid.layers.relu(conv1, name=None)
# conv2
conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv2',data_format='NHWC')
conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
conv2 = fluid.layers.relu(conv2, name=None)
# conv3
conv3 = fluid.layers.conv2d(input=conv2,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv3',data_format='NHWC')
conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
conv3 = fluid.layers.relu(conv3, name=None)
# conv4
conv4 = fluid.layers.conv2d(input=conv3,num_filters=256,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv4',data_format='NHWC')
conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
conv4 = fluid.layers.relu(conv4, name=None)
# conv5
conv5 = fluid.layers.conv2d(input=conv4,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv5',data_format='NHWC')
conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
conv5 = fluid.layers.relu(conv5, name=None)
# conv6
conv6 = fluid.layers.conv2d(input=conv5,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv6',data_format='NHWC')
conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)

conv6 = fluid.layers.relu(conv6, name=None)

# 空洞卷积
# dilated1
dilated1 = fluid.layers.conv2d(input=conv6,num_filters=256,filter_size=3,dilation=2,padding='SAME',name='generator_dilated1',data_format='NHWC')
dilated1 = fluid.layers.batch_norm(dilated1, momentum=0.99, epsilon=0.001)
dilated1 = fluid.layers.relu(dilated1, name=None)
# dilated2
dilated2 = fluid.layers.conv2d(input=dilated1,num_filters=256,filter_size=3,dilation=4,padding='SAME',name='generator_dilated2',data_format='NHWC') #stride=1
dilated2 = fluid.layers.batch_norm(dilated2, momentum=0.99, epsilon=0.001)
dilated2 = fluid.layers.relu(dilated2, name=None)
# dilated3
dilated3 = fluid.layers.conv2d(input=dilated2,num_filters=256,filter_size=3,dilation=8,padding='SAME',name='generator_dilated3',data_format='NHWC')
dilated3 = fluid.layers.batch_norm(dilated3, momentum=0.99, epsilon=0.001)
dilated3 = fluid.layers.relu(dilated3, name=None)
# dilated4
dilated4 = fluid.layers.conv2d(input=dilated3,num_filters=256,filter_size=3,dilation=16,padding='SAME',name='generator_dilated4',data_format='NHWC')
dilated4 = fluid.layers.batch_norm(dilated4, momentum=0.99, epsilon=0.001)
dilated4 = fluid.layers.relu(dilated4, name=None)
# conv7
conv7 = fluid.layers.conv2d(input=dilated4,num_filters=256,filter_size=3,dilation=1,name='generator_conv7',data_format='NHWC')
conv7 = fluid.layers.batch_norm(conv7, momentum=0.99, epsilon=0.001)
conv7 = fluid.layers.relu(conv7, name=None)
# conv8
conv8 = fluid.layers.conv2d(input=conv7,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv8',data_format='NHWC')
conv8 = fluid.layers.batch_norm(conv8, momentum=0.99, epsilon=0.001)
conv8 = fluid.layers.relu(conv8, name=None)
# deconv1
deconv1 = fluid.layers.conv2d_transpose(input=conv8, num_filters=128, output_size=[64,64],stride = 2,name='generator_deconv1',data_format='NHWC')

deconv1 = fluid.layers.batch_norm(deconv1, momentum=0.99, epsilon=0.001)

deconv1 = fluid.layers.relu(deconv1, name=None)
# conv9
conv9 = fluid.layers.conv2d(input=deconv1,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv9',data_format='NHWC')
conv9 = fluid.layers.batch_norm(conv9, momentum=0.99, epsilon=0.001)
conv9 = fluid.layers.relu(conv9, name=None)
# deconv2
deconv2 = fluid.layers.conv2d_transpose(input=conv9, num_filters=64, output_size=[128,128],stride = 2,name='generator_deconv2',data_format='NHWC')
deconv2 = fluid.layers.batch_norm(deconv2, momentum=0.99, epsilon=0.001)
deconv2 = fluid.layers.relu(deconv2, name=None)
# conv10
conv10 = fluid.layers.conv2d(input=deconv2,num_filters=32,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv10',data_format='NHWC')
conv10 = fluid.layers.batch_norm(conv10, momentum=0.99, epsilon=0.001)
conv10 = fluid.layers.relu(conv10, name=None)
# conv11
x = fluid.layers.conv2d(input=conv10,num_filters=3,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv11',data_format='NHWC')
x = fluid.layers.tanh(x)
return x
2. 内容鉴别器

内容鉴别器分为了两个部分,一个全局鉴别器(Global Discriminator)以及一个局部鉴别器(Local Discriminator)。全局鉴别器是将一张完整的图像作为输入数据,对图像的全局一致性做出判断;局部鉴别器仅在以填充区域为中心的原图像四分之一大小区域上观测,对此部分图像的一致性做出判断。通过采用上述两个不同的鉴别器,可以使得最终的网络,不但可以对图像全局一致性做判断,并且能够通过局部鉴别方法,优化生成图的细节,最终能产生更好的图片填充效果。

在原文中,作者设定的全局鉴别网络输入是256X256X3的图片,局部网络输入是128X128X3的图片。原始论文中,全局网络和局部网络都会通过使用5X5的卷积层、2X2的stride降低图像分辨率,通过全连接,分别得到一个1024维的向量。然后,作者将全局和局部两个鉴别器的输出连接成一个2048维向量,再通过一个全连接,然后用sigmoid函数对整体的图像的一致性进行打分判别。但在本次实验,为了能降低训练难度,设定全局鉴别网络输入是128X128X3的图片,局部网络输入是64X64X3的图片。

1.jpg

搭建内容鉴别器

   def discriminator(global_x, local_x):
def global_discriminator(x):
    # conv1
    conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv1',data_format='NHWC')
    conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    conv1 = fluid.layers.relu(conv1, name=None)
    # conv2
    conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv2',data_format='NHWC')
    conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    conv2 = fluid.layers.relu(conv2, name=None)
    # conv3
    conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv3',data_format='NHWC')
    conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    conv3 = fluid.layers.relu(conv3, name=None)
    # conv4
    conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv4',data_format='NHWC')
    conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    conv4 = fluid.layers.relu(conv4, name=None)
    # conv5
    conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv5',data_format='NHWC')
    conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    conv5 = fluid.layers.relu(conv5, name=None)
    # conv6
    conv6 = fluid.layers.conv2d(input=conv5,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv6',data_format='NHWC')
    conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)

conv6 = fluid.layers.relu(conv6, name=None)

    # fc
    x = fluid.layers.fc(input=conv6, size=1024,name='discriminator_global_fc1')
    return x

def local_discriminator(x):
    # conv1
  conv1  =fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv1',data_format='NHWC')
    conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    conv1 = fluid.layers.relu(conv1, name=None)
    # conv2
    conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv2',data_format='NHWC')
    conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    conv2 = fluid.layers.relu(conv2, name=None)
    # conv3
    conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv3',data_format='NHWC')
    conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    conv3 = fluid.layers.relu(conv3, name=None)
    # conv4
    conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv4',data_format='NHWC')
    conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    conv4 = fluid.layers.relu(conv4, name=None)
    # conv5
    conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv5',data_format='NHWC')
    conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    conv5 = fluid.layers.relu(conv5, name=None)
    # fc
    x = fluid.layers.fc(input=conv5, size=1024,name='discriminator_lobal_fc1')
    return x

global_output = global_discriminator(global_x)
local_output = local_discriminator(local_x)
print('global_output',global_output.shape)
print('local_output',local_output.shape)
output = fluid.layers.concat([global_output, local_output], axis=1)
output = fluid.layers.fc(output, size=1,name='discriminator_concatenation_fc1')

return output


3. 损失函数

生成网络使用weighted Mean Squared Error (MSE)作为损失函数,计算原图与生成图像像素之间的差异,表达式如下所示:

1.png

鉴别器网络使用GAN损失函数,其目标是最大化生成图像和原始图像的相似概率,表达式如下所示:
1.png

最后结合两者损失,形成下式:
1.png

网络训练

原文作者使用4个K80 GPU,使用的输入图像大小是256256,训练了2个月才训练完成。

本项目为了缩短训练时间,仅采用了此论文核心思想、网络结构、优化目标等,并对训练方式及部分细节做了简化。使用的输入图像大小:128128,训练方式设定为:先训练生成器再将生成器和判别器一起训练。

生成器优先迭代次数

NUM_TRAIN_TIMES_OF_DG = 100

总迭代轮次

epoch = 200

step_num = int(len(x_train) / BATCH_SIZE)

np.random.shuffle(x_train)

for pass_id in range(epoch):

训练生成器
if pass_id <= NUM_TRAIN_TIMES_OF_DG:
    g_loss_value = 0
    for i in tqdm.tqdm(range(step_num)):
        x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        points_batch, mask_batch = get_points()
        # print(x_batch.shape)
        # print(mask_batch.shape)
        dg_loss_n = exe.run(dg_program,
                             feed={'x': x_batch, 
                                    'mask':mask_batch,},
                             fetch_list=[dg_loss])[0]
        g_loss_value += dg_loss_n
    print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))

    np.random.shuffle(x_test)
    x_batch = x_test[:BATCH_SIZE]

    completion_n = exe.run(dg_program, 
                    feed={'x': x_batch, 
                           'mask': mask_batch,},
                    fetch_list=[completion])[0][0]
修复图片
    sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
    # 原图
    x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
    # 挖空洞输入图
    input_im_data = x_im * (1 - mask_batch[0])
    input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
    output_im = np.concatenate((x_im,input_im,sample),axis=1)
    #print(output_im.shape)
    cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
    # 保存模型
    save_pretrain_model_path = 'models/'
    # 创建保持模型文件目录
    #os.makedirs(save_pretrain_model_path)
    fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program=dg_program)
生成器判断器一起训练
else:
    g_loss_value = 0
    d_loss_value = 0
    for i in tqdm.tqdm(range(step_num)):
        x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        points_batch, mask_batch = get_points()
        dg_loss_n = exe.run(dg_program,
                             feed={'x': x_batch, 
                                    'mask':mask_batch,},
                             fetch_list=[dg_loss])[0]
        g_loss_value += dg_loss_n

        completion_n = exe.run(dg_program, 
                            feed={'x': x_batch, 
                                    'mask': mask_batch,},
                            fetch_list=[completion])[0]
        local_x_batch = []
        local_completion_batch = []
        for i in range(BATCH_SIZE):
            x1, y1, x2, y2 = points_batch[i]
            local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
            local_completion_batch.append(completion_n[i][y1:y2, x1:x2, :])
        local_x_batch = np.array(local_x_batch)
        local_completion_batch = np.array(local_completion_batch)
        d_loss_n  = exe.run(d_program,
                            feed={'x': x_batch, 'mask': mask_batch, 'local_x': local_x_batch, 'global_completion': completion_n, 'local_completion': local_completion_batch},
                            fetch_list=[d_loss])[0]
        d_loss_value += d_loss_n
    print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))
    print('Pass_id:{}, Discriminator loss: {}'.format(pass_id, d_loss_value))

    np.random.shuffle(x_test)
    x_batch = x_test[:BATCH_SIZE]
    completion_n = exe.run(dg_program, 
                    feed={'x': x_batch, 
                            'mask': mask_batch,},
                    fetch_list=[completion])[0][0]
    # 修复图片
    sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
    # 原图
    x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
    # 挖空洞输入图
    input_im_data = x_im * (1 - mask_batch[0])
    input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
    output_im = np.concatenate((x_im,input_im,sample),axis=1)
    #print(output_im.shape)
    cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
    # 保存模型
    save_pretrain_model_path = 'models/'
    # 创建保持模型文件目录
    #os.makedirs(save_pretrain_model_path)
    fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program = dg_program)
    
    
结果展示

1.jpg

项目总结

整个训练过程,花了9小时左右,共训练了100次补全网络+45次补全网络和鉴别网络。

Image Completion Result 中的 Input 是挖洞后输入补全网络的图像,在 Output 看到, Input 图像上挖的洞已经被补上了,这说明现在的训练结果已经能在一定程度上补全图像的缺失部分了。由于本项目实现时在硬件及时间方面受限,因此对原文中的方法进行了简化,训练方法和数据样本处理较原论文有所调整做了调整,无法达到原论文效果,但相较于原作者两个月的训练时间对比,这样的训练方式也是可取的。

如想到达到原论文的精准的小伙伴,可以在本项目基础上修改训练策略~在此附上原论文训练程序图
1.jpg

本项目使用了飞桨开源深度学习框架,在AI Studio上完成了数据处理、模型训练、效果预测等整个工作过程,非常感谢AI Studio给我们提供的GPU在线训练环境,对于在深度学习道路上硬件条件上不足的学生来说简直是非常大的帮助。

如果你对这个小实验感兴趣,也可以自己来尝试一下,整个项目包括数据集与相关代码已公开在AI Studio上,欢迎小伙伴们Fork。

https://aistudio.baidu.com/ai...

如在使用过程中有问题,可加入飞桨官方QQ群进行交流:1108045677。

如果您想详细了解更多飞桨的相关内容,请参阅以下文档。

**·飞桨开源框架项目地址·

GitHub: https://github.com/PaddlePadd... Gitee: https://gitee.com/paddlepaddl...

·飞桨官网地址·

https://www.paddlepaddle.org.cn/**

推荐阅读
关注数
12971
内容数
325
带你捕获最前沿的科技信息,了解最新鲜的科技资讯
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息