A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/WandeRum/multiROC below:

WandeRum/multiROC: Calculating and Visualizing ROC and PR Curves Across Multi-Class Classifications

multiROC

Calculating and Visualizing ROC and PR Curves Across Multi-Class Classifications

The receiver operating characteristic (ROC) and precision recall (PR) is an extensively utilized method for comparing binary classifiers in various areas. However, many real-world problems are designed to multiple classes (e.g., tumor, node, and metastasis staging system of cancer), which require an evaluation strategy to assess multiclass classifiers. This package aims to fill the gap by enabling the calculation of multiclass ROC-AUC and PR-AUC with confidence intervals and the generation of publication-quality figures of multiclass ROC curves and PR curves.

A user-friendly website is available at https://metabolomics.cc.hawaii.edu/software/multiROC/.

Please cite our paper once it is published: (Submitted).

Install multiROC from GitHub:

install.packages('devtools')
require(devtools)
install_github("WandeRum/multiROC")
require(multiROC)

Install multiROC from CRAN:

install.packages('multiROC')
require(multiROC)

This demo is about the comparison between random forest and multinomial logistic regression based on Iris data.

require(multiROC)
data(iris)
head(iris)
3.2 60% training data and 40% testing data
set.seed(123456)
total_number <- nrow(iris)
train_idx <- sample(total_number, round(total_number*0.6))
train_df <- iris[train_idx, ]
test_df <- iris[-train_idx, ]
rf_res <- randomForest::randomForest(Species~., data = train_df, ntree = 100)
rf_pred <- predict(rf_res, test_df, type = 'prob') 
rf_pred <- data.frame(rf_pred)
colnames(rf_pred) <- paste(colnames(rf_pred), "_pred_RF")
3.4 Multinomial logistic regression
mn_res <- nnet::multinom(Species ~., data = train_df)
mn_pred <- predict(mn_res, test_df, type = 'prob')
mn_pred <- data.frame(mn_pred)
colnames(mn_pred) <- paste(colnames(mn_pred), "_pred_MN")
3.5 Merge true labels and predicted values
true_label <- dummies::dummy(test_df$Species, sep = ".")
true_label <- data.frame(true_label)
colnames(true_label) <- gsub(".*?\\.", "", colnames(true_label))
colnames(true_label) <- paste(colnames(true_label), "_true")
final_df <- cbind(true_label, rf_pred, mn_pred)
roc_res <- multi_roc(final_df, force_diag=T)
pr_res <- multi_pr(final_df, force_diag=T)
plot_roc_df <- plot_roc_data(roc_res)
plot_pr_df <- plot_pr_data(pr_res)

require(ggplot2)
ggplot(plot_roc_df, aes(x = 1-Specificity, y=Sensitivity)) +
  geom_path(aes(color = Group, linetype=Method), size=1.5) +
  geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), 
                        colour='grey', linetype = 'dotdash') +
  theme_bw() + 
  theme(plot.title = element_text(hjust = 0.5), 
                 legend.justification=c(1, 0), legend.position=c(.95, .05),
                 legend.title=element_blank(), 
                 legend.background = element_rect(fill=NULL, size=0.5, 
                                                           linetype="solid", colour ="black"))

ggplot(plot_pr_df, aes(x=Recall, y=Precision)) + 
  geom_path(aes(color = Group, linetype=Method), size=1.5) + 
  theme_bw() + 
  theme(plot.title = element_text(hjust = 0.5), 
                 legend.justification=c(1, 0), legend.position=c(.95, .05),
                 legend.title=element_blank(), 
                 legend.background = element_rect(fill=NULL, size=0.5, 
                                                           linetype="solid", colour ="black"))

library(multiROC)
data(test_data)
head(test_data)
##   G1_true G2_true G3_true G1_pred_m1 G2_pred_m1 G3_pred_m1 G1_pred_m2 G2_pred_m2 G3_pred_m2
## 1       1       0       0  0.8566867  0.1169520 0.02636133  0.4371601  0.1443851 0.41845482
## 2       1       0       0  0.8011788  0.1505448 0.04827643  0.3075236  0.5930025 0.09947397
## 3       1       0       0  0.8473608  0.1229815 0.02965766  0.3046363  0.4101367 0.28522698
## 4       1       0       0  0.8157730  0.1422322 0.04199482  0.2378494  0.5566147 0.20553591
## 5       1       0       0  0.8069553  0.1472971 0.04574766  0.4067347  0.2355822 0.35768312
## 6       1       0       0  0.6894488  0.2033285 0.10722271  0.1063048  0.4800507 0.41364450

This example dataset contains two classifiers (m1, m2), and three groups (G1, G2, G3).

4.1 multi_roc and multi_pr function
roc_res <- multi_roc(test_data, force_diag=T)
pr_res <- multi_pr(test_data, force_diag=T)

The function multi_roc and multi_pr are core functions for calculating multiclass ROC-AUC and PR-AUC.

Arguments of multi_roc and multi_pr:

Outputs of multi_roc:

Outputs of multi_pr:

4.2.1 List of AUC results
##     m1.G1     m1.G2     m1.G3  m1.macro  m1.micro     m2.G1     m2.G2     m2.G3  m2.macro  m2.micro
## 0.7233607 0.5276190 0.9751462 0.7420609 0.8824221 0.3237705 0.3723810 0.4020468 0.3665670 0.4174394

This list shows the following AUC information:

  1. AUC of G1 v.s. the rest in the classifier m1;
  2. AUC of G2 v.s. the rest in the classifier m1;
  3. AUC of G3 v.s. the rest in the classifier m1;
  4. AUC of Macro in the classifier m1;
  5. AUC of Micro in the classifier m1;
  6. AUC of G1 v.s. the rest in the classifier m2;
  7. AUC of G2 v.s. the rest in the classifier m2;
  8. AUC of G3 v.s. the rest in the classifier m2;
  9. AUC of Macro in the classifier m2;
  10. AUC of Micro in the classifier m2.
roc_ci_res <- roc_ci(test_data, conf= 0.95, type='basic', R = 100, index = 4)
pr_ci_res <- pr_ci(test_data, conf= 0.95, type='basic', R = 100, index = 4)
## BOOTSTRAP CONFIDENCE INTERVAL CALCULATIONS
## Based on 100 bootstrap replicates
## 
## CALL : 
## boot.ci(boot.out = res_boot, conf = conf, type = type, index = index)
## 
## Intervals : 
## Level       BCa          
## 95%   ( 0.649,  0.861 )  
## Calculations and Intervals on Original Scale
## Some BCa intervals may be unstable


## BOOTSTRAP CONFIDENCE INTERVAL CALCULATIONS
## Based on 100 bootstrap replicates
## 
## CALL : 
## boot.ci(boot.out = res_boot, conf = conf, type = type, index = index)
## 
## Intervals : 
## Level       BCa          
## 95%   ( 0.4242,  0.6547 )  
## Calculations and Intervals on Original Scale
## Warning : BCa Intervals used Extreme Quantiles
## Some BCa intervals may be unstable

The function roc_ci and pr_ci are used to calculate confidence intervals of multiclass ROC-AUC and PR-AUC.

Arguments of roc_ci and pr_ci:

Here, we set index = 4 to calculate 95% CI of AUC of Macro in the classifier m1 based on 1000 bootstrap replicates as an example.

roc_auc_with_ci_res <- roc_auc_with_ci(test_data, conf= 0.95, type='bca', R = 100)
roc_auc_with_ci_res
pr_auc_with_ci_res <- pr_auc_with_ci(test_data, conf= 0.95, type='bca', R = 100)
pr_auc_with_ci_res
##         Var       AUC  lower CI higher CI
## 1     m1.G1 0.7233607 0.5555556 0.8406849
## 2     m1.G2 0.5276190 0.3141490 0.6991112
## 3     m1.G3 0.9751462 0.9156118 0.9969245
## 4  m1.macro 0.7420420 0.6162905 0.8469999
## 5  m1.micro 0.8824221 0.7942627 0.9318145
## 6     m2.G1 0.3237705 0.2039669 0.4888149
## 7     m2.G2 0.3723810 0.2322795 0.5126755
## 8     m2.G3 0.4020468 0.2266853 0.5944778
## 9  m2.macro 0.3665214 0.2796633 0.4970809
## 10 m2.micro 0.4174394 0.3449170 0.5036827


##         Var        AUC   lower CI higher CI
## 1     m1.G1 0.81104090 0.69394085 0.9219133
## 2     m1.G2 0.18898097 0.09755974 0.3404294
## 3     m1.G3 0.67479141 0.34789377 0.9871966
## 4  m1.macro 0.54968868 0.43208030 0.6663112
## 5  m1.micro 0.75125213 0.61651803 0.8635095
## 6     m2.G1 0.60633468 0.49548427 0.7510816
## 7     m2.G2 0.13298786 0.06840618 0.2092138
## 8     m2.G3 0.08150105 0.03931745 0.1388464
## 9  m2.macro 0.27320882 0.23863708 0.3009619
## 10 m2.micro 0.27540471 0.23380391 0.3087388

The function roc_auc_with_ci and pr_auc_with_ci are used to calculate confidence intervals of multiclass ROC-AUC, PR-AUC, and output a dataframe with AUCs, lower CIs, and higher CIs of all methods and groups.

Arguments of roc_auc_with_ci and pr_auc_with_ci:

5.1 change the format of AUC results to a ggplot2 friendly format.
plot_roc_df <- plot_roc_data(roc_res)
plot_pr_df <- plot_pr_data(pr_res)
ggplot(plot_roc_df, aes(x = 1-Specificity, y=Sensitivity)) + geom_path(aes(color = Group, linetype=Method)) + geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), colour='grey', linetype = 'dotdash') + theme_bw() + theme(plot.title = element_text(hjust = 0.5), legend.justification=c(1, 0), legend.position=c(.95, .05), legend.title=element_blank(), legend.background = element_rect(fill=NULL, size=0.5, linetype="solid", colour ="black"))
ggplot(plot_pr_df, aes(x=Recall, y=Precision)) + geom_path(aes(color = Group, linetype=Method), size=1.5) + theme_bw() + theme(plot.title = element_text(hjust = 0.5), legend.justification=c(1, 0), legend.position=c(.95, .05), legend.title=element_blank(), legend.background = element_rect(fill=NULL, size=0.5, linetype="solid", colour ="black"))

For sending comments, suggestions, bug reports of multiROC, please email to Runmin Wei (wander1021@gmail.com) or report it via thus URL: https://github.com/WandeRum/multiROC/issues

GPL-3


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4