Scientia Conditorium

[책리뷰] JAX/Flax로 딥러닝 레벨업 본문

서평/IT-책

[책리뷰] JAX/Flax로 딥러닝 레벨업

크썸 2024. 9. 21. 20:40

[책리뷰] JAX/Flax로 딥러닝 레벨업

 

 

전반적인 소감 및 마음에 드는 부분

흔히들 딥러닝 프레임워크 하면 Tensorflow나 PyTorch를 떠올린다. 인공지능 개발 경쟁이 한창인 요즘 JAX는 구글 딥마인드에서 개발, 유지관리하고 사용하는 또 다른 딥러닝 프레임워크다. 텐서플로우나 파이토치를 알려주는 책은 많지만 JAX 를 알려주는 책이 24년 9월 기준 이 책이 유일하다. 이 책에서도 쓸데없는 파이썬 기초 프로그래밍을 알려주는 페이지 없이 JAX/Flax 에 대해서만 다룬다. 책의 머릿말에도 나오듯이 JAX/Flax LAB은 '모두의 연구소'에서 활동하는 AI 리서처와 현업자로 구성되어 있는데, 해당 구성원들이 집필한 책이다. 아무리 공식 문서를 보고 공부하는 게 좋다고 하더라도 이렇게 번역하고 한국어로 책이 나온 점은 대단히 감사하다.

 

 

대상 독자 및 책 난이도

JAX에 관심 있는 모든 개발자와 학생들을 대상으로 한다. 특히 파이썬을 통한 고성능 계산과 머신러닝에 익숙하거나 관심이 있는 분들도 대상 독자다. 이 책을 보기 위해 적어도 파이썬 프로그래밍 기초 지식이 있어야 하며, 선형대수 기본 개념과 머신러닝/딥러닝 기초 이론이 필요하다. 여기에 함수형 프로그래밍이 대략 어떤건지 간단하게나마 알고 있으면 더욱 좋다. 개인적으로는 텐서플로우와 파이토치로 딥러닝을 공부해본 개발자를 대상으로 한 느낌이다.

 

텐서플로우나 파이토치로 딥러닝 공부를 해본 경험이 있다면 무리없이 따라할 수 있다. 파이썬 라이브러리 import 할 때, JAX로 대체하는 것이 대부분이라고 보면 될 것이다. 그러나 처음부터 JAX로 시작한다고 하면 어려울 수 있다.

 

다루는 내용과 범위

JAX/Flax 프레임워크에 대해서만 다룬다. 책의 목차는 다음과 같다.

  1. JAX/Flax 공부하기 전에
    1. JAX/Flax 소개와 예시
  2. JAX 특징
    1. Numpy에서부터 JAX 시작하기
    2. JAX의 JIT 컴파일
    3. 자동 벡터화/자동 미분
    4. JAX의 난수
    5. pytree 사용하기
    6. JAX에서의 병렬처리
    7. 상태를 유지하는 연산
  3. Flax 소개
    1. Flax CNN 튜토리얼
    2. 심화 튜토리얼
  4. JAX/Flax 활용한 딥러닝 모델 만들기
    1. 순수 JAX로 구현한 CNN
    2. ResNet
    3. DCGAN
    4. CLIP
    5. DistilGPT2 미세조정 학습
  5. TPU 환경 설정

 

 

결론

딥러닝 모델 학습에 막대한 리소스가 필요한 대규모 모델 연구에 깊이 관여 중이라면 JAX 를 고려해볼만 하다. 따라서 대규모 모델 연구 즉, LLM 같은 걸 연구하는 사람들에게 추천한다. 지난 약 1년 동안 나온 논문과 릴리스를 살펴보면 구글에서 많은 연구가 JAX로 옮겨갔다는 사실을 알 수 있다고 한다. 모든 것을 충족시키는 만능 프레임워크는 없다. 어떤 문제를 해결할 것인가, 배포할 모델의 크기, 플랫폼 등등에 따라 달라진다. 조그마한 모델을 학습시키기 위해 JAX를 고려할 필요는 없다는 의미이다. 그러나 대규모 연산의 확장성을 염두에 둔다면 JAX도 고려해볼만 하다.