背景 第1回: PyTorch to JAX 移行ガイド(MLP学習編) 第2回: PyTorch to JAX 移行ガイド(GPUでのCNN学習 | BatchNorm編) JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGenerative Adversarial Networks (GAN) の学習を取り上げ、flaxにおける便利機能TrainStateのカスタマイズやJAXにおける乱数の扱いについて学びます。 例によってコードはgistにアップロードしてあります。 サンプル: Gaussian MixtureをGANで学習する 学習データの用意 今回は簡単な問題として、1次元、2つ山のあるGaaussian Mixtureを学習する問題を考えましょう import numpy as np from tens
![PyTorch to JAX 移行ガイド(GANの学習|TrainStateのカスタマイズ編)](https://cdn-ak-scissors.b.st-hatena.com/image/square/c0d36354d5c33524356c5598e1b8a01433a4233c/height=288;version=1;width=512/https%3A%2F%2Fres.cloudinary.com%2Fzenn%2Fimage%2Fupload%2Fs--Jqd2iA8L--%2Fc_fit%252Cg_north_west%252Cl_text%3Anotosansjp-medium.otf_55%3APyTorch%252520to%252520JAX%252520%2525E7%2525A7%2525BB%2525E8%2525A1%25258C%2525E3%252582%2525AC%2525E3%252582%2525A4%2525E3%252583%252589%2525EF%2525BC%252588GAN%2525E3%252581%2525AE%2525E5%2525AD%2525A6%2525E7%2525BF%252592%2525EF%2525BD%25259CTrainState%2525E3%252581%2525AE%2525E3%252582%2525AB%2525E3%252582%2525B9%2525E3%252582%2525BF%2525E3%252583%25259E%2525E3%252582%2525A4%2525E3%252582%2525BA%2525E7%2525B7%2525A8%2525EF%2525BC%252589%252Cw_1010%252Cx_90%252Cy_100%2Fg_south_west%252Cl_text%3Anotosansjp-medium.otf_37%3Ayonetaniryo%252Cx_203%252Cy_121%2Fg_south_west%252Ch_90%252Cl_fetch%3AaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tL2EtL0FPaDE0R2llTW5CU2h3U0RzZmV2VUIzU1VWOXRLTGdxeVV4aW5DVkhrRncyelE9czk2LWM%3D%252Cr_max%252Cw_90%252Cx_87%252Cy_95%2Fv1627283836%2Fdefault%2Fog-base-w1200-v2.png)