https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhm7q6rv8tiNn0M6Xla7qTF_hveSNLWGZ1SIYu7i-H0ddTuMhJDZVFSqQ27MOgrxa_RAB-QTKpSfhKvqFxUOWEQ8fYzJ_W6PX4VOSEDBqOzGboQTZc0rzxEuV1-TotV4yruTR4SnDBmcgQ/s1600/custom-comp-figure1.png
由
刘若愚 和
Robert Crowe 代表 TFX 团队发布
TensorFlow Extended (
TFX) 是一个用于创建生产就绪的机器学习 (ML) 管道的平台。TFX 由 Google 创建,为 Google 的 ML 服务和应用提供基础,现在 Google 已将 TFX 开源,供任何希望创建生产 ML 管道的人使用。
TFX 可以通过多种方式进行扩展和定制。我们之前介绍过如何通过使用
自定义执行器 来更改
TFX 组件 的行为。在这篇文章中,我们将演示如何通过创建
全新的 TFX 组件 并将其用在 TFX 管道中来定制 TFX。
介绍
TFX 提供了一组
标准组件,这些组件可以链接在一起形成一个标准的 ML 工作流。虽然这满足了许多用例的需求,但仍然有一些标准组件不支持的场景。为了支持这些场景,TFX 可以通过自定义组件进行扩展。
在某些情况下,例如在
之前的博客文章 中,上游和下游语义(组件的输入和输出)与现有组件相同,您可以通过重用现有组件并替换执行器的行为来创建一个新的“半自定义”组件。现有组件可能是标准组件之一,也可能是您或其他人创建的自定义组件。
但是,如果新组件的上游和下游语义与现有组件不同,则需要创建一个新的“完全自定义”自定义组件,这是本博文的主题。
本文的剩余部分将说明如何从一个简单的
HelloWorld 组件
构建一个自定义组件。为了简便起见,HelloWorld 组件只会将所有输入复制为自身的输出,并将它们提供给下游组件,以演示如何使用和发出数据工件。
更新的管道工作流
在开始编码之前,让我们看一下包含新自定义组件的更新工作流。如下图 1 和 2 所示,我们将新的 HelloWorld 组件插入到
ExampleGen
和所有依赖示例数据的下游组件之间。这意味着关于新组件的两个事实
- 它需要将 ExampleGen 的输出作为其输入之一
- 它需要生成与 ExampleGen 输出类型相同的输出,以便原本依赖于 ExampleGen 的组件具有相同类型的输入
|
图 1. 插入新自定义组件之前 |
|
图 2. 插入新自定义组件之后 |
构建自己的自定义组件
接下来,我们将逐步构建新的组件。
通道
TFX 通道是一个抽象概念,它连接数据生产者和数据消费者。从概念上讲,一个组件从通道读取输入工件,并将输出工件写入通道,这些通道将被下游组件用作输入。通道用工件类型进行类型化(如下一节所述),这意味着写入或读取到通道的所有工件都共享相同的工件类型。
ComponentSpec
第一步是定义新组件的输入和输出,以及将在组件执行中使用的其他参数。
ComponentSpec
是我们将定义此契约的类,其中包含详细的类型信息。有三个预期的参数
INPUTS
:一个包含类型化参数的字典,用于将输入工件传递到组件执行器中。通常,输入工件是来自上游组件的输出,因此共享相同的类型。
OUTPUTS
:一个包含类型化参数的字典,用于组件将产生的输出工件。
PARAMETERS
:一个包含附加 ExecutionParameter
项的字典,这些项将被传递到组件执行器中。这些是非工件参数,我们希望在管道 DSL 中灵活地定义它们,并将其传递到执行中。
如前一节所述,我们需要保证
HelloWorld
组件的输入之一与 ExampleGen
输出类型相同,因为它是由其直接传递的。如图 3 所示,'input_data'
是它的规范。
HelloWorld
组件的输出之一与 ExampleGen
输出类型相同,因为它将被传递给下游组件,这些组件原本期望 ExampleGen 输出。如图 3 所示,'output_data'
是它的规范。
在参数规范部分,只声明了
'name'
,用于演示目的。
class HelloComponentSpec(types.ComponentSpec):
"""ComponentSpec for Custom TFX Hello World Component."""
# The following declares inputs to the component.
INPUTS = {
'input_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares outputs from the component.
OUTPUTS = {
'output_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares extra parameters used to create an instance of
# this component
PARAMETERS = {
'name': ExecutionParameter(type=Text),
}
|
图 3. HelloWorld 组件的 ComponentSpec。 |
执行器
接下来,让我们编写新组件执行器的代码。正如我们在上一篇文章中讨论的那样,我们需要创建一个
base_executor.BaseExecutor
的新子类,并覆盖其
Do
函数。
class Executor(base_executor.BaseExecutor):
"""Executor for HelloWorld component."""
...
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any]) -> None:
...
split_to_instance = {}
for artifact in input_dict['input_data']:
for split in json.loads(artifact.split_names):
uri = os.path.join(artifact.uri, split)
split_to_instance[split] = uri
for split, instance in split_to_instance.items():
input_dir = instance
output_dir = artifact_utils.get_split_uri(
output_dict['output_data'], split)
for filename in tf.io.gfile.listdir(input_dir):
input_uri = os.path.join(input_dir, filename)
output_uri = os.path.join(output_dir, filename)
io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)
|
图 4. HelloWorld 组件的执行器。 |
如
图 4 所示,我们可以使用之前在
ComponentSpec
中定义的相同键来获取输入和输出工件以及执行属性。在获得所有需要的价值观后,我们可以继续使用它们添加更多逻辑,并将输出写入输出工件(
'output_data'
)指向的 URI 中。
在继续下一步之前,不要忘记测试它!我们为您创建了一个方便的
脚本,让您可以在将执行器投入生产之前试用它。您应该编写类似的代码来练习代码的单元测试。与任何生产软件部署一样,在为 TFX 开发时,您应该确保有良好的测试覆盖率和强大的 CI/CD 框架。
组件接口
现在我们已经完成了最复杂的部分,我们需要将这些部分组装成一个组件接口,以便能够在管道中使用该组件。该过程(如
图 5 所示)需要以下步骤
- 使组件接口成为
base_component.BaseComponent
的子类
- 使用之前定义的
HelloComponentSpec
类为类变量 SPEC_CLASS
赋值
- 使用之前定义的
Executor
类为类变量 EXECUTOR_SPEC
赋值
- 使用函数的参数定义
__init__()
函数,以构造 HelloComponentSpec
的实例,并使用该值调用超类函数,以及可选的名称
当组件的实例被创建时,
base_component.BaseComponent
类中的类型检查逻辑将被调用,以确保传入的参数与
HelloComponentSpec
类中定义的参数类型兼容。
from hello_component import executor
class HelloComponent(base_component.BaseComponent):
"""Custom TFX HelloWorld Component."""
SPEC_CLASS = HelloComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)
def __init__(self,
input_data: channel.Channel,
output_data: channel.Channel,
name: Text):
if not output_data:
examples_artifact = standard_artifacts.Examples()
examples_artifact.split_names = input_data.get()[0].split_names
output_data = channel_utils.as_channel([examples_artifact])
spec = HelloComponentSpec(input_data=input_data,
output_data=output_data, name=name)
super(HelloComponent, self).__init__(spec=spec)
|
图 5. 组件接口。 |
插入 TFX 管道
好消息!经过之前几节的努力,我们全新的组件已经可以使用了。让我们将它插入到我们的芝加哥出租车示例
管道 中。除了添加新组件的实例外,我们还需要
- 当我们实例化原本期望
ExampleGen
输出的组件时,调整参数,使其现在接收我们新组件的输出
- 在构建管道时将新组件实例添加到组件列表中
图 6 突出了这些更改。完整的示例可以在我们的
GitHub 仓库 中找到。
def _create_pipeline():
...
example_gen = CsvExampleGen(input_base=examples)
hello = component.HelloComponent(
input_data=example_gen.outputs['examples'], name=u'HelloWorld')
statistics_gen = StatisticsGen(examples=hello.outputs['output_data'])
return pipeline.Pipeline(
...
components=[example_gen, hello, statistics_gen],
...
)
|
图 6. 使用新组件。 |
更多信息
要了解有关 TFX 的更多信息,请查看
TFX 网站,加入
TFX 讨论组,阅读
TFX 博客,观看我们的
TFX YouTube 播放列表,并
订阅 TensorFlow 频道。