> str(iris)'data.frame': 150 obs. of 5 variables: $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ... $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... $ Species : Factor w/ 3 levels 'setosa','versicolor',..: 1 1 1 1 1 1 1 1 1 1 ...> set.seed(1234)> ind <- sample(2, nrow(iris), replace=TRUE, prob=c(0.7, 0.3))> trainData <- iris[ind==1,]> testData <- iris[ind==2,]
> library(party)> myFormula <- Species ~ Sepal.Length Sepal.Width Petal.Length Petal.Width> iris_ctree <- ctree(myFormula, data=trainData)> # check the prediction> table(predict(iris_ctree), trainData$Species) setosa versicolor virginica setosa 40 0 0 versicolor 0 37 3 virginica 0 1 31
> print(iris_ctree) Conditional inference tree with 4 terminal nodesResponse: Species Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width Number of observations: 112 1) Petal.Length <= 1.9; criterion = 1, statistic = 104.643 2)* weights = 40 1) Petal.Length > 1.9 3) Petal.Width <= 1.7; criterion = 1, statistic = 48.939 4) Petal.Length <= 4.4; criterion = 0.974, statistic = 7.397 5)* weights = 21 4) Petal.Length > 4.4 6)* weights = 19 3) Petal.Width > 1.7 7)* weights = 32 > plot(iris_ctree)> # 圖略
> plot(iris_ctree, type='simple')
> #圖略
> # predict on test data> testPred <- predict(iris_ctree, newdata = testData)> table(testPred, testData$Species) testPred setosa versicolor virginica setosa 10 0 0 versicolor 0 12 2 virginica 0 0 14
① ctree()不能很好地處理缺失值,含有缺失值的觀測有時被劃分到左子樹,有時劃到右子樹,這是由缺失值的替代規(guī)則決定的。
② 訓練集和測試集需出自同一個數(shù)據(jù)集,即它們的表結(jié)構(gòu)、含有的變量要一致,無論決策樹最終是否用到了全部的變量。
③ 如果訓練集和測試集的分類變量的水平值不一致,對測試集的預(yù)測會識別。解決此問題的方法是根據(jù)測試集中的分類變量的水平值顯式地設(shè)置訓練數(shù)據(jù)。
> data('bodyfat', package = 'TH.data')> dim(bodyfat)[1] 71 10> attributes(bodyfat)$names [1] 'age' 'DEXfat' 'waistcirc' 'hipcirc' 'elbowbreadth' [6] 'kneebreadth' 'anthro3a' 'anthro3b' 'anthro3c' 'anthro4' $row.names [1] '47' '48' '49' '50' '51' '52' '53' '54' '55' '56' '57' '58' '59' '60' [15] '61' '62' '63' '64' '65' '66' '67' '68' '69' '70' '71' '72' '73' '74' [29] '75' '76' '77' '78' '79' '80' '81' '82' '83' '84' '85' '86' '87' '88' [43] '89' '90' '91' '92' '93' '94' '95' '96' '97' '98' '99' '100' '101' '102'[57] '103' '104' '105' '106' '107' '108' '109' '110' '111' '112' '113' '114' '115' '116'[71] '117'$class[1] 'data.frame'> bodyfat[1:5,] age DEXfat waistcirc hipcirc elbowbreadth kneebreadth anthro3a anthro3b anthro3c47 57 41.68 100.0 112.0 7.1 9.4 4.42 4.95 4.5048 65 43.29 99.5 116.5 6.5 8.9 4.63 5.01 4.4849 59 35.41 96.0 108.5 6.2 8.9 4.12 4.74 4.6050 58 22.79 72.0 96.5 6.1 9.2 4.03 4.48 3.9151 60 36.42 89.5 100.5 7.1 10.0 4.24 4.68 4.15 anthro447 6.1348 6.3749 5.8250 5.6651 5.91
> set.seed(1234)> ind <- sample(2, nrow(bodyfat), replace=TRUE, prob=c(0.7, 0.3))> bodyfat.train <- bodyfat[ind==1,]> bodyfat.test <- bodyfat[ind==2,]> # train a decision tree> library(rpart)> myFormula <- DEXfat ~ age waistcirc hipcirc elbowbreadth kneebreadth> bodyfat_rpart <- rpart(myFormula, data = bodyfat.train, control = rpart.control(minsplit = 10))> attributes(bodyfat_rpart)$names [1] 'frame' 'where' 'call' [4] 'terms' 'cptable' 'method' [7] 'parms' 'control' 'functions' [10] 'numresp' 'splits' 'variable.importance'[13] 'y' 'ordered' $xlevelsnamed list()$class[1] 'rpart'> print(bodyfat_rpart$cptable) CP nsplit rel error xerror xstd1 0.67272638 0 1.00000000 1.0194546 0.187243822 0.09390665 1 0.32727362 0.4415438 0.108530443 0.06037503 2 0.23336696 0.4271241 0.093628954 0.03420446 3 0.17299193 0.3842206 0.090305395 0.01708278 4 0.13878747 0.3038187 0.072955566 0.01695763 5 0.12170469 0.2739808 0.065996427 0.01007079 6 0.10474706 0.2693702 0.066136188 0.01000000 7 0.09467627 0.2695358 0.06620732> print(bodyfat_rpart)n= 56 node), split, n, deviance, yval * denotes terminal node 1) root 56 7265.0290000 30.94589 2) waistcirc< 88.4 31 960.5381000 22.55645 4) hipcirc< 96.25 14 222.2648000 18.41143 8) age< 60.5 9 66.8809600 16.19222 * 9) age>=60.5 5 31.2769200 22.40600 * 5) hipcirc>=96.25 17 299.6470000 25.97000 10) waistcirc< 77.75 6 30.7345500 22.32500 * 11) waistcirc>=77.75 11 145.7148000 27.95818 22) hipcirc< 99.5 3 0.2568667 23.74667 * 23) hipcirc>=99.5 8 72.2933500 29.53750 * 3) waistcirc>=88.4 25 1417.1140000 41.34880 6) waistcirc< 104.75 18 330.5792000 38.09111 12) hipcirc< 109.9 9 68.9996200 34.37556 * 13) hipcirc>=109.9 9 13.0832000 41.80667 * 7) waistcirc>=104.75 7 404.3004000 49.72571 *
> plot(bodyfat_rpart)> text(bodyfat_rpart, use.n=T)> #圖略
> opt <- which.min(bodyfat_rpart$cptable[,'xerror'])> cp <- bodyfat_rpart$cptable[opt, 'CP']> bodyfat_prune <- prune(bodyfat_rpart, cp = cp)> print(bodyfat_prune)n= 56 node), split, n, deviance, yval * denotes terminal node 1) root 56 7265.02900 30.94589 2) waistcirc< 88.4 31 960.53810 22.55645 4) hipcirc< 96.25 14 222.26480 18.41143 8) age< 60.5 9 66.88096 16.19222 * 9) age>=60.5 5 31.27692 22.40600 * 5) hipcirc>=96.25 17 299.64700 25.97000 10) waistcirc< 77.75 6 30.73455 22.32500 * 11) waistcirc>=77.75 11 145.71480 27.95818 * 3) waistcirc>=88.4 25 1417.11400 41.34880 6) waistcirc< 104.75 18 330.57920 38.09111 12) hipcirc< 109.9 9 68.99962 34.37556 * 13) hipcirc>=109.9 9 13.08320 41.80667 * 7) waistcirc>=104.75 7 404.30040 49.72571 *> plot(bodyfat_prune)> text(bodyfat_prune, use.n=T)> #圖略
> DEXfat_pred <- predict(bodyfat_prune, newdata=bodyfat.test)> xlim <- range(bodyfat$DEXfat)> plot(DEXfat_pred ~ DEXfat, data=bodyfat.test, xlab='Observed', ylab='Predicted', ylim=xlim, xlim=xlim)> abline(a=0, b=1)
> #圖略
> ind <- sample(2, nrow(iris), replace=TRUE, prob=c(0.7, 0.3))> trainData <- iris[ind==1,]> testData <- iris[ind==2,]
> library(randomForest)> rf <- randomForest(Species ~ ., data=trainData, ntree=100, proximity=TRUE)> table(predict(rf), trainData$Species) setosa versicolor virginica setosa 36 0 0 versicolor 0 31 1 virginica 0 1 35> print(rf)Call: randomForest(formula = Species ~ ., data = trainData, ntree = 100, proximity = TRUE) Type of random forest: classification Number of trees: 100No. of variables tried at each split: 2 OOB estimate of error rate: 1.92%Confusion matrix: setosa versicolor virginica class.errorsetosa 36 0 0 0.00000000versicolor 0 31 1 0.03125000virginica 0 1 35 0.02777778> attributes(rf)$names [1] 'call' 'type' 'predicted' 'err.rate' [5] 'confusion' 'votes' 'oob.times' 'classes' [9] 'importance' 'importanceSD' 'localImportance' 'proximity' [13] 'ntree' 'mtry' 'forest' 'y' [17] 'test' 'inbag' 'terms' $class[1] 'randomForest.formula' 'randomForest'
> plot(rf)> #圖略
> importance(rf) MeanDecreaseGiniSepal.Length 6.485090Sepal.Width 1.380624Petal.Length 32.498074Petal.Width 28.250058> varImpPlot(rf)> #圖略
> irisPred <- predict(rf, newdata=testData)> table(irisPred, testData$Species) irisPred setosa versicolor virginica setosa 14 0 0 versicolor 0 17 3 virginica 0 1 11> plot(margin(rf, testData$Species))
> #圖略