我们在 TF-GAN 上的暑期代码项目
2022 年 1 月 10 日

作者:Nived P AMargaret Maynard-ReidJoel Shor

Google 暑期代码项目 是一个每年夏季将学生开发者引入开源项目的计划。本文介绍了 Amrita 工程学院本科生 Nived PA 去年夏天提出的,对 TensorFlow GAN 库 (TF-GAN) 进行的增强。Nived 项目的目标是通过添加新的教程和为库本身添加新功能来改进 TF-GAN 库。

本文概述了 TF-GAN 以及我们去年夏天的成就。我们将从学生和导师的角度分享我们的经验,并逐步介绍 Nived 创建的新的教程之一,一个 ESRGAN TensorFlow 实现,并向您展示如何轻松使用 TF-GAN 来帮助训练和评估。

什么是 TF-GAN?

TF-GAN 为训练 GAN 提供了常见的构建块和基础设施支持,并提供易于使用且标准的评估技术。使用 TF-GAN 可以帮助开发人员和研究人员节省使用常见 GAN 工具的时间,并避免实现中的常见陷阱。此外,TF-GAN 还提供了一系列著名的示例,其中包括来自图像和音频领域的 GAN,以及 GPU 和 TPU 支持。

2017 年推出 以来,该团队已 更新了基础设施 以与 TensorFlow 2.0 协同工作,并发布了 自学 GAN 课程,该课程在 2020 年吸引了超过 15 万人的观看,以及关于 GAN 的 ML 技术讲座。该项目本身已被下载了 数百万次。使用 TF-GAN 的论文已获得数千次引用(例如 12345)。

TF-GAN 库可以分为多个独立的部分,即 核心功能损失评估示例。所有这些不同的部分都可以用来简化 GAN 的训练或评估过程。

项目范围

TF-GAN 上的 Google 暑期代码项目 2021 旨在将更多最新的 GAN 模型作为示例添加到库中,并额外添加更多教程笔记本,这些笔记本在训练和评估最先进的 GAN 模型(如 ESRGAN)的同时探索 TF-GAN 的各种功能。通过这个项目,新的损失函数也被添加到库中,可以改进 GAN 的训练过程。接下来,我们将逐步介绍 ESRGAN 代码,并演示如何使用 TF-GAN 来帮助训练和评估。

如果您不熟悉 GAN,那么一个好的起点是阅读这篇由 Margaret(该项目的导师)撰写的 GAN 简介 文章,这些 GAN 教程 在 tensorflow.org 上,以及上面提到的机器学习速成课程上的自学 GAN 课程

使用 TF-GAN 的 ESRGAN

图像超分辨率是 GAN 的一个重要用例。超分辨率是从给定的低分辨率 (LR) 图像重建高分辨率 (HR) 图像的过程。超分辨率可以应用于解决诸如照片编辑之类的现实世界问题。

SRGAN 论文(使用生成对抗网络实现逼真的单图像超分辨率)引入了单图像超分辨率的概念,并使用残差块和感知损失来实现。ESRGAN(增强型超分辨率生成对抗网络)论文通过引入没有批归一化的残差-残差密集块 (RRDB) 作为基本构建块,使用相对损失并改进感知损失,对 SRGAN 进行了增强。

现在让我们逐步介绍如何使用 TensorFlow 2 实现 ESRGAN 并使用 TF-GAN 评估其性能。Colab 笔记本有两个版本:一个使用 GPU,另一个 使用 TPU。我们将介绍 Colab 笔记本 TPU 版本。

先决条件

首先,让我们确保我们已准备好使用 Colab TPU 和 Google Cloud Storage 存储桶。

  1. Colab TPU
  2. 要在 Colab 中启用 TPU 运行时,请转到编辑 → 笔记本设置或运行时 → 更改运行时类型,然后从硬件加速器下拉菜单中选择“TPU”。

  3. Google Cloud Storage 存储桶

为了使用 TPU 进行训练,我们需要首先设置一个 Google Cloud Storage 存储桶,以在训练期间存储数据集和模型权重。请参阅 Google Cloud 文档,了解有关 创建存储桶 的内容。创建存储桶后,让我们从 Colab 进行身份验证,以便您可以授予 Google Cloud SDK 访问存储桶的权限

bucket = 'enter-your-bucket-name-here'
tpu_address = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

from google.colab import auth
auth.authenticate_user()

tf.config.experimental_connect_to_host(tpu_address)
tensorflow_gcs_config.configure_gcs_from_colab_auth()

您将被提示在浏览器中关注一个链接,以验证与存储桶的连接。单击该链接将带您进入一个新的浏览器选项卡。按照那里的说明获取验证代码,然后返回 Colab 笔记本输入代码。现在您应该能够在笔记本的剩余部分访问存储桶。

训练参数

现在我们已经为 Colab 启用了 TPU 并设置了 GCS 云存储桶以存储训练数据和模型权重,我们首先定义一些将从数据加载到模型训练中使用的参数,例如批次大小、HR 图像分辨率以及将图像缩放到 LR 的比例等。

Params = {
   'batch_size' : 32,    # Number of image samples used in each training step         
   'hr_dimension' : 256,          # Dimension of a High Resolution (HR) Image
   'scale' : 4, # Factor by which Low Resolution (LR) Images to be downscaled.
   'data_name': 'div2k/bicubic_x4',       # Dataset name - loaded using tfds.
   'trunk_size' : 11,            # Number of Residual blocks used in Generator
   ...
}

数据

我们使用 DIV2K 数据集:多样化的 2k 分辨率高质量图像。我们将使用 TensorFlow Datasets (tfds) API 将数据加载到我们的云存储桶中。

我们需要高分辨率 (HR) 和低分辨率 (LR) 数据进行训练。因此,我们将下载原始图像并将它们缩放到 96x96 用于 HR,并将它们缩放到 28x28 用于 LR。

注意:数据下载和缩放以存储在云存储桶中可能需要 30 多分钟。

可视化数据集

让我们可视化下载并缩放的数据集

img_lr, img_hr = next(iter(train_ds))

lr = Image.fromarray(np.array(img_lr)[0].astype(np.uint8))
lr = lr.resize([256, 256])
display(lr)

hr = Image.fromarray(np.array(img_hr)[0].astype(np.uint8))
hr = hr.resize([256, 256])
display(hr)

模型架构

我们将首先定义生成器架构、鉴别器架构和损失函数;然后将所有内容组合在一起形成 ESRGAN 模型。

生成器 - 与大多数 GAN 生成器一样,ESRGAN 生成器会对输入进行几次上采样。与众不同的是没有批归一化的残差-残差块 (RRDB)。

在生成器中,我们定义了用于创建 Conv 块、Dense 块、RRDB 块用于上采样的函数。然后,我们定义了一个函数来创建生成器网络,如下所示,使用 Keras Functional API

def generator_network(filter=32,
                     trunk_size=Params['trunk_size'],
                     out_channels=3):
 lr_input = layers.Input(shape=(None, None, 3))
  
 x = layers.Conv2D(filter, kernel_size=[3,3], strides=[1,1],
                   padding='same', use_bias=True)(lr_input)
 x = layers.LeakyReLU(0.2)(x)
  ref = x
 for i in range(trunk_size):
     x = rrdb(x)

 x = layers.Conv2D(filter, kernel_size=[3,3], strides=[1,1],
                   padding='same', use_bias = True)(x)
 x = layers.Add()([x, ref])

 x = upsample(x, filter)
 x = upsample(x, filter)
  x = layers.Conv2D(filter, kernel_size=3, strides=1,
                   padding='same', use_bias=True)(x)
 x = layers.LeakyReLU(0.2)(x)
 hr_output = layers.Conv2D(out_channels, kernel_size=3, strides=1,
                           padding='same', use_bias=True)(x)

 model = tf.keras.models.Model(inputs=lr_input, outputs=hr_output)
 return model

鉴别器

鉴别器是一个相当简单的 CNN,具有 Conv2DBatchNormalizationLeakyReLUDense 层。同样,使用 Keras Functional API。

def discriminator_network(filters = 64, training=True):
 img = layers.Input(shape = (Params['hr_dimension'], Params['hr_dimension'], 3))
  
 x = layers.Conv2D(filters, [3,3], 1, padding='same', use_bias=False)(img)
 x = layers.BatchNormalization()(x)
 x = layers.LeakyReLU(alpha=0.2)(x)
 
 x = layers.Conv2D(filters, [3,3], 2, padding='same', use_bias=False)(x)
 x = layers.BatchNormalization()(x)
 x = layers.LeakyReLU(alpha=0.2)(x)
 
 x = _conv_block_d(x, filters *2)
 x = _conv_block_d(x, filters *4)
 x = _conv_block_d(x, filters *8)
  x = layers.Flatten()(x)
 x = layers.Dense(100)(x)
 x = layers.LeakyReLU(alpha=0.2)(x)
 x = layers.Dense(1)(x)
 
 model = tf.keras.models.Model(inputs = img, outputs = x)
 return model

损失函数

ESRGAN 模型使用三个损失函数来确保视觉质量和指标(如峰值信噪比 (PSNR))之间的平衡,并鼓励生成器生成具有自然纹理的更逼真的图像

  1. 像素损失 - 生成图像与真实图像之间的像素损失。
  2. 对抗性损失(使用 相对 GAN) - 为 G 和 D 计算。
  3. 感知损失 - 使用预先训练的 VGG-19 网络计算。

让我们深入了解这里的对抗性损失 ,因为这是最复杂的损失,并且它是在该项目中添加到 TF-GAN 库中的一个函数。

在 GAN 中,鉴别器网络将输入数据分类为真实或虚假。生成器经过训练,以生成虚假数据并欺骗鉴别器错误地将其分类为真实数据。随着生成器提高虚假数据为真实数据的概率,真实数据为真实数据的概率也应该下降。正如这篇论文 中指出的,这是标准 GAN 缺失的一个属性,并且引入了相对鉴别器来克服这个问题。相对平均鉴别器估计给定真实数据比虚假数据更逼真的概率(平均而言)。这提高了生成数据的质量和模型在训练过程中的稳定性。在 TF-GAN 库中,请参阅relativistic_generator_loss 和 relativistic_discriminator_loss,了解此损失函数的实现。

def ragan_generator_loss(d_real, d_fake):
 real_logits = d_real - tf.reduce_mean(d_fake)
 fake_logits = d_fake - tf.reduce_mean(d_real)
  real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     labels=tf.zeros_like(real_logits), logits=real_logits)) 
 fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     labels=tf.ones_like(fake_logits), logits=fake_logits))

 return real_loss + fake_loss

def ragan_discriminator_loss(d_real, d_fake):
 def get_logits(x, y):
   return x - tf.reduce_mean(y)
  real_logits = get_logits(d_real, d_fake)
 fake_logits = get_logits(d_fake, d_real)

 real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
         labels=tf.ones_like(real_logits), logits=real_logits))
 fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     labels=tf.zeros_like(fake_logits), logits=fake_logits))

 return real_loss + fake_loss

训练

ESRGAN 模型分两个阶段进行训练

  • 阶段 1:单独训练生成器网络,旨在通过减少 L1 损失来提高生成图像的 PSNR 值。
  • 阶段 2:继续训练相同的生成器模型以及鉴别器网络。在第二阶段,生成器减少了 L1 损失、相对平均 GAN (RaGAN) 损失(指示生成图像看起来有多逼真)以及论文中提出的改进的感知损失。

如果从头开始,阶段 1 训练可以在免费的 colab TPU 上在一个小时内完成,而阶段 2 则需要大约 2-3 个小时才能获得良好的结果。因此,在训练过程中保存权重/检查点是重要的步骤。

阶段 1 训练

以下是阶段 1 训练的步骤

  • 定义生成器及其优化器
  • 从训练数据集中获取 LR、HR 图像对
  • 将 LR 图像输入到生成器网络
  • 使用生成的图像和 HR 图像计算 L1 损失
  • 计算梯度值并将其应用于优化器
  • 为了获得更好的性能,在每个衰减步骤之后更新优化器的学习率

第二阶段训练

在这个训练阶段

  • 加载在第一阶段训练的生成器网络
  • 定义训练过程中可能会有用的检查点
  • 使用 VGG-19 预训练网络计算感知损失

然后我们定义训练步骤如下

  • 将 LR 图像输入到生成器网络
  • 计算生成器和鉴别器的 L1 损失、感知损失和对抗损失。
  • 使用获得的梯度值更新两个网络的优化器
  • 为了获得更好的性能,在每个衰减步骤之后更新优化器的学习率
  • TF-GAN 的图像网格函数用于在验证步骤中显示生成的图像

有关完整代码实现,请参阅 Colab 笔记本

在训练过程中,我们可视化 3 张图像:LR 图像、HR 图像(生成)和 HR 图像(训练数据),以及这些指标:生成器损失、鉴别器损失和 PSNR。

步骤 0

生成器损失 = 0.636057436466217

鉴别器损失 = 0.0191921629011631

PSNR:20.95576286315918

以下是训练结束时的更多结果,看起来还不错。

评估

现在训练已经完成,我们将使用 3 个指标来评估 ESRGAN 模型:Fréchet Inception Distance (FID)、Inception Scores 和 Peak signal-to-noise ratio (PSNR)。

FIDInception Scores 是用于评估 GAN 模型性能的两个常用指标。Peak Signal-to- Noise Ratio (PSNR) 用于量化两幅图像之间的相似性,并用于对超分辨率模型进行基准测试。

我们没有从头开始编写代码来计算每个指标,而是使用 TF-GAN 库来轻松地评估我们的 GAN 实现,以获得 FID 和 Inception Scores。然后,我们利用 `tf.image` 模块来计算 PSNR 值,用于评估超分辨率算法。

为什么我们需要 TF-GAN 库进行评估?

GAN 的标准评估指标,例如 Inception Scores、Frechet Distance 或 Kernel Distance,在 TF-GAN 评估 中可用。此类指标的各种实现可能容易出错,这会导致不可靠的评估分数。通过使用 TF-GAN,可以避免此类错误,并简化 GAN 评估。为了评估 ESRGAN 模型,我们使用了来自 TF-GAN 库的 Inception Score (tfgan.eval.inception_score) 和 Frechet Distance Score (tfgan.eval.frechet_inception_distance)。


以下是我们在代码中如何使用 tf-gan 进行评估。

首先,我们需要安装 tf-gan 库,它应该是笔记本开头导入的一部分。然后,我们导入库。

!pip install tensorflow-gan
import tensorflow_gan as tfgan

现在,我们已准备好使用该库进行 ESRGAN 评估!

Fréchet inception distance (FID)

@tf.function
def get_fid_score(real_image, gen_image):
 size = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE

 resized_real_images = tf.image.resize(real_image, [size, size], method=tf.image.ResizeMethod.BILINEAR)
 resized_generated_images = tf.image.resize(gen_image, [size, size], method=tf.image.ResizeMethod.BILINEAR)
  num_inception_images = 1
 num_batches = Params['batch_size'] // num_inception_images
  fid = tfgan.eval.frechet_inception_distance(resized_real_images, resized_generated_images, num_batches=num_batches)
 return fid
Inception Scores
@tf.function
def get_inception_score(images, gen, num_inception_images = 8):
 size = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE
 resized_images = tf.image.resize(images, [size, size], method=tf.image.ResizeMethod.BILINEAR)

 num_batches = Params['batch_size'] // num_inception_images
 inc_score = tfgan.eval.inception_score(resized_images, num_batches=num_batches)

 return inc_score
Peak Signal-to- Noise Ratio (PSNR)
def get_psnr(real, generated):
  psnr_value = tf.reduce_mean(tf.image.psnr(generated, real, max_val=256.0))
  return psnr_value

GSoC 体验

以下是我们用自己的话说出的 Google Summer of Code 2021 体验

Nived

作为一名学生,Google Summer of Code 为我提供了一个参与 TensorFlow 令人兴奋的开源项目的機會,我在这段期间得到的指导非常宝贵。我学到了很多关于实施各种 GAN 模型、编写教程笔记本、使用 Cloud TPU 训练模型以及使用 Google Cloud Platform 等工具的知识。我在整个项目过程中得到了 Margaret 和 Joel 的大力支持,这使项目得以顺利进行。从一开始,他们的建议就帮助定义了项目的范围,在编码阶段,Margaret 和我每周见面一次,以消除我所有的疑虑,并解决我遇到的各种问题。Joel 还帮助审查了对 TF-GAN 库所做的所有 PR。GSoC 确实是一个参与各种有趣的 TensorFlow 库的绝佳方式,我期待着继续为社区做出有价值的贡献。

Margaret

作为项目导师,我从项目选择阶段就开始参与。指导 Nived 并与 Joel 合作进行 TF-GAN 是一段令人满意的经历。Nived 在使用 TensorFlow 2 和 TF-GAN 实施 ESRGAN 论文方面做得非常出色。Nived 和我花了很多时间查看各种文本到图像 GAN,以选择一个可以在 GSoC 时间范围内实施的 GAN。除了编写 ESRGAN 教程外,他还为文本到图像生成中的 ControlGAN 取得了很大进展。我希望这个项目能帮助其他人学习如何使用 TF-GAN 库,并为 TF-GAN 和其他开源 TensorFlow 项目做出贡献。

Joel

作为一名非官方的技术导师,我对 Nived 的独立性和有效工作能力印象深刻。我觉得我更像是与一位初级同事而不是实习生一起工作,因为我帮助提供了技术和项目指导,但最终是 Nived 做出了决定。我认为令人印象深刻的结果反映了这一点:Nived 拥有了这个项目,我认为,结果,示例和 Colab 比它们原本可能的样子更完善、更连贯。此外,Nived 成功地克服了在家办公的多时区现实!

下一步

在 GSoC 编码阶段,ESRGAN 模型的实现已完成,Python 代码和 Colab 笔记本已合并到 TF-GAN 存储库中。ControlGAN 模型用于文本到图像生成的实现仍在进行中。一旦 ControlGAN 的实现完成,我们计划将其扩展到为艺术生成或图像编辑等领域的一些实际应用提供服务。我们还计划编写教程,以探索解决文本到图像翻译任务的不同模型。

如果您想为 TF-GAN 做出贡献,您可以联系 `[email protected]` 以提出项目或补充。除非您之前曾为 OSS Google 项目做出贡献,否则通常最好在提交大型拉取请求之前先咨询一下某个人。我们期待着看到您的贡献并与您合作!

致谢

我们要感谢 GSoC 项目委员会及其支持,特别是 TensorFlow 团队的 Josh Gordon。

非常感谢机器学习 (ML) Google 开发者专家 (GDE) 计划、Google Cloud Platform 和 TensorFlow Research Cloud 的支持。

下一篇文章
Our Summer of Code Project on TF-GAN

Nived P AMargaret Maynard-ReidJoel Shor 发表Google Summer of Code 是一个每年夏天将学生开发者带入开源项目的计划。本文介绍了去年夏天由阿姆利塔工程学院的本科生 Nived PA 提出的对 TensorFlow GAN 库 (TF-GAN) 的增强。Nived 项目的目标是改进 TF-GAN 库…