gpt4 book ai didi

python - 按 id 过滤 Tensorflow 数据集

转载 作者:行者123 更新时间:2023-12-05 09:34:41 26 4
gpt4 key购买 nike

问题

我正在尝试基于包含我希望子集化的索引的 numpy 数组来过滤 Tensorflow 2.4 数据集。该数据集有 1000 张形状为 (28,28,1) 的图像。

玩具示例代码

m_X_ds = tf.data.Dataset.from_tensor_slices(list(range(1, 21))).shuffle(10, reshuffle_each_iteration=False)
arr = np.array([3, 4, 5])
m_X_ds = tf.gather(m_X_ds, arr) # This is the offending code

错误信息

ValueError: Attempt to convert a value (<ShuffleDataset shapes: (), types: tf.int32>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.ShuffleDataset'>) to a Tensor.

目前的研究

我找到了 thisthis但它们特定于它们的用例,而我正在寻找一种更通用的子集方法(即基于外部派生的索引数组)。

我对 Tensorflow 数据集还很陌生,迄今为止发现学习曲线非常陡峭。希望得到一些帮助。提前致谢!

最佳答案

长话短说

建议使用选项 C,定义如下。

完整答案

创建tf.data.Dataset 对象是为了不必将所有对象都加载到内存中。因此,默认情况下,使用 tf.gather 是行不通的。您可以选择三个选项:

选项 A:将 ds 加载到内存 + tf.gather

如果您想使用收集,您必须将整个数据集加载到内存中,然后选择一个子集:

m_X_ds = list(m_X_ds)  # Load into memory.
m_X_ds = tf.gather(m_X_ds, arr)) # Gather as usual.
print(m_X_ds)
# Example result: <tf.Tensor: shape=(3,), dtype=int32, numpy=array([8, 6, 2], dtype=int32)>

请注意,这并不总是可行的,尤其是对于庞大的数据集。

选项 B:遍历数据集,过滤不需要的样本

您还可以遍历数据集并手动选择具有所需索引的样本。这可以通过 filter 的组合来实现和 tf.py_function

m_X_ds = m_X_ds.enumerate()  # Create index,value pairs in the dataset.

# Create filter function:
def filter_fn(idx, value):
return idx in arr

# The above is not going to work in graph mode
# We are wrapping it with py_function to execute it eagerly
def py_function_filter(idx, value):
return tf.py_function(filter_fn, (idx, value), tf.bool)

# Filter the dataset as usual:
filtered_ds = m_X_ds.filter(py_function_filter)

选项 C:结合选项 B 和 tf.lookup.StaticHashTable

选项 B 很好,除了在转换图张量 -> 急切张量 -> 图张量时你可以预期性能会受到影响。 tf.py_function 很有用,但要付出代价。

相反,我们可以声明一个字典,其中所需的索引将返回 true,而不存在的索引可以返回 false。这个字典可能看起来像这样

my_table = {3: True, 4: True, 5: True}.

我们不能使用张量作为字典键,但我们可以声明一个 tensorflow's hash table让我们检查“好”指数。

m_X_ds = m_X_ds.enumerate()  # Do not repeat this if executed in Option B.

keys_tensor = tf.constant(arr)
vals_tensor = tf.ones_like(keys_tensor) # Ones will be casted to True.

table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=0) # If index not in table, return 0.


def hash_table_filter(index, value):
table_value = table.lookup(index) # 1 if index in arr, else 0.
index_in_arr = tf.cast(table_value, tf.bool) # 1 -> True, 0 -> False
return index_in_arr

filtered_ds = m_X_ds.filter(hash_table_filter)

无论选择 B 还是 C,剩下的就是从 fileterd 数据集中删除索引。我们可以使用带有 lambda 函数的简单映射:

final_ds = filtered_ds.map(lambda idx,value: value)
for entry in final_ds:
print(entry)

# tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(13, shape=(), dtype=int32)
# tf.Tensor(6, shape=(), dtype=int32)

祝你好运。

关于python - 按 id 过滤 Tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66410340/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com