gpt4 book ai didi

python - Tensorflow,如何连接具有不同批处理大小的多个数据集

转载 作者:行者123 更新时间:2023-12-03 19:00:56 25 4
gpt4 key购买 nike

想象一下我有:

  • 数据集 1 的数据 [5, 5, 5, 5, 5]
  • 数据集 2 包含数据 [4, 4]

  • 我想从两个数据集中提取批次并将它们连接起来,以便获得大小为 3 的批次,其中:
  • 我读取了批量大小为 2 的数据集 1
  • 我读取了批大小为 1 的数据集 2。

  • 如果某些数据集首先被清空,我还想读取最后一批。
    在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。

    我怎样才能做到这一点?
    我在这里看到了答案: Tensorflow how to generate unbalanced combined data sets

    这是一个很好的尝试,但如果其中一个数据集在其他数据集之前被清空,则它不起作用(因为当您尝试从数据集中获取首先被清空的元素时, tf.errors.OutOfRangeError 被抢先输出,而我没有得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]

    我想过用 tf.contrib.data.choose_from_datasets :
    ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
    ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
    choice_dataset = [1, 2, 1, 2, 1]
    ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
    ds = ds.apply(tf.contrib.data.unbatch())
    ds = ds.batch(3, drop_remainder=False)

    这种工作,但相当不雅(有unbatch和batch);此外,我对批次中的确切内容并没有很好的控制。 (例如,如果 ds1 是 [7] * 7,批大小为 2,而 ds2 是 [2, 2],批大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]. 但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 呢?即保持每个数据集中的元素数量固定.

    还有其他更好的解决方案吗?

    我的另一个想法是尝试使用 tf.data.Dataset.flat_map :
    ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
    ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
    batch_sizes = [2, 1]
    def concat(*inputs):
    concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
    datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
    datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
    return concat(datasets)
    dataset = (tf.data.Dataset
    .zip((ds1, ds2))
    .flat_map(_concat_and_batch)
    .batch(sum(batch_sizes)))

    但它似乎不起作用..

    最佳答案

    如果您不介意在构建新数据集期间运行 session ,您可以执行以下操作:

    import tensorflow as tf
    import numpy as np

    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])

    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)

    iter1 = ds1.make_one_shot_iterator()
    iter2 = ds2.make_one_shot_iterator()

    batch1 = iter1.get_next()
    batch2 = iter2.get_next()

    sess = tf.Session()

    # define a generator that will sess.run both datasets, and will return the concatenation of both
    def GetBatch():
    while True:
    try:
    b1 = sess.run(batch1)
    except tf.errors.OutOfRangeError:
    b1 = None
    try:
    b2 = sess.run(batch2)
    except tf.errors.OutOfRangeError:
    b2 = None
    if (b1 is None) and (b2 is None):
    break
    elif b1 is None:
    yield b2
    elif b2 is None:
    yield b1
    else:
    yield np.concatenate((b1,b2))

    # create a dataset from the above generator
    ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)

    请注意,如果您愿意(例如,在函数内部),可以隐藏\封装上述 session ,例如:
    iter = ds.make_one_shot_iterator()
    batch = iter.get_next()

    sess2 = tf.Session()

    while True:
    print(sess2.run(batch))

    关于python - Tensorflow,如何连接具有不同批处理大小的多个数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52897168/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com