gpt4 book ai didi

python - tensorflow : Transform class name to class index

转载 作者:行者123 更新时间:2023-11-28 18:15:33 25 4
gpt4 key购买 nike

我正在使用 tensorflow 研究机器学习。

问题:

我不知道如何将类名转换为类索引。

示例:

预期映射:

Car  ---> 0
Bike ---> 1
Boat ---> 2

代码:

#!/usr/bin/env python3.6

import tensorflow as tf

names = [
"Car",
"Bus",
"Boat"
]

_, class_name = tf.TextLineReader(skip_header_lines=1).read(
tf.train.string_input_producer(tf.gfile.Glob("input_file.csv"))
)

# I want to know if it is possible to do that :
# print(sess.run(class_name)) --> "Car"
# class_index = f(class_name, names)
# print(sess.run(class_index)) --> 0

输入文件.csv :

class_name
Car
Car
Boat
Bike
...

最佳答案

最简单的方法是这样的:

class_index = tf.reduce_min(tf.where(tf.equal(names, class_name)))

请注意,它工作正常,虽然该类出现在 names 中,但返回 263 − 1,当它不存在时(如 Bike在你的例子中)。您可以避免这种影响,但删除 tf.reduce_min,但在这种情况下,class_index 将计算为数组,而不是标量。

完整的可运行代码:

names = ["Car", "Bus", "Boat"]

_, class_name = tf.TextLineReader(skip_header_lines=1).read(
tf.train.string_input_producer(tf.gfile.Glob("input_file.csv"))
)
class_index = tf.reduce_min(tf.where(tf.equal(names, class_name)))

with tf.Session() as session:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

for i in range(4):
print(class_name.eval()) # Car, Car, Boat, Bike
for i in range(4):
print(class_index.eval()) # 0, 0, 2, 9223372036854775807

coord.request_stop()
coord.join(threads)

关于python - tensorflow : Transform class name to class index,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48651523/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com