Pages

Decision Tree Regression - Supervised Learning

Consider the data set given below:



Python Code

Install the following package:

pip install pydotplus

To conduct the Decision Tree Regression, the following Python code will be useful


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
boston = pd.read_csv(r'C:\Users\user\Desktop\Data2.csv') # please use your source file location 
boston


import matplotlib.pyplot as plt
plt.scatter(x=boston['F'], y=boston['N'],color='brown')
plt.xlabel('Average number of rooms per dwelling')
plt.ylabel('Median Value of Home')


Text(0, 0.5, 'Median Value of Home')





from sklearn.model_selection import train_test_split
x=pd.DataFrame(boston['F'])
y=pd.DataFrame(boston['N'])

from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.20)
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.20,  random_state=1)

from sklearn.tree import DecisionTreeRegressor
regressor=DecisionTreeRegressor(criterion= 'mse', random_state=100,max_depth=4, min_samples_leaf=1)
regressor.fit(x_train,y_train)


DecisionTreeRegressor(max_depth=4, random_state=100)


from sklearn.tree import export_graphviz
import pydotplus
export_graphviz(regressor, out_file= 'reg_tree.dot')
y_pred = regressor.predict(x_test)
print("Y Predicted Value",y_pred)
print("Y Actual Value",y_test)


Y Predicted Value [36.2   20.6   18.725 21.6  ]
Y Actual Value        N
3   33.4
16  23.1
6   22.9
10  15.0

from sklearn.metrics import mean_squared_error
mse=mean_squared_error(y_pred,y_test)
rmse=np.sqrt(mse)
print("Root Mean squared Error",rmse)

Root Mean squared Error 4.332453837030466


By using the Graphviz (http://www.webgraphviz.com/) and the data from the exported dot file by using the command export_graphviz(regressor, out_file= 'reg_tree.dot') we can draw the decsision tree.

The data from the dot file is as follows:

digraph Tree {

node [shape=box] ;

0 [label="X[0] <= 6.861\nmse = 39.411\nsamples = 16\nvalue = 23.8"] ;

1 [label="X[0] <= 6.134\nmse = 11.479\nsamples = 13\nvalue = 21.169"] ;

0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;

2 [label="X[0] <= 5.733\nmse = 2.008\nsamples = 9\nvalue = 19.311"] ;

1 -> 2 ;

3 [label="mse = 0.0\nsamples = 1\nvalue = 16.5"] ;

2 -> 3 ;

4 [label="X[0] <= 5.977\nmse = 1.147\nsamples = 8\nvalue = 19.662"] ;

2 -> 4 ;

5 [label="mse = 0.445\nsamples = 4\nvalue = 20.6"] ;

4 -> 5 ;

6 [label="mse = 0.092\nsamples = 4\nvalue = 18.725"] ;

4 -> 6 ;

7 [label="X[0] <= 6.296\nmse = 7.543\nsamples = 4\nvalue = 25.35"] ;

1 -> 7 ;

8 [label="mse = 0.0\nsamples = 1\nvalue = 27.1"] ;

7 -> 8 ;

9 [label="X[0] <= 6.425\nmse = 8.696\nsamples = 3\nvalue = 24.767"] ;

7 -> 9 ;

10 [label="mse = 0.0\nsamples = 1\nvalue = 21.6"] ;

9 -> 10 ;

11 [label="mse = 5.522\nsamples = 2\nvalue = 26.35"] ;

9 -> 11 ;

12 [label="X[0] <= 7.166\nmse = 0.5\nsamples = 3\nvalue = 35.2"] ;

0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;

13 [label="mse = 0.0\nsamples = 1\nvalue = 36.2"] ;

12 -> 13 ;

14 [label="mse = 0.0\nsamples = 2\nvalue = 34.7"] ;

12 -> 14 ;

}

Acknowledgement

Mrs. Divya D, Research Scholar, Division of IT, School of Engineering, Cochin University of Science and Technology.

2 comments: