gpt4 book ai didi

python - 为分组数据的 RNN 生成具有特定长度的序列/批处理

转载 作者:行者123 更新时间:2023-12-04 15:54:13 25 4
gpt4 key购买 nike

当我希望将来自不同组的数据传递到 RNN 时,问题就出现了——大多数示例假设优雅的时间序列,但是在添加组时,我们不能简单地在数据帧上加窗,我们需要在组更改时跳转,以便数据来自集团内部。

这些组只是不同的人,所以我想将它们的顺序保密。例如。浏览网站的用户和我们收集的网页浏览数据。或者它可能是不同的股票及其相关的价格变动。

import pandas as pd
data = {
'group_id': [1,1,1,1,2,2],
'timestep': [1,2,3,4,1,2],
'x': [6,5,4,3,2,1],
'y': [0,1,1,1,0,1]
}
df = pd.DataFrame(data=data)


group_id timestep x y
0 1 1 6 0
1 1 2 5 1
2 1 3 4 1
3 1 4 3 1
4 2 1 2 0
5 2 2 1 1

假设我们想使用 2 个样本的批量大小,并且每个样本都有 3 个时间步长。 RNNSequence.__len__ = 3(下)批处理,但这是不可能的,因为我们最多可以从第一组(即 1 批处理)中获得 2 个样本。第二组只有 2 个时间步,所以迭代是不可能的。

from keras.utils import Sequence

class RNNSequence(Sequence):

def __init__(self, x_set, y_set, batch_size, seq_length):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.seq_length = seq_length

def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))

def __getitem__(self, idx):
# get_batch to be coded
return get_batch(idx, self.x, self.y, self.batch_size, self.seq_length)

使用序列获取这些批处理的最有效方法是什么?

我的解决方案实际上是不使用 Sequence,而是使用自定义生成器来输出数据,而事先不知道会有多少批处理。并使用 fit_generator(custom_generator, max_queue_size=batch_size) 代替。这是最有效的方法吗?这里的问题是没有洗牌,这可能是个问题?

batchsize=2,seq_length=3 的期望输出是:

X = [ 
[ [6], [5], [4] ],
[ [5], [4], [3] ]
]

Y = [ 1, 1 ]

最佳答案

看来您不仅需要知道批处理的数量,还需要能够仅在给定批处理编号的情况下输出任何批处理。您可以在 RNNSequence.__init__ 或更早版本中创建所有样本的索引,然后从中组装批处理。在 __getitem__ 中,您可以相应地输出批处理。

这个快速而肮脏的伪代码应该说明示例索引的概念。如果需要,您可能会决定使用 pandas 或 numpy 中的函数等。

# Pseuducode for generating indexes for where samples start.
seq_len = 3
sample_start_ids = []
for group_id, group in enumerate(groups):
for timestep_id, timestep in enumerate(group_timesteps):
# Only add as sample if it is the first
# timestep in the group or if a full sample fits.
if timestep == 1 or timestep <= len(group_timesteps) - seq_len+1:
sample_start_ids.append((group_id, timestep_id))

num_samples = len(sample_start_ids)

# Group the samples into batches of appropriate size.
pass

num_batches = len(your_batches)

关于python - 为分组数据的 RNN 生成具有特定长度的序列/批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52594650/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com