gpt4 book ai didi

python - 在 mxnet 中使用我自己的 python 数据迭代器时出错

转载 作者:太空宇宙 更新时间:2023-11-03 15:50:43 24 4
gpt4 key购买 nike

我正在尝试创建自己的数据迭代器以与 mxnet 一起使用。当我运行它时,我收到错误:

Traceback (most recent call last):
File "train.py", line 24, in <module>
batch_end_callback = mx.callback.Speedometer(batch_size, 1) # output progress for each 200 data batches
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 811, in fit
sym_gen=self.sym_gen)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 236, in _train_multi_device
executor_manager.load_data_batch(data_batch)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 410, in load_data_batch
self.curr_execgrp.load_data_batch(data_batch)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 257, in load_data_batch
_load_data(data_batch, self.data_arrays)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 93, in _load_data
_load_general(batch.data, targets)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 89, in _load_general
d_src[slice_idx].copyto(d_dst)
AttributeError: 'numpy.ndarray' object has no attribute 'copy'

我认为这与我返回数据的方式有关。请参阅下面我的数据迭代器代码:

from mxnet.io import DataIter, DataDesc
import csv
from random import shuffle
import numpy as np
from cv2 import imread, resize

class MyData(DataIter):
def __init__(self, root_dir, flist_name, batch_size, size=(256,256), shuffle=True):
super(MyData, self).__init__()
self.batch_size = batch_size
self.root_dir = root_dir
self.flist_name = flist_name
self.size = size
self.shuffle = shuffle

self.data = []
with open(flist_name, 'rb') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
self.data.append(row)
self.num_data = len(self.data)
self.provide_data = [DataDesc('data', (self.batch_size, 6, self.size[0], self.size[1]), np.float32)]
self.provide_label = [DataDesc('Pa_label', (self.batch_size, 1), np.float32)]
self.reset()

def reset(self):
"""Reset the iterator. """
self.cursor = 0
if self.shuffle:
shuffle(self.data)

def iter_next(self):
"""Iterate to next batch.
Returns
-------
has_next : boolean
Whether the move is successful.
"""
self.cursor += self.batch_size
success = self.cursor < self.num_data
return success

def getdata(self):
"""Get data of current batch.
Returns
-------
data : NDArray
The data of current batch.
"""
datalist = self.data[self.cursor:self.cursor+self.batch_size]
ret = np.ndarray(shape=(0,6,self.size[0],self.size[1]), dtype=np.float32)
for data_row in datalist:
img1 = resize(imread(data_row[0]), self.size)
img2 = resize(imread(data_row[1]), self.size)
img1 = np.rollaxis(img1, 2)
img2 = np.rollaxis(img2, 2)
img = np.concatenate((img1, img2), 0)
imge = np.expand_dims(img,0)
ret = np.append(ret, imge, 0)

print ret.shape
pad = self.batch_size - ret.shape[0]
if pad > 0:
ret = np.append(ret, np.zeros((pad, 6, self.size[0], self.size[1])), 0)
return ret

def getlabel(self):
"""Get label of current batch.
Returns
-------
label : NDArray
The label of current batch.
"""
datalist = self.data[self.cursor:self.cursor+self.batch_size]
ret = np.ndarray(shape=(0,1,1,1), dtype=np.float32)
for data_row in datalist:
label = np.ndarray(shape=(1,1,1,1), dtype=np.float32)
label[0,0,0,0] = float(data_row[2]) / float(data_row[5])
np.append(ret, label, 0)

pad = self.batch_size - ret.shape[0]
np.append(ret, np.zeros((pad, 1,1,1)), 0)
return ret

def getindex(self):
"""Get index of the current batch.
Returns
-------
index : numpy.array
The index of current batch
"""
return self.cursor

def getpad(self):
"""Get the number of padding examples in current batch.
Returns
-------
pad : int
Number of padding examples in current batch
"""
if self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
else:
return 0

最佳答案

numpy.ndarray 没有 copyto 方法。尝试使用 mx.ndarray。

关于python - 在 mxnet 中使用我自己的 python 数据迭代器时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41323159/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com