gpt4 book ai didi

python - 在 tf 函数内迭代 tf.Tensor 以生成基于 NamedTuple 的数据集项列表

转载 作者:行者123 更新时间:2023-12-03 14:53:10 26 4
gpt4 key购买 nike

我正在使用 typing.NamedTuple基于 tf.data.Dataset 中的元素类型.下面是一个例子。

# You can run all the code in this question by pasting all
# the code blocks consecutively into a Python file

import tensorflow as tf
from typing import *
from random import *
from pprint import *

class Coord(NamedTuple):
x: float
y: float

@classmethod
def random(cls): return cls(gauss(10., 1.), gauss(10., 1.))

class Box(NamedTuple):
min: Coord
max: Coord

@classmethod
def random(cls): return cls(Coord.random(), Coord.random())

class Boxes(NamedTuple):
boxes: List[Box]

@classmethod
def random(cls): return cls([Box.random() for _ in range(randint(3, 5))])

def test_dataset():
for _ in range(randint(3, 5)): yield Boxes.random()

tf_dataset = tf.data.Dataset.from_generator(test_dataset, output_types=(tf.float32,))

如您所知, tf.data.Dataset.from_generator()将数据集元素(最初具有 Boxes 类型)转换为 tf.Tensor 的单元素元组与 (None, 2, 2)形状。例如,数据集的一个元素可能是以下项目:

(<tf.Tensor: shape=(4, 2, 2), dtype=float32, numpy=
array([[[11.642379, 9.937152],
[ 8.998009, 8.387287]],

[[10.649337, 10.028358],
[ 8.507834, 9.84779 ]],

[[11.10263 , 11.3706 ],
[ 9.20623 , 10.44905 ]],

[[ 9.591406, 9.560486],
[ 9.461394, 9.256082]]], dtype=float32)>,)

我有非 @tf.function -带注释的常规 Python 函数,可以将数据转换为原始类型,例如以下函数:

def flip_boxes(boxes: Boxes):
def flip_coord(c: Coord): return Coord(-c.x, c.y)
def flip_box(b: Box): return Box(flip_coord(b.min), flip_coord(b.max))
return Boxes(boxes=list(map(flip_box, boxes.boxes)))

我想将此 Python 函数(以及其他类似的函数)应用于此 tf.data.Dataset通过 tf.data.Dataset.map(map_func) 功能。 Dataset.map预计 map_func是一个函数,在其 tf.Tensor 中采用数据集元素类型的成员格式。原始元素类型为 Boxes其中有一个成员,原来是 boxes: List[Box] .该列表被转换为 (4, 2, 2) -shape Tensor 在创建数据集时位于上方。 tf.data.Dataset.map()时不回变电话 map_func ,Tensor 直接作为第一个参数传递给 map_func . (如果 Boxes 有更多成员,这些成员将作为单独的参数传递给 map_func,并且它们不会作为单个元组传递。)

问题:我实现了什么适配器函数才能使常规 Python 函数(如 flip_boxes )可用于 tf.data.Dataset.map() ?

我尝试迭代并使用 tf.split恢复 List[Boxes]来自输入 tf.Tensor但我遇到了下面作为评论列出的错误消息。
# Question: How do I implement this function?
def to_tf_mappable_function(fn: Callable) -> Callable:

def function(tensor: tf.Tensor):
boxes: List[Box] = [Box(Coord(10.0, 0.0), Coord(10.0, 0.0)), Box(Coord(10.0, 0.0), Coord(10.0, 0.0))]
# TODO calculate `boxes` from `tensor`, not use this dummy constant above

# Trivial Python code does not work, it results in this error on the commented-out line:
# OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
# AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
# boxes = [Box(Coord(row[0][0], row[0][1]), Coord(row[1][0], row[1][1])) for row in tensor]
# Decorating any of flip_boxes, to_tf_mappable_function and to_tf_mappable_function.function
# does not eliminate the error.

# I thought tf.split might help, but it results in this error on the commented-out line:
# ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split.
# Argument provided: Tensor("cond/Identity:0", shape=(), dtype=int32)
# boxes = tf.split(tensor, len(tensor))

return fn(Boxes(boxes))

return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
# The line above should be morally equivalent to `dataset = map(flip_boxes, dataset)`,
# given a `dataset: Iterable[Boxes]` and the builtin `map` function in Python.

也许我没有问正确的问题,但请给我一些懈怠。
* 高级任务是申请 flip_boxes和类似的功能 tf.data.Dataset以有效的方式
* 卡住的地方是找回 List[Box]来自 tf.Tensor它的形状与框坐标列表完全一样,所以也许我的问题应该仅限于这个问题。

最佳答案

我不确定您是否正在寻找更一般的东西,但对于您在这里提出的确切问题,这似乎是实现它的可能方法之一:

# Helper function to translate from tensor back to Boxes type
def boxes_from_tensor(t: tf.Tensor) -> Boxes:
n_boxes = t.shape[0]
t = t.numpy()
boxes = Boxes(boxes=[Box(Coord(t[i,0,0], t[i,0,1]), Coord(t[i,1,0], t[i,1,1])) for i in range(n_boxes)])
return boxes

def to_tf_mappable_function(fn: Callable) -> Callable:
def function(tensor: tf.Tensor):
return tf.py_function(lambda t: fn(boxes_from_tensor(t)), [tensor], tensor.dtype)
return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
list(tf_dataset)

关于python - 在 tf 函数内迭代 tf.Tensor 以生成基于 NamedTuple 的数据集项列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62332840/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com