gpt4 book ai didi

python - 自定义损失函数 : logits and targets must have the same shape ((? , 1) vs (45000,))

转载 作者:行者123 更新时间:2023-11-30 09:04:18 26 4
gpt4 key购买 nike

我的模型将二元分类器中的所有内容预测为 0。我们总共有 4000 个 true 和 41000 个 false。因此,我们正在尝试制作一个自定义损失函数。

我收到的错误是:

(logits.get_shape(),targets.get_shape()))

ValueError:logits 和目标必须具有相同的形状 ((?, 1) vs (45000,))

代码如下所示:

combined = tf.keras.layers.concatenate([modelRNN.output, modelCNN.output])

final_dense = tf.keras.layers.Dense(10, activation='relu')(combined) #ff kijken of dit slim is
final_dense = tf.keras.layers.Dense(1, activation='sigmoid')(final_dense)

final_model = tf.keras.Model(inputs=[modelCNN.input, modelRNN.input], outputs=final_dense)

targets = match_train
logits = final_dense
pos_weight = (45000 - 4539) / 4539


custom_loss = tf.nn.weighted_cross_entropy_with_logits(
targets,
logits,
pos_weight,
)


final_model.compile(optimizer='adam',
loss=custom_loss,
metrics=['accuracy'])

初始数组的形状是:

modelCNN = (45000, 28, 28, 1) float64
modelRNN = (45000, 93, 13) float64
labels = (45000,1) boolean

通过注释中的代码部分解决了问题。我现在收到一个以前没有过的错误。它说:

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

File "<ipython-input-6-42327e5a4b50>", line 3, in <module>
metrics=['accuracy'])

File "C:\Users\Tijev\Anaconda3\envs\tfp3.6\lib\site-packages\tensorflow\python\training\checkpointable\base.py", line 442, in _method_wrapper
method(self, *args, **kwargs)

File "C:\Users\Tijev\Anaconda3\envs\tfp3.6\lib\site-packages\tensorflow\python\keras\engine\training.py", line 215, in compile
loss = loss or {}

最佳答案

将标签 reshape 为 2D 张量。

targets = np.asarray(match_train).astype('float32').reshape((-1,1))

来源:Tensorflow estimator ValueError: logits and labels must have the same shape ((?, 1) vs (?,))

关于python - 自定义损失函数 : logits and targets must have the same shape ((? , 1) vs (45000,)),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56367068/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com