gpt4 book ai didi

java - Deeplearning4j 预测二手车价格

转载 作者:太空宇宙 更新时间:2023-11-04 09:38:18 30 4
gpt4 key购买 nike

我想预测二手车的价格,并且我有已售汽车的历史数据。我将数值缩放为 0-1,并使其他功能成为热门功能。

Data:

public RestResponse<JSONObject> buildModelDl4j( HttpServletRequest request, HttpServletResponse response, @RequestBody Map<String, String> json ) throws IOException
{
RestResponse<JSONObject> restResponse = ControllerBase.getRestResponse( request, response, null ) ;

String path = "\\HOME_EXCHANGE\\uploads\\" + json.get( "filePath" ) ;

int numLinesToSkip = 1 ;
char delimiter = ',' ;

RecordReader recordReader = new CSVRecordReader( numLinesToSkip, delimiter ) ;

try
{
recordReader.initialize( new FileSplit( new File( path ) ) ) ;
}
catch( InterruptedException e )
{
e.printStackTrace( ) ;
}

DataSetIterator iter = new RecordReaderDataSetIterator( recordReader, batchSize, indexToCalc, indexToCalc, true ) ;
json.put( "numAttr", String.valueOf( numAttr ) ) ;

// ds.shuffle( ) ; //TODO should I shuffle the data ?

MultiLayerNetwork net = buildNetwork( json ) ;

net.init( ) ;

net.setListeners( new ScoreIterationListener( 30 ) ) ;

DataSet testData = null ;

for( int i = 0; i < nEpochs; i++ )
{
iter.reset( ) ;

while( iter.hasNext( ) )
{
DataSet ds = iter.next( ) ;
SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
DataSet trainingData = testAndTrain.getTrain( ) ;
testData = testAndTrain.getTest( ) ;
net.fit( trainingData ) ;
}

iter.reset( ) ;

int cnt = 0 ;
while( iter.hasNext( ) && cnt++ < 3 )
{
DataSet ds = iter.next( ) ;
SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
testData = testAndTrain.getTest( ) ;
String testResults = testResults( net, testData, indexToCalc ) ;
System.err.println( "Test results: [" + i + "] \n" + testResults ) ;
}

}

RegressionEvaluation eval = new RegressionEvaluation( ) ;
INDArray output = net.output( testData.getFeatures( ) ) ;
eval.eval( testData.getLabels( ), output ) ;
System.out.println( eval.stats( ) ) ;

String testResults = testResults( net, testData, indexToCalc ) ;

result.put( "testResults", testResults ) ;

System.err.println( "Test results last: \n" + testResults ) ;

restResponse.setData( result ) ;

return restResponse ;
}

我使用从前端传递的参数构建模型,从 csv 文件读取数据,然后训练模型。我做的事正确吗?我应该如何使用测试和训练数据?示例中有两种方法,它们使用

DataSet ds = iter.next( ) ;
SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
DataSet trainingData = testAndTrain.getTrain( ) ;
testData = testAndTrain.getTest( ) ;
net.fit( trainingData ) ;

for( int i = 0; i < nEpochs; i++ )
{
net.fit( iter ) ;
iter.reset( ) ;
}

哪一种是正确的方法?

最佳答案

I build model with parameters passed from front-end, I read data from csv files then train the model. Am I doing the right thing? How should I use test and train data?

更好的方法是使用 DataSetIteratorSplitter ,如下所示:

DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,totalNumBatches,ratio);
multiLayerNetwork.fit(dataSetIteratorSplitter.getTrainIterator(),epochCount);

totalNumBatches 是数据集总数除以小批量大小。例如,如果您有 10000 个数据集,假设我们在单个批处理中分配 8 个样本,则总共有 1250 个批处理。

关于java - Deeplearning4j 预测二手车价格,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56236060/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com