gpt4 book ai didi

python - ValueError : Error when checking input: expected input_1 to have 4 dimensions, 但得到形状为 (6243, 256, 256) 的数组

转载 作者:行者123 更新时间:2023-12-01 19:50:00 25 4
gpt4 key购买 nike

我想在训练数据集上附加标签,我这样做

def one_hot_label(img):
label = img
if label == 'A':
ohl = np.array([1, 0])
elif label == 'B':
ohl = np.array([0, 1])
return ohl

def train_data_with_label():
train_images = []
for i in tqdm(os.listdir(train_data)):
path_pre = os.path.join(train_data, i)
for img in os.listdir(path_pre):
if img.endswith('.jpg'):
path = os.path.join(path_pre, img)
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
train_images.append([np.array(img), one_hot_label(i)])
shuffle(train_images)
return train_images

但是在Keras上执行输入时返回错误

training_images = train_data_with_label()
tr_img_data = np.array([i[0] for i in training_images])
tr_lbl_data = np.array([i[1] for i in training_images])

model = Sequential()
model.add(InputLayer(input_shape=(256, 256, 1)))

谁能帮我解决这个问题吗?

最佳答案

您的输入层需要一个形状为 (batch_size, 256, 256, 1) 的数组,但看起来您正在传递形状为 (batch_size, 256, 256) 的数据。您可以尝试按如下方式 reshape 训练数据:

tr_img_data = np.expand_dims(tr_img_data, axis=-1) 

关于python - ValueError : Error when checking input: expected input_1 to have 4 dimensions, 但得到形状为 (6243, 256, 256) 的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54798952/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com