背景 第1回: PyTorch to JAX 移行ガイド(MLP学習編) 第2回: PyTorch to JAX 移行ガイド(GPUでのCNN学習 | BatchNorm編) JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGenerative Adversarial Networks (GAN) の学習を取り上げ、flaxにおける便利機能TrainStateのカスタマイズやJAXにおける乱数の扱いについて学びます。 例によってコードはgistにアップロードしてあります。 サンプル: Gaussian MixtureをGANで学習する 学習データの用意 今回は簡単な問題として、1次元、2つ山のあるGaaussian Mixtureを学習する問題を考えましょう import numpy as np from tens