gpt4 book ai didi

python - 如何在 py_function 之后 reshape (图像,标签)数据集

转载 作者:行者123 更新时间:2023-12-05 06:11:55 24 4
gpt4 key购买 nike

我正在尝试读取用于训练的自定义映射数据集。但是在我使用 py_function 映射数据集之后,我得到了未知的形状,例如:

def process_path(file_path):

label = get_label(file_path)
img = tf.io.read_file(file_path)
img = decode_img(img)

print('image shape:', img.shape) #this print correctly: image shape: (180, 180, 3)
print('label shape:', label.shape) #this print correctly: label shape: ()

return img, label

train_ds = train_ds.map(lambda x: tf.py_function(process_path, [x], (tf.float32, tf.int32)))
print(train_ds)
# this print unknown shape <PrefetchDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.int32)>

这将使 model.fit() 失败,所以我想将数据集 reshape 为正确的形状,例如:

<BatchDataset shapes: ((None, 180, 180, 3), (None,)), types: (tf.float32, tf.int32)>

使用:

train_ds = tf.reshape(train_ds, ((None, 180, 180, 3), (None,)))

但这会报错:

ValueError: Attempt to convert a value (<MapDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.int32)>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>) to a Tensor.

如何在此步骤中正确分配(图像、标签)形状?

最佳答案

这里不需要py_function。假设您有一个名为 /dogs 的文件夹,其中充满了 jpg。您可以使用这两个小函数来加载和解码。

如果文件名(例如,'dogs\\dog1.jpg')在文件夹 dogs 中,第一个返回 10 否则。

第二个函数也接受一个文件名并将其转换为介于 0 和 1 之间的 float 。然后,它还调整图片的大小。

如果有任何不清楚的地方,请告诉我。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from glob2 import glob

os.chdir('c:/users/nicol/pictures')

files = glob('*/*jpg')

def get_label(file_path):
split = tf.strings.split(file_path, sep=os.sep)[0]
equal = tf.equal(split, 'dogs')
cast = tf.cast(equal, tf.int32)
return cast

def process_path(file_path):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, size=(180, 180))
return img, label

train_ds = tf.data.Dataset.from_tensor_slices(files).map(process_path)

next(iter(train_ds))
(<tf.Tensor: shape=(180, 180, 3), dtype=float32, numpy=
array([[[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
[1.41176477e-01, 9.41176564e-02, 1.33333340e-01],
...,
[2.63300300e-01, 2.76176542e-01, 4.67582583e-01],
[2.46176332e-01, 2.59706050e-01, 4.50785339e-01],
[2.54726082e-01, 2.68909693e-01, 4.59662050e-01]]], dtype=float32)>,
<tf.Tensor: shape=(), dtype=int32, numpy=1>)

get_label 应该返回一个整数,如果不是的话。

关于python - 如何在 py_function 之后 reshape (图像,标签)数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63749852/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com