gpt4 book ai didi

python - tensorflow 简单逻辑回归

转载 作者:行者123 更新时间:2023-11-30 09:00:11 26 4
gpt4 key购买 nike

你好

我只想尝试使用简单的逻辑回归进行二元分类。我的未标记输出数据为 {1,0}//(他/她是否通过了考试)成本函数返回 (NaN)。出了什么问题?

learning_rate = 0.05
total_iterator = 1500
display_per = 100

data = numpy.loadtxt("ex2data1.txt",dtype=numpy.float32,delimiter=",");

training_X = numpy.asarray(data[:,[0,1]]) # 100 x 2

training_X 包含 100 x 2 矩阵作为考试分数。例如 [98.771 4.817]

training_Y = numpy.asarray(data[:,[2]],dtype=numpy.int) # 100 x 1 

training_Y 包含 100x1 数组,[1] [0] [0] [1] 由于 stackoverflow 格式,我无法逐行写入

m = data.shape[0]

x_i = tf.placeholder(tf.float32,[None,2]) # None x 2
y_i = tf.placeholder(tf.float32,[None,1]) # None x 1

W = tf.Variable(tf.zeros([2,1])) # 2 x 1
b = tf.Variable(tf.zeros([1])) # 1 x 1

h = tf.nn.softmax(tf.matmul(x_i,W)+b)

cost = tf.reduce_sum(tf.add(tf.multiply(y_i,tf.log(h)),tf.multiply(1-
y_i,tf.log(1-h)))) / -m

我尝试使用简单的逻辑成本函数。它返回“NaN”。我认为我的成本函数完全是垃圾,使用了 tensorflow 示例的成本函数:

 cost = tf.reduce_mean(-tf.reduce_sum(y_i*tf.log(h), reduction_indices=1))

但效果并不理想。

initializer= tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)

print("cost: ", sess.run(cost, feed_dict={x_i:training_X,
y_i:training_Y}), "w: ", sess.run(W),"b: ", sess.run(b))

最佳答案

函数tf.nn.softmax期望logits的数量(最后一个维度)等于类的数量(在你的例子中为2{1,0})。由于您的情况的最后一个维度是 1,softmax 将始终返回 1(属于唯一可用类的概率始终为 1,因为不存在其他类)。因此,h 是一个充满 1 的张量,tf.log(1-h) 将返回负无穷大。无穷大乘以零(某些行中的 1-y_i)返回 NaN。

您应该将 tf.nn.softmax 替换为 tf.nn.sigmoid

可能的解决方法是:

h = tf.nn.sigmoid(tf.matmul(x_i,W)+b)
cost = tf.reduce_sum(tf.add(tf.multiply(y_i,tf.log(h)),tf.multiply(1-
y_i,tf.log(1-h)))) / -m

或者更好,您可以使用 tf.sigmoid_cross_entropy_with_logits
在这种情况下,应该按如下方式完成:

h = tf.matmul(x_i,W)+b
cost = tf.reduce_mean(tf.sigmoid_cross_entropy_with_logits(labels=y_i, logits=h))

此函数在数值上比使用 tf.nn.sigmoid 后跟 cross_entropy 函数更稳定,如果 tf.nn.sigmoid 接近 0 或,则 cross_entropy 函数可以返回 NaN 1 由于 float32 的不精确性。

关于python - tensorflow 简单逻辑回归,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43426454/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com