Introducing the Receiver Operating Characteristic Curve

Jose M Sallan 2021-04-30 8 min read

In a previous post I presented a small binary classification problem, consisting in finding if an observation of the iris dataset belongs to the versicolor species. Remember that the target variable must be a factor, with first level associated to the positive case:

iris <- iris |> 
  mutate(is_versicolor = ifelse(Species == "versicolor", "versicolor", "not_versicolor")) |>
  mutate(is_versicolor =factor(is_versicolor, levels = c("versicolor", "not_versicolor")))

The recipe of this problem is straightforward: remove correlated predictors, and normalize (centering and scaling) the remaining predictors:

iris_recipe <- iris |>
  recipe(is_versicolor ~.) |>
  step_rm(Species) |>
  step_corr(all_predictors()) |>
  step_normalize(all_predictors(), -all_outcomes())

Here I will introduce the receiver operating characteristic (ROC) curve in the context of performance assessment of classification problems. We’ll assess the performance of two predictive models with the ROC curve, and then examine the performance of a bad classifier with the same metric.

Logistic Regression

Let’s begin fitting a logistic regression model:

lr <- logistic_reg() |>
  set_engine("glm")

iris_lr_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(lr)

With predict(iris) we predict the class of each observation:

iris_pred_lr <- iris_lr_wf |>
  fit(iris) |>
  predict(iris)
iris_pred_lr
## # A tibble: 150 × 1
##    .pred_class   
##    <fct>         
##  1 not_versicolor
##  2 not_versicolor
##  3 not_versicolor
##  4 not_versicolor
##  5 not_versicolor
##  6 not_versicolor
##  7 not_versicolor
##  8 not_versicolor
##  9 not_versicolor
## 10 not_versicolor
## # … with 140 more rows

In logistic regression, the class of each observation is assigned using its probability to belong to each class. This probability is used to construct the ROC curve, and we can obtain it using predict with type = "prob":

iris_prob_lr <- iris_lr_wf |>
  fit(iris) |>
  predict(iris, type = "prob")
iris_prob_lr
## # A tibble: 150 × 2
##    .pred_versicolor .pred_not_versicolor
##               <dbl>                <dbl>
##  1           0.0977                0.902
##  2           0.364                 0.636
##  3           0.195                 0.805
##  4           0.245                 0.755
##  5           0.0658                0.934
##  6           0.0259                0.974
##  7           0.0914                0.909
##  8           0.127                 0.873
##  9           0.367                 0.633
## 10           0.303                 0.697
## # … with 140 more rows

Let’s bind the outcomes with the original dataset:

iris_lr <- bind_cols(iris, iris_pred_lr, iris_prob_lr)

Random Forests

Let’s replicate the same workflow with random forests:

rf <- rand_forest(mode = "classification", trees = 100) |>
  set_engine("ranger")

iris_rf_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(rf)

iris_pred_rf <- iris_rf_wf |>
  fit(iris) |>
  predict(iris)

iris_prob_rf <- iris_rf_wf |>
  fit(iris) |>
  predict(iris, type = "prob")

iris_rf <- bind_cols(iris, iris_pred_rf, iris_prob_rf)

The ROC Curve

To classify an observation as positive or negative class from probabilities we need a threshold value, a value between 0 and 1 so that:

  • If observation’s probability of belonging to positive class is below the threshold the observation is labelled as negative.
  • If observation’s probability of belonging to positive class is above the threshold the observation is labelled as positive.

In a ROC curve we plot two performance metrics:

  • On the x axis we plot 1 - specificity, the ratio of negative observations classified correctly.
  • On the x axis we plot sensitivity, the ratio of positive observations classified correctly.

In the ROC curve, each point is the value of sensitivity and 1 - specificity for a specific threshold value. The ROC curve allows us to assess the balance between sensitivity and specificity of a classifier.

We will always have two points in every ROC curve:

  • If threshold = 1, all observations are classified as negative. Sensitivity is equal to 0 and specificity is equal to 1. This is the (0, 0) point of the ROC curve.
  • If threshold = 0, all observations are classified as positive. Sensitivity is equal to 1 and specificity is equal to 0. This is the (1, 1) point of the ROC curve.

A perfect classifier has sensitivity and specificity equal to one, which corresponds with the (0, 1) point in the ROC curve.

roc_df <- data.frame(x_roc = c(0,0,1), 
                     y_roc = c(0,1,1),
                     threshold = c("threshold=1", "perfect class.", "threshold=0"))

ggplot(roc_df, aes(x_roc,y_roc, label=threshold)) + 
  geom_point(size=2) +
  xlim(-0.1, 1.1) +
  ylim(0, 1.1) +
  geom_text(vjust=0, nudge_y = 0.05) +
  geom_abline(slope=1, intercept = 0, linetype = "dashed", size=0.2) +
  theme_bw() +
  labs(x="1-specificity", y="sensitivity")
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.

This curve can be turned into a number with the area under the ROC curve (AUC). This area will be a value between 1 and 0.5, being 1 the value of a perfect classifier. A classifier with ana AUC below 0.5 can be transformed into a better predictor just reversing the outcome (classifying positives as negatives and vice versa).

Let’s see the performance of logistic regression and random forests using the ROC curve.

ROC Curve for Logistic Regression

Here is a plot of the ROC curve for logistic regression. The points in red correspond with threshold values around 0.5, the usual cutoff values when predicting classes.

roc_curve_lr <- iris_lr |>
  roc_curve(truth = is_versicolor, estimate = .pred_versicolor) |> 
  mutate(x_roc = 1-sensitivity, y_roc=specificity)

sens_spec_lr <- roc_curve_lr |> 
  filter(.threshold > 0.48, .threshold < 0.52)

roc_curve_lr |>
  ggplot(aes(x_roc, y_roc)) +
  geom_point(size=0.5)  +
  geom_point(data=sens_spec_lr, aes(x_roc, y_roc), size = 3, color = "red") +
  geom_line(size=0.3) +
  geom_abline(slope=1, intercept = 0, linetype = "dashed", size=0.2) +
  theme_bw() +
  labs(x="1-specificity", y="sensitivity")

We can obtain the ROC curve faster in tidymodels doing:

iris_lr |>
  roc_curve(truth = is_versicolor, estimate = .pred_versicolor) |> 
  autoplot()

The AUC is equal to:

iris_lr |>
  roc_auc(is_versicolor, .pred_versicolor)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.809

Examiing the ROC and the AUC we observe that the performance of logistic regression is relatively poor. Let’s compare it with random forests.

ROC Curve for Random Forests

This is the ROC curve for random forests. This curve is close to a perfect classifier:

roc_curve_rf <- iris_rf |>
  roc_curve(truth = is_versicolor, estimate = .pred_versicolor) |> 
  mutate(x_roc = 1-sensitivity, y_roc=specificity) 

sens_spec_rf <- roc_curve_rf |> 
  filter(.threshold > 0.4, .threshold < 0.6)

roc_curve_rf |>
  ggplot(aes(x_roc, y_roc)) +
  geom_point(size=0.5)  +
  geom_point(data=sens_spec_rf, aes(x_roc, y_roc), size = 3, color = "red") +
  geom_line(size=0.3) +
  geom_abline(slope=1, intercept = 0, linetype = "dashed", size=0.2) +
  theme_bw() +
  labs(x="1-specificity", y="sensitivity")

The AUC is much closer to one now:

iris_rf |>
  roc_auc(is_versicolor, .pred_versicolor)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.996

A Bad ROC Curve

Let’s see the behaviour of a very bad classifier. We define the probability of each observation of being versicolor as a random value between zero and one:

iris_awful <- iris |>
  mutate(.pred_versicolor = runif(150, 0, 1))

If we define the ROC curve for this classifier, we observe that it lays around the diagonal:

roc_iris_awful <- iris_awful |>
  roc_curve(truth = is_versicolor, estimate = .pred_versicolor) |> 
  mutate(x_roc = 1-sensitivity, y_roc=specificity) 

sens_spec_awful <- roc_iris_awful |>
  filter(.threshold > 0.48, .threshold < .52)
  

roc_iris_awful |>
  ggplot(aes(x_roc, y_roc)) +
  geom_point(size=0.5) +
  geom_point(data=sens_spec_awful, aes(x_roc, y_roc), size=3, color="red") +
  geom_line(size=0.3, color = "#808080") +
  geom_abline(slope=1, intercept = 0, linetype = "dashed", size=0.2) +
  theme_bw() +
  labs(x="1-specificity", y="sensitivity")

… and that the AUC is close to is lower bound 0.5:

iris_awful |>
  roc_auc(is_versicolor, .pred_versicolor)
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.571

Assessing Model Performance with ROC and AUC

With a bad classifier, if we want to increase sensitivity (the ability to predict the positive class) we need to reduce specificity (the ability to predict the negative class). That’s why the ROC curve of a bad classifier lies close to the diagonal. A good classifier will yield good values of sensitivity and sensibility at once. This means that its ROC curve will be closer to the upper left corner, and the area under the curve (AUC) will be closer to one. ROC and AUC measure of how classifier is able to separate positive and negative cases. In fact, it was developed by radar engineers in World War II to measure the ability of a radar to detect objects depending on the threshold value of detection.

References

Session Info

## R version 4.2.2 Patched (2022-11-10 r83330)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Linux Mint 21.1
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
## 
## locale:
##  [1] LC_CTYPE=es_ES.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=es_ES.UTF-8        LC_COLLATE=es_ES.UTF-8    
##  [5] LC_MONETARY=es_ES.UTF-8    LC_MESSAGES=es_ES.UTF-8   
##  [7] LC_PAPER=es_ES.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=es_ES.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] yardstick_1.1.0    workflowsets_1.0.0 workflows_1.1.2    tune_1.0.1        
##  [5] tidyr_1.3.0        tibble_3.1.8       rsample_1.1.1      recipes_1.0.4     
##  [9] purrr_1.0.1        parsnip_1.0.3      modeldata_1.1.0    infer_1.0.4       
## [13] ggplot2_3.4.0      dplyr_1.0.10       dials_1.1.0        scales_1.2.1      
## [17] broom_1.0.3        tidymodels_1.0.0  
## 
## loaded via a namespace (and not attached):
##  [1] sass_0.4.5          foreach_1.5.2       jsonlite_1.8.4     
##  [4] splines_4.2.2       prodlim_2019.11.13  bslib_0.4.2        
##  [7] assertthat_0.2.1    highr_0.10          GPfit_1.0-8        
## [10] yaml_2.3.7          globals_0.16.2      ipred_0.9-13       
## [13] pillar_1.8.1        backports_1.4.1     lattice_0.20-45    
## [16] glue_1.6.2          digest_0.6.31       hardhat_1.2.0      
## [19] colorspace_2.1-0    htmltools_0.5.4     Matrix_1.5-1       
## [22] timeDate_4022.108   pkgconfig_2.0.3     lhs_1.1.6          
## [25] DiceDesign_1.9      listenv_0.9.0       bookdown_0.32      
## [28] ranger_0.14.1       gower_1.0.1         lava_1.7.1         
## [31] timechange_0.2.0    farver_2.1.1        generics_0.1.3     
## [34] ellipsis_0.3.2      cachem_1.0.6        withr_2.5.0        
## [37] furrr_0.3.1         nnet_7.3-18         cli_3.6.0          
## [40] survival_3.5-3      magrittr_2.0.3      evaluate_0.20      
## [43] future_1.31.0       fansi_1.0.4         parallelly_1.34.0  
## [46] MASS_7.3-58.2       class_7.3-21        blogdown_1.16      
## [49] tools_4.2.2         lifecycle_1.0.3     munsell_0.5.0      
## [52] compiler_4.2.2      jquerylib_0.1.4     rlang_1.0.6        
## [55] grid_4.2.2          iterators_1.0.14    rstudioapi_0.14    
## [58] labeling_0.4.2      rmarkdown_2.20      gtable_0.3.1       
## [61] codetools_0.2-19    DBI_1.1.3           R6_2.5.1           
## [64] lubridate_1.9.1     knitr_1.42          fastmap_1.1.0      
## [67] future.apply_1.10.0 utf8_1.2.2          parallel_4.2.2     
## [70] Rcpp_1.0.10         vctrs_0.5.2         rpart_4.1.19       
## [73] tidyselect_1.2.0    xfun_0.36

Updated at 2023-03-20 10:41:17.