gpt4 book ai didi

machine-learning - tensorflow 中的多热编码(谷歌云机器学习,tf estimator api)

转载 作者:行者123 更新时间:2023-11-30 08:34:02 26 4
gpt4 key购买 nike

我有一个像帖子标签这样的功能。因此,对于每个观察,post_tag 特征可能是诸如“奥斯卡、布拉德皮特、奖项”之类的标签选择。我希望能够使用在谷歌云机器学习上运行的估计器 API 将其作为一个功能传递给 tensorflow 模型构建(按照this example,但针对我自己的问题进行了调整)。

我只是不确定如何将其转换为 tensorflow 中的多热编码功能。我正在尝试获得类似于 MultiLabelBinarizer 的内容理想情况下在sklearn中。

我认为this有点相关,但不完全是我需要的。

所以说我有这样的数据:

id,post_tag
1,[oscars,brad-pitt,awards]
2,[oscars,film,reviews]
3,[matt-damon,bourne]

我想将其特征化,作为 tensorflow 中预处理的一部分,如下:

id,post_tag_oscars,post_tag_brad_pitt,post_tag_awards,post_tag_film,post_tag_reviews,post_tag_matt_damon,post_tag_bourne
1,1,1,1,0,0,0,0
2,1,0,0,1,1,0,0
3,0,0,0,0,0,1,1

更新

如果我的 post_tag_list 是输入 csv 中的“oscars、brad-pitt、awards”之类的字符串。如果我尝试这样做:

INPUT_COLUMNS = [
...
tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer('post_tag_list',
tf.range(0, 10, dtype=tf.int64),
tf.string, tf.int64),
default_value=10, name='post_tag_list'),
...]

我收到此错误:

Traceback (most recent call last):
File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main
"__main__", fname, loader, pkg_name)
File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/andrew_maguire/localDev/codeBase/pmc-analytical-data-mart/clickmodel/trainer/task.py", line 4, in <module>
import model
File "trainer/model.py", line 49, in <module>
default_value=10, name='post_tag_list'),
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/lookup_ops.py", line 276, in __init__
super(HashTable, self).__init__(table_ref, default_value, initializer)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/lookup_ops.py", line 162, in __init__
self._init = initializer.initialize(self)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/lookup_ops.py", line 348, in initialize
table.table_ref, self._keys, self._values, name=scope)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_lookup_ops.py", line 205, in _initialize_table_v2
values=values, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2632, in create_op
set_shapes_for_outputs(ret)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1911, in set_shapes_for_outputs
shapes = shape_func(op)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1861, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/common_shapes.py", line 595, in call_cpp_shape_fn
require_shape_fn)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/common_shapes.py", line 659, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Shape must be rank 1 but is rank 0 for 'key_value_init' (op: 'InitializeTableV2') with input shapes: [], [], [10].

如果我将每个 post_tag_list 填充为“oscars,brad-pitt,awards,OTHER,OTHER,OTHER,OTHER,OTHER,OTHER,OTHER”,那么它总是 10 长。这是否是一个潜在的解决方案。

或者我是否需要以某种方式知道我可能在这里传递的所有帖子标签的大小(有点不好定义为一直创建的新标签)。

最佳答案

您尝试过 tf.contrib.lookup.Hashtable 吗?

这是我自己使用的示例用法:https://github.com/TensorLab/tensorfx/blob/master/src/data/_transforms.py#L160以及基于此编写的示例片段:

import tensorflow as tf
session = tf.InteractiveSession()

entries = ['red', 'blue', 'green']
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(entries,
tf.range(0, len(entries), dtype=tf.int64),
tf.string, tf.int64),
default_value=len(entries), name='entries')
tf.tables_initializer().run()

value = tf.constant([['blue', 'red'], ['green', 'red']])
print(table.lookup(value).eval())

我相信查找适用于常规张量和稀疏张量(给定可变长度值列表,您最终可能会得到后者)。

关于machine-learning - tensorflow 中的多热编码(谷歌云机器学习,tf estimator api),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46675108/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com