LEARNING HIERARCHICAL POLYNOMIALS WITH THREE-LAYER NEURAL NETWORKS

Zihao Wang, Eshaan Nichani, Jason D. Lee

Research output: Contribution to conferencePaperpeer-review

Abstract

We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form h = g ◦ p where p : Rd → R is a degree k polynomial and g : R → R is a degree q polynomial. This function class generalizes the single-index model, which corresponds to k = 1, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree k polynomials p, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target h up to vanishing test error in Õ(dk) samples and polynomial time. This is a strict improvement over kernel methods, which require Θ(e dkq) samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of p being a quadratic. When p is indeed a quadratic, we achieve the information-theoretically optimal sample complexity Õ(d2), which is an improvement over prior work (Nichani et al., 2023) requiring a sample size of Θ(e d4). Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature p with Õ(dk) samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.

Original languageEnglish (US)
StatePublished - 2024
Event12th International Conference on Learning Representations, ICLR 2024 - Hybrid, Vienna, Austria
Duration: May 7 2024May 11 2024

Conference

Conference12th International Conference on Learning Representations, ICLR 2024
Country/TerritoryAustria
CityHybrid, Vienna
Period5/7/245/11/24

All Science Journal Classification (ASJC) codes

  • Language and Linguistics
  • Computer Science Applications
  • Education
  • Linguistics and Language

Fingerprint

Dive into the research topics of 'LEARNING HIERARCHICAL POLYNOMIALS WITH THREE-LAYER NEURAL NETWORKS'. Together they form a unique fingerprint.

Cite this