- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章tensorflow estimator 使用hook实现finetune方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
为了实现finetune有如下两种解决方案:
model_fn里面定义好模型之后直接赋值 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
def
model_fn(features, labels, mode, params):
# .....
# finetune
if
params.checkpoint_path
and
(
not
tf.train.latest_checkpoint(params.model_dir)):
checkpoint_path
=
None
if
tf.gfile.IsDirectory(params.checkpoint_path):
checkpoint_path
=
tf.train.latest_checkpoint(params.checkpoint_path)
else
:
checkpoint_path
=
params.checkpoint_path
tf.train.init_from_checkpoint(
ckpt_dir_or_file
=
checkpoint_path,
assignment_map
=
{params.checkpoint_scope: params.checkpoint_scope}
# 'OptimizeLoss/':'OptimizeLoss/'
)
|
使用钩子 hooks.
可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定 。
1
2
3
4
5
6
7
8
9
10
11
|
# Define the experiment
experiment
=
tf.contrib.learn.Experiment(
estimator
=
estimator,
# Estimator
train_input_fn
=
train_input_fn,
# First-class function
eval_input_fn
=
eval_input_fn,
# First-class function
train_steps
=
params.train_steps,
# Minibatch steps
min_eval_frequency
=
params.eval_min_frequency,
# Eval frequency
# train_monitors=[], # Hooks for training
# eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps
=
params.eval_steps
# Use evaluation feeder until its empty
)
|
也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定.
不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等).
1
2
3
4
5
6
7
8
9
10
11
12
13
|
def
model_fn(features, labels, mode, params):
# ....
return
tf.estimator.EstimatorSpec(
mode
=
mode,
predictions
=
predictions,
loss
=
loss,
train_op
=
train_op,
eval_metric_ops
=
eval_metric_ops,
# scaffold=get_scaffold(),
# training_chief_hooks=None
)
|
这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:
当前的模式:
mode: A ModeKeys. Specifies if this is training, evaluation or prediction. 。
计算图 。
predictions: Predictions Tensor or dict of Tensor. 。
loss: Training loss Tensor. Must be either scalar, or with shape [1]. 。
train_op: Op for the training step. 。
eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching. 。
导出策略 。
export_outputs: Describes the output signatures to be exported to 。
SavedModel and used during serving. A dict {name: output} where
name: An arbitrary name for this output. 。
output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. 。
chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等 。
training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training. 。
worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等 。
training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training. 。
指定初始化和saver 。
scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training. 。
evaluation钩子 。
evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation. 。
自定义的钩子如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
|
class
RestoreCheckpointHook(tf.train.SessionRunHook):
def
__init__(
self
,
checkpoint_path,
exclude_scope_patterns,
include_scope_patterns
):
tf.logging.info(
"Create RestoreCheckpointHook."
)
#super(IteratorInitializerHook, self).__init__()
self
.checkpoint_path
=
checkpoint_path
self
.exclude_scope_patterns
=
None
if
(
not
exclude_scope_patterns)
else
exclude_scope_patterns.split(
','
)
self
.include_scope_patterns
=
None
if
(
not
include_scope_patterns)
else
include_scope_patterns.split(
','
)
def
begin(
self
):
# You can add ops to the graph here.
print
(
'Before starting the session.'
)
# 1. Create saver
#exclusions = []
#if self.checkpoint_exclude_scopes:
# exclusions = [scope.strip()
# for scope in self.checkpoint_exclude_scopes.split(',')]
#
#variables_to_restore = []
#for var in slim.get_model_variables(): #tf.global_variables():
# excluded = False
# for exclusion in exclusions:
# if var.op.name.startswith(exclusion):
# excluded = True
# break
# if not excluded:
# variables_to_restore.append(var)
#inclusions
#[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]
variables_to_restore
=
tf.contrib.framework.filter_variables(
slim.get_model_variables(),
include_patterns
=
self
.include_scope_patterns,
# ['Conv'],
exclude_patterns
=
self
.exclude_scope_patterns,
# ['biases', 'Logits'],
# If True (default), performs re.search to find matches
# (i.e. pattern can match any substring of the variable name).
# If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
reg_search
=
True
)
self
.saver
=
tf.train.Saver(variables_to_restore)
def
after_create_session(
self
, session, coord):
# When this is called, the graph is finalized and
# ops can no longer be added to the graph.
print
(
'Session created.'
)
tf.logging.info(
'Fine-tuning from %s'
%
self
.checkpoint_path)
self
.saver.restore(session, os.path.expanduser(
self
.checkpoint_path))
tf.logging.info(
'End fineturn from %s'
%
self
.checkpoint_path)
def
before_run(
self
, run_context):
#print('Before calling session.run().')
return
None
#SessionRunArgs(self.your_tensor)
def
after_run(
self
, run_context, run_values):
#print('Done running one step. The value of my tensor: %s', run_values.results)
#if you-need-to-stop-loop:
# run_context.request_stop()
pass
def
end(
self
, session):
#print('Done with the session.')
pass
|
以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/andylei777/article/details/79074757 。
最后此篇关于tensorflow estimator 使用hook实现finetune方式的文章就讲到这里了,如果你想了解更多关于tensorflow estimator 使用hook实现finetune方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我创建了一个简单的钩子(Hook),我安装了它 SetWindowsHookEx(WH_CBT, addr, dll, 0); 完成后,我卸载 UnhookWindowsHookEx(0); 然后我可
我正在使用 React Hooks,当我用 mobx 的观察者包装我的组件时,我收到了这个错误。可能是什么问题?是否可以将 mobx 与 React Hooks 一起使用? import classn
我知道这个问题已经被回答过很多次了。我只是找不到解决我的问题的答案,让我相信,我要么是愚蠢的,要么是我的问题没有被解决,因为它比我更愚蠢。除此之外,这是我的问题: 我正在尝试创建一个功能组件,它从 r
我正在使用 React Navigation 的 useNavigation 钩子(Hook): 在 MyComponent.js 中: import { useNavigation } from "
我想在 gitlab 中使用预提交钩子(Hook)。我做的一切都像文档中一样:https://docs.gitlab.com/ce/administration/custom_hooks.html 在
我最近在和一些人谈论我正在编写的程序时听到了“hook”这个词。尽管我从对话中推断出钩子(Hook)是一种函数,但我不确定这个术语到底意味着什么。我搜索了定义,但找不到好的答案。有人可以让我了解这个术
我正在寻找一个在页面创建或页面更改后调用的钩子(Hook),例如“在导航中隐藏页面”、“停用页面”或“移动/删除页面“ 有人知道吗? 谢谢! 最佳答案 这些 Hook 位于 t3lib/class.t
我正在使用钩子(Hook)将新方法添加到 CalEventLocalServiceImpl 中... 我的代码是.. public class MyCalendarLocalServiceImpl e
编译器将所有 SCSS 文件编译为 STANDALONE(无 Rails)项目中的 CSS 后,我需要一个 Compass Hook 。 除了编辑“compiler.rb”(这不是好的解决方案,因为
我“.get”一个请求并像这样处理响应: resp = requests.get('url') resp = resp.text .. # do stuff with resp 阅读包的文档后,我看到
我们想在外部数据库中存储一些关于提交的元信息。在克隆或 checkout 期间,应引用此数据库,我们将元信息复制到克隆的存储库中的文件中。需要数据库而不是仅仅使用文件是为了索引和搜索等...... 我
我有一个 react 钩子(Hook)useDbReadTable,用于从接受tablename和query初始数据的数据库读取数据。它返回一个对象,除了数据库中的数据之外,还包含 isLoading
在下面的代码中,当我调用 _toggleSearch 时,我同时更新 2 个钩子(Hook)。 toggleSearchIsVisible 是一个简单的 bool 值,但是,setActiveFilt
问题 我想在可由用户添加的表单中实现输入字段的键/值对。 参见 animated gif on dynamic fields . 此外,我想在用户提交表单并再次显示页面时显示保存的数据。 参见 ani
当状态处于 Hook 状态时,它可能会变得陈旧并泄漏内存: function App() { const [greeting, setGreeting] = useState("hello");
const shouldHide = useHideOnScroll(); return shouldHide ? null : something useHideOnScroll 行为应该返回更新后
我正在使用 React-native,在其中,我有一个名为 useUser 的自定义 Hook,它使用 Auth.getUserInfro 方法从 AWS Amplify 获取用户信息,并且然后获取返
我正在添加一个 gitolite 更新 Hook 作为 VREF,并且想知道是否有办法将它应用于除 gitolite-admin 之外的所有存储库。 有一个更简单的方法而不是列出我想要应用 Hook
如何使用带有 react-apollo-hooks 的 2 个 graphql 查询,其中第二个查询取决于从第一个查询中检索到的参数? 我尝试使用如下所示的 2 个查询: const [o, setO
我是 hooks 的新手,到目前为止印象还不错,但是,如果我尝试在函数内部使用 hooks,它似乎会提示(无效的 hook 调用。Hooks can only be called inside o
我是一名优秀的程序员,十分优秀!