gpt4 book ai didi

machine-learning - 在 TensorFlow 中使用 ReLU 构建非线性模型

转载 作者:行者123 更新时间:2023-11-30 08:36:31 24 4
gpt4 key购买 nike

我正在尝试在 TensorFlow 中构建一个简单的非线性模型。我创建了这个示例数据:

x_data = np.arange(-100, 100).astype(np.float32)
y_data = np.abs(x_data + 20.)

enter image description here

我想这个形状应该可以使用几个 ReLU 轻松重建,但我不知道如何实现。

到目前为止,我的方法是用 ReLU 包装线性组件,但这不会运行:

W1 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
W2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b1 = tf.Variable(tf.zeros([1]))
b2 = tf.Variable(tf.zeros([1]))

y = tf.nn.relu(W1 * x_data + b1) + tf.nn.relu(W2 * x_data + b2)

关于如何在 TensorFlow 中使用 ReLU 表达此模型有什么想法吗?

最佳答案

我想您是在问如何将 ReLU 结合到工作模型中?下面显示了两个选项:

选项1)将ReLU1输入到ReLU2

这可能是首选方法。请注意,r1r2 的输入。

x = tf.placeholder('float', shape=[None, 1])
y_ = tf.placeholder('float', shape=[None, 1])

W1 = weight_variable([1, hidden_units])
b1 = bias_variable([hidden_units])
r1 = tf.nn.relu(tf.matmul(x, W1) + b1)

# Input of r1 into r2 (which is just y)
W2 = weight_variable([hidden_units, 1])
b2 = bias_variable([1])
y = tf.nn.relu(tf.matmul(r1,W2)+b2) # ReLU2

选项 2)添加 ReLU1 和 ReLU2

选项 2 已在原始问题中列出,但我不知道这是否是您真正想要的...请阅读下面的完整工作示例并尝试一下。我想您会发现它的建模效果不佳。

x = tf.placeholder('float', shape=[None, 1])
y_ = tf.placeholder('float', shape=[None, 1])

W1 = weight_variable([1, hidden_units])
b1 = bias_variable([hidden_units])
r1 = tf.nn.relu(tf.matmul(x, W1) + b1)

# Add r1 to r2 -- won't be able to reduce the error.
W2 = weight_variable([1, hidden_units])
b2 = bias_variable([hidden_units])
r2 = tf.nn.relu(tf.matmul(x, W2) + b2)
y = tf.add(r1,r2) # Again, ReLU2 is just y

完整工作示例

下面是一个完整的工作示例。默认情况下它使用选项 1,但是选项 2 也包含在注释中。

 from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Config the matlotlib backend as plotting inline in IPython
%matplotlib inline


episodes = 55
batch_size = 5
hidden_units = 10
learning_rate = 1e-3

def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)

def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)


# Produce the data
x_data = np.arange(-100, 100).astype(np.float32)
y_data = np.abs(x_data + 20.)

# Plot it.
plt.plot(y_data)
plt.ylabel('y_data')
plt.show()

# Might want to randomize the data
# np.random.shuffle(x_data)
# y_data = np.abs(x_data + 20.)

# reshape data ...
x_data = x_data.reshape(200, 1)
y_data = y_data.reshape(200, 1)

# create placeholders to pass the data to the model
x = tf.placeholder('float', shape=[None, 1])
y_ = tf.placeholder('float', shape=[None, 1])

W1 = weight_variable([1, hidden_units])
b1 = bias_variable([hidden_units])
r1 = tf.nn.relu(tf.matmul(x, W1) + b1)

# Input of r1 into r2 (which is just y)
W2 = weight_variable([hidden_units, 1])
b2 = bias_variable([1])
y = tf.nn.relu(tf.matmul(r1,W2)+b2)

# OPTION 2
# Add r1 to r2 -- won't be able to reduce the error.
#W2 = weight_variable([1, hidden_units])
#b2 = bias_variable([hidden_units])
#r2 = tf.nn.relu(tf.matmul(x, W2) + b2)
#y = tf.add(r1,r2)


mean_square_error = tf.reduce_sum(tf.square(y-y_))
training = tf.train.AdamOptimizer(learning_rate).minimize(mean_square_error)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

min_error = np.inf
for _ in range(episodes):
# iterrate trough every row (with batch size of 1)
for i in range(x_data.shape[0]-batch_size+1):
_, error = sess.run([training, mean_square_error], feed_dict={x: x_data[i:i+batch_size], y_:y_data[i:i+batch_size]})
if error < min_error :
min_error = error
if min_error < 3:
print(error)
#print(error)
#print(error, x_data[i:i+batch_size], y_data[i:i+batch_size])


# error = sess.run([training, mean_square_error], feed_dict={x: x_data[i:i+batch_size], y_:y_data[i:i+batch_size]})
# if error != None:
# print(error)


sess.close()

print("\n\nmin_error:",min_error)

在 jupiter 笔记本中可能会更容易看到 here

关于machine-learning - 在 TensorFlow 中使用 ReLU 构建非线性模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36637901/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com