gpt4 book ai didi

python - 发电机停止 Keras fit_generator

转载 作者:行者123 更新时间:2023-11-30 22:20:28 28 4
gpt4 key购买 nike

我有一个包含 9 列的数据集,最后一个是带标题的 csv 格式的目标变量。我正在尝试编写一个生成器来在 keras 中训练模型。代码如下。训练在第一个时期运行,但在完成之前就永远停止/挂起。

from sklearn.datasets import california_housing
import pandas as pd
import numpy as np

data=california_housing.fetch_california_housing()
cols=data.feature_names
cols.append('y')
data=pd.DataFrame(np.column_stack([data.data,data.target.reshape((data.target.shape[0],1))]),columns=cols)
data.to_csv('/media/jma/DATA/calhousing.csv',index=False)

生成器代码:

import csv 
import numpy as np

def generate_arrays_from_file(file_name,batchsz):
csvfile = open(file_name)
reader = csv.reader(csvfile)

batchCount = 0

inputs = []
targets = []


while True: #infinite loop

linecounter=0 #which line the reader is reading

for line in reader:

if linecounter >0: #is not the header

inputs.append(line[0:8])
targets.append(line[8])

batchCount += 1 # we added

if batchCount >= batchsz: # we have our mini batch

batchCount = 0 #reset batch counter

X = np.array(inputs,dtype="float32")
y = np.array(targets,dtype="float32")

yield (X, y)

#reset the lists to hold the batches
inputs = []
targets = []

linecounter += 1 #increment the line read

linecounter = 0 #reset

像这样运行:

from keras.models import Sequential
from keras.layers import Dense


batch_size =100

train_gen=generate_arrays_from_file('/media/jma/DATA/calhousing.csv',batchsz=batch_size)


model = Sequential()
model.add(Dense(32, input_shape=(8,)))
model.add(Dense(1, activation='linear'))
model.compile(optimizer='rmsprop',
loss='mse', metrics=['mse'])


model.fit_generator(train_gen,steps_per_epoch=data.shape[0] / batch_size, epochs=5, verbose=1)

纪元 1/5194/206 [===========================>..] - 预计到达时间:0秒 - 损失:67100.1775 - 均方误差:67100.1775

最佳答案

OP 改变了什么:

import csv 
import numpy as np

def generate_arrays_from_file(file_name,batchsz):

###################
### Moved this: ###
###################
# csvfile = open(file_name)
# reader = csv.reader(csvfile)
### End ###########

batchCount = 0

inputs = []
targets = []

linecounter=0 #which line the reader is reading

while True: #infinite loop
################
### to here: ###
################
with open(file_name, "r") as csvfile:
for line in csv.reader(csvfile):
### End ###########

if linecounter >0: #is not the header

#could procress data as well
inputs.append(line[0:8])
targets.append(line[8])

batchCount += 1 # we added

if batchCount >= batchsz: # we have our mini batch
batchCount = 0 #reset batch counter
X = np.array(inputs,dtype="float32")
y = np.array(targets,dtype="float32")
yield (X, y)

#reset the lists to hold the batches
inputs = []
targets = []

linecounter += 1 #increment the line read
linecounter = 0

关于python - 发电机停止 Keras fit_generator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48844790/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com