gpt4 book ai didi

python - Tensorflow:feed_dict 的形状错误{}

转载 作者:行者123 更新时间:2023-11-30 22:27:31 24 4
gpt4 key购买 nike

第一次遇到这样的问题。

错误是关于 feed_dict={tfkids: kids, tfkids_fit: kids_fit} 的,似乎需要 reshape kids_fit

谁能帮我解决这个问题吗?

import tensorflow as tf
from tensorflow.contrib.distributions import Normal
import numpy as np
import matplotlib.pyplot as plt

DNA_SIZE = 1
POP_SIZE = 10
LR = 0.1
N_GENERATION = 50

def F(x):
return x**2

def get_fitness(value):
return -value

mean = tf.Variable(tf.constant(13.), dtype=tf.float32)
sigma = tf.Variable(tf.constant(5.), dtype=tf.float32)
N_dist = Normal(loc=mean, scale=sigma)
make_kids = N_dist.sample([POP_SIZE])

tfkids = tf.placeholder(tf.float32, [POP_SIZE, DNA_SIZE])
tfkids_fit = tf.placeholder(tf.float32, [POP_SIZE])
loss = -tf.reduce_mean(N_dist.log_prob(tfkids) * tfkids_fit)
train_op = tf.train.GradientDescentOptimizer(LR).minimize(loss)

x = np.linspace(-20, 20, 100)
plt.plot(x, F(x))

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

plt.ion()
for g in range(N_GENERATION):
kids = sess.run(make_kids)
kids_fit = get_fitness(F(kids))
sess.run(train_op, feed_dict={tfkids: kids, tfkids_fit: kids_fit})

if "plot_points" in globals():
plot_points.remove()

plot_points = plt.scatter(kids, F(kids), s=30)
plt.pause(0.05)

plt.ioff()
plt.show()

尝试测试代码时会出现错误。

ValueError: Cannot feed value of shape (10,) for Tensor 'Placeholder:0', which has shape '(10, 1)'

最佳答案

您的 Placeholder:0tfkids = tf.placeholder(tf.float32, [POP_SIZE, DNA_SIZE])

如您所见,tfkids 形状为 [POP_SIZE, DNA_SIZE] = (10, 1)

您的 kids 变量的 shape = (10)

尽管两个形状都包含 10 个值,但第一个形状有 2 维,而第二个形状为 1 维。

因此,您必须扩展 kids 变量的维度,以便以这种方式与 tfkids 兼容:

sess.run(train_op, feed_dict={tfkids: np.expand_dims(kids, axis=1), tfkids_fit: kids_fit})

np.expand_dims 允许您向 kids 形状添加一维尺寸

关于python - Tensorflow:feed_dict 的形状错误{},我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46892015/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com