Consider the data set given below:
Python Code
Install the following package:
pip install pydotplus
To conduct the Decision Tree Regression, the following Python code will be useful
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
boston = pd.read_csv(r'C:\Users\user\Desktop\Data2.csv') # please use your source file location
boston
import matplotlib.pyplot as plt
plt.scatter(x=boston['F'], y=boston['N'],color='brown')
plt.xlabel('Average number of rooms per dwelling')
plt.ylabel('Median Value of Home')
Text(0, 0.5, 'Median Value of Home')
from sklearn.model_selection import train_test_split
x=pd.DataFrame(boston['F'])
y=pd.DataFrame(boston['N'])
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.20)
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.20, random_state=1)
from sklearn.tree import DecisionTreeRegressor
regressor=DecisionTreeRegressor(criterion= 'mse', random_state=100,max_depth=4, min_samples_leaf=1)
regressor.fit(x_train,y_train)
DecisionTreeRegressor(max_depth=4, random_state=100)
from sklearn.tree import export_graphviz
import pydotplus
export_graphviz(regressor, out_file= 'reg_tree.dot')
y_pred = regressor.predict(x_test)
print("Y Predicted Value",y_pred)
print("Y Actual Value",y_test)
Y Predicted Value [36.2 20.6 18.725 21.6 ]
Y Actual Value N
3 33.4
16 23.1
6 22.9
10 15.0
from sklearn.metrics import mean_squared_error
mse=mean_squared_error(y_pred,y_test)
rmse=np.sqrt(mse)
print("Root Mean squared Error",rmse)
Root Mean squared Error 4.332453837030466
By using the Graphviz (http://www.webgraphviz.com/) and the data from the exported dot file by using the command export_graphviz(regressor, out_file= 'reg_tree.dot') we can draw the decsision tree.
The data from the dot file is as follows:
digraph Tree {
node [shape=box]
;
0
[label="X[0] <= 6.861\nmse = 39.411\nsamples = 16\nvalue = 23.8"]
;
1
[label="X[0] <= 6.134\nmse = 11.479\nsamples = 13\nvalue =
21.169"] ;
0 -> 1
[labeldistance=2.5, labelangle=45, headlabel="True"] ;
2
[label="X[0] <= 5.733\nmse = 2.008\nsamples = 9\nvalue = 19.311"]
;
1 -> 2 ;
3
[label="mse = 0.0\nsamples = 1\nvalue = 16.5"] ;
2 -> 3 ;
4
[label="X[0] <= 5.977\nmse = 1.147\nsamples = 8\nvalue = 19.662"]
;
2 -> 4 ;
5
[label="mse = 0.445\nsamples = 4\nvalue = 20.6"] ;
4 -> 5 ;
6
[label="mse = 0.092\nsamples = 4\nvalue = 18.725"] ;
4 -> 6 ;
7
[label="X[0] <= 6.296\nmse = 7.543\nsamples = 4\nvalue = 25.35"] ;
1 -> 7 ;
8
[label="mse = 0.0\nsamples = 1\nvalue = 27.1"] ;
7 -> 8 ;
9
[label="X[0] <= 6.425\nmse = 8.696\nsamples = 3\nvalue = 24.767"]
;
7 -> 9 ;
10
[label="mse = 0.0\nsamples = 1\nvalue = 21.6"] ;
9 -> 10 ;
11
[label="mse = 5.522\nsamples = 2\nvalue = 26.35"] ;
9 -> 11 ;
12
[label="X[0] <= 7.166\nmse = 0.5\nsamples = 3\nvalue = 35.2"] ;
0 -> 12
[labeldistance=2.5, labelangle=-45, headlabel="False"] ;
13
[label="mse = 0.0\nsamples = 1\nvalue = 36.2"] ;
12 -> 13 ;
14
[label="mse = 0.0\nsamples = 2\nvalue = 34.7"] ;
12 -> 14 ;
}
Mrs. Divya D, Research Scholar, Division of IT, School of Engineering, Cochin University of Science and Technology.
good work
ReplyDeletegood work
ReplyDelete