gpt4 book ai didi

python - 使用 tf.data 批量顺序数据

转载 作者:行者123 更新时间:2023-12-01 08:13:14 27 4
gpt4 key购买 nike

让我们考虑一个有序的玩具数据集,具有两个特征:

  • (例如1、2、3、4、5、111、222、333、444、555)
  • sequence_id(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1)

该数据基本上由两个连接的扁平序列组成:1, 2, 3, 4, 5(序列0)和111, 222, 333 , 444, 555(序列1)。

我想生成大小为 t 的序列(例如 3),由同一序列 (sequence_id) 中的连续元素组成,我不希望序列包含属于不同 sequence_id 的元素。

例如,在没有任何洗牌的情况下,我想获得以下批处理:

  • 第一批:1,2,3,
  • 第二批:2,3,4,
  • 第三批:3,4,5,
  • 第四批:111、222、333
  • 第五批:222、333、444
  • 第六批:333、444、555
  • 第七批:1,2,3,
  • 等等

我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列数据,但我不知道如何防止包含不同 sequence_id 混合的序列(例如,序列 4, 5, 111 不应该有效,因为它混合了序列 0 和序列1)。

以下是我失败的尝试:

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True)\
.repeat(-1)\
.flat_map(lambda x, y: x.batch(3))\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))

输出:

[[  1   2   3]   # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 111] # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
[ 5 111 222] # bad
[111 222 333] # good
[222 333 444] # good
[333 444 555] # good
[ 1 2 3] # good
[ 2 3 4]] # good

最佳答案

可以使用filter()判断sequence_id是否一致。由于 filter() 转换当前不支持嵌套数据集作为输入,因此您需要 zip()

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True) \
.flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
.filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
.map(lambda x,y:x)\
.repeat(-1)\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))

[[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]
[222 333 444]
[333 444 555]
[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]]

关于python - 使用 tf.data 批量顺序数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55109817/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com