gpt4 book ai didi

python - TensorFlow,什么时候可以使用类似 Python 的负索引?

转载 作者:行者123 更新时间:2023-11-28 19:07:15 25 4
gpt4 key购买 nike

我是 TensorFlow(版本 1.2)的新手,但不是 Python 或 Numpy 的新手。我正在建立一个模型来预测蛋白质分子的形状。我需要将 TensorFlow 的标准 tf.losses.cosine_distance 函数包装在一些额外的代码中,因为我需要停止将某些 NaN 值传播到损失计算中。

我确切地知道哪些单元格将是 NaN。无论我的机器学习系统对这些细胞的预测如何,都不算数。我计划在对损失函数求和之前将 tf.losses.cosine_distance 输出的 NaN 部分变为零。

这是一段工作代码,使用 tf.scatter_nd_update 进行元素分配:

def custom_distance(predict, actual):
with tf.name_scope("CustomDistance"):
loss = tf.losses.cosine_distance(predict, actual, -1,
reduction=tf.losses.Reduction.NONE)
loss = tf.Variable(loss) # individual elements can be modified
indices = tf.constant([[0,0,0],[29,1,0],[29,2,0]])
updates = tf.constant([0., 0., 0.])
loss = tf.scatter_nd_update(loss, indices, updates)
return loss

但是,这只适用于我拥有的一种长度为 30 个氨基酸的蛋白质。如果我有不同长度的蛋白质怎么办?我会有很多。 在 Numpy 中,我只使用 Python 的负索引,并用 -1 代替索引行上的两个 29。 Tensorflow 不会接受。如果我进行替换,我会得到很长的回溯,但我认为其中最重要的部分是:

File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [0,1] = [-1, 1, 0] is not in [0, 30)

(我还可以修改预测张量,以便在计算损失之前,所讨论的单元格与实际张量完全匹配,但在每种情况下,挑战都是相同的: 分配 TensorFlow 对象中各个元素的值。)

我应该忘记 TensorFlow 中的负索引吗?我正在仔细阅读 TensorFlow 文档以了解解决此问题的正确方法。我假设我可以沿着主轴检索我的输入张量的长度并使用它。但在看到 TensorFlow 和 Numpy 之间的强烈相似之处后,我不得不怀疑这是否笨拙。

感谢您的建议。

最佳答案

它可以与 tensorflow 绑定(bind)到 python 切片运算符一起使用。因此,例如,loss[-1]loss 的有效切片。

在你的例子中,如果你只有三个切片,你可以单独分配它们:

update_op0 = indices[0,0,0].assign(updates[0])
update_op1 = indices[-1,1,0].assign(updates[1])
update_op2 = indices[-1,2,0].assign(updates[2])

如果您有比这更多的切片,或者切片数量可变,则之前的方法不实用。您可以编写一个像这样的小辅助函数来将“正或负索引”转换为“仅正索引”:

def to_pos_idx(idx, x):
# todo: shape & bound checking
idx = tf.convert_to_tensor(idx)
s = tf.shape(x)[:tf.size(idx)]
idx = tf.where(idx < 0, s + idx, idx)
return idx

并像这样修改你的代码:

indices = tf.constant([[0,0,0],[-1,1,0],[-1,2,0]])
indices = tf.map_fn(lambda i: to_pos_idx(i, loss), indices) # transform indices here
updates = tf.constant([0., 0., 0.])
loss = tf.scatter_nd_update(loss, indices, updates)

关于python - TensorFlow,什么时候可以使用类似 Python 的负索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45434781/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com