gpt4 book ai didi

python - 将从 keras.backend.argmax 返回的张量作为索引传递给 keras.backend,gather 预计为 'An integer tensor of indices.'

转载 作者:行者123 更新时间:2023-11-30 09:18:10 27 4
gpt4 key购买 nike

我正在尝试实现自定义损失函数

def lossFunction(self,y_true,y_pred):

maxi=K.argmax(y_true)

return K.mean((K.max(y_true) -(K.gather(y_pred,maxi)))**2)

训练时出现以下错误

<小时/>

InvalidArgumentError (see above for traceback): indices[5] = 51 is not in [0, 32) [[Node: loss/dense_3_loss/Gather = Gather[Tindices=DT_INT64, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dense_3/BiasAdd, metrics/acc/ArgMax)]]

<小时/>

模型总结

<小时/>
_________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 64, 50, 1) 0
____________________________________________________________________________________________________
input_2 (InputLayer) (None, 64, 50, 1) 0
____________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 32, 25, 16) 272 input_1[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 32, 25, 16) 272 input_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 16, 12, 16) 0 conv2d_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 16, 12, 16) 0 conv2d_2[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 15, 11, 32) 2080 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 15, 11, 32) 2080 max_pooling2d_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 8, 6, 32) 0 conv2d_3[0][0]
____________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 8, 6, 32) 0 conv2d_4[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1536) 0 max_pooling2d_3[0][0]
____________________________________________________________________________________________________
flatten_2 (Flatten) (None, 1536) 0 max_pooling2d_4[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 3072) 0 flatten_1[0][0]
flatten_2[0][0]
____________________________________________________________________________________________________
input_3 (InputLayer) (None, 256) 0
____________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 3328) 0 concatenate_1[0][0]
input_3[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 512) 1704448 concatenate_2[0][0]
____________________________________________________________________________________________________
dense_2 (Dense) (None, 256) 131328 dense_1[0][0]
____________________________________________________________________________________________________
dense_3 (Dense) (None, 256) 65792 dense_2[0][0]
====================================================================================================
Total params: 1,906,272
Trainable params: 1,906,272
Non-trainable params: 0

最佳答案

Argmax 是从最后一个轴获取,而 Gather 是从第一个轴获取。两个轴上的元素数量不同,因此这是预期的。

如果仅在类上工作,请使用最后一个轴,因此我们将围绕收集方法进行怪异:

def lossFunction(self,y_true,y_pred):

maxi=K.argmax(y_true) #ok

#invert the axes
y_pred = K.permute_dimensions(y_pred,(1,0))

return K.mean((K.max(y_true,axis=-1) -(K.gather(y_pred,maxi)))**2)

关于python - 将从 keras.backend.argmax 返回的张量作为索引传递给 keras.backend,gather 预计为 'An integer tensor of indices.',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49885298/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com