gpt4 book ai didi

python - 如何从 numpy 数组的数组中获取 tensorflow 2 中的窗口数据集?

转载 作者:太空宇宙 更新时间:2023-11-04 04:01:54 25 4
gpt4 key购买 nike

假设我有一些数据:

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])

看起来像这样:

array([[1, 2, 3, 4],
[5, 6, 7, 8]])

每一行代表不同的观察结果,因此不应将它们合并。我想创建一个窗口数据集,每个窗口大小为 3,移动 1。当我通过一次观察时,我得到了我想要的,如下所示:

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

结果:

for x in dataset:
print(x.numpy())

[1 2 3]
[2 3 4]

但是当我传递整个 numpy 数组时,我什么也得不到。

dataset = tf.data.Dataset.from_tensor_slices(some_data)
dataset = dataset.window(size=3, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(3))

这是我所期望的:

for x in dataset:
print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]

我想我可以遍历 some_data 并一次传递一个数组,然后连接数据集,但这似乎是一个糟糕的解决方案。正确的做法是什么?

我正在使用 Tensorflow 2.0。谢谢!

最佳答案

当您使用 dataset = tf.data.Dataset.from_tensor_slices(some_data[0]) 时,每一行数据集只有一个元素。

dataset = tf.data.Dataset.from_tensor_slices(some_data[0])
for x in dataset:
print(x.numpy())
1
2
3
4

但是当您使用 dataset = tf.data.Dataset.from_tensor_slices(some_data) 时,数据集的每一行都有四个元素。

dataset = tf.data.Dataset.from_tensor_slices(some_data)
for x in dataset:
print(x.numpy())
[1 2 3 4]
[5 6 7 8]

所以你需要做的是转换每一行并合并它。

import numpy as np
import tensorflow as tf

some_data = np.array([[1,2,3,4], [5, 6, 7,8]])
dataset = tf.data.Dataset.from_tensor_slices(some_data)

def parse_samples(x):
return tf.data.Dataset.from_tensor_slices(x)\
.window(size=3, shift=1, drop_remainder=True)\
.flat_map(lambda window: window.batch(3))

dataset = dataset.flat_map(parse_samples)

for x in dataset:
print(x.numpy())

[1 2 3]
[2 3 4]
[5 6 7]
[6 7 8]

关于python - 如何从 numpy 数组的数组中获取 tensorflow 2 中的窗口数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58121828/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com