gpt4 book ai didi

python - 过滤数据集以仅获取特定类别的图像

转载 作者:太空宇宙 更新时间:2023-11-04 02:00:55 24 4
gpt4 key购买 nike

我想为 n-shot 学习准备 omniglot 数据集。因此我需要来自 10 个类(字母)的 5 个样本

重现代码

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']


def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
one_hot_label = np.zeros((51, 10))
return image, one_hot_label, example['alphabet']


def stack(image, label, alphabet):
return (image, label), label[-1]

def filter_func(image, label, alphabet):
# get just images from alphabet in array, not just 2
arr = np.array(2,3,4,5)
result = tf.reshape(tf.equal(alphabet, 2 ), [])
return result

# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)

for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
print(i, image[0].shape)

现在我想使用 filter 函数过滤数据集中的图像。tf.equal 让我按一个类过滤,我想要数组中的张量之类的东西。

您是否看到使用过滤器功能执行此操作的方法?或者这是错误的方法,还有更简单的方法吗?

我想创建一批 51 张图像和相应的标签,它们来自相同的 N=10 个类。从每个类(class),我需要 K=5 个不同的图像和一个额外的图像(我需要对其进行分类)。每批 N*K+1 (51) 张图像应该来自 10 个新的随机类别。

非常感谢您。

最佳答案

要仅保留特定标签,请使用此谓词:

dataset = datasets['train']

def predicate(x, allowed_labels=tf.constant([0, 1, 2])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, allowed_labels.dtype))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))

dataset = dataset.filter(predicate).batch(20)

for i, x in enumerate(tfds.as_numpy(dataset)):
print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]

allowed_labels 指定您要保留的标签。所有不在此张量中的标签都将被过滤掉。

关于python - 过滤数据集以仅获取特定类别的图像,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55731774/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com