2021 年 9 月 13 日 - 发布者:Elie Bursztein 和 Owen Vallis,Google 今天,我们发布了 TensorFlow Similarity 的第一个版本,这是一个 Python 包,旨在使用 TensorFlow 轻松快速地训练相似性模型。查找相关项目的的能力在现实世界中有着广泛的应用,从查找相似外观的服装,到识别当前正在播放的歌曲,到帮助寻找丢失的宠物……
发布者:Elie Bursztein 和 Owen Vallis,Google
今天,我们发布了 TensorFlow Similarity 的第一个版本,这是一个 Python 包,旨在使用 TensorFlow 轻松快速地训练相似性模型。
在 Oxford IIIT Pet 数据集 上训练的相似性模型生成的嵌入中执行的最近邻搜索示例 |
查找相关项目的的能力在现实世界中有着广泛的应用,从查找相似外观的服装,到识别当前正在播放的歌曲,到帮助寻找丢失的宠物。更普遍地说,能够快速检索相关项目是许多核心信息系统的重要组成部分,例如多媒体搜索、推荐系统和聚类管道。
相似性模型学习输出嵌入,这些嵌入将项目投影到一个度量空间中,在这个空间中,相似的项目彼此靠近,而不同的项目彼此远离。 |
在幕后,许多这些系统都由使用对比学习训练的深度学习模型提供支持。对比学习教会模型学习一个嵌入空间,在这个空间中,相似的示例彼此靠近,而不相似的示例彼此远离,例如,属于同一类的图像被拉到一起,而不同的类被相互推开。在我们的示例中,所有来自同一动物品种的图像被拉到一起,而不同的品种则相互推开。
使用 Tensorflow Similarity 投影仪对 Oxford-IIIT Pet 数据集进行可视化 |
当应用于整个数据集时,对比损失 使模型能够学习如何将项目投影到嵌入空间中,以便嵌入之间的距离代表输入示例的相似程度。在训练结束时,你会得到一个聚类良好的空间,其中相似项目之间的距离很小,而不同项目之间的距离很大。例如,如上所示,在 Oxford-IIIT Pet 数据集 上训练相似性模型会导致有意义的聚类,其中相似外观的品种彼此靠近,猫和狗明显分离。
查找相关项目涉及计算查询图像嵌入,执行 ANN 搜索以查找相似的项目,并获取相似的项目元数据,包括图像字节。 |
一旦模型训练完成,我们就会构建一个索引,其中包含我们要使其可搜索的各种项目的嵌入。然后在查询时,TensorFlow Similarity 利用 快速近似最近邻搜索 (ANN) 在亚线性时间内从索引中检索最接近的匹配项目。这种快速查找利用了 TensorFlow Similarity 学习了一个度量嵌入空间,其中嵌入点之间的距离是有效 距离度量 的函数。这些距离度量满足 三角不等式,使空间适合近似最近邻搜索,并导致高检索精度。
其他方法,例如使用 模型特征提取,需要使用精确最近邻搜索来查找相关项目,并且可能不如训练过的相似性模型准确。这会阻止这些方法扩展,因为执行精确搜索需要在搜索索引的大小上进行二次时间。相比之下,TensorFlow Similarity 内置的近似最近邻索引系统,它依赖于 NMSLIB,使搜索数百万个已索引的项目成为可能,并在几分之一秒内检索出前 K 个相似匹配。
除了准确性和检索速度之外,相似性模型的另一个主要优势是它们允许你向索引添加无限数量的新类,而无需重新训练。相反,你只需要计算新类的代表性项目的嵌入,并将它们添加到索引中。这种动态添加新类别的能力在解决不同项目数量事先未知、不断变化或非常大的问题时特别有用。这方面的例子包括允许用户发现与他们过去喜欢的歌曲相似的最新发布的音乐。
TensorFlow Similarity 提供了所有必要的组件,使相似性训练评估和查询变得直观和容易。特别是,如下所示,TensorFlow Similarity 引入了 SimilarityModel(),这是一个新的 Keras 模型,它原生支持嵌入索引和查询。这使你能够快速有效地执行端到端训练和评估。
一个在 MNIST 数据上进行训练、索引和搜索的最小示例可以用不到 20 行代码编写
from tensorflow.keras import layers
# Embedding output layer with L2 norm
from tensorflow_similarity.layers import MetricEmbedding
# Specialized metric loss
from tensorflow_similarity.losses import MultiSimilarityLoss
# Sub classed keras Model with support for indexing
from tensorflow_similarity.models import SimilarityModel
# Data sampler that pulls datasets directly from tf dataset catalog
from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler
# Nearest neighbor visualizer
from tensorflow_similarity.visualization import viz_neigbors_imgs
# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)
# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)
# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)
# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)
# Index 100 embedded MNIST examples to make them searchable
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)
# Find the top 5 most similar indexed MNIST examples for a given example
qx, qy = sampler.get_slice(3713, 1)
nns = model.single_lookup(qx[0])
# Visualize the query example and its top 5 neighbors
viz_neigbors_imgs(qx[0], qy[0], nns)
即使上面的代码片段使用了次优模型,它仍然产生了良好的匹配结果,其中最近邻看起来明显像查询的数字,如以下屏幕截图所示
此初始版本侧重于提供所有必要的组件,以帮助你构建基于对比学习的相似性模型,例如损失、索引、批次采样器、指标和教程。TF Similarity 还使使用 Keras API 和使用现有的 Keras 架构变得容易。展望未来,我们计划在此坚实的基础上构建,以支持半监督和自监督方法,例如 BYOL、SWAV 和 SimCLR。
你可以立即开始使用 TF Similarity 进行实验,方法是前往 Hello World 教程。有关更多信息,你可以查看 项目 Github。
2021 年 9 月 13 日 - 发布者:Elie Bursztein 和 Owen Vallis,Google 今天,我们发布了 TensorFlow Similarity 的第一个版本,这是一个 Python 包,旨在使用 TensorFlow 轻松快速地训练相似性模型。查找相关项目的的能力在现实世界中有着广泛的应用,从查找相似外观的服装,到识别当前正在播放的歌曲,到帮助寻找丢失的宠物……