gpt4 book ai didi

python - 从 Tensorflow PrefetchDataset 中提取目标

转载 作者:行者123 更新时间:2023-12-03 09:33:01 32 4
gpt4 key购买 nike

我仍在学习 tensorflow 和 keras,我怀疑这个问题有一个非常简单的答案,我只是因为不熟悉而错过了。
我有一个 PrefetchDataset目的:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>
...由特征和目标组成。我可以使用 for 迭代它环形:
> for example in tf_test:
> print(example[0].numpy())
> print(example[1].numpy())
> exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
[-0.22 -0.54 -0.14 ... 0.33 -0.55]
[-0.60 -0.02 -1.41 ... 0.21 -0.63]
...
[-0.03 -0.91 -0.12 ... 0.77 -0.23]
[-0.76 -1.48 -0.15 ... 0.38 -0.35]
[-0.55 -0.08 -0.69 ... 0.44 -0.36]]
[0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
...
0 1 1 0]
然而,这是非常缓慢的。我想要做的是访问与类标签相对应的张量,并将其转换为一个 numpy 数组、列表或任何类型的可迭代对象,这些可迭代对象可以输入到 scikit-learn 的分类报告和/或混淆矩阵中:
> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
[0.14]
[0.00]
...
[0.32]
[0.03]
[0.00]]
> y_pred_list = [int(x[0]) for x in y_pred] # assumes value >= 0.5 is positive prediction
> y_true = [] # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)
...或访问数据,使其可用于 tensorflow 的混淆矩阵:
> labels = []                                           # what I need help with
> predictions = y_pred_list # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)
在这两种情况下,以计算成本不高的方式从原始对象中获取目标数据的一般能力将非常有帮助(并且可能有助于我的潜在直觉:tensorflow 和 keras)。
任何建议将不胜感激。

最佳答案

您可以使用 list(ds) 将其转换为列表然后使用 tf.data.Dataset.from_tensor_slices(list(ds)) 将其重新编译为普通数据集.从那里你的噩梦再次开始,但至少这是其他人以前做过的噩梦。
请注意,对于更复杂的数据集(例如嵌套字典),您需要在调用 list(ds) 后进行更多的预处理。 ,但这应该适用于您询问的示例。
这远不是一个令人满意的答案,但不幸的是,该类(class)完全没有记录,并且标准 Dataset 技巧都不起作用。

关于python - 从 Tensorflow PrefetchDataset 中提取目标,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62436302/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com