Jak wyhodować drzewo … decyzyjne?

Drzewo decyzyjne (ang. decision tree) to bardzo prosty algorytm używany w uczeniu nadzorowanym. Jest na tyle prosty (jak sam zobaczysz), że często jest pierwszym algorytmem używanym jako przykład w uczeniu maszynowym. W tym artykule chciałbym wytłumaczyć jego strukturę oraz pokazać jak możemy użyć drzewa decyzyjnego przy pomocy modułu Scikit-Learn.

Koncepcją, która stoi za drzewem decyzyjnym, jest zadawanie pytań, na które można uzyskać tylko odpowiedzi Tak lub Nie. I tak w zasadzie wygląda używanie zbudowanego drzewa decyzyjnego. Jak więc zbudować takie drzewo? Tutaj też sytuacja jest prosta, ale będzie wymagała wprowadzenia kilku pojęć.

Entropia

Zacznijmy od klasyfikacji. Najtrudniejszym elementem, który będzie nam potrzebny do zrozumienia czym jest i jak działa drzewo decyzyjne w klasyfikacji jest entropia. Tak, entropia która przewija się w fizyce, matematyce, teorii informacji i ogólnie w naukach technicznych. Czym jest więc ta entropia? W przypadku drzew decyzyjnych entropia będzie miarą wymieszania cechy zależnej. Czyli będzie nam mówić jak bardzo nasz zbiór jest nieczysty. Najłatwiej zrozumieć to na przykładzie dwóch klas — wtedy mamy do czynienia z entropią binarną.

Entropia binarna
Entropia binarna — Wikipedia

Na wykresie obok widzimy zależność entropii binarnej od prawdopodobieństwa natrafienia na obiekt klasy 1 w zbiorze. Jeśli prawdopodobieństwo to jest równe 0 lub jest bardzo bliskie 0 to entropia będzie też dość bliska 0. Gdy prawdopodobieństwo rośnie, entropia również będzie rosła, aż do momentu, gdy prawdopodobieństwo osiągnie poziom 0.5. Wtedy entropia osiągnie maksimum. Następnie prawdopodobieństwo będzie rosnąć, ale entropia zacznie maleć. Będzie tak maleć aż do 0, gdy prawdopodobieństwo napotkania obserwacji o klasie 1 będzie równe 1.

Bardzo przydatną „intuicją” będzie w tym przypadku wyobrażenie sobie entropii jako miary nieczystości właśnie. Gdy mamy zbiór gdzie są prawie same obserwacje klasy 0, możemy śmiało powiedzieć, że wylosowana obserwacja pewnie będzie należała do klasy 0. Z naszego punktu widzenia sytuacja jest więc prosta – nasz zbiór jest bardzo mało zabrudzony „śmieciami” z klasy 1. No i entropia takiego zbioru będzie mała. Gdy natomiast w naszym zbiorze jest tyle samo obiektów klasy 0 jak i klasy 1 to w zasadzie nie możemy za bardzo zgadywać na co natrafimy, więc nasz zbiór jest bardzo nieczysty albo wymieszany.

A jak to wygląda od strony technicznej? Entropię w teorii informacji opisujemy w ten sposób:

H(X) = -\sum_{i=1}^n p_i \log_2 p_i

Jeżeli powrócimy do naszego przypadku, gdzie mamy tylko dwie klasy, powyższy wzór stanie się następujący:

H(X) = -p \log p-(1-p) \log_2 (1-p)

Przykładowe dane

Skoro znamy już pojęcie entropii, to możemy przejść do przykładu jej wyznaczenia. Stwórzmy do tego celu zbiór danych. Wyobraźmy sobie, że prowadzimy wydawnictwo które w swoim katalogu ma magazyn o stylu życia. Każdemu subskrybentowi wysyłamy ankietę w której w trzech pytaniach pytamy się o jego gusta – książka czy telewizja, kino czy teatr oraz kawa czy herbata. Następnie po roku subskrypcji sprawdzamy, czy ją odnowił. Taki zbiór danych wyglądałby mniej więcej tak:

wolne napój wyjście po_roku
0 książka herbata teatr Tak
1 książka kawa teatr Nie
2 książka kawa kino Tak
3 książka herbata kino Tak
4 telewizja kawa teatr Nie
5 książka kawa kino Tak
6 telewizja herbata kino Nie
7 telewizja kawa teatr Tak
8 książka kawa kino Nie
9 książka herbata kino Tak
10 telewizja herbata kino Nie
11 telewizja herbata teatr Nie
12 telewizja kawa teatr Nie
13 książka kawa teatr Tak
14 telewizja herbata kino Nie
15 telewizja kawa kino Nie
16 telewizja herbata teatr Tak
17 książka herbata teatr Nie
18 książka herbata teatr Nie
19 telewizja kawa teatr Tak

Cecha która dla nas będzie najbardziej interesujące to po_roku i to od obliczenia entropii względem niej zaczniemy. W naszym przypadku entropia będzie więc wynosić:

H(po\_roku) = -p(Tak)\cdot \log p(Tak)-(1-p(Tak))\cdot \log_2 (1-p(Tak)) \approx -0.45\cdot -1.15200-0.55\cdot -0.862496 = 0.9927728

Uzyskaliśmy wynik 0.9927728 co informuje nas, że odpowiedzi Tak i Nie nie występują tyle samo razy (występowałyby gdybyśmy mieli tutaj 1), ale nie ma tutaj znaczącej przewagi ani jednej z odpowiedzi. I to się zgadza, bo Tak wystąpiło 9 razy a Nie 11 razy.

Podział na podzbiory

Chcielibyśmy przewidzieć czy subskrybent przedłuży subskrypcję już na podstawie wyników ankiety. W tym celu możemy się przyjrzeć wynikom historycznym które uzyskaliśmy i zastanowić się, czy jakieś pytanie które zadawaliśmy pomoże nam dokonać takiej predykcji (przewidzenia czy subskrybent pozostanie). Zastanówmy się, czy np. wybrany napój miał jakiś wpływ na subskrypcję. Podzielmy więc naszych klientów na herbaciarzy i kawoszy i sprawdźmy, czy entropia cechy po_roku zmieni się jakoś znacząco. Zacznijmy od herbaciarzy. W naszym zbiorze danych mamy 10 herbaciarzy którzy wyglądają tak:

wolne napój wyjście po_roku
0 książka herbata teatr Tak
3 książka herbata kino Tak
6 telewizja herbata kino Nie
9 książka herbata kino Tak
10 telewizja herbata kino Nie
11 telewizja herbata teatr Nie
14 telewizja herbata kino Nie
16 telewizja herbata teatr Tak
17 książka herbata teatr Nie
18 książka herbata teatr Nie

Widzimy tutaj 4 osoby które przedłużyły subskrypcję i 6 które się na to nie zdecydowały. Po podstawieniu tych wartości do wzoru na entropię otrzymujemy H_herbata = 0.9709505944546686, czyli lepiej. Spójrzmy również na kawoszy:

wolne napój wyjście po_roku
1 książka kawa teatr Nie
2 książka kawa kino Tak
4 telewizja kawa teatr Nie
5 książka kawa kino Tak
7 telewizja kawa teatr Tak
8 książka kawa kino Nie
12 telewizja kawa teatr Nie
13 książka kawa teatr Tak
15 telewizja kawa kino Nie
19 telewizja kawa teatr Tak

Tutaj sytuacja jest inna – mamy równowagę 5 do 5 więc nasza entropia będzie wynosić okrągłe 1 (H_kawa = 1). W tym wypadku mamy więc pogorszenie naszej sytuacji względem zbioru do którego wpadają wszyscy.

Potencjalny problem

Entropia mówi nam o czystości danego zbioru, ze względu na jakąś cechę. Mamy entropię w zbiorze podstawowym, ale możemy też wyliczyć entropię w podzbiorach. W części podzbiorów entropia się zmniejszy, a w części może się zwiększyć. Z czego to wynika i jak sobie z tym poradzić?

Wyobraźmy sobie, że w naszym zbiorze danych mielibyśmy cechę zagranica. Jedna osoba w naszym zbiorze miałaby tam Tak, a pozostałe Nie. Ta jedna osoba z zagranicy Nie przedłużyłaby subskrypcji po roku. Jeżeli dokonalibyśmy podziału w taki sposób to H_zagranica = 0 a H_lokalni = 0.9980008838722995. Uzyskaliśmy w ten sposób dwa zbiory – jeden z minimalną entropią, a drugi z prawie maksymalną. Sprawdźmy jak się mają wartości entropii w naszym przykładzie (przypominam, że badamy entropię ze względu na wartość po_roku):

H_całość = 0.9927744539878083

H_książka = 0.9709505944546686
H_telewizja = 0.8812908992306927

H_kawa = 1.0
H_herbata = 0.9709505944546686

H_kino = 0.9910760598382222
H_teatr = 0.9940302114769565

H_lokalni = 0.9980008838722995
H_zagranica = 0

W tym momencie możemy pokusić się o stwierdzenie, że skoro w jednym zbiorze udało nam się zmniejszyć entropię, to taki podział będzie bardzo dla nas przydatny, jeśli chcemy dokonywać predykcji. No tak, ale jeśli ktoś zapytałby się nas w jaki sposób dokonujemy predykcji to z dużą pewnością siebie powiedzielibyśmy, że jeśli ktoś jest z zagranicy to pewnie nie przedłuży, a w pozostałych przypadkach prawie nic nie wiemy (9 przedłużyło, 10 nie przedłużyło). No więc mimo iż udało nam się zmniejszyć entropię (pozytywny efekt) w jednym potencjalnym zbiorze to nie pomaga nam to w predykcji dla znacznej większości potencjalnych klientów (negatywny efekt). No i tutaj przychodzi nam na ratunek nowe pojęcie jakim jest przyrost informacji.

Przyrost informacji

Przyrost informacji (ang. information gain) mówi nam o tym – ile informacji udało nam się uzyskać poprzez podział na podzbiory. Wzór na przyrost informacji wykorzystuje wyliczoną entropię, ale uwzględnia także liczność utworzonych podzbiorów. Wylicza się go w następujący sposób:

Przyrost \; informacji(rodzic, dzieci) = entropia(rodzic) - [p(c_1) \cdot entropia (c_1) + p(c_2) \cdot entropia (c_2) + ...]

czyli w przypadku napojów będzie to:

Przyrost \; informacji(calosc,\; podzbiory\; po\; napojach) = H\_calosc - p(kawa) \cdot H\_kawa - p(herbata) \cdot H\_herbata
PI\_napoj = 0.9927744539878083 - 0.5 \cdot 1 - 0.5 \cdot 0.9709505944546686 = 0.00729915676047399

Podsumujmy więc przyrost informacji dla trzech poszczególnych pytań w ankiecie i lokalizacji subskrybenta:

PI_wolne = 0.06665370714512764
PI_napój = 0.00729915676047399
PI_wyjście = 0.00007361074828204917
PI_zagranica = 0.044673614309123866

Jako że interesuje nas tutaj uzyskana informacja ze względu na cechę którą chcielibyśmy przewidywać w przyszłości, interesuje nas wartość najwyższa. Jeśli więc mielibyśmy zadać jedno pytanie, to pytalibyśmy się o to jak nasi subskrybenci spędzają czas wolny. Jeśli wybierają opcję książka to przewidujemy, że na podstawie danych historycznych przedłużą subskrypcję (6 na 10 tak uczyniło), a jeśli w ankiecie zaznaczyli opcję telewizja to przewidujemy, że nie przedłużą subskrypcji (7 na 10 nie przedłużyło).

Wizualizacja procesu

W ten oto sposób dokonaliśmy pierwszego podziału. Możemy się zatrzymać w tym miejscu albo możemy kontynuować. Dobra wiadomość jest taka, że gdybyśmy chcieli kontynuować, to wystarczy, że potraktujemy każdy z dwóch nowych podzbiorów (mole książkowi, telemaniacy) jako zbiór bazowy, i będziemy je dzielić w dokładnie taki sam sposób. Nie będę tego tutaj opisywał, pokażę natomiast jak wygląda efekt takiego powtarzającego się budowania podzbiorów:

Drzewo decyzyjne
Drzewo decyzyjne powstałe na bazie omawianego zbioru danych.

Patrząc od góry, widzimy pole z czterema wierszami:

  1. wolne <= 0.5 mówi nam o kryterium podziału. Jako że zastosowałem bibliotekę (Scikit-Learn) która nie za bardzo chce działać z danymi tekstowymi („książka”, „telewizja”) musiałem je zamienić na cyfry. I tak książka dostała umowną wartość 0 a telewizja 1. Biblioteka ta przezornie (aczkolwiek w naszym przypadku zupełnie zbędnie) ustawiła kryterium podziału całkiem pośrodku tych dwóch wartości.
  2. entropy = 0.993 mówi o entropii w tym zbiorze. Każde pole reprezentuje jakiś (pod)zbiór i dla każdego z nich wyliczana jest entropia.
  3. samples = 20 mówi o liczności zbioru.
  4. value = [9, 11] mówi o obserwowanych klasach w danym zbiorze. 9 obserwacji należało tam do klasy Tak, a 11 do klasy Nie.
  5. class = Nie mówi nam o wyborze klasy, jeśli algorytm miałby decydować w tym miejscu.

Pole to reprezentuje więc nasz podstawowy zbiór. Z tego pola wychodzą dwie strzałki podpisane True i False. Strzałki te mówią gdzie powinniśmy się udać po uzyskaniu odpowiedzi na pytanie czy wolne <= 0.5. Jeśli mielibyśmy nowego subskrybenta to na bazie pierwszego pola od góry przyjęlibyśmy, że Nie przedłuży subskrypcji. Natomiast jeśli otrzymalibyśmy jego ankietę, to moglibyśmy spojrzeć na odpowiedź jak spędza czas wolny. Jeśli wpisał by książka (0) to odpowiedź była by True i wtedy moglibyśmy przejść do pola poniżej i po lewej gdzie wartość class = Tak. I jest to zgodne z tym co omówiliśmy powyżej. Moglibyśmy drążyć temat dalej w taki sam sposób aż doszlibyśmy do któregoś z pól końcowych, tzn. takich które już się dalej nie dzielą. Pola takie mogą się nie dzielić z trzech powodów:

  • pole to już jest czyste i nie ma jak go dalej dzielić – bardzo dobra sytuacja
  • pole to ma dokładnie takie same obserwacje (wyniki ankiet), ale zupełnie inne klasy (jeden przedłużył a drugi nie przedłużył) – sytuacja jest nierozwiązywalna i jest to traktowane jako nie dobra sytuacja
  • algorytm jest ograniczony hiperparametrami (o tym w innym artykule)

W powyższym przypadku mamy siedem takich pól, z których dwa są czyste, trzy nierozwiązywalne i dwa mieszane.

Drzewo

W ten właśnie sposób zbudowaliśmy model który jest nazywany drzewem decyzyjnym – klasyfikacyjnym. Dlatego klasyfikacyjnym, bo dokonujemy klasyfikacji naszych obserwacji. Alternatywą jest drzewo decyzyjne – regresyjne w którym wyznaczana jest konkretna wartość. W przypadku takiego drzewa wyznaczana jest redukcja wariancji zamiast przyrostu informacji.

Dlaczego w ogóle ten algorytm nazywany jest drzewami decyzyjnymi? Otóż jeśli weźmiemy powyższą wizualizację i odwrócimy ją „do góry” nogami to otrzymamy strukturę która wygląda jak drzewo, a pola końcowe to po prostu liście. Ot, cała geneza nazwy.

Python

Jak więc zabrać się za hodowlę takiego drzewa? Możemy zrobić to ręcznie wyznaczając poszczególne wartości (entropia, przyrost informacji) dla każdej z cech, albo możemy też użyć do tego celu wspomnianego modułu Scikit-Learn.

Zacznijmy więc od przygotowania danych:

import pandas as pd
wolne = ['książka', 'książka', 'książka', 'książka', 'telewizja', 'książka', 'telewizja', 'telewizja', 'książka', 
         'książka', 'telewizja', 'telewizja', 'telewizja', 'książka', 'telewizja', 'telewizja', 'telewizja', 'książka', 
         'książka', 'telewizja']
napój = ['herbata', 'kawa', 'kawa', 'herbata', 'kawa', 'kawa', 'herbata', 'kawa', 'kawa', 'herbata', 'herbata', 
         'herbata', 'kawa', 'kawa', 'herbata', 'kawa', 'herbata', 'herbata', 'herbata', 'kawa']
wyjście = ['teatr', 'teatr', 'kino', 'kino', 'teatr', 'kino', 'kino', 'teatr', 'kino', 'kino', 'kino', 'teatr', 'teatr',
           'teatr', 'kino', 'kino', 'teatr', 'teatr', 'teatr', 'teatr']
po_roku = ['Tak', 'Nie', 'Tak', 'Tak', 'Nie', 'Tak', 'Nie', 'Tak', 'Nie', 'Tak', 'Nie', 'Nie', 'Nie', 'Tak', 'Nie', 
           'Nie', 'Tak', 'Nie', 'Nie', 'Tak']
prenumeratorzy = pd.DataFrame({"wolne": wolne, "napój": napój, "wyjście": wyjście, "po_roku":po_roku})

I dostosowania ich do naszych potrzeb (zamiana napisów na liczby):

prenumeratorzy["wolne"], wolne_kody = pd.factorize(prenumeratorzy["wolne"])
prenumeratorzy["napój"], napój_kody = pd.factorize(prenumeratorzy["napój"])
prenumeratorzy["wyjście"], wyjście_kody = pd.factorize(prenumeratorzy["wyjście"])
prenumeratorzy["po_roku"], po_roku_kody = pd.factorize(prenumeratorzy["po_roku"])

Funkcje modelujące w Scikit-Learn potrzebują mieć na wejściu dwa parametry: X – cechy danych obserwacji i y – klasę (bądź wartość) którą będziemy chcieli przewidywać. Przygotujmy więc nasze dane w taki sposób:

X = prenumeratorzy.drop(["po_roku"], axis = 1)
y = prenumeratorzy["po_roku"]

I w tym momencie jesteśmy już gotowi do zbudowania drzewa decyzyjnego dla naszego przypadku:

from sklearn.tree import DecisionTreeClassifier
klasyfikator = DecisionTreeClassifier(criterion = "entropy")
klasyfikator.fit(X = X, y = y)

W tych kilkunastu kilku linijkach wytrenowaliśmy drzewo decyzyjne przy pomocy Scikit-Learn. Prawda, że wygląda to fajnie? Możemy jeszcze stworzyć wizualizację taką jaka jest prezentowana powyżej (wyświetli się, jeśli używać jupyter notebook):

import graphviz
from sklearn import tree
drzewo = tree.export_graphviz(klasyfikator, out_file=None, 
                         feature_names=X.columns,  
                         class_names=po_roku_kody,  
                         filled=True, rounded=True,  
                         special_characters=True)
graf = graphviz.Source(drzewo)
graf

Podsumowanie

Drzewa decyzyjne nie są najefektywniejszymi modelami w uczeniu maszynowym. Jednakże mają kilka zalet które powodują, że korzystanie z nich może mieć sens:

  • łatwe w zrozumieniu i interpretacji – artykuł który przeczytałeś zawiera wszystkie niezbędne informacje abyś mógł interpretować stworzone drzewa. To naprawdę jest takie proste.
  • możliwość używania danych ilościowych i jakościowych – jeśli mamy dane ilościowe (np. wzrost – 1.88, 1,90 itd.) to funkcja modelująca będzie szukać optymalnej wartości do podziału. W przypadku danych jakościowych (dobry, lepszy, najlepszy) wystarczy je odpowiednio zakodować.
  • dane nie wymagają dużych przygotowań – niektóre algorytmy wymagają odpowiednio wyskalowanych danych. Tutaj nie ma takiego problemu – dajemy to co mamy i będzie dobrze.
  • Otrzymany model to white-box – oznacza to, że jesteśmy w pełni w stanie wytłumaczyć, dlaczego model dokonał takiej predykcji – tzn. na jakich wartościach danych cech bazował.

Żeby nie było zbyt różowo, drzewa decyzyjne mają też kilka wad:

  • utworzone drzewa mogą być skomplikowane – przy dużej ilości obserwacji i cech drzewa które powstaną będą miały dużo poziomów i bardzo dużo liści.
  • są zmienne – dodanie albo odjęcie nawet jednej obserwacji może spowodować zbudowanie całkiem innego drzewa. Sytuacja taka nastąpi, gdy przyrost informacji bazujący na nowym zbiorze wskaże inną wartość albo inną cechę jako kryterium pierwszego podziału.
  • mają dość dużą złożoność obliczeniową – funkcja modelująca musi sprawdzić wszystkie możliwe wartości dla wszystkich cech.
  • domyślnie ( bez użycia parametru random_state) w Scikit-Learn mogą być trudne do odtworzenia – z racji powyżej wspomnianej złożoności obliczeniowej programiści Scikit-Learn zastosowali kilka sztuczek, żeby przyspieszyć proces. Bazują one na algorytmach wykorzystujących losowość. Przy pewnych kombinacjach danych ponownie wytrenowane drzewo może różnić się od swojego poprzednika, mimo iż sam algorytm jest w pełni precyzyjny.

W czasie tworzenia rozwiązań bazujących na modelach predykcyjnych warto rozważyć użycie drzew decyzyjnych – szczególnie na początku, gdy poznajemy dane. Z racji tego, że wszystko jest w nich widoczne – możemy się dowiedzieć wiele ciekawych rzeczy o naszym problemie. Przy poszukiwaniu najlepszego rozwiązania które ma wylądować „na produkcji”, drzewa nie będą aż tak pomocne, chociażby z tego powodu, że są prawie najsłabszym algorytmem modelującym. W każdym razie warto mieć gdzieś tę koncepcję z tyłu głowy.

Dodatek

Opisany powyżej przykład jest szerzej przeze mnie omówiony w kursie Wystartuj z Data Science w Pythonie! który jest dostępny na platformie Udemy. Pod tym linkiem jest dostępny notebook którego tam użyłem. Jeśli chciałbyś zakupić dostęp do tego kursu sprawdź tę stronę – może akurat opublikowałem tam kupony zniżkowe 🙂

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *