2020 年 3 月 10 日 - 由 Google Cloud 开发者倡导 Reza Rokni 代表 TFX 和 Dataflow 团队发布
TFX 的核心任务是让模型能够从研究阶段迁移到生产阶段,创建和管理生产管道。许多模型将使用大量数据构建,需要多个主机并行工作,以满足生产管道对处理和服务的需要。
我们…
AnalyzeAndTransformDataset
,最后通过两个 TFX 组件 ExampleGen
和 StatisticsGen
。preprocessing_fn,
的详细信息,请参考教程。目前,我们只需要知道它正在转换传递给该函数的数据点。virtualenv tfx-beam --python=python3
source tfx-beam/bin/activate
pip install tfx
def main():
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_dataset, transform_fn = (
(raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset
print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))
if __name__ == '__main__':
main()
注意result = pass_this | 'name this step' >> to_this_call
正在调用方法 to_this_call
并传递名为 pass_this
的对象,并且 此操作将在堆栈跟踪中被称为 name this step
。 beam_impl.Context
包含在 beam.Pipeline
中。这使我们能够传递参数,例如“--runner
”。对于快速本地测试,您可以使用 --runner
设置为 DirectRunner 运行下面的示例。import apache_beam as beam
argv=['--runner=DirectRunner']
def main():
with beam.Pipeline(argv=argv) as p:
# Ignore the warnings
with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
input = p | beam.Create(raw_data)
transformed_dataset, transform_fn = (
(input, raw_data_metadata)
| beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))
transformed_dataset[0] |"Print Transformed Dataset" >> beam.Map(print)
if __name__ == '__main__':
main()
接下来,我们将切换到使用 Dataflow 运行器。由于 Dataflow 是在 Google Cloud 上运行的完全托管的运行器,因此我们需要为管道提供一些环境信息。这包括 Google Cloud 项目以及管道使用的临时文件的存储位置。# Setup our Environment
## The location of Input / Output between various stages ( TFX Components )
## This will also be the location for the Metadata
### Can be used when running the pipeline locally
#LOCAL_PIPELINE_ROOT =
### In production you want the input and output to be stored on non-local location
#GOOGLE_CLOUD_STORAGE_PIPELINE_ROOT=
#GOOGLE_CLOUD_PROJECT =
#GOOGLE_CLOUD_TEMP_LOCATION =
# Will need setup.py to make this work with Dataflow
#
# import setuptools
#
# setuptools.setup(
# name='demo',
# version='0.0',
# install_requires=['tfx==0.21.1'],
# packages=setuptools.find_packages(),)
SETUP_FILE = "./setup.py"
argv=['--project={}'.format(GOOGLE_CLOUD_PROJECT),
'--temp_location={}'.format(GOOGLE_CLOUD_TEMP_LOCATION),
'--setup_file={}'.format(SETUP_FILE),
'--runner=DataflowRunner']
def main():
with beam.Pipeline(argv=argv) as p:
with beam_impl.Context(temp_dir=GOOGLE_CLOUD_TEMP_LOCATION):
input = p | beam.Create(raw_data)
transformed_data, transformed_metadata = (
(input, raw_data_metadata)
| beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))
if __name__ == '__main__':
main()
为了了解 TFX 隐藏了多少工作,下面是管道处理的图表的可视化表示。我们不得不缩小图像以使它全部显示,因为有很多变换!def createExampleGen(query: Text):
# Output 2 splits: train:eval=3:1.
output = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
example_gen_pb2.SplitConfig.Split(
name='train', hash_buckets=3),
example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
]))
return BigQueryExampleGen(query=query, output_config=output)
除了要运行的 SQL 查询之外,BigQueryExampleGen 代码还通过 SplitConfig 对象传递配置信息。bigquery-public-data.chicago_taxi_trips.taxi_trips.
query="""
SELECT
pickup_community_area,
fare,
EXTRACT(MONTH FROM trip_start_timestamp) trip_start_month,
EXTRACT(HOUR FROM trip_start_timestamp) trip_start_hour,
EXTRACT(DAYOFWEEK FROM trip_start_timestamp) trip_start_day,
UNIX_Millis(trip_start_timestamp) trip_start_ms_timestamp,
pickup_latitude,
pickup_longitude,
dropoff_latitude,
dropoff_longitude,
trip_miles,
pickup_census_tract,
dropoff_census_tract,
payment_type,
company,
trip_seconds,
dropoff_community_area,
tips
FROM
`bigquery-public-data.chicago_taxi_trips.taxi_trips`
LIMIT 100
"""
请注意,使用了 LIMIT 100,这将限制输出为 100 条记录,从而使我们能够快速测试代码的正确性。def createStatisticsGen(bigQueryExampleGen: BigQueryExampleGen):
# Computes statistics over data for visualization and example validation.
return StatisticsGen(examples=bigQueryExampleGen.outputs['examples'])
由于 StatisticsGen 需要 ExampleGen 的输出,因此这两个步骤之间存在依赖关系。这种生产者-消费者模式贯穿大多数生产 ML 管道。为了自动化这个管道,我们需要一些东西来协调这些依赖关系。BeamDagRunner
使用 Apache Beam 进行编排。这意味着我们以两种不同的角色使用 Beam - 作为处理数据的执行引擎,以及作为对 TFX 任务进行排序的编排器。# Used for setting up the orchestration
from tfx.orchestration import pipeline
from tfx.orchestration import metadata
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
以下代码创建了我们的管道对象,准备由 BeamDagRunner 执行。from typing import Text
from typing import Type
def createTfxPipeline(pipeline_name: Text, pipeline_root: Text, query: Text,
beam_pipeline_args) -> pipeline.Pipeline:
output = example_gen_pb2.Output(
# Output 2 splits: train:eval=3:1.
split_config=example_gen_pb2.SplitConfig(splits=[
example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=3),
example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
]))
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = BigQueryExampleGen(query=query, output_config=output)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen
],
metadata_connection_config=metadata.sqlite_metadata_connection_config(
os.path.join(".", 'metadata', pipeline_name,'metadata.db')),
enable_cache=False,
additional_pipeline_args=beam_pipeline_args)
要测试该代码,请使用本地 DirectRunner 中的查询“LIMIT 100
”。tfx_pipeline = createTfxPipeline(
pipeline_name="my_first_directRunner_pipeline",
pipeline_root=LOCAL_PIPELINE_ROOT,
query=query,
beam_pipeline_args= {
'beam_pipeline_args':[
'--project={}'.format(GOOGLE_CLOUD_PROJECT),
'--runner=DirectRunner']})
BeamDagRunner().run(tfx_pipeline)
您可以看到使用 tfdv 生成的结果,并将其输出到 LOCAL_PIPELINE_ROOT
中;import os
import tensorflow_data_validation as tfdv
stats = tfdv.load_statistics(os.path.join(LOCAL_PIPELINE_ROOT,"StatisticsGen","statistics","","train","stats_tfrecord"))
tfdv.visualize_statistics(stats)
这对于一百条记录来说效果很好,但是如果目标是处理数据集中所有 187,002,0025 行怎么办?为此,我们将管道从 DirectRunner 切换到生产 Dataflow 运行器。还设置了一些额外的环境参数,例如运行管道的 Google Cloud 项目。tfx_pipeline = createTfxPipeline(
pipeline_name="my_first_dataflowRunner_pipeline",
pipeline_root=GOOGLE_CLOUD_STORAGE_PIPELINE_ROOT,
query=query,
beam_pipeline_args={
'beam_pipeline_args':[
'--project={}'.format(GOOGLE_CLOUD_PROJECT)
,
'--temp_location={}'.format(GOOGLE_CLOUD_TEMP_LOCATION),
'--setup_file=./setup.py',
'--runner=DataflowRunner']})
BeamDagRunner().run(tfx_pipeline)
BeamDagRunner
负责将 ExampleGen
和 StatisticsGen
提交为独立的管道,并确保 ExampleGen 首先成功完成,然后再开始 StatisticsGen。 Dataflow 服务会自动负责启动工作器、自动伸缩、在工作器发生故障时重试、集中式日志记录以及监控。 自动伸缩基于各种信号,包括吞吐率,如下所示; Dataflow 监控控制台会显示有关管道的各种指标,例如工作器的 CPU 使用率。 下面我们看到机器的使用率随着它们上线而不断提高,大多数工作器的使用率始终保持在 90% 以上: Apache Beam 支持自定义计数器,这允许开发人员在他们的管道中创建指标。 TFX 团队利用这一点为各种组件创建了有用的信息计数器。 下面我们可以看到 StatisticsGen 运行期间记录的一些计数器。 过滤关键字“num_*_feature
”,大约有十亿个整数和浮点数特征值。