gpt4 book ai didi

tensorflow - 使用tf.metrics.mean_absolute_error时,获取 'AttributeError: '元组'对象没有属性 'dtype''

转载 作者:行者123 更新时间:2023-12-03 07:45:12 25 4
gpt4 key购买 nike

我想用一个隐藏层训练一个非常简单的网络,但是我似乎无法训练该网络。我一直在标题中得到错误。但是,当我将损失定义为y - a2时,没有问题(除非结果是所有Nan,而不是我期望的结果)。我想念什么?

import tensorflow as tf
import numpy as np

# import data
X = np.array([[0,0,1], #XOR prob
[0,1,1],
[1,0,1],
[1,1,1],])


# output dataset, same as before
y = np.array([[0,1,1,0]]).T


# ----------------design network architecture
# define variables

X = tf.convert_to_tensor(X, dtype=tf.float32) # convert np X to a tensor
y = tf.convert_to_tensor(y, dtype=tf.float32) # convert np y to a tensor
W1 = tf.Variable(tf.random_normal([3, 4]))
W2 = tf.Variable(tf.random_normal([4, 1]))
a1 = tf.matmul(X, W1)
a2 = tf.matmul(a1, W2)

# define operations

# ---------------define loss and select training algorithm
loss = tf.metrics.mean_absolute_error(labels=y, predictions=a2)
#loss = y - a2
optimizer = tf.train.GradientDescentOptimizer(0.1)
train = optimizer.minimize(loss)

# ----------------run graph to train and get result
with tf.Session() as sess:

#initialize variables
sess.run(tf.initialize_all_variables())

for i in range(60000):
sess.run(train)
if i % 10000 == 0:
print("Loss: ", sess.run(loss))

print("Activation: ", sess.run(a2))
print("Loss: ", sess.run(loss))

最佳答案

@伊曼纽尔·阿科萨(Emmanuel Akosah)

这是运行的代码。我刚刚修改了损失函数以及global_variables_initializer()。

import tensorflow as tf
import numpy as np

# import data
X = np.array([[0,0,1], #XOR prob
[0,1,1],
[1,0,1],
[1,1,1],])


# output dataset, same as before
y = np.array([[0,1,1,0]]).T


# ----------------design network architecture
# define variables

X = tf.convert_to_tensor(X, dtype=tf.float32) # convert np X to a tensor
y = tf.convert_to_tensor(y, dtype=tf.float32) # convert np y to a tensor
W1 = tf.Variable(tf.random_normal([3, 4]))
W2 = tf.Variable(tf.random_normal([4, 1]))
a1 = tf.matmul(X, W1)
a2 = tf.matmul(a1, W2)

# define operations

# ---------------define loss and select training algorithm
#loss = tf.metrics.mean_absolute_error(labels=y, predictions=a2) #commented this line
loss=tf.keras.losses.mean_absolute_error(y_true=y, y_pred=a2) #added this line
#loss = y - a2
optimizer = tf.train.GradientDescentOptimizer(0.1)
train = optimizer.minimize(loss)

# ----------------run graph to train and get result
with tf.Session() as sess:

#initialize variables
#sess.run(tf.initialize_all_variables()) #commented this line
sess.run(tf.global_variables_initializer()) #added this line

for i in range(60000):
sess.run(train)
if i % 10000 == 0:
print("Loss: ", sess.run(loss))

print("Activation: ", sess.run(a2))
print("Loss: ", sess.run(loss))

输出如下
Loss:  [0.04997864 1.2521797  1.6842678  0.8864688 ]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]
Activation: [[0.44828027]
[0.07319663]
[0.6571804 ]
[0.2820967 ]]
Loss: [0.44828027 0.92680335 0.34281957 0.2820967 ]

我认为最好遵循tensorflow网站上的 this regression tutorial更新代码。您可以使用其他损失函数和优化器,以获得更好的结果。

如果您认为此答案有用,请接受此答案和/或对其进行投票。谢谢!

关于tensorflow - 使用tf.metrics.mean_absolute_error时,获取 'AttributeError: '元组'对象没有属性 'dtype'',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54355312/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com