创建自定义 TFX 组件
2020 年 1 月 22 日
刘若愚Robert Crowe 代表 TFX 团队发布

TensorFlow Extended (TFX) 是一个用于创建生产就绪的机器学习 (ML) 管道的平台。TFX 由 Google 创建,为 Google 的 ML 服务和应用提供基础,现在 Google 已将 TFX 开源,供任何希望创建生产 ML 管道的人使用。

TFX 可以通过多种方式进行扩展和定制。我们之前介绍过如何通过使用 自定义执行器 来更改 TFX 组件 的行为。在这篇文章中,我们将演示如何通过创建 全新的 TFX 组件 并将其用在 TFX 管道中来定制 TFX。

介绍

TFX 提供了一组 标准组件,这些组件可以链接在一起形成一个标准的 ML 工作流。虽然这满足了许多用例的需求,但仍然有一些标准组件不支持的场景。为了支持这些场景,TFX 可以通过自定义组件进行扩展。

在某些情况下,例如在 之前的博客文章 中,上游和下游语义(组件的输入和输出)与现有组件相同,您可以通过重用现有组件并替换执行器的行为来创建一个新的“半自定义”组件。现有组件可能是标准组件之一,也可能是您或其他人创建的自定义组件。

但是,如果新组件的上游和下游语义与现有组件不同,则需要创建一个新的“完全自定义”自定义组件,这是本博文的主题。

本文的剩余部分将说明如何从一个简单的 HelloWorld 组件 构建一个自定义组件。为了简便起见,HelloWorld 组件只会将所有输入复制为自身的输出,并将它们提供给下游组件,以演示如何使用和发出数据工件。

更新的管道工作流

在开始编码之前,让我们看一下包含新自定义组件的更新工作流。如下图 1 和 2 所示,我们将新的 HelloWorld 组件插入到 ExampleGen 和所有依赖示例数据的下游组件之间。这意味着关于新组件的两个事实
  • 它需要将 ExampleGen 的输出作为其输入之一
  • 它需要生成与 ExampleGen 输出类型相同的输出,以便原本依赖于 ExampleGen 的组件具有相同类型的输入
图 1. 插入新自定义组件之前
图 2. 插入新自定义组件之后

构建自己的自定义组件

接下来,我们将逐步构建新的组件。

通道

TFX 通道是一个抽象概念,它连接数据生产者和数据消费者。从概念上讲,一个组件从通道读取输入工件,并将输出工件写入通道,这些通道将被下游组件用作输入。通道用工件类型进行类型化(如下一节所述),这意味着写入或读取到通道的所有工件都共享相同的工件类型。

ComponentSpec

第一步是定义新组件的输入和输出,以及将在组件执行中使用的其他参数。ComponentSpec 是我们将定义此契约的类,其中包含详细的类型信息。有三个预期的参数
  • INPUTS:一个包含类型化参数的字典,用于将输入工件传递到组件执行器中。通常,输入工件是来自上游组件的输出,因此共享相同的类型。
  • OUTPUTS:一个包含类型化参数的字典,用于组件将产生的输出工件。
  • PARAMETERS:一个包含附加 ExecutionParameter 项的字典,这些项将被传递到组件执行器中。这些是非工件参数,我们希望在管道 DSL 中灵活地定义它们,并将其传递到执行中。
如前一节所述,我们需要保证
  • HelloWorld 组件的输入之一与 ExampleGen 输出类型相同,因为它是由其直接传递的。如图 3 所示,'input_data' 是它的规范。
  • HelloWorld 组件的输出之一与 ExampleGen 输出类型相同,因为它将被传递给下游组件,这些组件原本期望 ExampleGen 输出。如图 3 所示,'output_data' 是它的规范。
在参数规范部分,只声明了 'name',用于演示目的。
class HelloComponentSpec(types.ComponentSpec):
  """ComponentSpec for Custom TFX Hello World Component."""
  # The following declares inputs to the component.
  INPUTS = {
    'input_data': ChannelParameter(type=standard_artifacts.Examples),
  }
  # The following declares outputs from the component.
  OUTPUTS = {
    'output_data': ChannelParameter(type=standard_artifacts.Examples),
  }
  # The following declares extra parameters used to create an instance of
  # this component
  PARAMETERS = {
    'name': ExecutionParameter(type=Text),
  }
图 3. HelloWorld 组件的 ComponentSpec。

执行器

接下来,让我们编写新组件执行器的代码。正如我们在上一篇文章中讨论的那样,我们需要创建一个 base_executor.BaseExecutor 的新子类,并覆盖其 Do 函数。
class Executor(base_executor.BaseExecutor):
  """Executor for HelloWorld component."""
  ...  
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    ...
    split_to_instance = {}
    for artifact in input_dict['input_data']:
      for split in json.loads(artifact.split_names):
        uri = os.path.join(artifact.uri, split)
        split_to_instance[split] = uri
    for split, instance in split_to_instance.items():
      input_dir = instance
      output_dir = artifact_utils.get_split_uri(
          output_dict['output_data'], split)
      for filename in tf.io.gfile.listdir(input_dir):
        input_uri = os.path.join(input_dir, filename)
        output_uri = os.path.join(output_dir, filename)
        io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)
图 4. HelloWorld 组件的执行器。
图 4 所示,我们可以使用之前在 ComponentSpec 中定义的相同键来获取输入和输出工件以及执行属性。在获得所有需要的价值观后,我们可以继续使用它们添加更多逻辑,并将输出写入输出工件('output_data')指向的 URI 中。

在继续下一步之前,不要忘记测试它!我们为您创建了一个方便的 脚本,让您可以在将执行器投入生产之前试用它。您应该编写类似的代码来练习代码的单元测试。与任何生产软件部署一样,在为 TFX 开发时,您应该确保有良好的测试覆盖率和强大的 CI/CD 框架。

组件接口

现在我们已经完成了最复杂的部分,我们需要将这些部分组装成一个组件接口,以便能够在管道中使用该组件。该过程(如图 5 所示)需要以下步骤
  1. 使组件接口成为 base_component.BaseComponent 的子类
  2. 使用之前定义的 HelloComponentSpec 类为类变量 SPEC_CLASS 赋值
  3. 使用之前定义的 Executor 类为类变量 EXECUTOR_SPEC 赋值
  4. 使用函数的参数定义 __init__() 函数,以构造 HelloComponentSpec 的实例,并使用该值调用超类函数,以及可选的名称
当组件的实例被创建时,base_component.BaseComponent 类中的类型检查逻辑将被调用,以确保传入的参数与 HelloComponentSpec 类中定义的参数类型兼容。
from hello_component import executor
class HelloComponent(base_component.BaseComponent):
  """Custom TFX HelloWorld Component."""
  SPEC_CLASS = HelloComponentSpec
  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(self,
               input_data: channel.Channel,
               output_data: channel.Channel,
               name: Text):
    if not output_data:
      examples_artifact = standard_artifacts.Examples()
      examples_artifact.split_names = input_data.get()[0].split_names
      output_data = channel_utils.as_channel([examples_artifact])
    spec = HelloComponentSpec(input_data=input_data,
                              output_data=output_data, name=name)
    super(HelloComponent, self).__init__(spec=spec)
图 5. 组件接口。

插入 TFX 管道

好消息!经过之前几节的努力,我们全新的组件已经可以使用了。让我们将它插入到我们的芝加哥出租车示例 管道 中。除了添加新组件的实例外,我们还需要
  • 当我们实例化原本期望 ExampleGen 输出的组件时,调整参数,使其现在接收我们新组件的输出
  • 在构建管道时将新组件实例添加到组件列表中
图 6 突出了这些更改。完整的示例可以在我们的 GitHub 仓库 中找到。
def _create_pipeline():
  ...
  example_gen = CsvExampleGen(input_base=examples)
  hello = component.HelloComponent(
      input_data=example_gen.outputs['examples'], name=u'HelloWorld')
  statistics_gen = StatisticsGen(examples=hello.outputs['output_data'])
  return pipeline.Pipeline(
      ...
      components=[example_gen, hello, statistics_gen],
      ...
  )
图 6. 使用新组件。

更多信息

要了解有关 TFX 的更多信息,请查看 TFX 网站,加入 TFX 讨论组,阅读 TFX 博客,观看我们的 TFX YouTube 播放列表,并 订阅 TensorFlow 频道。
下一篇文章
Creating a Custom TFX Component

刘若愚Robert Crowe 代表 TFX 团队发布

TensorFlow Extended (TFX) 是一个用于创建生产就绪的机器学习 (ML) 管道的平台。TFX 由 Google 创建,为 Google 的 ML 服务和应用提供基础,现在 Google 已将 TFX 开源,供任何希望创建生产 ML 管道的人使用。

TFX 可以通过多种方式进行扩展和定制...