Abstract
We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form h = g ◦ p where p : Rd → R is a degree k polynomial and g : R → R is a degree q polynomial. This function class generalizes the single-index model, which corresponds to k = 1, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree k polynomials p, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target h up to vanishing test error in Õ(dk) samples and polynomial time. This is a strict improvement over kernel methods, which require Θ(e dkq) samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of p being a quadratic. When p is indeed a quadratic, we achieve the information-theoretically optimal sample complexity Õ(d2), which is an improvement over prior work (Nichani et al., 2023) requiring a sample size of Θ(e d4). Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature p with Õ(dk) samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.
Original language | English (US) |
---|---|
State | Published - 2024 |
Event | 12th International Conference on Learning Representations, ICLR 2024 - Hybrid, Vienna, Austria Duration: May 7 2024 → May 11 2024 |
Conference
Conference | 12th International Conference on Learning Representations, ICLR 2024 |
---|---|
Country/Territory | Austria |
City | Hybrid, Vienna |
Period | 5/7/24 → 5/11/24 |
All Science Journal Classification (ASJC) codes
- Language and Linguistics
- Computer Science Applications
- Education
- Linguistics and Language