#install.packages("ISLR") # if you don't have the package install it.
#install.packages("tree")# if you don't have the package install it.
library(ISLR)
library(tree)
attach(Carseats)
range(Sales)
high = ifelse(Sales >=8, "YES", "NO")
Carseats=data.frame(Carseats, high)
range(Sales)
set.seed(2)
str(Carseats)
range(high)
names(Carseats)
Carseats <- Carseats[,-1]
train <- sample(1:nrow(Carseats), nrow(Carseats)/2)
test = -train
training_data= Carseats[train,]
testing_data = Carseats[test,]
testing_high <- high[test]
tree_model <- tree(high~., training_data)
plot(tree_model)
text(tree_model,pretty=0)
#Predict
tree_pred = predict(tree_model, testing_data, type="class")
#Check model accuracy
error_rate = mean(tree_pred != testing_high) # 28%
##Pruning the tree. Cross-validate to see where to stop pruning
set.seed(9)
cv_tree = cv.tree(tree_model, FUN=prune.misclass)
cv_tree
names(cv_tree) #dev = cross_validation error rate
plot(cv_tree$size,cv_tree$dev, type = "b")
#Pruned model
pruned_model = prune.misclass(tree_model, best=9)
plot(pruned_model)
text(pruned_model)
#Check accuracy
tree_pred2=predict(pruned_model, testing_data, type = "class")
error_rate2= mean(tree_pred2 != testing_high) # 23%
#install.packages("tree")# if you don't have the package install it.
library(ISLR)
library(tree)
attach(Carseats)
range(Sales)
high = ifelse(Sales >=8, "YES", "NO")
Carseats=data.frame(Carseats, high)
range(Sales)
set.seed(2)
str(Carseats)
range(high)
names(Carseats)
Carseats <- Carseats[,-1]
train <- sample(1:nrow(Carseats), nrow(Carseats)/2)
test = -train
training_data= Carseats[train,]
testing_data = Carseats[test,]
testing_high <- high[test]
tree_model <- tree(high~., training_data)
plot(tree_model)
text(tree_model,pretty=0)
#Predict
tree_pred = predict(tree_model, testing_data, type="class")
#Check model accuracy
error_rate = mean(tree_pred != testing_high) # 28%
##Pruning the tree. Cross-validate to see where to stop pruning
set.seed(9)
cv_tree = cv.tree(tree_model, FUN=prune.misclass)
cv_tree
names(cv_tree) #dev = cross_validation error rate
plot(cv_tree$size,cv_tree$dev, type = "b")
#Pruned model
pruned_model = prune.misclass(tree_model, best=9)
plot(pruned_model)
text(pruned_model)
#Check accuracy
tree_pred2=predict(pruned_model, testing_data, type = "class")
error_rate2= mean(tree_pred2 != testing_high) # 23%
No comments:
Post a Comment