gpt4 book ai didi

tensorflow - Tensorflow 的 K-Means - 图断开错误

转载 作者:行者123 更新时间:2023-12-05 06:47:32 26 4
gpt4 key购买 nike

我正在尝试编写一个运行 KMeans 的函数在数据集上并输出聚类质心。我的目标是在自定义中使用它 keras层,所以我使用 TensorFlow以张量作为输入数据集的 KMeans 实现。

然而,我的问题是,即使作为独立函数,我也无法使其正常工作。问题来自于 KMeans接受提供小批量而不是普通张量的生成器函数,但是当我使用闭包来执行此操作时,我得到一个 graph disconnected错误:

import tensorflow as tf                                           # version: 2.4.1
from tensorflow.compat.v1.estimator.experimental import KMeans

@tf.function
def KMeansCentroids(inputs, num_clusters, steps, use_mini_batch=False):
# `inputs` is a 2D tensor

def input_fn():
# Each one of the lines below results in the same "Graph Disconnected" error. Tuples don't really needed but just to be consistent with the documentation
return (inputs, None)
return (tf.data.Dataset.from_tensor_slices(inputs), None)
return (tf.convert_to_tensor(inputs), None)

kmeans = KMeans(
num_clusters=num_clusters,
use_mini_batch=use_mini_batch)

kmeans.train(input_fn, steps=steps) # This is where the error happens
return kmeans.cluster_centers()

>>> x = tf.random.uniform((100, 2))
>>> c = KMeansCentroids(x, 5, 10)

准确的错误是:

ValueError:Tensor("strided_slice:0", shape=(), dtype=int32)must be from the same graph asTensor("Equal:0", shape=(), dtype=bool)(graphs are FuncGraph(name=KMeansCentroids, id=..) and <tensorflow.python.framework.ops.Graph object at ...>).

  • 如果我使用 numpy数据集并在函数内部转换为张量,代码就可以正常工作。
  • 另外,制作 input_fn()直接返回tf.random.uniform((100, 2)) (忽略输入参数),将再次工作。这就是为什么我猜测 tensorflow 不支持闭包,因为它需要在开始时构建计算图。
    但我不知道如何解决这个问题。可能是由于 KMeans 是 compat.v1.experimental 而导致的版本错误模块?

请注意 documentation of KMeans input_fn() 的状态:

The function should construct and return one of the following:

  • A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below.
  • A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.

最佳答案

您面临的问题更多是关于在创建的图形之外调用张量。基本上,当您调用 .train 函数时,将创建一个新图,其中包含 input_fn 中定义的图和 model_fn 中定义的图

kmeans.train(input_fn, steps=steps)

并且,之后所有来自这些函数之外的张量都将被视为局外人,不会成为这个新图表的一部分。这就是为什么您在尝试使用外部张量时遇到 graph disconnected 错误。要解决此问题,您需要在这些图中创建必要的张量。

import tensorflow as tf                                        
from tensorflow.compat.v1.estimator.experimental import KMeans

@tf.function
def KMeansCentroids(num_clusters, steps, use_mini_batch=False):
def input_fn(batch_size):
pinputs = tf.random.uniform((100, 2))
dataset = tf.data.Dataset.from_tensor_slices((pinputs))
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)

kmeans = KMeans(
num_clusters=num_clusters,
use_mini_batch=use_mini_batch)

kmeans.train(input_fn = lambda: input_fn(5),
steps=steps)

return kmeans.cluster_centers()

c = KMeansCentroids(5, 10)

这里有更多信息供阅读,1 .仅供引用,我用几个版本的 tf > 2 测试了您的代码,我认为这与版本错误或其他问题无关。


在这里为 future 的读者重新提及。在 Keras 层中使用 KMeans 的替代方法:

关于tensorflow - Tensorflow 的 K-Means - 图断开错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67088567/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com