gpt4 book ai didi

python - 如何在自定义操作中改变 Tensorflow 变量?

转载 作者:太空宇宙 更新时间:2023-11-04 02:31:25 25 4
gpt4 key购买 nike

我正在尝试修改简单的 Adding a New Op这样它就不会创建一个新的张量作为返回值,但它实际上会改变输入张量并返回它。我知道这是可能的,因为 scatter Op 正在做同样的事情,但是查看 scatter Op 源代码,鉴于我缺乏 C++ 经验,我不知道该怎么做。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;


REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});



#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {

// Grab the input tensor
Tensor input_tensor = context->mutable_input(0, true);
auto input = input_tensor.flat<int32>();

// We always return the input ref.
context->forward_ref_input_to_ref_output(0, 0);

// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
input(i) = 0;
}
}
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

如果我编译上面的代码并运行一个简单的 Python 脚本来测试它,我会得到以下错误:

Python(14820,0x700003a7e000) malloc: *** error for object 0x7fd5c45a5a88: pointer being freed was not allocated
*** set a breakpoint in malloc_error_break to debug

我需要更改代码中的哪些内容才能满足我的需求?

最佳答案

我觉得你最好修改一下抓取输入输出的过程。实际上根据你的REGISTER_OP,它不是引用输入,所以

context->mutable_input(0, true)

会是

context->input(0)

此外,设置输出将更改为

context->set_output(0, context->input(0))

我觉得设置输出后就可以了。

关于python - 如何在自定义操作中改变 Tensorflow 变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49074392/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com