gpt4 book ai didi

tensorflow - 如何强制自动编码器中的瓶颈产生二进制值?

转载 作者:行者123 更新时间:2023-12-04 10:22:45 24 4
gpt4 key购买 nike

我尝试强制自动编码器中的瓶颈层产生二进制值。我是通过使用 tensorflow.cond 来实现的在自定义损失函数中,惩罚所有不是 0 或 1 的值。但是这种方法非常慢。有没有更好的方法来执行此操作?

def custom_loss(weight):
def loss(y_true, y_pred):
reconstruction_loss = binary_crossentropy(y_true, y_pred)

def binarize_loss(value):
return tf.cond(tf.reduce_mean(value) > 0.5, lambda: tf.abs(value - 1), lambda: tf.abs(value))

binarized_loss_value = tf.map_fn(binarize_loss, neckLayer.output)
return reconstruction_loss + (K.mean(binarized_loss_value , axis=-1) * weight)

return loss

最佳答案

我可能会摆脱 tf.cond语句,因为您可以使用简单的算术来做您想做的事:

def loss(y_true, y_pred):
reconstruction_loss = binary_crossentropy(y_true, y_pred)

binary_neck_loss = tf.abs(0.5 - tf.abs(0.5 - neckLayer.output))

return reconstruction_loss + (K.mean(binary_neck_loss , axis=-1) * weight)

当然,我不确切知道您的数据的形状,但您应该能够从那里推断出来。

关于tensorflow - 如何强制自动编码器中的瓶颈产生二进制值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60771990/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com