gpt4 book ai didi

tensorflow2.0 - 尝试使用 tf.data.Dataset 而不是张量来提供 tf.keras 模型时出错

转载 作者:行者123 更新时间:2023-12-04 04:10:28 25 4
gpt4 key购买 nike

为什么以下 tf2 tf.keras 模型在装有张量时“有效”,但在尝试以 tf.data.Dataset.from_tensor_slices 形式拟合相同张量时生成 ValueError?

编辑:换句话说,已经使用 numpy 数组开发/安装/测试了下面的模型。这些相同的 numpy 数组需要如何 reshape (?)以便它们可以用于创建与模型一起使用的带有 tf.data.Dataset.from_tensor_slices 的数据集?

embed = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embed, output_shape=[20], input_shape=[],
dtype=tf.string, trainable=True, name='hub_layer')

# from tf hub docs. hub_layer takes a 1D tensor of strings.

input_tensor = tf.keras.Input(shape=(), name="input_enquiry", dtype=tf.string) # Note tf.string. Ref: https://github.com/tensorflow/hub/issues/483
hub_tensor = hub_layer(input_tensor)
x = tf.keras.layers.Dense(16, activation='relu')(hub_tensor)
main_output = tf.keras.layers.Dense(units=4, activation='softmax', name='main_output')(x)

model = tf.keras.models.Model(inputs=[input_tensor], outputs=[main_output])
model.compile(optimizer='adam', loss=tf.losses.CategoricalCrossentropy(),metrics='acc')

# Input and target
X = tf.constant([['The quick brown fox'], ['Hello World']])
y = tf.constant([[0,0,0,1], [0,0,1,0]])

# Works OK
model.fit(X, y) # fit on tensors

X_ds = tf.data.Dataset.from_tensor_slices(X)

# Works OK
model.predict(X_ds) # predict on dataset

y_ds = tf.data.Dataset.from_tensor_slices(y)

ds = tf.data.Dataset.zip((X_ds, y_ds))

# Fails with ValueError
model.fit(ds)

值错误:

---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
30
31 # Fails with ValueError
---> 32 model.fit(ds)
33
34

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
846 batch_size=batch_size):
847 callbacks.on_train_batch_begin(step)
--> 848 tmp_logs = train_function(iterator)
849 # Catch OutOfRangeError for Datasets of unknown size.
850 # This blocks until the batch has finished executing.

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args, **kwds)
581
582 if tracing_count == self._get_tracing_count():

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
609 # In this case we have created variables on the first call, so we run the
610 # defunned version which is guaranteed to never create variables.
--> 611 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
612 elif self._stateful_fn is not None:
613 # Release the lock early so that multiple threads can perform the call

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2417 """Calls a graph function specialized to the inputs."""
2418 with self._lock:
-> 2419 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2420 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2421

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2772 and self.input_signature is None
2773 and call_context_key in self._function_cache.missed):
-> 2774 return self._define_function_with_shape_relaxation(args, kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _define_function_with_shape_relaxation(self, args, kwargs)
2704 relaxed_arg_shapes)
2705 graph_function = self._create_graph_function(
-> 2706 args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
2707 self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
2708

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2665 arg_names=arg_names,
2666 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667 capture_by_value=self._capture_by_value),
2668 self._function_attributes,
2669 # Tell the ConcreteFunction to clean up its graph once it goes out of

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
439 # __wrapped__ allows AutoGraph to swap in a converted function. We give
440 # the function a weak reference to itself to avoid a reference cycle.
--> 441 return weak_wrapped_fn().__wrapped__(*args, **kwds)
442 weak_wrapped_fn = weakref.ref(wrapped_fn)
443

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise

ValueError: in user code:

/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:533 train_step **
y, y_pred, sample_weight, regularization_losses=self.losses)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/compile_utils.py:205 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:143 __call__
losses = self.call(y_true, y_pred)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:246 call
return self.fn(y_true, y_pred, **self._fn_kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:1527 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:4561 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))

ValueError: Shapes (4, 1) and (1, 4) are incompatible

如果我们不使用“.from_tensor_slices”,而是使用“.from_tensors”来创建 X_ds 和 y_ds,那么在压缩之后,一切正常。然而,docs 给我的印象是“.from_tensors”占用大量内存且不可取。此外,我认为单个元素“.from_tensors”数据集只是为模型提供两个二维张量,而 from_tensor_slices 版本是一维元素序列。

最佳答案

问题的具体问题的解决方案是.batch() 数据集:

ds = tf.data.Dataset.zip((X_ds, y_ds)).batch(32) # eg, batch size 32

我的理解(docs)是呈现给模型的“批处理”有效地恢复了通过 tf.data.Data.from_tensor_slices 方法删除的数据外部维度。也就是说,数据恢复到与原始 numpy 数组一起使用的形状。

关于tensorflow2.0 - 尝试使用 tf.data.Dataset 而不是张量来提供 tf.keras 模型时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61829273/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com