gpt4 book ai didi

tensorflow - 将 SageMaker 管道模式与 tfrecords 的 s3 目录一起使用

转载 作者:行者123 更新时间:2023-12-05 05:53:38 26 4
gpt4 key购买 nike

当我使用 Pipe 而不是 File 作为input_mode。我相应地将 TensorFlow Dataset 替换为 PipemodedatasetFile 模式下的训练成功完成。

我的数据由两个 s3 桶组成,每个桶中有多个 tfrecord 文件。尽管仔细阅读了文档,但我对如何在这种情况下使用 Pipemodedataset 没有信心 - 具体来说,如何设置 channel

这是我的 Sagemaker 笔记本设置:

hyperparameters = {
"batch-size": 1,
"pipe_mode": 1,
}

estimator_config = {
"entry_point": "tensorflow_train.py",
"source_dir": "source",
"framework_version": "2.3",
"py_version": "py37",
"instance_type": "ml.p3.2xlarge",
"instance_count": 1,
"role": sagemaker.get_execution_role(),
"hyperparameters": hyperparameters,
"output_path": f"s3://{bucket_name}",
"input_mode": "Pipe",
}

tf_estimator = TensorFlow(**estimator_config)

s3_data_channels = {
"training": f"s3://{bucket_name}/data/training",
"validation": f"s3://{bucket_name}/data/validation",
}

tf_estimator.fit(s3_data_channels)

如果我在 s3_data_channels 上运行 aws s3 ls,我会得到一个 tfrecord 文件列表。

下面是我设置数据集的方式(根据是否选择pipe_mode看if/else语句:

import tensorflow as tf

if __name__ == "__main__":

arg_parser = argparse.ArgumentParser()
...
arg_parser.add_argument("--pipe_mode", type=int, default=0)

arg_parser.add_argument("--train_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAINING"))
arg_parser.add_argument(
"--validation_dir", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION")
)
arg_parser.add_argument("--model_dir", type=str)
args, _ = arg_parser.parse_known_args()

AUTOTUNE = tf.data.experimental.AUTOTUNE

if args.pipe_mode == 1:
from sagemaker_tensorflow import PipeModeDataset
train_ds = PipeModeDataset(channel="training", record_format='TFRecord')
val_ds = PipeModeDataset(channel="validation", record_format='TFRecord')

else:
train_files = tf.data.Dataset.list_files(args.train_dir + '/*tfrecord')
val_files = tf.data.Dataset.list_files(args.validation_dir + '/*tfrecord')
train_ds = tf.data.TFRecordDataset(filenames=train_files, num_parallel_reads=AUTOTUNE)
val_ds = tf.data.TFRecordDataset(filenames=val_files, num_parallel_reads=AUTOTUNE)

train_ds = (
train_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
.batch(args.batch_size)
.prefetch(AUTOTUNE)
)

val_ds = (
val_ds.map(tfrecord_parser, num_parallel_calls=AUTOTUNE)
.batch(args.batch_size)
.prefetch(AUTOTUNE)
)
...

最佳答案

我遇到了同样的问题,model.fit() 在使用管道模式时无限期地卡住了。经过一些研究并尝试了许多更改后,通过在拟合模型时定义 steps_per_epoch 解决了这个问题。

我想在使用文件模式时它已经知道每个纪元会有多少步,但是对于管道模式你必须手动指定它

关于tensorflow - 将 SageMaker 管道模式与 tfrecords 的 s3 目录一起使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69843184/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com