Decision Tree

โดย ชิตพงษ์ กิตตินราดร | มกราคม 2563

Decision tree เป็น Algorithm ที่เป็นที่นิยม ใช้ง่าย เข้าใจง่าย ได้ผลดี และเป็นฐานของ Random Forest ซึ่งเป็นหนึ่งใน Algorithm ที่ดีที่สุดในปัจจุบัน

หลักการพยากรณ์ด้วย Decision tree นั้นเข้าใจง่ายมาก ให้นึกว่า Decision tree คือต้นไม้กลับหัว โดยบนสุดคือราก และส่วนล่างลงมาที่ไม่สามารถแตกไปไหนได้แล้วก็คือใบ เราจะเริ่มด้วยการพิจารณาเริ่มแรกบนจุดเริ่มต้นที่เรียกว่า Root node ถ้าข้อมูลที่พบเป็นไปตามเงื่อนไขนั้น การตัดสินใจก็จะวิ่งไปทางซ้ายของ Root node ไปที่จุดที่เรียกว่า Child node ซึ่งถ้าข้อมูลที่มาตามเส้นทางนี้ตรงตามเงื่อนไขของ Child node นี้ ก็จะถือว่าสิ้นสุด เราเรียกว่า Node สิ้นสุดว่า Leaf node

ย้อนกลับไปยัง Root node ถ้าข้อมูลที่พิจารณาไม่เป็นไปตามเงื่อนไข การตัดสินใจจะวิ่งไปอีกทาง คือทางขวา ไปพบ Child node อีกอันซึ่งก็จะตั้งเงื่อนไขคำถามต่อไป การตัดสินใจก็จะวิ่งไปทางที่ตรงตามเงื่อนไข ทำอย่างนี้ไปเรื่อยๆ จนได้คำตอบ

สำหรับกลไกการเทรนโมเดลเพื่อสร้าง Decision tree จะอธิบายด้านล่าง เมื่อได้อธิบายเรื่องค่า Gini

ก่อนจะไปไกลกว่านั้น จะชี้ให้เห็นว่า Decision tree นั้นไม่ต้องใช้ข้อมูลที่ทำ Feature scaling เพราะไม่ได้มี Optimisation algorithm แบบทั่วไป จึงใช้งานสะดวกมาก

เพื่อความเข้าใจมากขึ้น ลองมา Program ด้วย scikit-learn กันเลย โดยใช้ชุดข้อมูล Iris เหมือนเดิม เพื่อความเรียบง่ายจะเลือก Feature เฉพาะ Petal length และ Petal width:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz

# Load the iris data
iris = datasets.load_iris()
X = iris.data[:, 2:]
y = iris.target

# Split the data into train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

แล้วก็ลองเทรนโมเดลกันเลย:

# Train and fit the model
tree_clf = DecisionTreeClassifier(max_depth=5).fit(X_train, y_train)

# Evaluate the model's accuracy
print("Train set accuracy = " + str(tree_clf.score(X_train, y_train)))
print("Test set accuracy = " + str(tree_clf.score(X_test, y_test)))

สังเกตว่ามี Argument max-depth นั่นคือตัวควบคุมว่าจะให้ต้นไม้ของเรามีความลึกกี่ขั้น ถ้าเราพบว่าโมเดลของเรา Overfit ก็ควรลองลดจำนวนชั้นของความลึกลง ดังนั้น max-depth จึงเป็น Regularisation hyperparameter ของ Decision trees

สำหรับความคลาดเคลื่อน พบว่าได้ความแม่นยำจาก Method .score ดังนี้:

Train set accuracy = 0.9821428571428571
Test set accuracy = 1.0

คือแม่นยำ 100% ทีเดียว

จากนั้นเรามาลองสร้างกราฟการตัดสินใจ โดยใช้ฟังก์ชัน export_graphviz ในโมดูล sklearn.tree เพื่อ Export กราฟออกมาด้วย Graphviz ซึ่งเป็น Open-source graph visualisation software:

# Export graph
export_graphviz(tree_clf, out_file="iris_tree.dot",
               feature_names=iris.feature_names[2:],
               class_names=iris.target_names,
               rounded=True, filled=True)

ฟังก์ชันนี้จะ Export กราฟออกมาเป็นไฟล์นามสกุล .dot ซึ่งเราต้องใช้คำสั่ง dot จาก Graphviz package เพื่อแปลงไฟล์เป็น .png โดยเรียกคำสั่งจาก Command line:

$ dot -Tpng iris_tree.dot -o iris_tree.png

กราฟที่ได้ออกมาเป็นดังนี้:

Iris dataset GraphViz output

ถ้าอ่านตามแผนภาพก็น่าจะพอเข้าใจคำอธิบายข้างต้นว่า Decision tree พยากรณ์อย่างไร อย่างไรก็ตาม ในแต่ละ Node มีรายละเอียดที่ควรรู้เพื่อเพิ่มความเข้าใจดังนี้:

เพื่อให้เป็นประโยชน์ในการทำความเข้าใจ Algorithm ในส่วนถัดไป เราสามารถบอกได้ว่าค่า gini คำนวนตามสูตรนี้:

โดย คือสัดส่วนว่าจากจำนวนรายการข้อมูลใน Node ที่ นั้นอยู่ใน Class กี่รายการ

มาถึงจุดนี้เราก็พร้อมที่จะมาทำความเข้าใจว่า Decision tree algorithm นั้นสร้างโมเดลได้อย่างไร โดย scikit-learn จะใช้ Algorithm ที่ชื่อ Classification And Regression Tree (CART) ซึ่งทำงานตามลำดับดังนี้:

1) แบ่ง Train set ออกเป็น 2 ส่วนโดยเลือก Class และเงื่อนไข เช่น petal length <= 2.45 โดยค้นหาคู่ของ และ ที่จะได้ Node ที่ "บริสุทธิ์" ที่สุด นั่นคือมีค่า Gini ต่ำที่สุดนั่นเอง เราสามารถแสดง Cost function ที่สอดคล้องกับเงื่อนไขนี้ได้ดังนี้:

2) แยกข้อมูลแต่ละชุดย่อยออกเป็นสองชุดและทำซ้ำข้อ 1) เรื่อยๆ จนกระทั่งถึงความลึก max-depth ที่กำหนด หรือจนกระทั่งไม่พบค่า และ ที่จะลดความไม่บริสุทธิ์ได้อีกต่อไป

ตอนนี้เราก็เข้าใจแล้วว่า Decision tree สร้างโมเดลอย่างไร สุดท้ายเรามาลองพล็อตเส้นแบ่งการตัดสินใจจากตัวอย่างของเรา:

# Plot the decision boundaries
def plot_decision_boundary(clf, X, y, cmap='Paired_r'):
    h = 0.005  # Boundary lines' resolution
    x_min, x_max = X[:,0].min() - 10*h, X[:,0].max() + 10*h
    y_min, y_max = X[:,1].min() - 10*h, X[:,1].max() + 10*h
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    plt.figure(figsize=(7,6))
    plt.contourf(xx, yy, Z, cmap=cmap, alpha=0.25)  # Background
    plt.contour(xx, yy, Z, colors='k', linewidths=0.2)  # Boundary lines
    plt.scatter(X[:,0], X[:,1], c=y, cmap=cmap);  # Data points
    plt.xlabel(iris.feature_names[2])
    plt.ylabel(iris.feature_names[3])

plot_decision_boundary(tree_clf, X, y)

ได้ผลคือ:

Iris decision trees decision boundary

จะเห็นว่า Decision tree สร้างเส้นแบ่งการตัดสินใจที่เป็นเส้น Orthogonal คือแนวระนาบและแนวดิ่งที่ตั้งฉากกันเท่านั้น ดังนั้นในบางกรณีจะเป็นไปไม่ได้เลยที่ Decision tree จะทำนาย Test set ได้แม่นยำ 100%

ข้อสังเกตนี้สอดคล้องกับธรรมชาติของ CART algorithm ที่ค้นหา Parameter ที่ดีที่สุดจากบนลงล่างโดยตัดสินใจจากเงื่อนไขที่พบในลำดับชั้นปัจจุบันเท่านั้น โดยไม่ได้เช็คว่าการตัดสินใจในชั้นนั้นจะส่งผลให้ลำดับชั้นล่างๆ ลงมามีค่าความไม่บริสุทธิ์น้อยที่สุดหรือไม่ เราเรียกพฤติกรรมแบบนี้ว่า Greedy algorithm ซึ่งพฤติกรรมแบบนี้ทำให้ต้องใช้เวลานานจนแทบเป็นอนันต์จึงจะหาต้นไม้ที่ดีที่สุดได้ โดยใช้เวลาถึง ในทางคณิตศาสตร์เรียกปัญหานี้ว่า ปัญหา NP-Complete ดังนั้นเมื่อเราใช้ Decision tree เราจึงต้องยอมรับผลลัพธ์ที่อาจจะไม่สมบูรณ์แบบ แต่ส่วนมากก็ดีพอสำหรับงานส่วนมาก

บทต่อไปเราจะเรียนรู้ Random Forest ซึ่งคือการนำต้นไม้หลายๆ ต้นมารวมกันเป็นป่า เพื่อเพิ่มความแม่นยำในการ Generalise โมเดลให้ทำงานได้ดีขึ้นกับข้อมูลที่โมเดลไม่เคยเห็น

หน้าแรก | บทที่ 8 Support Vector Machines | บทที่ 10 Random Forest

Creative Commons License
This work is licensed under a Creative Commons Attribution 4.0 International License.