gpt4 book ai didi

python - 如何在 Tensorflow 2.0 数据集中动态更改批量大小?

转载 作者:行者123 更新时间:2023-12-02 04:22:29 24 4
gpt4 key购买 nike

在 TensorFlow 1.X 中,您可以使用占位符动态更改批量大小。例如

dataset.batch(batch_size=tf.placeholder())
See full example

如何在 TensorFlow 2.0 中做到这一点?

我尝试了以下方法,但不起作用。

import numpy as np
import tensorflow as tf


def new_gen_function():
for i in range(100):
yield np.ones(2).astype(np.float32)


batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
batch_size=batch_size)

for data in train_ds:
print(data.shape[0])
batch_size.assign(10)
print(batch_size)

输出

5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...

我正在使用渐变胶带使用自定义训练循环来训练模型。我怎样才能实现这个目标?

最佳答案

我认为你不能像以前在 TF1 中那样。

解决方法可能是通过堆叠单个样本来自己构建批处理:

import tensorflow as tf

ds = tf.data.Dataset.range(10).repeat()
iterator = iter(ds)
for batch_size in range(1, 10):
batch = tf.stack([iterator.next() for _ in range(batch_size)], axis=0)
print(batch)

# tf.Tensor([0], shape=(1,), dtype=int64)
# tf.Tensor([1 2], shape=(2,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7 8 9], shape=(4,), dtype=int64)
# tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
# tf.Tensor([5 6 7 8 9 0], shape=(6,), dtype=int64)
# tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int64)
# tf.Tensor([8 9 0 1 2 3 4 5], shape=(8,), dtype=int64)
# tf.Tensor([6 7 8 9 0 1 2 3 4], shape=(9,), dtype=int64)

关于python - 如何在 Tensorflow 2.0 数据集中动态更改批量大小?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59368539/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com