gpt4 book ai didi

python - 使用 TensorFlow Dataset API 的纪元计数器

转载 作者:太空狗 更新时间:2023-10-29 20:22:24 24 4
gpt4 key购买 nike

我正在将我的 TensorFlow 代码从旧的队列接口(interface)更改为新的 Dataset API .在我的旧代码中,每次在队列中访问和处理新的输入张量时,我都会通过递增 tf.Variable 来跟踪纪元计数。我想使用新的 Dataset API 计算这个纪元,但我在让它工作时遇到了一些麻烦。

由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增 (Python) 计数器并不是一件简单的事情 - 我需要根据队列或数据集的输入。

我模仿了以前使用旧队列系统的方式,这是我最终为数据集 API 得到的(简化示例):

with tf.Graph().as_default():

data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)

epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
trainable=False)

def pre_processing_func(data_):
data_size = tf.constant(0.1, dtype=tf.float32)
epoch_counter_op = tf.assign_add(epoch_counter, data_size)
with tf.control_dependencies([epoch_counter_op]):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)

dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
dataset = dataset.repeat()
# ... do something with 'dataset' and print
# the value of 'epoch_counter' every once a while

但是,这是行不通的。它崩溃并显示一条神秘的错误消息:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
are not compatible with expected types ([tf.float32_ref, tf.float32])

仔细检查表明 epoch_counter 变量可能根本无法在 pre_processing_func 中访问。它可能存在于不同的图表中吗?

知道如何修正上面的例子吗?或者如何通过其他方式获取纪元计数器(带小数点,例如 0.4 或 2.9)?

最佳答案

TL;DR:将 epoch_counter 的定义替换为以下内容:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
trainable=False, use_resource=True)

tf.data.Dataset 转换中使用 TensorFlow 变量有一些限制。原则限制是所有变量都必须是“资源变量”,而不是旧的“引用变量”;不幸的是,tf.Variable 仍然出于向后兼容性的原因创建“引用变量”。

一般来说,如果可以避免,我不建议在 tf.data 管道中使用变量。例如,您可以使用 Dataset.range() 定义纪元计数器,然后执行如下操作:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
(pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的代码片段将纪元计数器作为第二个组件附加到每个值。

关于python - 使用 TensorFlow Dataset API 的纪元计数器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47410778/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com