Previous chapter
Model Fitting BasicsFirst steps with tidymodels
Next chapter

Linear Regression using Tidymodels

In the previous section, we saw how to fit a model with lm(), predict() and retrieve key model figures using summary().

We will now use the linear_reg() function which is part of the parsnip package from the tidymodels framework. The advantage of using this function is not only the coverage of regularized models like glmnet() but also the different supported back-ends including stan, keras or spark (see also ?linear_reg).

linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars)

Note, that the coefficients are equal to the vanilla lm fit:

lm(mpg ~ ., data = mtcars)
## 
## Call:
## lm(formula = mpg ~ ., data = mtcars)
## 
## Coefficients:
## (Intercept)          cyl         disp           hp         drat  
##    12.30337     -0.11144      0.01334     -0.02148      0.78711  
##          wt         qsec           vs           am         gear  
##    -3.71530      0.82104      0.31776      2.52023      0.65541  
##        carb  
##    -0.19942

Additionally, we can output the call which is performed by parsnip under the hood using translate():

linear_reg() %>%
  set_engine("lm") %>% 
  translate()

Replacement for model formula

Instead of a model formula you can also specify the predictors and outcomes as data.frame using fit_xy():

linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(x = select(mtcars, -mpg), y = select(mtcars, mpg))

Parsnip predict

As opposed to the vanilla lm() fitting function the linear_reg() fit looks more complicated at first. However, compared to the lm() fit linear_reg() also different engines/back-ends which can be changed through the parameter in set_engine().

The fitted model can now be predicted using again the function predict():

mod <- linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars)
predict(mod, new_data = mtcars[1:5, ])

The tidymodels predict() function has 2 differences to the base predict():

  1. The output is a tibble or data.frame as opposed to a vector.
  2. The parameter for new_data is snake_case and not optional - thus it needs to be explicitly specified.

Summary

parsnip does not provide a detailed model summary as provided by the vanilla summary() output on the regression model object. The plain-text model output includes the model name, fitting time, the function call object and the coefficients:

mod <- linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars)
mod
## parsnip model object
## 
## Fit time:  2ms 
## 
## Call:
## stats::lm(formula = formula, data = data)
## 
## Coefficients:
## (Intercept)          cyl         disp           hp         drat  
##    12.30337     -0.11144      0.01334     -0.02148      0.78711  
##          wt         qsec           vs           am         gear  
##    -3.71530      0.82104      0.31776      2.52023      0.65541  
##        carb  
##    -0.19942

However, it does not include statistical significance of coefficients, variance explained \(R^2\), etc. To retrieve also these model specific statistics you can directly access the model object which parsnip stores in a specific list data structure as can be seen here:

str(mod, max.level = 1)
## List of 5
##  $ lvl    : NULL
##  $ spec   :List of 5
##   ..- attr(*, "class")= chr [1:2] "linear_reg" "model_spec"
##  $ fit    :List of 12
##   ..- attr(*, "class")= chr "lm"
##  $ preproc:List of 1
##  $ elapsed: 'proc_time' Named num [1:5] 0.002 0 0.002 0 0
##   ..- attr(*, "names")= chr [1:5] "user.self" "sys.self" "elapsed" "user.child" ...
##  - attr(*, "class")= chr [1:2] "_lm" "model_fit"

To access the fitted model object $fit and compute the standard summary you can use

summary(mod$fit)
## 
## Call:
## stats::lm(formula = formula, data = data)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.4506 -1.6044 -0.1196  1.2193  4.6271 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)  
## (Intercept) 12.30337   18.71788   0.657   0.5181  
## cyl         -0.11144    1.04502  -0.107   0.9161  
## disp         0.01334    0.01786   0.747   0.4635  
## hp          -0.02148    0.02177  -0.987   0.3350  
## drat         0.78711    1.63537   0.481   0.6353  
## wt          -3.71530    1.89441  -1.961   0.0633 .
## qsec         0.82104    0.73084   1.123   0.2739  
## vs           0.31776    2.10451   0.151   0.8814  
## am           2.52023    2.05665   1.225   0.2340  
## gear         0.65541    1.49326   0.439   0.6652  
## carb        -0.19942    0.82875  -0.241   0.8122  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.65 on 21 degrees of freedom
## Multiple R-squared:  0.869,  Adjusted R-squared:  0.8066 
## F-statistic: 13.93 on 10 and 21 DF,  p-value: 3.793e-07

Plot

To finally plot the model predictions in a tidy fashion with the input data ggplot2 we now focus on univariate regression models with only one predictor. The further steps include

  1. As a predictor we choose the weight wt and fit a new model mod_wt.
  2. Bind the original data with the predicted tibble.
  3. Create a scatter plot including the regression line.
library(dplyr)
library(tibble)

mod_wt <- linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ wt, data = mtcars)

mtcars %>% 
  bind_cols(predict(mod_wt, new_data = mtcars)) %>%
  ggplot(aes(x = wt)) + 
  geom_line(aes(y = .pred), group="lm") + 
  geom_point(aes(y = mpg))