gpt4 book ai didi

python - pycaffe 中的求解器回调函数是什么,我该如何使用它们?

转载 作者:太空宇宙 更新时间:2023-11-03 15:59:05 26 4
gpt4 key购买 nike

查看this PR ,我看到可以为 caffe.Solver 对象定义 on_starton_gradient 回调。

import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient) # <- ??

on_starton_gradient是什么类型的对象?
这些回调有什么用?
如何使用它们(一个例子会很好...)?

最佳答案

<强>1。在哪里以及如何定义回调?

回调是求解器的一部分,因此在 solver.hpp 中定义文件。确切地说,有一个 Callback类,看起来像这样:

  // Invoked at specific points during an iteration
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;

template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}

和一个protected vector此类回调的集合,它是 Solver 类的成员。

  vector<Callback*> callbacks_;

因此,这基本上为 Solver 类提供了一个 add_callback 函数,它允许将类型为 Callback 的对象添加到向量.这是为了确保每个回调都有两个方法:on_start()on_gradients_ready()

<强>2。在哪里调用回调?

这发生在 solver.cpp文件,在 step()函数,其中包含主工作循环。这是主循环部分(为简单起见,去掉了很多东西):

while (iter_ < stop_iter) {

for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}

// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();

for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}

ApplyUpdate();

++iter_;
}

<强>3。这个用在什么地方?

此回调功能是在添加多 GPU 支持时实现的。唯一使用回调的地方(据我所知)是在多个 GPU 之间同步求解器:

P2PSyncparallel.hpp 中上课继承自 Solver::Callback 类,并实现 on_start()on_gradients_ready()方法,同步 GPU 并最终累积所有梯度更新。

<强>4。如何使用 Python 的回调?

作为拉取请求 #3020解释,

on_start and on_gradient are python functions.

所以它应该是直接使用。 this Github Gist 中显示了完整的可运行示例我创建。

<强>5。这有什么用?

由于这两个回调函数不接受任何参数,您不能简单地使用它们来跟踪丢失或类似的事情。为此,您必须围绕 Solver 类创建一个包装函数,并使用两个方法调用 add_callback 作为回调函数。这允许您使用 self.solver.net 从回调中访问网络。在下面的示例中,我使用 on_start 回调将数据加载到网络中,并使用 on_gradients_ready 回调打印损失函数。

class SolverWithCallback:
def __init__(self, solver_file):
self.solver = caffe.SGDSolver(solver_file)
self.solver.add_callback(self.load, self.loss)

def solve(self):
self.solver.solve()

def load(self):
inp = np.random.randint(0, 255)
self.solver.net.blobs['data'].data[...] = inp
self.solver.net.blobs['labels'].data[...] = 2 * inp

def loss(self):
print "loss: " + str(self.solver.net.blobs['loss'].data)

if __name__=='__main__':
solver = SolverWithCallback('solver.prototxt')
solver.solve()

关于python - pycaffe 中的求解器回调函数是什么,我该如何使用它们?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41490526/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com