gpt4 book ai didi

python - 将cv.imread()数据发送到Keras模型

转载 作者:行者123 更新时间:2023-12-02 16:34:18 26 4
gpt4 key购买 nike

我在尝试弄清楚如何将数据从cv.imread()发送到我的机器学习模型时遇到了麻烦。

从我的图像读取功能中,我得到一个numpy数组的列表,其中包含维度为(256、256、3)的图像。

# image reading
res_img = []
for i in files:
img = cv2.imread(os.path.join("temp", i))
res = cv2.resize(img, (256, 256))
res_img.append(res)
return res_img

然后将其存储在数据框中并发送到模型。但是,该数据帧被检测为具有维度(56,1),其中56是我的数据的长度,维度为1,因为每个numpy数组都被检测为1个对象。

# train model
model = create_model(trainX)
model_history = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=..., batch_size=...)

# create model
def create_model(data):
model = Sequential()
model.add(Conv2D(32, kernel_size=4, activation='relu', input_shape=(256, 256, 3)))
...
return model

但是,这返回
ValueError: Error when checking input: expected conv2d_input to have 4 dimensions, but got array with shape (56, 1)

我尝试的另一件事是将数据中的所有numpy数组合并为一个大numpy数组,该数组确实具有正确的尺寸

trainX_arr = []
trainX = trainX.to_numpy()
for i in trainX:
trainX_arr.append(i)
trainX_arr = np.asarray(trainX_arr)

这确实给出了正确的形状:

print(trainX_arr.shape)
# (56, 256, 256, 3)

但是,发送模型时返回
ValueError: No data provided for "conv2d_input". Need data for each key in: ['conv2d_input']

我认为是因为输入不是数据帧。最后,我尝试在第一步中合并numpy数组,然后将其存储在数据框中,如下所示

res_img = []
for i in files:
img = cv2.imread(os.path.join("temp", i))
res = cv2.resize(img, (256, 256))
res_img.append(res)
img_arr = []
for i in res_img:
img_arr.append(i)
img_arr = np.asarray(img_arr)
return img_arr

但是,当尝试将其插入数据框时:

df.insert(0, "x", img_arr)

它返回
ValueError: Wrong number of dimensions. values.ndim != ndim [4 != 2]

我认为这是因为数据框无法容纳多维数组,但这使我回到了起点。为了使它正常工作,我对于应该实际做什么感到非常困惑,我们将不胜感激。

最佳答案

我设法使其工作,在这里我的第二种方法将numpy数组组合成具有有效尺寸的大数组,从而能够成功地训练模型。我不确定为什么以前无法使用,但是这是我的代码可以正常工作:

定义一个函数(任意称为numpyfy)以执行此数组合并

import os

import cv2
import numpy as np
import pandas as pd
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from keras.engine.input_layer import Input
from keras.models import Model
from sklearn.model_selection import train_test_split


# download images
def process_link(link_list):
counter = 0
for i in link_list:
if i.find("jpg") != -1:
ext = ".jpg"
elif i.find("png") != -1:
ext = ".png"
f = open(os.path.join("temp", str(counter) + ext), "wb")
f.write(requests.get(i).content)
f.close()
counter += 1
files = os.listdir("temp") # images stored in temp directory
res_img = []
for i in files:
img = cv2.imread(os.path.join("temp", i))
res = cv2.resize(img, (256, 256))
res_img.append(res)
return res_img

# process data
def process_data():
# link_list and y are a list of: links to images, and data labels, respectively
df = pd.DataFrame()
df.insert(0, "Y", y])
df.insert(0, "img", process_link(link_list))
(train, test) = train_test_split(clean_df, test_size=0.25, random_state=42)
return (train, test)

# numpy array combining
def numpyfy(df):
arr = []
df_numpy = df.to_numpy()
print(df_numpy[:2])
for i in df_numpy:
arr.append(i)
arr = np.asarray(arr)
#print(arr.shape), returns 4 dimensional array
return arr

# Deep learning model, changed to use the Functional API
def create_model():
input1 = Input(shape=(256, 256, 3))
conv1 = Conv2D(32, (3, 3), input_shape=(3, 256, 256), activation="relu")(inpu1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(32, (3, 3), activation="relu")(pool1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
flat1 = Flatten()(pool2)
dense1 = Dense(16, activation="relu")(flat1)
dense2 = Dense(1, activation="sigmoid")(dense1)
model = Model(inputs=input1, outputs=dense2)
model.compile(loss='mse', optimizer='adadelta', metrics=['mse', 'mae'])
return model

# train model
def train_model():
(train, test) = process_data() #Returns 2 dataframes (train, test)
train_img, test_img = numpyfy(train["img"]), numpyfy(test["img"])
model = create_model()
model.fit(train_img, train["Y"], validation_data=(test_img, test["Y"]),
epochs=epochs, batch_size=batch_size)

最后,我不确定我做了什么不同的操作并没有引发错误,但这是可行的。

关于python - 将cv.imread()数据发送到Keras模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61284143/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com