gpt4 book ai didi

python - 动态编辑 Tensorflow 对象检测的管道配置

转载 作者:太空宇宙 更新时间:2023-11-04 09:32:06 25 4
gpt4 key购买 nike

我正在使用 tensorflow 对象检测 API,我希望能够在 python 中动态编辑配置文件,如下所示。我想过在 python 中使用 Protocol Buffer 库,但我不确定该怎么做。

model {
ssd {
num_classes: 1
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
feature_extractor {
type: "ssd_inception_v2"
depth_multiplier: 1.0
min_depth: 16
conv_hyperparams {
regularizer {
l2_regularizer {
weight: 3.99999989895e-05
}
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.0299999993294
}
}
activation: RELU_6
batch_norm {
decay: 0.999700009823
center: true
scale: true
epsilon: 0.0010000000475
train: true
}
}
...
...

是否有一种简单/容易的方法可以将 image_resizer -> fixed_shape_resizer 中的高度等字段的特定值从 300 更改为 500?并在不更改任何其他内容的情况下用修改后的值写回文件?

编辑:尽管@DmytroPrylipko 提供的答案适用于配置中的大多数参数,但我在“复合字段”方面遇到了一些问题。

也就是说,如果我们有这样的配置:

train_input_reader: {
label_map_path: "/tensorflow/data/label_map.pbtxt"
tf_record_input_reader {
input_path: "/tensorflow/models/data/train.record"
}
}

然后我添加这一行来编辑 input_path:

 pipeline_config.train_input_reader.tf_record_input_reader.input_path = "/tensorflow/models/data/train100.record"

它抛出错误:

TypeError: Can't set composite field

最佳答案

是的,使用 Protobuf Python API 非常简单:

edit_pipeline.py:

import argparse

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2


def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()


def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

with tf.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)

pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300

config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(args.output, "wb") as f:
f.write(config_text)


if __name__ == '__main__':
main()

我调用脚本的方式:

TOOL_DIR=tool/tf-models/research

(
cd $TOOL_DIR
protoc object_detection/protos/*.proto --python_out=.
)

export PYTHONPATH=$PYTHONPATH:$TOOL_DIR:$TOOL_DIR/slim

python3 edit_pipeline.py pipeline.config pipeline_new.config

复合字段

在重复字段的情况下,您必须将它们视为数组(例如使用extend()append() 方法):

pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'

Eval Input 阅读器错误

这是尝试编辑复合字段的常见错误。 (在 eval_input_reader 的情况下“未找到属性 tf_record_input_reader”)

下面@latida 的回答中提到了它。通过将其设置为数组字段来解决该问题。

pipeline_config.eval_input_reader[0].label_map_path  = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path

关于python - 动态编辑 Tensorflow 对象检测的管道配置,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55323907/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com