gpt4 book ai didi

python-3.x - python- tf.write_file 在 tensorflow 中不工作

转载 作者:行者123 更新时间:2023-12-04 09:20:31 24 4
gpt4 key购买 nike

我有一个代码可以训练物体检测的物体坐标。我使用了 CNN 网络,输出层是回归层(称为 bound_box_output),它返回图像中对象的 (x0,y0, height, width)。在这一层之后,我尝试在丢失步骤之前直接保存图像。

   i = 0
image_decoded = tf.image.decode_jpeg(tf.read_file('3.jpg'), channels=3)
cropped = tf.image.crop_to_bounding_box(image = image_decoded,
offset_height = tf.cast(bound_box_output[i,0], tf.int32),
offset_width = tf.cast(bound_box_output[i,1], tf.int32),
target_height = tf.cast(bound_box_output[i,2], tf.int32),
target_width = tf.cast(bound_box_output[i,3], tf.int32))

enc = tf.image.encode_jpeg(cropped)
fname = tf.constant('4.jpeg')
fwrite = tf.write_file(fname, enc)

然后在 tf.train.SessionRunHook 中运行它

def begin(self):
self._step = -1
self._start_time = time.time()

def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
if self._step % LOG_FREQUENCY == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time

loss_value = run_values.results
examples_per_sec = LOG_FREQUENCY * BATCH_SIZE / duration
sec_per_batch = float(duration / LOG_FREQUENCY)


format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))

if self._step == MAX_STEPS-1:
loss_value = run_values.results
print("The final value of loss is:: ")
print(loss_value)
print(fwrite)
tf.train.SessionRunArgs(fwrite)

问题是它没有将“4.jpeg”图像保存在特定文件夹中

注意:我使用的是tensorflow 1.1.3python3.5

最佳答案

TLDR;将 tf.train.SessionRunArgs(fwrite) 替换为 run_context.session.run(fwrite)

SessionRunArgs实际上并不运行提供的操作。 SessionRunArgsbefore_run() 返回称呼。它们的作用是为下一个 session.run() 调用添加参数。

if self._step  == MAX_STEPS-1:
loss_value = run_values.results
print("The final value of loss is:: ")
print(loss_value)
print(fwrite)
tf.train.SessionRunArgs(fwrite) # problematic line

您正试图在 after_run() 结束时运行 fwrite 操作.但是,它只是实例化了 SessionRunArgs。对象。

实现所需行为的一个选项是利用提供给 after_run()run_context 参数. run_context 的类型为 SessionRunContext ,包含 session 引用的类型。

run_context.session.run(fwrite) 应该可以帮到您。

关于python-3.x - python- tf.write_file 在 tensorflow 中不工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46544088/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com