Loss Landscape Degeneracy and Stagewise Development of Transformers

Authors

Jesse Hoogland =
Timaeus
George Wang =
Timaeus
Matthew Farrugia-Roberts
Timaeus
Liam Carroll
Timaeus
Susan Wei
University of Melbourne
Daniel Murfet
University of Melbourne
See Contributions

Publication Details

Published:
February 4, 2024
Venue:
TMLR

Abstract

We show that in-context learning emerges in transformers in discrete developmental stages, when they are trained on either language modeling or linear regression tasks. We introduce two methods for detecting the milestones that separate these stages, by probing the geometry of the population loss in both parameter space and function space. We study the stages revealed by these new methods using a range of behavioral and structural metrics to establish their validity.

Automated Conversion Notice

Warning: This paper was automatically converted from LaTeX. While we strive for accuracy, some formatting or mathematical expressions may not render perfectly. Please refer to the original ArXiv version for the authoritative document.

1 Introduction

A striking phenomenon in modern deep learning is the sudden shift in a model’s internal computational structure and associated changes in input/output behavior (e.g., Wei et al., 2022; Olsson et al., 2022; McGrath et al., 2022). As large models become more deeply integrated into real-world applications, understanding this phenomenon is a priority for the science of deep learning.

A key feature of the loss landscape of neural networks is degeneracy—parameters for which some local perturbations do not affect the loss. Motivated by the perspectives of singular learning theory (SLT; Watanabe, 2009) and nonlinear dynamics (Waddington, 1957; Thom, 1972), where degeneracy plays a fundamental role in governing development, we believe that studying degeneracy in the local geometry of the loss landscape is key to understanding the development of structure and behavior in modern deep learning.

Refer to caption
Stage LM1 LM2 LM3 LM4 LM5
End ttitalic_t 900 6.5k 8.5k 17k 50k
Δ^\Delta\hat{\ell}roman_Δ over^ start_ARG roman_ℓ end_ARG 2.33-2.33- 2.33 1.22-1.22- 1.22 0.18-0.18- 0.18 0.40-0.40- 0.40 0.34-0.34- 0.34
Δλ^\Delta\hat{\lambda}roman_Δ over^ start_ARG italic_λ end_ARG +26.4{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+26.4}+ 26.4 +22.5{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+22.5}+ 22.5 1.57{\color[rgb]{0,1,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,1,1}\pgfsys@color@cmyk@stroke{1}{0}{0}{0}\pgfsys@color@cmyk@fill{1}{0}{0}{0}-1.57}- 1.57 +8.62{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+8.62}+ 8.62 +1.77{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+1.77}+ 1.77
(a) Two-layer attention-only language transformer (LM).
Refer to caption
Stage LR1 LR2 LR3 LR4 LR5
End ttitalic_t 1k 40k 126k 320k 500k
Δ^\Delta\hat{\ell}roman_Δ over^ start_ARG roman_ℓ end_ARG 0.32-0.32- 0.32 2.21-2.21- 2.21 0.07-0.07- 0.07 0.05-0.05- 0.05 0.029-0.029- 0.029
Δλ^\Delta\hat{\lambda}roman_Δ over^ start_ARG italic_λ end_ARG +21.4{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+21.4}+ 21.4 +149{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+149}+ 149 12.3{\color[rgb]{0,1,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,1,1}\pgfsys@color@cmyk@stroke{1}{0}{0}{0}\pgfsys@color@cmyk@fill{1}{0}{0}{0}-12.3}- 12.3 44.1{\color[rgb]{0,1,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,1,1}\pgfsys@color@cmyk@stroke{1}{0}{0}{0}\pgfsys@color@cmyk@fill{1}{0}{0}{0}-44.1}- 44.1 +3.56{\color[rgb]{1,.5,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,.5,0}+3.56}+ 3.56
(b) In-context linear regression transformer (LR).
Figure 1: Tracking loss landscape degeneracy reveals developmental stages. We train transformer models on both (a) natural language data and (b) synthetic in-context linear regression data. In addition to test loss (top row), we track loss landscape degeneracy as quantified by the local learning coefficient (LLC) (middle row; Section˜4). Critical points in the LLC curve mark boundaries between distinct developmental stages (bottom row; warm hues for increasing LLC, cold for decreasing LLC; Section˜5). We show in Sections˜6 and 7 that most of these stages coincide with the formation of significant internal structures or changes in input/output behavior. The language model first learns to predict using bigram statistics (LM1), then common nnitalic_n-grams (LM2), before forming the induction circuit studied by Olsson et al. (2022) (LM3&LM4). The regression model first learns the optimal context-independent solution (LR1), then acquires robust in-context learning (LR2), then specializes to the pre-training distribution (LR3&LR4). These stage divisions and interpretations are specific to the above training runs, but we show in Section˜B.4 that similar divisions arise with different training seeds.

In this paper, we contribute an empirical investigation of the link between degeneracy and development for transformers in two learning settings. We track loss landscape degeneracy along with model structure and behavior throughout training, using the following methodology.

  1. Transformer training (Section˜3): We train two transformers, a language model (LM) with around 3M parameters trained on a subset of the Pile (Gao et al., 2020; Xie et al., 2023), and an in-context linear regression model (LR) with around 50k parameters trained on synthetic regression data following Garg et al. (2022).

  2. Degeneracy tracking (Section˜4): We quantify loss landscape degeneracy throughout training by estimating the local learning coefficient (LLC; Lau et al., 2025), a measure of degeneracy derived from SLT.

  3. Degeneracy-based stage division (Section˜5): Motivated by the singular learning process in Bayesian inference Watanabe, 2009, §7.6; Chen et al., 2023, we use critical points in the LLC curve to divide training into approximate developmental stages.

  4. Developmental analysis (Sections 6, 7): We track shifts in each model’s internal computational structure and input/output behavior across training, quantified using various setting-specific metrics.

Crucially, we discover that most of the developmental stages identified by changes in loss landscape degeneracy coincide with significant, interpretable shifts in the internal computational structure and input/output behavior of the transformers, showing that the stage division is meaningful. Our investigations are motivated by the hypothesis of a fundamental link between degeneracy and development in deep learning. This hypothesis is theoretically grounded in SLT but so far not empirically validated except in toy models (Chen et al., 2023). We view the above discoveries as preliminary evidence for this hypothesis in larger models, and an indication of the potential of degeneracy as a lens for understanding modern neural network development. Section˜8 discusses this and other implications of our investigation.

Degeneracy and development in singular learning theory

Our hypothesis that degeneracy and development are fundamentally linked is motivated by singular learning theory (SLT; Watanabe, 2009), a framework for studying singular statistical models a class that includes neural networks, Hagiwara et al., 1993; Watanabe, 2007; Wei et al., 2023. SLT proves that in singular models, Bayesian inference follows the singular learning process, in which degeneracy in the likelihood governs stagewise development in the posterior as the number of samples increases Watanabe, 2009, §7.6; Lau et al., 2025; Chen et al., 2023. While there are many differences between Bayesian inference and modern neural network training, an analogy to the singular learning process informs our methodology for stage division.

Degeneracy and development in nonlinear dynamics

Further motivation for our hypothesis comes from viewing neural network training as a stochastic dynamical system, in which the population loss is a governing potential encoding the data distribution. It is well-understood in nonlinear dynamics that degeneracy in the local geometry of a potential can give rise to stagewise development of system structure (Waddington, 1957; Thom, 1972, cf. Franceschelli, 2010). This connection has been observed in biological systems at significant scale and in the presence of stochasticity (Freedman et al., 2023). We emphasize changes in degeneracy over a stage whereas in bifurcation theory the focus is more on the degeneracy at stage boundaries (Rand et al., 2021; MacArthur, 2022; Sáez et al., 2022).

Stagewise development in deep learning

The idea that neural networks development occurs in stages goes back decades (Raijmakers et al., 1996) and has received renewed attention in modern deep learning (e.g., Wei et al., 2022; Olsson et al., 2022; McGrath et al., 2022; Odonnat et al., 2024; Chen et al., 2024; Edelman et al., 2024). In the case of deep linear networks, we understand theoretically that models learn progressively higher-rank approximations of their data distribution (see, e.g., Baldi & Hornik, 1989; Rogers & McClelland, 2004; Saxe et al., 2019) throughout training. Our findings suggest that studying degeneracy could help generalize this understanding to modern architectures that exhibit more complex internal computational structure, such as transformers.

Studying loss landscape geometry

Given the central role played by the loss landscape in deep learning, it is unsurprising that there have been many attempts to study its geometry.

One approach is to visualize low-dimensional slices of the loss landscape (Erhan et al., 2010; Goodfellow et al., 2014; Lipton, 2016; Li et al., 2018; Tikeng Notsawo et al., 2024). Unfortunately, a random slice is with high probability a quadratic form associated to nonzero eigenvalues of the Hessian and is thus biased against geometric features that we know are important, such as degeneracy (Wei et al., 2023). Moreover, Antognini & Sohl-Dickstein (2018) have emphasized the difficulty of probing the loss landscape of neural networks with dimensionality reduction tools.

Other standard methods of quantifying the geometry of the loss landscape, such as via the Hessian, are insensitive to important aspects of degeneracy. For example, the Hessian trace or maximum eigenvalues quantify the curvature of a critical point but ignore degenerate dimensions, and the Hessian rank counts the number of degenerate dimensions but fails to distinguish between dimensions by the order of their degeneracy (e.g., quartic vs. zero). In contrast, the LLC is a principled quantitative measure of loss landscape degeneracy. Section˜B.5 includes experiments showing that Hessian statistics do not reveal the clear stage boundaries revealed by the LLC in our in-context linear regression setting.

3 Training transformers in two settings

We study transformers trained in two learning settings, namely language modeling and in-context linear regression. These settings have been the subject of recent work on the emergence of in-context learning (ICL), a compelling example of a sudden shift in a model’s internal computational structure in modern deep learning (Olsson et al., 2022).

In this section, we describe both settings and introduce their loss functions and data distributions. Common to both settings is a transformer model denoted fwf_{w}italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT with parameters wwitalic_w, which takes as input a sequence of tokens, also called a context. We describe specific architecture details and training hyperparameters in Sections˜F.1 and F.2.

Language modeling

Elhage et al. (2021) and Olsson et al. (2022) observed that two-layer attention-only transformers (transformers without MLP layers) form interesting internal computational structures supporting ICL, including induction heads. In order to compare with their behavioral and structural analysis we adopt the same architecture. In Appendix˜E we also study one-layer attention-only transformers. We note that, while we don’t study language models with MLP layers (following prior work), we do use MLP layers for in-context linear regression.

We consider the standard task of next-token prediction for token sequences taken from a subset of the Pile (Gao et al., 2020; Xie et al., 2023). We denote the input context by SK=(t1,,tK)S_{K}=(t_{1},\ldots,t_{K})italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_t start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) where KKitalic_K is the context length. We denote by SkS_{\leq k}italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT the prefix context (t1,,tk)(t_{1},\ldots,t_{k})( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) of context SKS_{K}italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT. Our data is a collection of length-KKitalic_K contexts, {SKi}i=1n\{S^{i}_{K}\}_{i=1}^{n}{ italic_S start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Thus SkiS^{i}_{\leq k}italic_S start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT denotes a prefix of the iiitalic_ith context, SKiS^{i}_{K}italic_S start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT.

Given the context SkiS^{i}_{\leq k}italic_S start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT, the transformer model fwf_{w}italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT outputs a vector of logits fw(Ski)f_{w}(S^{i}_{\leq k})italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT ) such that softmax(fw(Ski))\mathrm{softmax}(f_{w}(S_{\leq k}^{i}))roman_softmax ( italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) ) is a probability distribution over all tokens (we denote by softmax(fw(Ski))[t]\mathrm{softmax}(f_{w}(S_{\leq k}^{i}))[t]roman_softmax ( italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) ) [ italic_t ] the probability of token ttitalic_t). The per-token empirical loss for language modeling is then the average cross-entropy between this distribution and the true next token at each index k{1,,K1}k\in\{1,\ldots,K-1\}italic_k ∈ { 1 , … , italic_K - 1 },

n,k(w)=1ni=1nlog(softmax(fw(Ski))[tk+1i]).\ell_{n,k}(w)=\frac{1}{n}\sum_{i=1}^{n}-\log\left(\mathrm{softmax}(f_{w}(S_{\leq k}^{i}))[t^{i}_{k+1}]\right).roman_ℓ start_POSTSUBSCRIPT italic_n , italic_k end_POSTSUBSCRIPT ( italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT - roman_log ( roman_softmax ( italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) ) [ italic_t start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ] ) . (1)

The empirical loss is then n(w)=1K1k=1K1n,k(w)\ell_{n}(w)=\frac{1}{K-1}\sum_{k=1}^{K-1}\ell_{n,k}(w)roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_K - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K - 1 end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n , italic_k end_POSTSUBSCRIPT ( italic_w ), with the test loss ^(w)\hat{\ell}(w)over^ start_ARG roman_ℓ end_ARG ( italic_w ) defined analogously on a held-out set of examples. The corresponding population loss (w)\ell(w)roman_ℓ ( italic_w ) is defined by taking the expectation with respect to the true distribution of contexts (see also Section˜A.6).

In-context linear regression

Following Garg et al. (2022), a number of recent works have explored ICL in the setting of learning simple function classes, such as linear functions. This setting is of interest because we understand theoretically optimal (in-context) linear regression, and because simple transformers are capable of good ICL performance in practice (see, e.g., Garg et al., 2022; Raventós et al., 2023).

We consider a standard synthetic in-context linear regression problem. A task is a vector 𝐭D\mathbf{t}\in\mathbb{R}^{D}bold_t ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, and an example is a pair (x,y)D×(x,y)\in\mathbb{R}^{D}\times\mathbb{R}( italic_x , italic_y ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT × blackboard_R. We sample a context by sampling one task 𝐭𝒩(0,ID)\mathbf{t}\sim\mathcal{N}(0,I_{D})bold_t ∼ caligraphic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) and then sampling KKitalic_K i.i.d. inputs x1,,xK𝒩(0,ID)x_{1},\ldots,x_{K}\sim\mathcal{N}(0,I_{D})italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) and outputs y1,,yK𝒩(𝐭x,σ2)y_{1},\ldots,y_{K}\sim\mathcal{N}(\mathbf{t}^{\top}x,\sigma^{2})italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_t start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). This results in the context SK=(x1,y1,,xK1,yK1,xK)S_{K}=(x_{1},y_{1},\ldots,x_{K-1},y_{K-1},x_{K})italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_K - 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_K - 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) with label yKy_{K}italic_y start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT. We denote by SkS_{\leq k}italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT the prefix context (x1,y1,,xk)(x_{1},y_{1},\ldots,x_{k})( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) of context SKS_{K}italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, its label is yky_{k}italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Section˜F.2.2 describes how we encode the xix_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and yiy_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as tokens. Our data is a set of contexts {(𝐭i,SKi,yKi)}i=1n\{(\mathbf{t}_{i},S_{K}^{i},y_{K}^{i})\}_{i=1}^{n}{ ( bold_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT sampled i.i.d. as described above.

Running a context SkiS_{\leq k}^{i}italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT through the transformer yields a prediction y^ki=fw(Ski)\hat{y}_{k}^{i}=f_{w}(S_{\leq k}^{i})over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT ≤ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ), leading to the per-token empirical loss for in-context linear regression for k{1,,K}k\in\{1,\ldots,K\}italic_k ∈ { 1 , … , italic_K },

n,k(w)=1ni=1n(y^kiyki)2.\ell_{n,k}(w)=\frac{1}{n}\sum_{i=1}^{n}(\hat{y}^{i}_{k}-y_{k}^{i})^{2}.roman_ℓ start_POSTSUBSCRIPT italic_n , italic_k end_POSTSUBSCRIPT ( italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (2)

The associated empirical loss is n(w)=1Kk=1Kn,k(w)\ell_{n}(w)=\frac{1}{K}\sum_{k=1}^{K}\ell_{n,k}(w)roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n , italic_k end_POSTSUBSCRIPT ( italic_w ). The corresponding test loss ^(w)\hat{\ell}(w)over^ start_ARG roman_ℓ end_ARG ( italic_w ) and population loss (w)\ell(w)roman_ℓ ( italic_w ) are defined analogously as in the language modeling setting.

Refer to caption
Figure 2: The local learning coefficient (LLC) measures loss landscape degeneracy. The LLC can be defined in terms of the rate at which the parameter space volume (within a given neighborhood and with a given maximum loss) shrinks as the loss threshold is reduced to that of the local minimum. We show four population loss landscapes for a two-dimensional parameter space with decreasing LLC (increasing degeneracy). In these examples, the local multiplicity is 1. See Section˜A.2 for a detailed description of each example, as well as several additional examples.

4 Quantifying degeneracy with the local learning coefficient

We track the evolution of degeneracy in the local geometry of the loss landscape throughout training by estimating the local learning coefficient (LLC; Watanabe, 2009; Lau et al., 2025) at model checkpoints. In this section, we review the LLC and the estimation procedure of Lau et al. (2025).

The local learning coefficient (LLC)

Given a local minimum w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT of a population loss \ellroman_ℓ (a negative log likelihood), the LLC of w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, denoted λ(w)\lambda({w^{*}})italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ), is a positive rational number that measures the amount of degeneracy in \ellroman_ℓ near ww^{*}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT (Watanabe, 2009; Lau et al., 2025), i.e., how many ways wwitalic_w can be varied near w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT such that (w)\ell(w)roman_ℓ ( italic_w ) remains equal to (w)\ell({w^{*}})roman_ℓ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ). Formally, the LLC is defined as the volume-scaling rate near w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. This is illustrated in Figure˜2, further described in Section˜A.1, and treated in full detail in Lau et al. (2025). Informally, the LLC is a measure of minimum “flatness.” It improves over conventional (second-order) Hessian-based measures of flatness because the LLC is sensitive to more significant, higher-order contributions to volume-scaling.

Estimating the LLC

Lau et al. (2025) introduced an estimator for the LLC based on stochastic-gradient Langevin dynamics (SGLD; Welling & Teh, 2011), which we use in our experiments. Let w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT be a local minimum of the population loss \ellroman_ℓ. The LLC estimate λ^(w)\hat{\lambda}({w^{*}})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is

λ^(w)=nβ[𝔼w|w,γβ[n(w)]n(w)],\hat{\lambda}({w^{*}})=n\beta\left[\mathbb{E}_{w|{w^{*}},\gamma}^{\beta}[\ell_{n}(w)]-\ell_{n}({w^{*}})\right],over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = italic_n italic_β [ blackboard_E start_POSTSUBSCRIPT italic_w | italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT [ roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) ] - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ] , (3)

where 𝔼w|w,γβ\mathbb{E}_{w|{w^{*}},\gamma}^{\beta}blackboard_E start_POSTSUBSCRIPT italic_w | italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_γ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT denotes the expectation with respect to the localized Gibbs posterior

p(w;w,β,γ)exp{nβn(w)γ2ww22}p(w;{w^{*}},\beta,\gamma)\propto\exp\left\{-n\beta\ell_{n}(w)-\frac{\gamma}{2}||w-{w^{*}}||^{2}_{2}\right\}italic_p ( italic_w ; italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_β , italic_γ ) ∝ roman_exp { - italic_n italic_β roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) - divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG | | italic_w - italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT }

with inverse temperature β\betaitalic_β (controlling the contribution of the empirical loss landscape) and localization strength γ\gammaitalic_γ (controlling proximity to w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT). The basic idea behind this estimator is the following: the more degenerate the loss landscape, the easier it is for a sampler exploring the Gibbs posterior to find points of low loss, and, in turn, the lower λ^(w)\hat{\lambda}({w^{*}})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ). Section˜A.3 discusses technical SGLD details, Section˜A.4 documents the hyperparameters used in our experiments, and Section˜A.5 outlines our hyperparameter tuning procedure.

Assumptions of LLC estimation

Strictly speaking, the LLC is defined only for loss functions arising as a negative log likelihood, whereas our loss function includes terms from overlapping context prefixes. It is possible to define a negative log likelihood-based loss for transformer training—we show empirically in Section˜A.6 that this does not have a significant effect on LLC estimates, and so we proceed with overlapping contexts for efficiency.

Moreover, the LLC is defined only for local minima of such loss functions, whereas we note equation ˜3 is defined for arbitrary w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and we apply the estimator throughout training. This approach has precedent in prior work on LLC estimation: Lau et al. (2025) showed that when applied to trained parameters, the estimator accurately recovers the learning coefficient associated with a nearby minimum, and Chen et al. (2023) found that the estimator produces reliable results for parameters throughout training. In our case, we obtain stable estimates throughout training given sufficiently strong localization γ\gammaitalic_γ. See Section˜A.7 for more details.

Refer to caption
Figure 3: In the singular learning process, the Bayesian posterior can shift between neighborhoods with different degeneracy. Watanabe’s free energy formula ˜4 highlights a tradeoff between loss n\ell_{n}roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT (the linear term coefficient) and degeneracy λ\lambdaitalic_λ (the LLC, the logarithmic term coefficient). Consider two local minima w1,w2w_{1}^{\ast},w_{2}^{\ast}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT with neighborhoods W1,W2W_{1}^{\ast},W_{2}^{\ast}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. As the number of samples nnitalic_n increases, if w2w_{2}^{\ast}italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT has lower loss and higher LLC than w1w_{1}^{\ast}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, W2W_{2}^{\ast}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT will suddenly achieve lower free energy than W1W_{1}^{\ast}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT at some critical sample size ncritn_{\text{crit}}italic_n start_POSTSUBSCRIPT crit end_POSTSUBSCRIPT, causing the Bayesian posterior to shift from concentrating in W1W_{1}^{\ast}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to W2W_{2}^{\ast}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

5 Degeneracy-based stage division

We use critical points (that is, plateaus, where the first derivative vanishes) in the LLC curve to define stage boundaries that divide training into developmental stages. This approach is motivated by the singular learning process in Bayesian inference, which we review below.

Bayesian local free energy

Let WW^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT be a neighborhood of a local minimum w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT of the population loss \ellroman_ℓ (a negative log likelihood). Given nnitalic_n samples we can define the local free energy of the neighborhood (Lau et al., 2025),

Fn(W)=logWexp(nn(w))φ(w)𝑑w,F_{n}(W^{*})=-\log\int_{W^{*}}\exp(-n\ell_{n}(w))\varphi(w)\,dw,italic_F start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = - roman_log ∫ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_exp ( - italic_n roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) ) italic_φ ( italic_w ) italic_d italic_w ,

where φ(w)\varphi(w)italic_φ ( italic_w ) is a prior positive on the neighborhood WW^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. The lower the local free energy of a neighborhood WW^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, the higher the Bayesian posterior mass of WW^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. In fact, by a log-sum-exp approximation, the Bayesian posterior is approximately concentrated on the neighborhood with the lowest local free energy (cf., Chen et al., 2023).

The singular learning process

Watanabe’s free energy formula gives, under certain technical conditions, an asymptotic expansion in nnitalic_n of the local free energy Watanabe, 2018, Theorem 11; Lau et al., 2025:

Fn(W)=nn(w)+λ(w)logn+Op(loglogn)F_{n}(W^{*})=n\ell_{n}({w^{*}})+\lambda({w^{*}})\log n+O_{p}(\log\log n)italic_F start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = italic_n roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) + italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) roman_log italic_n + italic_O start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_log roman_log italic_n ) (4)

Here, n(w)\ell_{n}({w^{*}})roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is the empirical loss, λ(w)\lambda({w^{*}})italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is the LLC, and the lower-order terms include a constant contribution from the prior mass of WW^{*}italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

The first two terms in equation ˜4 create a tradeoff between accuracy (n\ell_{n}roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT) and degeneracy (λ\lambdaitalic_λ). Moreover, as nnitalic_n increases, the linear term becomes increasingly important relative to the logarithmic term, changing the nature of the tradeoff. At certain critical nnitalic_n the neighborhood with the lowest local free energy may rapidly change to a neighborhood with decreased loss and increased LLC, as illustrated in Figure˜3.

A sequence of such posterior transitions between increasingly degenerate neighborhoods is a prime example of the singular learning process (Watanabe, 2009, §7.6). We note that this is not the only possible dynamic—lower-order terms may also play a role in the evolving competition.

LLC plateaus separate developmental stages

While the general connection between the singular learning process in Bayesian inference and stagewise development in deep learning remains to be understood, Chen et al. (2023) showed that, in small autoencoders, both Bayesian inference and stochastic gradient descent undergo rapid transitions between encoding schemes, and these transitions are reflected as sudden changes in the estimated LLC.

This perspective suggests that changes in the loss landscape degeneracy, as measured by the LLC, reflect qualitative changes in the model. In larger models, we expect that these qualitative changes may be more gradual, while still being delineated by brief moments in which the posterior is stably concentrated around a given local minimum. This motivates our approach of identifying plateaus in the estimated LLC curve—brief pauses before and after a given increase or decrease in degeneracy—as stage boundaries which divide training into approximate developmental stages. The resolution of these stage boundaries depends on the density of checkpoints used for LLC estimation and the precision of those estimates.

Results

In our experiments, we identify plateaus in the estimated LLC curve by first lightly smoothing the LLC curve with a Gaussian process to facilitate stable numerical differentiation with respect to log training time. We identify plateaus as approximate zeros of this derivative, namely local minima of the absolute derivative that fall below a small threshold (see Section˜B.1). Figures˜1, B.2 and B.3 show the results. Sections˜B.4 and B.4 shows that similar stage divisions arise for independent training runs.

6 Results for language modeling

Plateaus in LLC estimates (Figures˜1(a) and 1(a)) reveal five developmental stages for our language model. In order to validate that this stage division is meaningful, we search for concomitant changes in the model’s input/output behavior and its internal computational structure. In this section, we report a range of setting-specific metrics that reveal the following significant, interpretable changes coinciding with each stage: in LM1 the model learns to predict according to bigram statistics; in LM2 the model learns to predict frequent nnitalic_n-grams and use the positional embedding; in LM3 and LM4 the model respectively forms “previous-token heads” and “induction heads” as part of the same induction circuit studied by Olsson et al. (2022). Note that we did not discover significant changes in LM5, and we do not claim that these are the only interesting developmental changes occurring throughout training. There may be other interesting developmental changes that are not captured by our metrics, or are not significant enough to not show up in the LLC curve.

6.1 Stage LM1 (0–900 steps)

Learning bigram statistics

Figure˜4(a) shows that the bigram score—the average cross entropy between model logits and empirical bigram frequencies (see Section˜C.1.1)—is minimized around the LM1LM2 boundary, with a value only 0.3 nats above the irreducible entropy of the empirical bigram distribution. This suggests that during LM1 the model learns to predict using bigram statistics (the optimal next-token prediction given only the current token).

6.2 Stage LM2 (900–6.5k steps)

Using positional information

During LM2 the positional embedding becomes structurally important. Figure˜4(b) shows that here the test loss for the model with the positional embedding zero-ablated diverges from the test loss of the unablated model (see Section˜C.2.1). Specifically, we mean setting learned positional embeddings to zero during evaluation. Conditional on our architecture this establishes whether the model effectively uses positional information. A similar method could be used in a model without learned positional embeddings. There is also an uptick in previous-token attention among some first-layer attention heads shown in green in Figure˜4(d).

Learning common nnitalic_n-grams

We define an nnitalic_n-gram score as the ratio of final-position token loss on (1) a baseline set of samples from a validation set truncated to nnitalic_n tokens, and (2) a fixed set of common nnitalic_n-grams (see Section˜C.1.2). To compute the “common n-grams” score after extracting the top 1000 n-grams, we compute the loss on contexts like

            [<bos_token>, <token_1>, <token_2>, ..., <token_n>]
    

using the loss on <token_n> and normalize against the average loss on the nnitalic_n-th token of similar-length contexts drawn from the pretraining distribution, then divide the nnitalic_n-th token loss of truncated pretraining contexts by the nnitalic_n-gram loss to get the nnitalic_n-gram score.

Figure˜4(c) shows a large improvement in nnitalic_n-gram score for n=3,4n=3,4italic_n = 3 , 4 during LM2. This suggests that during LM2 the model memorizes and learns to predict common nnitalic_n-grams for n>2n>2italic_n > 2 (note this requires using the positional encoding and may also involve previous-token heads).

(a)(b)(c)Refer to caption(d)(e)(f)Refer to caption
Figure 4: Language model stages coincide with significant structural and behavioral changes. (a) The model learns bigram statistics in LM1, (b) then the positional embedding becomes useful from LM2, (c) enabling the learning of common nnitalic_n-grams. Induction circuit formation begins with (d) previous-token heads in LM3, followed by (e) induction heads in LM4, leading to (f) a drop in ICL score indicating the acquisition of in-context learning. Note: in (d,e), l:hl{:}hitalic_l : italic_h denotes attention head hhitalic_h in layer llitalic_l; dark lines indicate heads comprising the induction circuit.
Foundations of induction circuit

In this stage, the heads that eventually become previous-token and induction heads in future stages begin to compose (that is, read from and write to a shared residual stream subspace; see Figures˜C.4 and C.2.2). This suggests that the foundations for the induction circuit are laid in advance of any measurable change in model outputs or attention patterns.

6.3 Stages LM3 & LM4 (6.5k–8.5k & 8.5k–17k steps)

Formation of induction circuit as studied in Olsson et al., 2022

Figure˜4(d) shows the previous-token matching score (Section˜C.2.3) rises over LM3 and LM4 for the two first-layer heads that eventually participate in the induction circuit (as distinguished by their composition scores, Section˜C.2.2). Figure˜4(e) shows that during LM4 there is an increase in the prefix-matching score (Section˜C.2.4) for the two second-layer induction heads that complete the induction circuit. Figure˜4(f) shows a corresponding drop in the ICL score (Section˜C.1.3) as the model begins to perform in-context learning.

The LLC decreases during LM3, suggesting an increase in degeneracy (a decrease in model complexity). This may be related to interaction between heads. It would be interesting to study this stage further via mechanistic interpretability.

7 Results for in-context linear regression

Refer to captionRefer to caption(a)(b)(c)
Figure 5: In-context linear regression model stages coincide with significant structural and behavioral changes. (a) During LR1, the model learns to make context-independent predictions, xky^k=0x_{k}\mapsto\hat{y}_{k}=0italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ↦ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 0. (b) During LR2, ICL performance improves, then during LR3 the model becomes worse at ICL on OOD inputs xk𝒩(0,gID)x_{k}\sim\mathcal{N}(0,gI_{D})italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_g italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) for g>3g>3italic_g > 3. (c) During LR3 and LR4, layer normalization weights “collapse,” possibly contributing to the LLC decrease.

Plateaus in the LLC estimates (Figures˜1(b) and 1(b)) reveal five developmental stages for our in-context linear regression model. We validate that this stage division is meaningful by identifying significant, concomitant changes in the model’s structure and behavior: in LR1 the model learns to predict without looking at the context; in LR2 the model acquires a robust in-context learning ability; and in LR3 and LR4 the model becomes more fragile to out-of-distribution inputs. We did not discover significant changes in LR5, nor do we claim this is an exhaustive list of developments.

7.1 Stage LR1 (0–1k steps)

Learning to predict without context

Figure˜5(a) shows that the mean square prediction for all tokens 𝔼[y^k2]\mathbb{E}[\|\hat{y}_{k}\|^{2}]blackboard_E [ ∥ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] decreases during LR1, reaching a minimum of 0.10.10.1 (smaller than the target noise σ2=0.125\sigma^{2}=0.125italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 0.125) slightly after the end of LR1. Similar to how the language model learned bigram statistics in LM1, this suggests the model first learns the optimal context-independent prediction y^k=¯𝐭xk\hat{y}_{k}=\bar{}\mathbf{t}^{\top}x_{k}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = over¯ start_ARG end_ARG bold_t start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT where ¯𝐭\bar{}\mathbf{t}over¯ start_ARG end_ARG bold_t is the mean of the task distribution (zero in this case).

7.2 Stage LR2 (1k–40k steps)

Acquiring in-context learning

Figure˜5(b) shows that during LR2 there is a drop in ICL score (Section˜D.1.2), indicating that the model acquires in-context learning.

Embedding and attention collapse

Section˜D.2 documents additional changes. Near the end of LR2, token and positional embeddings begin to “collapse,” effectively losing singular values and aligning with the same activation subspace (Sections˜D.2.1 and D.2.2). At the same time, several attention heads form concentrated, input-independent attention patterns (Section˜D.2.3).

7.3 Stages LR3 & LR4 (40k–126k & 126k–320k steps)

Reduced robustness to input magnitude

While performance continues to improve on typical sequences, Figure˜5(b) shows that during LR3 and LR4, the model’s in-context learning ability deteriorates for outlier sequences with higher-than-average |xk||x_{k}|| italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT |.

Layer-normalization collapse

Figure˜5(c) shows the individual weights in the final layer normalization module. A large fraction of these weights go to zero in LR3 and LR4. This occurs in tandem with a similar collapse in the weights of the unembedding transforms (Section˜D.2.4). This results in the model learning to read its prediction y^k\hat{y}_{k}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT from a handful of privileged dimensions of the residual stream. Since this means that the network outputs become insensitive to changes in many of the parameters, we conjecture that this explains part of the striking decrease in estimated LLC over these stages (Section˜D.2.4).

This collapse is most pronounced and affects the largest proportion of weights in the unembedding module, but in LR4 it spreads to earlier layer normalization modules, particularly the layer normalization module before the first attention block (Section˜D.2.5).

8 Discussion

In this paper, we have examined the development of transformer models in two distinct learning settings. We quantified the changes in loss landscape degeneracy throughout transformer training by estimating the local learning coefficient (LLC). Motivated by the singular learning process in Bayesian inference, we divided these training runs into developmental stages at critical points of the LLC curve. We found that these developmental stages roughly coincided with significant changes in the internal computational structure and the input/output behavior of each model. In this section, we discuss several implications of these findings.

Towards a degeneracy-based understanding of deep learning

That significant structural and behavioral changes show up in the LLC curve is evidence that the development of our transformers is closely linked to loss landscape degeneracy. This finding underscores the potential of loss landscape degeneracy as a crucial lens through which to study the development of deep learning models.

While we studied two distinct learning settings (including language modeling with a nontrivial transformer architecture), it remains necessary to verify the connection between degeneracy and development across a more diverse range of emergent model structures and behaviors. Moreover, future work should investigate this connection in more depth, seeking to establish a causal connection between changes in degeneracy and changes in structure and behavior.

Towards developmental interpretability

We showed that degeneracy can reveal meaningful changes in transformers. We emphasize that our analysis is not exhaustive—we expect only certain “macroscopic” changes, such as the emergence of in-context learning, will have a significant enough effect on loss landscape degeneracy to appear separated by plateaus in the LLC curve. Recent work has extended these ideas by measuring the LLC with respect to network sub-modules and with different data distributions, providing a more refined picture of model development (Wang et al., 2025). We expect this research direction will lead to insights into the development of more complex models.

Loss landscape degeneracy offers a setting-agnostic, “unsupervised” alternative to setting-specific progress measures such as those derived by Barak et al. (2022) or developed using mechanistic insights from similar models by Nanda et al. (2023). Both approaches can reveal developments invisible in the loss, but loss landscape degeneracy is able to detect changes without requiring a mechanistic understanding in advance. Of course, once a change is detected through its effect on degeneracy, it remains to interpret the change.

Cases studies in transformer development

We do not claim that the structural and behavioral developments we observed in each setting are universal phenomena. Transformers trained with different architectures, data distributions, algorithms, or hyperparameters may develop differently. Rather, our detailed analysis contributes two “case studies” to the growing empirical literature on the emergence of structure in transformers.

On this note, our observations extend those of Olsson et al. (2022) and Elhage et al. (2021). We show that before the induction circuit forms, our 2-layer language model learns simpler interpretable strategies (based on bigram statistics and common nnitalic_n-grams). This shows that a single training run follows a progression akin to that found by Olsson et al. (2022) for fully-developed models of increasing depth (they showed that “0-layer” models learn bigram statistics and 1-layer models learn “skip-trigrams”). A similar progression was observed by Edelman et al. (2024) in a Markovian sequence modeling task.

Moreover, in both settings, we saw that before in-context learning emerges, the model learns to predict tokens using the optimal prediction given only the current token (bigram statistics for language modeling, zero for in-context linear regression with this distribution of tasks).

Development and model complexity

While we have described the LLC as a measure of loss landscape degeneracy, it can also be understood as a measure of model complexity (cf. Section˜A.2). It is natural for changes in a model’s internal structure to show up as a change in complexity. For example, Chen et al. (2024) showed that the emergence of syntactic attention structure coincides with a spike in two model complexity measures, namely the model’s Fisher information and the intrinsic dimension (Facco et al., 2017) of the model’s embeddings.

Notably, we observe stages in which the LLC decreases, corresponding to a simplification of the computational structure of the model. Such model simplification has empirical precedent, for instance with Chen et al. (2024) and the recent literature on grokking (Power et al., 2022; Nanda et al., 2023; Tikeng Notsawo et al., 2024). In our case, the mechanistic nature of the simplification is not fully clear, with the collapse of various weights and attention patterns arising as candidates in the in-context linear regression setting.

This phenomenon is currently not accounted for by theories of neural network development. In the theory of saddle-to-saddle dynamics, deep linear networks learn progressively more complex approximations of the data (Saxe et al., 2019). Likewise, the example transitions in the singular learning process outlined in Sections˜5 and 3 describe LLC increases. Though we note that decreasing the LLC while holding the loss constant would be another way to decrease the free energy according to equation ˜4, providing a full theoretical account of these stages is an open problem.

Appendix

Appendix˜A reviews the learning coefficient, providing some simple toy examples contrasting the learning coefficient with Hessian-based measures. This section also discusses SGLD-based LLC estimation including experiment hyperparameters (Section˜A.4), and offers a detailed example of the calibrations involved in applying LLC estimation to regression transformers to serve as a reference (Section˜A.5). Appendix˜B provides further detail on our procedure for LLC-based stage identification, including stages identified in additional training runs and a brief comparison with Hessian statistics. Appendices˜C and D examine the developmental stages of language models and in-context linear regression in more detail and explain the various metrics we use to track behavioral and structural development. Appendix˜E describes some additional experiments on a one-layer language model. Appendix˜F covers transformer training experimental details, such as model architectures, training procedures, and hyperparameters.

To facilitate reproduction of our analyses, we have made our codebase available. A repository containing additional figures and code can be accessed at the URL https://github.com/timaeus-research/icl.

[sections] [sections]l1

Appendix A The local learning coefficient (LLC)

A.1 Formal Definition of the LLC

In the setting of Section˜4, let BBitalic_B be a closed ball around w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT such that w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is a global minimum on BBitalic_B, by which we mean a point with (equal) lowest loss. If there are multiple such global minima, the volume asymptotics are determined by the geometry of one that is most degenerate in the precise sense of SLT, formalised in Lau et al. (2025), roughly corresponding to having the lowest LLC. We call this minimum the maximally degenerate global minimum on BBitalic_B. Consider the volume of the set of nearby low-loss parameters,

V(ϵ)=B𝟙{(w)(w)+ϵ}𝑑w.V(\epsilon)=\int_{B}\mathds{1}\{\ell(w)\leq\ell({w^{*}})+\epsilon\}\,dw.italic_V ( italic_ϵ ) = ∫ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT blackboard_1 { roman_ℓ ( italic_w ) ≤ roman_ℓ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) + italic_ϵ } italic_d italic_w .

As ϵ0\epsilon\to 0italic_ϵ → 0, V(ϵ)V(\epsilon)italic_V ( italic_ϵ ) is asymptotically equivalent to

cϵλ(w)log(1/ϵ)m(w)1,c\epsilon^{\lambda({w^{*}})}\log(1/\epsilon)^{m({w^{*}})-1},italic_c italic_ϵ start_POSTSUPERSCRIPT italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT roman_log ( 1 / italic_ϵ ) start_POSTSUPERSCRIPT italic_m ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) - 1 end_POSTSUPERSCRIPT ,

where λ(w)\lambda({w^{*}})italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is the LLC, m(w)m({w^{*}})italic_m ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is another geometric quantity called the local multiplicity, and c>0c>0italic_c > 0 is a constant.

A.2 Interpretations and examples of the LLC

In Section˜4, we introduced the LLC as a quantification of geometric degeneracy. In this section, we discuss an additional perspectives on the LLC as a count of the “effective” dimensionality of a parameter, and we give additional examples of the LLC. We refer the reader to Watanabe (2009) and Lau et al. (2025) for more discussion.

The LLC has some similarity to an effective parameter count. If the population loss \ellroman_ℓ looks like a quadratic form near ww^{*}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT then λ(w)=d2\lambda(w^{*})=\tfrac{d}{2}italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = divide start_ARG italic_d end_ARG start_ARG 2 end_ARG is half the number of parameters, which we can think of as dditalic_d contributions of 12\tfrac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG from every independent quadratic direction. If there are only d1d-1italic_d - 1 independent quadratic directions, and one coordinate wiw_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT such that small variations in wiw_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT near wiw_{i}^{*}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT do not change the model relative to the truth (this dimension is “unused”) then λ(w)=d12\lambda(w^{*})=\tfrac{d-1}{2}italic_λ ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = divide start_ARG italic_d - 1 end_ARG start_ARG 2 end_ARG.

The situation becomes more intricate when certain dimensions are degenerate but not completely unused, varying to quartic or higher order near the parameter (rather than being quadratic or flat). While every unused coordinate reduces the LLC by 12\tfrac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG, changing the dependency on a coordinate from quadratic (wi2w_{i}^{2}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT) to quartic (wi4w_{i}^{4}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT) (increasing its degeneracy while still “using” it) reduces the contribution to the LLC from 12\tfrac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG to 14\tfrac{1}{4}divide start_ARG 1 end_ARG start_ARG 4 end_ARG.

As a source of intuition, we provide several examples of exact LLCs:

  • (w1,w2,w3)=aw12+bw22+cw32\ell(w_{1},w_{2},w_{3})=aw_{1}^{2}+bw_{2}^{2}+cw_{3}^{2}roman_ℓ ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) = italic_a italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_b italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_c italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with a,b,c>0a,b,c>0italic_a , italic_b , italic_c > 0. This function is nondegenerate, and λ(0,0,0)=12+12+12=32\lambda(0,0,0)=\tfrac{1}{2}+\tfrac{1}{2}+\tfrac{1}{2}=\tfrac{3}{2}italic_λ ( 0 , 0 , 0 ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG + divide start_ARG 1 end_ARG start_ARG 2 end_ARG + divide start_ARG 1 end_ARG start_ARG 2 end_ARG = divide start_ARG 3 end_ARG start_ARG 2 end_ARG. This is independent of a,b,ca,b,citalic_a , italic_b , italic_c. That is, the LLC λ\lambdaitalic_λ does not measure curvature. For this reason, it is better to avoid an intuition that centers on “basin broadness” since this tends to suggest that lowering a,b,ca,b,citalic_a , italic_b , italic_c should affect the LLC.

  • (w1,w2,w3)=w12+w22+0\ell(w_{1},w_{2},w_{3})=w_{1}^{2}+w_{2}^{2}+0roman_ℓ ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 0 in 3\mathbb{R}^{3}blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT is degenerate, but its level sets are still submanifolds and λ(0,0,0)=12+12\lambda(0,0,0)=\tfrac{1}{2}+\tfrac{1}{2}italic_λ ( 0 , 0 , 0 ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG + divide start_ARG 1 end_ARG start_ARG 2 end_ARG. In this case the variable w3w_{3}italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT is unused, and so does not contribute to the LLC.

  • (w1,w2,w3)=w12+w24+w34\ell(w_{1},w_{2},w_{3})=w_{1}^{2}+w_{2}^{4}+w_{3}^{4}roman_ℓ ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT is degenerate and its level sets are, for our purposes, not submanifolds. The singular function germ (,0)(\ell,0)( roman_ℓ , 0 ) is an object of algebraic geometry, and the appropriate mathematical object is not a manifold or a variety but a scheme. The quartic terms contribute 14\tfrac{1}{4}divide start_ARG 1 end_ARG start_ARG 4 end_ARG to the LLC, so that λ(0,0,0)=12+14+14=1\lambda(0,0,0)=\tfrac{1}{2}+\tfrac{1}{4}+\tfrac{1}{4}=1italic_λ ( 0 , 0 , 0 ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG + divide start_ARG 1 end_ARG start_ARG 4 end_ARG + divide start_ARG 1 end_ARG start_ARG 4 end_ARG = 1. The higher the power of a variable, the greater the degeneracy and the lower the LLC.

Figure˜2 offers several additional examples, from left to right:

  • A quadratic potential 1(w1,w2)=w12+w22\ell_{1}(w_{1},w_{2})=w_{1}^{2}+w_{2}^{2}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, for which the LLC is maximal in two dimensions, λ1(0,0)=d/2=1\lambda_{1}(0,0)=d/2=1italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( 0 , 0 ) = italic_d / 2 = 1.

  • A quartic potential 2(w1,w2)=w14+w24\ell_{2}(w_{1},w_{2})=w_{1}^{4}+w_{2}^{4}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, for which the LLC is λ2(0,0)=1/2\lambda_{2}(0,0)=1/2italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 0 , 0 ) = 1 / 2.

  • An even more degenerate potential 3(w1,w2)=w12w24\ell_{3}(w_{1},w_{2})=w_{1}^{2}w_{2}^{4}roman_ℓ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, for which λ3(0,0)=1/4\lambda_{3}(0,0)=1/4italic_λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( 0 , 0 ) = 1 / 4. We note that Hessian-derived metrics cannot distinguish between this degeneracy and the preceding quartic degeneracy.

  • A qualitatively distinct potential 4(w1,w2)=(w11)2(w12+w22)4\ell_{4}(w_{1},w_{2})=(w_{1}-1)^{2}(w_{1}^{2}+w_{2}^{2})^{4}roman_ℓ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT from Lau et al. (2025) with the same LLC at the origin, λ4(0,0)=1/4\lambda_{4}(0,0)=1/4italic_λ start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( 0 , 0 ) = 1 / 4.

While nondegenerate functions can be locally written as quadratic forms by the Morse Lemma (and are thus qualitatively similar to the approximation obtained from their Hessians), there is no simple equivalent for degenerate functions, such as the population losses of deep neural networks.

A.3 Estimating LLCs with SGLD

We follow Lau et al. (2025) in using SGLD to estimate the expectation value of the loss in the estimator of the LLC. For a given choice of weights ww^{*}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT we sample CCitalic_C independent chains with TSGLDT_{\mathrm{SGLD}}italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT steps per chain. Each chain ccitalic_c is a sequence of weights {wτ(c)}τ=1TSGLD\{w^{(c)}_{\tau}\}_{\tau=1}^{T_{\mathrm{SGLD}}}{ italic_w start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_τ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. From these samples, we estimate the expectation 𝔼w|w,γβ[𝒪(w)]\mathbb{E}^{\beta}_{w|w^{*},\gamma}[\mathcal{O}(w)]blackboard_E start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w | italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_γ end_POSTSUBSCRIPT [ caligraphic_O ( italic_w ) ] of an observable 𝒪\mathcal{O}caligraphic_O by

1CTSGLDc=1Cτ=1TSGLD𝒪(wτ(c)),\frac{1}{CT_{\mathrm{SGLD}}}\sum_{c=1}^{C}\sum_{\tau=1}^{T_{\mathrm{SGLD}}}\mathcal{O}(w_{\tau}^{(c)}),divide start_ARG 1 end_ARG start_ARG italic_C italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_τ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT end_POSTSUPERSCRIPT caligraphic_O ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT ) , (5)

with an optional burn-in period. Dropping the chain index ccitalic_c, each sample in a chain is generated according to:

wτ+1\displaystyle w_{\tau+1}italic_w start_POSTSUBSCRIPT italic_τ + 1 end_POSTSUBSCRIPT =wτ+Δwτ,\displaystyle=w_{\tau}+\Delta w_{\tau},= italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT + roman_Δ italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT , (6)
w1\displaystyle w_{1}italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =w,\displaystyle=w^{*},= italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , (7)

where the step Δwτ\Delta w_{\tau}roman_Δ italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT comes from an SGLD update

Δwτ=ϵ2(βnm(τ)(wτ)+γ2(wτw))+𝒩(0,ϵ).\Delta w_{\tau}=\frac{\epsilon}{2}\left(\beta n\nabla\ell_{m}^{(\tau)}(w_{\tau})+\tfrac{\gamma}{2}\left(w_{\tau}-w^{*}\right)\right)+\mathcal{N}(0,\epsilon)\,.roman_Δ italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT = divide start_ARG italic_ϵ end_ARG start_ARG 2 end_ARG ( italic_β italic_n ∇ roman_ℓ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) + divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) + caligraphic_N ( 0 , italic_ϵ ) . (8)

In each step τ\tauitalic_τ we sample a mini-batch of size mmitalic_m and the associated empirical loss, denoted m(τ)\ell_{m}^{(\tau)}roman_ℓ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT, is used to compute the gradient in the SGLD update. We note that LLC estimator defined in ˜3 uses the expectation 𝔼β[n(w)]\mathbb{E}^{\beta}[\ell_{n}(w)]blackboard_E start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT [ roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) ] which in the current notation means we should take 𝒪(w)\mathcal{O}(w)caligraphic_O ( italic_w ) to be n(w)\ell_{n}(w)roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ). For computational efficiency we follow Lau et al. (2025) in recycling the mini-batch losses m(wτ(c))\ell_{m}(w^{(c)}_{\tau})roman_ℓ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_w start_POSTSUPERSCRIPT ( italic_c ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) computed during the SGLD process. That is, we take 𝒪=m(τ)\mathcal{O}=\ell_{m}^{(\tau)}caligraphic_O = roman_ℓ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT rather than 𝒪=n\mathcal{O}=\ell_{n}caligraphic_O = roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

Time and Space Complexity.

The computational cost per LLC estimate is proportional to a standard training step, denoted SSitalic_S. We expect to require a constant number CTSGLDCT_{\text{SGLD}}italic_C italic_T start_POSTSUBSCRIPT SGLD end_POSTSUBSCRIPT of samples (on the order of 10210^{2}10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT) to yield robust estimates, independent of model size. Using logarithmically spaced checkpoints, the total computational complexity for generating an LLC curve over the entire training process scales as O(SCTSGLDlogN)O(SCT_{\text{SGLD}}\log N)italic_O ( italic_S italic_C italic_T start_POSTSUBSCRIPT SGLD end_POSTSUBSCRIPT roman_log italic_N ), where NNitalic_N is the total number of training steps. The space complexity incurs a modest linear overhead compared to standard SGD, requiring storage for one additional copy of the weights to enable localization.

A.4 LLC estimation experiment details

A.4.1 LLC estimation details for language models

For language models, we use SGLD to sample 20 independent chains with 200 steps per chain and 1 sample per step. For the one-layer model, we used ϵ=0.003,γ=300\epsilon=0.003,\gamma=300italic_ϵ = 0.003 , italic_γ = 300, and for the two-layer model we used ϵ=0.001,γ=100\epsilon=0.001,\gamma=100italic_ϵ = 0.001 , italic_γ = 100. Estimating the LLC across all checkpoints took around 200 GPU hours for the two-layer model on a single A100 and around 125 GPU hours for the one-layer model. For additional runs of the two-layer model, we ran fewer chains, bringing the time down to about 2 TPU hours per training run.

We sampled a separate set of 1 million lines (lines 10m-11m) from the DSIR filtered Pile, denoted as DsgldD_{\text{sgld}}italic_D start_POSTSUBSCRIPT sgld end_POSTSUBSCRIPT. The first 100,000 lines from this SGLD set (lines 10m-10.1m) were used as a validation set. The sampling of batches for SGLD mirrored the approach taken during the primary training phase. Each SGLD estimation pass was seeded analogously, so, at different checkpoints, the SGLD chains encounter the same selection of batches and injected Gaussian noise.

Table 1: Hyperparameters for estimating the LLC for language models.
Hyperparameter Category Description/Notes 1-Layer 2-Layer
C Sampler # of chains 202020
TSGLDT_{\mathrm{SGLD}}italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT Sampler # of SGLD steps / chain 200200200
ϵ\epsilonitalic_ϵ SGLD Step size 0.0030.0030.003 0.0010.0010.001
γ\gammaitalic_γ SGLD Localization strength 300 100
nβn\betaitalic_n italic_β SGLD Inverse temperature 21.7
mmitalic_m SGLD The size of each SGLD batch 100
μ\muitalic_μ Data Dataset size for gradient minibatches 13m
A.4.2 LLC estimation details for in-context linear regression

For in-context linear regression models, we generate a fixed dataset of 2202^{20}2 start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT samples. Using SGLD, we sample 101010 independent chains with 5,000 steps per chain, of which the first 1,000 are discarded as a burn-in, after which we draw observations once per step, at a temperature nβ=66.7n\beta=66.7italic_n italic_β = 66.7, ϵ=0.0003\epsilon=0.0003italic_ϵ = 0.0003, and γ=13.3\gamma=13.3italic_γ = 13.3, over batches of size m=1024m=1024italic_m = 1024. LLC estimation takes up to 72 CPU-hours per training run.

Table 2: LLC estimation hyperparameters. A summary of the hyperparameters involved in estimating the LLC and the default values we use.
Hyperparameter Category Description/Notes Default Values
C Sampler # of chains 101010
TSGLDT_{\mathrm{SGLD}}italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT Sampler # of SGLD steps / chain 500050005000
ϵ\epsilonitalic_ϵ SGLD Step size 0.00030.00030.0003
γ\gammaitalic_γ SGLD Localization strength 13.313.313.3
nβn\betaitalic_n italic_β SGLD Inverse temperature 66.766.766.7
mmitalic_m SGLD The size of each SGLD batch 102410241024
μ\muitalic_μ Data Dataset size for gradient minibatches 2202^{20}2 start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT

A.5 A guide to SGLD-based LLC estimation

This section walks through some of the hyperparameter choices and sweeps involved in calibrating LLC estimates. We provide it as a reference for others seeking to adjust LLC estimation to novel settings.

Refer to caption
Figure A.1: Past some threshold, the choice of validation set size from which SGLD batches are sampled has little effect on learning coefficient estimates. Estimation hyperparameters are C=8,TSGLD=2,000,m=1024,ϵ=0.0003,γ~=0.01,β~=0.01C=8,T_{\mathrm{SGLD}}=2,000,m=1024,\epsilon=0.0003,\tilde{\gamma}=0.01,\tilde{\beta}=0.01italic_C = 8 , italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT = 2 , 000 , italic_m = 1024 , italic_ϵ = 0.0003 , over~ start_ARG italic_γ end_ARG = 0.01 , over~ start_ARG italic_β end_ARG = 0.01. Loss is evaluated over gradient minibatches at a representative selection of checkpoints. LLCs quickly converge to a constant value as the size increases.
Refer to caption
Figure A.2: The size of SGLD minibatches has a negligible effect on LLC estimates (at least among the batch sizes considered). Top: Loss is evaluated on the same minibatch as the SGLD gradients. Bottom: Loss is evaluated on a newly sampled minibatch from the SGLD gradients (of the same size). Estimation hyperparameters are C=8,TSGLD=2,000,μ=220C=8,T_{\mathrm{SGLD}}=2,000,\mu=2^{20}italic_C = 8 , italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT = 2 , 000 , italic_μ = 2 start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT.
A.5.1 Varying the temperature

In Lau et al. (2025), the inverse temperature β\betaitalic_β is set to a fixed “optimal” value β=1/logn\beta^{*}=1/\log nitalic_β start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = 1 / roman_log italic_n, where nnitalic_n is the number of training samples. In practice, we find that it can be advantageous to sample at a higher temperature.

Since β\betaitalic_β always shows up in a product with nnitalic_n (in ˜8 for the SGLD step and in ˜3 for the LLC), we can view the inverse temperature as a multiplier that adjusts the effective size of your dataset. In a Bayesian setting, β=2\beta=2italic_β = 2 would mean updating twice on each of the samples in your dataset.

The problem with the default choice of β\beta^{*}italic_β start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is that as we increase nnitalic_n we have to decrease the SGLD step size ϵ\epsilonitalic_ϵ to prevent the update from becoming ill-conditioned, and this eventually causes the gradient term to suppress the noise term. This, in turn, leads to requiring larger batches to suppress the gradient noise and requiring longer chains to sufficiently explore the local posterior (Section˜A.5.3).

Instead of nβ=n/lognn\beta=n/\log nitalic_n italic_β = italic_n / roman_log italic_n, we perform LLC estimation at nβ=m/logmn\beta=m/\log mitalic_n italic_β = italic_m / roman_log italic_m, where mmitalic_m is the SGLD batch size.

A.5.2 Seeding the random noise

To smooth out the λ^t\hat{\lambda}_{t}over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT curves, we reset the random seed before LLC estimation run at each checkpoint. This means the sequence of injected Gaussian noise is the same for LLC estimation runs at different checkpoints. Additionally, if the batch size is held constant, the batch schedule will also be constant across different estimation runs. Figure˜A.3 shows that this does not affect the overall shape of the learning coefficient curves; it simply smooths it out.

Refer to caption
Figure A.3: Consistently seeding SGLD estimates at each checkpoint smooths out the resulting LLC-over-time curve. Except towards the end of training (this is plotted over a log time axis), the difference is barely noticeable. Variable seeds yield a noisier set of estimates.
A.5.3 Calibrating ϵ\epsilonitalic_ϵ, β\betaitalic_β, and γ\gammaitalic_γ

As a rule of thumb, ϵ\epsilonitalic_ϵ should be large enough that the λ^\hat{\lambda}over^ start_ARG italic_λ end_ARG estimate converges within the TSGLDT_{\mathrm{SGLD}}italic_T start_POSTSUBSCRIPT roman_SGLD end_POSTSUBSCRIPT steps of each chain but not too large that you run into issues with numerical stability and divergent estimates. Subject to this constraint, γ\gammaitalic_γ should be as small as possible to encourage exploration without enabling the chains to “escape” to nearby better optima, and β\betaitalic_β should be as large as possible (but no greater than 1/logn1/\log n1 / roman_log italic_n).

To determine the optimal SGLD hyperparameters, we perform a grid sweep over a reparametrization of the SGLD steps in terms of β~,γ~,ε\tilde{\beta},\tilde{\gamma},\varepsilonover~ start_ARG italic_β end_ARG , over~ start_ARG italic_γ end_ARG , italic_ε:

Δwt=β~m(τ)+γ~(wwt)+𝒩(0,ε),\Delta w_{t}=\tilde{\beta}\nabla\ell^{(\tau)}_{m}+\tilde{\gamma}(w^{*}-w_{t})+\mathcal{N}(0,\varepsilon),roman_Δ italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = over~ start_ARG italic_β end_ARG ∇ roman_ℓ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT + over~ start_ARG italic_γ end_ARG ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT - italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + caligraphic_N ( 0 , italic_ε ) ,

where β~=εβn/2\tilde{\beta}=\varepsilon\beta n/2over~ start_ARG italic_β end_ARG = italic_ε italic_β italic_n / 2, γ~=εγ/4\tilde{\gamma}=\varepsilon\gamma/4over~ start_ARG italic_γ end_ARG = italic_ε italic_γ / 4.

The results of this hyperparameter sweep are illustrated in Figure˜A.4 for final checkpoints. Separately (not pictured), we check the resulting hyperparameters for a subset of earlier checkpoints. This is needed since, for example, a well-behaved set of hyperparameters at the end of training may lead to failures like divergent estimates (Figure˜A.5) earlier in training when the geometry is more complex and thus the chains less stable.

Refer to caption
Figure A.4: Results of grid sweep over SGLD hyperparameters for model 0 at t=27kt=27\mathrm{k}italic_t = 27 roman_k.
A.5.4 LLC traces

As a useful diagnostic when calibrating the LLC estimates, we propose an online variant for learning coefficient estimation. When overlaid on top of individual-chain LLC traces, this helps reveal common failure modes like divergent estimates, non-converged estimates, and escapes (Figure˜A.5). These traces display the running estimate of λ^\hat{\lambda}over^ start_ARG italic_λ end_ARG as a function of the number of steps taken in a chain (with the estimate averaged across independent chains).

Refer to caption
(a) Numerical instability
Refer to caption
(b) Non-convergence
Refer to caption
(c) Negative estimates
Figure A.5: Failure modes of SGLD estimation. Top left: the gradient term is too large, leading to issues with numerical instability and exploding λ^\hat{\lambda}over^ start_ARG italic_λ end_ARG estimates. Top right: ϵ\epsilonitalic_ϵ is too small, leading to λ^\hat{\lambda}over^ start_ARG italic_λ end_ARG not converging within each chain. Bottom: the localization term is too small, which allows the chain to escape to better optima.

Define λ^τ(w0)\hat{\lambda}_{\tau}(w_{0})over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), the LLC at w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT after τ\tauitalic_τ time-steps for a single SGLD chain as follows (Lau et al., 2025):

λ^τ(w0)=nβ(1Tt=1Tn(wτ)n(w0)).\hat{\lambda}_{\tau}(w_{0})=n\beta\left(\frac{1}{T}\sum_{t=1}^{T}\ell_{n}(w_{\tau})-\ell_{n}(w_{0})\right).over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_n italic_β ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) .

Moving terms around, we get,

λ^τ(w0)\displaystyle\hat{\lambda}_{\tau}(w_{0})over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) =nlogn(1ττ=1τn(wτ)n(w0))\displaystyle=\frac{n}{\log n}\left(\frac{1}{\tau}\sum_{\tau^{\prime}=1}^{\tau}\ell_{n}(w_{\tau^{\prime}})-\ell_{n}(w_{0})\right)= divide start_ARG italic_n end_ARG start_ARG roman_log italic_n end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG ∑ start_POSTSUBSCRIPT italic_τ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (9)
=nβ(τ1τ(1τ1τ=1τ1n(wτ)n(w0)+n(w0))+1τn(wτ)n(w0))\displaystyle=n\beta\left(\frac{\tau-1}{\tau}\left(\frac{1}{\tau-1}\sum_{\tau^{\prime}=1}^{\tau-1}\ell_{n}(w_{\tau}^{\prime})-\ell_{n}(w_{0})+\ell_{n}(w_{0})\right)+\frac{1}{\tau}\ell_{n}(w_{\tau})-\ell_{n}(w_{0})\right)= italic_n italic_β ( divide start_ARG italic_τ - 1 end_ARG start_ARG italic_τ end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_τ - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_τ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ - 1 end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) + divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (10)
=τ1τλ^τ1(w0)+nβ(1τn(wτ)+(τ1τ1)n(w0))\displaystyle=\frac{\tau-1}{\tau}\hat{\lambda}_{\tau-1}(w_{0})+n\beta\left(\frac{1}{\tau}\ell_{n}(w_{\tau})+\left(\frac{\tau-1}{\tau}-1\right)\ell_{n}(w_{0})\right)= divide start_ARG italic_τ - 1 end_ARG start_ARG italic_τ end_ARG over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_τ - 1 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + italic_n italic_β ( divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) + ( divide start_ARG italic_τ - 1 end_ARG start_ARG italic_τ end_ARG - 1 ) roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (11)
=1τ((τ1)λ^τ1(w0)+nβ(n(wτ)n(w0))),\displaystyle=\frac{1}{\tau}\left((\tau-1)\hat{\lambda}_{\tau-1}(w_{0})+n\beta\left(\ell_{n}(w_{\tau})-\ell_{n}(w_{0})\right)\right),= divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG ( ( italic_τ - 1 ) over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT italic_τ - 1 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + italic_n italic_β ( roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ) , (12)

where

λ^0(w0)=0.\hat{\lambda}_{0}(w_{0})=0.over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = 0 .

This can be easily extended to an online estimate over chains by averaging the update nβ(n(wτ)n(w0))n\beta\left(\ell_{n}(w_{\tau})-\ell_{n}(w_{0})\right)italic_n italic_β ( roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ) - roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) over multiple chains.

A.6 LLC estimates for a non-log-likelihood-based loss

In the main body, we apply the LLC to empirical loss functions that do not arise as the log likelihood of independent random variables, due to the repeated use of dependent sub-sequences. Here we explain that it is possible to define a proper negative log likelihood over independent observations for the in-context linear regression setting: similar observations can be made in the language modeling setting.

Let Π(k)\Pi(k)roman_Π ( italic_k ) be a probability distribution over the context length kkitalic_k. Ideally, the transformer would be trained to make predictions yky_{k}italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT given a context of length kkitalic_k where kkitalic_k is sampled from Π\Piroman_Π. With the given distribution over contexts this leads to a negative log likelihood of the form

L(w)=kpkL[k](w)L(w)=\sum_{k}p_{k}L_{[k]}(w)italic_L ( italic_w ) = ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT [ italic_k ] end_POSTSUBSCRIPT ( italic_w ) (13)

where pkp_{k}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the probability of sampling kkitalic_k from Π\Piroman_Π and

L[k](w)=q(Sk,yk|𝐭,k)q(𝐭)[fw(Sk)yk]2𝑑Sk𝑑yk𝑑𝐭L_{[k]}(w)=\int q(S_{k},y_{k}|\mathbf{t},k)q(\mathbf{t})\Big{[}f_{w}(S_{k})-y_{k}\Big{]}^{2}\,dS_{k}\,dy_{k}\,d\mathbf{t}italic_L start_POSTSUBSCRIPT [ italic_k ] end_POSTSUBSCRIPT ( italic_w ) = ∫ italic_q ( italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | bold_t , italic_k ) italic_q ( bold_t ) [ italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_d italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_d bold_t (14)

using the notation of Section˜3 so Sk=(x1,y1,,xk1,yk1,xk)S_{k}=(x_{1},y_{1},\ldots,x_{k-1},y_{k-1},x_{k})italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) is a context of length kkitalic_k. It is straightforward to check that this negative log likelihood LLitalic_L agrees with the population loss \ellroman_ℓ associated to the empirical loss defined in Section˜3. However the empirical quantities Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) and n(w)\ell_{n}(w)roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) defined for a set of samples of size nnitalic_n are not the same.

Since we use the empirical loss n\ell_{n}roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT in our calculation of the estimated LLC, whereas the foundational theory of SLT is written in terms of the empirical negative log likelihood LnL_{n}italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, it is natural to wonder how much of a difference this makes in practice. Figure˜A.6 depicts LLC traces (Section˜A.5) for a highlighted number of checkpoints using either a likelihood-based estimate (with variable sequence length) or loss-based estimate (with fixed sequence length). The relative orderings of complexities does not change, and even the values of the LLC estimates do not make much of a difference, except at the final checkpoint, which has a higher value for the sub-sequence-based estimate.

Refer to caption
Figure A.6: Loss-based (left) and likelihood-based (right) LLC estimation yield identically ordered LLC estimates. With the exception of final checkpoint’s LLC estimate (which is larger for the loss-based estimate), the values are close to identical. These plots display LLC traces, which show the LLC estimate as a function of SGLD steps. This is a useful tool for calibrating LLC estimation (Section˜A.5).

A.7 LLC estimates away from local minima

Our methodology for detecting stages is to apply LLC estimation to compute λ^(w)\hat{\lambda}(w^{*})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) at neural network parameters w=wt{w^{*}}=w_{t}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT across training. In the typical case these parameters will not be local minima of the population loss, violating the theoretical conditions under which the LLC is defined.

It is not surprising that the estimator appears to work if w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is approximately a local minima. Lau et al. (2025) validated their estimator at both parameters constructed to be local minima of the population loss and also at parameters found through training with stochastic gradient descent (possibly not local minima of the empirical loss, let alone the population loss). They showed that in both cases the estimator recovers the true learning coefficient associated with the global minimum of the population loss.

On the other hand, if w{w^{*}}italic_w start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is far from any local minima, it is a priori quite surprising that the SGLD-based estimation procedure works at all, as in this situation one might expect the chains to explore directions in which the loss decreases. Nevertheless, Chen et al. (2023) found that, empirically, LLC estimation away from local minima appears to give sensible results in practice. In our case, with sufficient localization we see stable estimates throughout training.

Theoretically accounting for this phenomenon is an interesting open problem. Perhaps there is a notion of stably evolving equilibrium in the setting of neural network training, echoing some of the ideas of Waddington (1957), such that the LLC estimation procedure is effectively giving us the LLC of a different potential to the population loss—a potential for which the current parameter actually is at a critical point. We leave addressing this question to future work.

Appendix B LLC-based stage boundary identification

B.1 Procedure for stage boundary identification

To identify stage boundaries, we look for plateaus in the LLC: checkpoints at which the slope of λ^(wt)\hat{\lambda}(w_{t})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) over ttitalic_t vanishes. To mitigate noise in the LLC estimates, we first fit a Gaussian process with some smoothing to the LLC-over-time curve. Then we numerically calculate the slope of this Gaussian process with respect to logt\log troman_log italic_t. The logarithm corrects for the fact that the learning coefficient, like the loss, changes less as training progresses. We identify stage boundaries by looking for checkpoints at which this estimated slope equals zero. The results of this procedure are depicted in Figure˜B.1 for language and Figure˜B.2 for in-context linear regression.

At a local minima or maxima of the estimated LLC curve identifying a plateau from this estimated slope is straightforward, since the derivative crosses the x-axis. However at a saddle point, the slope may not exactly reach zero, so we have to specify a “tolerance” for the absolute value of the derivative, below which we treat the boundary as an effective plateau.

In this case, we additionally require that the plateau be at a local minimum of the absolute first derivative. Otherwise, we may identify several adjacent points as all constituting a stage boundary.

To summarize, identifying stage boundaries is sensitive to the following choices: the intervals between checkpoints, the amount of smoothing, whether to differentiate with respect to ttitalic_t or logt\log troman_log italic_t, and the choice of tolerance. However, once a given choice of these hyperparameters is fixed, stages can be automatically identified, without further human judgment.

B.2 Stage boundary identification details for language model

Figure˜B.1 displays the test loss and LLC curves from Figure˜1(a) in addition to the weight norm over time and associated slopes. Stage boundaries coincide with where the slope of the LLC crosses zero, that is, where there is a plateau in the LLC.

Refer to caption
Figure B.1: A more detailed version of Figure˜1(a) for two-layer language models. Top: Loss, LLC, and weight norm, along with an overlaid Gaussian process fit to these curves (red dotted lines). Bottom: Associated slopes, both numerically estimated finite differences (transparent blue) and of the Gaussian process (red dotted lined). Note that stage LM5 may be subdivided into further stages (Section˜B.1). However, the noise in LLC estimates late in training is high, so we do not draw any conclusions from this.

B.3 Stage boundary identification details for in-context linear regression

Figure˜B.2 displays the test loss and LLC curves from Figure˜1(b) in addition to the weight norm over time, and numerically estimated slopes associated to these three metrics. As in the case of language models, we identify stage boundaries by looking for plateaus in the LLC. Unlike the language models, here the boundaries LR1LR2 and LR2LR3 are clearly visible in the loss.

Refer to caption
Figure B.2: A more detailed version of Figure˜1(b) for in-context linear regression. Top: Loss, LLC, and weight norm, along with an overlaid Gaussian process fit to these curves (red dotted lines). Bottom: Associated slopes, both numerically estimated finite differences (transparent blue) and of the Gaussian process (red dotted lined). Top middle: Error bars displaying the standard deviation over the 10 SGLD chains are displayed in the background. Note that large error bars across chains are to be expected. Between different SGLD estimations, the variance is much lower. For example, averaged over training, the standard deviation over different seeds is only 4.2.

B.4 Stage identification for additional training runs

Refer to caption
(a) Two-layer attention-only language transformers.
Refer to caption
(b) In-context linear regression transformers.
Figure B.3: Figure˜1(a) and Figure˜1(b) for multiple seeds. In both settings, LLC reveals a consistent set of stages across five seeds. Late-training behavior shows more variance across seeds (see Sections˜B.4 and B.4).

Figure˜3(a) shows loss and LLC curves for five seeds (differing in model initialization and batch schedule). In each seed, LLC estimation reveals stage LM1LM4. In three of the five seeds, stage LM5 is subdivided into two additional stages.

Figure˜3(b) shows loss and LLC curves for five unique seeds (differing in model initialization and batch schedule). In each seed, LLC estimation reveals stages LR1LR5. There is remarkably little variance across different seeds.

B.5 Comparison to Hessian statistics

Figure˜B.4 shows a quantification of the curvature-based notion of flatness captured by the Hessian (in contrast to the degeneracy-based notion of flatness captured by the LLC) for our in-context linear regression transformer. To estimate the trace and maximum eigenvalues shown in this figure, we use the PyHessian library (Yao et al., 2020) over a batch of m=1024m=1024italic_m = 1024 samples.

Crucially, we observe that these Hessian-derived metrics (henceforth, “curvature”) and the LLC are not consistently correlated. During the first part of LR2, the LLC and the curvature are jointly increasing. Starting at around t=20kt=20\mathrm{k}italic_t = 20 roman_k, while the LLC is still increasing, the curvature starts decreasing. In the first part of LR3, both metrics decrease in tandem, but as of around t=120kt=120\mathrm{k}italic_t = 120 roman_k, the curvature turns around and starts increasing.

The Hessian fails to detect three of the four stage boundaries identified by our LLC-based methodology. Since these Hessian-based metrics are dominated by the largest eigenvalues—the directions of maximum curvature—they fail to observe the finer-grained measures of degeneracy that dominate the LLC. Moreover, we observe that LLC estimation is more scalable (empirically, it seems to be roughly linear in parameter count) than estimating the full Hessian (which is quadratic).

Refer to caption
Figure B.4: Hessian-based statistics reveal only one stage boundary in the development of our in-context linear regression transformer.

Appendix C Developmental analysis of language models

In this section, we present further evidence on behavioral (Section˜C.1) and structural (Section˜C.2) development of the language model over the course of training.

C.1 Behavioral development

C.1.1 Bigram score

We empirically estimate the conditional bigram distribution by counting instances of bigrams over the training data. From this, we obtain the conditional distribution q~(t|t)\tilde{q}(t^{\prime}|t)over~ start_ARG italic_q end_ARG ( italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_t ), the likelihood that a token tt^{\prime}italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT follows ttitalic_t. The bigram score BkSB_{k}^{S}italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT at index kkitalic_k of an input context SSitalic_S is the cross entropy between the model’s predictions p(tk+1|tk)p(t_{k+1}|t_{k})italic_p ( italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT | italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) at that position and the empirical bigram distribution,

BkS=i=1dvocabq~(tk+1(i)|tk)logp(tk+1(i)|tk),B_{k}^{S}=-\sum_{i=1}^{d_{\textrm{vocab}}}\tilde{q}(t_{k+1}^{(i)}|t_{k})\log p(t_{k+1}^{(i)}|t_{k}),italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S end_POSTSUPERSCRIPT = - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT vocab end_POSTSUBSCRIPT end_POSTSUPERSCRIPT over~ start_ARG italic_q end_ARG ( italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) roman_log italic_p ( italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT | italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (15)

where the tk+1(i)t_{k+1}^{(i)}italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT range over the possible second tokens from the tokenizer vocabulary. From this we obtain the average bigram score

B¯=1ni=1nBkiSi,\bar{B}=\frac{1}{n}\sum_{i=1}^{n}B_{k_{i}}^{S_{i}},over¯ start_ARG italic_B end_ARG = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_B start_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (16)

where we take fixed random sequences of kik_{i}italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and SiS_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for 1in=5,0001\leq i\leq n=5,0001 ≤ italic_i ≤ italic_n = 5 , 000, which is displayed over training in Figure˜4(a). This is compared against the best-achievable bigram score, which is the bigram distribution entropy itself, averaged over the validation set.

C.1.2 nnitalic_n-gram scores

In stage LM2 we consider nnitalic_n-grams, which are sequences of nnitalic_n consecutive tokens, meaning 222-grams and bigrams are the same. Specifically, we consider common nnitalic_n-grams, which is defined heuristically by comparing our 5,000 vocab size tokenizer with the full GPT-2 tokenizer. We use the GPT-2 tokenizer as our heuristic because its vocabulary is constructed iteratively by merging the most frequent pairs of tokens.

We first tokenize the tokens in the full GPT-2 vocabulary to get a list of 50,257 nnitalic_n-grams for various nnitalic_n. The first 5,000 such nnitalic_n-grams are all 111-grams, after which 222-grams begin appearing, then 333-grams, 444-grams, and so on (where 222-grams and 333-grams may still continue to appear later in the vocabulary). We then define the set of common nnitalic_n-grams as the first 1,000 nnitalic_n-grams that appear in this list for a fixed nnitalic_n, n2n\geq 2italic_n ≥ 2.

If we track the performance on nnitalic_n-grams and see it improve, we may ask whether this is simply a function of the model learning to use more context in general, rather than specifically improving on the set of nnitalic_n-grams being tracked. We measure performance against this baseline by defining an nnitalic_n-gram score. For a fixed nnitalic_n, we obtain the average loss gramn\ell_{\textrm{gram}}^{n}roman_ℓ start_POSTSUBSCRIPT gram end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT of the model on predicting the final tokens of our set of 1,000 nnitalic_n-grams and also obtain the average loss testn\ell_{\textrm{test}}^{n}roman_ℓ start_POSTSUBSCRIPT test end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT of the model on a validation set at position nnitalic_n of each validation sequence. The nnitalic_n-gram score is then defined to be testn/gramn\ell_{\textrm{test}}^{n}/\ell_{\textrm{gram}}^{n}roman_ℓ start_POSTSUBSCRIPT test end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT / roman_ℓ start_POSTSUBSCRIPT gram end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

C.1.3 In-context learning score

The in-context learning score is a behavioral measure of the relative performance of a model later in a sequence versus earlier in the sequence. We define ICLk1:k2\text{ICL}_{k_{1}:k_{2}}ICL start_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT : italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT to be the loss on token k2k_{2}italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT minus the loss on token k1k_{1}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, so a more negative score indicates better relative performance later in the sequence. A more negative ICL score does not, however, mean that a model is achieving better overall loss on later tokens; it is only about the relative improvement. For the language model, we follow a similar construction as Olsson et al. (2022), where we take k2k_{2}italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT to be the 500th token and k1k_{1}italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to be the 50th token. This is then averaged over a 100k-row validation dataset. The performance of the language model over the course of training can be seen at the bottom of Figure˜4(f).

C.1.4 Visualizing behavioral changes

In Figure˜C.1, we visualize changes in the model’s input/output behavior by comparing model predictions before and after developmental stages and highlighting tokens with the greatest differences.

Refer to caption
Figure C.1: Samples are shown with tokens highlighted to indicate changes in logits during a given range. Red is improved performance (higher logit output for the true next token) and blue is worse. Sample (a): improvement in bigrams (LM1) such as “te/ll, ab/out, des/ire, mot/ion, eng/aged, strugg/le, etc." Sample (b): improvement in common nnitalic_n-grams (LM2) such as “L/in/ux, P/y/th/on, h/on/or/able, S/up/reme, dat/ab/ase, f/ram/ew/ork." Sample (c): development of in-context learning via induction circuits (LM3, LM4), visible in the improved predictions in the word “D/urs/ley" after the first time it appears in the context, as initially observed by (Olsson et al., 2022).

C.2 Structural development

C.2.1 Positional embedding

In Figure˜C.2, we measure the effect of the positional embedding on model performance by comparing the model’s performance at particular context positions on a validation set over the course of training against performance on the same validation set but with the positional embedding zero-ablated. The full context length is 1024, and we measure test loss at positions 1, 2, 3, 5, 10, 20, 30, 50, 100, 200, 300, 500, and 1000. In the transition from stage LM1 to LM2, the model begins using the learnable positional embedding to improve performance. The difference between test loss with and without the positional ablation is negligible at all measured positions until the LM1LM2 boundary.

Refer to caption
Figure C.2: The model learns to start using the positional encoding in LM2, when the performance starts to worsen when ablating the positional encoding. In both plots, earlier token positions are colored more purple, while later token positions are more yellow, and the overall mean loss is colored in red. Both sets of per-token losses are shown in both graphs for ease of comparison. Left: original test loss is emphasized. Right: test loss with the positional embedding ablated is emphasized.

Structurally, we might predict that the positional embeddings should organize themselves in a particular way: in order to understand relative positions, adjacent positions should be embedded close to each other, and far-away positions should be embedded far apart.

In Figure˜C.3, we examine the development of the positional embedding itself over time from two angles. The first is to take the embeddings of each position in the context and to run PCA on those embeddings. The result is that as training progresses, the positional embedding PCAs gradually resolve into Lissajous curves, suggesting that the positional embeddings might look like a random walk (Antognini & Sohl-Dickstein, 2018; Shinn, 2023). However, if we look to the explained variance, we see that it grows very large for PC1, reaching 94.2%94.2\%94.2 % at training step 6400. This is much higher than we would expect for Brownian motion, where we expect to see about 61%61\%61 % explained variance in PC1 (Antognini & Sohl-Dickstein, 2018).

The second perspective we use is to look at how the magnitudes of positional embeddings over the context length develop. In this case, we observe that the magnitudes seem to have a fairly regular structure. In conjunction with the PCAs and explained variance, we might infer that the positional embeddings look approximately like a (possibly curved) line in dmodel=256d_{\textrm{model}}=256italic_d start_POSTSUBSCRIPT model end_POSTSUBSCRIPT = 256 dimensional space. A positional embedding organized in this way would make it easier for an attention head to attend to multiple recent tokens, which is necessary if a single head is to learn nnitalic_n-grams.

Refer to caption
Figure C.3: Columns progress through training time at training steps 0, 400, 800, 1600, 3200, and 6400. The first three rows are plots of the first three principle components of PCA on the positional embedding weights, while the fourth row shows the explained variance for each of the principal components. The fifth row plots the magnitude of the embedding of each position in the context length of 1024.
C.2.2 Composition scores

Let WQh,WKh,WVhW_{Q}^{h},W_{K}^{h},W_{V}^{h}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT , italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT be the query, key, and value weights of attention head hhitalic_h respectively. There are three types of composition between attention heads in transformer models in Elhage et al. (2021):

  • Q-Composition: the query matrix WQhW_{Q}^{h}italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT of an attention head reads in a subspace affected by a previous head

  • K-Composition: the key matrix WKhW_{K}^{h}italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT of an attention head reads in a subspace affected by a previous head

  • V-Composition: the value matrix WVhW_{V}^{h}italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT of an attention head reads in a subspace affected by a previous head

If WOhW_{O}^{h}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT is the output matrix of an attention head, then WQKh=WQhTWKhW_{QK}^{h}=W_{Q}^{h\ T}W_{K}^{h}italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT = italic_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h italic_T end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT and WOVh=WOhWVhW_{OV}^{h}=W_{O}^{h}W_{V}^{h}italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT = italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT. The composition scores are

MWOVh1F/(MFWOVh1F)||MW_{OV}^{h1}||_{F}/(||M||_{F}||W_{OV}^{h_{1}}||_{F})| | italic_M italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h 1 end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT / ( | | italic_M | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT | | italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (17)

Where M=WQKh2TM=W_{QK}^{h_{2}\ T}italic_M = italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_T end_POSTSUPERSCRIPT, M=WQKh2M=W_{QK}^{h_{2}}italic_M = italic_W start_POSTSUBSCRIPT italic_Q italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and M=WOVh2M=W_{OV}^{h_{2}}italic_M = italic_W start_POSTSUBSCRIPT italic_O italic_V end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for Q-, K-, and V-Composition respectively. See Figure˜C.4 for K-composition scores over time between attention heads in the induction circuits.

Refer to caption
Figure C.4: The K-composition scores (Elhage et al., 2021) between first and second layer attention heads. The hhitalic_hth attention head in layer llitalic_l is indexed by l:hl:hitalic_l : italic_h. The attention heads that eventually become previous token heads are h=2,5h=2,5italic_h = 2 , 5 in layer 1 (subplot rows 2 and 3), and the attention heads that eventually become induction heads are h=7,8h=7,8italic_h = 7 , 8 in layer 2 (subplot columns 2 and 3). The attention heads 1:11:11 : 1 and 2:12:12 : 1 are included for comparison. The induction heads 2:72:72 : 7 and 2:82:82 : 8 begin K-composing with first layer heads near the start of stage LM2. They continue to compose with the previous token heads in stages LM3 and LM4 (highlighted in green) while their K-composition scores drop with other attention heads in layer 1 in later stages.
C.2.3 Previous-token matching score

The previous-token matching score is a structural measure of induction head attention. It is the attention score given to [A][A][ italic_A ] by an attention head at [B][B][ italic_B ] in the sequence [A][B]\ldots[A][B]… [ italic_A ] [ italic_B ] (i.e., how much the head attends to the immediately preceding token).

We compute this score using a synthetic data generating process, generating 10k fixed random sequences with length between 16 and 64. The first token is a special “beginning of string" token, and the remaining tokens are uniformly randomly sampled from other tokens in the vocabulary.

For each sample in this synthetic dataset, we measure the attention score that an attention head gives to the previous token when at the last token in the sequence. These scores are averaged across the dataset to produce the previous-token matching score for that attention head at a given checkpoint. The progression of previous-token matching scores over time can be seen in Figure˜4(d).

C.2.4 Prefix matching score

The prefix matching score from Olsson et al. (2022) is defined similarly to the previous-token matching score. Given a sequence [A][B][A][A][B]\ldots[A][ italic_A ] [ italic_B ] … [ italic_A ], the prefix matching score of a particular attention head is how much the attention head attends back to the first instance of [A][A][ italic_A ] when at the second instance of [A][A][ italic_A ].

We compute this score using a synthetic data-generating process. We first generate 10k fixed random sequences of length 128. The first token is always a special “beginning of string" token and the [A][A][ italic_A ] and [B][B][ italic_B ] tokens are selected and placed randomly. One [A][A][ italic_A ] token is placed in the first half of the sequence, the other is placed in the second half, and the [B][B][ italic_B ] token is placed directly after the first [A][A][ italic_A ] token. The remaining tokens are randomly sampled from the tokenizer vocabulary, excluding the [A][A][ italic_A ], [B][B][ italic_B ], and beginning of string tokens.

For each sample in this synthetic dataset, we measure the attention score that each attention head assigns to the earlier instance of [A][A][ italic_A ] from the latter instance of [A][A][ italic_A ]. These scores are averaged across the dataset to produce the prefix matching score for that attention head at a given checkpoint. The progression of prefix matching scores over time can be seen in Figure˜4(e).

Appendix D Developmental analysis of regression transformers

In this section, we present further evidence on the behavioral (Section˜D.1) and structural (Section˜D.2) development of the transformer in the setting of in-context linear regression.

D.1 Behavioral development

D.1.1 Task prior score

In addition to training models on a data distribution in which tasks 𝐭\mathbf{t}bold_t are generated on-the-fly, we examine the setting of Raventós et al. (2023), in which a finite set of MMitalic_M tasks is generated ahead of time, and training samples involve randomly selected tasks from this set.

Figure˜D.1 depicts (a) the mean square distance between the model’s predictions and the zero prediction in addition to (b) the mean square distance between the model’s predictions and the “task prior” prediction, using the component-wise averaged 𝐭missing¯\overline{\mathbf{t}missing}over¯ start_ARG bold_t roman_missing end_ARG over the set of tasks encountered during training. For all models, the minimum distance to the task prior prediction is lower than the minimum distance to the zero prediction. Hence, we call stage LR1 “learning the task prior” rather than simply learning the zero prediction.

Refer to caption
Figure D.1: Learning the task prior is universal across models trained on very different data distributions. Each line represents a model trained on a data distribution with a different number of MMitalic_M distinct tasks (“task diversity” in Raventós et al., 2023). In addition to taking a finite MMitalic_M, the models depicted here differ from the other models considered in this paper in that the former were trained with a maximum learning rate of 0.010.010.01, and the models (inadvertently) lack an output matrix after the multi-head attention layer.
D.1.2 ICL

We consider two variants of the ICL score: ICL1:D\operatorname{ICL}_{1:D}roman_ICL start_POSTSUBSCRIPT 1 : italic_D end_POSTSUBSCRIPT, and ICLD:K\operatorname{ICL}_{D:K}roman_ICL start_POSTSUBSCRIPT italic_D : italic_K end_POSTSUBSCRIPT.

If the noise term σ2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT equals zero and both tasks 𝐭\mathbf{t}bold_t and inputs xkx_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are normalized (i.e., 𝐭SD1\mathbf{t}\in S^{D-1}bold_t ∈ italic_S start_POSTSUPERSCRIPT italic_D - 1 end_POSTSUPERSCRIPT), then D1D-1italic_D - 1 observations of input/output pairs are enough to precisely identify 𝐭\mathbf{t}bold_t. Therefore, ICL1:D\operatorname{ICL}_{1:D}roman_ICL start_POSTSUBSCRIPT 1 : italic_D end_POSTSUBSCRIPT measures how successful the model is at initially locating the task. The fact that the tasks and inputs are not normalized changes this only slightly: the task will still sit near SD1S^{D-1}italic_S start_POSTSUPERSCRIPT italic_D - 1 end_POSTSUPERSCRIPT within a shell of vanishing thickness as DD\to\inftyitalic_D → ∞.

Once localized, ICLD:K\operatorname{ICL}_{D:K}roman_ICL start_POSTSUBSCRIPT italic_D : italic_K end_POSTSUBSCRIPT measures how successfully the model refines its internal estimate of 𝐭\mathbf{t}bold_t with additional examples, which it can use to reduce the error due to noise.

In terms of implementation, it’s not necessary for the model to internally make a distinction between locating and refining its estimate of the task. For example, ridge regression makes no distinction. Still, we find it useful for reasoning about the progression of the model. In particular, we note that early in stage LR2, while the model begins to develop ICL for early tokens, it becomes worse at ICL over tokens late in the context. Later, at around 23k steps, ICLD:K\operatorname{ICL}_{D:K}roman_ICL start_POSTSUBSCRIPT italic_D : italic_K end_POSTSUBSCRIPT stabilizes, while ICL1:D\operatorname{ICL}_{1:D}roman_ICL start_POSTSUBSCRIPT 1 : italic_D end_POSTSUBSCRIPT continues improving over the entire training run.

Refer to caption
Figure D.2: ICL scores for the in-context linear regression model. Right: ICL scores between inputs 1 and 4 and inputs 4 and 8 over time. We see that ICL emerges during the first half of LR2. Left: Highlighted ICL score curves from the end of LR1 to halfway through LR2. Note that when the model first starts improving on early tokens, it temporarily becomes worse at predicting later tokens. Note also that the model ceases to become better at later tokens as of the second half of LR2, whereas ICL on early tokens continues to improve throughout training.
D.1.3 OOD generalization

To further investigate behavior in stages LR2 and LR3, we probe the model on data sampled from different distributions than encountered during training.111Cf. Raventós et al. (2023) evaluating models trained on a set of discrete tasks on the “true” distribution consisting of novel tasks. We evaluate behavior on two families of perturbations: “OOD inputs” xkx_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, sampled according to a different scale

xk𝒩(0,gID),x_{k}\sim\mathcal{N}(0,gI_{D}),italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_g italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) , (18)

for some gain parameter ggitalic_g, and “OOD tasks”

𝐭𝒩(0,gID).\mathbf{t}\sim\mathcal{N}(0,gI_{D}).bold_t ∼ caligraphic_N ( 0 , italic_g italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) . (19)

Note that these inputs and tasks are not out-of-distribution in the sense of coming from a distribution with a different support than the training distribution. However, the samples drawn from these “extreme” distributions are exponentially suppressed by the original training distribution. Figure˜D.3 plots the normalized MSE for these two distributions over training time.

Between t=1kt=1kitalic_t = 1 italic_k and t=4kt=4\mathrm{k}italic_t = 4 roman_k the model’s outputs rapidly diminish in scale for out-of-distribution samples, both for g>1g>1italic_g > 1 and g<1g<1italic_g < 1, especially for out-of-distribution inputs. While the model is moving away from predicting with the task prior for in-distribution samples, it moves closer to predicting with the task prior for-in-distribution samples.

Between t=4kt=4\mathrm{k}italic_t = 4 roman_k and t=23kt=23\mathrm{k}italic_t = 23 roman_k, the model recovers on moderately out-of-distribution inputs g<101.5g<10^{1.5}italic_g < 10 start_POSTSUPERSCRIPT 1.5 end_POSTSUPERSCRIPT with performance remaining close to constant beyond this range. Past this stage, performance improves constantly for out-of-distribution tasks.

For out-of-distribution inputs, performance eventually worsens for some ranges of ggitalic_g. Between t=23kt=23\mathrm{k}italic_t = 23 roman_k and t=80kt=80\mathrm{k}italic_t = 80 roman_k the model further approaches the task prior prediction for extreme out-of-distribution inputs g>101.5g>10^{1.5}italic_g > 10 start_POSTSUPERSCRIPT 1.5 end_POSTSUPERSCRIPT . Subsequently, between t=75kt=75\mathrm{k}italic_t = 75 roman_k and t=130kt=130\mathrm{k}italic_t = 130 roman_k the model moves away from the task prior prediction for extreme inputs, and performance deteriorates for inputs with g>100.5g>10^{0.5}italic_g > 10 start_POSTSUPERSCRIPT 0.5 end_POSTSUPERSCRIPT. As of LR5, performance is roughly constant.

Refer to caption
Figure D.3: Performance on extreme inputs over time may reveal additional substages in LR2 and in LR3. Left: The model first becomes better, then worsens at ICL on inputs sampled from 𝒩(0,gID)\mathcal{N}(0,gI_{D})caligraphic_N ( 0 , italic_g italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ) for large ggitalic_g. Right: The model continues to improve on ICL at tasks sampled from 𝒩(0,gID)\mathcal{N}(0,gI_{D})caligraphic_N ( 0 , italic_g italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ). Top: Normalized loss (divided by g2g^{2}italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT) over time for OOD inputs and tasks. Bottom: Average |y^||\hat{y}|| over^ start_ARG italic_y end_ARG | over time for OOD inputs and tasks.

D.2 Structural development

D.2.1 Embedding

The embedding matrix WEW_{E}italic_W start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT is a linear transformation from D+1dembed\mathbb{R}^{D+1}\to\mathbb{R}^{{d_{\text{embed}}}}blackboard_R start_POSTSUPERSCRIPT italic_D + 1 end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT embed end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Plotting the D+1D+1italic_D + 1 singular values of this matrix, we notice that the embedding partially loses one of its components starting at the end of LR2 (Figure˜D.4a).

The input “tokens” xkx_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT span a DDitalic_D-dimensional subspace of the (D+1)(D+1)( italic_D + 1 )-dimensional “token space.” The target tokens yky_{k}italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT span an orthogonal 111-dimensional subspace. The collapse of one of the embedding matrix’s singular values means that the model learns to redundantly encode the inputs and targets in the same DDitalic_D-dimensional subspace of the space of residual stream activations. The almost order of magnitude separation in the magnitudes of the square singular value means that the (D+1)(D+1)( italic_D + 1 )th component of the token embedding explains only 2.9% of the variance in activations of the residual stream immediately after the embedding, whereas the dominant components explain roughly 24% each.

Contributions to degeneracy

Given a linear transformation T1:D1D2T_{1}:\mathbb{R}^{D_{1}}\to\mathbb{R}^{D_{2}}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT followed by another linear transformation T2:D2D3T_{2}:\mathbb{R}^{D_{2}}\to\mathbb{R}^{D_{3}}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, reducing the rank of T1T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT from rritalic_r to r<rr^{\prime}<ritalic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_r renders D3(rr)D_{3}(r-r^{\prime})italic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_r - italic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) components of the second transformation irrelevant. This would mean a decrease in the learning coefficient of D3(rr)/2D_{3}(r-r^{\prime})/2italic_D start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_r - italic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) / 2 (a decrease in the effective dimensionality of dditalic_d leads to a decrease in the LLC of d/2d/2italic_d / 2222Note that this is not the only possible way for the LLC to decrease. Changing the local loss landscape from quadratic to quartic or some higher power would also lower the LLC, by a fractional amount.). In the actual model, we don’t see an exact decrease in the rank, and a layer normalization sits between the linear transformation of the embedding and the linear transformations of each transformer block and unembedding. It is unclear what the precise relation between structure and degeneracy is in this case (Section˜D.2.6). Still, suggestively, the onset of embedding collapse coincides with a decrease in the rate of increase of λ^(wt)\hat{\lambda}(w_{t})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ).

Refer to caption
Figure D.4: Left: The embedding partially “collapses” during the second half of LR2. At the start of stage LR2, the minimum singular values explains only 3% of the variance in residual stream activations due to the sample. By the end of training, it explains half that. Middle: The positional encoding goes through a similar shift during LR3 (that begins earlier during LR2). Right: The cosine similarity between the 5 rows of WembedW_{\text{embed}}italic_W start_POSTSUBSCRIPT embed end_POSTSUBSCRIPT and the projection of those rows onto the subspace spanned by WunembedW_{\text{unembed}}italic_W start_POSTSUBSCRIPT unembed end_POSTSUBSCRIPT shows that the model learns to write to the same write tokens and positional information to the same subspace.
D.2.2 Positional encoding

The positional encoding goes through a similar collapse to the unembedding starting during the second part of LR2 and continuing into LR3 (Figure˜D.4b). Additionally, throughout these stages, the subspace spanned by the embedding becomes more aligned with the subspace spanned by the positional encoding (Figure˜D.4c).

Contributions to degeneracy

For the same reason as with the token embedding, a decrease in the dimensionality of the subspace occupied by activations reduces the effective number of dimensions and thus the learning coefficient. This occurs both as the positional encoding’s effective dimensionality decreases (vanishing singular values, Figure˜D.4b) and as the token embedding subspace and positional embedding subspace align (increasing cosine similarity, Figure˜D.4b).

D.2.3 Attention collapse

Over the course of training, we observe that some attention heads learn to attend solely (soft attention becomes hard attention) and consistently to certain positions (the attention pattern becomes content-independent). We call this phenomenon attention collapse in parallel with the other observed forms of collapse. Not only does this potentially contribute to a decrease in the LLC, but it also makes the attention heads identifiable: we find a self-attention head, previous-attention heads, previous-xxitalic_x-attention heads, and previous-yyitalic_y-attention heads.

xxitalic_x-attention vs. yyitalic_y-attention

For convenience we separate each attention head in two: one part for the xxitalic_x-tokens, and the other for the yyitalic_y-tokens.

Attention entropy score

To quantify attention hardness, we use the attention entropy score (Ghader & Monz, 2017; Vig & Belinkov, 2019). Given the attention pattern αk,k(b,h)\alpha_{k,k^{\prime}}^{(b,h)}italic_α start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT for how much token kkitalic_k in head hhitalic_h in block bbitalic_b attends back to token kk^{\prime}italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, its attention entropy score Hk(b,h)H_{k}^{(b,h)}italic_H start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT is the Shannon entropy over preceding indices k<kk^{\prime}<kitalic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_k,

Hk(b,h)=kkαk,k(b,h)log2αk,k(b,h).H_{k}^{(b,h)}=-\sum_{k^{\prime}\leq k}\alpha_{k,k^{\prime}}^{(b,h)}\log_{2}\alpha_{k,k^{\prime}}^{(b,h)}.italic_H start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT = - ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_k end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT . (20)

From this, we compute the normalized entropy H^k(b,k)\hat{H}_{k}^{(b,k)}over^ start_ARG italic_H end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_k ) end_POSTSUPERSCRIPT, which divides the attention entropy by the maximum entropy for the given context length,

H^k(b,h)=Hk(b,h)log2(k).\hat{H}_{k}^{(b,h)}=\frac{H_{k}^{(b,h)}}{\log_{2}(k)}.over^ start_ARG italic_H end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT = divide start_ARG italic_H start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT end_ARG start_ARG roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_k ) end_ARG . (21)

This accounts for the entropy being calculated over different numbers of tokens and is displayed in Figure˜D.5. Notably, the identified stages line up closely to stages of these attention entropy curves.

Refer to caption
Figure D.5: Attention hardening as measured by the normalized attention entropy score (Section˜D.2.3). Block 1 heads 1y/3y and block 2 head 1y harden over training. In combination with the fact that these attention heads become less variable (Figure˜D.6), this may contribute to a decrease in the LLC (discussed in Section˜D.2.3) The x-components of the attention heads remain much softer over the entire training run.
Constant attention

Accomplishing constant attention requires the presence of biases in the query and key transformations, or if there is no bias (as is the case for the models we investigated), requires attending to the positional embedding. With the Shortformer-style positional encoding used for the language models (Section˜F.1.1), this is straightforward: the positional information is injected directly into the key and weight matrices. With the in-context linear regression models, where the positional embedding is added to the residual stream activations, this is less straightforward: achieving constant attention requires separating residual stream activations into orthogonal positional- and input-dependent subspaces, then reading from the former with the query and key weight matrices.

Attention variability score

To quantify how constant the attention pattern is, we use measure attention variability (Vig & Belinkov, 2019),

Vk(b,h)=i=1nkk|αk,k(b,h)(SK(i))α¯k,k(b,h)|2nkkα¯k,k(b,h),V^{(b,h)}_{k}=\frac{\sum_{i=1}^{n}\sum_{k^{\prime}\leq k}\left|\alpha_{k,k^{\prime}}^{(b,h)}(S_{K}^{(i)})-\bar{\alpha}_{k,k^{\prime}}^{(b,h)}\right|}{2n\sum_{k^{\prime}\leq k}\bar{\alpha}_{k,k^{\prime}}^{(b,h)}},italic_V start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_k end_POSTSUBSCRIPT | italic_α start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT ( italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) - over¯ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT | end_ARG start_ARG 2 italic_n ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≤ italic_k end_POSTSUBSCRIPT over¯ start_ARG italic_α end_ARG start_POSTSUBSCRIPT italic_k , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT end_ARG , (22)

where the division by 222 ensures the variability lies in the range [0,1][0,1][ 0 , 1 ]. This is displayed in Figure˜D.6. These reveal that though attention hardness and variability are independent axes of differentiation, empirically, we observe that hard attention is correlated with low variability.

Refer to caption
Figure D.6: Attention variability over time. The heads that develop hard attention in Figure˜D.5 (block 1 heads 1y, 3y, and 4y) also become less variable over time.
Self-attention score

Self-attention is measured by the average amount a token kkitalic_k attends to itself, αk,k(b,h)\alpha_{k,k}^{(b,h)}italic_α start_POSTSUBSCRIPT italic_k , italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT.

Previous-token attention score

Previous-token attention is measured the same as in the language model setting (Section˜C.2) with one difference: we compute the previous-token score not over a synthetic dataset but over a validation batch.

xxitalic_x-attention score

The total amount attended to inputs xkx_{k}italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, that is αk,x(b,h)=k=1Kαk,2k(b,h)\alpha_{k,x}^{(b,h)}=\sum_{k^{\prime}=1}^{K}\alpha_{k,2k}^{(b,h)}italic_α start_POSTSUBSCRIPT italic_k , italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_k , 2 italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT.

yyitalic_y-attention score

Defined analogously αk,x(b,h)=k=1Kαk,2k+1(b,h)\alpha_{k,x}^{(b,h)}=\sum_{k^{\prime}=1}^{K}\alpha_{k,2k+1}^{(b,h)}italic_α start_POSTSUBSCRIPT italic_k , italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_k , 2 italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT.

Classifying attention heads

Several attention heads are easy to identify by virtue of being both concentrated and consistent. These are depicted in Figure˜D.7 and include: (B1H3y) previous-token heads (also present in the language model case), (B1H1y) previous-x, and (B1H4x, B2H1y) previous-y heads. Other training runs also include self-attention heads.

Refer to caption
Figure D.7: Collection of attention heads identified by their consistent and recognizable attention patterns. Left to right: previous-xxitalic_xs head, previous-token head, previous-yyitalic_ys head, previous-yyitalic_ys head
Contributions to degeneracy

Suppose an attention head hhitalic_h in block bbitalic_b has the following constant attention pattern (after the softmax) A(b,h)=iδl(i)iA^{(b,h)}=\sum_{i}\delta_{l(i)\,i}italic_A start_POSTSUPERSCRIPT ( italic_b , italic_h ) end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_l ( italic_i ) italic_i end_POSTSUBSCRIPT. That is, for each token iiitalic_i, that attention head attends solely to a single earlier token l(i)il(i)\leq iitalic_l ( italic_i ) ≤ italic_i and no others. Restricting to single-head attention (the argument generalizes straightforwardly), the final contribution of this attention head to the residual stream is the following (Phuong & Hutter, 2022):

O=WO(VA)O=W_{O}\cdot(V\cdot A)italic_O = italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ⋅ ( italic_V ⋅ italic_A ) (23)

where Az×xA\in\mathbb{R}^{\ell_{z}}\times\mathbb{R}^{\ell_{x}}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the attention pattern, Vdout×zV\in\mathbb{R}^{d_{\text{out}}}\times\mathbb{R}^{\ell_{z}}italic_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the value matrix, and WOdz×zW_{O}\in\mathbb{R}^{d_{z}}\times\mathbb{R}^{\ell_{z}}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the matrix of residual stream activations, and Vdout×zV\in\mathbb{R}^{d_{\text{out}}}\times\mathbb{R}^{\ell_{z}}italic_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT out end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the value matrix. The result of this operation is subsequently multiplied by the output matrix and then added back into the residual stream. Plugging in the hard and constant attention pattern, writing out the matrix multiplication, and filling in the definition of AAitalic_A we get

Oij=k(WO)ikVkl(j)δl(j)j.O_{ij}=\sum_{k}(W_{O})_{ik}V_{kl(j)}\delta_{l(j)j}.italic_O start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT italic_k italic_l ( italic_j ) end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_l ( italic_j ) italic_j end_POSTSUBSCRIPT . (24)

For each column in AAitalic_A, the hard attention picks out a single element of VVitalic_V at column l(j)l(j)italic_l ( italic_j ) for each row kkitalic_k. Now suppose that there is a token ll^{\prime}italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT that receives no attention from any position jjitalic_j. That is, there exists no jjitalic_j such that l=l(j)l^{\prime}=l(j)italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_l ( italic_j ). Then, there is a column ll^{\prime}italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT in VVitalic_V which does not contribute to the result of VAV\cdot Aitalic_V ⋅ italic_A, and, in turn, a column ll^{\prime}italic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT in WOW_{O}italic_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT, which does not contribute to the output of the head. As discussed for the embedding and layer norm, this decrease in effective dimensionality leads to a decrease in the learning coefficient.

Note that this argument does not hold for all hard and constant attention patterns. It holds solely for attention patterns that consistently ignore some earlier token across all positions, such as the previous-xxitalic_x and previous-yyitalic_y heads, but not the self-attention and previous-token heads. As discussed in Section˜D.2.6, it remains unclear what exactly the threshold for “ignoring” a token should be before it contributes to degeneracy and whether any of the heads we examine actually meet this threshold.

D.2.4 Unembedding collapse

The unembedding block consists of a layer normalization layer LN(z)\operatorname{LN}(z)roman_LN ( italic_z ) followed by a linear transformation WUz+bUW_{U}z+b_{U}italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT italic_z + italic_b start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT and finally a projection πy\pi_{y}italic_π start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT to extract the yyitalic_y-component. Given the 64-dimensional vector of activations zzitalic_z in the residual stream right before the unembedding (for a specific token), the full unembedding operation is:

πy[WU(z𝔼[z]𝕍[z]+ϵγ+β)+bU]\pi_{y}\left[W_{U}\left(\frac{z-\mathbb{E}[z]}{\sqrt{\mathbb{V}[z]+\epsilon}}\odot\gamma+\beta\right)+b_{U}\right]italic_π start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT [ italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ( divide start_ARG italic_z - blackboard_E [ italic_z ] end_ARG start_ARG square-root start_ARG blackboard_V [ italic_z ] + italic_ϵ end_ARG end_ARG ⊙ italic_γ + italic_β ) + italic_b start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ]

where \odot denotes element-wise multiplication of two vectors and γ,β\gamma,\betaitalic_γ , italic_β are the layer normalization weights and biases respectively.

Effective unembedding weights and biases

Moving terms around, we can represent this as

((WU)[0,:]γ)(z𝔼[z]𝕍[z]+ϵ)+((WU)[0,:]β)+(bU)[0]\left((W_{U})_{[0,:]}\odot\gamma\right)\left(\frac{z-\mathbb{E}[z]}{\sqrt{\mathbb{V}[z]+\epsilon}}\right)+\left((W_{U})_{[0,:]}\beta\right)+(b_{U})_{[0]}( ( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 , : ] end_POSTSUBSCRIPT ⊙ italic_γ ) ( divide start_ARG italic_z - blackboard_E [ italic_z ] end_ARG start_ARG square-root start_ARG blackboard_V [ italic_z ] + italic_ϵ end_ARG end_ARG ) + ( ( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 , : ] end_POSTSUBSCRIPT italic_β ) + ( italic_b start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 ] end_POSTSUBSCRIPT

where we order the outputs so that the yyitalic_y-token corresponds to the 0th row. Because we are reading out a single yyitalic_y component, we can express the unembedding transformation in terms of “effective" unembedding weights and biases

W~U\displaystyle\tilde{W}_{U}over~ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT =(WU)[0,:]γ,\displaystyle=(W_{U})_{[0,:]}\odot\gamma,= ( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 , : ] end_POSTSUBSCRIPT ⊙ italic_γ ,
b~U\displaystyle\tilde{b}_{U}over~ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT =((WU)[0,:]β)+(bU)[0].\displaystyle=\left((W_{U})_{[0,:]}\beta\right)+(b_{U})_{[0]}.= ( ( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 , : ] end_POSTSUBSCRIPT italic_β ) + ( italic_b start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 ] end_POSTSUBSCRIPT .
Unembedding weights over time

In Figure˜D.8, we plot (γ,β)(\gamma,\beta)( italic_γ , italic_β ), ((WU)[0,:],(bU)[0])((W_{U})_{[0,:]},(b_{U})_{[0]})( ( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 , : ] end_POSTSUBSCRIPT , ( italic_b start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ 0 ] end_POSTSUBSCRIPT ), and (W~U,b~U)(\tilde{W}_{U},\tilde{b}_{U})( over~ start_ARG italic_W end_ARG start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT , over~ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) as a function of training steps, along with the mean weight over time. These are 64- and 1-dimensional vectors, so we can display the entire set of components. During stage LR3 the majority of weights β\betaitalic_β and WUW_{U}italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT “collapse” to zero. Additionally, the layer normalization biases temporarily experience a large increase in variance before returning to small values. Despite this, the mean of the linear weights, layer normalization biases, and effective weights remains remarkably constant and close to zero throughout the entire process.

Contributions to degeneracy

Suppose that DDitalic_D of the layer normalization weights have vanished, say γi=0\gamma_{i}=0italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 for 1iD1\leq i\leq D1 ≤ italic_i ≤ italic_D. Then the corresponding columns of WUW_{U}italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT only contribute to the unembedding via their product (WU)[:,1:D]β[1:D](W_{U})_{[:,1:D]}\beta_{[1:D]}( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ : , 1 : italic_D ] end_POSTSUBSCRIPT italic_β start_POSTSUBSCRIPT [ 1 : italic_D ] end_POSTSUBSCRIPT with the first DDitalic_D rows of β\betaitalic_β. This creates a typical form of degeneracy studied in SLT and found, for example, in deep linear networks, where we can change the weights to (WU)[:,1:D]A,A1β[1:D](W_{U})_{[:,1:D]}A,A^{-1}\beta_{[1:D]}( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ : , 1 : italic_D ] end_POSTSUBSCRIPT italic_A , italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT [ 1 : italic_D ] end_POSTSUBSCRIPT for any invertible D×DD\times Ditalic_D × italic_D matrix AAitalic_A without changing the function computed by the network. If in addition the βi\beta_{i}italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT vanish for 1iD1\leq i\leq D1 ≤ italic_i ≤ italic_D then the entries of (WU)[:,1:D](W_{U})_{[:,1:D]}( italic_W start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT [ : , 1 : italic_D ] end_POSTSUBSCRIPT are completely unconstrained, creating further degeneracy.

Refer to caption
Figure D.8: Unembedding weights over time for the RT1\operatorname{RT}{}{1}roman_RT 1 transformer undergo a “collapse” that begins towards the end of LR2. When these weights reach zero in LR3 and LR4, it may contribute to the observed decrease in the LLC. Top: Weights over time. The outlier in the positive direction is the weight for the yyitalic_y-token output. Bottom: Biases over time. Left: Unembedding layer normalization weights over time. Middle: Unembedding linear weights over time (restricted to yyitalic_y-subspace). Right: Effective unembedding weights over time (obtained by element-wise multiplication of preceding columns, and focusing on the bias for only the yyitalic_y-token.
D.2.5 Layer normalization collapse

The “collapse” in layer normalization weights is not unique to the unembedding. As depicted in Figure˜D.9, this behavior occurs in all layer norms except for the second MLP. The biases also remain centered close to zero even as the variance in biases grows much larger. Unlike in the unembedding, these layers begin to change earlier (starting halfway through LR2).

What is most striking about the layer normalization collapse is that it occurs without any explicit regularization (neither weight decay nor dropout). As such, it demonstrates a clear example of implicit regularization, i.e., inductive biases in the optimizer or model that favor simpler solutions.

Refer to caption
Figure D.9: Layer norm weights over time. Top: After LR3, the layer normalization collapse expands from the unembedding to earlier layers, most notably in the first pre-attention layer norm. This occurs without explicit regularization and may contribute to the concurrent decrease in LLC. Bottom: During layer normalization collapse, the variance of layer normalization biases increases drastically while the mean of the biases remains relatively constant. Inset: Plotting the fraction of weights or biases whose magnitude is less than 0.1 over time reveals that the collapse is more measured for intermediate layer norms: weights shrink to small values but not extremely close to zero as in the unembedding and first attention layer.
Contributions to degeneracy

In the previous section, we describe how layer norm collapse in the unembedding is linked to an increase in degeneracy because it ensures that parameters in the subsequent linear layer become irrelevant. The same is true for layer norm which precedes the attention and MLP blocks.

D.2.6 Degeneracy and development

In the previous subsections, we provide a set of theoretical arguments for how embedding collapse (Section˜D.2.1), layer normalization collapse (Section˜D.2.5), and attention collapse (Section˜D.2.3) can lead to an increase in degeneracy, even while leaving the implemented function unchanged.

The free energy formula tells us that, for two different solutions (sets of weights) with the same loss, the Bayesian posterior will asymptotically prefer the model that has the lower learning coefficient (i.e., higher degeneracy). This suggests that these different forms of collapse may be driven by a bias towards higher degeneracy, as captured in the free energy formula. However, in idealized Bayesian inference, we do not expect the posterior to concentrated around the neighborhood of an equal-loss-but-higher-degeneracy local minimum to begin with. That this kind of transition arises in practice might arise from one of the various differences between Bayesian inference and gradient-based training.

Actually establishing a causal link between increasing degeneracy and structure development is beyond the scope of this paper. For one, the theoretical arguments hinge on the collapse being complete, that is, the components that go to zero must become exactly zero in the limit, where we take the number of samples to compute the loss to infinity. In practice, we expect there to be some threshold ϵ\epsilonitalic_ϵ below which we can treat weights as effectively zero. Second, even if these explanations are correct, we do not know that they account for all of the empirically observed decrease in the LLC during these stages. There may be other drivers we missed. Finally, establishing a causal link requires theoretical progress in relating the Bayesian learning process to the SGD learning process. The arguments are suggestive, but currently only a source of intuition for how structure and degeneracy can be related, and a starting point for future research.

Appendix E One-layer language model experiments

We also trained and ran some experiments on a one-layer language model (see Section˜F.1.1 for details). We aggregate results for the one-layer language model here, mirroring the experiments for the two-layer language model where possible. The early development of the one-layer model has many parallels with the two-layer model. At a single stage boundary, just as it occurs in the two-layer model, the one-layer model minimizes its bigram score (see Section˜C.1.1), begins utilizing the positional embedding to noticeably improve performance (see Section˜C.2.1), and starts making sudden improvements to the same nnitalic_n-gram scores (see Section˜C.1.2). Remarkably this occurs at the same checkpoint as in the 2-layer model (at 900 training steps).

One key difference, however, is that this occurs at the second stage boundary as discerned by the plateaus of the LLC estimation. We did not closely investigate why the LLC estimation appears to drop between steps 400 and 900 in this model. As a result though, we do observe an interesting qualitative similarity to the drop in LLC in stage LM3 of the two-layer model, that this drop precedes a noticeable bump in the loss function.

Refer to caption
Refer to caption
Figure E.1: We train a one-layer transformer model in the language setting to compare with the two-layer model. The development of certain behavioral and structural metrics over time closely mirrors the development of the same metrics in the early stages of the two-layer language model. Top: test loss and LLC estimations over time for the one-layer attention-only transformer, compare with Figure˜1(a). Bottom: bigram score, test loss with positional embedding ablated, and nnitalic_n-gram scores for the one-layer attention-only transformer, compare with Figure˜4(a,b,c).
Refer to caption
Figure E.2: A more detailed version of Figure˜E.1 for the one-layer language model. Top: Loss, LLC, and weight norm, along with an overlaid Gaussian process fit to these curves (red dotted lines). Bottom: Associated slopes, both numerically estimated finite differences (transparent blue) and of the Gaussian process (red dotted lined).

Appendix F Transformer training experiment details

F.1 Language models

F.1.1 Architecture

The language model architectures we consider are one- and two-layer attention-only transformers. They have a context length of 1024, a residual stream dimension of dmodel=256d_{model}=256italic_d start_POSTSUBSCRIPT italic_m italic_o italic_d italic_e italic_l end_POSTSUBSCRIPT = 256, H=8H=8italic_H = 8 attention heads per layer, and include layer normalization layers. We also used a learnable Shortformer positional embedding (Press et al., 2021). The resulting models have a total of d=3,091,336d=3,091,336italic_d = 3 , 091 , 336 parameters for L=1L=1italic_L = 1 and d=3,355,016d=3,355,016italic_d = 3 , 355 , 016 parameters for L=2L=2italic_L = 2. We used an implementation provided by TransformerLens (Nanda & Bloom, 2022).

Refer to caption
Component 1-Layer 2-Layer
Token Embedding Weights 1,280,0001,280,0001 , 280 , 000
Positional Embedding Weights 262,144262,144262 , 144
Layer 1 Layer Norm Weights 256256256
Layer 1 Layer Norm Bias 256256256
Layer 1 Attention Query Weights 65,53665,53665 , 536
Layer 1 Attention Key Weights 65,53665,53665 , 536
Layer 1 Attention Value Weights 65,53665,53665 , 536
Layer 1 Attention Output Weights 65,53665,53665 , 536
Layer 1 Attention Query Bias 256256256
Layer 1 Attention Key Bias 256256256
Layer 1 Attention Value Bias 256256256
Layer 1 Attention Output Bias 256256256
Layer 2 Layer Norm Weights N/A 256256256
Layer 2 Layer Norm Bias N/A 256256256
Layer 2 Attention Query Weights N/A 65,53665,53665 , 536
Layer 2 Attention Key Weights N/A 65,53665,53665 , 536
Layer 2 Attention Value Weights N/A 65,53665,53665 , 536
Layer 2 Attention Output Weights N/A 65,53665,53665 , 536
Layer 2 Attention Query Bias N/A 256256256
Layer 2 Attention Key Bias N/A 256256256
Layer 2 Attention Value Bias N/A 256256256
Layer 2 Attention Output Bias N/A 256256256
Final Layer Norm Weights 256256256
Final Layer Norm Bias 256256256
Unembedding Weights 1,280,0001,280,0001 , 280 , 000
Unembedding Bias 5,0005,0005 , 000
Figure F.1: Attention-only transformers with Shortformer position-infused attention and pre-layer norm. The one-layer model has a total of 3,091,336 trainable parameters, while the two-layer model has 3,355,016.
F.1.2 Tokenization

For tokenization, we used a truncated variant of the GPT-2 tokenizer that cut the original vocabulary of 50,000 tokens down to 5,000 (Eldan & Li, 2023) to reduce the size of the model. We think this may contribute to the prominence of the the plateau at the end of LM1: the frequency of bigram statistics depends on your choice of tokens, and a larger tokenizer leads to bigrams that are individually much less frequent.

F.1.3 Training

The models are trained on a single epoch over 50,00050,00050 , 000 steps on \sim5 billion tokens using a resampled subset of the Pile (Gao et al., 2020; Xie et al., 2023) using a batch size of 100100100. A snapshot was saved every 101010 steps for a total of 500050005000 checkpoints, though a majority of analysis used checkpoints every 100 steps. The training time was around 6 GPU hours per model on an A100. Additional seeds were trained on v4 TPUs at around 1.5 TPU hours per model.

Training was conducted on the first 10 million lines of the DSIR-filtered Pile (Xie et al., 2023; Gao et al., 2020) but did not exhaust all 10 million lines. The model was subject to weight decay regularization, without the application of dropout. We did not employ a learning rate scheduler throughout the training process.

Table 3: Summary of hyperparameters and their values for transformer language model training experiments.
Hyperparameter Category Description/Notes Value
nnitalic_n Data # of training samples 5,000,0005,000,0005 , 000 , 000
TTitalic_T Data # of training steps 50,00050,00050 , 000
NtestN_{\text{test}}italic_N start_POSTSUBSCRIPT test end_POSTSUBSCRIPT Data # of test samples 512
Tokenizer Type Data Type of Tokenizer Truncated GPT-2 Tokenizer
DDitalic_D Data Vocabulary size 5,000
KKitalic_K Data Context size 1,024
LLitalic_L Model # of layers in the model 222
HHitalic_H Model # of heads per layer 8
dmlpd_{\mathrm{mlp}}italic_d start_POSTSUBSCRIPT roman_mlp end_POSTSUBSCRIPT Model MLP hidden layer size N/A
dembedd_{\mathrm{embed}}italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT Model Embedding size 256
dheadd_{\text{head}}italic_d start_POSTSUBSCRIPT head end_POSTSUBSCRIPT Model Head size 32
seed\mathrm{seed}roman_seed Model Model initialization 1
m Training Batch Size 100
Optimizer Type Optimizer Type of optimizer AdamW
η\etaitalic_η Optimizer Learning rate 0.0010.0010.001
λwd\lambda_{\mathrm{wd}}italic_λ start_POSTSUBSCRIPT roman_wd end_POSTSUBSCRIPT Optimizer Weight Decay 0.050.050.05
β1,2\beta_{1,2}italic_β start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT Optimizer Betas (0.9,0.999)(0.9,0.999)( 0.9 , 0.999 )

F.2 In-context linear regression transformers

F.2.1 Architecture

In the following LLitalic_L refers to the number of layers (blocks) in the transformer, HHitalic_H is the number of heads in each layer, DDitalic_D is the dimension of inputs xDx\in\mathbb{R}^{D}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and KKitalic_K is the number of (x,y)(x,y)( italic_x , italic_y ) pairs provided to the Transformer in-context.

The architecture is a pre-layer-norm decoder-only transformer modeled after NanoGPT Karpathy, 2022; see also Phuong & Hutter, 2022 with a learnable positional embedding. For the models discussed in the main body, we consider L=2L=2italic_L = 2, H=4H=4italic_H = 4 transformers (with d=51,717d=51,717italic_d = 51 , 717 parameters), i.e., two transformer blocks with four attention heads each.

Refer to caption
Component # of Parameters
Token Embedding Weight 320320320
Positional Embedding Weight 1,0241,0241 , 024
Layer 1 Layer Norm Weight 2 646464
Layer 1 Layer Norm Bias 1 646464
Layer 1 Attention Weights 12,28812,28812 , 288
Layer 1 Attention Output Weights 4,0964,0964 , 096
Layer 1 Layer Norm Weight 1 646464
Layer 1 Layer Norm Bias 2 646464
Layer 1 Feed-Forward MLP Weight 4,0964,0964 , 096
Layer 1 Feed-Forward MLP Bias 646464
Layer 1 Feed-Forward Output Weight 4,0964,0964 , 096
Layer 1 Feed-Forward Output Bias 646464
Layer 2 Layer Norm Weight 1 646464
Layer 2 Layer Norm Bias 1 646464
Layer 2 Attention Weights 12,28812,28812 , 288
Layer 2 Attention Output Weights 4,0964,0964 , 096
Layer 2 Layer Norm Weight 2 646464
Layer 2 Layer Norm Bias 2 646464
Layer 2 Feed-Forward MLP Weight 4,0964,0964 , 096
Layer 2 Feed-Forward MLP Bias 646464
Layer 2 Feed-Forward Output Weight 4,0964,0964 , 096
Layer 2 Feed-Forward Output Bias 646464
Unembedding Layer Norm Weight 1 646464
Unembedding Layer Norm Bias 1 646464
Unembedding Weight 2 320320320
Unembedding Bias 2 555
Figure F.2: Transformer parameters in the in-context linear regression setting. The model has two transformer blocks for a total of 51,71751,71751 , 717 trainable parameters.
F.2.2 Tokenization

To run contexts SKS_{K}italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT through the above model requires an initial encoding or “tokenization step” and final “projection step.” The context is encoded as a sequence of “tokens” TkT_{k}italic_T start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as follows:

Tk=((0x1),(y100),(0xk),(yk00)).T_{k}=\left(\begin{pmatrix}0\\ \vline\\ x_{1}\\ \vline\end{pmatrix},\begin{pmatrix}y_{1}\\ 0\\ \vdots\\ 0\end{pmatrix},\cdots\begin{pmatrix}0\\ \vline\\ x_{k}\\ \vline\end{pmatrix},\begin{pmatrix}y_{k}\\ 0\\ \vdots\\ 0\end{pmatrix}\right).italic_T start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ( ( start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) , ⋯ ( start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) ) .

Through the main text, we write fw(Sk)f_{w}(S_{k})italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) for fw(Tk)f_{w}(T_{k})italic_f start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ( italic_T start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). Note that this tokenization includes the final yky_{k}italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT token even though this receives no training signal. For this reason, we omit this token from the attention entropy and variability plots (Figures˜D.5 and D.6).

The transformer outputs a series of tokens of the same shape as TkT_{k}italic_T start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. To read out the y^k\hat{y}_{k}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT predictions, we read out the first component of every other token, i.e.,

πY:(D+1)×2K\displaystyle\pi_{Y}:\mathbb{R}^{(D+1)\times 2K}italic_π start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT ( italic_D + 1 ) × 2 italic_K end_POSTSUPERSCRIPT K\displaystyle\to\mathbb{R}^{K}→ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT (25)
((y^1),(.),,(y^k),(.))\displaystyle\left(\begin{pmatrix}\hat{y}_{1}\\ \vdots\\ \end{pmatrix},\begin{pmatrix}.\\ \vdots\end{pmatrix},\cdots,\begin{pmatrix}\hat{y}_{k}\\ \vdots\\ \end{pmatrix},\begin{pmatrix}.\\ \vdots\end{pmatrix}\right)( ( start_ARG start_ROW start_CELL over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL . end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW end_ARG ) , ⋯ , ( start_ARG start_ROW start_CELL over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL . end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW end_ARG ) ) (y^1,,yk).\displaystyle\mapsto(\hat{y}_{1},\dots,y_{k}).↦ ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (26)
F.2.3 Training

We train from a single seed for each choice of architecture and optimizer hyperparameters using minibatch stochastic gradient descent. We train without explicit regularization and use the Adam optimizer (Kingma & Ba, 2014). The training runs take 1 to 5 TPU-hours on TPUs provided by Google Research. Models are trained from the same initialization and on the data vectors within each batch (but for different sets of tasks and task orderings).

Models are trained on a single epoch: each of the T=500,000T=500,000italic_T = 500 , 000 batches consists of a new set of sequences with batch size 256. For the LLC estimates, we save 190 checkpoints: 100 are linearly spaced over the training run, and the remaining 90 are logarithmically spaced. We perform LLC estimation and other analyses on these checkpoints.

Table 4: Summary of hyperparameters and their default values for in-context linear regression transformer model training experiments.
Hyperparameter Category Description/Notes Default Values
nnitalic_n Data # of training samples 128,000,000
BBitalic_B Data Batch size during training 256
TTitalic_T Data # of training steps 500k
NtestN_{\text{test}}italic_N start_POSTSUBSCRIPT test end_POSTSUBSCRIPT Data # of eval samples 2048
DDitalic_D Data Dimensions of linear regression task (Task size) 4
KKitalic_K Data Maximum in-context examples 8
σ2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT Data Variance of noise in data generation 0.125
LLitalic_L Model # of layers in the model 2
HHitalic_H Model # of attention heads per layer 4
dmlpd_{\mathrm{mlp}}italic_d start_POSTSUBSCRIPT roman_mlp end_POSTSUBSCRIPT Model Size of the hidden layer in MLP 64
dembedd_{\mathrm{embed}}italic_d start_POSTSUBSCRIPT roman_embed end_POSTSUBSCRIPT Model Embedding size 64
seed\mathrm{seed}roman_seed Misc Training run seeds {(0, 1, 2, 3, 4)}
Optimizer Type Optimizer Type of optimizer Adam
η\etaitalic_η Optimizer Maximum learning rate 0.003
λwd\lambda_{\mathrm{wd}}italic_λ start_POSTSUBSCRIPT roman_wd end_POSTSUBSCRIPT Optimizer Weight Decay 0
β1,2\beta_{1,2}italic_β start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT Optimizer Betas (0.9, 0.999)
Scheduler Type Scheduler Type of learning rate scheduler OneCycleLR
Strategy Scheduler Strategy for annealing the learning rate Linear
% start Scheduler Percentage of the cycle when learning rate is increasing 0.5

Cite as

@article{hoogland2024loss,
  title = {Loss Landscape Degeneracy and Stagewise Development of Transformers},
  author = {Jesse Hoogland and George Wang and Matthew Farrugia-Roberts and Liam Carroll and Susan Wei and Daniel Murfet},
  year = {2024},
  abstract = {We show that in-context learning emerges in transformers in discrete developmental stages, when they are trained on either language modeling or linear regression tasks. We introduce two methods for detecting the milestones that separate these stages, by probing the geometry of the population loss in both parameter space and function space. We study the stages revealed by these new methods using a range of behavioral and structural metrics to establish their validity.},
  url = {https://tmlr.infinite-conf.org/paper_pages/45qJyBG8Oj.html}
}
Click to copy