如何避免predict()输出的列名不一致?

7rtdyuoh  于 2023-07-31  发布在  其他
关注(0)|答案(1)|浏览(104)

上下文:

tidymodels元包允许简化机器学习方法。我正试图在我的数据上使用它,并使用DALEX包探索模型。

问题:

当我在工作流中使用tune::last_fit()DALEXtra::explain_tidymodels()时,我会收到以下错误:
错误:无法计算rmse()
由错误引起:!无法为不存在的列设置子集。列.pred不存在。

2:未知或未初始化的列:.pred
当我查看这个工作流和测试数据集上predict()的输出时,输出是一个tibble,其中有一列名为.pred_res。这很奇怪,因为通常该列被命名为.pred。我只在线性模型中观察到了这种行为。

**问题:**如何避免predict()的输出名称不一致?

  • 试验:**我试图找出predict()输出名称发生变化的原因,但我无法使用示例数据集得到相同的“错误”名称,因此问题可能来自我的数据。我还没有弄明白为什么。

可复制示例:

library(tidyverse)
library(tidymodels)

set.seed(55)
split_index <- initial_split(df, prop = 0.80)
train_data <- training(split_index)
test_data  <- testing(split_index)

lm_mod <- linear_reg() %>% set_engine("lm")
rec_ex <- recipe(var_01 ~ ., data = train_data) %>% 
  step_zv(all_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>%
  step_normalize(all_numeric_predictors())
wflow <- workflow(rec_ex, lm_mod)

model_fitted <- fit(wflow, data = train_data)
predict(model_fitted, new_data = test_data)
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> # A tibble: 6 × 1
#>   .pred_res
#>       <dbl>
#> 1      6.97
#> 2      9.10
#> 3      6.70
#> 4      4.68
#> 5      8.53
#> 6      4.04

# latter errors
tune::last_fit(wflow, split_index, metrics = metric_set(rmse, rsq))
#> → A | warning: prediction from rank-deficient fit; attr(*, "non-estim") has doubtful cases
#> There were issues with some computations   A: x1                                                 → B | error:   Failed to compute `rmse()`.
#>                Caused by error:
#>                ! Can't subset columns that don't exist.
#>                ✖ Column `.pred` doesn't exist.
#> There were issues with some computations   A: x1There were issues with some computations   A: x1   B: x1
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 5
#>   splits         id               .metrics .notes           .predictions
#>   <list>         <chr>            <list>   <list>           <list>      
#> 1 <split [24/6]> train/test split <NULL>   <tibble [1 × 3]> <NULL>      
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x1: Failed to compute `rmse()`. Caused by error: ! Can't subset colum...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.
# 
expl <- DALEXtra::explain_tidymodels(model = model_fitted, data = train_data[-1], y = train_data[["var_01"]])
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  (  default  )
#>   -> data              :  24  rows  10  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  24  values 
#>   -> predict function  :  yhat.workflow  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> Warning: Unknown or uninitialised column: `.pred`.
#>   -> model_info        :  package tidymodels , ver. 1.1.0 , task regression (  default  )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases

#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): Unknown or uninitialised column: `.pred`.
#> Warning in min(y_hat): aucun argument trouvé pour min ; Inf est renvoyé
#> Warning in mean.default(y_hat): l'argument n'est ni numérique, ni logique :
#> renvoi de NA
#> Warning in max(y_hat): aucun argument pour max ; -Inf est renvoyé
#>   -> predicted values  :  numerical, min =  Inf , mean =  NA , max =  -Inf  
#>   -> residual function :  difference between y and yhat (  default  )
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from rank-deficient fit; attr(*, "non-estim") has
#> doubtful cases
#> Warning: Unknown or uninitialised column: `.pred`.
#> Warning in min(residuals): aucun argument trouvé pour min ; Inf est renvoyé
#> Warning in max(residuals): aucun argument pour max ; -Inf est renvoyé
#>   -> residuals         :  numerical, min =  Inf , mean =  NaN , max =  -Inf  
#>   A new explainer has been created!

字符串
创建df的数据:

df <- structure(list(var_01 = c(7.19629231495312, 6.40984259572719, 
                                6.27257177433827, 8.00374317718997, 9.01579517195628, 10.4956246669712, 
                                7.1507318272229, 8.63141379388915, 3.5783397329538, 9.80138831562098, 
                                5.67664288362839, 5.86270848138085, 7.70754823877525, 4.54353336735003, 
                                4.41906983044935, 5.7861067955948, 4.65799688139184, 9.81761822083165, 
                                7.96694789250631, 7.76621541494846, 10.3171208898987, 6.57979209451482, 
                                6.28851433743331, 7.10593824215654, 6.3189033905231, 6.84749525108044, 
                                4.89179975550183, 4.27178831250932, 7.346738509904, 4.73880904124239
), var_02 = c(0.0250696386912475, 0.0917431290077227, 0.0917431290077227, 
              0.398148162370976, 0.0250696386912475, 0, 0, 0, 0.0177304971816102, 
              0.0250696386912475, 0.0177304971816102, 0.0177304971816102, 0.0778688561527045, 
              0.0917431290077227, 0.0869565304907941, 0.00303030313658471, 
              0.0825688161069504, 0.0250696386912475, 0.0573770519019928, 0, 
              0.00303030313658471, 0.0177304971816102, 0.0888157928551334, 
              0, 0, 0.427152334251301, 1, 1, 0.00303030313658471, 0.0869565304907941
), var_03 = c(0.0557103082027722, 0.908256870992277, 0.908256870992277, 
              0.601851837629024, 0.0557103082027722, 0, 1, 0, 0, 0.974930361308753, 
              0, 0, 0.922131143847295, 0.908256870992277, 0.913043469509206, 
              0, 0.91743118389305, 0.974930361308753, 0.942622948098007, 0, 
              0, 0.98226950281839, 0.911184207144867, 1, 0, 0.572847665748699, 
              0, 0, 0, 0.913043469509206), var_04 = c(0.91922005310598, 0, 
                                                      0, 0, 0.91922005310598, 1, 0, 1, 0.98226950281839, 0, 0.98226950281839, 
                                                      0.98226950281839, 0, 0, 0, 0.996969696863415, 0, 0, 0, 1, 0.996969696863415, 
                                                      0, 0, 0, 1, 0, 0, 0, 0.996969696863415, 0), var_05 = c(0.502890173410405, 
                                                                                                             0.81651376146789, 0.81651376146789, 0.605678233438486, 0.502890173410405, 
                                                                                                             0.452991452991453, 0.773722627737226, 0.452991452991453, 0.390681003584229, 
                                                                                                             0.502890173410405, 0.390681003584229, 0.390681003584229, 0.379912663755459, 
                                                                                                             0.81651376146789, 0.809090909090909, 0.60377358490566, 0.81651376146789, 
                                                                                                             0.502890173410405, 0.379912663755459, 0.452991452991453, 0.60377358490566, 
                                                                                                             0.390681003584229, 0.435215946843854, 0.773722627737226, 0.452991452991453, 
                                                                                                             0.576271186440678, 0.754098360655738, 0.754098360655738, 0.60377358490566, 
                                                                                                             0.809090909090909), var_06 = c(0.00578034682080925, 0, 0, 0.0189274447949527, 
                                                                                                                                            0.00578034682080925, 0.0427350427350427, 0.072992700729927, 0.0427350427350427, 
                                                                                                                                            0.0501792114695341, 0.00578034682080925, 0.0501792114695341, 
                                                                                                                                            0.0501792114695341, 0.0262008733624454, 0, 0, 0.0188679245283019, 
                                                                                                                                            0, 0.00578034682080925, 0.0262008733624454, 0.0427350427350427, 
                                                                                                                                            0.0188679245283019, 0.0501792114695341, 0.0465116279069767, 0.072992700729927, 
                                                                                                                                            0.0427350427350427, 0.0203389830508475, 0.0819672131147541, 0.0819672131147541, 
                                                                                                                                            0.0188679245283019, 0), var_07 = c(43.5001740470859, 32.3527481198312, 
                                                                                                                                                                               25.1446396085035, 22.114003075575, 39.782754, 46.6101768, 54.7703728, 
                                                                                                                                                                               38.5200794, 49.8644040600424, 14.2498672, 53.1737503261588, 88.4927937438296, 
                                                                                                                                                                               131.189808315445, 44.1463295668831, 41.2608870959689, 38.0888602065136, 
                                                                                                                                                                               35.1336068510951, 50.3895112010419, 133.873532494894, 48.7908703333333, 
                                                                                                                                                                               20.9324783797359, 92.8490890224401, 6.02135625, 50.4673972, 48.5801272, 
                                                                                                                                                                               26.7198413722397, 17.2102265, 24.571311, 22.8074026807395, 7.51621
                                                                                                                                            ), var_08 = c(7996.35552336138, 6643.27476793249, 5888.67438138256, 
                                                                                                                                                          1024.71316538883, 7091.4, 1296.6, 1523.6, 1071.55, 7433.87521488188, 
                                                                                                                                                          2724.64, 8834.26080278937, 10138.9031845642, 10069.4584107327, 
                                                                                                                                                          7154.99668831169, 7328.75436873337, 1376.55921008303, 7214.29298790454, 
                                                                                                                                                          8526.14402724905, 9300.53665976904, 1813.05, 934.217809988519, 
                                                                                                                                                          10449.5589122055, 1286.61458333333, 1403.9, 1351.4, 1431.71466876972, 
                                                                                                                                                          3475.8, 5428.3, 864.356504082237, 1582.36), var_09 = c(84.9264705882353, 
                                                                                                                                                                                                                 90.5544147843942, 102.810304449649, 19.9332843971281, 78.25311942959, 
                                                                                                                                                                                                                 10.5236452653833, 10.5236452653833, 10.5236452653833, 62.9476584022039, 
                                                                                                                                                                                                                 70.5544933078394, 69.8924731182796, 49.3576017130621, 38.9678553737951, 
                                                                                                                                                                                                                 71.4748784440843, 79.0408525754885, 15.2615972702899, 90.1437371663244, 
                                                                                                                                                                                                                 78.1725888324873, 35.0397998597178, 15.7333626712446, 18.7407520633081, 
                                                                                                                                                                                                                 48.7421383647799, 92.3076923076923, 10.5236452653833, 10.5236452653833, 
                                                                                                                                                                                                                 22.9351465712713, 89.3555119684218, 97.973282744254, 16.068780950685, 
                                                                                                                                                                                                                 90.3157894736842), var_10 = c(2.12256267409471, 2.84403669724771, 
                                                                                                                                                                                                                                               2.84403669724771, 2.16975308641975, 2.12256267409471, 1.81742738589212, 
                                                                                                                                                                                                                                               2.06934306569343, 1.81742738589212, 1.50709219858156, 2.12256267409471, 
                                                                                                                                                                                                                                               1.50709219858156, 1.50709219858156, 1.38729508196721, 2.84403669724771, 
                                                                                                                                                                                                                                               2.7, 2.13181818181818, 2.84403669724771, 2.12256267409471, 1.38729508196721, 
                                                                                                                                                                                                                                               1.81742738589212, 2.13181818181818, 1.50709219858156, 1.60197368421053, 
                                                                                                                                                                                                                                               2.06934306569343, 1.81742738589212, 2.12251655629139, 2.39344262295082, 
                                                                                                                                                                                                                                               2.39344262295082, 2.13181818181818, 2.7), var_11 = c(10.6604456824513, 
                                                                                                                                                                                                                                                                                                    12.8660550458716, 12.8660550458716, 11.9376543209877, 10.6604456824513, 
                                                                                                                                                                                                                                                                                                    9.23112033195021, 13.7715328467153, 9.23112033195021, 9.52872340425532, 
                                                                                                                                                                                                                                                                                                    10.6604456824513, 9.52872340425532, 9.52872340425532, 8.44303278688525, 
                                                                                                                                                                                                                                                                                                    12.8660550458716, 12.4608695652174, 11.8209090909091, 12.8660550458716, 
                                                                                                                                                                                                                                                                                                    10.6604456824513, 8.44303278688525, 9.23112033195021, 11.8209090909091, 
                                                                                                                                                                                                                                                                                                    9.52872340425532, 9.82927631578947, 13.7715328467153, 9.23112033195021, 
                                                                                                                                                                                                                                                                                                    11.8251655629139, 13.2737704918033, 13.2737704918033, 11.8209090909091, 
                                                                                                                                                                                                                                                                                                    12.4608695652174)), row.names = c(NA, -30L), class = c("tbl_df", 
                                                                                                                                                                                                                                                                                                                                                           "tbl", "data.frame"), na.action = structure(c(`46` = 46L), class = "omit"))

Session信息

R version 4.3.1 (2023-06-16 ucrt)

tidymodels 1.1.0 ──
✔ broom        1.0.5     ✔ rsample      1.1.1
✔ dials        1.2.0     ✔ tune         1.1.1
✔ infer        1.0.4     ✔ workflows    1.1.3
✔ modeldata    1.1.0     ✔ workflowsets 1.0.1
✔ parsnip      1.1.0     ✔ yardstick    1.2.0
✔ recipes      1.0.6

dldeef67

dldeef671#

你得到这个错误,因为你正在使用CRAN版本的{parsnip}和4.3.0或更高版本的R。(2023-07-18)
此错误已在https://github.com/tidymodels/parsnip/pull/987中修复,并将很快在CRAN上发布。
发生这种情况是因为您拟合的模型是秩不足的,而4.3.0更改了lm()模型的预测方法,使其在发生这种情况时提供更多的信息。{parsnip}没有预料到输出的变化,因此您看到的bug。

library(tidyverse)
library(tidymodels)

df <- tibble(
  var_01 = c(
    7.19629231495312, 6.40984259572719,
    6.27257177433827, 8.00374317718997, 9.01579517195628, 10.4956246669712,
    7.1507318272229, 8.63141379388915, 3.5783397329538, 9.80138831562098,
    5.67664288362839, 5.86270848138085, 7.70754823877525, 4.54353336735003,
    4.41906983044935, 5.7861067955948, 4.65799688139184, 9.81761822083165,
    7.96694789250631, 7.76621541494846, 10.3171208898987, 6.57979209451482,
    6.28851433743331, 7.10593824215654, 6.3189033905231, 6.84749525108044,
    4.89179975550183, 4.27178831250932, 7.346738509904, 4.73880904124239
  ),
  var_02 = c(
    0.0250696386912475, 0.0917431290077227, 0.0917431290077227,
    0.398148162370976, 0.0250696386912475, 0, 0, 0, 0.0177304971816102,
    0.0250696386912475, 0.0177304971816102, 0.0177304971816102, 0.0778688561527045,
    0.0917431290077227, 0.0869565304907941, 0.00303030313658471,
    0.0825688161069504, 0.0250696386912475, 0.0573770519019928, 0,
    0.00303030313658471, 0.0177304971816102, 0.0888157928551334,
    0, 0, 0.427152334251301, 1, 1, 0.00303030313658471, 0.0869565304907941
  ),
  var_03 = c(
    0.0557103082027722, 0.908256870992277, 0.908256870992277,
    0.601851837629024, 0.0557103082027722, 0, 1, 0, 0, 0.974930361308753,
    0, 0, 0.922131143847295, 0.908256870992277, 0.913043469509206,
    0, 0.91743118389305, 0.974930361308753, 0.942622948098007, 0,
    0, 0.98226950281839, 0.911184207144867, 1, 0, 0.572847665748699,
    0, 0, 0, 0.913043469509206
  ),
  var_04 = c(
    0.91922005310598, 0,
    0, 0, 0.91922005310598, 1, 0, 1, 0.98226950281839, 0, 0.98226950281839,
    0.98226950281839, 0, 0, 0, 0.996969696863415, 0, 0, 0, 1, 0.996969696863415,
    0, 0, 0, 1, 0, 0, 0, 0.996969696863415, 0
  ),
  var_05 = c(
    0.502890173410405,
    0.81651376146789, 0.81651376146789, 0.605678233438486, 0.502890173410405,
    0.452991452991453, 0.773722627737226, 0.452991452991453, 0.390681003584229,
    0.502890173410405, 0.390681003584229, 0.390681003584229, 0.379912663755459,
    0.81651376146789, 0.809090909090909, 0.60377358490566, 0.81651376146789,
    0.502890173410405, 0.379912663755459, 0.452991452991453, 0.60377358490566,
    0.390681003584229, 0.435215946843854, 0.773722627737226, 0.452991452991453,
    0.576271186440678, 0.754098360655738, 0.754098360655738, 0.60377358490566,
    0.809090909090909
  ),
  var_06 = c(
    0.00578034682080925, 0, 0, 0.0189274447949527,
    0.00578034682080925, 0.0427350427350427, 0.072992700729927, 0.0427350427350427,
    0.0501792114695341, 0.00578034682080925, 0.0501792114695341,
    0.0501792114695341, 0.0262008733624454, 0, 0, 0.0188679245283019,
    0, 0.00578034682080925, 0.0262008733624454, 0.0427350427350427,
    0.0188679245283019, 0.0501792114695341, 0.0465116279069767, 0.072992700729927,
    0.0427350427350427, 0.0203389830508475, 0.0819672131147541, 0.0819672131147541,
    0.0188679245283019, 0
  ),
  var_07 = c(
    43.5001740470859, 32.3527481198312,
    25.1446396085035, 22.114003075575, 39.782754, 46.6101768, 54.7703728,
    38.5200794, 49.8644040600424, 14.2498672, 53.1737503261588, 88.4927937438296,
    131.189808315445, 44.1463295668831, 41.2608870959689, 38.0888602065136,
    35.1336068510951, 50.3895112010419, 133.873532494894, 48.7908703333333,
    20.9324783797359, 92.8490890224401, 6.02135625, 50.4673972, 48.5801272,
    26.7198413722397, 17.2102265, 24.571311, 22.8074026807395, 7.51621
  ),
  var_08 = c(
    7996.35552336138, 6643.27476793249, 5888.67438138256,
    1024.71316538883, 7091.4, 1296.6, 1523.6, 1071.55, 7433.87521488188,
    2724.64, 8834.26080278937, 10138.9031845642, 10069.4584107327,
    7154.99668831169, 7328.75436873337, 1376.55921008303, 7214.29298790454,
    8526.14402724905, 9300.53665976904, 1813.05, 934.217809988519,
    10449.5589122055, 1286.61458333333, 1403.9, 1351.4, 1431.71466876972,
    3475.8, 5428.3, 864.356504082237, 1582.36
  ),
  var_09 = c(
    84.9264705882353,
    90.5544147843942, 102.810304449649, 19.9332843971281, 78.25311942959,
    10.5236452653833, 10.5236452653833, 10.5236452653833, 62.9476584022039,
    70.5544933078394, 69.8924731182796, 49.3576017130621, 38.9678553737951,
    71.4748784440843, 79.0408525754885, 15.2615972702899, 90.1437371663244,
    78.1725888324873, 35.0397998597178, 15.7333626712446, 18.7407520633081,
    48.7421383647799, 92.3076923076923, 10.5236452653833, 10.5236452653833,
    22.9351465712713, 89.3555119684218, 97.973282744254, 16.068780950685,
    90.3157894736842
  ),
  var_10 = c(
    2.12256267409471, 2.84403669724771,
    2.84403669724771, 2.16975308641975, 2.12256267409471, 1.81742738589212,
    2.06934306569343, 1.81742738589212, 1.50709219858156, 2.12256267409471,
    1.50709219858156, 1.50709219858156, 1.38729508196721, 2.84403669724771,
    2.7, 2.13181818181818, 2.84403669724771, 2.12256267409471, 1.38729508196721,
    1.81742738589212, 2.13181818181818, 1.50709219858156, 1.60197368421053,
    2.06934306569343, 1.81742738589212, 2.12251655629139, 2.39344262295082,
    2.39344262295082, 2.13181818181818, 2.7
  ),
  var_11 = c(
    10.6604456824513,
    12.8660550458716, 12.8660550458716, 11.9376543209877, 10.6604456824513,
    9.23112033195021, 13.7715328467153, 9.23112033195021, 9.52872340425532,
    10.6604456824513, 9.52872340425532, 9.52872340425532, 8.44303278688525,
    12.8660550458716, 12.4608695652174, 11.8209090909091, 12.8660550458716,
    10.6604456824513, 8.44303278688525, 9.23112033195021, 11.8209090909091,
    9.52872340425532, 9.82927631578947, 13.7715328467153, 9.23112033195021,
    11.8251655629139, 13.2737704918033, 13.2737704918033, 11.8209090909091,
    12.4608695652174
  )
)

set.seed(55)
split_index <- initial_split(df, prop = 0.80)
train_data <- training(split_index)
test_data  <- testing(split_index)

lm_mod <- linear_reg() %>% set_engine("lm")
rec_ex <- recipe(var_01 ~ ., data = train_data) %>% 
  step_zv(all_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>%
  step_normalize(all_numeric_predictors())
wflow <- workflow(rec_ex, lm_mod)

model_fitted <- fit(wflow, data = train_data)
predict(model_fitted, new_data = test_data)
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response", : prediction from rank-deficient fit; consider predict(.,
#> rankdeficient="NA")
#> # A tibble: 6 × 1
#>   .pred
#>   <dbl>
#> 1  6.97
#> 2  9.10
#> 3  6.70
#> 4  4.68
#> 5  8.53
#> 6  4.04

字符串
创建于2023-07-18,使用reprex v2.0.2

相关问题