gpt4 book ai didi

python - Tensorflow dataset.batch() 不显示真实的批量大小

转载 作者:行者123 更新时间:2023-11-28 21:37:23 24 4
gpt4 key购买 nike

我想将原来基于队列的数据加载机制更改为tf.data API。

原代码为:

    # Index queue
self.input_idxs = tf.placeholder(tf.int64, shape=[None, 2])
idx_queue = tf.FIFOQueue(1e8, tf.int64)
self.enq_idxs = idx_queue.enqueue_many(self.input_idxs)
get_idx = idx_queue.dequeue()

# Image loading queue
img_queue = tf.FIFOQueue(opt.max_queue_size, task.proc_arg_dtype)
load_data = tf.py_func(task.load_sample_data, [get_idx], task.proc_arg_dtype)
enq_img = img_queue.enqueue(load_data)
init_sample = img_queue.dequeue()

# Preprocessing queue
# (for any preprocessing that can be done with TF operations)
data_queue = tf.FIFOQueue(opt.max_queue_size, task.data_arg_dtype,
shapes=task.data_shape)
enq_data = data_queue.enqueue(task.preprocess(init_sample, train_flag))
self.get_sample = data_queue.dequeue_many(opt.batchsize)

更改后为:

    # Dataset
self.input_idxs = tf.placeholder(tf.int64, shape=[None, 2])
dataset = tf.data.Dataset.from_tensor_slices(self.input_idxs)

def load_sample(idx):
sample = task.load_sample_data(idx)
sample = task.preprocess(sample, train_flag)
return sample

dataset = dataset.map(lambda idx: tf.py_func(load_sample, [idx], task.proc_arg_dtype), num_parallel_calls=self.num_threads)

def gen(dataset):
yield dataset.make_one_shot_iterator().get_next()

dataset = tf.data.Dataset.from_generator(gen, tuple(task.proc_arg_dtype), tuple(task.data_shape))
dataset = dataset.batch(opt.batchsize)
self.iterator = dataset.make_initializable_iterator()
self.get_sample = self.iterator.get_next()

哪里task.proc_arg_dtypetask.data_shape是:

    proc_arg_dtype = [tf.float32, tf.float32, tf.int32, tf.int32, tf.int32, tf.float32, tf.int32, tf.int32, tf.int32]
data_shape = [
[opt.input_res, opt.input_res, 3],
[opt.output_res, opt.output_res, opt.det_inputs],
[2, opt.max_nodes, 2],
[4],
[opt.max_nodes, opt.obj_slots + opt.rel_slots],
[opt.max_nodes, opt.obj_slots, 5],
[opt.max_nodes, opt.rel_slots, 2],
[opt.max_nodes, 7],
[1]
]

自从我找到tf.py_func没有data_shape参数,以便我使用 tf.data.Dataset.from_generator去做吧。 (不确定这是否正确,因为我在运行竞争之前遇到了问题)

问题之前是self.get_sample类似于:

[<tf.Tensor 'IteratorGetNext:0' shape=(8, 512, 512, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(8, 64, 64, 300) dtype=float32>, <tf.Tensor 'IteratorGetNext:2' shape=(8, 2, 200, 2) dtype=int32>, <tf.Tensor 'IteratorGetNext:3' shape=(8, 4) dtype=int32>, <tf.Tensor 'IteratorGetNext:4' shape=(8, 200, 9) dtype=int32>, <tf.Tensor 'IteratorGetNext:5' shape=(8, 200, 3, 5) dtype=float32>, <tf.Tensor 'IteratorGetNext:6' shape=(8, 200, 6, 2) dtype=int32>, <tf.Tensor 'IteratorGetNext:7' shape=(8, 200, 7) dtype=int32>, <tf.Tensor 'IteratorGetNext:8' shape=(8, 1) dtype=int32>]

其中批量大小是第一个维度。但是通过使用dataset.batch(opt.batch_size)self.get_sample

[<tf.Tensor 'IteratorGetNext:0' shape=(?, 512, 512, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(?, 64, 64, 300) dtype=float32>, <tf.Tensor 'IteratorGetNext:2' shape=(?, 2, 200, 2) dtype=int32>, <tf.Tensor 'IteratorGetNext:3' shape=(?, 4) dtype=int32>, <tf.Tensor 'IteratorGetNext:4' shape=(?, 200, 9) dtype=int32>, <tf.Tensor 'IteratorGetNext:5' shape=(?, 200, 3, 5) dtype=float32>, <tf.Tensor 'IteratorGetNext:6' shape=(?, 200, 6, 2) dtype=int32>, <tf.Tensor 'IteratorGetNext:7' shape=(?, 200, 7) dtype=int32>, <tf.Tensor 'IteratorGetNext:8' shape=(?, 1) dtype=int32>]

这不显示真实的批量大小。

最佳答案

目前,要在批处理张量上获得完全定义的静态形状,您需要明确告诉 TensorFlow,如果批处理大小不能均匀划分元素总数,则“丢弃”任何“余数”。为此,请替换以下行:

dataset = dataset.batch(opt.batchsize)

...应用程序 tf.contrib.data.batch_and_drop_remainder() :

dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(opt.batchsize))

关于python - Tensorflow dataset.batch() 不显示真实的批量大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49641098/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com