R语言 ggplot的stat_function()给出错误结果

sh7euo9m  于 2023-02-06  发布在  其他
关注(0)|答案(1)|浏览(134)

我生成了一些数据来进行回归分析:

library(tidyverse)
library(nnet)

# Generating the data --------------------------
set.seed(100)
helicopter <- rnorm(20, mean = 35, sd = 3)
car <- rnorm(20, mean = 30, sd = 3)
bus <- rnorm(20, mean = 25, sd = 3)
bike <- rnorm(20, mean = 20, sd = 3)

transportation_data <- data.frame(helicopter, car, bus, bike) %>%
  pivot_longer(cols = 1:4, values_to = "income", names_to = "mode")

# Setting up the regression -------------------
transportation_regression <- multinom(mode~income, data = transportation_data)

到目前为止,一切顺利,现在我想用stat_function绘制回归结果(基于收入选择某种交通方式的概率):

ins <- coef(transportation_regression)[1:3]
betas <- coef(transportation_regression)[4:6]

transportation_data %>%
  ggplot(aes(x = income))+
  stat_function(fun = function(x) { 1 / (1 + sum(exp(ins + betas * x))) }, aes(color = "bike"))+
  stat_function(fun = function(x) { exp(ins[1] + betas[1] * x) / (1 + sum(exp(ins + betas * x))) }, aes(color = "bus"))+
  stat_function(fun = function(x) { exp(ins[2] + betas[2] * x) / (1 + sum(exp(ins + betas * x))) }, aes(color = "car"))+
  stat_function(fun = function(x) { exp(ins[3] + betas[3] * x) / (1 + sum(exp(ins + betas * x))) }, aes(color = "helicopter"))


我得到了这个输出,这显然是错误的,以及一个警告Warning: longer object length is not a multiple of shorter object length,我不知道它意味着什么。
当我使用相同的函数,但首先预测数据点时,一切都很好:

income <- seq(0,50,0.1)
result <- matrix( , nrow = length(income), ncol = 4)
i <- 1
for(x in income){
  result[i,1] <- 1 / (1 + sum(exp(ins + betas * x))) # bike
  result[i,2] <- exp(ins[1] + betas[1] * x) / (1 + sum(exp(ins + betas * x))) # bus
  result[i,3] <- exp(ins[2] + betas[2] * x) / (1 + sum(exp(ins + betas * x))) # car
  result[i,4] <- exp(ins[3] + betas[3] * x) / (1 + sum(exp(ins + betas * x))) # helicopter
  
  i <- i + 1
}

cbind(income, as.data.frame(result)) %>%
  pivot_longer(cols = V1:V4) %>%
  ggplot(aes(x = income, y = value, color = name))+
  geom_line()

为什么ggplot中的stat_function()不工作?

e4eetjau

e4eetjau1#

我认为这只是对函数工作原理的误解,下面是一个使用stat_function()生成正确结果的示例:

library(tidyverse)
library(nnet)

# Generating the data --------------------------
set.seed(100)
helicopter <- rnorm(20, mean = 35, sd = 3)
car <- rnorm(20, mean = 30, sd = 3)
bus <- rnorm(20, mean = 25, sd = 3)
bike <- rnorm(20, mean = 20, sd = 3)

transportation_data <- data.frame(helicopter, car, bus, bike) %>%
  pivot_longer(cols = 1:4, values_to = "income", names_to = "mode")

# Setting up the regression -------------------
transportation_regression <- multinom(mode~income, data = transportation_data)
#> # weights:  12 (6 variable)
#> initial  value 110.903549 
#> iter  10 value 48.674542
#> iter  20 value 46.980349
#> iter  30 value 46.766625
#> iter  40 value 46.734782
#> iter  50 value 46.732249
#> final  value 46.732163 
#> converged

ins <- coef(transportation_regression)[1:3]
betas <- coef(transportation_regression)[4:6]

transportation_data %>%
  ggplot(aes(x = income))+
  stat_function(fun = function(x) { 1 / (1 + exp(ins[1] + betas[1] * x) + exp(ins[2] + betas[2] * x) + exp(ins[3] + betas[3] * x)) }, aes(color = "bike"))+
  stat_function(fun = function(x) { exp(ins[1] + betas[1] * x) / (1 + exp(ins[1] + betas[1] * x) + exp(ins[2] + betas[2] * x) + exp(ins[3] + betas[3] * x)) }, aes(color = "bus"))+
  stat_function(fun = function(x) { exp(ins[2] + betas[2] * x) / (1 + exp(ins[1] + betas[1] * x) + exp(ins[2] + betas[2] * x) + exp(ins[3] + betas[3] * x)) }, aes(color = "car"))+
  stat_function(fun = function(x) { exp(ins[3] + betas[3] * x) / (1 + exp(ins[1] + betas[1] * x) + exp(ins[2] + betas[2] * x) + exp(ins[3] + betas[3] * x)) }, aes(color = "helicopter"))

最初有两个问题,例如,stat_function()的第一个示例,

stat_function(fun = function(x) { 
         1 / (1 + sum(exp(ins + betas * x))) }, 
     aes(color = "bike"))

我们期望ins + betas * x等于ins[1] + betas[1] * x + ins[2] + betas[2] * x + ins[3] + betas[3] * x,但这并不是循环使用insbetas,使它们成为与x一样长的向量,然后将betas乘以x,再加上ins
另一个问题是exp(ins ...)周围的sum()不是对行求和,而是对输出的所有行和列求和,得到一个标量值。
您还可以使用矩阵计算使其更加通用:

b <- coef(transportation_regression)

transportation_data %>%
  ggplot(aes(x = income))+
  stat_function(fun = function(x) { 1 / (1 + rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "bike"))+
  stat_function(fun = function(x) { exp(ins[1] + betas[1] * x) / (1 + rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "bus"))+
  stat_function(fun = function(x) { exp(ins[2] + betas[2] * x) / (1 + rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "car"))+
  stat_function(fun = function(x) { exp(ins[3] + betas[3] * x) / (1 + rowSums(exp(cbind(1, x) %*% t(b)))) }, aes(color = "helicopter"))

reprex package(v2.0.1)于2023年2月4日创建

相关问题