タグ

Pythonと最適輸送に関するxiangzeのブックマーク (2)

  • PyTorchでSliced Wasserstein Distance (SWD)を実装した

    PyTorchでSliced Wasserstein Distance (SWD)を実装してみました。オリジナルの実装はNumpyですが、これはPyTorchで実装しているので、GPU上で計算することができます。来はGANの生成画像を評価するためのものですが、画像の分布不一致を見るためにも使うことができます。 コード こちらのリポジトリにあります。 https://github.com/koshian2/swd-pytorch SWDとは PGGANの論文で使われている画像類似度の評価指標です。GANの評価指標の多く(Inception scoreやFID)が訓練済みInceptionモデルベースであるのに対し、SWDはInception依存ではありません。訓練済みInceptionモデルは大抵ImageNetベースなので、特徴量抽出がドメインによって得意だったり不得意だったりします。し

    PyTorchでSliced Wasserstein Distance (SWD)を実装した
  • Sliced Wasserstein GMM を実装してみた - yokaze.github.io

    最近話題の Sliced Wasserstein Distance (SWD) [Deshpande 2018, Deshpande 2019] を理解するため、 Kolouri らの Sliced Wasserstein Distance for Learning Gaussian Mixture Models (SWGMM) を実装しました。 以前の記事で Wasserstein 距離の解説 を書いたので、こちらも是非ご参照ください。 GitHub に実装を載せました。 あらすじ 従来、観測データと生成モデルの分布を比較するために Kullback-Leibler (KL) 損失が使われてきた。 しかし KL 損失を使うと勾配消失や局所解が発生するため、GAN などの高度なタスクでは上手く動かないことがある。 近年の研究から、Wasserstein 距離が勾配消失に強く、色々なタスクに

  • 1