Optimal transport mapping via input convex neural networks

Ashok Vardhan Makkuva, Amirhossein Taghvaei, Jason D. Lee, Sewoong Oh

Research output: Chapter in Book/Report/Conference proceedingConference contribution

1 Scopus citations

Abstract

In this paper, we present a novel and principled approach to learn the optimal transport between two distributions, from samples. Guided by the optimal transport theory, we learn the optimal Kantorovich potential which induces the optimal transport map. This involves learning two convex functions, by solving a novel minimax optimization. Building upon recent advances in the field of input convex neural networks, we propose a new framework to estimate the optimal transport mapping as the gradient of a convex function that is trained via minimax optimization. Numerical experiments confirm the accuracy of the learned transport map. Our approach can be readily used to train a deep generative model. When trained between a simple distribution in the latent space and a target distribution, the learned optimal transport map acts as a deep generative model. Although scaling this to a large dataset is challenging, we demonstrate two important strengths over standard adversarial training: robustness and discontinuity. As we seek the optimal transport, the learned generative model provides the same mapping regardless of how we initialize the neural networks. Further, a gradient of a neural network can easily represent discontinuous mappings, unlike standard neural networks that are constrained to be continuous. This allows the learned transport map to match any target distribution with many discontinuous supports and achieve sharp boundaries.

Original languageEnglish (US)
Title of host publication37th International Conference on Machine Learning, ICML 2020
EditorsHal Daume, Aarti Singh
PublisherInternational Machine Learning Society (IMLS)
Pages6628-6637
Number of pages10
ISBN (Electronic)9781713821120
StatePublished - 2020
Event37th International Conference on Machine Learning, ICML 2020 - Virtual, Online
Duration: Jul 13 2020Jul 18 2020

Publication series

Name37th International Conference on Machine Learning, ICML 2020
VolumePartF168147-9

Conference

Conference37th International Conference on Machine Learning, ICML 2020
CityVirtual, Online
Period7/13/207/18/20

All Science Journal Classification (ASJC) codes

  • Computational Theory and Mathematics
  • Human-Computer Interaction
  • Software

Fingerprint

Dive into the research topics of 'Optimal transport mapping via input convex neural networks'. Together they form a unique fingerprint.

Cite this