gpt4 book ai didi

python - 我怎样才能解决这个脚本中的内存限制?

转载 作者:行者123 更新时间:2023-11-28 18:28:00 25 4
gpt4 key购买 nike

我正在尝试规范化我的数据集,它是 1.7 GB。我有 14Gig RAM,很快就达到了极限。

这发生在计算训练数据的 mean/std 时。训练数据加载到RAM(13.8Gig)时占用了大部分内存,因此计算了平均值,但是当它到达下一行时计算std,它崩溃了。

遵循脚本:

import caffe
import leveldb
import numpy as np
from caffe.proto import caffe_pb2
import cv2
import sys
import time

direct = 'examples/svhn/'
db_train = leveldb.LevelDB(direct+'svhn_train_leveldb')
db_test = leveldb.LevelDB(direct+'svhn_test_leveldb')
datum = caffe_pb2.Datum()

#using the whole dataset for training which is 604,388
size_train = 604388 #normal training set is 73257
size_test = 26032
data_train = np.zeros((size_train, 3, 32, 32))
label_train = np.zeros(size_train, dtype=int)

print 'Reading training data...'
i = -1
for key, value in db_train.RangeIter():
i = i + 1
if i % 1000 == 0:
print i
if i == size_train:
break
datum.ParseFromString(value)
label = datum.label
data = caffe.io.datum_to_array(datum)
data_train[i] = data
label_train[i] = label

print 'Computing statistics...'
print 'calculating mean...'
mean = np.mean(data_train, axis=(0,2,3))
print 'calculating std...'
std = np.std(data_train, axis=(0,2,3))

#np.savetxt('mean_svhn.txt', mean)
#np.savetxt('std_svhn.txt', std)

print 'Normalizing training'
for i in range(3):
print i
data_train[:, i, :, :] = data_train[:, i, :, :] - mean[i]
data_train[:, i, :, :] = data_train[:, i, :, :]/std[i]


print 'Outputting training data'
leveldb_file = direct + 'svhn_train_leveldb_normalized'
batch_size = size_train

# create the leveldb file
db = leveldb.LevelDB(leveldb_file)
batch = leveldb.WriteBatch()
datum = caffe_pb2.Datum()

for i in range(size_train):
if i % 1000 == 0:
print i

# save in datum
datum = caffe.io.array_to_datum(data_train[i], label_train[i])
keystr = '{:0>5d}'.format(i)
batch.Put( keystr, datum.SerializeToString() )

# write batch
if(i + 1) % batch_size == 0:
db.Write(batch, sync=True)
batch = leveldb.WriteBatch()
print (i + 1)

# write last batch
if (i+1) % batch_size != 0:
db.Write(batch, sync=True)
print 'last batch'
print (i + 1)
#explicitly freeing memory to avoid hitting the limit!
#del data_train
#del label_train

print 'Reading test data...'
data_test = np.zeros((size_test, 3, 32, 32))
label_test = np.zeros(size_test, dtype=int)
i = -1
for key, value in db_test.RangeIter():
i = i + 1
if i % 1000 == 0:
print i
if i ==size_test:
break
datum.ParseFromString(value)
label = datum.label
data = caffe.io.datum_to_array(datum)
data_test[i] = data
label_test[i] = label

print 'Normalizing test'
for i in range(3):
print i
data_test[:, i, :, :] = data_test[:, i, :, :] - mean[i]
data_test[:, i, :, :] = data_test[:, i, :, :]/std[i]

#Zero Padding
#print 'Padding...'
#npad = ((0,0), (0,0), (4,4), (4,4))
#data_train = np.pad(data_train, pad_width=npad, mode='constant', constant_values=0)
#data_test = np.pad(data_test, pad_width=npad, mode='constant', constant_values=0)

print 'Outputting test data'
leveldb_file = direct + 'svhn_test_leveldb_normalized'
batch_size = size_test

# create the leveldb file
db = leveldb.LevelDB(leveldb_file)
batch = leveldb.WriteBatch()
datum = caffe_pb2.Datum()

for i in range(size_test):
# save in datum
datum = caffe.io.array_to_datum(data_test[i], label_test[i])
keystr = '{:0>5d}'.format(i)
batch.Put( keystr, datum.SerializeToString() )

# write batch
if(i + 1) % batch_size == 0:
db.Write(batch, sync=True)
batch = leveldb.WriteBatch()
print (i + 1)

# write last batch
if (i+1) % batch_size != 0:
db.Write(batch, sync=True)
print 'last batch'
print (i + 1)

如何让它消耗更少的内存以便我可以运行脚本?

最佳答案

为什么不计算原始数据子集的统计数据?例如,这里我们只计算 100 个点的均值和标准差:

sample_size = 100
data_train = np.random.rand(1000, 20, 10, 10)

# Take subset of training data
idxs = np.random.choice(data_train.shape[0], sample_size)
data_train_subset = data_train[idxs]

# Compute stats
mean = np.mean(data_train_subset, axis=(0,2,3))
std = np.std(data_train_subset, axis=(0,2,3))

如果您的数据是 1.7Gb,您不太可能需要所有数据来准确估计均值和标准差。

此外,您能否在数据类型中减少位数?我不确定 caffe.io.datum_to_array 返回什么数据类型,但你可以这样做:

data = caffe.io.datum_to_array(datum).astype(np.float32)

确保数据是float32格式。 (如果数据当前是 float64,那么这将为您节省一半的空间)。

关于python - 我怎样才能解决这个脚本中的内存限制?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39892920/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com