Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

Zixuan Wang, Stanley Wei, Daniel Hsu, Jason D. Lee

Research output: Contribution to journalConference articlepeer-review

2 Scopus citations

Abstract

The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. (2023) proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an average-case setting and establish an algorithmic separation of transformers over FCNs. Specifically, a one-layer transformer trained with gradient descent provably learns the sparse token selection task and, surprisingly, exhibits strong out-of-distribution length generalization. We provide empirical simulations to justify our theoretical findings.

Original languageEnglish (US)
Pages (from-to)51854-51912
Number of pages59
JournalProceedings of Machine Learning Research
Volume235
StatePublished - 2024
Externally publishedYes
Event41st International Conference on Machine Learning, ICML 2024 - Vienna, Austria
Duration: Jul 21 2024Jul 27 2024

All Science Journal Classification (ASJC) codes

  • Artificial Intelligence
  • Software
  • Control and Systems Engineering
  • Statistics and Probability

Fingerprint

Dive into the research topics of 'Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot'. Together they form a unique fingerprint.

Cite this