gpt4 book ai didi

python - 如何循环张量对象直到满足条件

转载 作者:行者123 更新时间:2023-12-01 06:22:10 26 4
gpt4 key购买 nike

我有一个像这样的张量:

masked_bad_col = [[False  True  True False  True  True  True  True  True  True  True False]]

我想循环遍历这个张量,直到所有元素都为True。所以我有另一个函数,它将更新这个张量,我们称之为唯一性

def uniqueness():

'blah blah blha'
return tensor1, updated_masked_bad_col

我查看了文档并了解到我可以使用 tf.while_loop 来做到这一点。尽管如此,我找不到任何关于 bool 值的例子。这是我到目前为止所做的:

tensor1, _ = tf.while_loop(masked_bad_col != True, uniqueness)

这显然是不正确的,但不知道如何使用masked_bad_col的每个元素作为继续循环uniqueness函数的条件。

更新 1这是我试图在循环中调用的方法:

corpus = load_corpus('path_to_corpus/train.corpus')
topics = []
vocab, docs = corpus['vocab'], corpus['docs']
number_of_topics = 0
encoder_model = load_keras_model(
'path_to_model/encoder_model',
custom_objects={"KCompetitive": KCompetitive})
weights = encoder_model.get_weights()[0]
for idx in range(encoder_model.output_shape[1]):
token_idx = np.argsort(weights[:, idx])[::-1][:20]
topics.append([(revdict(vocab)[x]) for x in token_idx])
number_of_topics += 1

nparr = np.asarray(topics)
# print nparr.shape

unique, indices, count = np.unique(nparr, return_inverse=True, return_counts=True)

tensor1 = (np.sum(count[indices].reshape(nparr.shape), axis=1).reshape(1, nparr.shape[0]) / (
number_of_topics * 20))

def uniqueness_score():
corpus = load_corpus('path_to_corpus/train.corpus')
topics = []
vocab, docs = corpus['vocab'], corpus['docs']
number_of_topics = 0
encoder_model = load_keras_model(
'path_to_model/encoder_model',
custom_objects={"KCompetitive": KCompetitive})
weights = encoder_model.get_weights()[0]
for idx in range(encoder_model.output_shape[1]):
token_idx = np.argsort(weights[:, idx])[::-1][:20]
topics.append([(revdict(vocab)[x]) for x in token_idx])
number_of_topics += 1

nparr = np.asarray(topics)

unique, indices, count = np.unique(nparr, return_inverse=True, return_counts=True)

tensor1 = (np.sum(count[indices].reshape(nparr.shape), axis=1).reshape(1, nparr.shape[0]) / (
number_of_topics * 20))
return tensor1

这就是我在 while_loop 中调用此方法的方式

with tf.Session() as sess:

tensor2, _ = tf.while_loop(
# Loop condition (negated goal condition)
lambda tensor1: ~tf.math.reduce_all(tensor1 > tf.reduce_mean(tensor1)),
# Loop body
lambda tensor1: uniqueness_score(),
# Loop variables
[tensor1])
# Returned loop value
print(tensor2.eval())

最佳答案

我想我或多或少知道你想要什么,但我不确定我是否需要 bool 数组。如果您想要执行一些迭代过程,计算或检索某些值直到它们满足某些条件,则无需额外的数组即可实现。例如,请参阅此循环对一些随机值进行采样,直到所有值都满足条件:

import tensorflow as tf

# Draw five random numbers until all are > 0.5
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(0)
# Initial values, here simply initialized to zero
tensor1 = tf.zeros([5], dtype=tf.float32)
# Loop
tensor1 = tf.while_loop(
# Loop condition (negated goal condition)
lambda tensor1: ~tf.math.reduce_all(tensor1 > 0.5),
# Loop body
lambda tensor1: tf.random.uniform(tf.shape(tensor1), dtype=tensor1.dtype),
# Loop variables
[tensor1])
# Returned loop value
print(tensor1.eval())
# [0.7778928 0.9396961 0.572209 0.6187117 0.89615726]

看看这是否有帮助,如果您仍然不确定如何将其应用于您的特定案例,请留下评论。

<小时/>

编辑:再次看到你的问题,你的唯一性函数计算了tensor1和掩码,所以也许更相似的类似代码是这样的:

import tensorflow as tf

def sample_numbers(shape, dtype):
tensor1 = tf.random.uniform(shape, dtype=dtype)
mask = tensor1 > 0.5
return tensor1, mask

# Draw five random numbers until all are > 0.5
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(0)
# Initial values, here simply initialized to zero
tensor1 = tf.zeros([5], dtype=tf.float32)
mask = tf.zeros(tf.shape(tensor1), dtype=tf.bool)
# Loop
tensor1, _ = tf.while_loop(
# Loop condition (negated goal condition)
lambda tensor1, mask: ~tf.math.reduce_all(mask),
# Loop body
lambda tensor1, mask: sample_numbers(tf.shape(tensor1), tensor1.dtype),
# Loop variables
[tensor1, mask])
# Returned loop value
print(tensor1.eval())
# [0.95553064 0.5170193 0.69573617 0.9501506 0.99776053]

关于python - 如何循环张量对象直到满足条件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60311184/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com