電通国際情報サービス、オープンイノベーションラボの比嘉康雄です。 Stable Diffusion(というよりdiffusers)でTPU(JAX / Flax)を使った並列実行バージョンがリリースされたので、早速試してみました。 オリジナルのNotebookはこちら。 僕が作ったNotebookはこちら。 今回は、TPUを使うので、Google Colabに特化しています。自分で1から試す方は、メニューのEdit -> Notebook settingsでTPUを使うように設定してください。 Stable Diffusionのおすすめコンテンツはこちら。 Waifu Diffusion 1.3.5_80000 v2.1 金髪美女写真 v2.1 美少女アニメ画 v2.1 AUTOMATIC1111 v2.0 美少女イラスト v1.5 美少女画検証 美少女アニメ画改善版 美少女を高確率で出す
[IT研修]注目キーワード Python UiPath(RPA) 最新技術動向 Microsoft Azure Docker Kubernetes 第140回 オリジナル論文から学ぶ「JAX」の特徴とその役割 (中井悦司) 2022年11月 はじめに 今回は、2018年に公開された論文「Compiling machine learning programs via high-level tracing」を元にして、機械学習ライブラリーJAXの役割と、その基礎となる考え方を紹介します。 JAXのオリジナル論文 最近、機械学習に関連するオープソースとして、JAXの名前を耳にすることが増えてきました。たとえば、2022年6月に公開されたGoogle Cloudの公式ブログ「EvoJAX: あなたの課題をNeuroevolutionの力で解く」の冒頭には、次のような一節があります。 『JAXはユー
はじめに 深層学習モデルやその学習を実装する際には、多くの場合でPyTorchやKerasなどのフレームワークが使われます。本記事では、Googleより公開されているJaxというフレームワークを用いた深層学習プログラミングを紹介します。 コードは以下に配置しています。 Jaxとは JaxはGoogleから公開されている、自動微分を備えた数値計算ライブラリと言えます。Numpyとほぼ同じように計算処理を実装でき、またGPUやTPUによって高速に演算を実行することもできます。これによって深層学習モデルを実装し、学習することができます。またNumpyと近い使い方ができるので、やろうと思えば深層学習以外の多くのアルゴリズムを実装することもできます。 Jaxにはいくつかの派生ライブラリがあります。深層学習でよく利用されるような畳み込み層やバッチ正規化などはFlaxというライブラリで提供されており、本
はじめに 逐次更新による最適化 大枠 勾配法 数式 勾配法コード例 ニュートン法 数式 ニュートン法のコード例 はじめに 最近、しっかり学ぶ数理最適化を購入しました。 しっかり学ぶ数理最適化 モデルからアルゴリズムまで (KS情報科学専門書) 作者:梅谷 俊治発売日: 2020/10/26メディア: 単行本(ソフトカバー) 1章→3章と読んでいく中で、元々馴染みの深い連続最適化の極々基本的な手法である勾配法とニュートン法を試してみました。実装はJAXを使っています。こいつは現状、最高の自動微分ライブラリだと思っております(深層学習ライブラリという観点ではPyTorchの方が今の所使いやすい)。 普通、機械学習では二次微分なんてパラメータが多すぎてまともに計算できる見込みがないので、純粋なニュートン法なんて絶対に使わないのですが、その圧倒的な性能の高さを確認し、兎にも角にも勾配法の弱さを確認
はじめに モジュールインポート Jax 勾配関数と線形探索関数を準備 最適化実行 PyTorch 線形探索関数準備 最適化実行 結果 はじめに 前回は下記の記事で学習率固定で勾配法を実施しました。 www.hellocybernetics.tech 今回はウルフ条件を満たすような学習率を各更新時にバックステップで探索し、満たすものを見つけたら直ちにその学習率の更新するという形式で勾配法を実施します。 この記事ではJaxとPyTorchで収束までのステップ数や収束先等の結果はほぼ一致しましたが、速度が圧倒的にJaxの方が速く、PyTorchの計算グラフが変なふうになってしまっている可能性があります(こんなPyTorch遅いわけがない…!) どなたか見つけたら教えて下さい…。 モジュールインポート import jax import jax.numpy as jnp from jax impo
December 25 2021 in Julia, Programming, Science, Scientific ML | Tags: automatic differentiation, compilers, differentiable programming, jax, julia, machine learning, pytorch, tensorflow, XLA | Author: Christopher Rackauckas To understand the differences between automatic differentiation libraries, let’s talk about the engineering trade-offs that were made. I would personally say that none of thes
背景 第1回: PyTorch to JAX 移行ガイド(MLP学習編) 第2回: PyTorch to JAX 移行ガイド(GPUでのCNN学習 | BatchNorm編) JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGenerative Adversarial Networks (GAN) の学習を取り上げ、flaxにおける便利機能TrainStateのカスタマイズやJAXにおける乱数の扱いについて学びます。 例によってコードはgistにアップロードしてあります。 サンプル: Gaussian MixtureをGANで学習する 学習データの用意 今回は簡単な問題として、1次元、2つ山のあるGaaussian Mixtureを学習する問題を考えましょう import numpy as np from tens
はじめに 使う関数 autograd with pytorch autograd with jax Jax で単回帰 はじめに PyTorchとjaxの比較用。この手のライブラリを使うには、autogradの使い方を理解することが一番最初の仕事だと思われます。そして、そのautogradの時点で大きく思想が異なっているので、メモしておきます。 使う関数 下記をインポートしている前提で import torch import jax import jax.numpy as np from jax import vmap, jit, grad 下記の2次元上のスカラー関数 $$ f(x, y) = x ^ 2 - y ^ 2 + 2 x y $$ を微分していきます。 def f(x, y): return x**2 - y**2 + 2*x*y と書いておきます。 autograd with
リリース、障害情報などのサービスのお知らせ
最新の人気エントリーの配信
処理を実行中です
j次のブックマーク
k前のブックマーク
lあとで読む
eコメント一覧を開く
oページを開く