ELI5 i czarne pudełka

Gdy tworzymy model, któremu jest bardzo blisko do typowego białego pudełka (np. drzewo decyzyjne), to bardzo łatwo jest nam zorientować się, które cechy naszych obserwacji są istotne. Nieco inaczej sprawa ma się w modelach, które bardziej przypominają czarne pudełka. Modele takie, mimo iż niczego nie ukrywają, nie są w stanie wskazać nam które informacje były dla nich najważniejsze. Dostajemy od nich odpowiedź niejako z komentarzem „bo tak”.

Czy możemy coś sensownego z tym zrobić? Okazuje się, że tak. Nawet jeżeli funkcja modelująca nie będzie chciała (mogła) nam podać feature_importances_, sami będziemy mogli wyznaczyć najbardziej przydatne kolumny. Użyjemy do tego algorytmu permutation importance.

Permutation Importance

Permutation importance to algorytm, który wskazuje nam ważność danej kolumny w ramce danych w problemie typu uczenie nadzorowane. W takim uczeniu maszynowym, mamy co najmniej jedną kolumnę zależną (zwaną również target), którą chcemy połączyć z innym kolumnami ramki danych. Łączymy ją po to, żeby później, gdy będziemy mieć wszystkie pozostałe dane móc ją przewidzieć.

W zależności od tego, jaką funkcję modelującą wybierzemy, nasz model będzie korzystał w różny sposób z poszczególnych kolumn, aby przewidzieć wartości w naszej kolumnie zależnej. Czasem będzie korzystał z wszystkich kolumn w różnych proporcjach, a czasem będzie silnie polegał tylko na małym ułamku z nich. Wszystko będzie zależało od typu danych i funkcji modelującej wraz z parametrami, które wybierzemy. Nam natomiast najczęściej zależy na odnalezieniu najbardziej przydatnych kolumn.

Skoro rozwiązanie naszego problemu to łączenie poszczególnych kolumn z kolumną zależną, możemy sprawdzić, jak skuteczne będą predykcje, jeśli zepsujemy poszczególne kolumny.

Przypomnijmy sobie sytuację z artykułu ELI5 i białe pudełka. Funkcją modelującą, którą tam użyłem, było drzewo decyzyjne, a dane dotyczyły nowotworów. Wytrenowałem tam drzewo decyzyjne i zastosowałem funkcję explain_weights w celu sprawdzenia które cechy są według mojego modelu najważniejsze. Okazało się, że składową predykcji, która miała wagę aż 70%, była kolumna mean concave points.

Wyliczanie ręczne

Okej, skoro ta kolumna jest „niby” taka ważna, to jeżeli ją przemieszamy, to powinno to znacznie zepsuć skuteczność naszego klasyfikatora. Pójdźmy o krok dalej i w pętli, przy każdym jej przejściu przemieszajmy jedną kolumnę i sprawdźmy wtedy dokładność klasyfikatora. W kodzie Pythonowym będzie to wyglądać mniej więcej tak:


from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import eli5

base_accuracy = accuracy_score(y_true = y_test, y_pred = estimator_tree.predict(X_test))

base_accuracy

results = pd.DataFrame()
for column in X_test.columns:
    X_test_permuted = pd.DataFrame(X_test, copy = True)
    X_test_permuted[column] = shuffle(X_test_permuted[column], random_state = 42).values
    accuracy = accuracy_score(y_true = y_test, y_pred = estimator_tree.predict(X_test_permuted))
    accuracy_after_permutation = accuracy / base_accuracy
    results = results.append({"perm_column": column, "perm_decrease": 1-accuracy_after_permutation},
                             ignore_index=True)

eli5_df = eli5.format_as_dataframe(eli5.explain_weights(estimator_tree, feature_names = list(X_train.columns), 
                                                        top = None))

results_sorted = results.sort_values(by = "perm_decrease", ascending=False)
results_sorted.reset_index(inplace=True, drop=True)
results_sorted["eli5"] = eli5_df["feature"]
results_sorted["eli5_imp"] = eli5_df["weight"]

W tej pętli, zamiast wyliczać nową dokładność, wyliczamy, o ile zmniejszyła się względem bazowej. Ta kolumna, która będzie miała największy spadek, będzie dla nas najważniejsza. Wyniki, które uzyskaliśmy, ułożyłem właśnie według tego spadku i dokleiłem do nich wyniki wydobyte z explain_weights, dla łatwiejszego naocznego porównania. Wyglądają one tak:

perm_column perm_decrease eli5 eli5_imp
0 mean concave points 0.197080 mean concave points 0.708984
1 worst radius 0.080292 worst texture 0.117653
2 worst perimeter 0.043796 worst radius 0.060099
3 worst texture 0.043796 worst area 0.035168
4 worst area 0.029197 worst perimeter 0.029283
5 worst smoothness 0.007299 concave points error 0.017345
6 concave points error 0.007299 area error 0.013011
7 mean radius 0.000000 worst smoothness 0.010041
8 concavity error 0.000000 texture error 0.006833
9 worst symmetry 0.000000 smoothness error 0.001584
10 worst concave points 0.000000 radius error 0.000000
11 worst concavity 0.000000 mean smoothness 0.000000
12 worst compactness 0.000000 mean compactness 0.000000
13 fractal dimension error 0.000000 mean area 0.000000
14 symmetry error 0.000000 mean concavity 0.000000
15 compactness error 0.000000 mean perimeter 0.000000
16 mean texture 0.000000 mean symmetry 0.000000
17 smoothness error 0.000000 mean texture 0.000000
18 area error 0.000000 mean fractal dimension 0.000000
19 perimeter error 0.000000 worst fractal dimension 0.000000
20 texture error 0.000000 perimeter error 0.000000
21 radius error 0.000000 worst symmetry 0.000000
22 mean fractal dimension 0.000000 compactness error 0.000000
23 mean symmetry 0.000000 concavity error 0.000000
24 mean concavity 0.000000 symmetry error 0.000000
25 mean compactness 0.000000 fractal dimension error 0.000000
26 mean smoothness 0.000000 worst compactness 0.000000
27 mean area 0.000000 worst concavity 0.000000
28 mean perimeter 0.000000 worst concave points 0.000000
29 worst fractal dimension 0.000000 mean radius 0.000000

Jak powinniśmy interpretować tę tabelę?

  • perm_column zawiera ułożone w kolejności nazwy kolumn, które w naszej pętli wyszły jako najbardziej istotne.
  • perm_decrease zawiera informację, o ile spadła dokładność predykcji. 1 byłoby wartością maksymalną, a 0 wartością minimalną.
  • eli5 to ułożone w kolejności nazwy kolumn, które moduł eli5 wybrał jako najważniejsze.
  • eli5_imp to ważność cechy według eli5. Wszystkie ważności powinny się zsumować do 1.

Z racji, że perm_decreaseeli5_imp są w innych skalach, nie za bardzo możemy porównywać te kolumny. Nie powinniśmy też tego robić, dlatego że, w tym wypadku eli5 wylicza tę ważność na danych treningowych, a my wyliczamy na danych testowych. To, co chciałbym porównać to ważność kolumn. Widzimy tutaj, że pierwsze 5 kolumn jest takie same, różnią się nieco kolejnością. Jest to całkiem dobra przesłanka, że gdybyśmy nie mieli feature_importances_, to moglibyśmy dojść do podobnych wniosków.

Wyliczanie przy pomocy eli5

Wiemy już więc, jak funkcjonuje permutation importance. Potrafimy też napisać kawałek kodu, który wyliczy nam tę wartość. Ale czy będziemy musieli zawsze ręcznie z tym walczyć? Okazuje się, że nie. Moduł eli5 ma również odpowiednią funkcję, która nam wyliczy tę wartość. Funkcja ta to (a jakże) PermutationImportance. Funkcja ta, oprócz tego, że dokona odpowiednich permutacji i wyliczy zmianę dokładności, powtórzy to kilka razy i poda nam jeszcze odchylenie standardowe. Mała rzecz, a cieszy 🙂

Jak to będzie wyglądać w naszym przypadku?


from sklearn.metrics import make_scorer
accuracy_scorer = make_scorer(accuracy_score)

from eli5.sklearn import PermutationImportance

perm = PermutationImportance(estimator_tree, random_state=42, scoring = accuracy_scorer)
perm.fit(X_test, y_test)
eli5_perm = eli5.explain_weights(perm, feature_names = list(X_test.columns), top = None)
eli5_perm_df = eli5.format_as_dataframe(eli5_perm)

results_sorted["eli5_perm"] = eli5_perm_df["feature"]
results_sorted["eli5_perm_decr"] = eli5_perm_df["weight"]

Po wyznaczeniu permutation importance od razu dokleiłem dwie kolumny do wcześniej uzyskanej ramki danych. Wygląda ona teraz tak:

perm_column perm_decrease eli5 eli5_imp eli5_perm eli5_perm_decr
0 mean concave points 0.197080 mean concave points 0.708984 mean concave points 0.278322
1 worst radius 0.080292 worst texture 0.117653 worst radius 0.127273
2 worst perimeter 0.043796 worst radius 0.060099 worst area 0.044755
3 worst texture 0.043796 worst area 0.035168 worst texture 0.036364
4 worst area 0.029197 worst perimeter 0.029283 worst perimeter 0.030769
5 worst smoothness 0.007299 concave points error 0.017345 area error 0.012587
6 concave points error 0.007299 area error 0.013011 worst smoothness 0.008392
7 mean radius 0.000000 worst smoothness 0.010041 concave points error 0.008392
8 concavity error 0.000000 texture error 0.006833 mean compactness 0.000000
9 worst symmetry 0.000000 smoothness error 0.001584 mean symmetry 0.000000
10 worst concave points 0.000000 radius error 0.000000 mean concavity 0.000000
11 worst concavity 0.000000 mean smoothness 0.000000 worst fractal dimension 0.000000
12 worst compactness 0.000000 mean compactness 0.000000 mean smoothness 0.000000
13 fractal dimension error 0.000000 mean area 0.000000 radius error 0.000000
14 symmetry error 0.000000 mean concavity 0.000000 mean area 0.000000
15 compactness error 0.000000 mean perimeter 0.000000 mean perimeter 0.000000
16 mean texture 0.000000 mean symmetry 0.000000 mean texture 0.000000
17 smoothness error 0.000000 mean texture 0.000000 mean fractal dimension 0.000000
18 area error 0.000000 mean fractal dimension 0.000000 smoothness error 0.000000
19 perimeter error 0.000000 worst fractal dimension 0.000000 texture error 0.000000
20 texture error 0.000000 perimeter error 0.000000 perimeter error 0.000000
21 radius error 0.000000 worst symmetry 0.000000 worst symmetry 0.000000
22 mean fractal dimension 0.000000 compactness error 0.000000 compactness error 0.000000
23 mean symmetry 0.000000 concavity error 0.000000 concavity error 0.000000
24 mean concavity 0.000000 symmetry error 0.000000 symmetry error 0.000000
25 mean compactness 0.000000 fractal dimension error 0.000000 fractal dimension error 0.000000
26 mean smoothness 0.000000 worst compactness 0.000000 worst compactness 0.000000
27 mean area 0.000000 worst concavity 0.000000 worst concavity 0.000000
28 mean perimeter 0.000000 worst concave points 0.000000 worst concave points 0.000000
29 worst fractal dimension 0.000000 mean radius 0.000000 mean radius 0.000000

Dwie dodatkowe kolumny:

  • eli5_perm to ułożone w kolejności nazwy kolumn, które moduł eli5 wyznaczył jako najważniejsze przy pomocy funkcji PermutationImportance.
  • eli5_perm_decr to spadek dokładności wyliczony przy pomocy funkcji PermutationImportance.

Mamy tutaj podobną sytuację jak wcześniej – pierwsze 5 kolumn jest takie samo, mamy jednak nieco inną kolejność. W tym wypadku możemy już porównywać kolumny eli5_perm_decr i perm_decrease, bo bazują na tych samych danych i oznaczają to samo.

Podsumowanie

Jak widać, możemy obyć się bez atrybutu feature_importances_, co jest szczególnie przydatne, jeśli funkcja modelująca nie dostarcza nam tych informacji. Może to być szczególnie przydatne, jeśli np. budujemy model, który jest złożeniem wielu innych modeli. No bo permutation importance wyliczamy już po treningu, bazując na danych testowych. Sądzę więc, że chociażby z tego powodu jest to fajne narzędzie przydatne w procesie feature engineering.

Pełny kod użyty w tym artykule znajduje się tutaj.

 

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *