gpt4 book ai didi

python - 执行堆栈中数据集类型的更改

转载 作者:行者123 更新时间:2023-12-03 23:45:35 31 4
gpt4 key购买 nike

问题是在执行堆栈的不同点期间数据集从一种类型更改为另一种类型。例如,如果我添加一个具有更多感兴趣成员属性的新数据集类(它继承自 ops.data.dataset_ops 中的一个类,如 UnaryDataset),结果是在稍后的执行点(client_update 函数),数据集被转换到 _VaraintDataset 类型,因此任何添加的属性都将丢失。所以问题是如何在执行过程中保留新定义的数据集类的成员属性。下面是类型从 ParallelMapDataset 更改为 _VariantDataset 的 emnist 示例。
在training_utils.py第194行的client_dataset函数中,我修改了一下,显示数据集的类型如下

  def client_datasets(round_num):
sampled_clients = sample_clients_fn(round_num)
sampled_client_datasets = []
for client in sampled_clients:
d = train_dataset.create_tf_dataset_for_client(client)
sampled_client_datasets.append(train_dataset.create_tf_dataset_for_client(client))
tf.print('CLIENT DATASETS: ', d, type(d))
return sampled_client_datasets
输出是:
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
然后在 fed_avg_schedule.py 第 178 行中的客户端调用的 tf.function client_update 中,数据集是不同类型的
@tf.function
def client_update(model,
dataset,
initial_weights,
client_optimizer,
client_weight_fn=None):
"""Updates client model.

Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
initial_weights: A `tff.learning.Model.weights` from server.
client_optimizer: A `tf.keras.optimizer.Optimizer` object.
client_weight_fn: Optional function that takes the output of
`model.report_local_outputs` and returns a tensor that provides the
weight in the federated average of model deltas. If not provided, the
default is the total number of examples processed on device.

Returns:
A 'ClientOutput`.
"""

tf.print('CLIENT UPDATE: ', dataset, type(dataset))
....
输出将是:
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
我可能是错的,但我做了一些跟踪,发现在某些时候调用函数 (_to_components(self, value) of DatasetSpec) 进行转换:
  def _to_components(self, value):
return value._variant_tensor # pylint: disable=protected-access
编辑 - 按照建议的答案
以下是我在拉取最新版本的联合存储库后对 simpel_fedavg 示例所做的更改
首先,我添加/修改下面的行到 simple_fedavg_tff.py 的 build_fed_avg_process
server_message_type = server_message_fn.type_signature.result
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
meta_data_type = tff.SequenceType(tf.string)

@tff.tf_computation(tf_dataset_type, meta_data_type, server_message_type)
def client_update_fn(tf_dataset, meta_data, server_message):
model = model_fn()
client_optimizer = client_optimizer_fn()
return client_update(model, tf_dataset, meta_data, server_message, client_optimizer)

@tff.tf_computation((tf_dataset_type, meta_data_type))
def extract_data_metadata_fn(tf_dataset_metadata_tuple):
x, y = tf_dataset_metadata_tuple
return x, y

federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType( (tf_dataset_type, meta_data_type), tff.CLIENTS)
@tff.federated_computation(federated_server_state_type,
federated_dataset_type)
def run_one_round(server_state, federated_dataset):
"""Orchestration logic for one round of computation.

Args:
server_state: A `ServerState`.
federated_dataset: A federated `tf.data.Dataset` with placement
`tff.CLIENTS`.

Returns:
A tuple of updated `ServerState` and `tf.Tensor` of average loss.
"""
server_message = tff.federated_map(server_message_fn, server_state)
server_message_at_client = tff.federated_broadcast(server_message)

data_set, meta_data = tff.federated_map(extract_data_metadata_fn, federated_dataset)

#client_outputs = tff.federated_map(client_update_fn, (federated_dataset, server_message_at_client))
client_outputs = tff.federated_map(client_update_fn, (data_set, meta_data, server_message_at_client))
在 simple_fedavg_tf.py 中,我添加了以下元数据的打印行
@tf.function
def client_update(model, dataset, meta_data, server_message, client_optimizer):
"""Performans client local training of `model` on `dataset`.

Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
server_message: A `BroadcastMessage` from server.
client_optimizer: A `tf.keras.optimizers.Optimizer`.

Returns:
A 'ClientOutput`.
"""
tf.print(meta_data)

model_weights = model.weights
initial_weights = server_message.model_weights
client_ids = server_message.client_ids
tff.utils.assign(model_weights, initial_weights)
在主文件 emnist_simple_fedavg.py 中,我修改了 main 函数中主训练循环的以下几行:
meta_data = ['a','b','c','d']
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, meta_data))
这没有解决,我收到以下错误:
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 176, in <module>
app.run(main)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 166, in main
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, sampled_clients.tolist()))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/utils/function_utils.py", line 563, in __call__
return context.invoke(self, arg)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
return Retrying(*dargs, **dkw).call(f, *args, **kw)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
return attempt.get(self._wrap_exception)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
six.reraise(self.value[0], self.value[1], self.value[2])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/six/__init__.py", line 693, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 215, in invoke
_ingest(executor, unwrapped_arg, arg.type_signature)))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 99, in _ingest
ingested = await asyncio.gather(*ingested)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 104, in _ingest
return await executor.create_value(val, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 383, in create_value
return await self._strategy.compute_federated_value(value, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federated_resolving_strategy.py", line 275, in compute_federated_value
for v, c in zip(value, children)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 282, in create_value
*[self.create_value(val, t) for (_, val), t in zip(v_el, type_spec)])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 464, in create_value
return EagerValue(value, self._tf_function_cache, type_spec, self._device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 367, in __init__
type_spec, device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 335, in to_representation_for_type
type_conversions.TF_DATASET_REPRESENTATION_TYPES)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.data.ops.dataset_ops.DatasetV2 or tensorflow.python.data.ops.dataset_ops.DatasetV1, found str.
E0721 23:53:29.388700 139706363909952 base_events.py:1285] Task was destroyed but it is pending!
task: <Task pending coro=<trace.<locals>.async_trace() running at /root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py:200> wait_for=<Future pending cb=[_chain_future.<locals>._call_check_cancel() at /usr/lib/python3.6/asyncio/futures.py:403, <TaskWakeupMethWrapper object at 0x7f0f7c07eca8>()]> cb=[<TaskWakeupMethWrapper object at 0x7f0f7c07e648>()]>

最佳答案

新的数据集 Python 类将需要支持序列化。这是必要的,因为 TensorFlow Federated 旨在运行在与编写计算的机器不一定相同的机器上(例如,在跨设备联邦学习的情况下,智能手机)。这些机器可能没有运行 Python,因此无法理解创建的新子类,因此需要更新序列化层。然而,这是相当低级的,可能有其他方法来实现预期目标。
大胆尝试:如果目标是为客户端提供元数据和数据集,那么更改 fed_avg_schedule.build_fed_avg_process 返回的迭代过程的函数签名可能更容易。接受每个客户端的(数据集,元数据结构)元组。
当前下一个计算的签名是(在 Custom Federated Algorithms, Part 1: Introduction to the Federated Core 中引入的 TFF 类型速记):

(<ServerState@SERVER, Datasets@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
( ServerState . DatasetMetrics 的定义由模型和数据集定义)
相反,我们可能想要一个看起来像这样的签名:
(<ServerState@SERVER, <Datasets, Metadata>@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
为此,我们可以执行以下操作:
  • 更新 run_one_round 上的参数类型here成为 tf_dataset_type 的元组和元数据结构。
  • 通过 tff.federated_map 寻找新的论点调用 here
  • client_update_fn 添加新参数here
  • 关于python - 执行堆栈中数据集类型的更改,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62993389/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com