R语言 将多个不同模型的预测值按组添加到新数据集

fcwjkofz  于 2023-07-31  发布在  其他
关注(0)|答案(2)|浏览(92)

给定分组数据的 Dataframe :

library(tidyverse)

# fake up some grouped data: 
set.seed(123)
dat <- data.frame(x = rnorm(100), 
                  y = rnorm(100), 
                  group = rep(x=letters[1:10],each=10))
head(dat)
> head(dat)
            x           y group
1 -0.56047565 -0.71040656     a
2 -0.23017749  0.25688371     a
3  1.55870831 -0.24669188     a
4  0.07050839 -0.34754260     a
5  0.12928774 -0.95161857     a
6  1.71506499 -0.04502772     a

字符串
我想通过一个(或多个)分组列构建一组独立的模型:

# store models by group in a list
models <- list()
for(i in letters[1:10]) {
  models[[paste0("mdl_",i)]] = lm(y ~ x, dat %>% filter(group == i))
}

names(models)
 [1] "mdl_a" "mdl_b" "mdl_c" "mdl_d" "mdl_e" "mdl_f" "mdl_g" "mdl_h" "mdl_i" "mdl_j"


我可以通过多种方式将模型预测值(拟合值)添加到原始数据框中,这种方式很方便:

# add model predictions (fitted values) column to original data frame
dat <- dat %>%
  group_by(group) %>%
  mutate(fits = lm(y ~ x)$fitted.values)

# verify prediction from stored models and fitted values column match 
# to within a 10-decimal tolerance: 
for(i in letters[1:10]) {
  tmp <- dat %>%
    filter(group == i) %>%
    select(group, x, y, fits)
  tmp$stored_fit = predict(models[[paste0("mdl_",i)]], tmp)
  print(paste("mdl", i, "results match:", all(round(tmp$stored_fit,10) == round(tmp$fits,10))))
}
[1] "mdl a results match: TRUE"
[1] "mdl b results match: TRUE"
[1] "mdl c results match: TRUE"
[1] "mdl d results match: TRUE"
[1] "mdl e results match: TRUE"
[1] "mdl f results match: TRUE"
[1] "mdl g results match: TRUE"
[1] "mdl h results match: TRUE"
[1] "mdl i results match: TRUE"
[1] "mdl j results match: TRUE"


所有这些步骤都在其他问题中讨论过,比如this one
现在,我想在一个新的data.frame上生成这些模型的预测,并将这些预测作为一个列添加到data. frame中。
以下是我尝试过的几件事:

# fake up some new grouped data: 
set.seed(456)
dat2 <- data.frame(x = rnorm(100), 
                   y = rnorm(100), 
                   group = rep(x=letters[1:10],each=10))

方法一(应用):

tmp <- dat2 %>%
  group_by(group) %>%
  nest() # %>%
  # mutate(fits = map())

fits = as.data.frame(apply(X = tmp, MARGIN=1, FUN = function(X) predict(models[[paste0("mdl_",X$group)]], X$data)))
names(fits) = tmp$group
fits <- fits %>% 
  pivot_longer(cols = everything(), names_to = "group.fits") %>% 
  arrange(group.fits)

tmp <- tmp %>%
  unnest(cols = c(data)) %>%
  bind_cols(fits)


感觉很容易出错很不优雅

方法二(for loop,base r):

tmp$fits = NA
for(g in unique(tmp$group)) {
  tmp[tmp$group==g,]$fits = predict(models[[paste0("mdl_",g)]], tmp[tmp$group==g,])
}
tmp


这没有什么特别的错误,除了循环在较大的数据集上是出了名的慢。

方法三(嵌套/Map):

我以为下面这样的东西会起作用,但我在语法上有问题...

dat2 %>%
  group_by(group) %>%
  nest() %>%
  mutate(fits = map(.f = predict(models[[paste0("mdl_",group)]]), data))

mutate(fits = map(.x = data, 
                    .f = predict(models[[paste0("mdl_",group)]],
                                 .x)))

我正在寻找方法3的路线沿着的某个地方的答案--理想情况下,所有这些都在一组dplyr命令中。

deyfvvtc

deyfvvtc1#

选项一:purrr::map2

要沿着方法3,您应该使用map2()来预测每个模型和数据。

dat2 %>%
  nest(.by = group) %>% # .by: {tidyr} >= v1.3.0
  mutate(fits = map2(group, data, ~ predict(models[[paste0("mdl_", .x)]], .y))) %>%
  unnest(c(data, fits))

字符串

选项二:rowwise

您也可以用rowwise()替换map2(),并用list()包围预测值。

dat2 %>%
  nest(.by = group) %>% # .by: {tidyr} >= v1.3.0
  rowwise() %>%
  mutate(fits = list(predict(models[[paste0("mdl_", group)]], data))) %>%
  unnest(c(data, fits))

选项三:group_modify

你甚至不需要nest/unnest{tidyr}。只需利用dplyr::group_modify()

dat2 %>%
  group_by(group) %>%
  group_modify(~ {
    .x %>% mutate(fits = predict(models[[paste0("mdl_", .y$group)]], .x))
  }) %>%
  ungroup()


所有方法返回相同的输出:

# # A tibble: 100 × 4
#    group      x       y   fits
#    <chr>  <dbl>   <dbl>  <dbl>
#  1 a     -1.34   0.118  -0.677
#  2 a      0.622  0.870  -0.287
#  3 a      0.801 -0.0919 -0.252
#  4 a     -1.39   0.0689 -0.686
#  5 a     -0.714 -1.68   -0.552
#  6 a     -0.324  1.12   -0.475
#  7 a      0.691 -1.35   -0.274
#  8 a      0.251 -0.537  -0.361
#  9 a      1.01  -0.370  -0.211
# 10 a      0.573  0.354  -0.297
# # ℹ 90 more rows
# # ℹ Use `print(n = ...)` to see more rows

基准测试

bench::mark(
  `purrr::map2` = {
    dat2 %>%
      nest(.by = group) %>%
      mutate(fits = map2(group, data, ~ predict(models[[paste0("mdl_", .x)]], .y))) %>%
      unnest(c(data, fits))
  }, `dplyr::rowwise` = {
    dat2 %>%
      nest(.by = group) %>%
      rowwise() %>%
      mutate(fits = list(predict(models[[paste0("mdl_", group)]], data))) %>%
      unnest(c(data, fits))
  }, `dplyr::group_modify` = {
    dat2 %>%
      group_by(group) %>%
      group_modify(~ {
        .x %>% mutate(fits = predict(models[[paste0("mdl_", .y$group)]], .x))
      }) %>%
      ungroup()
  },
  iterations = 100, min_time = Inf
)

# # A tibble: 3 × 13
#   expression            min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
#   <bch:expr>         <bch:> <bch:>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
# 1 purrr::map2        16.7ms 17.3ms      57.0    36.4KB     3.00    95     5      1.67s
# 2 dplyr::rowwise     19.9ms 20.2ms      48.7    40.7KB     3.11    94     6      1.93s
# 3 dplyr::group_modi… 33.1ms 34.3ms      28.8     186KB     3.20    90    10      3.13s

v1uwarro

v1uwarro2#

对于dplyr 1.0.9,purrr 0.3.4:

dat2 %>%
  group_by(group) %>%
  nest() %>%
  ungroup() %>%
  mutate(fits = map2(.x = group, 
                     .y = data, 
                     ~ predict(models[[paste0("mdl_", .x)]], .y))) %>%
  unnest(c(data, fits))

字符串

dat2 %>%
  nest_by(group) %>%
  rowwise() %>%
  mutate(fits = list(predict(models[[paste0("mdl_", group)]], data))) %>%
  unnest(c(data, fits))

相关问题