gpt4 book ai didi

python - 如何将本地数据加载到我的tensorflow/keras-model中?

转载 作者:太空宇宙 更新时间:2023-11-03 20:34:56 25 4
gpt4 key购买 nike

我是 tensorflow 和 keras 的新手,不知道如何加载数据以使模型适合。

我尝试使用从图像路径和标签列表构建的 tf.dataset,但无济于事。我知道下面代码中的模型本身可能不太适合我的任务。我只是想尝试一下并学习如何建立模型并训练它。我的图像有多种格式(tiff、png、jpg)并且具有不同的尺寸。这就是为什么我需要调整它们的大小并将它们转换为 numpy 数组。我根据这个线程尝试过:TensorFlow: training on my own image

import tensorflow as tf
import random
import numpy
import cv2
from PIL import Image
from pathlib import Path
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator

training_data_path = Path("/home/xxxx/validation_data")

validation_data_path = Path("/home/xxxx/validation_data")
test_data_paths = Path("/home/xxxx/test_data")
validation_image_paths = list(validation_data_path.glob("**/*"))
label_array = ["DIS","ANG", "FEA", "SAD", "SUR", "JOY", "NEU"]
label_to_index = dict((name, index) for index,name in enumerate(label_array))

def getLabelDict(image_paths):

all_image_labels = [label_to_index[Path(path).absolute().name[0:3]]
for path in image_paths]
return all_image_labels

def getLabelList(image_paths):
all_img_labels = list()
for path in image_paths:
all_img_labels.append(Path(path).absolute().name[0:3])
return all_img_labels


def preProcessPath(path):
return path.absolute().name


def get_ds(data_path):
image_paths = list(data_path.glob("**/*"))
img_paths = tf.constant(image_paths)

dataset = tf.data.Dataset.from_tensor_slices((img_paths, getLabelList(image_paths)))
for path in image_paths:
dataset.map(getPic(path))

return dataset

def getPic(path):
image = Image.open(path).convert('RGB')
image = image.resize((256,256,3))
array = numpy.array(image.getdata())
array = array.reshape((256,256,3))
return array


os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


ds_inputs, ds_labels = get_ds(test_data_paths).make_one_shot_iterator()
val_inputs, val_labels= validation_data=get_ds(validation_data_path).make_one_shot_iterator()

model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(256,256,3)),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(ds_inputs, epochs=1, steps_per_epoch=3,validation_data=val_inputs)

编辑:删除代码中不必要的行

现在我收到一个类型错误:TypeError:无法将类型对象转换为张量。内容:

最佳答案

代码中的问题:

  1. PIL resize :将大小视为 2 元组(宽度、高度)。您传入了 3 个值。
  2. get_ds:标签应编码为数字而不是字符串
  3. 模型架构:由于您使用 3 channel 图像作为神经网络的输入,因此必须首先将它们展平。

工作代码---(在评论中解释)

label_array = ["DIS","ANG", "FEA", "SAD", "SUR", "JOY", "NEU"]
label_to_index = dict((name, index) for index,name in enumerate(label_array))

# Takes as input path to image file and returns
# resized 3 channel RGB image of as numpy array of size (256, 256, 3)
def getPic(img_path):
return np.array(Image.open(img_path).convert('RGB').resize((256,256),Image.ANTIALIAS))

# returns the Label of the image based on its first 3 characters
def get_label(img_path):
return Path(img_path).absolute().name[0:3]

# Return the images and corresponding labels as numpy arrays
def get_ds(data_path):
img_paths = list()
# Recursively find all the image files from the path data_path
for img_path in glob.glob(data_path+"/**/*"):
img_paths.append(img_path)
images = np.zeros((len(img_paths),256,256,3))
labels = np.zeros(len(img_paths))

# Read and resize the images
# Get the encoded labels
for i, img_path in enumerate(img_paths):
images[i] = getPic(img_path)
labels[i] = label_to_index[get_label(img_path)]

return images,labels

# Model Architecture
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(256,256,3)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

# Load the train and validation data
train_X, train_y = get_ds("./images/")
val_X, val_y = get_ds("./v_images/")

# Finally train it
model.fit(train_X,train_y, validation_data=(val_X,val_y))

# Predictions
model.predict(val_X)

关于python - 如何将本地数据加载到我的tensorflow/keras-model中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57230587/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com