gpt4 book ai didi

python - 如何使用不可微的损失函数?

转载 作者:行者123 更新时间:2023-11-30 08:48:06 27 4
gpt4 key购买 nike

我试图在完全连接的神经网络的输出处找到一个密码本,该网络选择的点使得如此生成的密码本之间的最小距离(欧几里德范数)最大化。神经网络的输入是需要映射到输出空间的更高维度的点。

例如,如果输入维度为 2,输出维度为 3,则以下映射(以及任何排列)效果最佳: 00 - 000, 01 - 011, 10 - 101, 11 - 110

import tensorflow as tf
import numpy as np
import itertools


input_bits = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='input_bits')
code_out = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='code_out')
np.random.seed(1331)


def find_code(message):
weight1 = np.random.normal(loc=0.0, scale=0.01, size=[2, 3])
init1 = tf.constant_initializer(weight1)
out = tf.layers.dense(inputs=message, units=3, activation=tf.nn.sigmoid, kernel_initializer=init1)
return out


code = find_code(input_bits)

distances = []
for i in range(0, 3):
for j in range(i+1, 3):
distances.append(tf.linalg.norm(code_out[i]-code_out[j]))
min_dist = tf.reduce_min(distances)
# avg_dist = tf.reduce_mean(distances)

loss = -min_dist

opt = tf.train.AdamOptimizer().minimize(loss)

init_variables = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_variables)

saver = tf.train.Saver()

count = int(1e4)

for i in range(count):
input_bit = [list(k) for k in itertools.product([0, 1], repeat=2)]
code_preview = sess.run(code, feed_dict={input_bits: input_bit})
sess.run(opt, feed_dict={input_bits: input_bit, code_out: code_preview})

由于损失函数本身不可微,所以我收到错误

ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables 

我是在做一些愚蠢的事情还是有办法避免这种情况?感谢这方面的任何帮助。提前致谢。

最佳答案

你的损失函数对于某些参数必须是可微的。在您的情况下,没有参数,因此您将计算常量函数的导数,该导数为 0。此外,在您的代码中您有以下行:

code = find_code(input_bits)

不再使用。根据代码,我假设您想要更改此行:

distances.append(tf.linalg.norm(code_out[i]-code_out[j]))

至:

distances.append(tf.linalg.norm(code[i]-code_out[j]))

因此,您将使用现有的 tf.layers.dense,从而包含一个可用于计算相对于该参数的损失梯度的参数。

<小时/>

此外,您无需担心 TF 操作是否可微。事实上,所有 TF 操作都是可微的。当涉及到tf.reduce_min()时,请查看 this link .

关于python - 如何使用不可微的损失函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57553280/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com