gpt4 book ai didi

r - Sparklyr : Column prediction must be of type org. apache.spark.mllib.linalg.VectorUDT@f71b0bce 上的类型错误,但实际上是 DoubleType

转载 作者:行者123 更新时间:2023-11-30 08:45:48 25 4
gpt4 key购买 nike

我正在尝试按照此网站上的本教程进行操作:https://beta.rstudioconnect.com/content/1518/notebook-classification.html#auc_and_accuracy

我不知道为什么,因为我刚刚粘贴了网站上的代码。我也不知道如何将列转换为正确的类型。有人有解决办法吗? :)

我的数据位于分区中,并且具有以下形状:

> partition
$train
# Source: table<sparklyr_tmp_100e145972790> [?? x 9]
# Database: spark_connection
Survived Pclass Sex Age SibSp Parch Fare Embarked Family_Sizes
<dbl> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 0. 1 female 2. 1. 2. 152. S 1
2 0. 1 female 25. 1. 2. 152. S 1
3 0. 1 female 50. 0. 0. 28.7 C 0
4 0. 1 male 18. 1. 0. 109. C 1
5 0. 1 male 19. 1. 0. 53.1 S 1
6 0. 1 male 19. 3. 2. 263. S 2
7 0. 1 male 22. 0. 0. 136. C 0
8 0. 1 male 24. 0. 0. 79.2 C 0
9 0. 1 male 24. 0. 1. 248. C 1
10 0. 1 male 27. 0. 2. 212. C 1
# ... with more rows

然后我只应用一个模型,例如逻辑回归。

# Create table references
train_tbl <- partition$train
test_tbl <- partition$test

# Model survival as a function of several predictors
ml_formula <- formula(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare +
Embarked + Family_Sizes)

# Train a logistic regression model
ml_log <- ml_logistic_regression(train_tbl, ml_formula)

# Create a function for scoring
score_test_data <- function(model, data=test_tbl){
pred <- sdf_predict(model, data)
select(pred, Survived, prediction)
}

# Calculate the score and AUC metric
ml_score <- score_test_data(ml_log)

现在,ml_score 是:

> ml_score
# Source: lazy query [?? x 2]
# Database: spark_connection
Survived prediction
<dbl> <dbl>
1 0. 1.
2 0. 0.
3 0. 0.
4 0. 0.
5 0. 0.
6 0. 0.
7 0. 0.
8 0. 0.
9 0. 0.
10 0. 0.
# ... with more rows

现在我应用函数 ml_binart_classification_eval :

ml_binary_classification_eval(ml_score,'Survived','prediction')

然后我有错误:

Error: java.lang.IllegalArgumentException: requirement failed: Column prediction must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually DoubleType.
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.util.SchemaUtils$.checkColumnType(SchemaUtils.scala:42)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:82)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at sparklyr.Invoke$.invoke(invoke.scala:102)
at sparklyr.StreamHandler$.handleMethodCall(stream.scala:97)
at sparklyr.StreamHandler$.read(stream.scala:62)
at sparklyr.BackendHandler.channelRead0(handler.scala:52)
at sparklyr.BackendHandler.channelRead0(handler.scala:14)
at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
at java.lang.Thread.run(Thread.java:748)

最佳答案

在当前版本中,您需要将原始预测列名称传递给 ml_binary_classification_evaluator()。默认情况下它是 "rawPrediction" 文档 ?ml_evaluator 不正确,现已更新。

关于r - Sparklyr : Column prediction must be of type org. apache.spark.mllib.linalg.VectorUDT@f71b0bce 上的类型错误,但实际上是 DoubleType,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50016746/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com