We need the following packages:

library(tidyverse)
library(leaps)
library(glmnet)
library(mlbench)
library(rpart)
library(rpart.plot)
library(visNetwork)
library(ISLR)
library(kernlab)
library(randomForest)
library(plotROC)
library(pROC)
library(ranger)
library(caret)

Trees

Exercise 1 (tree regression)

We consider the following dataset

n <-50
set.seed(1234)
X <- runif(n)
set.seed(5678)
Y <- 1*X*(X<=0.6)+(-1*X+3.2)*(X>0.6)+rnorm(n,sd=0.1)
data1 <- data.frame(X,Y)
ggplot(data1)+aes(x=X,y=Y)+geom_point()+theme_classic()

  1. Fit a tree to explain \(Y\) by \(X\) (use rpart).
tree <- rpart(Y~X,data=data1)
  1. Draw the tree with prp and rpart.plot.
prp(tree)

rpart.plot(tree)

  1. Write the tree regression function.

The regression function \(m(x)\) is estimated by

\[\widehat m(x)=0.31 \mathbf{1}_{x<0.58}+2.4\mathbf{1}_{x\geq 0.58}.\]

  1. Add on the graph of question 1 the partition defined by the tree and the prediction.

We have a partition into two terminal nodes. This partition is defined by the question: “Is \(X\) less than 0.58?”

df1 <- data.frame(x=c(0,0.58),xend=c(0.58,1),y=c(0.31,2.41),yend=c(0.31,2.41))
ggplot(data1)+aes(x=X,y=Y)+geom_point()+geom_vline(xintercept = 0.58,size=1,color="blue")+
  geom_segment(data=df1,aes(x=x,y=y,xend=xend,yend=yend),size=1,color="red")+theme_classic()

Exercise 2 (classification tree)

We consider the following dataset

n <- 50
set.seed(12345)
X1 <- runif(n)
set.seed(5678)
X2 <- runif(n)
Y <- rep(0,n)
set.seed(54321)
Y[X1<=0.45] <- rbinom(sum(X1<=0.45),1,0.85)
set.seed(52432)
Y[X1>0.45] <- rbinom(sum(X1>0.45),1,0.15)
data2 <- data.frame(X1,X2,Y)
library(ggplot2)
ggplot(data2)+aes(x=X1,y=X2,color=Y)+geom_point(size=2)+
  scale_x_continuous(name="")+scale_y_continuous(name="")+theme_classic()

  1. Fit a tree to explain \(Y\) by \(X_1\) and \(X_2\). Draw the tree. What happens?
tree <- rpart(Y~.,data=data2)
rpart.plot(tree)

We observe that the tree is a regression tree, not a classification tree. It’s beacause of \(Y\) is a numeric variable. We have to convert \(Y\) into a factor.

data2$Y <- as.factor(data2$Y)
tree <- rpart(Y~.,data=data2)
rpart.plot(tree)

Now, it’s OK!

  1. Write the classification rule and the score function induced by the tree.

The classification rule is \[\widehat g(x)=\mathbf{1}_{X_1<0.44}.\]

The score function is \[\widehat S(x)=\widehat P(Y=1|X=x)=0.83\mathbf{1}_{X_1<0.44}+0.07\mathbf{1}_{X_1\geq 0.44}.\]

  1. Add on the graph of question 1 the partition defined by the tree.
ggplot(data2)+aes(x=X1,y=X2,color=Y,shape=Y)+geom_point(size=2)+
  theme_classic()+geom_vline(xintercept = 0.44,size=1,color="blue")

Exercise 3 (categorical input)

We consider the following dataset

n <- 100
X <- factor(rep(c("A","B","C","D"),n))
set.seed(1234)
Y[X=="A"] <- rbinom(sum(X=="A"),1,0.9)
Y[X=="B"] <- rbinom(sum(X=="B"),1,0.25)
Y[X=="C"] <- rbinom(sum(X=="C"),1,0.8)
Y[X=="D"] <- rbinom(sum(X=="D"),1,0.2)
Y <- as.factor(Y)
data3 <- data.frame(X,Y)
  1. Fit a tree to explain \(Y\) by \(X\).
tree3 <- rpart(Y~.,data=data3)
rpart.plot(tree3)

  1. Explain how the tree is fitted in this context.

The tree is fitted in the same way as for conitnous variable: we consider all binary partition of the cartegorical set \(\{A,B,C,D\}\) and we choose the best binary partition.

Exercise 4 (pruning)

We consider the dataset Carseats from the ISLR package.

data(Carseats)

The problem is to explain the Sales variable by the other variables.

  1. Fit a tree with rpart to explain Sales by the other variables.
tree <- rpart(Sales~.,data=Carseats)
rpart.plot(tree)

  1. Explain the output of the printcp command
printcp(tree)
## 
## Regression tree:
## rpart(formula = Sales ~ ., data = Carseats)
## 
## Variables actually used in tree construction:
## [1] Advertising Age         CompPrice   Income      Population  Price      
## [7] ShelveLoc  
## 
## Root node error: 3182.3/400 = 7.9557
## 
## n= 400 
## 
##          CP nsplit rel error  xerror     xstd
## 1  0.250510      0   1.00000 1.00492 0.069530
## 2  0.105073      1   0.74949 0.75877 0.051613
## 3  0.051121      2   0.64442 0.68283 0.046333
## 4  0.045671      3   0.59330 0.64240 0.043550
## 5  0.033592      4   0.54763 0.60051 0.041716
## 6  0.024063      5   0.51403 0.58903 0.039691
## 7  0.023948      6   0.48997 0.59472 0.039561
## 8  0.022163      7   0.46602 0.58972 0.039539
## 9  0.016043      8   0.44386 0.58329 0.039731
## 10 0.014027      9   0.42782 0.57392 0.038516
## 11 0.013145     11   0.39976 0.57780 0.038529
## 12 0.012711     12   0.38662 0.58719 0.038339
## 13 0.012147     13   0.37391 0.58970 0.038419
## 14 0.011888     14   0.36176 0.58850 0.038291
## 15 0.010778     15   0.34987 0.58673 0.038383
## 16 0.010506     16   0.33909 0.57818 0.038886
## 17 0.010000     17   0.32859 0.57320 0.038277

We have some informations on the sequence of 17 nested trees:

  • CP refers to the complexity parameter, the smaller CP, the larger tree.
  • nsplit refers to the number of split of the tree.
  • rel error contains the quadratic risk estimated on the training sample (it decreases as the complexity increases).
  • xerror contains the quadratic risk estimated by cross validation
  • xstd stands for the estimated standard deviation of the quadratic risk.

We can obtain a longer sequence if we decrease the defalut values of cp and minsplit.

tree1 <- rpart(Sales~.,data=Carseats,cp=0.00001,minsplit=2)
printcp(tree1)
## 
## Regression tree:
## rpart(formula = Sales ~ ., data = Carseats, cp = 1e-05, minsplit = 2)
## 
## Variables actually used in tree construction:
##  [1] Advertising Age         CompPrice   Education   Income     
##  [6] Population  Price       ShelveLoc   Urban       US         
## 
## Root node error: 3182.3/400 = 7.9557
## 
## n= 400 
## 
##             CP nsplit  rel error  xerror     xstd
## 1   2.5051e-01      0 1.00000000 1.00455 0.069340
## 2   1.0507e-01      1 0.74948961 0.75617 0.051398
## 3   5.1121e-02      2 0.64441706 0.65640 0.044108
## 4   4.5671e-02      3 0.59329646 0.64608 0.044382
## 5   3.3592e-02      4 0.54762521 0.62575 0.042153
## 6   2.4063e-02      5 0.51403284 0.59178 0.040173
## 7   2.3948e-02      6 0.48997005 0.61073 0.041773
## 8   2.2163e-02      7 0.46602225 0.61050 0.041747
## 9   1.6043e-02      8 0.44385897 0.56163 0.038088
## 10  1.4027e-02      9 0.42781645 0.59766 0.041327
## 11  1.3145e-02     11 0.39976237 0.60609 0.040589
## 12  1.2711e-02     12 0.38661699 0.60264 0.039917
## 13  1.2147e-02     13 0.37390609 0.60251 0.039996
## 14  1.1888e-02     14 0.36175900 0.60307 0.039700
## 15  1.0778e-02     15 0.34987122 0.59875 0.039513
## 16  1.0506e-02     16 0.33909277 0.59720 0.039314
## 17  1.0301e-02     17 0.32858663 0.60048 0.039539
## 18  9.8052e-03     18 0.31828518 0.60215 0.039321
## 19  9.5324e-03     20 0.29867475 0.59688 0.038847
## 20  9.3098e-03     21 0.28914234 0.58509 0.037957
## 21  8.6039e-03     22 0.27983257 0.59364 0.038845
## 22  8.5728e-03     23 0.27122871 0.57979 0.038021
## 23  7.7737e-03     25 0.25408305 0.58077 0.038159
## 24  7.4353e-03     26 0.24630936 0.57043 0.038219
## 25  6.2838e-03     28 0.23143882 0.56120 0.038484
## 26  6.1242e-03     29 0.22515504 0.56843 0.039348
## 27  5.6953e-03     30 0.21903085 0.56462 0.039230
## 28  5.5687e-03     31 0.21333555 0.56750 0.039181
## 29  5.4134e-03     32 0.20776686 0.56992 0.039160
## 30  5.1373e-03     33 0.20235343 0.56912 0.039094
## 31  4.9581e-03     34 0.19721608 0.56093 0.038749
## 32  4.8270e-03     35 0.19225798 0.56086 0.038918
## 33  4.5558e-03     36 0.18743102 0.55355 0.037534
## 34  4.5456e-03     37 0.18287525 0.55178 0.036494
## 35  4.3739e-03     38 0.17832965 0.55426 0.036417
## 36  4.3307e-03     39 0.17395578 0.54874 0.036136
## 37  4.2485e-03     40 0.16962503 0.55823 0.036516
## 38  4.0980e-03     41 0.16537650 0.55192 0.036372
## 39  4.0525e-03     42 0.16127847 0.55285 0.036362
## 40  4.0054e-03     43 0.15722596 0.55230 0.036364
## 41  3.6917e-03     44 0.15322052 0.55021 0.036473
## 42  3.6352e-03     45 0.14952883 0.54755 0.036442
## 43  3.5301e-03     46 0.14589367 0.54445 0.036306
## 44  3.5196e-03     47 0.14236356 0.54687 0.036470
## 45  2.8653e-03     48 0.13884396 0.55608 0.037051
## 46  2.8565e-03     49 0.13597868 0.56478 0.039584
## 47  2.8565e-03     50 0.13312217 0.56478 0.039584
## 48  2.7253e-03     51 0.13026571 0.56183 0.039665
## 49  2.6841e-03     52 0.12754044 0.55113 0.039584
## 50  2.6829e-03     54 0.12217220 0.55447 0.039833
## 51  2.6660e-03     55 0.11948928 0.55837 0.039927
## 52  2.4588e-03     56 0.11682326 0.55918 0.039957
## 53  2.3693e-03     57 0.11436443 0.56518 0.040630
## 54  2.3018e-03     58 0.11199508 0.56799 0.041090
## 55  2.2746e-03     60 0.10739152 0.56779 0.041112
## 56  2.2540e-03     61 0.10511688 0.56776 0.041112
## 57  2.1781e-03     62 0.10286290 0.56921 0.041107
## 58  2.1645e-03     63 0.10068483 0.56958 0.041022
## 59  2.0950e-03     64 0.09852033 0.56900 0.040877
## 60  2.0945e-03     65 0.09642538 0.56643 0.040806
## 61  2.0740e-03     66 0.09433084 0.56643 0.040806
## 62  1.8864e-03     67 0.09225680 0.56347 0.040630
## 63  1.8413e-03     68 0.09037038 0.55582 0.040486
## 64  1.7921e-03     69 0.08852905 0.56068 0.040597
## 65  1.7167e-03     70 0.08673697 0.56170 0.040676
## 66  1.6766e-03     71 0.08502031 0.56936 0.042289
## 67  1.6704e-03     72 0.08334367 0.57048 0.042586
## 68  1.6064e-03     73 0.08167332 0.56832 0.042634
## 69  1.6055e-03     74 0.08006697 0.56869 0.042690
## 70  1.5103e-03     75 0.07846149 0.56819 0.042705
## 71  1.4967e-03     76 0.07695120 0.56867 0.042767
## 72  1.4907e-03     77 0.07545453 0.56912 0.042755
## 73  1.4007e-03     78 0.07396387 0.56785 0.042790
## 74  1.4002e-03     79 0.07256317 0.56724 0.042750
## 75  1.3613e-03     80 0.07116301 0.56578 0.042601
## 76  1.3589e-03     81 0.06980172 0.56668 0.042522
## 77  1.3462e-03     82 0.06844282 0.56901 0.042534
## 78  1.3351e-03     83 0.06709659 0.57107 0.042649
## 79  1.3304e-03     84 0.06576144 0.57276 0.042842
## 80  1.3146e-03     85 0.06443102 0.57369 0.042951
## 81  1.2795e-03     86 0.06311644 0.57160 0.042984
## 82  1.2412e-03     87 0.06183696 0.56826 0.042483
## 83  1.2373e-03     88 0.06059575 0.56710 0.042447
## 84  1.2135e-03     89 0.05935843 0.56688 0.042507
## 85  1.2002e-03     91 0.05693148 0.56720 0.042546
## 86  1.1269e-03     92 0.05573126 0.56481 0.042261
## 87  1.0919e-03     93 0.05460435 0.56104 0.042163
## 88  1.0898e-03     94 0.05351243 0.56062 0.042194
## 89  1.0864e-03     95 0.05242260 0.55813 0.041778
## 90  1.0646e-03     96 0.05133621 0.55894 0.041994
## 91  1.0116e-03     97 0.05027156 0.55803 0.041999
## 92  9.5940e-04     98 0.04925996 0.54697 0.041965
## 93  8.9105e-04     99 0.04830056 0.54764 0.041798
## 94  8.8465e-04    100 0.04740951 0.55178 0.041915
## 95  8.7611e-04    101 0.04652486 0.55178 0.041915
## 96  8.5644e-04    102 0.04564875 0.55599 0.041948
## 97  8.4568e-04    103 0.04479231 0.55423 0.041546
## 98  8.3004e-04    104 0.04394663 0.55304 0.041472
## 99  8.0748e-04    105 0.04311659 0.55457 0.042079
## 100 7.9944e-04    106 0.04230912 0.55353 0.041991
## 101 7.5680e-04    107 0.04150968 0.55385 0.041995
## 102 7.4082e-04    108 0.04075288 0.55442 0.041911
## 103 7.4043e-04    109 0.04001206 0.55442 0.041911
## 104 7.3510e-04    110 0.03927163 0.55484 0.041898
## 105 7.0107e-04    111 0.03853653 0.55115 0.041616
## 106 6.9184e-04    112 0.03783546 0.55243 0.041645
## 107 6.7585e-04    113 0.03714362 0.55320 0.041794
## 108 6.7373e-04    114 0.03646776 0.55515 0.041803
## 109 6.7173e-04    115 0.03579403 0.55515 0.041803
## 110 6.6783e-04    116 0.03512230 0.55515 0.041803
## 111 6.6518e-04    117 0.03445448 0.55515 0.041803
## 112 6.6451e-04    118 0.03378929 0.55515 0.041803
## 113 6.0900e-04    119 0.03312478 0.55448 0.041499
## 114 6.0343e-04    120 0.03251578 0.55767 0.041640
## 115 5.9465e-04    121 0.03191235 0.55670 0.041632
## 116 5.8550e-04    123 0.03072304 0.55839 0.041836
## 117 5.8340e-04    124 0.03013754 0.55812 0.041842
## 118 5.6972e-04    125 0.02955414 0.55917 0.041912
## 119 5.6433e-04    126 0.02898442 0.55950 0.041904
## 120 5.6323e-04    127 0.02842009 0.55946 0.041906
## 121 5.4821e-04    128 0.02785686 0.55903 0.041909
## 122 5.4339e-04    131 0.02621222 0.56010 0.041958
## 123 5.1968e-04    132 0.02566882 0.56287 0.042061
## 124 5.0869e-04    133 0.02514915 0.56068 0.041964
## 125 5.0157e-04    134 0.02464045 0.56243 0.041951
## 126 4.7302e-04    135 0.02413889 0.56139 0.041866
## 127 4.6969e-04    136 0.02366587 0.56188 0.041883
## 128 4.6775e-04    137 0.02319618 0.56141 0.041861
## 129 4.6669e-04    138 0.02272842 0.56021 0.041864
## 130 4.5761e-04    139 0.02226174 0.56066 0.041865
## 131 4.5283e-04    140 0.02180413 0.55858 0.041818
## 132 4.5270e-04    141 0.02135130 0.55858 0.041818
## 133 4.5251e-04    142 0.02089861 0.55858 0.041818
## 134 4.4875e-04    143 0.02044610 0.55810 0.041756
## 135 4.4874e-04    144 0.01999735 0.55796 0.041759
## 136 4.4666e-04    145 0.01954861 0.55796 0.041759
## 137 4.3805e-04    146 0.01910194 0.55837 0.041753
## 138 4.2159e-04    147 0.01866389 0.55877 0.041787
## 139 4.1179e-04    148 0.01824230 0.56067 0.041806
## 140 3.8646e-04    149 0.01783051 0.56106 0.041776
## 141 3.6959e-04    150 0.01744404 0.56180 0.041783
## 142 3.3035e-04    151 0.01707446 0.56121 0.041657
## 143 3.0799e-04    152 0.01674411 0.56542 0.041918
## 144 3.0672e-04    153 0.01643612 0.56529 0.041890
## 145 3.0672e-04    154 0.01612940 0.56529 0.041890
## 146 3.0672e-04    155 0.01582268 0.56529 0.041890
## 147 3.0544e-04    156 0.01551596 0.56529 0.041890
## 148 3.0094e-04    157 0.01521052 0.56485 0.041862
## 149 2.9757e-04    158 0.01490958 0.56596 0.041870
## 150 2.8981e-04    159 0.01461201 0.56611 0.041866
## 151 2.8923e-04    160 0.01432220 0.56599 0.041847
## 152 2.8782e-04    161 0.01403296 0.56599 0.041847
## 153 2.8635e-04    162 0.01374515 0.56599 0.041847
## 154 2.8189e-04    163 0.01345879 0.56546 0.041851
## 155 2.8173e-04    164 0.01317690 0.56421 0.041710
## 156 2.6988e-04    165 0.01289517 0.56660 0.042074
## 157 2.6283e-04    166 0.01262530 0.56654 0.042080
## 158 2.5737e-04    167 0.01236246 0.56394 0.042168
## 159 2.5139e-04    168 0.01210509 0.56464 0.042378
## 160 2.5003e-04    169 0.01185370 0.56464 0.042378
## 161 2.3771e-04    170 0.01160367 0.56479 0.042374
## 162 2.3512e-04    171 0.01136596 0.56502 0.042369
## 163 2.2600e-04    172 0.01113084 0.56323 0.042177
## 164 2.1796e-04    173 0.01090483 0.56390 0.042161
## 165 2.1590e-04    174 0.01068688 0.56399 0.042160
## 166 2.1121e-04    175 0.01047098 0.56355 0.042101
## 167 2.0973e-04    176 0.01025977 0.56355 0.042101
## 168 2.0949e-04    178 0.00984031 0.56355 0.042101
## 169 2.0779e-04    179 0.00963081 0.56362 0.042099
## 170 2.0120e-04    180 0.00942302 0.56540 0.042188
## 171 2.0025e-04    181 0.00922182 0.56557 0.042184
## 172 1.9247e-04    182 0.00902157 0.56427 0.042187
## 173 1.8668e-04    183 0.00882910 0.56498 0.042364
## 174 1.7976e-04    184 0.00864242 0.56465 0.042362
## 175 1.6630e-04    185 0.00846266 0.56593 0.042429
## 176 1.6596e-04    186 0.00829637 0.56580 0.042426
## 177 1.6594e-04    187 0.00813041 0.56580 0.042425
## 178 1.6347e-04    188 0.00796447 0.56526 0.042386
## 179 1.6290e-04    189 0.00780100 0.56528 0.042386
## 180 1.5712e-04    190 0.00763810 0.56528 0.042386
## 181 1.5619e-04    191 0.00748098 0.56499 0.042312
## 182 1.5210e-04    192 0.00732479 0.56549 0.042366
## 183 1.4745e-04    193 0.00717270 0.56637 0.042407
## 184 1.4354e-04    194 0.00702525 0.56640 0.042405
## 185 1.3883e-04    195 0.00688171 0.56554 0.042334
## 186 1.3883e-04    196 0.00674288 0.56559 0.042334
## 187 1.3613e-04    197 0.00660405 0.56695 0.042371
## 188 1.3589e-04    198 0.00646792 0.56695 0.042371
## 189 1.3299e-04    199 0.00633203 0.56674 0.042373
## 190 1.3241e-04    200 0.00619904 0.56680 0.042371
## 191 1.3011e-04    201 0.00606664 0.56780 0.042405
## 192 1.2674e-04    202 0.00593652 0.56703 0.042418
## 193 1.2674e-04    203 0.00580978 0.56703 0.042418
## 194 1.2167e-04    204 0.00568304 0.56751 0.042421
## 195 1.2167e-04    205 0.00556136 0.56947 0.042553
## 196 1.2105e-04    206 0.00543969 0.56947 0.042553
## 197 1.1352e-04    207 0.00531864 0.57068 0.042659
## 198 1.0898e-04    208 0.00520512 0.57011 0.042653
## 199 1.0860e-04    209 0.00509614 0.56965 0.042649
## 200 1.0592e-04    210 0.00498754 0.56987 0.042644
## 201 1.0265e-04    211 0.00488162 0.57012 0.042639
## 202 9.6794e-05    212 0.00477896 0.57130 0.042691
## 203 9.5532e-05    213 0.00468217 0.57190 0.042702
## 204 9.4042e-05    214 0.00458664 0.57205 0.042697
## 205 9.1257e-05    215 0.00449260 0.57212 0.042695
## 206 9.0753e-05    216 0.00440134 0.57198 0.042699
## 207 8.9624e-05    217 0.00431059 0.57198 0.042699
## 208 8.8270e-05    218 0.00422096 0.57196 0.042699
## 209 8.7486e-05    219 0.00413269 0.57290 0.042691
## 210 8.3729e-05    220 0.00404521 0.57271 0.042688
## 211 8.1451e-05    221 0.00396148 0.57627 0.042767
## 212 7.9204e-05    222 0.00388003 0.57555 0.042737
## 213 7.7471e-05    224 0.00372162 0.57480 0.042744
## 214 7.6989e-05    225 0.00364415 0.57480 0.042744
## 215 7.4805e-05    227 0.00349017 0.57426 0.042735
## 216 7.2925e-05    228 0.00341536 0.57441 0.042748
## 217 7.2160e-05    229 0.00334244 0.57529 0.042808
## 218 7.1694e-05    230 0.00327028 0.57529 0.042808
## 219 6.9264e-05    231 0.00319859 0.57625 0.042965
## 220 6.8065e-05    232 0.00312932 0.57500 0.042894
## 221 6.8065e-05    233 0.00306126 0.57500 0.042894
## 222 6.7977e-05    234 0.00299319 0.57500 0.042894
## 223 6.6383e-05    235 0.00292522 0.57500 0.042894
## 224 6.6383e-05    236 0.00285883 0.57472 0.042890
## 225 6.6383e-05    237 0.00279245 0.57472 0.042890
## 226 6.6203e-05    238 0.00272607 0.57472 0.042890
## 227 6.5697e-05    239 0.00265986 0.57431 0.042877
## 228 6.5373e-05    240 0.00259417 0.57431 0.042877
## 229 6.4356e-05    241 0.00252879 0.57447 0.042876
## 230 6.3372e-05    242 0.00246444 0.57485 0.042870
## 231 6.2228e-05    243 0.00240107 0.57485 0.042870
## 232 6.2225e-05    244 0.00233884 0.57485 0.042870
## 233 6.0397e-05    245 0.00227661 0.57485 0.042870
## 234 5.8464e-05    246 0.00221622 0.57305 0.042827
## 235 5.8137e-05    248 0.00209929 0.57331 0.042831
## 236 5.4694e-05    249 0.00204115 0.57380 0.042822
## 237 5.2855e-05    251 0.00193176 0.57320 0.042751
## 238 5.1331e-05    252 0.00187891 0.57335 0.042758
## 239 5.1048e-05    253 0.00182758 0.57302 0.042763
## 240 4.9324e-05    255 0.00172548 0.57259 0.042714
## 241 4.9278e-05    256 0.00167616 0.57259 0.042714
## 242 4.9278e-05    257 0.00162688 0.57259 0.042714
## 243 4.9273e-05    258 0.00157760 0.57259 0.042714
## 244 4.5298e-05    259 0.00152833 0.57201 0.042714
## 245 4.3577e-05    260 0.00148303 0.57136 0.042655
## 246 4.3370e-05    261 0.00143945 0.57173 0.042656
## 247 4.2422e-05    262 0.00139608 0.57173 0.042656
## 248 4.0867e-05    263 0.00135366 0.57245 0.042708
## 249 3.9280e-05    264 0.00131279 0.57322 0.042752
## 250 3.7840e-05    265 0.00127351 0.57331 0.042762
## 251 3.7840e-05    266 0.00123567 0.57331 0.042762
## 252 3.7840e-05    267 0.00119783 0.57331 0.042762
## 253 3.6955e-05    268 0.00115999 0.57331 0.042762
## 254 3.5847e-05    269 0.00112304 0.57377 0.042782
## 255 3.5216e-05    270 0.00108719 0.57377 0.042782
## 256 3.4708e-05    271 0.00105197 0.57388 0.042779
## 257 3.4032e-05    272 0.00101727 0.57388 0.042779
## 258 3.3519e-05    273 0.00098323 0.57320 0.042678
## 259 3.3247e-05    274 0.00094971 0.57311 0.042681
## 260 2.9981e-05    275 0.00091647 0.57343 0.042674
## 261 2.9052e-05    276 0.00088649 0.57331 0.042680
## 262 2.7245e-05    277 0.00085744 0.57305 0.042685
## 263 2.5663e-05    278 0.00083019 0.57267 0.042608
## 264 2.5663e-05    279 0.00080453 0.57267 0.042608
## 265 2.2814e-05    280 0.00077886 0.57366 0.042715
## 266 2.2688e-05    281 0.00075605 0.57382 0.042740
## 267 2.2128e-05    282 0.00073336 0.57375 0.042742
## 268 2.1877e-05    283 0.00071123 0.57385 0.042740
## 269 2.1510e-05    284 0.00068936 0.57385 0.042740
## 270 2.0132e-05    285 0.00066785 0.57383 0.042740
## 271 2.0132e-05    286 0.00064772 0.57381 0.042733
## 272 1.8231e-05    287 0.00062758 0.57381 0.042733
## 273 1.8163e-05    288 0.00060935 0.57461 0.042727
## 274 1.7618e-05    289 0.00059119 0.57461 0.042727
## 275 1.7618e-05    290 0.00057357 0.57440 0.042728
## 276 1.7608e-05    291 0.00055595 0.57440 0.042728
## 277 1.7110e-05    292 0.00053834 0.57441 0.042727
## 278 1.5272e-05    293 0.00052123 0.57478 0.042734
## 279 1.5099e-05    294 0.00050596 0.57496 0.042730
## 280 1.4162e-05    296 0.00047576 0.57496 0.042730
## 281 1.4162e-05    297 0.00046160 0.57483 0.042726
## 282 1.4141e-05    298 0.00044744 0.57483 0.042726
## 283 1.4141e-05    300 0.00041916 0.57483 0.042726
## 284 1.3214e-05    301 0.00040502 0.57471 0.042729
## 285 1.3214e-05    302 0.00039180 0.57479 0.042728
## 286 1.3093e-05    303 0.00037859 0.57479 0.042728
## 287 1.2318e-05    304 0.00036550 0.57437 0.042717
## 288 1.2318e-05    305 0.00035318 0.57436 0.042717
## 289 1.1454e-05    306 0.00034086 0.57438 0.042717
## 290 1.1082e-05    307 0.00032941 0.57432 0.042720
## 291 1.0621e-05    308 0.00031832 0.57467 0.042765
## 292 1.0000e-05    312 0.00027584 0.57467 0.042765
  1. Draw the tree with 8 split (use prune).
tree2 <- prune(tree,cp=0.016043)
rpart.plot(tree2)

  1. visTree function from visNetwork package allows to draw interactive graphs
visTree(tree)

A shiny web application is also proposed to visualise the sequence of subtrees

visTreeEditor(Carseats)
  1. Split the dataset into a training set of size 250 and a test set of size 150.
n.train <- 250
set.seed(12345)
perm <- sample(nrow(Carseats))
train <- Carseats[perm[1:n.train],]
test <- Carseats[-perm[1:n.train],]
  1. We fit a sequence of trees on the train sample with
set.seed(12345)
tree <- rpart(Sales~.,data=train,cp=0.00000001,minsplit=2)

In this sequence of tree, select

  • a very simple tree (with 2 or 3 splits)
  • a very large tree
  • an optimal tree (with the classical pruning strategy)
#printcp(tree)
simple.tree <- prune(tree,cp=0.05366748)
large.tree <- prune(tree,cp=0.00000001)
#cp_opt <- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]
cp_opt <- tree$cptable %>% as.data.frame() %>% 
  filter(xerror==min(xerror)) %>% 
  dplyr::select(CP) %>% as.numeric()

opt.tree <- prune(tree,cp=cp_opt)
  1. Estimate the quadratic error of these 3 trees with the test sample.

For each tree \(T\) we compute \[\frac{1}{n_{test}}\sum_{i\in test}(Y_i-T(X_i))^2.\]

We first compute a table with the predictions of the three trees:

data.prev <- data.frame(simple=predict(simple.tree,newdata = test),
                        large=predict(large.tree,newdata = test),
                        opt=predict(opt.tree,newdata = test))

We now estimate the quadratic risk with dplyr:

data.prev %>% mutate(obs=test$Sales) %>% 
  summarise_at(1:3,~(mean((obs-.)^2)))

The selected tree has (just) the best estimated quadratic error.

Random Forest

Exercise 5 (random Forest)

We again consider the spam dataset from package kernlab.

data(spam)
  1. Explain the following graph
rf1 <- randomForest(type~.,data=spam)
plot(rf1)

We visualize error rates (MSE, specificity and sensibility) according to the number of trees in the forest. It is useful to check if the forest has converged or if we need more iterations.

  1. Fit anoter random forest with mtry=1. Make a comparison between the two forests.
rf2 <- randomForest(type~.,data=spam,mtry=1)
rf1
## 
## Call:
##  randomForest(formula = type ~ ., data = spam) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 7
## 
##         OOB estimate of  error rate: 4.54%
## Confusion matrix:
##         nonspam spam class.error
## nonspam    2712   76  0.02725968
## spam        133 1680  0.07335907
rf2
## 
## Call:
##  randomForest(formula = type ~ ., data = spam, mtry = 1) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 1
## 
##         OOB estimate of  error rate: 8.15%
## Confusion matrix:
##         nonspam spam class.error
## nonspam    2725   63  0.02259684
## spam        312 1501  0.17209046

OOB error is clearly higher for mtry=1.

  1. Split the data into a training set of size 3000 and a test set of size 1601.
set.seed(1234)
perm <- sample(nrow(spam))
train <- spam %>% slice(perm[1:3000])
test <- spam %>% slice(-perm[1:3000])
  1. Fit 2 random forest on the training set: one with the default value of mtry and one with mtry=1.
rf1 <- randomForest(type~.,data=train)
rf2 <- randomForest(type~.,data=train,mtry=1)
  1. Estimate the misclassification error of the 2 RF with the test set.
prev <- data.frame(rf1=predict(rf1,newdata=test,type="class"),
                   rf2=predict(rf2,newdata=test,type="class"),Y=test$type)
prev %>% summarize_at(1:2,~mean(.!=Y))
  1. Use caret package to select mtry parameter. You can select this parameter in the grid seq(1,30,by=5).
grille.mtry <- data.frame(mtry=seq(1,30,by=5))
ctrl <- trainControl(method="oob")
library(doParallel) ## pour paralléliser
cl <- makePSOCKcluster(4)
registerDoParallel(cl)
set.seed(12345)
sel.mtry <- train(type~.,data=train,method="rf",trControl=ctrl,
                  tuneGrid=grille.mtry)
on.exit(stopCluster(cl))
sel.mtry$bestTune

The selected parameter is closed to the default value.

  1. Fit a tree on the train sample.
tree <- rpart(type~.,data=train,cp=0.00001,minsplit=3)
plotcp(tree)

cp_opt <- tree$cptable %>% as.data.frame() %>% 
  filter(xerror==min(xerror)) %>% dplyr::select(CP) %>% 
  slice(1) %>% as.numeric()
opt.tree <- prune(tree,cp=cp_opt)
rpart.plot(opt.tree) 

  1. Draw ROC curves and compute AUC for the 2 random forests and the tree.
  • With roc function from pROC package
score1 <- predict(rf1,newdata=test,type="prob")[,2]
score2 <- predict(rf2,newdata=test,type="prob")[,2]
score3 <- predict(opt.tree,newdata=test,type="prob")[,2]
a1 <- roc(test$type,score1)
a2 <- roc(test$type,score2)
a3 <- roc(test$type,score3)

plot(a1)
plot(a2,add=TRUE,col="red")
plot(a3,add=TRUE,col="blue")

  • With geom_roc to obtain a ggplot graph:
score <- data.frame(rf1=score1,rf2=score2,tree=score3,Y=test$type) %>% gather(key="Method",value="score",-Y)

ggplot(score)+aes(d=Y,m=score,color=Method)+geom_roc()+theme_classic()

score %>% group_by(Method) %>% summarize(AUC=pROC::auc(Y,score))
  1. Represent the 10 most important variables of the best random forest with a bar chart.
imp <- randomForest::importance(rf1) %>% as.data.frame()
imp1 <- imp %>% mutate(variable=rownames(imp)) %>% as_tibble() %>%
  arrange(desc(MeanDecreaseGini)) %>% slice(1:10) %>%
  mutate(Var=substr(variable,1,8))
imp1$Var <- factor(as.character(imp1$Var),levels=imp1$Var)
ggplot(imp1)+aes(x=Var,y=MeanDecreaseGini)+geom_bar(stat = "identity")+xlab("")

  1. Fit a forest on the train dataset with the ranger function of the ranger package. What do you notice?
system.time(rf3 <- ranger(type~.,data=train))
##    user  system elapsed 
##   2.795   0.030   0.455
system.time(rf4 <- randomForest(type~.,data=train))
##    user  system elapsed 
##   5.392   0.055   5.463

The fitting process is clearly faster. ranger is a fast implementation of random forests for high dimensional data. You can find observation here.

Exercice 6

Make a comparison between random Forest, ridge and lasso for the spam dataset. To do that, you will compute the classical risks (error probability, ROC…) by 10-folds cross validation. Be carefull wih the selection of the parameters.

We first create the folds

set.seed(1234)
Folds <- createFolds(1:nrow(spam),k=10)

and the input matrix for glmnet.

mat.X <- model.matrix(type~.,data=spam)[,-1]

Now the cross validation

prev <- matrix(0,nrow=nrow(spam),ncol=3) %>% as.data.frame()
names(prev) <- c("forest","ridge","lasso")
for (k in 1:10){
  print(k)
  train <- spam %>% slice(-Folds[[k]])
  test <- spam %>% slice(Folds[[k]])
  X.train <- mat.X[-Folds[[k]],]
  Y.train <- spam$type[-Folds[[k]]]
  X.test <- mat.X[Folds[[k]],]
  forest.k <- ranger(type~.,data=train,probability = TRUE)
  ridge.k <- cv.glmnet(X.train,Y.train,family="binomial",alpha=0)
  lasso.k <- cv.glmnet(X.train,Y.train,family="binomial",alpha=1)
  prev[Folds[[k]],1] <- predict(forest.k,data=test)[1]$predictions[,2]
  prev[Folds[[k]],2] <- as.vector(predict(ridge.k,newx=X.test,type="response"))
  prev[Folds[[k]],3] <- as.vector(predict(lasso.k,newx=X.test,type="response"))
}
## [1] 1
## [1] 2
## [1] 3
## [1] 4
## [1] 5
## [1] 6
## [1] 7
## [1] 8
## [1] 9
## [1] 10

We compute ROC curves and AUC.

prev1 <- prev %>% mutate(obs=spam$type) %>%
  gather(key="Method",value="score",-obs)
prev1 %>% group_by(Method) %>% summarize(AUC=pROC::auc(obs,score)) %>%
  arrange(desc(AUC))
ggplot(prev1)+aes(d=obs,m=score,color=Method)+geom_roc()+theme_classic()

And finaly misclassification errors.

round(prev) %>% mutate(obs=recode(spam$type,nonspam="0",spam="1")) %>%
  gather(key="Method",value="pred",-obs) %>% group_by(Method) %>%
  summarise(Err=mean(obs!=pred)) %>% arrange(Err)