gpt4 book ai didi

python - 在多个数据集上训练神经网络模型

转载 作者:行者123 更新时间:2023-12-05 04:38:43 25 4
gpt4 key购买 nike

我有什么:

  1. 神经网络模型
  2. 10 个结构相同的数据集

我想要的:

  1. 分别在所有数据集上训练模型
  2. 分别保存他们的模型

我可以单独训练数据集并一次保存一个模型。但是我想在一次运行中加载我的 10 个数据集并使用它们创建 10 个模型。解决方案可能很明显,但我对此还很陌生。我如何实现这一点?

提前致谢。

最佳答案

您可以使用并发和并行 的概念之一,即Multi-Threading ,或者在某些情况下,Multi-Processing实现这一目标。
最简单的编码方法是使用 concurrent-futures python模块。

You can call the training function on model for each dataset to be used, all under the ThreadPoolExecutor, in order to fire parallel threads for performing individual trainings.

代码可以是这样的:


第 1 步:必要的导入
from concurrent.futures import ThreadPoolExecutor, as_completed

import tensorflow as tf
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten

第 2 步:创建和构建模型
def create_model():                                                 # responsible for creating model
model = Sequential()
model.add(Flatten()) # adding NN layers
model.add(Dense(64))
model.add(Activation('relu'))
# ........ so on
model.compile(optimizer='..', loss='..', metrics=[...]) # compiling the model
return model # finally returning the model

第三步:定义拟合函数(进行模型训练)
def fit(model, XY_train):                                      # performs model.fit(...parameters...)
model.fit(XY_train[0], XY_train[1], epochs=5, validation_split=0.3) # use your already defined x_train, y_train
return model # finally returns trained model

第 4 步:并行训练器方法,使用 TPE 上下文管理器启动同步训练
# trains provided model on each dataset parallelly by using multi-threading
def parallel_trainer(model, XY_train_datasets : list[tuple]):
with ThreadPoolExecutor(max_workers = len(XY_train_datasets)) as executor:
futureObjs = [
executor.submit(
lambda ds: fit(model, ds), XY_train_datasets) # Call Fit for each dataset iterate through the datasets
]

for i, obj in enumerate(as_completed(futureObjs)): # iterate through trained models
(obj.result()).save(f"{i}.model") # save models

第五步:创建模型、加载数据集、调用并行训练器
model = create_model()                                              # create the model

mnist = tf.keras.datasets.mnist # get dataset - for example :- mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data() # get (x_train, y_train), (x_test, y_test)
datasets = [(x_train, y_train)]*10 # list of dataset paths (in your case, same dataset used 10 times)

parallel_trainer(model, datasets) # call parallel trainer



整个程序

from concurrent.futures import ThreadPoolExecutor, as_completed

import tensorflow as tf
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten


def create_model(): # responsible for creating model
model = Sequential()
model.add(Flatten()) # adding NN layers
model.add(Dense(64))
model.add(Activation('relu'))
# ........ so on
model.compile(optimizer='..', loss='..', metrics=[...]) # compiling the model
return model # finally returning the model


def fit(model, XY_train): # performs model.fit(...parameters...)
model.fit(XY_train[0], XY_train[1], epochs=5, validation_split=0.3) # use your already defined x_train, y_train
return model # finally returns trained model


# trains provided model on each dataset parallelly by using multi-threading
def parallel_trainer(model, XY_train_datasets : list[tuple]):
with ThreadPoolExecutor(max_workers = len(XY_train_datasets)) as executor:
futureObjs = [
executor.submit(
lambda ds: fit(model, ds), XY_train_datasets) # Call Fit for each dataset iterate through the datasets
]

for i, obj in enumerate(as_completed(futureObjs)): # iterate through trained models
(obj.result()).save(f"{i}.model") # save models



model = create_model() # create the model

mnist = tf.keras.datasets.mnist # get dataset - for example :- mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data() # get (x_train, y_train), (x_test, y_test)
datasets = [(x_train, y_train)]*10 # list of dataset paths (in your case, same dataset used 10 times)

parallel_trainer(model, datasets) # call parallel trainer

关于python - 在多个数据集上训练神经网络模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70539674/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com