From Global to Local: A Scalable Benchmark for Local Posterior Sampling

Authors

Rohan Hitchcock
University of Melbourne
Jesse Hoogland
Timaeus
See Contributions

Publication Details

Published:
July 29, 2025

Access

Abstract

Degeneracy is an inherent feature of the loss landscape of neural networks, but it is not well understood how stochastic gradient MCMC (SGMCMC) algorithms interact with this degeneracy. In particular, current global convergence guarantees for common SGMCMC algorithms rely on assumptions which are likely incompatible with degenerate loss landscapes. In this paper, we argue that this gap requires a shift in focus from global to local posterior sampling, and, as a first step, we introduce a novel scalable benchmark for evaluating the local sampling performance of SGMCMC algorithms. We evaluate a number of common algorithms, and find that RMSProp-preconditioned SGLD is most effective at faithfully representing the local geometry of the posterior distribution. Although we lack theoretical guarantees about global sampler convergence, our empirical results show that we are able to extract non-trivial local information in models with up to O(100M) parameters.

Automated Conversion Notice

Warning: This paper was automatically converted from LaTeX. While we strive for accuracy, some formatting or mathematical expressions may not render perfectly. Please refer to the original ArXiv version for the authoritative document.

1 Introduction

Neural networks have highly complex loss landscapes which are non-convex and have non-unique degenerate minima. When a neural network is used as the basis for a Bayesian statistical model, the complex geometry of the loss landscape makes sampling from the Bayesian posterior using Markov chain Monte Carlo (MCMC) algorithms difficult. Much of the research in this area has focused on whether MCMC algorithms can adequately explore the global geometry of the loss landscape by visiting sufficiently many minima. Comparatively little attention has been paid to whether MCMC algorithms adequately explore the local geometry near minima. Our focus is on Stochastic Gradient MCMC (SGMCMC) algorithms, such as Stochastic Gradient Langevin Dynamics (SGLD; Welling and Teh 2011), applied to large models like neural networks. For these models, the local geometry near critical points is highly complex and degenerate, so local sampling is a non-trivial problem.

In this paper we argue for a shift in focus from global to local posterior sampling for complex models like neural networks. Our main contributions are:

  • We identify open theoretical problems regarding the convergence of SGMCMC algorithms for posteriors with degenerate loss landscapes. We survey existing theoretical guarantees for the global convergence of SGMCMC algorithms, noting their common incompatibility with the degenerate loss landscapes characteristic of these models (e.g., deep linear networks). Additionally, we highlight some important negative results which suggest that global convergence may in fact not occur. Despite this, empirical results show that SGMCMC are able to extract non-trivial local information from the posterior; a phenomena which currently lacks theoretical explanation.

  • We introduce a novel, scalable benchmark for evaluating the local sampling performance of SGMCMC algorithms. Recognizing the challenge of obtaining theoretical convergence guarantees in degenerate loss landscapes, this benchmark assesses a sampler’s ability to capture known local geometric invariants related to volume scaling. Specifically, it leverages deep linear networks (DLNs), where a key invariant controlling volume scaling — the local learning coefficient (LLC; Lau et al., 2024) — can be computed analytically. This provides a ground-truth for local posterior geometry.

  • We evaluate common SGMCMC algorithms and find RMSProp-preconditioned SGLD (Li et al., 2016) to be the most effective at capturing local posterior features, scaling successfully up to models with O(100M) parameters. This offers practical guidance for researchers and practitioners. Importantly, our benchmark shows that although we lack theoretical guarantees about global sampling, SGMCMC samplers are able to extract non-trivial local information about the posterior distribution.

Refer to caption
Figure 1: From global to local posterior sampling. Left: Neural network posteriors are often erroneously simplified as isolated Gaussian modes. Middle: Neural network posterior distributions are highly degenerate, where parameter changes often don’t affect posterior density. Right: Local sampling must handle these degeneracies, which raises open theoretical and practical questions about the guarantees and effectiveness of local posterior exploration.

2 Background

In Section 2.1 we discuss existing theoretical results about the convergence of SGLD and related sampling algorithms. We highlight some important negative results from the literature, which suggest that global convergence of algorithms like SGLD are unlikely to occur in loss landscapes with degeneracy, and show that existing global convergence guarantees for SGLD rely on assumptions that likely do not hold for neural networks. In Section 2.2 we discuss applications and related work.

2.1 Problems with global sampling

We consider the problem of sampling from an absolutely continuous probability distribution π(w)\pi(w)italic_π ( italic_w ) on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. In Bayesian statistics we often consider the tempered posterior distribution

π(w)φ(w)i=1np(Xi|w)β=φ(w)exp(nβLn(w))\pi(w)\propto\varphi(w)\prod_{i=1}^{n}p(X_{i}|w)^{\beta}=\varphi(w)\exp(-n\beta L_{n}(w))italic_π ( italic_w ) ∝ italic_φ ( italic_w ) ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_w ) start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT = italic_φ ( italic_w ) roman_exp ( - italic_n italic_β italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) ) (2.1)

where {p(x|w)}wd\{p(x|w)\}_{w\in\mathbb{R}^{d}}{ italic_p ( italic_x | italic_w ) } start_POSTSUBSCRIPT italic_w ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is a statistical model, φ(w)\varphi(w)italic_φ ( italic_w ) the prior, Dn={X1,,Xn}D_{n}=\{X_{1},\ldots,X_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } a dataset drawn independently from some distribution q(x)q(x)italic_q ( italic_x ), Ln(w)=1ni=1nlogp(Xi|w)L_{n}(w)=\tfrac{-1}{n}\sum_{i=1}^{n}\log p(X_{i}|w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) = divide start_ARG - 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_log italic_p ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_w ) is the empirical negative log-likelihood, and β>0\beta>0italic_β > 0 is a fixed parameter called the inverse temperature. For any distribution, we can consider the overdamped Langevin diffusion; the stochastic differential equation

dWt=12logπ(Wt)dt+dBtdW_{t}=\tfrac{1}{2}\nabla\log\pi(W_{t})dt+dB_{t}italic_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∇ roman_log italic_π ( italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_d italic_t + italic_d italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (2.2)

where BtB_{t}italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is standard Brownian motion. Under fairly mild assumptions on π(w)\pi(w)italic_π ( italic_w ) (see Roberts and Tweedie, 1996, Theorem 2.1) (2.2) has well-behaved solutions and π(w)\pi(w)italic_π ( italic_w ) is its stationary distribution. The idea of using the forward Euler-Maruyama discretisation of (2.2) to sample from π(w)\pi(w)italic_π ( italic_w ) was first proposed by Parisi (1981) in what is now known as the Unadjusted Langevin Algorithm (ULA; also known as the Metropolis Langevin Algorithm), where for t=0,1,t=0,1,\ldotsitalic_t = 0 , 1 , … we take

wt+1=wt+ΔwtwhereΔwt=ϵ2logπ(wt)+ϵηtw_{t+1}=w_{t}+\Delta w_{t}\qquad\text{where}\qquad\Delta w_{t}=\tfrac{\epsilon}{2}\nabla\log\pi(w_{t})+\sqrt{\epsilon}\eta_{t}italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where roman_Δ italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_ϵ end_ARG start_ARG 2 end_ARG ∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + square-root start_ARG italic_ϵ end_ARG italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (2.3)

where ϵ>0\epsilon>0italic_ϵ > 0 is the step size and η0,η1,\eta_{0},\eta_{1},\ldotsitalic_η start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … are a sequence of iid standard normal random vectors in d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Stochastic Gradient Langevin Dynamics (SGLD; Welling and Teh 2011) is obtained by replacing logπ(wt)\nabla\log\pi(w_{t})∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) in (2.3) with a stochastic estimate g(wt,Ut)g(w_{t},U_{t})italic_g ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), where U0,U1,U_{0},U_{1},\ldotsitalic_U start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … are independent random variables. When π(w)\pi(w)italic_π ( italic_w ) is given by (2.1), usually g(w,U)g(w,U)italic_g ( italic_w , italic_U ) is a mini-batch estimate of the log-likelihood gradient g(w,U)=nβmi=1mlogp(Xui|w)g(w,U)=-\tfrac{n\beta}{m}\sum_{i=1}^{m}\nabla\log p(X_{u_{i}}|w)italic_g ( italic_w , italic_U ) = - divide start_ARG italic_n italic_β end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ roman_log italic_p ( italic_X start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_w ) where U=(u1,,um)U=(u_{1},\ldots,u_{m})italic_U = ( italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) selects a random subset of the dataset DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

Degeneracy can cause samplers to diverge.

Issues with ULA when logπ(w)\log\pi(w)roman_log italic_π ( italic_w ) has degenerate critical points were first noted in Roberts and Tweedie (1996, Section 3.2), where they show that ULA will fail to converge to π(w)\pi(w)italic_π ( italic_w ) for a certain class of distributions when logπ(w)\log\pi(w)roman_log italic_π ( italic_w ) is a polynomial with degenerate critical points. Mattingly et al. (2002) relates the convergence of forward Euler-Mayuyama discretisations of stochastic differential equations (and hence the convergence of ULA) to a global Lipschitz condition on logπ(w)\nabla\log\pi(w)∇ roman_log italic_π ( italic_w ), giving examples of distributions which do not satisfy the global Lipschitz condition for which ULA diverges at any step size. Hutzenthaler et al. (2011, Theorem 1) shows that if logπ(w)\nabla\log\pi(w)∇ roman_log italic_π ( italic_w ) grows any faster than linearly then ULA diverges. This is a strong negative result on the global convergence of ULA, since in models like deep linear networks, the presence of degenerate critical points causes super-linear growth away from these critical points.

The current theory makes strong assumptions about the loss landscape.

Given the above results, it is therefore no surprise that results showing that SGLD is well-behaved in a global sense rely either on global Lipschitz conditions on logπ(w)\nabla\log\pi(w)∇ roman_log italic_π ( italic_w ), or on other conditions which control the growth of this gradient. Asymptotic and non-asymptotic global convergence properties of SGLD are studied in Teh et al. (2015); Vollmer et al. (2015). These results rely on the existence of a Lyaponov function V:d[1,)V:\mathbb{R}^{d}\to[1,\infty)italic_V : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → [ 1 , ∞ ) with globally bounded second derivatives which satisfies

V(w)2+logπ(w)2CV(w) for all wd\|\nabla V(w)\|^{2}+\|\nabla\log\pi(w)\|^{2}\leq CV(w)\qquad\text{ for all }w\in\mathbb{R}^{d}∥ ∇ italic_V ( italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ ∇ roman_log italic_π ( italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_C italic_V ( italic_w ) for all italic_w ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (2.4)

for some C>0C>0italic_C > 0 (see Teh et al., 2015, Assumption 4). The bounded second derivatives of V(w)V(w)italic_V ( italic_w ) impose strong conditions on the growth of logπ(w)\log\pi(w)roman_log italic_π ( italic_w ) and is incompatible with the deep linear network regression model described in Section 3.2. A more general approach to analysing the convergence of diffusion-based SGMCMC algorithms is described in Chen et al. (2015), though in the case of SGLD the necessary conditions for this method imply that (2.4) holds (see Chen et al., 2015, Appendix C).

Other results about the global convergence of SGLD and ULA rely on a global Lipschitz condition on logπ(w)\nabla\log\pi(w)∇ roman_log italic_π ( italic_w ) (also called α\alphaitalic_α-smoothness), which supposes that there exists a constant α>0\alpha>0italic_α > 0 such that

logπ(w1)logπ(w2)αw1w2 for all w1,w2d.\|\nabla\log\pi(w_{1})-\nabla\log\pi(w_{2})\|\leq\alpha\|w_{1}-w_{2}\|\qquad\text{ for all }w_{1},w_{2}\in\mathbb{R}^{d}.∥ ∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - ∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∥ ≤ italic_α ∥ italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ for all italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT . (2.5)

This is a common assumption when studying SGLD in the context of stochastic optimization (Raginsky et al., 2017; Tzen et al., 2018; Xu et al., 2020; Zou et al., 2021; Zhang et al., 2022) and also in the study of forward Euler-Mayurmara discretisations of diffusion processes including ULA (Kushner, 1987; Borkar and Mitter, 1999; Durmus and Moulines, 2016; Brosse et al., 2018; Dalalyan and Karagulyan, 2019; Cheng et al., 2020). Other results rely on assumptions which preclude the possibility of divergence (Gelfand and Mitter, 1991; Higham et al., 2002).

Conditions which impose strong global conditions on the growth rate of logπ(w)\nabla\log\pi(w)∇ roman_log italic_π ( italic_w ) often do not hold when logπ(w)\log\pi(w)roman_log italic_π ( italic_w ) has degenerate critical points. For concreteness, consider the example of an MMitalic_M-layer deep linear network learning a regression task (full details are given in Section 3.2). Lemma 2.1 shows that assumptions (2.4) and (2.5) do not hold when M>1M>1italic_M > 1, which is exactly the situation when Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) has degenerate critical points.

Lemma 2.1.

The negative log-likelihood function Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) for the deep linear network regression task described in Section 3.2 is a polynomial in wwitalic_w of degree 2M2M2 italic_M with probability one, where MMitalic_M is the number of layers of the network.

Proof.

We give a self-contained statement and proof in Appendix B.

All is not lost.

Despite the negative results discussed above and the lack of theoretical guarantees, Lau et al. (2024) shows empirically that SGLD can be used to obtain good local measurements of the geometry of Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) when π(w)\pi(w)italic_π ( italic_w ) has the form in (2.1). This forms the basis of our benchmark in Section 3.2, and we show similar results for a variety of SGMCMC samplers in Section 4. In the absence of theoretical guarantees about sampler convergence, we can empirically verify that samplers can recover important geometric invariants of the log-likelihood. We emphasise that this empirical phenomena is unexplained by the global convergence results discussed above, and presents an open theoretical problem.

Some work proposes modifications to ULA or SGLD which aim to address potential convergence issues (Gelfand and Mitter, 1993; Lamba et al., 2006; Hutzenthaler et al., 2012; Sabanis, 2013; Sabanis and Zhang, 2019; Brosse et al., 2019). Finally, Zhang et al. (2018) is notable for its local analysis of SGLD, studying escape from and convergence to local minima in the context of stochastic optimization.

2.2 The need for local sampling

The shift toward local posterior sampling has immediate practical implications in areas such as interpretability and Bayesian deep learning.

Interpretability

SGMCMC algorithms play a central role in approaches to interpretability based on singular learning theory (SLT). SLT (see Watanabe, 2009, 2018) is a mathematical theory of Bayesian learning which properly accounts for degeneracy in the model’s log-likelihood function, and so is the correct Bayesian learning theory for neural networks (see Wei et al., 2022). It provides us with statistically relevant geometric invariants such as the local learning coefficient (LLC; Lau et al., 2024), which has been estimated using SGLD in large neural networks. When tracked over training, changes in the LLC correspond to qualitative changes in the model’s behaviour (Chen et al., 2023; Hoogland et al., 2025; Carroll et al., 2025). This approach has been used in models as large as 100M parameter language models, providing empirically useful results for model interpretability (Wang et al., 2024). SGMCMC algorithms are also required to estimate local quantities other than the LLC (Baker et al., 2025).

Bayesian deep learning

Some approaches to Bayesian deep learning involve first training a neural network using a standard optimization method, and then sampling from the posterior distribution in a neighbourhood of the parameter found via standard training. This is done for reasons such as uncertainty quantifications during prediction. Bayesian deep ensembling involves independently training multiple copies of the network in parallel, with the aim of obtaining several distinct high likelihood solutions. Methods such as MultiSWAG (Wilson and Izmailov, 2020) then incorporate local posterior samples from a neighbourhood of each training solution when making predictions.

3 Methodology

In Section 3.1 we discuss how the local geometry of the expected negative log-likelihood L(w)L(w)italic_L ( italic_w ) affects the posterior distribution (2.1), focusing on volume scaling of sublevel sets of L(w)L(w)italic_L ( italic_w ). In Section 3.2 we describe a specific benchmark for local sampling which involves estimating the local learning coefficient of deep linear networks. Details of experiments are given in Section 3.3.

3.1 Measurements of local posterior geometry

In this section we consider the setting of Section 2.1, in particular the tempered posterior distribution π(w)\pi(w)italic_π ( italic_w ) in (2.1) and the geometry of the expected negative log-likelihood function L(w)=𝐄logp(X|w)L(w)=-\mathbf{E}\log p(X|w)italic_L ( italic_w ) = - bold_E roman_log italic_p ( italic_X | italic_w ) where Xq(x)X\sim q(x)italic_X ∼ italic_q ( italic_x ).

Volume scaling in the loss landscape.

An important geometric quantity of L(w)L(w)italic_L ( italic_w ) from the perspective of π(w)\pi(w)italic_π ( italic_w ) is the volume of sublevel sets

V(ϵ,w0)=vol{wWL(w)L(w0)+ϵ}V(\epsilon,w_{0})=\operatorname{vol}\{w\in W\mid L(w)\leq L(w_{0})+\epsilon\}italic_V ( italic_ϵ , italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = roman_vol { italic_w ∈ italic_W ∣ italic_L ( italic_w ) ≤ italic_L ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + italic_ϵ } (3.1)

where ϵ>0\epsilon>0italic_ϵ > 0, w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a minimum of L(w)L(w)italic_L ( italic_w ) and WdW\subseteq\mathbb{R}^{d}italic_W ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a neighbourhood of w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The volume V(ϵ,w0)V(\epsilon,w_{0})italic_V ( italic_ϵ , italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) quantifies the number of parameters near w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT which achieve close to minimum loss within WWitalic_W, and so is closely related to how the posterior distribution π(w)\pi(w)italic_π ( italic_w ) concentrates.111This is notwithstanding the difference between Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) and L(w)L(w)italic_L ( italic_w ), which requires a careful technical treatment that goes beyond the scope of this paper; see Watanabe (2009, Chapter 5).

We can consider the rate of volume scaling by taking ϵ0\epsilon\to 0italic_ϵ → 0. When w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a non-degenerate critical point, volume scaling is determined by the spectrum of the Hessian H(w0)H(w_{0})italic_H ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) at w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and we have

V(ϵ,w0)|detH(w0)|1/2ϵd/2asϵ0V(\epsilon,w_{0})\approx|\det H(w_{0})|^{-1/2}\epsilon^{d/2}\qquad\text{as}~\epsilon\to 0italic_V ( italic_ϵ , italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≈ | roman_det italic_H ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) | start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT italic_ϵ start_POSTSUPERSCRIPT italic_d / 2 end_POSTSUPERSCRIPT as italic_ϵ → 0 (3.2)

where dditalic_d is the dimension of parameter space. However, when w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is a degenerate critical point we have detH(w0)=0\det H(w_{0})=0roman_det italic_H ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = 0 and this formula is no longer true; the second-order Taylor expansion of L(w)L(w)italic_L ( italic_w ) at w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT used to derive (3.2) no longer provides sufficient geometric information understand how volume is changing as ϵ0\epsilon\to 0italic_ϵ → 0. In general we have

V(ϵ,w0)cϵλ(w0)(logϵ)m(w0)1asϵ0V(\epsilon,w_{0})\approx c\epsilon^{\lambda(w_{0})}(-\log\epsilon)^{m(w_{0})-1}\qquad\text{as}~\epsilon\to 0italic_V ( italic_ϵ , italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≈ italic_c italic_ϵ start_POSTSUPERSCRIPT italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ( - roman_log italic_ϵ ) start_POSTSUPERSCRIPT italic_m ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - 1 end_POSTSUPERSCRIPT as italic_ϵ → 0 (3.3)

for some c>0c>0italic_c > 0, where λ(w0)\lambda(w_{0})\in\mathbb{Q}italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∈ blackboard_Q is the local learning coefficient (LLC) and m(w0)m(w_{0})\in\mathbb{N}italic_m ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∈ blackboard_N is its multiplicity within WWitalic_W (see Lau et al., 2024, Section 3, Appendix A). When w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is non-degenerate and the only critical point in WWitalic_W then λ(w0)=d/2\lambda(w_{0})=d/2italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_d / 2, c=|detH(w0)|1/2c=|\det H(w_{0})|^{-1/2}italic_c = | roman_det italic_H ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) | start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT and m(w0)=1m(w_{0})=1italic_m ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = 1, and (3.2) is obtained from (3.3).

Refer to caption
Figure 2: The Local Learning Coefficient (LLC) captures the local geometry of the posterior. We illustrate volume-scaling behaviour near minima for various simple potentials, highlighting their local geometry. The LLC, defined as the volume-scaling exponent, quantifies the extent of degeneracy. For non-degenerate minima (left two examples), the LLC is always d/2d/2italic_d / 2, where dditalic_d is the number of parameters, and the Hessian determinant is nonzero. In contrast, degenerate minima (right two examples) have different LLCs (both less than d/2d/2italic_d / 2) and a Hessian determinant of 0, reflecting a geometry fundamentally distinct from Gaussian. Estimating the LLC thus serves as a benchmark for assessing a sampler’s capacity to explore complex, degenerate posteriors.
Measuring volume scaling via sampling.

A SGMCMC algorithm which is producing good posterior samples from a neighbourhood of w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT should reflect the correct volume scaling rate λ(w0)\lambda(w_{0})italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) in (3.3). In other words, it should produce good estimates of the LLC. The LLC can be estimated by sampling from the posterior distribution (2.1) without direct access to L(w)L(w)italic_L ( italic_w ).

Definition 3.1 (Lau et al., 2024, Definition 1).

Let w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT be a local minimum of L(w)L(w)italic_L ( italic_w ) and let WWitalic_W be an open, connected neighbourhood of w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT such that the closure W¯\overline{W}over¯ start_ARG italic_W end_ARG is compact, L(w0)=infwW¯L(w)L(w_{0})=\inf_{w\in\overline{W}}L(w)italic_L ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = roman_inf start_POSTSUBSCRIPT italic_w ∈ over¯ start_ARG italic_W end_ARG end_POSTSUBSCRIPT italic_L ( italic_w ), and λ(w0)λ(w)\lambda(w_{0})\leq\lambda(w)italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≤ italic_λ ( italic_w ) for any wW¯w\in\overline{W}italic_w ∈ over¯ start_ARG italic_W end_ARG which satisfies L(w)=L(w0)L(w)=L(w_{0})italic_L ( italic_w ) = italic_L ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). The local learning coefficient estimator λ^(w0)\hat{\lambda}(w_{0})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) at w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is

λ^(w0)=nβ(𝐄wβ[Ln(w)]Ln(w0))\hat{\lambda}(w_{0})=n\beta(\mathbf{E}^{\beta}_{w}[L_{n}(w)]-L_{n}(w_{0}))over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_n italic_β ( bold_E start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT [ italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) ] - italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (3.4)

where 𝐄wβ\mathbf{E}^{\beta}_{w}bold_E start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT is an expectation over the tempered posterior (2.1) with prior φ(w)\varphi(w)italic_φ ( italic_w ) and support W¯\overline{W}over¯ start_ARG italic_W end_ARG.

The following theorem and Theorem 3.4 in the next section rely on a number of technical conditions, which we give in Definition E.1 in Appendix E:

Theorem 3.2 (Watanabe, 2013, Theorem 4).

Let w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT be a local minimum of L(w)L(w)italic_L ( italic_w ) and consider the local learning coefficient estimator λ^(w0)\hat{\lambda}(w_{0})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Let the inverse temperature in (2.1) be β=β0/logn\beta=\beta_{0}/\log nitalic_β = italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT / roman_log italic_n for some β0>0\beta_{0}>0italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > 0. Then, assuming the fundamental conditions of SLT (Definition E.1), we have

  1. λ^(w0)\hat{\lambda}(w_{0})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is an asymptotically unbiased estimator of λ(w0)\lambda(w_{0})italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) as nn\to\inftyitalic_n → ∞.

  2. If q(x)=p(x|w0)q(x)=p(x|w_{0})italic_q ( italic_x ) = italic_p ( italic_x | italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) for some w0Ww_{0}\in Witalic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ italic_W then

    Var(λ^(w0))λ(w0)β02log(n)+O(λ(w0)1/2β03/2log(n)3/2)+O(β02log(n)2)as n.\operatorname{Var}(\hat{\lambda}(w_{0}))\leq\frac{\lambda(w_{0})\beta_{0}}{2\log(n)}+O\left(\frac{\lambda(w_{0})^{1/2}\beta_{0}^{3/2}}{\log(n)^{3/2}}\right)+O\left(\frac{\beta_{0}^{2}}{\log(n)^{2}}\right)\qquad\text{as }n\to\infty.roman_Var ( over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ≤ divide start_ARG italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 roman_log ( italic_n ) end_ARG + italic_O ( divide start_ARG italic_λ ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG roman_log ( italic_n ) start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG ) + italic_O ( divide start_ARG italic_β start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG roman_log ( italic_n ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) as italic_n → ∞ .
Enforcing locality in practice.

In Lau et al. (2024), locality of λ^(w0)\hat{\lambda}(w_{0})over^ start_ARG italic_λ end_ARG ( italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is enforced by using a Gaussian prior φ(w)=(γ/π)d/2exp(γww02)\varphi(w)=\left(\gamma/\pi\right)^{d/2}\exp(-\gamma\|w-w_{0}\|^{2})italic_φ ( italic_w ) = ( italic_γ / italic_π ) start_POSTSUPERSCRIPT italic_d / 2 end_POSTSUPERSCRIPT roman_exp ( - italic_γ ∥ italic_w - italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) centred at w0w_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, where γ>0\gamma>0italic_γ > 0 is a hyperparameter of the estimator. We use the same prior in our experiments, acknowledging that this deviates from the theory above because it is not compactly supported. The way we use the prior in each sampling algorithm is made explicit in the pseudocode in Section D.3.

3.2 Deep linear network benchmark

As noted in Section 2.1, we lack theoretical convergence guarantees for SGMCMC algorithms in models such as neural networks. In the absence of these theoretical guarantees, we can instead empirically verify that samplers respect certain geometric invariants of the log-likelihood function. The local learning coefficient (LLC) from Section 3.1 is a natural choice.

We do not have ground-truth values for the LLC for most systems; the only known method for computing it exactly (i.e., other than via the estimator in Definition 3.1) involves computing a resolution of singularities (Hironaka, 1964), making the problem intractable in general. However, LLC values have recently been computed for deep linear networks (DLNs; Aoyagi, 2024). DLNs provide a scalable setting where the ground-truth LLC values are known.

Definition 3.3.

A deep linear network (DLN) with MMitalic_M layers of sizes H0,,HMH_{0},\ldots,H_{M}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT is a family of functions f(;𝒘):NNf(-;\boldsymbol{w}):\mathbb{R}^{N}\to\mathbb{R}^{N^{\prime}}italic_f ( - ; bold_italic_w ) : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT parametrised by vectors of matrices 𝒘=(W1,,WM)\boldsymbol{w}=(W_{1},\ldots,W_{M})bold_italic_w = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ) where WlW_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is a Hl×Hl1H_{l}\times H_{l-1}italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT matrix. We define f(x;𝒘)=Wxf(x;\boldsymbol{w})=Wxitalic_f ( italic_x ; bold_italic_w ) = italic_W italic_x where W=WMWM1W1W=W_{M}W_{M-1}\cdots W_{1}italic_W = italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and N=H0N=H_{0}italic_N = italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, N=HMN^{\prime}=H_{M}italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT.

The learning task takes the form of a regression task using the parametrised family of DLN functions defined in Definition 3.3, with the aim being to learn the function specified by an identified parameter 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We fix integers MMitalic_M and H0,,HMH_{0},\ldots,H_{M}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT for the number of layers and layer sizes of a DLN architecture, and let N=H0N=H_{0}italic_N = italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and N=HMN^{\prime}=H_{M}italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT. We fix a prior φ(𝒘)\varphi(\boldsymbol{w})italic_φ ( bold_italic_w ) on the set of all matrices parametrizing the DLN. Consider an input distribution q(x)q(x)italic_q ( italic_x ) on N\mathbb{R}^{N}blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT and a fixed input dataset X1,,XnX_{1},\ldots,X_{n}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT drawn independently from q(x)q(x)italic_q ( italic_x ). For a parameter 𝒘\boldsymbol{w}bold_italic_w, we consider noisy observations of the function’s behaviour f(Xi;𝒘)+Nif(X_{i};\boldsymbol{w})+N_{i}italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_w ) + italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, where Ni𝒩(0,σ2IN)N_{i}\sim\mathcal{N}(0,\sigma^{2}I_{N^{\prime}})italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I start_POSTSUBSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) are independent. The likelihood of observing (x,y)(x,y)( italic_x , italic_y ) generated using q(x)q(x)italic_q ( italic_x ) and the parameter 𝒘\boldsymbol{w}bold_italic_w is

p(x,y|𝒘)=1(2πσ2)N/2exp(12σ2yf(x;𝒘)2)q(x).p(x,y|\boldsymbol{w})=\frac{1}{(2\pi\sigma^{2})^{N^{\prime}/2}}\exp\left(-\frac{1}{2\sigma^{2}}\|y-f(x;\boldsymbol{w})\|^{2}\right)q(x).italic_p ( italic_x , italic_y | bold_italic_w ) = divide start_ARG 1 end_ARG start_ARG ( 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / 2 end_POSTSUPERSCRIPT end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ italic_y - italic_f ( italic_x ; bold_italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_q ( italic_x ) . (3.5)

We fix 𝒘0=(W1(0),,WM(0))\boldsymbol{w}_{0}=(W_{1}^{(0)},\ldots,W_{M}^{(0)})bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) and consider the task of learning the distribution q(x,y)=p(x,y|𝒘0)q(x,y)=p(x,y|\boldsymbol{w}_{0})italic_q ( italic_x , italic_y ) = italic_p ( italic_x , italic_y | bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Note that while we identify a single true parameter 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, when M>1M>1italic_M > 1 there are infinitely many 𝒘0𝒘0\boldsymbol{w}_{0}^{\prime}\neq\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT such that f(x;𝒘0)=f(x;𝒘0)f(x;\boldsymbol{w}_{0}^{\prime})=f(x;\boldsymbol{w}_{0})italic_f ( italic_x ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = italic_f ( italic_x ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). The dataset is Dn={(X1,Y1),,(Xn,Yn)}D_{n}=\{(X_{1},Y_{1}),\ldots,(X_{n},Y_{n})\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ( italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) } where Yi=f(Xi;𝒘0)+NiY_{i}=f(X_{i};\boldsymbol{w}_{0})+N_{i}italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We define

L(𝒘)=Nf(x;𝒘)f(x;𝒘0)2q(x)𝑑xL(\boldsymbol{w})=\int_{\mathbb{R}^{N}}\|f(x;\boldsymbol{w})-f(x;\boldsymbol{w}_{0})\|^{2}q(x)dxitalic_L ( bold_italic_w ) = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ italic_f ( italic_x ; bold_italic_w ) - italic_f ( italic_x ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_q ( italic_x ) italic_d italic_x (3.6)

which is, up to an additive constant, the expected negative log-likelihood 𝐄[p(X,Y|𝒘)]-\mathbf{E}\left[p(X,Y|\boldsymbol{w})\right]- bold_E [ italic_p ( italic_X , italic_Y | bold_italic_w ) ] where (X,Y)q(x,y)(X,Y)\sim q(x,y)( italic_X , italic_Y ) ∼ italic_q ( italic_x , italic_y ). This can be estimated using the dataset DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT as

Ln(𝒘)=1ni=1nYif(Xi;𝒘)2.L_{n}(\boldsymbol{w})=\frac{1}{n}\sum_{i=1}^{n}\|Y_{i}-f(X_{i};\boldsymbol{w})\|^{2}.italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.7)

While a DLN with M2M\geq 2italic_M ≥ 2 layers expresses exactly the same functions as a DLN with M=1M=1italic_M = 1 layer, the geometry of its parameter space is significantly more complex. When M2M\geq 2italic_M ≥ 2 the loss landscape has degenerate critical points, whereas when M=1M=1italic_M = 1 all critical points are non-degenerate. Aoyagi (2024) recently computed the LLC for DLNs, and we give this result in Theorem 3.4 below. We also refer readers to Lehalleur and Rimányi (2024).

Theorem 3.4 (Aoyagi, 2024, Theorem 1).

Consider a MMitalic_M-layer DLN with layer sizes H0,,HMH_{0},\ldots,H_{M}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT learning the regression task described above. Let r=rank(WM(0)WM1(0),W1(0))r=\operatorname{rank}(W_{M}^{(0)}W_{M-1}^{(0)}\cdots,W_{1}^{(0)})italic_r = roman_rank ( italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ⋯ , italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) and Δi=Hir\Delta_{i}=H_{i}-rroman_Δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_H start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_r. There exists a set of indices Σ{0,1,,M}\Sigma\subseteq\{0,1,\ldots,M\}roman_Σ ⊆ { 0 , 1 , … , italic_M } which satisfies:

  1. max{ΔσσΣ}<min{Δσ¯σ¯Σc}\max\left\{\Delta_{\sigma}\mid\sigma\in\Sigma\right\}<\min\left\{\Delta_{\overline{\sigma}}\mid\overline{\sigma}\in\Sigma^{c}\right\}roman_max { roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ∣ italic_σ ∈ roman_Σ } < roman_min { roman_Δ start_POSTSUBSCRIPT over¯ start_ARG italic_σ end_ARG end_POSTSUBSCRIPT ∣ over¯ start_ARG italic_σ end_ARG ∈ roman_Σ start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT }

  2. σΣΔσmax{ΔσσΣ}\sum_{\sigma\in\Sigma}\Delta_{\sigma}\geq\ell\cdot\max\left\{\Delta_{\sigma}\mid\sigma\in\Sigma\right\}∑ start_POSTSUBSCRIPT italic_σ ∈ roman_Σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ≥ roman_ℓ ⋅ roman_max { roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ∣ italic_σ ∈ roman_Σ }

  3. σΣΔσ<max{Δσ¯σ¯Σc}\sum_{\sigma\in\Sigma}\Delta_{\sigma}<\ell\cdot\max\left\{\Delta_{\overline{\sigma}}\mid\overline{\sigma}\in\Sigma^{c}\right\}∑ start_POSTSUBSCRIPT italic_σ ∈ roman_Σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT < roman_ℓ ⋅ roman_max { roman_Δ start_POSTSUBSCRIPT over¯ start_ARG italic_σ end_ARG end_POSTSUBSCRIPT ∣ over¯ start_ARG italic_σ end_ARG ∈ roman_Σ start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT }

where =|Σ|1\ell=|\Sigma|-1roman_ℓ = | roman_Σ | - 1 and Σc={0,1,,M}Σ\Sigma^{c}=\{0,1,\ldots,M\}\setminus\Sigmaroman_Σ start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT = { 0 , 1 , … , italic_M } ∖ roman_Σ. Assuming the fundamental conditions of SLT (Definition E.1), the LLC λ(𝐰0)\lambda(\boldsymbol{w}_{0})italic_λ ( bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) at the identified true parameter 𝐰0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is

λ(𝒘0)=12(r(H0+HM)r2)+14a(a)14(1)(σΣΔσ)2+12(σ,σ)Σ(2)ΔσΔσ\lambda(\boldsymbol{w}_{0})=\tfrac{1}{2}(r(H_{0}+H_{M})-r^{2})+\tfrac{1}{4\ell}a(\ell-a)-\tfrac{1}{4\ell}(\ell-1)\left(\sum_{\sigma\in\Sigma}\Delta_{\sigma}\right)^{2}+\tfrac{1}{2}\sum_{(\sigma,\sigma^{\prime})\in\Sigma^{(2)}}\Delta_{\sigma}\Delta_{\sigma^{\prime}}italic_λ ( bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_r ( italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ) - italic_r start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 4 roman_ℓ end_ARG italic_a ( roman_ℓ - italic_a ) - divide start_ARG 1 end_ARG start_ARG 4 roman_ℓ end_ARG ( roman_ℓ - 1 ) ( ∑ start_POSTSUBSCRIPT italic_σ ∈ roman_Σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT ( italic_σ , italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ roman_Σ start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT

where a=σΣΔσ(1σΣΔσ1)a=\sum_{\sigma\in\Sigma}\Delta_{\sigma}-\ell\left(\left\lceil\tfrac{1}{\ell}\sum_{\sigma\in\Sigma}\Delta_{\sigma}\right\rceil-1\right)italic_a = ∑ start_POSTSUBSCRIPT italic_σ ∈ roman_Σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT - roman_ℓ ( ⌈ divide start_ARG 1 end_ARG start_ARG roman_ℓ end_ARG ∑ start_POSTSUBSCRIPT italic_σ ∈ roman_Σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ⌉ - 1 ), xxx\mapsto\lceil x\rceilitalic_x ↦ ⌈ italic_x ⌉ is the ceiling function, and Σ(2)={(σ,σ)Σ×Σσ<σ}\Sigma^{(2)}=\{(\sigma,\sigma^{\prime})\in\Sigma\times\Sigma\mid\sigma<\sigma^{\prime}\}roman_Σ start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT = { ( italic_σ , italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ roman_Σ × roman_Σ ∣ italic_σ < italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT } is the set of all 2-combinations of Σ\Sigmaroman_Σ.

3.3 Deep linear network experiments

We generate learning problems in the same way as Lau et al. (2024). We consider four classes of DLN architecture 100K, 1M, 10M and 100M, whose names correspond approximately to the number of parameters of DLNs within that class. Each class is defined by integers MminM_{\text{min}}italic_M start_POSTSUBSCRIPT min end_POSTSUBSCRIPT, MmaxM_{\text{max}}italic_M start_POSTSUBSCRIPT max end_POSTSUBSCRIPT and HminH_{\text{min}}italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT, HmaxH_{\text{max}}italic_H start_POSTSUBSCRIPT max end_POSTSUBSCRIPT, specifying the minimum and maximum number of layers, and minimum and maximum layer size respectively (see Table 1 in Appendix D). A learning problem is generated by randomly generating an architecture within a given class, and then randomly generating a low rank true parameter 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (a detailed procedure is given in Section D.1). The input distribution is uniform on [10,10]N[-10,10]^{N}[ - 10 , 10 ] start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT where NNitalic_N is the number of neurons in the input layer.

We estimate the LLC by running a given SGMCMC algorithm for TTitalic_T steps starting at 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We provide pseudocode for our implementation Section D.3; notably the adaptive elements of AdamSGLD and RMSPropSGLD are based only on the statistics of the loss gradient and not the prior.

This results in a sequence 𝒘0,𝒘1,,𝒘T\boldsymbol{w}_{0},\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT of parameters. For all SGMCMC algorithms we use mini-batch estimates Lm,t(𝒘t)L_{m,t}(\boldsymbol{w}_{t})italic_L start_POSTSUBSCRIPT italic_m , italic_t end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) of Ln(𝒘t)L_{n}(\boldsymbol{w}_{t})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), where Lm,t(𝒘t)=1mj=1mf(Xuj,t;𝒘t)Yuj,t2L_{m,t}(\boldsymbol{w}_{t})=\tfrac{1}{m}\sum\nolimits_{j=1}^{m}\|f(X_{u_{j,t}};\boldsymbol{w}_{t})-Y_{u_{j,t}}\|^{2}italic_L start_POSTSUBSCRIPT italic_m , italic_t end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∥ italic_f ( italic_X start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_j , italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_Y start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_j , italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and Ut=(u1,t,,um,t)U_{t}=(u_{1,t},\ldots,u_{m,t})italic_U start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_u start_POSTSUBSCRIPT 1 , italic_t end_POSTSUBSCRIPT , … , italic_u start_POSTSUBSCRIPT italic_m , italic_t end_POSTSUBSCRIPT ) defines ttitalic_t-th batch of the dataset DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. Rather than using Ln(𝒘t)L_{n}(\boldsymbol{w}_{t})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) to estimate the LLC, instead assume that L¯𝐄𝒘βLn(𝒘)\overline{L}\approx\mathbf{E}^{\beta}_{\boldsymbol{w}}L_{n}(\boldsymbol{w})over¯ start_ARG italic_L end_ARG ≈ bold_E start_POSTSUPERSCRIPT italic_β end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) where L¯=1TBt=BTLm,t(𝒘t)\overline{L}=\tfrac{1}{T-B}\sum\nolimits_{t=B}^{T}L_{m,t}(\boldsymbol{w}_{t})over¯ start_ARG italic_L end_ARG = divide start_ARG 1 end_ARG start_ARG italic_T - italic_B end_ARG ∑ start_POSTSUBSCRIPT italic_t = italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_m , italic_t end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) for some number of ‘burn-in’ steps BBitalic_B. Inspired by Definition 3.1, we then estimate the LLC at 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as

λ¯(𝒘0)nβ(L¯Lm,0(𝒘0)).\overline{\lambda}(\boldsymbol{w}_{0})\coloneqq n\beta(\overline{L}-L_{m,0}(\boldsymbol{w}_{0})).over¯ start_ARG italic_λ end_ARG ( bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≔ italic_n italic_β ( over¯ start_ARG italic_L end_ARG - italic_L start_POSTSUBSCRIPT italic_m , 0 end_POSTSUBSCRIPT ( bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) . (3.8)

To improve this estimate, one could run C>1C>1italic_C > 1 independent sampling chains to obtain estimates however in this case we take C=1C=1italic_C = 1, preferring instead to run more experiments with different architectures. We give the hyperparameters used in LLC estimation in Table 2 in Appendix D.

4 Results

Refer to caption
Figure 3: Adaptive samplers like RMSPropSGLD and AdamSGLD achieve superior performance in estimating the local learning coefficient of 100M parameter deep linear networks. Top left: The mean relative error (λ^λ)/λ(\hat{\lambda}-\lambda)/\lambda( over^ start_ARG italic_λ end_ARG - italic_λ ) / italic_λ versus the step size of the sampler; bars indicate standard deviation. Bottom left: The proportion of estimated values which were NaN, indicating the sampler encountered numerical issues. RMSPropSGLD and AdamSGLD are less sensitive to step size, producing more accurate results across a larger range of step size values. Right: The mean relative error versus the standard deviation of the relative error, only plotting points where <10%<10\%< 10 % of estimates are NaN. RMSPropSGLD and AdamSGLD achieve a superior mean-variance trade-off. These observations are more pronounced in the 10M and 1M models; see Figures 5 and 6

To assess how well they capture the local posterior geometry, we apply the benchmark described in Section 3.2 to the following samplers: SGLD (Welling and Teh, 2011), SGLD with RMSProp preconditioning (RMSPropSGLD; Li et al. 2016), an “Adam-like” adaptive SGLD (AdamSGLD; Kim et al. 2020), SGHMC (Chen et al., 2014) and SGNHT (Ding et al., 2014). We give pseudocode for our implementation these samplers in Section D.3. We are interested not only in the absolute performance of a given sampler, but also in how sensitive its estimates are to the chosen step size ϵ\epsilonitalic_ϵ of the sampler. To assess the performance of a sampler at a given step size, we primarily consider the relative error (λ^λ)/λ(\hat{\lambda}-\lambda)/\lambda( over^ start_ARG italic_λ end_ARG - italic_λ ) / italic_λ of an LLC estimate λ^\hat{\lambda}over^ start_ARG italic_λ end_ARG with true value λ\lambdaitalic_λ.

RMSPropSGLD and AdamSGLD are less sensitive to step-size.

Figures 3, 5, 6 and 7 display the mean relative error (averaged over different networks generated from the model class) versus the step size of each sampler. While all samplers seem to be able to achieve empirically unbiased LLC estimates (relative error = 0) for some step size, RMSPropSGLD and AdamSGLD have a wider range of step size values which produce accurate results. For SGLD, at the step size values where the relative error is close to zero, a significant fraction of the LLC estimates are also diverging. In experiments on language models (Figure 4), RMSPropSGLD is also stable across a wider range of step sizes . This results in more consistent LLC estimates compared to SGLD.

RMSPropSGLD and AdamSGLD achieve a superior mean-variance tradeoff.

In Figures 3, 5, 6 and 7 we plot the mean of the relative error versus the standard deviation of the relative error for each sampler at different step sizes. We see that RMSPropSGLD and AdamSGLD obtain a superior combination of good mean performance and lower variance compared to the other samplers.

RMSPropSGLD and AdamSGLD are better at preserving order.

In many applications of LLC estimation, observing relative changes (e.g. over training) in the LLC is more important than determining absolute values (Chen et al., 2023; Hoogland et al., 2025; Wang et al., 2024). In these cases, the order of a sampler’s estimates should reflect the order of the true LLCs. We compute the order preservation rate of each sampler, which we define as the proportion of all pairs of estimates for which the order of the true LLCs matches the estimates. We plot this quantity versus the step size in Figure 8 in Appendix C. Again, RMSPropSGLD and AdamSGLD achieve superior performance to the other samplers with a higher order preservation rate across a wider range of step sizes.

RMSPropSGLD step size is easier to tune.

Above a certain step size RMSPropSGLD experiences rapid performance degradation, with LLC estimates which are orders of magnitude larger than the maximum theoretical value of d/2d/2italic_d / 2 and large spikes in the loss trace of the sampler. In contrast, AdamSGLD experiences a more gradual performance degradation and its loss traces do not obviously suggest the step size is set to high. We see this catastrophic drop-off in performance as an advantage of RMSPropSGLD, as it provides a clear signal that hyperparameters are incorrectly tuned in the absence of ground-truth LLC values. This clear signal is not present for AdamSGLD.

Refer to caption
Figure 4: RMSPropSGLD stabilizes sampling chains for an attention head in a four-layer attention-only transformer trained on the Pile (Section D.4), leading to more consistent and reliable LLC estimates.

5 Discussion

In this paper we introduced a scalable benchmark for evaluating the local sampling performance of SGMCMC algorithms. This benchmark is based on how well samplers estimate the local learning coefficient (LLC) for deep linear networks (DLNs). Since the LLC is the local volume-scaling rate for the log-likelihood function, this directly assesses how well samplers explore the local posterior.

Towards future empirical benchmarks.

The stochastic optimization literature has accumulated numerous benchmarks for assessing (optimization) performance on non-convex landscapes (e.g., Rastrigin, Ackley, and Griewank functions; Plevris and Solorzano 2022). However, these benchmarks focus primarily on multimodality and often ignore degeneracy. Our current work takes DLNs as an initial step towards developing a more representative, degeneracy-aware benchmark. A key limitation is that DLNs represent only one class of degenerate models and may not capture all forms of degeneracy encountered in general (see Lehalleur and Rimányi, 2024). Developing a wider set of degeneracy-aware benchmarks therefore remains an important direction for future research.

Towards future theoretical guarantees.

In Section 2.1, we establish that global convergence guarantees for sampling algorithms like SGLD rely on assumptions which provably do not hold for certain model classes (e.g. deep linear networks) with degenerate loss landscapes, and are unlikely to be compatible with degeneracy in general. Shifting from global to local convergence guarantees, which properly account for degeneracy, provides one promising way forward.

This shift may have broader implications beyond sampling. Many current convergence guarantees for stochastic optimizers make similar assumptions that may fail for degenerate landscapes (e.g., global Lipschitz or Polyak-Łojasiewicz conditions; Rebjock and Boumal 2024). Generally, the role of degeneracy in shaping the dynamics of sampling and optimization methods is not well understood.

Open problem: A theoretical explanation for the empirical success of local SGMCMC.

In this paper, we observed empirically that SGMCMC algorithms can successfully estimate the LLC despite the lack of theoretical convergence guarantees. The strong assumptions employed by the convergence results discussed in Section 2.1 arise, in some sense, because the goal is to prove global convergence to a posterior with support d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. It possible is that convergence results similar to those in Section 2.1 may be proved for compact parameter spaces, however this does not explain the success of local sampling we observe because our experiments are not in a compactly supported setting. This also places our empirical results outside of the setting of SLT (in particular Theorem 3.2 and Theorem 3.4). We see understanding precisely what determines the “effective support” of SGMCMC sampling chains in practice as a central issue in explaining why these samplers work in practice.

Appendix

This appendix contains supplementary details, proofs, and experimental results supporting the main text. Specifically, we include:

  • Appendix A provides examples of global and local degeneracies that are characteristic of modern deep neural network architectures.

  • Appendix B provides a proof of Lemma 2.1, which shows that the negative log-likelihood of deep linear networks is a polynomial of degree 2M2M2 italic_M.

  • Appendix C presents additional experimental results, focusing on the relative error, variance, and order preservation rates of various samplers across different deep linear network architectures.

  • Appendix D describes additional methodological details, including procedures for randomly generating deep linear network tasks, explicit hyperparameter settings for LLC estimation, and pseudocode for the implemented SGMCMC algorithms.

  • Appendix E summarizes the fundamental technical conditions required by singular learning theory (SLT), outlining the mathematical assumptions underlying our theoretical discussions.

Appendix A Examples of Degeneracy

Degenerate critical points in the loss landscape of neural networks can arise from symmetries in the parametrisation: continuous (or discrete) families of parameter settings that induce identical model outputs or leave the training loss unchanged. We distinguish global (or “generic”) symmetries, which hold throughout parameter space, from local and data-dependent degeneracies that arise only in particular regions in parameter space or only for particular data distributions. In this section we provide several examples (these are far from exhaustive).

A.1 Global degeneracies

Matrix–sandwich symmetry in deep linear networks.

For a DLN with composite weight W=WMW1W=W_{M}\cdots W_{1}italic_W = italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT one can insert any invertible matrix OOitalic_O between any neighbouring layers, e.g., W=(WMO)(O1WM1)WM2W1,W=(W_{M}O)(O^{-1}W_{M-1})W_{M-2}\!\dotsm W_{1},italic_W = ( italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_O ) ( italic_O start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_M - 2 end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , without changing the implemented function. This produces a GL(H)\,\mathrm{GL}(H)roman_GL ( italic_H )–orbit of equivalent parameters.

ReLU scaling symmetries.

Because ReLU(ax)=aReLU(x)\operatorname{ReLU}(ax)=a\,\operatorname{ReLU}(x)roman_ReLU ( italic_a italic_x ) = italic_a roman_ReLU ( italic_x ) for all a>0a>0italic_a > 0, scaling pre–activation weights by aaitalic_a and post–activation weights by a1a^{-1}italic_a start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT leaves the network invariant. Such positively scale–invariant (PSI) directions invalidate naïve flatness measures [Yi et al., 2019, Tsuzuku et al., 2020].

Permutation and sign symmetries.

Exchanging hidden units (or simultaneously flipping signs of incoming and outgoing weights) are examples of discrete changes that leave network outputs unchanged [Carroll, 2021].

Batch- and layer-normalization scaling symmetries.

BN and LN outputs do not change when their inputs pass through the same affine map [Laarhoven, 2017].

A.2 Local and data-dependent degeneracies

Low-rank DLNs.

When the end-to-end matrix WWitalic_W is rank-deficient, any transformation restricted to the null space can be absorbed by the factors WW_{\ell}italic_W start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT.

Elimination singularities.

If incoming weights to a given layer are zero, the associated outgoing weights are free to take any value. The reverse also holds: if outgoing weights are zero, this frees incoming weights to take any value. Residual or skip connections can help to bypass these degeneracies [Orhan and Pitkow, 2018].

Dead (inactive) ReLUs.

If a bias is large enough that a ReLU never activates or if the data distribution is such that the pretraining activations never exceed the bias, then the outgoing weights become free parameters; they can be set to arbitrary values without affecting loss because they are always multiplied by a zero activation. Incoming weights also become free parameters (up until the point that they change the preactivation distribution enough to activate the ReLU).

Always-active ReLUs.

Conversely, ReLUs that are always on behave linearly. In this regime, the incoming and outgoing weight matrices act as a DLN with the associated matrix-sandwich and low-rank degeneracies discussed above.

Overlap singularities.

If two neurons share the same incoming weights, then the outgoing weights become non-identifiable: in this regime, only the sum matters to the model’s functional behaviour [Orhan and Pitkow, 2018].

Appendix B Proof of Lemma 2.1

In this section we prove Lemma 2.1, which shows that the global convergence results for SGLD discussed in Section 2.1 do not apply to the regression problem for deep linear networks described in Section 3.2.

Consider an MMitalic_M-layer deep linear network with layer sizes H0,,HMH_{0},\ldots,H_{M}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT. Recall from Definition 3.3 that this is a family of functions f(;𝒘):NNf(-;\boldsymbol{w}):\mathbb{R}^{N}\to\mathbb{R}^{N^{\prime}}italic_f ( - ; bold_italic_w ) : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT parametrised by matrices 𝒘=(W1,,WM)\boldsymbol{w}=(W_{1},\ldots,W_{M})bold_italic_w = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ) where WlW_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is Hl×Hl1H_{l}\times H_{l-1}italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT and N=H0N=H_{0}italic_N = italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and N=HMN^{\prime}=H_{M}italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT. By definition we have

f(x;𝒘)=WMWM1W1x.f(x;\boldsymbol{w})=W_{M}W_{M-1}\cdots W_{1}x.italic_f ( italic_x ; bold_italic_w ) = italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x .

We consider the regression task described in Section 3. Let q(x)q(x)italic_q ( italic_x ) be an absolutely continuous input distribution on N\mathbb{R}^{N}blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT and X1,,XnX_{1},\ldots,X_{n}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT be an input dataset drawn from independently from q(x)q(x)italic_q ( italic_x ). For an identified parameter 𝒘0=(W1(0),,WM(0))\boldsymbol{w}_{0}=(W_{1}^{(0)},\ldots,W_{M}^{(0)})bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) we consider the task of learning the function f(x;𝒘0)f(x;\boldsymbol{w}_{0})italic_f ( italic_x ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). That is, we consider the statistical model

p(x,y|𝒘)=1(2πσ2)N/2exp(12σ2yf(x;𝒘)2)q(x)p(x,y|\boldsymbol{w})=\frac{1}{(2\pi\sigma^{2})^{N^{\prime}/2}}\exp\left(-\frac{1}{2\sigma^{2}}\|y-f(x;\boldsymbol{w})\|^{2}\right)q(x)italic_p ( italic_x , italic_y | bold_italic_w ) = divide start_ARG 1 end_ARG start_ARG ( 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / 2 end_POSTSUPERSCRIPT end_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ italic_y - italic_f ( italic_x ; bold_italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_q ( italic_x )

from (3.5), where the true distribution is q(x,y)=p(x,y|𝒘0)q(x,y)=p(x,y|\boldsymbol{w}_{0})italic_q ( italic_x , italic_y ) = italic_p ( italic_x , italic_y | bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). As in (3.6) and (3.7) we consider

L(𝒘)=Nf(x;𝒘)f(x;𝒘0)2q(x)𝑑xL(\boldsymbol{w})=\int_{\mathbb{R}^{N}}\|f(x;\boldsymbol{w})-f(x;\boldsymbol{w}_{0})\|^{2}q(x)dxitalic_L ( bold_italic_w ) = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ italic_f ( italic_x ; bold_italic_w ) - italic_f ( italic_x ; bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_q ( italic_x ) italic_d italic_x

and

Ln(𝒘)=1ni=1nf(Xi;w)f(Xi;w0)2.L_{n}(\boldsymbol{w})=\frac{1}{n}\sum_{i=1}^{n}\|f(X_{i};w)-f(X_{i};w_{0})\|^{2}.italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w ) - italic_f ( italic_X start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

which are, up to additive constants (irrelevant when computing gradients), the expected and empirical negative log-likelihood respectively. When sampling from the tempered posterior distribution (2.1) using a Langevin diffusion based sampler like SGLD, Ln(w)\nabla L_{n}(w)∇ italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) is used to compute sampler steps along with the prior and noise terms (see Section 2.1).

For convenience in the following proof, and in-line with the experimental method in Section 3.3, we consider a prior which is an isotropic Gaussian distribution on 𝒘\boldsymbol{w}bold_italic_w

φ(𝒘)=(γ2π)d/2exp(γ2𝒘𝒘μ2)\varphi(\boldsymbol{w})=\left(\frac{\gamma}{2\pi}\right)^{d/2}\exp\left(-\frac{\gamma}{2}\|\boldsymbol{w}-\boldsymbol{w}_{\mu}\|^{2}\right)italic_φ ( bold_italic_w ) = ( divide start_ARG italic_γ end_ARG start_ARG 2 italic_π end_ARG ) start_POSTSUPERSCRIPT italic_d / 2 end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG ∥ bold_italic_w - bold_italic_w start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (B.1)

where γ>0\gamma>0italic_γ > 0 and 𝒘μ\boldsymbol{w}_{\mu}bold_italic_w start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT is any fixed parameter. By 𝒘𝒘μ2\|\boldsymbol{w}-\boldsymbol{w}_{\mu}\|^{2}∥ bold_italic_w - bold_italic_w start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT we mean to take the sum of the square of all matrix entries, treating the parameter 𝒘\boldsymbol{w}bold_italic_w in this expression as a vector with d=l=1MHlHl1d=\sum_{l=1}^{M}H_{l}H_{l-1}italic_d = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT entries.

Lemma B.1 (restating Lemma 2.1).

Consider the above situation of an MMitalic_M-layer deep linear network learning the described regression task. If q(x)q(x)italic_q ( italic_x ) is absolutely continuous then with probability one Ln(𝐰)L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) is a degree 2M2M2 italic_M polynomial in the matrix entries 𝐰=(W1,,WM)\boldsymbol{w}=(W_{1},\ldots,W_{M})bold_italic_w = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ).

Proof.

We treat each parameter of the model as a different polynomial variable; that is, for each l=1,Ml=1,\ldots Mitalic_l = 1 , … italic_M we have Wl=(wi,j,ll)W_{l}=(w^{l}_{i,j,l})italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = ( italic_w start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT ) where wi,j,lw_{i,j,l}italic_w start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT are distinct polynomial variables. As before let W=WMWM1W1W=W_{M}W_{M-1}\cdots W_{1}italic_W = italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. For a matrix UUitalic_U with polynomial entries we denote by deg(U)\deg(U)roman_deg ( italic_U ) the maximum degree of its entries. Hence we have that deg(W)=M\deg(W)=Mroman_deg ( italic_W ) = italic_M since each entry of WWitalic_W is a sum of monomials of the form l=1Mwil,jl,l\prod_{l=1}^{M}w_{i_{l},j_{l},l}∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_l end_POSTSUBSCRIPT. Denote the entries of WWitalic_W by Pij(𝒘)P_{ij}(\boldsymbol{w})italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ), which is a polynomial of degree MMitalic_M in the variables wi,j,lw_{i,j,l}italic_w start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT. The entries of W(0)W^{(0)}italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT are constants, and we denote them by aija_{ij}italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT.

We treat each parameter of the model as a different polynomial variable; that is, for each l=1,Ml=1,\ldots Mitalic_l = 1 , … italic_M we have Wl=(wi,j,ll)W_{l}=(w^{l}_{i,j,l})italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = ( italic_w start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT ) where wi,j,lw_{i,j,l}italic_w start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT are distinct polynomial variables. As before let W=WMWM1W1W=W_{M}W_{M-1}\cdots W_{1}italic_W = italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_M - 1 end_POSTSUBSCRIPT ⋯ italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. For a matrix UUitalic_U with polynomial entries we denote by deg(U)\deg(U)roman_deg ( italic_U ) the maximum degree of its entries. Hence we have that deg(W)=M\deg(W)=Mroman_deg ( italic_W ) = italic_M since each entry of WWitalic_W is a sum of monomials of the form l=1Mwil,jl,l\prod_{l=1}^{M}w_{i_{l},j_{l},l}∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_j start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_l end_POSTSUBSCRIPT. Denote the entries of WWitalic_W by Pij(𝒘)P_{ij}(\boldsymbol{w})italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ), which is a polynomial of degree MMitalic_M in the variables wi,j,lw_{i,j,l}italic_w start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT. The entries of W(0)W^{(0)}italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT are constants, and we denote them by aija_{ij}italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT.

We now consider the degree of the likelihood function Ln(𝒘)L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) as a polynomial in the variables wi,j,lw_{i,j,l}italic_w start_POSTSUBSCRIPT italic_i , italic_j , italic_l end_POSTSUBSCRIPT. Let N=H0N=H_{0}italic_N = italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT denote the input dimension of the deep linear network and N=HMN^{\prime}=H_{M}italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT the output dimension. First note that square of the iiitalic_i-th coordinate of WXkW(0)XkWX_{k}-W^{(0)}X_{k}italic_W italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is

(WXkW0Xk)i2\displaystyle(WX_{k}-W^{0}X_{k})_{i}^{2}( italic_W italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT =(j=1NPij(𝒘)Xk(j)aijXk(j))2\displaystyle=\left(\sum_{j=1}^{N}P_{ij}(\boldsymbol{w})X_{k}^{(j)}-a_{ij}X_{k}^{(j)}\right)^{2}= ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT - italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=j=1Nj=1NXk(j)Xk(j)(Pij(𝒘)Pij(𝒘)aijPij(𝒘)aijPij(𝒘)+aijaij)\displaystyle=\sum_{j=1}^{N}\sum_{j^{\prime}=1}^{N}X_{k}^{(j)}X_{k}^{(j^{\prime})}\left(P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})-a_{ij}P_{ij^{\prime}}(\boldsymbol{w})-a_{ij^{\prime}}P_{ij}(\boldsymbol{w})+a_{ij}a_{ij^{\prime}}\right)= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) - italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) - italic_a start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) + italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )

where Xk(j)X_{k}^{(j)}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT is the jjitalic_j-th coordinate of XkX_{k}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Hence we have

Ln(𝒘)\displaystyle L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) =1nk=1ni=1N(WXkW0Xk)i2\displaystyle=\frac{1}{n}\sum_{k=1}^{n}\sum_{i=1}^{N^{\prime}}(WX_{k}-W^{0}X_{k})_{i}^{2}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_W italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_W start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=1nk=1ni=1Nj=1Nj=1NXk(j)Xk(j)Pij(𝒘)Pij(𝒘)+(terms of at most degree M)\displaystyle=\frac{1}{n}\sum_{k=1}^{n}\sum_{i=1}^{N^{\prime}}\sum_{j=1}^{N}\sum_{j^{\prime}=1}^{N}X_{k}^{(j)}X_{k}^{(j^{\prime})}P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})+\left(\text{terms of at most degree }M\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) + ( terms of at most degree italic_M )
=1ni=1Nj=1Nj=1N(k=1nXk(j)Xk(j))Pij(𝒘)Pij(𝒘)+(terms of at most degree M)\displaystyle=\frac{1}{n}\sum_{i=1}^{N^{\prime}}\sum_{j=1}^{N}\sum_{j^{\prime}=1}^{N}\left(\sum_{k=1}^{n}X_{k}^{(j)}X_{k}^{(j^{\prime})}\right)P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})+\left(\text{terms of at most degree }M\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT ) italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) + ( terms of at most degree italic_M )

The polynomial Pij(𝒘)Pij(𝒘)P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) has degree 2M2M2 italic_M. The only way Ln(𝒘)L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) can have degree smaller than 2M2M2 italic_M is if the random coefficients k=1nXk(j)Xk(j)\sum_{k=1}^{n}X_{k}^{(j)}X_{k}^{(j^{\prime})}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j ) end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT result in cancellation between the terms of different Pij(𝒘)Pij(𝒘)P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ). We now show that this does not happen with probability one. Consider a function f:nNf:\mathbb{R}^{nN}\to\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n italic_N end_POSTSUPERSCRIPT → blackboard_R given by

f(𝒙)=(b,b)Λcbbxbxbf(\boldsymbol{x})=\sum_{(b,b^{\prime})\in\Lambda}c_{bb^{\prime}}x_{b}x_{b}^{\prime}italic_f ( bold_italic_x ) = ∑ start_POSTSUBSCRIPT ( italic_b , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ roman_Λ end_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_b italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

where Λ\Lambdaroman_Λ is any non-empty subset of {1,,nN}2\{1,\ldots,nN\}^{2}{ 1 , … , italic_n italic_N } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and cbb{0}c_{bb^{\prime}}\in\mathbb{R}\setminus\{0\}italic_c start_POSTSUBSCRIPT italic_b italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R ∖ { 0 }. The set f1(0)nNf^{-1}(0)\subseteq\mathbb{R}^{nN}italic_f start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( 0 ) ⊆ blackboard_R start_POSTSUPERSCRIPT italic_n italic_N end_POSTSUPERSCRIPT has Lebesgue measure zero, and hence has measure zero with respect to the joint distribution of the dataset DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT considered as a distribution on nN\mathbb{R}^{nN}blackboard_R start_POSTSUPERSCRIPT italic_n italic_N end_POSTSUPERSCRIPT, since q(x)q(x)italic_q ( italic_x ) is assumed to be absolutely continuous. It follows that cancellation of the degree 2M2M2 italic_M monomials in the above expression for Ln(𝒘)L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) cannot occur with probability greater than zero, and thus deg(Ln(𝒘))=2M\deg(L_{n}(\boldsymbol{w}))=2Mroman_deg ( italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) ) = 2 italic_M with probability one.

Recall from Section 2.1 that proofs about the global convergence of SGLD assume that there exists a function V:d[1,)V:\mathbb{R}^{d}\to[1,\infty)italic_V : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → [ 1 , ∞ ) with bounded second derivatives which satisfies

V(𝒘)2+logπ(𝒘)2CV(𝒘) for all wd.\|\nabla V(\boldsymbol{w})\|^{2}+\|\nabla\log\pi(\boldsymbol{w})\|^{2}\leq CV(\boldsymbol{w})\qquad\text{ for all }w\in\mathbb{R}^{d}.∥ ∇ italic_V ( bold_italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ ∇ roman_log italic_π ( bold_italic_w ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_C italic_V ( bold_italic_w ) for all italic_w ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT .

See Teh et al. [2015, Assumption 4]. In the setting of deep linear networks we have logπ(𝒘)=nβLn(𝒘)\nabla\log\pi(\boldsymbol{w})=-n\beta\nabla L_{n}(\boldsymbol{w})∇ roman_log italic_π ( bold_italic_w ) = - italic_n italic_β ∇ italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ). Since the second derivatives of VVitalic_V are bounded and Ln(𝒘)L_{n}(\boldsymbol{w})italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_italic_w ) is a degree 2M2M2 italic_M polynomial, this condition can only hold when M=1M=1italic_M = 1. This corresponds precisely to the case when all critical points of Ln(w)L_{n}(w)italic_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_w ) are non-degenerate. Likewise, the global Lipschitz condition

logπ(w1)logπ(w2)αw1w2 for all w1,w2d.\|\nabla\log\pi(w_{1})-\nabla\log\pi(w_{2})\|\leq\alpha\|w_{1}-w_{2}\|\qquad\text{ for all }w_{1},w_{2}\in\mathbb{R}^{d}.∥ ∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - ∇ roman_log italic_π ( italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∥ ≤ italic_α ∥ italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ for all italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT .

can also only be satisfied when M=1M=1italic_M = 1.

Remark B.2.

In detailed treatments of singular models such as Watanabe [2009] it is often more convenient to analyse the distribution

π(𝒘)exp(nβL(𝒘))φ(𝒘)\pi(\boldsymbol{w})\propto\exp(-n\beta L(\boldsymbol{w}))\varphi(\boldsymbol{w})italic_π ( bold_italic_w ) ∝ roman_exp ( - italic_n italic_β italic_L ( bold_italic_w ) ) italic_φ ( bold_italic_w ) (B.2)

in-place of the tempered posterior distribution (2.1). The geometry of L(𝒘)L(\boldsymbol{w})italic_L ( bold_italic_w ) determines much of the learning behaviour of singular statistical models. In SGMCMC algorithms, a stochastic estimate g(𝒘,U)g(\boldsymbol{w},U)italic_g ( bold_italic_w , italic_U ) of the gradient of the log-posterior could equally be considered an estimate of logπ(𝒘)\nabla\log\pi(\boldsymbol{w})∇ roman_log italic_π ( bold_italic_w ), where π(𝒘)\pi(\boldsymbol{w})italic_π ( bold_italic_w ) is as in (B.2). In the case of deep linear networks, a result similar to Lemma B.1 can be shown using (B.2) in-place of the usual posterior distribution. In this case we have

L(𝒘)=𝐄WXW(0)X2L(\boldsymbol{w})=\mathbf{E}\|WX-W^{(0)}X\|^{2}italic_L ( bold_italic_w ) = bold_E ∥ italic_W italic_X - italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT italic_X ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where Xq(x)X\sim q(x)italic_X ∼ italic_q ( italic_x ). From the proof of Lemma B.1 we have that

L(𝒘)\displaystyle L(\boldsymbol{w})italic_L ( bold_italic_w ) =𝐄[i=1Nj=1Nj=1NXjXjPij(𝒘)Pij(𝒘)]+(terms of at most degree M)\displaystyle=\mathbf{E}\left[\sum_{i=1}^{N^{\prime}}\sum_{j=1}^{N}\sum_{j^{\prime}=1}^{N}X_{j}X_{j^{\prime}}P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})\right]+\left(\text{terms of at most degree }M\right)= bold_E [ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) ] + ( terms of at most degree italic_M )
=i=1Nj=1Nj=1N𝐄[XjXj]Pij(𝒘)Pij(𝒘)+(terms of at most degree M).\displaystyle=\sum_{i=1}^{N^{\prime}}\sum_{j=1}^{N}\sum_{j^{\prime}=1}^{N}\mathbf{E}\left[X_{j}X_{j^{\prime}}\right]P_{ij}(\boldsymbol{w})P_{ij^{\prime}}(\boldsymbol{w})+\left(\text{terms of at most degree }M\right).= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_E [ italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] italic_P start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ( bold_italic_w ) italic_P start_POSTSUBSCRIPT italic_i italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_w ) + ( terms of at most degree italic_M ) .

where we now write XjX_{j}italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT for the jjitalic_j-th coordinate of XXitalic_X. If each coordinate of XXitalic_X is independent and identically distributed (as in the experiments in Section 3.3) then 𝐄[Xj2]>0\mathbf{E}\left[X_{j}^{2}\right]>0bold_E [ italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] > 0 and 𝐄[XjXj]0\mathbf{E}\left[X_{j}X_{j^{\prime}}\right]\geq 0bold_E [ italic_X start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_X start_POSTSUBSCRIPT italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] ≥ 0 for jjj\neq j^{\prime}italic_j ≠ italic_j start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. It follows that L(𝒘)L(\boldsymbol{w})italic_L ( bold_italic_w ) has degree 2M2M2 italic_M, since all monomial terms with degree 2M2M2 italic_M in the above expression for L(𝒘)L(\boldsymbol{w})italic_L ( bold_italic_w ) appear with non-negative coefficients, and at least some are non-zero.

Appendix C Additional results

In this section we give additional results from the experiments described in Section 3.3. In Figures 5, 6 and 7 we present the relative error (λ^λ)/λ(\hat{\lambda}-\lambda)/\lambda( over^ start_ARG italic_λ end_ARG - italic_λ ) / italic_λ in the estimated local learning coefficient for the deep linear network model classes 10M, 1M and 100K (the results for the 100M model class is given in Figure 3 in the main text).

In Figure 8 we present the order preservation rate of each sampling algorithm, for each deep linear network model class. This assesses how good a sampling algorithm is at preserving the ordering of local learning coefficient estimates. These results are discussed in more detail in Section 4 in the main text.

Refer to caption
Figure 5: We assess the performance of samplers in estimating the local learning coefficient of 10M parameter deep linear networks. As for the 100M parameter models (Figure 3) we see that RMSPropSGLD and AdamSGLD are less sensitive to step size and achieve a superior mean-variance trade-off.
Refer to caption
Figure 6: We assess the performance of samplers in estimating the local learning coefficient of 1M parameter deep linear networks. As for the 100M parameter models (Figure 3) we see that RMSPropSGLD and AdamSGLD are less sensitive to step size and achieve a superior mean-variance trade-off.
Refer to caption
Figure 7: We assess the performance of samplers in estimating the local learning coefficient of 100K parameter deep linear networks. Here RMSPropSGLD and AdamSGLD still appear better, though the picture is less clear compared to the 1M, 10M or 100M parameter models. This may suggest that the best choice of sampler may depend on the scale of the sampling problem, though this requires further investigation.
Refer to caption
(a) 100M
Refer to caption
(b) 10M
Refer to caption
(c) 1M
Refer to caption
(d) 100K
Figure 8: We assess how well each sampler preserves ordering of true LLCs in the estimated values, reporting the order preservation rate as the proportion of pairs of true LLC values λ1<λ2\lambda_{1}<\lambda_{2}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT where the estimates are correctly ordered as λ^1<λ^2\hat{\lambda}_{1}<\hat{\lambda}_{2}over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < over^ start_ARG italic_λ end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. We see that RMSPropSGLD and AdamSGLD are better at preserving order across a wider range of of step sizes than the other samplers for larger models. For RMSPropSGLD and AdamSGLD, good order preservation performances emerges at smaller step sizes compared to accuracy (see Figures 3, 5, 6, 7).

Appendix D Additional methodology details

In this sectional methodology details, supplemental to Section 3.3. In Section D.1 we describe the procedure for generating deep linear networks, mentioned in Section 3.3. In Section D.3 we give pseudocode for the samplers we benchmark: SGLD (Algorithm 1), AdamSGLD (Algorithm 2), RMSPropSGLD (Algorithm 3), SGHMC (Algorithm 4) and SGNHT (Algorithm 5). In Section D.4 we give details of the large language model experiments presented in Figure 4.

D.1 Deep linear network generation

With the notation introduced in Section 3.3, a deep linear network is generated as follows. The values of MminM_{\text{min}}italic_M start_POSTSUBSCRIPT min end_POSTSUBSCRIPT, MmaxM_{\text{max}}italic_M start_POSTSUBSCRIPT max end_POSTSUBSCRIPT, HminH_{\text{min}}italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT and HmaxH_{\text{max}}italic_H start_POSTSUBSCRIPT max end_POSTSUBSCRIPT for each model class are given in Table 1.

  1. Choose a number of layers MMitalic_M uniformly from {Mmin,Mmin+1,,Mmax}\{M_{\text{min}},M_{\text{min}}+1,\ldots,M_{\text{max}}\}{ italic_M start_POSTSUBSCRIPT min end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT min end_POSTSUBSCRIPT + 1 , … , italic_M start_POSTSUBSCRIPT max end_POSTSUBSCRIPT }.

  2. Choose layer sizes HlH_{l}italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT uniformly from {Hmin,Hmin+1,,Hmax}\{H_{\text{min}},H_{\text{min}}+1,\ldots,H_{\text{max}}\}{ italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT + 1 , … , italic_H start_POSTSUBSCRIPT max end_POSTSUBSCRIPT }, for l=1,,Ml=1,\ldots,Mitalic_l = 1 , … , italic_M.

  3. Generate a fixed parameter 𝒘0=(W1(0),,WM(0))\boldsymbol{w}_{0}=(W_{1}^{(0)},\ldots,W_{M}^{(0)})bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ) where Wl(0)W^{(0)}_{l}italic_W start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is a Hl×Hl1H_{l}\times H_{l-1}italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT matrix generated according to the Xavier-normal distribution: each entry is drawn independently from 𝒩(0,σ2)\mathcal{N}(0,\sigma^{2})caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) where σ2=2Hl+Hl1\sigma^{2}=\tfrac{2}{H_{l}+H_{l-1}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 2 end_ARG start_ARG italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_ARG.

  4. To obtain lower rank true parameters we modify 𝒘0\boldsymbol{w}_{0}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as follows. For each layer l=1,,Ml=1,\ldots,Mitalic_l = 1 , … , italic_M we choose whether or not to reduce the rank of WlW_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT with probability 0.50.50.5. If we choose to do so, we choose a new rank rritalic_r uniformly from {0,,min(Hl,Hl1)}\{0,\ldots,\min(H_{l},H_{l-1})\}{ 0 , … , roman_min ( italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) } and set some number of rows or columns of WlW_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT to zero to force it to be at most rank rritalic_r.

The above process results in a deep linear network f(x;𝒘)f(x;\boldsymbol{w})italic_f ( italic_x ; bold_italic_w ) with MMitalic_M layers and layer sizes H0,,HMH_{0},\ldots,H_{M}italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT, along with an identified true parameter 𝒘\boldsymbol{w}bold_italic_w. We take the input distribution to be the uniform distribution on [10,10]N[-10,10]^{N}[ - 10 , 10 ] start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT where N=H0N=H_{0}italic_N = italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, and the output noise distribution is 𝒩(𝟎,σ2IN)\mathcal{N}(\boldsymbol{0},\sigma^{2}I_{N^{\prime}})caligraphic_N ( bold_0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I start_POSTSUBSCRIPT italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) where N=HMN^{\prime}=H_{M}italic_N start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_H start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT and σ2=1/4\sigma^{2}=1/4italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 1 / 4.

100K 1M 10M 100M
Minimum number of layers MminM_{\text{min}}italic_M start_POSTSUBSCRIPT min end_POSTSUBSCRIPT 2 2 2 2
Maximum number of layers MmaxM_{\text{max}}italic_M start_POSTSUBSCRIPT max end_POSTSUBSCRIPT 10 20 20 40
Minimum layer size HminH_{\text{min}}italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT 50 100 500 500
Maximum layer size HmaxH_{\text{max}}italic_H start_POSTSUBSCRIPT max end_POSTSUBSCRIPT 500 1000 2000 3000
Table 1: Deep linear network architecture hyperparameters
Number of sampling steps 5×1045\times 10^{4}5 × 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT
Number of burn-in steps BBitalic_B 0.9T0.9T0.9 italic_T
Dataset size nnitalic_n 10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT
Batch size mmitalic_m 500500500
Localisation parameter γ\gammaitalic_γ 111
Inverse-temperature parameter β\betaitalic_β 1/log(n)1/\log(n)1 / roman_log ( italic_n )
Number of learning problems 100100100
Table 2: Hyperparameters for local learning coefficient estimation in deep linear networks.

D.2 Deep linear network compute details

The DLN experiments were run on a cluster using NVIDIA H100 GPUs. A batch of 100 independent experiments took approximately 7 GPU hours for the 100M model class, 5 GPU hours for the 10M model class, 3.5 GPU hours for the 1M model class, and 0.9 GPU hours for the 100K model class. The total compute used for the DLN experiments was approximately 1000 GPU hours.

D.3 Samplers

Algorithm 1 SGLD
1:Inputs: 𝒘0d\boldsymbol{w}_{0}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (initial parameter), Dn={𝒛1,,𝒛n}D_{n}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } (dataset), :N×d\ell:\mathbb{R}^{N}\times\mathbb{R}^{d}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R (loss function)
2:Outputs: 𝒘1,,𝒘Td\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT
3:Hyperparameters: ϵ>0\epsilon>0italic_ϵ > 0 (step size), γ>0\gamma>0italic_γ > 0 (localization), β~>0\tilde{\beta}>0over~ start_ARG italic_β end_ARG > 0 (posterior temperature), mm\in\mathbb{N}italic_m ∈ blackboard_N (batch size)
4:t0:T1t\leftarrow 0:T-1italic_t ← 0 : italic_T - 1
5:Draw a batch zu1,,zumz_{u_{1}},\ldots,z_{u_{m}}italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT from DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.
6:𝒈t1mk=1m𝒘(𝒛uk,𝒘t)\boldsymbol{g}_{t}\leftarrow\frac{1}{m}\sum_{k=1}^{m}\nabla_{\boldsymbol{w}}\ell(\boldsymbol{z}_{u_{k}},\boldsymbol{w}_{t})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) \triangleright Gradient with respect to the parameter argument.
7:Draw 𝜼t\boldsymbol{\eta}_{t}bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from the standard normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
8:Δ𝒘tϵ2(γ(𝒘t𝒘0)+β~𝒈t)+ϵ𝜼t\Delta\boldsymbol{w}_{t}\leftarrow\tfrac{-\epsilon}{2}\left(\gamma(\boldsymbol{w}_{t}-\boldsymbol{w}_{0})+\tilde{\beta}\boldsymbol{g}_{t}\right)+\sqrt{\epsilon}\boldsymbol{\eta}_{t}roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG - italic_ϵ end_ARG start_ARG 2 end_ARG ( italic_γ ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + over~ start_ARG italic_β end_ARG bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + square-root start_ARG italic_ϵ end_ARG bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
9:𝒘t+1𝒘t+Δ𝒘t\boldsymbol{w}_{t+1}\leftarrow\boldsymbol{w}_{t}+\Delta\boldsymbol{w}_{t}bold_italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 2 AdamSGLD
1:Inputs: 𝒘0d\boldsymbol{w}_{0}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (initial parameter), Dn={𝒛1,,𝒛n}D_{n}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } (dataset), :N×d\ell:\mathbb{R}^{N}\times\mathbb{R}^{d}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R (loss function)
2:Outputs: 𝒘1,,𝒘Td\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT
3:Hyperparameters: ϵ>0\epsilon>0italic_ϵ > 0 (base step size), γ>0\gamma>0italic_γ > 0 (localization), β~>0\tilde{\beta}>0over~ start_ARG italic_β end_ARG > 0 (posterior temperature), mm\in\mathbb{N}italic_m ∈ blackboard_N (batch size), a>0a>0italic_a > 0 (stability), b1,b2(0,1)b_{1},b_{2}\in(0,1)italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ ( 0 , 1 ) (EMA decay rates)
4:
5:𝒎1(0,0,,0)d\boldsymbol{m}_{-1}\leftarrow(0,0,\ldots,0)\in\mathbb{R}^{d}bold_italic_m start_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ← ( 0 , 0 , … , 0 ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
6:𝒗1(1,1,,1)d\boldsymbol{v}_{-1}\leftarrow(1,1,\ldots,1)\in\mathbb{R}^{d}bold_italic_v start_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ← ( 1 , 1 , … , 1 ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. t0:T1t\leftarrow 0:T-1italic_t ← 0 : italic_T - 1
7:Draw a batch zu1,,zumz_{u_{1}},\ldots,z_{u_{m}}italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT from DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.
8:𝒈t1mk=1m𝒘(𝒛uk,𝒘t)\boldsymbol{g}_{t}\leftarrow\frac{1}{m}\sum_{k=1}^{m}\nabla_{\boldsymbol{w}}\ell(\boldsymbol{z}_{u_{k}},\boldsymbol{w}_{t})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) \triangleright Gradient with respect to the parameter argument.
9:𝒎tb1𝒎t1+(1b1)𝒈t\boldsymbol{m}_{t}\leftarrow b_{1}\boldsymbol{m}_{t-1}+(1-b_{1})\boldsymbol{g}_{t}bold_italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_m start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + ( 1 - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
10:Define 𝒗t\boldsymbol{v}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by 𝒗t[i]b2𝒗t1[i]+(1b2)𝒈t[i]2\boldsymbol{v}_{t}[i]\leftarrow b_{2}\boldsymbol{v}_{t-1}[i]+(1-b_{2})\boldsymbol{g}_{t}[i]^{2}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT [ italic_i ] + ( 1 - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d.
11:𝒎^t11b1t𝒎t\hat{\boldsymbol{m}}_{t}\leftarrow\frac{1}{1-b_{1}^{t}}\boldsymbol{m}_{t}over^ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG 1 - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG bold_italic_m start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
12:𝒗^t11b2t𝒗t\hat{\boldsymbol{v}}_{t}\leftarrow\frac{1}{1-b_{2}^{t}}\boldsymbol{v}_{t}over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG 1 - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
13:Define ϵt\boldsymbol{\epsilon}_{t}bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by ϵt[i]ϵ𝒗^t[t]+a\boldsymbol{\epsilon}_{t}[i]\leftarrow\frac{\epsilon}{\sqrt{\hat{\boldsymbol{v}}_{t}[t]}+a}bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← divide start_ARG italic_ϵ end_ARG start_ARG square-root start_ARG over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_t ] end_ARG + italic_a end_ARG for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d. \triangleright Step size of each parameter.
14:Draw 𝜼t\boldsymbol{\eta}_{t}bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from the standard normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
15:Define Δ𝒘t\Delta\boldsymbol{w}_{t}roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by Δ𝒘t[i]ϵt[i]2(γ(𝒘t[i]𝒘0[i])+β~𝒎^t[i])+ϵt[i]𝜼t[i]\Delta\boldsymbol{w}_{t}[i]\leftarrow\tfrac{-\boldsymbol{\epsilon}_{t}[i]}{2}\left(\gamma(\boldsymbol{w}_{t}[i]-\boldsymbol{w}_{0}[i])+\tilde{\beta}\hat{\boldsymbol{m}}_{t}[i]\right)+\sqrt{\boldsymbol{\epsilon}_{t}[i]}\boldsymbol{\eta}_{t}[i]roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← divide start_ARG - bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] end_ARG start_ARG 2 end_ARG ( italic_γ ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] - bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT [ italic_i ] ) + over~ start_ARG italic_β end_ARG over^ start_ARG bold_italic_m end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ) + square-root start_ARG bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] end_ARG bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d.
16:𝒘t+1𝒘t+Δ𝒘t\boldsymbol{w}_{t+1}\leftarrow\boldsymbol{w}_{t}+\Delta\boldsymbol{w}_{t}bold_italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 3 RMSPropSGLD
1:Inputs: 𝒘0d\boldsymbol{w}_{0}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (initial parameter), Dn={𝒛1,,𝒛n}D_{n}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } (dataset), :N×d\ell:\mathbb{R}^{N}\times\mathbb{R}^{d}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R (loss function)
2:Outputs: 𝒘1,,𝒘Td\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT
3:Hyperparameters: ϵ>0\epsilon>0italic_ϵ > 0 (base step size), γ>0\gamma>0italic_γ > 0 (localization), β~>0\tilde{\beta}>0over~ start_ARG italic_β end_ARG > 0 (posterior temperature), mm\in\mathbb{N}italic_m ∈ blackboard_N (batch size), a>0a>0italic_a > 0 (stability), b(0,1)b\in(0,1)italic_b ∈ ( 0 , 1 ) (EMA decay rate)
4:
5:𝒗1(1,1,,1)d\boldsymbol{v}_{-1}\leftarrow(1,1,\ldots,1)\in\mathbb{R}^{d}bold_italic_v start_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ← ( 1 , 1 , … , 1 ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. t0:T1t\leftarrow 0:T-1italic_t ← 0 : italic_T - 1
6:Draw a batch zu1,,zumz_{u_{1}},\ldots,z_{u_{m}}italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT from DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.
7:𝒈t1mk=1m𝒘(𝒛uk,𝒘t)\boldsymbol{g}_{t}\leftarrow\frac{1}{m}\sum_{k=1}^{m}\nabla_{\boldsymbol{w}}\ell(\boldsymbol{z}_{u_{k}},\boldsymbol{w}_{t})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) \triangleright Gradient with respect to the parameter argument.
8:Define 𝒗t\boldsymbol{v}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by 𝒗t[i]b𝒗t1[i]+(1b)𝒈t[i]2\boldsymbol{v}_{t}[i]\leftarrow b\boldsymbol{v}_{t-1}[i]+(1-b)\boldsymbol{g}_{t}[i]^{2}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← italic_b bold_italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT [ italic_i ] + ( 1 - italic_b ) bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d.
9:𝒗^t11bt𝒗t\hat{\boldsymbol{v}}_{t}\leftarrow\frac{1}{1-b^{t}}\boldsymbol{v}_{t}over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG 1 - italic_b start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
10:Define ϵt\boldsymbol{\epsilon}_{t}bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by ϵt[i]ϵ𝒗^t[t]+a\boldsymbol{\epsilon}_{t}[i]\leftarrow\frac{\epsilon}{\sqrt{\hat{\boldsymbol{v}}_{t}[t]}+a}bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← divide start_ARG italic_ϵ end_ARG start_ARG square-root start_ARG over^ start_ARG bold_italic_v end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_t ] end_ARG + italic_a end_ARG for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d. \triangleright Step size of each parameter.
11:Draw 𝜼t\boldsymbol{\eta}_{t}bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from the standard normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
12:Define Δ𝒘t\Delta\boldsymbol{w}_{t}roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by Δ𝒘t[i]ϵt[i]2(γ(𝒘t[i]𝒘0[i])+β~𝒈t[i])+ϵt[i]𝜼t[i]\Delta\boldsymbol{w}_{t}[i]\leftarrow\tfrac{-\boldsymbol{\epsilon}_{t}[i]}{2}\left(\gamma(\boldsymbol{w}_{t}[i]-\boldsymbol{w}_{0}[i])+\tilde{\beta}\boldsymbol{g}_{t}[i]\right)+\sqrt{\boldsymbol{\epsilon}_{t}[i]}\boldsymbol{\eta}_{t}[i]roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ← divide start_ARG - bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] end_ARG start_ARG 2 end_ARG ( italic_γ ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] - bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT [ italic_i ] ) + over~ start_ARG italic_β end_ARG bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] ) + square-root start_ARG bold_italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] end_ARG bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT [ italic_i ] for i=1,,di=1,\ldots,ditalic_i = 1 , … , italic_d.
13:𝒘t+1𝒘t+Δ𝒘t\boldsymbol{w}_{t+1}\leftarrow\boldsymbol{w}_{t}+\Delta\boldsymbol{w}_{t}bold_italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 4 SGHMC
1:Inputs: 𝒘0d\boldsymbol{w}_{0}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (initial parameter), Dn={𝒛1,,𝒛n}D_{n}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } (dataset), :N×d\ell:\mathbb{R}^{N}\times\mathbb{R}^{d}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R (loss function)
2:Outputs: 𝒘1,,𝒘Td\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT
3:Hyperparameters: ϵ>0\epsilon>0italic_ϵ > 0 (step size), γ>0\gamma>0italic_γ > 0 (localization), β~>0\tilde{\beta}>0over~ start_ARG italic_β end_ARG > 0 (posterior temperature), mm\in\mathbb{N}italic_m ∈ blackboard_N (batch size), α>0\alpha>0italic_α > 0 (friction)
4:
5:Draw 𝒑0\boldsymbol{p}_{0}bold_italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT from a normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with mean 𝟎\boldsymbol{0}bold_0 and variance ϵ\epsilonitalic_ϵ. t0:T1t\leftarrow 0:T-1italic_t ← 0 : italic_T - 1
6:Draw a batch zu1,,zumz_{u_{1}},\ldots,z_{u_{m}}italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT from DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.
7:𝒈t1mk=1m𝒘(𝒛uk,𝒘t)\boldsymbol{g}_{t}\leftarrow\frac{1}{m}\sum_{k=1}^{m}\nabla_{\boldsymbol{w}}\ell(\boldsymbol{z}_{u_{k}},\boldsymbol{w}_{t})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) \triangleright Gradient with respect to the parameter argument.
8:Draw 𝜼t\boldsymbol{\eta}_{t}bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from the standard normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
9:Δ𝒑tϵ2(γ(𝒘t𝒘0)+β~𝒈t)α𝒑t+2αϵηt\Delta\boldsymbol{p}_{t}\leftarrow\tfrac{-\epsilon}{2}\left(\gamma(\boldsymbol{w}_{t}-\boldsymbol{w}_{0})+\tilde{\beta}\boldsymbol{g}_{t}\right)-\alpha\boldsymbol{p}_{t}+\sqrt{2\alpha\epsilon}\eta_{t}roman_Δ bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG - italic_ϵ end_ARG start_ARG 2 end_ARG ( italic_γ ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + over~ start_ARG italic_β end_ARG bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_α bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + square-root start_ARG 2 italic_α italic_ϵ end_ARG italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
10:𝒑t+1𝒑t+Δ𝒑t\boldsymbol{p}_{t+1}\leftarrow\boldsymbol{p}_{t}+\Delta\boldsymbol{p}_{t}bold_italic_p start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
11:𝒘t+1𝒘t+𝒑t\boldsymbol{w}_{t+1}\leftarrow\boldsymbol{w}_{t}+\boldsymbol{p}_{t}bold_italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 5 SGNHT
1:Inputs: 𝒘0d\boldsymbol{w}_{0}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (initial parameter), Dn={𝒛1,,𝒛n}D_{n}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{n}\}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } (dataset), :N×d\ell:\mathbb{R}^{N}\times\mathbb{R}^{d}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R (loss function)
2:Outputs: 𝒘1,,𝒘Td\boldsymbol{w}_{1},\ldots,\boldsymbol{w}_{T}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT
3:Hyperparameters: ϵ>0\epsilon>0italic_ϵ > 0 (step size), γ>0\gamma>0italic_γ > 0 (localization), β~>0\tilde{\beta}>0over~ start_ARG italic_β end_ARG > 0 (posterior temperature), mm\in\mathbb{N}italic_m ∈ blackboard_N (batch size), α0>0\alpha_{0}>0italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > 0 (initial friction).
4:
5:Draw 𝒑0\boldsymbol{p}_{0}bold_italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT from a normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with mean 𝟎\boldsymbol{0}bold_0 and variance ϵ\epsilonitalic_ϵ. t0:T1t\leftarrow 0:T-1italic_t ← 0 : italic_T - 1
6:Draw a batch zu1,,zumz_{u_{1}},\ldots,z_{u_{m}}italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT from DnD_{n}italic_D start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.
7:𝒈t1mk=1m𝒘(𝒛uk,𝒘t)\boldsymbol{g}_{t}\leftarrow\frac{1}{m}\sum_{k=1}^{m}\nabla_{\boldsymbol{w}}\ell(\boldsymbol{z}_{u_{k}},\boldsymbol{w}_{t})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_italic_z start_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT , bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) \triangleright Gradient with respect to the parameter argument.
8:Draw 𝜼t\boldsymbol{\eta}_{t}bold_italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from the standard normal distribution on d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.
9:Δ𝒑tϵ2(γ(𝒘t𝒘0)+β~𝒈t)αt𝒑t+2αtϵηt\Delta\boldsymbol{p}_{t}\leftarrow\tfrac{-\epsilon}{2}\left(\gamma(\boldsymbol{w}_{t}-\boldsymbol{w}_{0})+\tilde{\beta}\boldsymbol{g}_{t}\right)-\alpha_{t}\boldsymbol{p}_{t}+\sqrt{2\alpha_{t}\epsilon}\eta_{t}roman_Δ bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG - italic_ϵ end_ARG start_ARG 2 end_ARG ( italic_γ ( bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + over~ start_ARG italic_β end_ARG bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + square-root start_ARG 2 italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ϵ end_ARG italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
10:𝒑t+1𝒑t+Δ𝒑t\boldsymbol{p}_{t+1}\leftarrow\boldsymbol{p}_{t}+\Delta\boldsymbol{p}_{t}bold_italic_p start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + roman_Δ bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.
11:αt+1αt+pt/dϵ\alpha_{t+1}\leftarrow\alpha_{t}+\|p_{t}\|/d-\epsilonitalic_α start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ∥ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ / italic_d - italic_ϵ
12:𝒘t+1𝒘t+𝒑t\boldsymbol{w}_{t+1}\leftarrow\boldsymbol{w}_{t}+\boldsymbol{p}_{t}bold_italic_w start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

D.4 Large language model experiments

To complement our analysis of deep linear networks, we also examined the performance of sampling algorithms for LLC estimation on a four-layer attention-only transformer trained on the DSIR-filtered Pile [Gao et al., 2020, Xie et al., 2023].

D.4.1 Model Architecture

We trained a four-layer attention-only transformer with the following specifications:

  • Number of layers: 4

  • Hidden dimension (dmodeld_{\text{model}}italic_d start_POSTSUBSCRIPT model end_POSTSUBSCRIPT): 256

  • Number of attention heads per layer: 8

  • Head dimension (dheadd_{\text{head}}italic_d start_POSTSUBSCRIPT head end_POSTSUBSCRIPT): 32

  • Context length: 1024

  • Activation function: GELU [Hendrycks and Gimpel, 2023]

  • Vocabulary size: 5,000

  • Positional embedding: Learnable (Shortformer-style, Press et al. 2021)

The model was implemented using TransformerLens [Nanda and Bloom, 2022] and trained using AdamW with a learning rate of 0.001 and weight decay of 0.05 for 75,000 steps with a batch size of 32. Training took approximately 1 hour on a TPUv4.

D.4.2 LLC Estimation

We applied both standard SGLD and RMSProp-preconditioned SGLD to estimate the Local Learning Coefficient (LLC) of individual attention heads (“weight-refined LLCs,” Wang et al. 2024) at various checkpoints during training. As with the deep linear network experiments, we tested both algorithms across a range of step sizes ϵ{104,3×104,103,3×103}\epsilon\in\{10^{-4},3\times 10^{-4},10^{-3},3\times 10^{-3}\}italic_ϵ ∈ { 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT , 3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT , 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT , 3 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT }. LLC estimation was implemented using devinterp [van Wingerden et al., 2024]. Each LLC over training time trajectory in Figure 4 took approximately 15 minutes on a TPUv4, for a total of roughly 2 hours.

D.4.3 Results and Analysis

Figure 4 shows LLC estimates for the second head in the third layer. Here, RMSProp-preconditioned SGLD demonstrates several advantages over standard SGLD:

  1. Step size stability: RMSProp-SGLD produces more consistent LLC-over-time curves across a wider range of step sizes, enabling more reliable parameter estimation.

  2. Loss trace stability: The loss traces for RMSProp-SGLD show significantly fewer spikes compared to standard SGLD, resulting in more stable posterior sampling.

  3. Failure detection: When the step size becomes too large for stable sampling, RMSProp-SGLD fails catastrophically with NaN values, providing a clear signal that hyperparameters need adjustment. In contrast, standard SGLD might produce plausible but inaccurate results without obvious warning signs.

These results align with our findings in deep linear networks and suggest that the advantages of RMSProp-preconditioned SGLD generalize across model architectures.

Appendix E Technical conditions for singular learning theory

In this section we state the technical conditions for singular learning theory (SLT). These are required to state Theorem 3.2 and Theorem 3.4. The conditions are discussed in Lau et al. [2024, Appendix A] and Watanabe [2018].

Definition E.1 (Watanabe, 2009, 2013, 2018, see).

Consider a true distribution q(x)q(x)italic_q ( italic_x ), model {p(x|w)}w𝒲\{p(x|w)\}_{w\in\mathcal{W}}{ italic_p ( italic_x | italic_w ) } start_POSTSUBSCRIPT italic_w ∈ caligraphic_W end_POSTSUBSCRIPT and prior φ(w)\varphi(w)italic_φ ( italic_w ). Let W𝒲W\subseteq\mathcal{W}italic_W ⊆ caligraphic_W be the support of φ(w)\varphi(w)italic_φ ( italic_w ) and W0WW_{0}\subseteq Witalic_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⊆ italic_W be the set of global minima of L(w)L(w)italic_L ( italic_w ) restricted to WWitalic_W. The fundamental conditions of SLT are as follows:

  1. For all wWw\in Witalic_w ∈ italic_W the support of p(x|w)p(x|w)italic_p ( italic_x | italic_w ) is equal to the support of q(x)q(x)italic_q ( italic_x ).

  2. The prior’s support WWitalic_W is compact with non-empty interior, and can be written as the intersection of finitely many analytic inequalities.

  3. The prior φ(w)\varphi(w)italic_φ ( italic_w ) can be written as φ(w)=φ1(w)φ2(w)\varphi(w)=\varphi_{1}(w)\varphi_{2}(w)italic_φ ( italic_w ) = italic_φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_w ) italic_φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_w ) where φ1(w)0\varphi_{1}(w)\geq 0italic_φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_w ) ≥ 0 is analytic and φ2(w)>0\varphi_{2}(w)>0italic_φ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_w ) > 0 is smooth.

  4. For all w0,w1W0w_{0},w_{1}\in W_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ italic_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT we have p(x|w0)=p(x|w1)p(x|w_{0})=p(x|w_{1})italic_p ( italic_x | italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_p ( italic_x | italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) almost everywhere.

  5. Given w0W0w_{0}\in W_{0}italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ italic_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the function g(x,w)=logp(x|w0)p(x|w)g(x,w)=\log\frac{p(x|w_{0})}{p(x|w)}italic_g ( italic_x , italic_w ) = roman_log divide start_ARG italic_p ( italic_x | italic_w start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p ( italic_x | italic_w ) end_ARG satisfies:

    1. For each fixed wWw\in Witalic_w ∈ italic_W, g(x,w)g(x,w)italic_g ( italic_x , italic_w ) is in Ls(q)L^{s}(q)italic_L start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( italic_q ) for s2s\geq 2italic_s ≥ 2.

    2. g(x,w)g(x,w)italic_g ( italic_x , italic_w ) is an analytic function of wwitalic_w which can be analytically extended to a complex analytic function on an open subset of d\mathbb{C}^{d}blackboard_C start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

    3. There exists C>0C>0italic_C > 0 such that for all wWw\in Witalic_w ∈ italic_W we have

      𝐄[g(X,w)]C𝐄[g(X,w)2]Xq(x)\mathbf{E}[g(X,w)]\geq C\mathbf{E}[g(X,w)^{2}]\qquad X\sim q(x)bold_E [ italic_g ( italic_X , italic_w ) ] ≥ italic_C bold_E [ italic_g ( italic_X , italic_w ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] italic_X ∼ italic_q ( italic_x )

Conditions 4 and 5c are together called the relatively finite variance condition.

Cite as

@article{hitchcock2025from,
  title = {From Global to Local: A Scalable Benchmark for Local Posterior Sampling},
  author = {Rohan Hitchcock and Jesse Hoogland},
  year = {2025},
  abstract = {Degeneracy is an inherent feature of the loss landscape of neural networks, but it is not well understood how stochastic gradient MCMC (SGMCMC) algorithms interact with this degeneracy. In particular, current global convergence guarantees for common SGMCMC algorithms rely on assumptions which are likely incompatible with degenerate loss landscapes. In this paper, we argue that this gap requires a shift in focus from global to local posterior sampling, and, as a first step, we introduce a novel scalable benchmark for evaluating the local sampling performance of SGMCMC algorithms. We evaluate a number of common algorithms, and find that RMSProp-preconditioned SGLD is most effective at faithfully representing the local geometry of the posterior distribution. Although we lack theoretical guarantees about global sampler convergence, our empirical results show that we are able to extract non-trivial local information in models with up to O(100M) parameters.},
  eprint = {2507.21449},
  archivePrefix = {arXiv},
  url = {https://arxiv.org/abs/2507.21449}
}
Click to copy