Przejdź do treści

JAX

logo JAX

Wprowadzenie

JAX to stworzona przez firmę Google biblioteka do akceleracji obliczeń numerycznych i tworzenia sieci neuronowych dla języka Python. JAX nadaje się jako osobny framework do tworzenia nowych, autorskich rozwiązań z wykorzystaniem uczenia głębokiego (deep learning), głównie w celach badawczych przez zaawansowanych użytkowników. Może być także wykorzystany jako element wspomagający pracę innych frameworków jak Tensorflow i Pytorch. Dodatkowo, stanowi wartościową bibliotekę do wykonywania innych obliczeń naukowych, niekoniecznie związanych z sieciami neuronowymi.

Dostępność

Została napisana w języku Python i C++. Twórcy byli nastawieni głównie na wykorzystanie akceleratorów, stąd JAX nie jest dobrze zoptymalizowany do pracy na CPU. Niektóre operacje mogą być wręcz wolniejsze niż w tradycyjnym NumPy. JAX umożliwia trening rozproszony na wielu akceleratorach i/lub węzłach. Zaletę stanowi prostota implementacji treningu rozproszonego, wymagająca tylko nieznacznych zmian w kodzie.

Szczegóły

W odróżnieniu od innych bibliotek, JAX jest oparty głównie na paradygmacie programowania funkcyjnego. Jego stosowanie jest wymagane, aby w pełni wykorzystać możliwości tej biblioteki.

Głównymi funkcjonalnościami dostarczanymi przez JAX w celu przyspieszenia obliczeń są: automatyczne różniczkowanie (autograd) dla dowolnych funkcji, automatyczna wektoryzacja i dystrybucja danych na wiele akceleratorów, kompilacja JIT (just-in-time), XLA (accelerated linear algebra).

JAX oferuje głównie funkcjonalności niskopoziomowe. Nie posiada wysokopoziomowo zdefiniowanych wielu funkcjonalności takich jak konkretne implementacje warstw sieci podobnych do tych spotykanych w Keras czy Pytorch. JAX wymaga własnego tworzenia niektórych rozwiązań, ze względu na mniejszą ilość zaimplementowanych gotowych funkcjonalności. Z jednej strony to komplikuje i wydłuża pracę nad kodem, z drugiej zaś pozwala na znaczną elastyczność.

Twórcy starali się zachować kompatybilność API z operacjami dostępnymi w NumPy. Jednak zachowanie niektórych funkcji różni się w porównaniu do NumPy. Największą zaletą JAX jest wydajność – framework ten pozwala na szybsze działanie modeli niż Tensorflow (nieznacznie) i Pytorch (znacznie).

Informacje o wydaniu

Obecna wersja to 0.4.6, wydana w marcu 2023 roku. Kod źródłowy jest dostępny publicznie. Nowe wersje pojawiają się raz w roku, z rozszerzeniami pojawiającymi się kilka razy w roku. Warto zaznaczyć, że JAX jest młodą biblioteką, nadal określaną przez twórców jako eksperymentalną. Gdy twórcy wycofują funkcjonalność, jest ona wspierana przez minimum trzy miesiące od ogłoszenia zmiany. Po tym okresie może zostać w dowolnym momencie usunięta z kolejnych wersji.

Linki


Ostatnia aktualizacja: 29 lipca 2024