Abstract
While deep generative models have succeeded in image processing, natural language processing, and reinforcement learning, training that involves discrete random variables remains challenging due to the high variance of its gradient estimation process. Monte Carlo is a common solution used in most variance reduction approaches. However, this involves time-consuming resampling and multiple function evaluations. We propose a Gapped Straight-Through (GST) estimator to reduce the variance without incurring resampling overhead. This estimator is inspired by the essential properties of Straight-Through Gumbel-Softmax. We determine these properties and show via an ablation study that they are essential. Experiments demonstrate that the proposed GST estimator enjoys better performance compared to strong baselines on two discrete deep generative modeling tasks, MNIST-VAE and ListOps.
Original language | English (US) |
---|---|
Pages (from-to) | 6059-6073 |
Number of pages | 15 |
Journal | Proceedings of Machine Learning Research |
Volume | 162 |
State | Published - 2022 |
Externally published | Yes |
Event | 39th International Conference on Machine Learning, ICML 2022 - Baltimore, United States Duration: Jul 17 2022 → Jul 23 2022 |
All Science Journal Classification (ASJC) codes
- Artificial Intelligence
- Software
- Control and Systems Engineering
- Statistics and Probability