gpt4 book ai didi

numpy - 在 tensorflow 中编写自定义成本函数

转载 作者:行者123 更新时间:2023-12-04 08:45:38 24 4
gpt4 key购买 nike

我正在尝试在 tensorflow 中编写自己的成本函数,但显然我无法“切片”张量对象?

import tensorflow as tf
import numpy as np

# Establish variables
x = tf.placeholder("float", [None, 3])
W = tf.Variable(tf.zeros([3,6]))
b = tf.Variable(tf.zeros([6]))

# Establish model
y = tf.nn.softmax(tf.matmul(x,W) + b)

# Truth
y_ = tf.placeholder("float", [None,6])

def angle(v1, v2):
return np.arccos(np.sum(v1*v2,axis=1))

def normVec(y):
return np.cross(y[:,[0,2,4]],y[:,[1,3,5]])

angle_distance = -tf.reduce_sum(angle(normVec(y_),normVec(y)))
# This is the example code they give for cross entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

我收到以下错误: TypeError: Bad slice index [0, 2, 4] of type <type 'list'>

最佳答案

目前,tensorflow can't gather on axes other than the first - it's requested .

但是对于你在这种特定情况下想要做什么,你可以转置,然后收集0,2,4,然后转置回来。它不会很快疯狂,但它有效:

tf.transpose(tf.gather(tf.transpose(y), [0,2,4]))

对于当前gather 实现中的一些限制,这是一个有用的解决方法。

(但您不能在 tensorflow 节点上使用 numpy 切片也是正确的 - 您可以运行它并对输出进行切片,并且您还需要在运行之前初始化这些变量。:)。你正在以一种不起作用的方式混合 tf 和 np 。
x = tf.Something(...)

是一个 tensorflow 图对象。 Numpy 不知道如何处理这些对象。
foo = tf.run(x)

又回到了python可以处理的对象。

您通常希望将损失计算保持在纯 tensorflow 中,因此在 tf.您可能需要长时间执行 arccos,因为 tf 没有它的功能。

关于numpy - 在 tensorflow 中编写自定义成本函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33684441/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com