並び順

ブックマーク数

期間指定

  • から
  • まで

1 - 30 件 / 30件

新着順 人気順

JAXの検索結果1 - 30 件 / 30件

タグ検索の該当結果が少ないため、タイトル検索結果を表示しています。

JAXに関するエントリは30件あります。 機械学習pythonPyTorch などが関連タグです。 人気エントリには 『JAXによるスケーラブルな機械学習 - ZOZO TECH BLOG』などがあります。
  • JAXによるスケーラブルな機械学習 - ZOZO TECH BLOG

    はじめに こんにちは、ZOZO NEXT ZOZO ResearchのSai Htaung Khamです。ZOZO NEXTは、ファッション領域におけるユーザーの課題を想像しテクノロジーの力で解決すること、より多くの人がファッションを楽しめる世界の創造を目指す企業です。 ZOZO NEXTでは多くのアルゴリズムを研究開発しており、その中でJAXというライブラリを使用しています。JAXは高性能な機械学習のために設計されたPythonのライブラリです。NumPyに似ていますが、より強力なライブラリであると考えることができます。NumPyとは異なり、JAXはマルチGPU、マルチTPU、そして機械学習の研究に非常に有用な自動微分(Autograd)をサポートしています。 JAXはNumPyのAPIのほとんどをミラーリングしているので、NumPyライブラリに慣れている人なら非常に導入しやすいです。A

      JAXによるスケーラブルな機械学習 - ZOZO TECH BLOG
    • Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート1) - めもめも

      なんの話かと言うと 最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。 arxiv.org ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート1では、Embedding レイヤーを解説します。 JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。 JAX/Flaxで学ぶディープラーニングの仕組み 作者:中井悦司マイナビ出版Amazon Transformer の全体像 冒頭の論文では、Transformer Encoder と Transformer Decoder を組み合わせた下記のモデルが説明されています。 左側の Encoder でテキストを解

        Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート1) - めもめも
      • JAX/Flaxを使ってMNISTを学習させてみる | TC3株式会社|GIG INNOVATED.

        本記事は20分程度でお読みいただけます。 こんにちは。TC3データサイエンス部門の梅本です。 普段はPyTorchを使っているのですが、新しいライブラリを触るのも勉強になると思いますので、今日は新進気鋭の深層学習ライブラリであるJAX/Flaxを使って、MNISTを学習させてみようと思います。 はじめに 皆さんご存知の通り、TensorFlow、Keras、PyTorch(Chainer…)と近年は様々な深層学習ライブラリが使われています。最近、JAXというライブラリが話題になっているものの、十分すぎるライブラリがある中でなぜJAXが新たに出てきたのでしょうか?(そしてなぜ使うべきなのか)。この理由には後発ライブラリの強みとして、先行ライブラリの問題点を改良しているという点が挙げられます。現状以下のような利点が挙げられます XLAコンパイルによる高速性 厳密な乱数の管理による再現性の担保

          JAX/Flaxを使ってMNISTを学習させてみる | TC3株式会社|GIG INNOVATED.
        • JAXとPyTorch、どっちが速いのか検証してみた - まったり勉強ノート

          高速化が趣味&仕事なので、最近よく目にするJAXの速度が気になってました。このため、今回は日ごろ使っているPyTorchと比較したので、その結果のまとめを紹介します。 結論 結果だけ知りたい方が多いだろうと思ったので先に結論から書くと、私のPyTorch力では力及ばず、今回の検証では JAXのほうがPyTorchの2.2倍速い という結果でした。ここから詳しく評価について説明します。 評価方法 今回、JAXとPyTorchを比較するにあたり、この前紹介したSmooth Smith Watermanのコードを利用しました。Smooth Smith Watermanについて知りたいという方は以下の記事をご覧ください。 この記事で紹介したJAXコードは論文の著者が頑張って高速化した結果なため、十分最適化された結果であるという認識です。このため、今回はPyTorchのコードを私が作成し、測定を行い

          • JAX入門~高速なNumPyとして使いこなすためのチュートリアル~ - Qiita

            TensorFlow Advent Calendar 2020 10日目の記事です。空いてたので当日飛び入りで参加しました。 この記事では、TensorFlowの関連ライブラリである「JAX」について初歩的な使い方、ハマりどころ、GPU・TPUでの使い方や、画像処理への応用について解説します。 JAXとは https://github.com/google/jax Google製のライブラリで、AutogradとXLAからなる、機械学習のための数値計算ライブラリ。簡単に言うと「自動微分に特化した、GPUやTPUに対応した高速なNumPy」。NumPyとほとんど同じ感覚で書くことができます。自動微分については解説が多いので、この記事では単なる高速なNumPyの部分を中心に書いていきます。 関連記事 JAX Quickstart JAXで始めるディープラーニング JAX : Tutorials

              JAX入門~高速なNumPyとして使いこなすためのチュートリアル~ - Qiita
            • JAXとFlaxを使って、ナウい機械学習をしたい

              JAXとFlaxの基本と、深層学習フレームワークの流れなど

                JAXとFlaxを使って、ナウい機械学習をしたい
              • 今こそはじめるJAX/Flax入門 Part 1

                1. はじめに 2012年から始まった深層学習の発展の過程で、さまざまな学習フレームワークが登場しました。中でもPyTorchとTensorflowは最も広く使われており、それぞれのフレームワークが支持されている背景には、柔軟性、拡張性、そして使いやすさがあります。 一方で、これらのフレームワークはその機能を拡張し続けてきた結果として、全体として非常に巨大で複雑なライブラリになっています。そのため、独自に機能拡張を行いたいユーザーにとっては扱いづらく、性能的にもオーバーヘッドを感じさせることがあります。 そこで新たに出てきたのが「JAX」とその関連ライブラリの組み合わせになります。2019年に登場して以降、特に海外の開発者に支持されてきました。近年注目されている大規模言語モデル(LLM)の分野においても、JAXによるモデルが公開されていることは珍しくなくなりつつあります。 PyTorch(

                  今こそはじめるJAX/Flax入門 Part 1
                • jaxのautogradをpytorchのautogradと比較、単回帰まで(速度比較追加) - HELLO CYBERNETICS

                  はじめに 使う関数 autograd with pytorch autograd with jax Jax で単回帰 はじめに PyTorchとjaxの比較用。この手のライブラリを使うには、autogradの使い方を理解することが一番最初の仕事だと思われます。そして、そのautogradの時点で大きく思想が異なっているので、メモしておきます。 使う関数 下記をインポートしている前提で import torch import jax import jax.numpy as np from jax import vmap, jit, grad 下記の2次元上のスカラー関数 $$ f(x, y) = x ^ 2 - y ^ 2 + 2 x y $$ を微分していきます。 def f(x, y): return x**2 - y**2 + 2*x*y と書いておきます。 autograd with

                    jaxのautogradをpytorchのautogradと比較、単回帰まで(速度比較追加) - HELLO CYBERNETICS
                  • GitHub - kingoflolz/mesh-transformer-jax: Model parallel transformers in JAX and Haiku

                    You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert

                      GitHub - kingoflolz/mesh-transformer-jax: Model parallel transformers in JAX and Haiku
                    • JAXライクなfunctorchで機械学習を速くする – part 1 – Rest Term

                      PyTorch 1.11からβ版として追加された functorch と呼ばれる機能を試してみました。PyTorch 1.9くらいのときから試験版として本体に組み込まれて提供されていましたが、どうやらfunctorchという別モジュールに切り出して提供されるようになったようです。 pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch. functorchとは PyTorch公式サイトには以下のように説明されています。 functorch is a library that adds composable function transforms to PyTorch. It aims to provide composable vmap (vectorization) and a

                        JAXライクなfunctorchで機械学習を速くする – part 1 – Rest Term
                      • Jaxでガウス過程 + NumPyroでハイパーパラメータ推論 - HELLO CYBERNETICS

                        モジュール データ ガウス過程 カーネル関数 予測 決め打ちハイパーパラメータでの予測 MCMC でのハイパーパラメータ推論 モデル 事前分布からのサンプリング 事後分布の推論 予測分布 ガウス過程関連の記事 モジュール import jax.numpy as np import jax from jax import random, grad, vmap, jit, lax import matplotlib.pyplot as plt import seaborn as sns import numpyro import numpyro.distributions as dist from numpyro import plate, sample, handlers from numpyro.infer import MCMC, NUTS, SVI, ELBO plt.style.us

                          Jaxでガウス過程 + NumPyroでハイパーパラメータ推論 - HELLO CYBERNETICS
                        • PyTorch to JAX 移行ガイド(MLP学習編)

                          背景 「JAX最高」「GoogleではみんなJAXやってる」などと巷で言われているが、研の活動をやってると、比較手法がPyTorchで提供されていたり、ちょっと特殊な損失関数とかを使わないといけなかったり、あとはネットワーク魔改造をしたくなったりと、「とりあえずまずはPyTorchでやっとくか…」と思わせる要素がたくさんあり、PyTorchから抜け出せずにいた。 ムムッでもこれは2013年ごろを思い出す…その頃自分はとにかくMatlabで全部書いてて、なかなかPythonに移行出来ずにいた。そんななか「飯の種ネタをPythonで書き始めれば、Pythonできない→成果が出ない→死」なので自動的にPythonを習得できるのでは???と思い、えいやとPythonの海に飛び込んだのである。思えばPyTorchもDockerもそんな感じで飛び込んだが、今こそJAXに飛び込む時なのかもしれない。 移

                            PyTorch to JAX 移行ガイド(MLP学習編)
                          • 拡散生成モデルで学ぶJax/Flaxによる深層学習プログラミング

                            はじめに 深層学習モデルやその学習を実装する際には、多くの場合でPyTorchやKerasなどのフレームワークが使われます。本記事では、Googleより公開されているJaxというフレームワークを用いた深層学習プログラミングを紹介します。 コードは以下に配置しています。 Jaxとは JaxはGoogleから公開されている、自動微分を備えた数値計算ライブラリと言えます。Numpyとほぼ同じように計算処理を実装でき、またGPUやTPUによって高速に演算を実行することもできます。これによって深層学習モデルを実装し、学習することができます。またNumpyと近い使い方ができるので、やろうと思えば深層学習以外の多くのアルゴリズムを実装することもできます。 Jaxにはいくつかの派生ライブラリがあります。深層学習でよく利用されるような畳み込み層やバッチ正規化などはFlaxというライブラリで提供されており、本

                              拡散生成モデルで学ぶJax/Flaxによる深層学習プログラミング
                            • Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia - Stochastic Lifestyle

                              December 25 2021 in Julia, Programming, Science, Scientific ML | Tags: automatic differentiation, compilers, differentiable programming, jax, julia, machine learning, pytorch, tensorflow, XLA | Author: Christopher Rackauckas To understand the differences between automatic differentiation libraries, let’s talk about the engineering trade-offs that were made. I would personally say that none of thes

                                Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia - Stochastic Lifestyle
                              • Jax/FlaxでKaggleをやってみよう!

                                こんにちは。いのいちです。 この記事は(の18日目の記事ですこの記事はKaggle Advent Calendr 2022の18日目の記事です。 最近はスプラトゥーンにはまっていて毎日忙しくてあまりコンペに参加できてませんが、2023年はどんどんコンペに参加していきたいなと思っています。私はコンペに参加するときはいつも1つはこれまでやったことないことをやると決めています。そこで新しい挑戦としてJax/Flaxを使ってみようと思い至りました。私が普段参加するComputer Vison系のコンペでは主にPytorchが使用されており、TPUをぶん回すコンペでTensorflowが使われていたりします。Jax/Flaxのnotebookも時々見かけますが、まだベースラインとなるようなnotebookが共有されたり、がっつりJax/Flaxでコンペをやったというのは見たことがありません。しかし、

                                  Jax/FlaxでKaggleをやってみよう!
                                • Getting started with JAX (MLPs, CNNs & RNNs)

                                  Robert Tjarko Lange Evolutionary Meta-Learning @sakana.ai @TU Berlin JAX, Jax, JaX. Twitter seems to know nothing else nowadays (next to COVID-19). If you are like me and want to know what the newest hypetrain is about - welcome to todays blog post! I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for yo

                                    Getting started with JAX (MLPs, CNNs & RNNs)
                                  • 自動微分+XLA付き機械学習フレームワークJAXを使用してMNISTを学習させてみる - Morikatron Engineer Blog

                                    こんにちは、エンジニアの竹内です。 深層学習を行う際によく利用されるフレームワークといえばGoogleが開発しているTensorflowとFacebookが開発しているPytorchの2大巨頭に加えて、Kerasなどが挙げられるかと思いますが、今回はそのような選択肢の一つとしてGoogleが新しく開発している*1新進気鋭(?)の機械学習フレームワークJAXを紹介したいと思います。 github.com 今回JAXを紹介するきっかけですが、最近話題になったVision Transformerの公式実装のソースコードを読む際に、モデルの実装にJAXが使用されており、少し気になったので勉強がてら触ってみたというのが経緯です。 github.com ディープラーニングのフレームワークの入門といえばMNISTデータセットを使った画像分類ですので、今回はJAXの入門編としてシンプルな多層パーセプトロン

                                      自動微分+XLA付き機械学習フレームワークJAXを使用してMNISTを学習させてみる - Morikatron Engineer Blog
                                    • GitHub - google/flax: Flax is a neural network library for JAX that is designed for flexibility.

                                      Overview | Quick install | What does Flax look like? | Documentation This README is a very short intro. To learn everything you need to know about Flax, refer to our full documentation. Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community. Flax is bei

                                        GitHub - google/flax: Flax is a neural network library for JAX that is designed for flexibility.
                                      • Using JAX to accelerate our research

                                        Open source Using JAX to accelerate our research Published 4 December 2020 Authors David Budden, Matteo Hessel DeepMind engineers accelerate our research by building tools, scaling up algorithms, and creating challenging virtual and physical worlds for training and testing artificial intelligence (AI) systems. As part of this work, we constantly evaluate new machine learning libraries and framewor

                                          Using JAX to accelerate our research
                                        • JAXとPyTorchで勾配法とニュートン法を試す - HELLO CYBERNETICS

                                          はじめに 逐次更新による最適化 大枠 勾配法 数式 勾配法コード例 ニュートン法 数式 ニュートン法のコード例 はじめに 最近、しっかり学ぶ数理最適化を購入しました。 しっかり学ぶ数理最適化 モデルからアルゴリズムまで (KS情報科学専門書) 作者:梅谷 俊治発売日: 2020/10/26メディア: 単行本(ソフトカバー) 1章→3章と読んでいく中で、元々馴染みの深い連続最適化の極々基本的な手法である勾配法とニュートン法を試してみました。実装はJAXを使っています。こいつは現状、最高の自動微分ライブラリだと思っております(深層学習ライブラリという観点ではPyTorchの方が今の所使いやすい)。 普通、機械学習では二次微分なんてパラメータが多すぎてまともに計算できる見込みがないので、純粋なニュートン法なんて絶対に使わないのですが、その圧倒的な性能の高さを確認し、兎にも角にも勾配法の弱さを確認

                                            JAXとPyTorchで勾配法とニュートン法を試す - HELLO CYBERNETICS
                                          • GitHub - google/maxtext: A simple, performant and scalable Jax LLM!

                                            You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session.

                                              GitHub - google/maxtext: A simple, performant and scalable Jax LLM!
                                            • Jax, PyTorch 直線探索付き勾配法 - HELLO CYBERNETICS

                                              はじめに モジュールインポート Jax 勾配関数と線形探索関数を準備 最適化実行 PyTorch 線形探索関数準備 最適化実行 結果 はじめに 前回は下記の記事で学習率固定で勾配法を実施しました。 www.hellocybernetics.tech 今回はウルフ条件を満たすような学習率を各更新時にバックステップで探索し、満たすものを見つけたら直ちにその学習率の更新するという形式で勾配法を実施します。 この記事ではJaxとPyTorchで収束までのステップ数や収束先等の結果はほぼ一致しましたが、速度が圧倒的にJaxの方が速く、PyTorchの計算グラフが変なふうになってしまっている可能性があります(こんなPyTorch遅いわけがない…!) どなたか見つけたら教えて下さい…。 モジュールインポート import jax import jax.numpy as jnp from jax impo

                                                Jax, PyTorch 直線探索付き勾配法 - HELLO CYBERNETICS
                                              • JAX: High-Performance Array Computing — JAX documentation

                                                JAX: High-Performance Array Computing# JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. If you’re looking to train neural networks, use Flax and start with its documentation. Some associated tools are Optax and Orbax. For an end-to-end transformer library built on JAX,

                                                • Whisper JAX - a Hugging Face Space by sanchit-gandhi

                                                  Discover amazing ML apps made by the community

                                                    Whisper JAX - a Hugging Face Space by sanchit-gandhi
                                                  • GitHub - google-deepmind/mctx: Monte Carlo tree search in JAX

                                                    You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert

                                                      GitHub - google-deepmind/mctx: Monte Carlo tree search in JAX
                                                    • Exploring hyperparameter meta-loss landscapes with Jax

                                                      Exploring hyperparameter meta-loss landscapes with Jax Saturday. February 06, 2021 - 20 mins A common mantra of the deep learning community is to differentiate though all the things, e.g. differentiable renderer, differentiable physics, differentiable programming language[julia, dex, myia], etc. In my own research, I’ve found that, while one can often compute a gradient, it isn’t always the most u

                                                      • GitHub - NVIDIA-Merlin/dataloader: The merlin dataloader lets you rapidly load tabular data for training deep leaning models with TensorFlow, PyTorch or JAX

                                                        You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert

                                                          GitHub - NVIDIA-Merlin/dataloader: The merlin dataloader lets you rapidly load tabular data for training deep leaning models with TensorFlow, PyTorch or JAX
                                                        • JAXで始めるディープラーニング|ディープラーニングネイティブ

                                                          こんにちは。今回はGoogleの開発するディープラーニング向けライブラリ⭐️であるJAXの紹介をします。 https://github.com/google/jax ディープラーニング向けライブラリとしてはTensorflowやPyTorch、最近開発終了が宣言されたChainerなどが有名かと思います。これらは多次元配列と自動微分をサポートした計算ライブラリをコアとしていて、それにニューラルネットの実装を容易にするラッパーなどが付属しています。 GoogleといえばTensorflowが有名ですが、JAXはTensorflowとは何が違うのでしょうか。 JAXを一言で表現すると、高速なautogradです。 もう少し詳しくいうと、多次元配列の計算ライブラリであるnumpyに自動微分とJITがくっついたものです。さらに、GPUやTPUといったアクセラレーター上でも動作します。Tensorf

                                                            JAXで始めるディープラーニング|ディープラーニングネイティブ
                                                          • PyTorchとJAXに対応したKeras3でMNISTを試す|はまち

                                                            バックボーンのフレームワークを、従来のTensorFlowから、デファクトスタンダードになりつつあるPyTorchと、実行効率に優れたJAXも選べるようになったKeras3.0が公開されていたので、さっそくバックボーンをPyTorchやJAXに設定して、手書きアルファベット画像のクラス分け課題のMNISTを試してみました。 23.11.29追記 公式の紹介ページも公開されていました。 https://keras.io/keras_3/ Keras3のインストール、インポート今回はGoogle Colabで試してみます。Keras3は現時点ではPyPI上では、プレビューリリースとしてkeras-coreの名前でインストールできます。 !pip install keras-coreバックエンドの設定(torch, jax, tensorflow) import os os.environ["K

                                                              PyTorchとJAXに対応したKeras3でMNISTを試す|はまち
                                                            • Jax/Flax × TransformersでBERTのfine-tuningをTPUで行う | 株式会社AI Shift

                                                              こんにちは AIチームの戸田です 以前、BERTをfine-tuningする際のTipsとして混合精度の利用や、Uniform Length Batchingをつかった学習効率化を紹介させていただきましたが、今回はTPUを使った高速化について紹介したいと思います。 Flax TPU対応というと、まずGoogleのTensorflowが思い浮かびますが、今回は同じGoogleのニューラルネット学習用フレームワークのFlaxを使います。 FlaxはTensorflowと比較して簡潔に、かつ柔軟に書くことができると言われており、huggingfaceのtransformersもv4.8.0からFlaxをサポートするようになっています。 JAX/Flax has joined the ranks of PyTorch and TensorFlow in 🤗Transformers! Versio

                                                                Jax/Flax × TransformersでBERTのfine-tuningをTPUで行う | 株式会社AI Shift
                                                              1

                                                              新着記事