gpt4 book ai didi

python - 将张量作为数组进行迭代 Tensorflow

转载 作者:行者123 更新时间:2023-12-01 01:51:56 25 4
gpt4 key购买 nike

我正在尝试将预测图像保存在我用 Tensorflow 编写的 CNN 网络上。在我的代码中,y_pred_cls 包含我的预测标签,y_pred_cls 是维度为 1 x 批量大小的张量。现在,我想将 y_pred_cls 作为数组进行迭代,并创建一个包含 pred 类、真实类和一些索引号的文件名,然后找出与预测标签相关的图像,并使用 imsave 保存为图像。

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_writer.add_graph(sess.graph)



print("{} Start training...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print("{} Open Tensorboard at --logdir {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), tensorboard_dir))

for epoch in range(FLAGS.num_epochs):
print("{} Epoch number: {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1))
step = 1

# Start training
while step < train_batches_per_epoch:
batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
opt, train_acc = sess.run([optimizer, accuracy], feed_dict={x: batch_xs, y_true: batch_ys})

# Logging
if step % FLAGS.log_step == 0:
s = sess.run(sum, feed_dict={x: batch_xs, y_true: batch_ys})
train_writer.add_summary(s, epoch * train_batches_per_epoch + step)

step += 1

# Epoch completed, start validation
print("{} Start validation".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
val_acc = 0.
val_count = 0
cm_running_total = None

for _ in range(val_batches_per_epoch):
batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
acc, loss , conf_m= sess.run([accuracy, cost, tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes)],
feed_dict={x: batch_tx, y_true: batch_ty})



if cm_running_total is None:
cm_running_total = conf_m
else:
cm_running_total += conf_m


val_acc += acc
val_count += 1

val_acc /= val_count

s = tf.Summary(value=[
tf.Summary.Value(tag="validation_accuracy", simple_value=val_acc),
tf.Summary.Value(tag="validation_loss", simple_value=loss)
])

val_writer.add_summary(s, epoch + 1)
print("{} -- Training Accuracy = {:.4%} -- Validation Accuracy = {:.4%} -- Validation Loss = {:.4f}".format(
datetime.now().strftime('%Y-%m-%d %H:%M:%S'), train_acc, val_acc, loss))

# Reset the dataset pointers
val_preprocessor.reset_pointer()
train_preprocessor.reset_pointer()

print("{} Saving checkpoint of model...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

# save checkpoint of the model
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch.ckpt' + str(epoch+1))
save_path = saver.save(sess, checkpoint_path)
print("{} Model checkpoint saved at {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), checkpoint_path))

batch_tx、batch_ty分别是我的RGB数据和标签。

提前致谢。

最佳答案

要将张量中的数据提取到 python 变量中,请使用

标签 = sess.run(y_pred_cls)

这将为您提供一个用于单热向量标签的数组或用于标量标签的 int 变量。

要将数组保存到图像中,您可以使用 PIL 库

from PIL import Image
img = Image.fromarray(data, 'RGB')
img.save('name.png')

其余部分应该简单明了,

  1. 从batch_tx、batch_ty 和 y_pred_cls 张量中提取数据
  2. 迭代每个三元组
  3. 从当前的x创建一个RGB图像
  4. 创建一个格式为 name = str(y)+'_'+str(y_hat)
  5. 的字符串
  6. 保存您的图片

如果您在执行这些步骤时遇到困难,我可以进一步帮助您

关于python - 将张量作为数组进行迭代 Tensorflow,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50640254/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com