gpt4 book ai didi

Tensorflow:如何通过索引获取张量值并分配新值

转载 作者:行者123 更新时间:2023-12-02 04:29:10 25 4
gpt4 key购买 nike

我正在尝试对 Tensorflow 中张量的最高 K 值进行一些操作。基本上,我想要的是首先获取前 K 值的索引,进行一些操作并分配新值。例如:

A = tf.constant([[1,2,3,4,5],[6,7,8,9,10]])
values, indices = tf.nn.top_k(A, k=3)

在这里,值将是 array([[ 5, 4, 3],[10, 9, 8]],dtype=int32)

我对值做了一些操作后,比如prob=tf.nn.softmax(values),我应该如何根据索引将这个值赋给A?这类似于 numpy A[indices] = prob.似乎无法在 tensorflow 中找到合适的函数来执行此操作。

最佳答案

不幸的是,一旦您想将索引与张量一起使用,Tensorflow 会非常痛苦,因此要实现您的想法,您必须使用一些丑陋的解决方法。我的选择是:

import tensorflow as tf

#First you better use Variable as constant is not designed to be updated
A = tf.Variable(initial_value = [[1,2,3,4,5],[6,7,8,9,10]])

#Create a buffer variable which will store tentative updates,
#initialize it with random values
t = tf.Variable(initial_value = tf.cast(tf.random_normal(shape=[5]),dtype=tf.int32))
values, indices = tf.nn.top_k(A, k=3)

#Create a function for manipulation on the values you want
def val_manipulation(v):
return 2*v+1

#Create a while loop to update each entry of the A one-be-one,
#as scatter_nd_update can update only by slices, but not individual entries
i = tf.constant(0)
#Stop once updated every slice
c = lambda i,x: tf.less(i, tf.shape(A)[0])
#Each iteration update i and
#update every slice of A (A[i]) with updated slice
b = lambda i,x: [i+1,tf.scatter_nd_update(A,[[i]],[tf.scatter_update(tf.assign(t,A[i]),indices[i],val_manipulation(values[i]) )])]
#While loop
r = tf.while_loop(c, b, [i,A])

init = tf.initialize_all_variables()

with tf.Session() as s:
s.run(init)
#Test it!
print s.run(A)
s.run(r)
print s.run(A)

所以基本上你要做的是:

  1. scatter_update 只能使用变量,所以我们从 A 中取出一个切片(作为 A[i])并将这些值存储到缓冲区变量 t
  2. 用所需值更新缓冲区变量中的值
  3. 用更新后的 t 更新 i - A 的第 slice>
  4. 重复 A 的其余条目

最终您应该得到以下输出:

[[ 1  2  3  4  5]  [ 6  7  8  9 10]] 
[[ 1 2 7 9 11] [ 6 7 17 19 21]]

关于Tensorflow:如何通过索引获取张量值并分配新值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50656197/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com