gpt4 book ai didi

python - 使用 argmax 在 tensorflow 中切片张量

转载 作者:太空狗 更新时间:2023-10-29 22:28:44 24 4
gpt4 key购买 nike

我想在tensorflow中做一个动态损失函数。我想计算信号 FFT 的能量,更具体地说,只计算最主要峰值周围大小为 3 的窗口。我无法在 TF 中实现,因为它会抛出很多错误,例如 StrideInvalidArgumentError(回溯见上文):Expected begin, end, and strides to be 1D equal size tensors,但取而代之的是形状 [1,64]、[1,64] 和 [1]。

我的代码是这样的:

self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)

因为我正在计算 64 位的批处理和数据维度也是 64 位,所以 self.signal 的形状是 (64,64)。我只想计算 FFT 的交流分量。由于信号是实值的,因此只有一半的频谱可以完成这项工作。因此,self.spec_mag 的形状是 (64,32)

此 fft 中的最大值位于 self.argm,其形状为 (64,1)

现在我想通过以下方式计算最大峰值附近 3 个元素的能量:self.spec_mag[self.argm-1:self.argm+2]

但是,当我运行代码并尝试获取 self.frac 的值时,我遇到了多个错误。

最佳答案

访问 argm 时似乎缺少索引。这里是1的固定版本,64位版本。

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

这里是扩展版本(batch, m, n)

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

你可能想要修复函数名称,因为我在较新版本的 tensorflow 上编辑了这段代码。

关于python - 使用 argmax 在 tensorflow 中切片张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48680607/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com