Log transformations in numerical prediction

Jose M Sallan 2022-06-10 11 min read

When we ned to predict a continuous variable, it is frequent to use a log transformation when it is right-skewed. Here I will discuss why and how should be doing so in the context of a prediction job with the tidymodels workflow, and how can we obtain a predicted value of the original variable.

Apart from tidymodels I am using corrplot to display a correlation matrix, patchwork to present two plots together and kableExtra to present HTML tables.

library(tidymodels)
library(corrplot)
library(patchwork)
library(kableExtra)

I will use the diamonds dataset included with ggplot2, that contains the prices and other attributes of almost 54,000 diamonds:

diamonds %>% glimpse()
## Rows: 53,940
## Columns: 10
## $ carat   <dbl> 0.23, 0.21, 0.23, 0.29, 0.31, 0.24, 0.24, 0.26, 0.22, 0.23, 0.…
## $ cut     <ord> Ideal, Premium, Good, Premium, Good, Very Good, Very Good, Ver…
## $ color   <ord> E, E, E, I, J, J, I, H, E, H, J, J, F, J, E, E, I, J, J, J, I,…
## $ clarity <ord> SI2, SI1, VS1, VS2, SI2, VVS2, VVS1, SI1, VS2, VS1, SI1, VS1, …
## $ depth   <dbl> 61.5, 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, 59.4, 64…
## $ table   <dbl> 55, 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54, 62, 58…
## $ price   <int> 326, 326, 327, 334, 335, 336, 336, 337, 337, 338, 339, 340, 34…
## $ x       <dbl> 3.95, 3.89, 4.05, 4.20, 4.34, 3.94, 3.95, 4.07, 3.87, 4.00, 4.…
## $ y       <dbl> 3.98, 3.84, 4.07, 4.23, 4.35, 3.96, 3.98, 4.11, 3.78, 4.05, 4.…
## $ z       <dbl> 2.43, 2.31, 2.31, 2.63, 2.75, 2.48, 2.47, 2.53, 2.49, 2.39, 2.…

Our job will be predicting the price of diamonds from their attributes.

Exploratory analysis

Let’s start examining the correlations between numerical variables.

cor_diamonds <- diamonds %>% 
  select(where(is.numeric)) %>% 
  cor()

Let’s use corrplot to examine the correlations:

corrplot.mixed(cor_diamonds, order = "hclust")

We observe that carat, x, y and z are highly correlated with price and among themselves. In line with other analysis, I have kept carat and discarded x, y and z.

Let’s go now to the distribution of price:

ggplot(diamonds, aes(price)) +
  geom_histogram(bins = 30) +
  theme_minimal() +
  labs(title = "Distribution of price")

The histogram shows that price is a right skewed distribution, as has a long tail in the right hand side of the distribution. This means that some diamonds have a very large price, compared with the whole of the distribution. Something similar happens usually with variables housing price, income and other variables relevant in economics.

Let’s examine the distribution of the decimal logarithm of price.

ggplot(diamonds, aes(price)) +
  geom_histogram(bins = 30) +
  theme_minimal() +
  scale_x_log10() +
  labs(title = "Distribution of logarithm of price")

We observe a bimodal normal distribution. In general, the log of a right-skewed distribution looks similar to a normal distribution. This means that we will obtain better estimators in a regression model using the log transformation, as the residuals of the model will tend to be normal.

Let’s add a log_price variable to the dataset, that we will be using later in the prediction job.

diamonds <- diamonds %>%
  mutate(log_price = log10(price))

To end with the exploratory data analysis section, let’s examine the relationship between carat and log_price.

ggplot(diamonds, aes(carat, price)) +
  geom_point() +
  geom_smooth() +
  scale_x_log10() +
  theme_minimal() +
  labs(title = "Relationship between carat and price")

This relationship is nonlinear, so the model could benefit of the addition of a quadratic term.

Predicting the price

To check the opportunity of using the log variable, I will predict using the original price variable. The workflow starts by splitting the data into train and test sets, stratifying by the target variable. I am doing this to get a similar distribution of price in the train and test sets.

set.seed(1111)
split_price <- initial_split(diamonds, prop = 0.9, strata = price)

The rec_price recipe includes the transformations for this model:

  • removing variables with step_rm.
  • adding a quadratic term to carat with step_poly.
  • transforming categorical variables to a set of dummies generated with one hot encoding with step_dummy.
  • removing near zero variance variables with step_nzv.
rec_price <- training(split_price) %>%
  recipe(price ~ .) %>%
  step_rm(log_price, x, y, z) %>%
  step_poly(carat, degree = 2) %>%
  step_dummy(all_nominal(), one_hot = TRUE) %>%
  step_nzv(all_predictors()) %>%
  step_rm(color_1)

Here is the result of the recipe:

rec_price %>% prep() %>% juice() %>% glimpse()
## Rows: 48,545
## Columns: 21
## $ depth        <dbl> 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, 59.4, 64.…
## $ table        <dbl> 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54, 62, 5…
## $ price        <int> 326, 327, 334, 335, 336, 336, 337, 337, 338, 339, 340, 34…
## $ carat_poly_1 <dbl> -0.005631217, -0.005439625, -0.004864846, -0.004673253, -…
## $ carat_poly_2 <dbl> 0.006116938, 0.005640942, 0.004280652, 0.003849787, 0.005…
## $ cut_2        <dbl> 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, …
## $ cut_3        <dbl> 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, …
## $ cut_4        <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, …
## $ cut_5        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, …
## $ color_2      <dbl> 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, …
## $ color_3      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ color_4      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ color_5      <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ color_6      <dbl> 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, …
## $ color_7      <dbl> 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, …
## $ clarity_2    <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, …
## $ clarity_3    <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, …
## $ clarity_4    <dbl> 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ clarity_5    <dbl> 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ clarity_6    <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
## $ clarity_7    <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …

Let’s define a straightforward linear model…

lm <- linear_reg(mode = "regression") %>%
  set_engine("lm")

… and fit the model and the recipe with a workflow:

model_price <- workflow() %>%
  add_recipe(rec_price) %>%
  add_model(lm) %>%
  fit(training(split_price))

We obtain the predicted values of price with pred_price

pred_price <- model_price %>%
  predict(training(split_price)) %>%
  bind_cols(training(split_price))

… And here it is the real versus predicted variables plot:

ggplot(pred_price, aes(price, .pred)) +
  geom_point() +
  geom_abline(slope = 1, intercept = 0, color = "red") +
  theme_minimal() +
  labs(title = "Real versus predicted values (price model)", x = "price", y = "pred. price")

The result of this model is problematic, as it return negative prices for some observations. We can also avoid that with a log transformation. The logarithm can take any value, but it is only defined for positive values.

Finally, let’s define the metrics we will be using in this job:

metric_prediction <- metric_set(rsq, mae, rmse)

The values of the metrics in the train test of this model are:

metrics_price <- pred_price %>%
  metric_prediction(truth = price, estimate = .pred) %>%
  mutate(.model = "price") %>%
  select(-.estimator)

metrics_price
## # A tibble: 3 × 3
##   .metric .estimate .model
##   <chr>       <dbl> <chr> 
## 1 rsq         0.899 price 
## 2 mae       814.    price 
## 3 rmse     1264.    price

Predicting with log

Let’s do the same job, but predicting the log of the target variable log_price. Let’s start with a new split, now with log_price as the stratifying variable.

set.seed(1111)
split_log <- initial_split(diamonds, prop = 0.9, strata = log_price) 

The recipe is quite similar to the previous model:

rec_log <- training(split_log) %>%
  recipe(log_price ~ .) %>%
  step_rm(price, x, y, z) %>%
  step_poly(carat, degree = 2) %>%
  step_dummy(all_nominal(), one_hot = TRUE) %>%
  step_nzv(all_predictors()) %>%
  step_rm(color_1)

I am fitting the same lm model, but with the rec_log recipe:

model_log <- workflow() %>%
  add_recipe(rec_log) %>%
  add_model(lm) %>%
  fit(training(split_log))

And then we can predict the train set. Note that we will obtain a prediction of the log of price in the .pred column.

pred_log <- model_log %>%
  predict(training(split_log)) %>%
  bind_cols(training(split_log))

If we want a predicted value of price, we need to undo the log transformation in the predicted variable. We do that in .pred_price.

pred_log <- pred_log %>%
  mutate(.pred_price = 10^.pred)

Here is the result of plotting the real and predicted variables in the train set for price and its log. We observe now that we do not have negative values of price.

log <- ggplot(pred_log, aes(log_price, .pred)) +
  geom_point() +
  geom_abline(slope = 1, intercept = 0, color = "red") +
  theme_minimal() +
  labs(title = "Log real versus log predicted (log model)", x = "price (log)", y = "pred. price (log)")

real <- ggplot(pred_log, aes(price, .pred_price)) +
  geom_point() +
  geom_abline(slope = 1, intercept = 0, color = "red") +
  theme_minimal() +
  labs(title = "Real versus predicted (log model)", x = "price", y = "pred. price")

log + real

Let’s examine the metrics for the log of price:

metrics_log <- pred_log %>%
  metric_prediction(truth = log_price, estimate = .pred) %>%
  mutate(.model = "log") %>%
  select(-.estimator)

metrics_log
## # A tibble: 3 × 3
##   .metric .estimate .model
##   <chr>       <dbl> <chr> 
## 1 rsq        0.965  log   
## 2 mae        0.0582 log   
## 3 rmse       0.0822 log

And then with price itself.

metrics_log_price <- pred_log %>%
  metric_prediction(truth = price, estimate = .pred_price) %>%
  mutate(.model = "price with log") %>%
  select(-.estimator)

metrics_log_price
## # A tibble: 3 × 3
##   .metric .estimate .model        
##   <chr>       <dbl> <chr>         
## 1 rsq         0.924 price with log
## 2 mae       528.    price with log
## 3 rmse     1123.    price with log

Comparing metrics

Let’s compare how well performs each model putting side to side the metrics of both models. We can do that because tidymodels offers us the metrics as data frames that can be put together easily.

bind_rows(metrics_price, metrics_log_price) %>%
  mutate(.estimate = format(.estimate, digits = 3)) %>%
  pivot_wider(id_cols = ".metric", names_from = ".model", values_from = ".estimate") %>%
  kbl(align = c("l", "r", "r")) %>%
  kable_styling(bootstrap_options = c("striped", "condensed", "responsive", "hover"), full_width = FALSE)
.metric price price with log
rsq 0.899 0.924
mae 813.541 528.007
rmse 1263.751 1123.007

For the training test, the log model performs better than the price model. Let’s see how well each model performs in the test set.

pred_price_test <- model_price %>%
  predict(testing(split_price)) %>%
  bind_cols(testing(split_price))

pred_log_test <- model_log %>%
  predict(testing(split_log)) %>%
  bind_cols(testing(split_log)) %>%
  mutate(.pred_price = 10^.pred)

metrics_price_test <- pred_price_test %>%
  metric_prediction(truth = price, estimate = .pred) %>%
  mutate(.model = "price") %>%
  select(-.estimator)

metrics_log_test <- pred_log_test %>%
  metric_prediction(truth = price, estimate = .pred_price) %>%
  mutate(.model = "price with log") %>%
  select(-.estimator)

bind_rows(metrics_price_test, metrics_log_test) %>%
  mutate(.estimate = format(.estimate, digits = 3)) %>%
  pivot_wider(id_cols = ".metric", names_from = ".model", values_from = ".estimate") %>%
  kbl(align = c("l", "r", "r")) %>%
  kable_styling(bootstrap_options = c("striped", "condensed", "responsive", "hover"), full_width = FALSE)
.metric price price with log
rsq 0.908 0.923
mae 811.561 537.652
rmse 1226.000 1167.739

Again, the log transformation has better performance in the test set.

References

Session info

## R version 4.2.0 (2022-04-22)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Linux Mint 19.2
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/openblas/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/libopenblasp-r0.2.20.so
## 
## locale:
##  [1] LC_CTYPE=es_ES.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=es_ES.UTF-8        LC_COLLATE=es_ES.UTF-8    
##  [5] LC_MONETARY=es_ES.UTF-8    LC_MESSAGES=es_ES.UTF-8   
##  [7] LC_PAPER=es_ES.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=es_ES.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] kableExtra_1.3.4   patchwork_1.1.1    corrplot_0.92      yardstick_0.0.9   
##  [5] workflowsets_0.2.1 workflows_0.2.6    tune_0.2.0         tidyr_1.2.0       
##  [9] tibble_3.1.6       rsample_0.1.1      recipes_0.2.0      purrr_0.3.4       
## [13] parsnip_0.2.1      modeldata_0.1.1    infer_1.0.0        ggplot2_3.3.5     
## [17] dplyr_1.0.9        dials_0.1.1        scales_1.2.0       broom_0.8.0       
## [21] tidymodels_0.2.0  
## 
## loaded via a namespace (and not attached):
##  [1] nlme_3.1-157       lubridate_1.8.0    webshot_0.5.3      httr_1.4.2        
##  [5] DiceDesign_1.9     tools_4.2.0        backports_1.4.1    bslib_0.3.1       
##  [9] utf8_1.2.2         R6_2.5.1           rpart_4.1.16       mgcv_1.8-40       
## [13] DBI_1.1.2          colorspace_2.0-3   nnet_7.3-17        withr_2.5.0       
## [17] tidyselect_1.1.2   compiler_4.2.0     rvest_1.0.2        cli_3.3.0         
## [21] xml2_1.3.3         labeling_0.4.2     bookdown_0.26      sass_0.4.1        
## [25] systemfonts_1.0.4  stringr_1.4.0      digest_0.6.29      svglite_2.1.0     
## [29] rmarkdown_2.14     pkgconfig_2.0.3    htmltools_0.5.2    parallelly_1.31.1 
## [33] lhs_1.1.5          highr_0.9          fastmap_1.1.0      rlang_1.0.2       
## [37] rstudioapi_0.13    farver_2.1.0       jquerylib_0.1.4    generics_0.1.2    
## [41] jsonlite_1.8.0     magrittr_2.0.3     Matrix_1.4-1       Rcpp_1.0.8.3      
## [45] munsell_0.5.0      fansi_1.0.3        GPfit_1.0-8        lifecycle_1.0.1   
## [49] furrr_0.3.0        stringi_1.7.6      pROC_1.18.0        yaml_2.3.5        
## [53] MASS_7.3-57        plyr_1.8.7         grid_4.2.0         parallel_4.2.0    
## [57] listenv_0.8.0      crayon_1.5.1       lattice_0.20-45    splines_4.2.0     
## [61] knitr_1.39         pillar_1.7.0       future.apply_1.9.0 codetools_0.2-18  
## [65] glue_1.6.2         evaluate_0.15      blogdown_1.9       vctrs_0.4.1       
## [69] foreach_1.5.2      gtable_0.3.0       future_1.25.0      assertthat_0.2.1  
## [73] xfun_0.30          gower_1.0.0        prodlim_2019.11.13 viridisLite_0.4.0 
## [77] class_7.3-20       survival_3.2-13    timeDate_3043.102  iterators_1.0.14  
## [81] hardhat_0.2.0      lava_1.6.10        globals_0.14.0     ellipsis_0.3.2    
## [85] ipred_0.9-12