gpt4 book ai didi

python - TensorFlow 中的成本不变

转载 作者:行者123 更新时间:2023-11-30 09:43:59 25 4
gpt4 key购买 nike

我正在尝试在 tensorflow 中构建一个神经网络以更好地学习该库,并且我的损失值没有改变。这是我的代码:

import tensorflow as tf
import numpy as np
import pandas as pd
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

all_data = pd.read_csv('/projects/data/testfile.csv')
all_data = all_data.values

size_layer1 = 1
size_layer2 = 10
size_layer3 = 1

labels = all_data[:, 9]; labels = tf.convert_to_tensor(labels, np.float32); labels = tf.reshape(labels, [985, 1])
data = all_data[:, 6]; data = tf.convert_to_tensor(data, np.float32)
theta1 = tf.Variable(tf.zeros([size_layer2, size_layer1])); theta1 = tf.reshape(theta1, [10, 1])
theta2 = tf.Variable(tf.zeros([size_layer3, size_layer2])); theta2 = tf.reshape(theta2, [1, 10])

a1 = data; a1 = tf.reshape(a1, [1, 985])
z2 = tf.matmul(theta1, a1)
a2 = tf.nn.relu(z2)
z3 = tf.matmul(theta2, a2)
a3 = tf.nn.sigmoid(z3)
h = tf.transpose(a3)

cost = tf.losses.mean_squared_error(labels, h)
train = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)
for i in range(10):
sess.run(train)
print(sess.run(cost))

我的整个数据集是 985x12,但大多数列都是文本,所以我隔离了两列。我知道神经网络不应该这样使用,具有 1:10:1 节点系统和实数标签,但我并不是想优化网络,只是学习语言。我知道我应该使用特征缩放/均值归一化,但正如我所说,我并不是真的试图完美地优化神经网络。这是我的输出:

73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0
73948990000.0

我已经尝试了很多事情。最初,我的成本函数是普通的交叉熵,但由于我的数据是实数值,所以我将其更改为均方误差。我也尝试过更改优化器,但它没有改变任何东西。问题是我没有尝试很好地设计网络并且使用了糟糕的架构,还是其他原因?

最佳答案

初始权重theta1theta2是零数组,不能用于训练。权重用于计算在训练期间更新权重的增量值,这会将增量清零,因此权重不会改变。此外,如果所有权重都是相同的值(除了零),它们将具有相同的增量,这也会阻止学习。因此,初始权重需要是随机数。

尝试使用它来初始化随机权重:

theta1 = tf.get_variable('theta1', shape=(size_layer2, size_layer1), initializer=tf.contrib.layers.xavier_initializer())
theta2 = tf.get_variable('theta2', shape=(size_layer3, size_layer2), initializer=tf.contrib.layers.xavier_initializer())

关于python - TensorFlow 中的成本不变,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54907610/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com