Przejdź do treści

Pipeline w SciKit Learn – własny estymator

Dzisiaj napiszemy własny (a raczej poznamy mechanikę działania) oraz nauczymy się szukać najlepszych hyper parametrów dla modelu (właściwie: całego pipeline) w zwarty sposób.

W pierszej części zobaczyliśmy jak zbudować pipeline dla danych i modeli. Dzięki temu dostaliśmy możliwość zmiany sposobu transformacji danych i zmiany modeli w ramach prostego, zwartego kodu. Ale co jeśli potrzebujemy jakiś transformator którego nie ma w pakietach?

Tak jak poprzednio – nasze działania oprzemy na Pythonie i pakiecie scikit learn. Podobne rozwiązania można znaleźć w R (jeśli tego szukasz – zainteresuj się Tidymodels).

Aby zobaczyć co się dzieje wewnątrz kolejnych budowanych przez nas metod zbudujemy sobie prosty zestaw danych. W dzisiejszym ćwiczeniu nie chodzi o znalezienie konkretnego modelu czy też najlepszego wyniku – dane mogą być więc dowolne, ważne żebyśmy widzieli na nich efekty działań naszego kodu.

Własny transformer/estymator

Budujemy pierwszy transformer. Tutaj przyda się podstawowa wiedza na temat programowania obiektowego – co to jest klasa, co to są metody tej klasy, co to jest dziedziczenie i jak to wygląda w Pythonie. Zakładam, że znasz te podstawy.

Nasz transformer potrzebuje dwóch metod: fit() oraz transform(). Wiemy jak działają transformery i modele w scikit-learn, prawda? Uczymy je na danych treningowych poprzez wywołanie metody fit() a potem stosujemy przekształcenie do danych treningowych poprzez transform() (często stosuje się też od razu uczenie i przekształcenie wywołując fit_transform()), zaś dane testowe (czy też nowe dane) traktujemy jedynie przez transfor().

Często modele/transformery mają jakieś parametry. Podaje się je podczas budowania klasy – czyli wywołuje się konstruktora klasy __init__() z odpowiednimi parametrami. Konstruktor "zapamiętuje" w ramach obiektu te parametry (w Pythonie jest to po prostu ustawienie wartości zmiennych self.cośtam dostępnych w ramach całego obiektu). Tak wygląda teoria, konkrety tłumaczy kod poniżej:

Mając zbudowaną klasę możemy sprawdzić jak ona się zachowuje. Najpierw "samodzielnie", a później upakujemy ją w pipeline.

Co się stanie jak utworzymy obiekt naszej klasy? Powinna wykonać się metoda __init__() – spróbujmy od razu podać parametry:

I zadziałało zgodnie z planem – wypisały się podane przez nas parametry, a fitted jest jeszcze nie zdefiniowany. Ale czy na pewno?

Zgadza się. Zatem wyuczmy (dofitujmy) nasz obiekt na danych przygotowanych wyżej:

Wszystko się zgadza. A co się stanie jak zrobimy transformację?

Poza wypisaniem jakichś parametrów obiektu dostaliśmy tablicę pięciu wartości równych temu co jest w fitted. I właśnie o to chodziło. Tutaj jest tablica długa na tyle na ile mamy rekordów w danych – musimy dostać listę, którą da się porównać z wartościami Y naszych danych.

Ale miało być o pipeline’ach – zatem użymy naszego transformera, a za klasyfikator weźmiemy najprostszy DummyClassifier() w dodatku skonfigurowany tak, aby zawsze odpowiadał wartością 0, nie ważne jakie są cechy konkretnej próbki (to ułatwi nam porównanie wyników).

Co się wydarzyło? A no tylko tyle, że wywołaliśmy konstruktora.

Co zrobi fitowanie całego pipeline’u?

Wykonało się najpierw fit() a potem transform().

Po fitowaniu możemy ocenić nasz model:

Widzimy że wykonał się transformer z już wyuczonymi wartościami fitted. Odpowiedzią w tym przypadku jest accuracy – mamy do czynienia z modelem klasyfikującym.

Oczywiście otrzymana wartość predykcji zależy od tego jak wylosował nam się w ramach fit parametr fitted, ale w tym konkretnym przypadku zawsze będziemy mieć 1/3 skuteczności (bo mamy 2 zera i 4 jedynki w danych, a DummyClassifier skonfigurowaliśmy tak, aby zawsze zwracał zero).

Do czego to wszystko może być potrzebne? Do kazdej sytuacji, w której nie mamy gotowego rozwiązania. Nie mamy (a przynajmniej nie kojarzę) transformatora który na przykład zamieni nam zapis liczb dziesiętnych z 12.345,67 na 12345.67 i zmieni otrzymaną wartość na float. Oczywiście przykłady można mnożyć – to jeden z najprostszych.

Szukanie najlepszych hyper parametrów

No dobrze – wiemy jak przygotować swoje transformatory danych (spróbuj analogicznie przygotować estymatory!), a nawet dać im możliwość kręcenia śrubkami w postaci hyper parametrów. Ale jak znaleźć najlepszą kombinację hyper parametrów?

Użyjemy przeszukiwania po siatce wszystkich parametrów. SciKit Learn ma to na dzień dobry:

Do tego ćwiczenia przygotujemy inną wersję naszego transformatora – takiego, który podczas fittowania niczego nie robi, ale podczas transformacji już coś się dzieje (i jest to zależne od hyper parametrów):

Podobnie jak już wiele razy wcześniej – budujemy pipeline i siatkę parametrów do przeszukania. Dla naszego przykładu nie będziemy zmieniać parametrów klasyfikatora (i tak jak poprzednio zawsze dostaniemy – tym razem – jedynkę) – dzięki temu zobaczymy jakie są kolejne przebiegi po siatce.

No to szukamy. Przy okazji (parametr cv) włączamy walidację krzyżową (cross validation) – dla każdej kombinacji parametrów podzielimy zbiór na dwa foldy.

Każdy przebieg to jeden wiersz w powyższym listingu. I w wierszu tym widzimy które z elementów służyły jako dane treningowe oraz jakie hyper parametry zostały podane na "rurociąg". Nie widzimy wyniku modelu dla takich parametrów (za moment zobaczymy), ale możemy szybko znaleźć najlepszy wynik:

Wszytkie wyniki oczywiście też. Ale uwaga – są to wartości dla danej kombinacji parametrów (odpowiednio uśrednione) a nie konkretnego przebiegu w ramach puli parametrów. To nawet lepiej – mamy wynik bardziej stabilny (uwzględniający cross validation):

Dość długa i nieczytelna ta lista, najprościej uzyskać zestaw najlepszych hyper parametrów przez

Wytrenowany obiekt typu GridSearch() jest tym samym co wytrenowany pipeline czy też gotowy estymator, tak więc już na nim możemy użyć metod do predykcji (oczywiście skorzysta wtedy z najlepszych hyper parametrów):

Wydawać się może dziwne, że wszystkie kombinacje parametrów dały 66% skuteczności. Ale spójrz na dobór próbek do foldów – mamy zestaw T1, T2 i T3 i drugi zestaw to T3, T4, T5. Patrząc na tabelę z danymi:

widać, że w tych kombinacjach zawsze mamy dwie jedynki i jedno zero (a DummyClassifier zawsze zwraca 1) więc zawsze w 2/3 "pasuje".

Czy przeszukiwanie całej siatki hyper parametrów to najlepszy sposób? Na to pytanie odpowiemy sobie w części kolejnej.

Dodaj komentarz

Twój adres e-mail nie zostanie opublikowany. Wymagane pola są oznaczone *