Learning one-hidden-layer neural networks under general input distributions

Weihao Gao, Ashok Vardhan Makkuva, Sewoong Oh, Pramod Viswanath

Research output: Contribution to conferencePaperpeer-review

4 Scopus citations

Abstract

Significant advances have been made recently on training neural networks, where the main challenge is in solving an optimization problem with abundant critical points. However, existing approaches to address this issue crucially rely on a restrictive assumption: the training data is drawn from a Gaussian distribution. In this paper, we provide a novel unified framework to design loss functions with desirable landscape properties for a wide range of general input distributions. On these loss functions, remarkably, stochastic gradient descent theoretically recovers the true parameters with global initializations and empirically outperforms the existing approaches. Our loss function design bridges the notion of score functions with the topic of neural network optimization. Central to our approach is the task of estimating the score function from samples, which is of basic and independent interest to theoretical statistics. Traditional estimation methods (example: kernel based) fail right at the outset; we bring statistical methods of local likelihood to design a novel estimator of score functions, that provably adapts to the local geometry of the unknown density.

Original languageEnglish (US)
StatePublished - 2020
Externally publishedYes
Event22nd International Conference on Artificial Intelligence and Statistics, AISTATS 2019 - Naha, Japan
Duration: Apr 16 2019Apr 18 2019

Conference

Conference22nd International Conference on Artificial Intelligence and Statistics, AISTATS 2019
Country/TerritoryJapan
CityNaha
Period4/16/194/18/19

All Science Journal Classification (ASJC) codes

  • Artificial Intelligence
  • Statistics and Probability

Fingerprint

Dive into the research topics of 'Learning one-hidden-layer neural networks under general input distributions'. Together they form a unique fingerprint.

Cite this