gpt4 book ai didi

python - tensorflow 1.10+ : passing epoch to estimator input_fn?

转载 作者:行者123 更新时间:2023-12-01 07:43:05 26 4
gpt4 key购买 nike

tf.estimator input_fn 的签名可能如下所示:

def input_fn(files:list, params:dict):
dataset = tf.data.TFRecordDataset(files)
.map(lambda record: parse_record_fn(record))

if params['mode'] == 'train':
# train specific things
# ...

这样的定义允许人们构建所有的 input_fn ,如下所示:

train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn = lambda: input_fn(files['test_set'], {**params, **{"mode": "test"}})


train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec = tf.estimator.EvalSpec(input_fn=valid_fn, ...)

我的问题是如何更改 input_fn 签名以允许基于纪元的变化。我知道这可能会带来瓶颈,但如果我能做这样的事情那就太好了:


def input_fn(...):
# see above

epoch = params["epoch"]
if epoch % 100 == 0:
# modify or make a new dataset

# ...
return dataset.make_one_shot_iterator().get_next()

关键是确保 input_fn 仍然兼容:

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

最佳答案

我不知道有任何选项提供epoch数字作为参数。

也就是说,根据定义,纪元是输入函数的一个特征,因此我们应该能够处理输入函数内的所有内容,而不是完全访问训练参数。所以我认为你可能可以通过一点点的摆弄来实现你所需要的。

例如,如果我有 2 个数据集:ds1ds2 说并且希望在“epoch”数字不能被整除时使用 ds1 100 那么我可以通过执行以下操作来创建一个新数据集:

dataset = ds1.repeat(99).concatenate(ds2)

由于数据集默认是延迟加载的,所以我不需要担心内存影响(我不会将 100 倍的数据加载到内存中)。

显然,这确实对数据集的大小有影响,因此您需要考虑评估操作/回调等之间的步骤策略,但这应该很容易调整。

关于python - tensorflow 1.10+ : passing epoch to estimator input_fn?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56594469/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com