gpt4 book ai didi

mongodb - 使用 keras.utils.Sequence 多处理和数据库 - 何时连接?

转载 作者:可可西里 更新时间:2023-11-01 09:47:56 27 4
gpt4 key购买 nike

我正在使用带有 Tensorflow 后端的 Keras 训练神经网络。数据集不适合 RAM,因此,我将其存储在 Mongo 数据库中,并使用 keras.utils.Sequence 的子类检索批处理。

如果我使用 use_multiprocessing=False 运行 model.fit_generator(),一切正常。

当我打开多处理时,我在生成工作进程或连接到数据库时遇到错误。

如果我在 __init__ 中创建一个连接,我会遇到一个异常,其文本说明了 pickling 锁对象中的错误。对不起,我记不太清了。但是训练甚至没有开始。

如果我在 __get_item__ 中创建连接,训练开始并运行一些 epoch,然后我得到错误 [WinError 10048] Only one usage of each socket address (protocol/network address/端口)通常是允许的

根据 the pyMongo manuals ,它不是 fork 安全的,每个子进程都必须创建自己的数据库连接。我使用 Windows,它不使用 fork ,而是生成进程,但是,恕我直言,区别在这里并不重要。

这解释了为什么无法在 __init__ 中连接。

这是来自 docs 的另一引述:

Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.

这解释了 __get_item__ 中的错误。

但是,我的类(class)如何理解 Keras 已创建新进程尚不清楚。

这是我的 Sequence 实现的最后一个变体的伪代码(每个请求的新连接):

import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical

class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size

query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]

def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]

def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y

也就是说,在对象构造上,我根据条件检索所有相关的 ObjectIDs 形成训练集。在调用 __getitem__ 时从数据库中检索实际对象。它们的 ObjectIDs 由列表切片确定。

这段调用 model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... ) 的代码生成了几个 python 进程,每个进程根据日志消息初始化 Tensorflow 后端,训练开始。

但最终异常从名为 connect 的函数中抛出,位于 pymongo 的深处。

不幸的是,我没有存储调用堆栈。错误如上所述,我重复:[WinError 10048] 通常允许每个套接字地址(协议(protocol)/网络地址/端口)的一种用法

我的假设是这段代码创建了太多与服务器的连接,因此,在 __getitem__ 中的连接是错误的。

构造函数中的连接也是错误的,因为它是在主进程中执行的,Mongo文档直接反对它。

Sequence类中还有一个方法,on_epoch_end。但是,我需要在纪元开始而不是结束时建立连接。

引用自 Keras 文档:

If you want to modify your dataset between epochs you may implement on_epoch_end

那么,有什么推荐的吗?文档在这里不是很具体。

最佳答案

看来我找到了解决方案。解决方案是 - 跟踪进程 ID 并在它更改时重新连接

class MongoSequence(Sequence):
def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
self._server = server
self._db = database
self._collection_name = collection
self._batch_size = batch_size
self._query = query
self._collection = self._connect()

self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]

self._pid = os.getpid()
del self._collection # to be sure, that we've disconnected
self._collection = None

def _connect(self):
client = pymongo.MongoClient(self._server)
db = client[self._db]
return db[self._collection_name]

def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

def __getitem__(self, item):
if self._collection is None or self._pid != os.getpid():
self._collection = self._connect()
self._pid = os.getpid()

oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y

关于mongodb - 使用 keras.utils.Sequence 多处理和数据库 - 何时连接?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49879750/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com