gpt4 book ai didi

tensorflow - 了解对象检测 API 中的数据增强

转载 作者:行者123 更新时间:2023-12-03 17:34:42 24 4
gpt4 key购买 nike

我正在使用对象检测 API 来训练不同的数据集,我想知道是否有可能在训练期间获得到达网络的样本图像。

我问这个是因为我试图找到数据增强选项( here the options )的良好组合,但添加它们的结果更糟。查看在训练中到达网络的内容将非常有帮助。

另一个问题是是否有可能获得 API 来帮助平衡类,以防传递的数据集使它们不平衡。

谢谢!

最佳答案

对的,这是可能的。简而言之,您需要获取 tf.data.Dataset 的实例。然后,您可以迭代它并以 NumPy 数组的形式获取网络输入数据。然后使用 PIL 或 OpenCV 将其保存为图像文件是微不足道的。
假设您使用 TF2,伪代码如下:

ds = ... get dataset object somehow

sample_num = 0
for features, _ in ds:
images = features[fields.InputDataFields.image] # is a [batch_size, H, W, C] float32 tensor with preprocessed images
batch_size = images.shape[0]
for i in range(batch_size):
image = np.array(images[i] * 255).astype(np.uint8) # assuming input data is only scaled to [0..1]
cv2.imwrite(output_path, image)

sample_num += 1
if sample_num >= MAX_SAMPLES:
break
这里的技巧是获取 Dataset 实例。 Google 对象检测 API 非常复杂,但我想您应该从调用 train_input 开始函数在这里: https://github.com/tensorflow/models/blob/3c8b6f1e17e230b68519fd8d58c4dd9e9570d789/research/object_detection/inputs.py#L763
它需要描述训练、train_input 和模型的管道配置子部分。
您可以在此处找到有关如何使用管道的一些代码片段: Dynamically Editing Pipeline Config for Tensorflow Object Detection
import argparse

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2


def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()


def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

with tf.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)

关于tensorflow - 了解对象检测 API 中的数据增强,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47294531/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com