gpt4 book ai didi

python - 索引错误: index is out of bounds for axis 0 with size

转载 作者:行者123 更新时间:2023-11-30 09:53:36 25 4
gpt4 key购买 nike

我有数组x_traintargets_train。我想对训练数据进行洗牌,并将其分成更小的批处理,并将这些批处理用作训练数据。我的原始数据有 1000 行,每次我尝试使用其中的 250 行:

    x_train = np.memmap('/home/usr/train', dtype='float32', mode='r', shape=(1000, 1, 784))
# print(x_train)
targets_train = np.memmap('/home/usr/train_label', dtype='int32', mode='r', shape=(1000, 1))
train_idxs = [i for i in range(x_train.shape[0])]
np.random.shuffle(train_idxs)


num_batches_train = 4
def next_batch(start, train, labels, batch_size=250):
newstart = start + batch_size
if newstart > train.shape[0]:
newstart = 0
idxs = train_idxs[start:start + batch_size]
# print(idxs)
return train[idxs, :], labels[idxs, :], newstart


# x_train_lab = x_train[:200]
# # x_train = np.array(targets_train)
# targets_train_lab = targets_train[:200]
for i in range(num_batches_train):
x_train, targets_train, newstart = next_batch(i*batch_size, x_train, targets_train, batch_size=250)

问题是,当我打乱训练数据并尝试访问批处理时,我收到错误消息:

    return train[idxs, :], labels[idxs, :], newstart
IndexError: index 250 is out of bounds for axis 0 with size 250

有人知道我做错了什么吗?

最佳答案

(编辑 - 关于 newstart 的第一个猜测已删除)

在这一行中:

x_train, targets_train, newstart = next_batch(i*batch_size, x_train, targets_train, batch_size=250)

您在每次迭代时更改了 x_train 的大小,但仍继续使用为全尺寸数组创建的 train_idxs 数组。

批量从 x_train 中提取随机值是一回事,但必须保持选择数组一致。

由于缺乏最小且可验证的示例,这个问题可能应该被关闭。必须猜测并制作一个小的可测试示例以希望重现问题,这令人沮丧。

https://stackoverflow.com/help/mcve

如果我当前的猜测是错误的,那么只需几个中间打印语句就可以清楚地说明问题。

==========================

将代码简化为简单的情况

import numpy as np
x_train = np.arange(20).reshape(20,1)
train_idxs = np.arange(x_train.shape[0])
np.random.shuffle(train_idxs)

num_batches_train = 4
batch_size=5
def next_batch(start, train):
idxs = train_idxs[start:start + batch_size]
print(train.shape, idxs)
return train[idxs, :]

for i in range(num_batches_train):
x_train = next_batch(i*batch_size, x_train)
print(x_train)

运行产生:

1658:~/mypy$ python3 stack39919181.py 
(20, 1) [ 7 18 3 0 9]
[[ 7]
[18]
[ 3]
[ 0]
[ 9]]
(5, 1) [13 5 2 15 1]
Traceback (most recent call last):
File "stack39919181.py", line 14, in <module>
x_train = next_batch(i*batch_size, x_train)
File "stack39919181.py", line 11, in next_batch
return train[idxs, :]
IndexError: index 13 is out of bounds for axis 0 with size 5

我将 (5,1) x_train 反馈给 next_batch 但尝试对其进行索引,就好像它是原始数据一样。

将迭代更改为:

for i in range(num_batches_train):
x_batch = next_batch(i*batch_size, x_train)
print(x_batch)

让它运行生产 4 批,每批 5 行。

关于python - 索引错误: index is out of bounds for axis 0 with size,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39919181/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com