上下文:
tidymodels
元包允许简化机器学习方法。我正试图在我的数据上使用它,并使用DALEX
包探索模型。
问题:
当我在工作流中使用tune::last_fit()
或DALEXtra::explain_tidymodels()
时,我会收到以下错误:
错误:无法计算rmse()
。
由错误引起:!无法为不存在的列设置子集。列.pred
不存在。
或
2:未知或未初始化的列:.pred
。
当我查看这个工作流和测试数据集上predict()
的输出时,输出是一个tibble,其中有一列名为.pred_res
。这很奇怪,因为通常该列被命名为.pred
。我只在线性模型中观察到了这种行为。
**问题:**如何避免predict()
的输出名称不一致?
- 试验:**我试图找出
predict()
输出名称发生变化的原因,但我无法使用示例数据集得到相同的“错误”名称,因此问题可能来自我的数据。我还没有弄明白为什么。
可复制示例:
library(tidyverse)
library(tidymodels)
set.seed(55)
split_index <- initial_split(df, prop = 0.80)
train_data <- training(split_index)
test_data <- testing(split_index)
lm_mod <- linear_reg() %>% set_engine("lm")
rec_ex <- recipe(var_01 ~ ., data = train_data) %>%
step_zv(all_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_normalize(all_numeric_predictors())
wflow <- workflow(rec_ex, lm_mod)
model_fitted <- fit(wflow, data = train_data)
predict(model_fitted, new_data = test_data)
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> # A tibble: 6 × 1
#> .pred_res
#> <dbl>
#> 1 6.97
#> 2 9.10
#> 3 6.70
#> 4 4.68
#> 5 8.53
#> 6 4.04
# latter errors
tune::last_fit(wflow, split_index, metrics = metric_set(rmse, rsq))
#> → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
#> There were issues with some computations A: x1 → B | error: Failed to compute `rmse()`.
#> Caused by error:
#> ! Can't subset columns that don't exist.
#> ✖ Column `.pred` doesn't exist.
#> There were issues with some computations A: x1There were issues with some computations A: x1 B: x1
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Resampling results
#> # Manual resampling
#> # A tibble: 1 × 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [24/6]> train/test split <NULL> <tibble [1 × 3]> <NULL>
#>
#> There were issues with some computations:
#>
#> - Error(s) x1: Failed to compute `rmse()`. Caused by error: ! Can't subset colum...
#>
#> Run `show_notes(.Last.tune.result)` for more information.
#
expl <- DALEXtra::explain_tidymodels(model = model_fitted, data = train_data[-1], y = train_data[["var_01"]])
#> Preparation of a new explainer is initiated
#> -> model label : workflow ( default )
#> -> data : 24 rows 10 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 24 values
#> -> predict function : yhat.workflow will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> Warning: Unknown or uninitialised column: `.pred`.
#> -> model_info : package tidymodels , ver. 1.1.0 , task regression ( default )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): Unknown or uninitialised column: `.pred`.
#> Warning in min(y_hat): aucun argument trouvé pour min ; Inf est renvoyé
#> Warning in mean.default(y_hat): l'argument n'est ni numérique, ni logique :
#> renvoi de NA
#> Warning in max(y_hat): aucun argument pour max ; -Inf est renvoyé
#> -> predicted values : numerical, min = Inf , mean = NA , max = -Inf
#> -> residual function : difference between y and yhat ( default )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> Warning: Unknown or uninitialised column: `.pred`.
#> Warning in min(residuals): aucun argument trouvé pour min ; Inf est renvoyé
#> Warning in max(residuals): aucun argument pour max ; -Inf est renvoyé
#> -> residuals : numerical, min = Inf , mean = NaN , max = -Inf
#> A new explainer has been created!
字符串
创建df
的数据:
df <- structure(list(var_01 = c(7.19629231495312, 6.40984259572719,
6.27257177433827, 8.00374317718997, 9.01579517195628, 10.4956246669712,
7.1507318272229, 8.63141379388915, 3.5783397329538, 9.80138831562098,
5.67664288362839, 5.86270848138085, 7.70754823877525, 4.54353336735003,
4.41906983044935, 5.7861067955948, 4.65799688139184, 9.81761822083165,
7.96694789250631, 7.76621541494846, 10.3171208898987, 6.57979209451482,
6.28851433743331, 7.10593824215654, 6.3189033905231, 6.84749525108044,
4.89179975550183, 4.27178831250932, 7.346738509904, 4.73880904124239
), var_02 = c(0.0250696386912475, 0.0917431290077227, 0.0917431290077227,
0.398148162370976, 0.0250696386912475, 0, 0, 0, 0.0177304971816102,
0.0250696386912475, 0.0177304971816102, 0.0177304971816102, 0.0778688561527045,
0.0917431290077227, 0.0869565304907941, 0.00303030313658471,
0.0825688161069504, 0.0250696386912475, 0.0573770519019928, 0,
0.00303030313658471, 0.0177304971816102, 0.0888157928551334,
0, 0, 0.427152334251301, 1, 1, 0.00303030313658471, 0.0869565304907941
), var_03 = c(0.0557103082027722, 0.908256870992277, 0.908256870992277,
0.601851837629024, 0.0557103082027722, 0, 1, 0, 0, 0.974930361308753,
0, 0, 0.922131143847295, 0.908256870992277, 0.913043469509206,
0, 0.91743118389305, 0.974930361308753, 0.942622948098007, 0,
0, 0.98226950281839, 0.911184207144867, 1, 0, 0.572847665748699,
0, 0, 0, 0.913043469509206), var_04 = c(0.91922005310598, 0,
0, 0, 0.91922005310598, 1, 0, 1, 0.98226950281839, 0, 0.98226950281839,
0.98226950281839, 0, 0, 0, 0.996969696863415, 0, 0, 0, 1, 0.996969696863415,
0, 0, 0, 1, 0, 0, 0, 0.996969696863415, 0), var_05 = c(0.502890173410405,
0.81651376146789, 0.81651376146789, 0.605678233438486, 0.502890173410405,
0.452991452991453, 0.773722627737226, 0.452991452991453, 0.390681003584229,
0.502890173410405, 0.390681003584229, 0.390681003584229, 0.379912663755459,
0.81651376146789, 0.809090909090909, 0.60377358490566, 0.81651376146789,
0.502890173410405, 0.379912663755459, 0.452991452991453, 0.60377358490566,
0.390681003584229, 0.435215946843854, 0.773722627737226, 0.452991452991453,
0.576271186440678, 0.754098360655738, 0.754098360655738, 0.60377358490566,
0.809090909090909), var_06 = c(0.00578034682080925, 0, 0, 0.0189274447949527,
0.00578034682080925, 0.0427350427350427, 0.072992700729927, 0.0427350427350427,
0.0501792114695341, 0.00578034682080925, 0.0501792114695341,
0.0501792114695341, 0.0262008733624454, 0, 0, 0.0188679245283019,
0, 0.00578034682080925, 0.0262008733624454, 0.0427350427350427,
0.0188679245283019, 0.0501792114695341, 0.0465116279069767, 0.072992700729927,
0.0427350427350427, 0.0203389830508475, 0.0819672131147541, 0.0819672131147541,
0.0188679245283019, 0), var_07 = c(43.5001740470859, 32.3527481198312,
25.1446396085035, 22.114003075575, 39.782754, 46.6101768, 54.7703728,
38.5200794, 49.8644040600424, 14.2498672, 53.1737503261588, 88.4927937438296,
131.189808315445, 44.1463295668831, 41.2608870959689, 38.0888602065136,
35.1336068510951, 50.3895112010419, 133.873532494894, 48.7908703333333,
20.9324783797359, 92.8490890224401, 6.02135625, 50.4673972, 48.5801272,
26.7198413722397, 17.2102265, 24.571311, 22.8074026807395, 7.51621
), var_08 = c(7996.35552336138, 6643.27476793249, 5888.67438138256,
1024.71316538883, 7091.4, 1296.6, 1523.6, 1071.55, 7433.87521488188,
2724.64, 8834.26080278937, 10138.9031845642, 10069.4584107327,
7154.99668831169, 7328.75436873337, 1376.55921008303, 7214.29298790454,
8526.14402724905, 9300.53665976904, 1813.05, 934.217809988519,
10449.5589122055, 1286.61458333333, 1403.9, 1351.4, 1431.71466876972,
3475.8, 5428.3, 864.356504082237, 1582.36), var_09 = c(84.9264705882353,
90.5544147843942, 102.810304449649, 19.9332843971281, 78.25311942959,
10.5236452653833, 10.5236452653833, 10.5236452653833, 62.9476584022039,
70.5544933078394, 69.8924731182796, 49.3576017130621, 38.9678553737951,
71.4748784440843, 79.0408525754885, 15.2615972702899, 90.1437371663244,
78.1725888324873, 35.0397998597178, 15.7333626712446, 18.7407520633081,
48.7421383647799, 92.3076923076923, 10.5236452653833, 10.5236452653833,
22.9351465712713, 89.3555119684218, 97.973282744254, 16.068780950685,
90.3157894736842), var_10 = c(2.12256267409471, 2.84403669724771,
2.84403669724771, 2.16975308641975, 2.12256267409471, 1.81742738589212,
2.06934306569343, 1.81742738589212, 1.50709219858156, 2.12256267409471,
1.50709219858156, 1.50709219858156, 1.38729508196721, 2.84403669724771,
2.7, 2.13181818181818, 2.84403669724771, 2.12256267409471, 1.38729508196721,
1.81742738589212, 2.13181818181818, 1.50709219858156, 1.60197368421053,
2.06934306569343, 1.81742738589212, 2.12251655629139, 2.39344262295082,
2.39344262295082, 2.13181818181818, 2.7), var_11 = c(10.6604456824513,
12.8660550458716, 12.8660550458716, 11.9376543209877, 10.6604456824513,
9.23112033195021, 13.7715328467153, 9.23112033195021, 9.52872340425532,
10.6604456824513, 9.52872340425532, 9.52872340425532, 8.44303278688525,
12.8660550458716, 12.4608695652174, 11.8209090909091, 12.8660550458716,
10.6604456824513, 8.44303278688525, 9.23112033195021, 11.8209090909091,
9.52872340425532, 9.82927631578947, 13.7715328467153, 9.23112033195021,
11.8251655629139, 13.2737704918033, 13.2737704918033, 11.8209090909091,
12.4608695652174)), row.names = c(NA, -30L), class = c("tbl_df",
"tbl", "data.frame"), na.action = structure(c(`46` = 46L), class = "omit"))
型
Session信息
R version 4.3.1 (2023-06-16 ucrt)
tidymodels 1.1.0 ──
✔ broom 1.0.5 ✔ rsample 1.1.1
✔ dials 1.2.0 ✔ tune 1.1.1
✔ infer 1.0.4 ✔ workflows 1.1.3
✔ modeldata 1.1.0 ✔ workflowsets 1.0.1
✔ parsnip 1.1.0 ✔ yardstick 1.2.0
✔ recipes 1.0.6
型
1条答案
按热度按时间dldeef671#
你得到这个错误,因为你正在使用CRAN版本的{parsnip}和4.3.0或更高版本的R。(2023-07-18)
此错误已在https://github.com/tidymodels/parsnip/pull/987中修复,并将很快在CRAN上发布。
发生这种情况是因为您拟合的模型是秩不足的,而4.3.0更改了
lm()
模型的预测方法,使其在发生这种情况时提供更多的信息。{parsnip}没有预料到输出的变化,因此您看到的bug。字符串
创建于2023-07-18,使用reprex v2.0.2