DSLT 4. Phase Transitions in Neural Networks

Authors

Liam Carroll

Published

Jun 24, 2023
Read on LessWrong

TLDR; This is the fourth main post of Distilling Singular Learning Theory which is introduced in DSLT0. I explain how to relate SLT to thermodynamics, and therefore how to think about phases and phase transitions in the posterior in statistical learning. I then provide intuitive examples of first and second order phase transitions in a simple \(K(w)\)loss function. Finally, I experimentally demonstrate phase transitions in two layer ReLU neural networks associated to the node-degeneracy and orientation-reversing phases established in DSLT3, which we can understand precisely through the lens of SLT.

In deep learning, the terms “phase” and “phase transition” are often used in an informal manner to refer to a steep change in a metric we care about, like the training or test loss, as a function of SGD steps, or alternatively some hyperparameter like the number of samples from the truth \(n\).

A classic example of a phase transition in deep learning literature is Anthropic’s Induction Heads paper, where the “bump” in loss (in-context learning score) is referred to as a “phase change”.

But what exactly are the phases? And why do phase transitions even occur? SLT provides us a solid theoretical framework for understanding phases and phase transitions in deep learning. In this post, we will argue that in the Bayesian setting,

A phase of the learning process corresponds to a singularity of\(K(w)\), and a phase transition **corresponds to a drastic change in the posterior as a function of a hyperparameter **\(\theta\).

The hyperparameter \(\theta \) could be the number of samples from the truth \(n\), some way of varying the model function \(f(x,w) = f(x,w;\theta)\) or something about the true distribution \(D_n = D_n(\theta)\), amongst other things. At some critical value \(\theta=\theta_c\), we recognise a phase transition as being a discontinuous change in the free energy or one of its derivatives, for example the generalisation error \(G_n = \mathbb{E}[F_{n+1}] - \mathbb{E}[F_n]\).

In this post, we will present experiments that observe precise phase transitions in the toy neural network models we studied in DSLT3, for which we understand the set of true parameters \(W_0\) and therefore the phases. By the end of this post, you will have a framework for thinking about phase transitions in singular models and an intuition for why SLT predicts them to occur in learning.

Phases Correspond to Singularities

The Story Starts in Physics

This subsection is modelled on [Callen, Ch9], but it is only intended to be a high level discussion of the concepts grounded in some basic physics - don’t get too bogged down in the details of the thermodynamics. 

Fundamentally, a phase describes an aggregate state of a complex system of many interacting components, where the state retains particular qualities with variations in some hyperparameter. To explain the concept in detail, it is natural to start in physics (thermodynamics in particular), where these ideas originally arose. But there is a deeper reason to build from here: every human has an intuitive understanding of the phases of water and how they change with temperature 1, which serves as the base mental model for what a phase is. 

One of the main goals of thermodynamics is to study how the equilibrium state of a system changes as a function of macroscopic parameters. In the case of a vessel of water at 1atm of pressure in constant contact with a thermal and pressure reservoir, the equilibrium state of the system corresponds to a state that is minimised by the Gibbs free energy \(F\) 2. The phases, then, are the equilibrium states, which describe qualitative physical properties of the system. The states of matter - solid, liquid, and gas - are all phases of water, which are characterised by variables like their volume and crystal structure. As anybody that has boiled water before knows, these phases undergo transitions as a function of temperature. Let’s make this more precise.

The Thermodynamic Setup

Consider a system of \(K\) water molecules moving in a 2D container, each with equal mass \(m\). To each particle \(i \in [K]={1,\dots,K}\) we can associate a set of microstates describing its physical properties at a point in time, for example its position \(x_i\) and its velocity \(v_i\). In our discussion we will simply focus on the position, which we will relabel \(w=(w_1, \dots, w_K)\) (for reasons that will become clear), so our configuration space \(W \subseteq \mathbb{R}^{2K}\) of possible microstates is 

\[W = { w ,|, (w_{i,x}, w_{i,y}) \in \mathbb{R}^2 , \text{for each} , i \in [K]} ,.\]

Since it is physically infeasible to know or model the positions of all molecules, we instead reason about the dynamics of the system by calculating macroscopic variables associated to a microstate, for example the temperature or total volume of the molecules. We will focus on the volume \(V(w)\) of a microstate \(w\). Importantly, a macroscopic state is an aggregate over the system (for example, temperature being related to average squared velocity), meaning there are many possible configurations of microstates that result in the same macrostate. To this end, we can define regions of our configuration space according to their volume \(v\), 

\[\mathcal{W}_{v} = {w \in W , | , V(w) = v} \subseteq W ,.\]

In our toy example, we want to study how the system changes as a function of temperature, which we will denote with \(\theta\). In a Gibbs ensemble, we can associate an energy functional, the Hamiltonian \(H(w;\theta)\), to any given microstate \(w\) at temperature \(\theta\). The fundamental postulate of such a Gibbs ensemble is that probability of the system being in a particular micro state \(w\) is determined by a Gibbs distribution 3

\[p(w;\theta) = \frac{e^{-H(w;\theta)}}{Z} \quad \text{where} ,, Z=\int_W e^{- H(w;\theta)} dw ,.\]

This should look pretty familiar from our statistical learning setup! Indeed, we can then calculate the free energy of the ensemble for different volumes \(v\) at temperature \(\theta\), 

\[F_{\theta}(v) = - \log \left( \int_{\mathcal{W}_v} e^{-H(w;\theta)} dw \right) ,.\]

For a Gibbs ensemble, the equilibrium state of a given system is that state which minimises the free energy. In the context of bringing water to a boiling point, there are two minima of the free energy characterised by the liquid and gaseous states, which for ease we will characterise by their volumes \(v_{\text{liquid}}\) and \(v_{\text{gas}}\). Then the equilibrium state changes at the critical temperature \(\theta_c = 100°\mathrm{C}\),

\[\begin{cases} \mathcal{W}{v{\text{liquid}}} & 0°\mathrm{C} < \theta < 100°\mathrm{C} \ \mathcal{W}{v{\text{gas}}} & \theta > 100°\mathrm{C} \end{cases} ,.\]

Importantly, while small variations in the temperature away from \(\theta_c\) will change the free energy of each state, it will not change the configuration of these minima with respect to the free energy. In other words, the system will still be a liquid for any \(\theta \in (0, 100)\) - its qualitative properties are stable. This is the content of a phase. 

What is a phase?

A phase of a system is a region of configuration space \(\mathcal{W} \subset W\) that minimises the free energy, and is invariant to small perturbations in a relevant hyperparameter \(\theta\). Typically, phases are distinguished by some macroscopic variable, in our case the volume \(V(w)\) distinguishing subsets \(\mathcal{W}_v\). More generally though, a phase describes some qualitative aggregate state of a system - like, as we’ve discussed in our example, the states of matter.

A phase is a path in the space of critical points

Phases are minima of the free energy that remain minima with small perturbations in \(\theta\).

In some sense, you can define a phase to be any region that induces an equilibrium state with qualities you care about. But what makes phases a powerful concept is their relation to phase transitions - when there is a sudden jump in which state is preferred by the system.

What is a phase transition?

Phase transitions are changes in the structure of the global minima of the free energy, and often arise as non-analyticities of  \(F_n\). This is a fancy way of saying they correspond to discontinuities in the free energy or one of its derivatives 4

A first order phase transition at a critical temperature \(\theta_c\) corresponds to a reconfiguration of which phase is the global minima of the free energy.

First order phase transition exchanges minima

A first order phase transition exchanges global and local minima, changing the configuration of which phase is preferred according to which has the lowest free energy. These curves depict the phase transition as water reaches boiling point \(\theta = 100°C\).

As we discussed above, heating water to boiling point \(\theta_c = 100°\mathrm{C}\) is a classic example of a first order phase transition. 

Two examples of second order phase transitions are where: 

  • A merge transition occurs at \(\theta_c\) when two phases that are initially disjoint for \(\theta < \theta_c\) merge to become the same state for \(\theta \geq \theta_c\), or;
  • A creation transition occurs at \(\theta_c\) when a local minima exists for \(\theta \geq \theta_c\) but does not exist for \(\theta < \theta_c\). (If the directions are reversed, we call this a destruction transition).

Second order phase transitions

A second order merge phase transition (left), and a second order creation phase transition (right).

(Note that we have not given a full classification of phase transitions here, because to do so one needs to study the possible types of catastrophes that can occur, as presented in [Gilmore]).

Phases in Statistical Learning

The notation and concepts in the previous section were not presented without reason. For starters, the Gibbs ensemble view of statistical learning is actually quite a rich analogy because, when the prior is uniform, the (random) Hamiltonian is equal to the empirical KL divergence 5

\[H_n(w) = n K_n(w) ,.\]

The configuration space of microstates of the physical system then corresponds to parameter space \(W\) with microstates given by different parameters \(w \in W\). This means the posterior is equivalent to the Gibbs probability distribution of the system being in a certain microstate, meaning the definition of free energy is identical. So, what exactly are the phases then?

In statistical learning then, 

A phase **corresponds to a local neighbourhood **\(\mathcal{W} \subset W\) containing a singularity \(w_{\mathcal{W}}^{(0)}\) of interest.

To say that \(\mathcal{W}\) minimises the free energy is equivalent to saying that it has non-negligible posterior mass. The reason for this, as we explored in DSLT2, is that the singularity structure of a most singular optimal point \(w^{(0)}{\mathcal{W}} \in \mathcal{W}{\text{opt}}\) dominates the behaviour of the free energy, because it minimises the loss \(L(w)\) and has the smallest RLCT \(\lambda_{\mathcal{W}}\). 

You can, in principal, define a phase to be any region of \(W\). But the analysis of phases in the posterior only gets interesting when you have a set of phases that have fundamentally different geometric properties. The free energy formula tells us that these geometric properties correspond to different accuracy-complexity tradeoffs.

Consequently, in statistical learning, Watanabe states in [Wat18, \(\S9.4\)] that

A phase transition **is a drastic change in the geometry of the posterior as a function of a hyperparameter **\(\theta\). 

Our definitions of first and second order phase transitions carry over perfectly from the physics discussion above. 

It’s important to clarify here that phase transitions in deep learning have many flavours. If one believes that SGD is effectively just “sampling from the posterior”, then the conception that phase transitions are related to changes in the geometry of the posterior carries over. There is, however, one fundamentally different kind of “phase transition” that we cannot explain easily with SLT: a phase transition of SGD in time, i.e. the number gradient descent steps. The Bayesian framework of SLT does not really allow one to speak of time - the closest quantity is the number of datapoints \(n\), but these are not equivalent. We leave this gap as one of the fundamental open questions of relating SLT to current deep learning practice. 6

The hyperparameter \(\theta \) can affect any number of objects involved in the posterior. Remembering that the posterior is 

\[p(w|D_n) = \frac{\varphi(w) e^{-nL_n(w)}}{Z_n},,\]

 we could include hyperparameter \(\theta\) dependence in any of:

  • The model function \(f(x,w)=f(x,w;\theta)\) (i.e. the neural network defining \(p(y|x,w)).\)
  • The true distribution \(D_n = D_n(\theta)\), meaning \(L_n(w) = L_n(w;\theta)\). (This could in principal be dependence in the input prior \(q(x)\) or the actual dataset generated by \(q(y|x)\).)
  • The number of datapoints \(n\) (inducing a first order phase transition due to the change in accuracy-complexity tradeoff).
  • The prior \(\varphi(w) = \varphi(w;\theta)\).

Intuitive Examples to Interpret Phase Transitions

In DSLT2 we studied an example of a very simple one-dimensional \(K(w)\) curve and got a feel for how the accuracy and complexity of a singularity affect the free energy of different neighbourhoods. Having now learned about phase transitions, we can cast new light on this example.

Example 1: First Order Phase Transition in \(n\) 

Example 4.1: Consider again a KL divergence given by

\[K(w) = (w+1)^2 ((w-(1+h_C))^4 - k_C)\]

where \(w_{-1}^{(0)}=-1\) and \(w^{(0)}_{1}=1\) are the singularities, but the accuracy of \(w_1^{(0)}\) is worse, \(K(w^{(0)}_1) = C >0\). Then we can identify two phases corresponding to the two singularities, 

\[\mathcal{W}{-1} = B(w{-1}^{(0)}, \delta) \quad \text{and} \quad \mathcal{W}{1} = B(w{1}^{(0)}, \delta)\]

for some radius \(\delta>0\) such that the accuracy of \(\mathcal{W}_{-1}\) is better, but the complexity of \(\mathcal{W}_1\)was smaller, 

\[L(w^{(0)}{-1}) < L(w^{(0)}{1}),, \quad \text{but} \quad \lambda_{\mathcal{W}{-1}} > \lambda{\mathcal{W}_{1}} ,.\]

As the hyperparameter \(\theta=n\) 7 varies, we see a first order phase transition at the critical value of \(n_c\approx 17\) where the two free energy curves intersect, causing an exchange which phase is the global minima of the free energy. As we argued in that post, this is largely due to the accuracy-complexity tradeoff of the free energy. Notice also how the free energy of the global minima is non-differentiable at \(n_c\), showing an example of the “non-analyticity” of \(F_n\) that we mentioned above. 

Free energy crossover

A first order phase transition occurs at \(n_c\) when the phase \(\mathcal{W}_1\) becomes the new global minima of the free energy.

Example 2: Second Order Merge Phase Transition\(\)

Example 4.2: We can modify our example slightly to observe a second order phase transition. Let’s consider 

\[K(w;\theta) = (w+(1-\theta))^2 (w-(1-\theta))^4\]

where \(\theta \in [0,1]\) is a hyperparmeter that shifts the two singularities \(w_{-1}^{(0)} = -1+\theta\) and \(w_{1}^{(0)} = 1-\theta\) towards the origin. We will continue to label these phases \(\mathcal{W}_{-1}\) and \(\mathcal{W}_1\), noting their \(\theta\) dependence. 8

Thus, at \(\theta_c=1\) the two phases will merge and the KL divergence will be 

\[K(w;1) = w^6 ,.\]

Therefore, at \(\theta=1\) the singularity \(w^{(0)}_0=0\) will have an RLCT of 

\[\lambda_0 = \frac{1}{6} ,.\]

There is a new most singular point caused by the merging of two phases! Again, we can visually depict this phase transition: 

2nd order phase transition

A second order phase transition as two phases merge into one with a lower RLCT.

Now that we have the basic intuitions of SLT and phase transitions down pat, let’s apply these concepts to the case of two layer feedforward ReLU neural networks.

Phase Transitions in Two Layer ReLU Neural Networks

The main claim of this sequence is that Singular Learning Theory is a solid theoretical framework for understanding phases and phase transitions in neural networks. It’s now time to make good on that promise and bring all of the pieces together to understand an actual example of phase transitions in neural networks. The full details of these experiments are explained in my thesis, [Carroll, \(\S5.2\)], but I will briefly outline some points here for the interested reader. All notation and terminology is explained in detail in DSLT3, so use that section as a reference. 

If you are uninterested, just skip to the next subsection to see the results. 

Experimental Setup

We will consider a (model, truth) pair defined by the simple two layer feedforward ReLU neural network models we studied in DSLT3. Phase transitions will be induced by varying true distribution by a hyperparameter \(\theta\), meaning \(D_n = D_n(\theta)\). Since we have a full classification of \(W_0\) from DSLT3, we understand the phases of the system, and therefore we want to study how their differing geometries affect the posterior. As we explained in that post, the scaling and permutation symmetries are generic (they occur for all parameters \(w \in W\)), but the node-degeneracy and orientation-reversing symmetries only occur under precise configurations of the truth. Thus, we are interested in studying the how the posterior changes as we vary the truth to induce these alternative true parameters - the phases of our setup. 

The posterior sampling procedure uses an MCMC variant called HMC NUTS, which is brilliantly explained and interpreted here. Estimating precise nominal free energy values, and particularly those of the RLCT \(\lambda\), using sampling methods is currently very challenging (as explained in [Wei22]). So, for these experiments, our inference about phases and phase transitions will be based on visualising the posterior and observing the posterior concentrations of different phases. With this in mind, the posteriors below are averaged over four trials, 20,000 samples each, for each fixed true distribution defined by \(\theta\). (Bayesian sampling is very computationally expensive, even in simple settings).

To isolate the phases we care about, we can use the fact that the scaling symmetry and permutation symmetries of our networks are generic. To this end we will normalise the weights by defining the effective weight \(\hat{w}i=|q_i| w_i\) 9, which preserves functional equivalence \(f(x,w)= f(x,\hat{w})\) 10. We will say a node is degenerate if \(\hat{w}i = 0\). We also project different node indices on to the same \((\hat{w}{i,1}, \hat{w}{i,2})\) axes as follows:

Unnormalised vs normalised scatterplot of samples

Scatterplots of \(20{,}000\) posterior samples \(w^{(k)}\) for a two layer network: non-normalised weights \((w_{i,1}, w_{i,2})\) (left), and then normalised effective weights \((\hat{w}{i,1}, \hat{w}{i,2})\) (right). For the given sample \(w^{(k)}\), each of the two nodes \(\hat{w}_1, \hat{w}_2\) are projected on to the same plot (i.e. there are 40,000 dots on each of these plots). 

The prior on inputs \(q(x)\) is uniform on the square \([-1,1]^2\), and the prior on parameters \(\varphi(w)\) is the standard multidimensional normal \(\mathcal{N}(0,1)\). 

Phase Transition 1 - Deforming to Degeneracy

In this experiment we will see a first order phase transition induced by deforming a true network from having no degenerate nodes to having one (possibility of a) degenerate node, as discussed in DSLT3 - Node Degeneracy. This example will reinforce the key messages of Watanabe’s free energy formula: true parameters are preferred according to their RLCT, and at finite \(n\) non-true parameters can be preferred due to the accuracy-complexity tradeoff.

Defining the Model, Truth, and Phases

We are going to consider a model network with \(d=2\) nodes, 

\[f(x,w) =\mathrm{ReLU}(\langle \hat{w}_1, x \rangle + \hat{b}_1) + \mathrm{ReLU}(\langle \hat{w}_2, x \rangle + \hat{b}_2) + c\]

and a realisable true network \(f(x,w^{(0)})\) with \(m=2\) nodes, which we will denote by \(f_2(x,\theta):=f(x,w^{(0)}) \) to signify its hyperparameter \(\theta\) dependence (and distinguish it from the next experiment),

\[f_2(x,\theta) =\mathrm{ReLU}\left(\langle \hat{w}_1^{(0)}, x \rangle -\frac{1}{3}\right) + \mathrm{ReLU}\left(\langle \hat{w}_2^{(0)}, x \rangle -\frac{1}{3}\right) ,.\]

The true weights rotate towards one another by a hyperparameter \(\theta = [0, \frac{\pi}{2}]\), so 11

\[w_1^{(0)} = (\cos\theta, \sin\theta), \quad w_2^{(0)} = (-\cos\theta, \sin\theta) ,.\]

As we explained in DSLT3, we can depict the function and its activation boundaries pictorially:

PT1 Activation Boundaries

Contour plot of \(f_2(x,\theta)\) and its activation boundaries at \(\theta =0\) (left), \(\theta = \frac{\pi}{4}\) (middle) and \(\theta = \frac{\pi}{2}\) (right).

At \(\theta=\frac{\pi}{2}\), the truth could be expressed by a network with only one node, \(m=1\),

\[f_2\left(x,\frac{\pi}{2}\right) = \mathrm{ReLU}\left( x_2 - \frac{1}{3} \right) + \mathrm{ReLU}\left( x_2 - \frac{1}{3} \right) = 2\mathrm{ReLU}\left( x_2 - \frac{1}{3} \right) ,.\]

This degeneracy is what we are interested in studying. The WBIC tells us to expect the posterior to prefer the one-degenerate-node configuration since it has less effective parameters. 12

To identify our phases, at \(\theta = \frac{\pi}{2}\) there are two possible configurations of the effective model weights that are true parameters:

  • Both non-degenerate but share the same activation boundary: Both \(\hat{w}_1, \hat{w}_2 \neq 0\) such that \(\hat{w}_1 + \hat{w}_2 = (0,2)\).
  • One degenerate, one non-degenerate: Either \(\hat{w}_1 = (0,0)\) and \(\hat{w}_2=(0,2)\), or vice versa by permutation symmetry.

To study these configurations we thus define phases based on annuli in the plane centred on the circle of radius \(r\) with annuli radius of \(\varepsilon\), 

\[\mathcal{A}(r, \varepsilon) = { (\hat{w}{_,1}, \hat{w}{_,2}) \in \mathbb{R}^2 ; | ; r-\varepsilon \leq |(\hat{w}{_,1}, \hat{w}{_,2})| \leq r+\varepsilon } ,.\]

 Then we define the two phases containing the singularities of interest to be 

\[\begin{aligned}\mathcal{A}{\text{NonDegen}} &= \mathcal{A}(1, \varepsilon) \times \mathcal{A}(1, \varepsilon) \ \mathcal{A}{\text{Degen}} &= \left(\mathcal{A}(0, \varepsilon) \times \mathcal{A}(2, \varepsilon) \right) \cup \left(\mathcal{A}(2, \varepsilon) \times \mathcal{A}(0, \varepsilon) \right) \end{aligned} ,.\]

The union is due to the permutation symmetry - which precise node is degenerate doesn’t matter. We will let \(\mathcal{A}^c = \mathbb{R}^2 \backslash(\mathcal{A}{\text{NonDegen}} \cup \mathcal{A}{\text{NonDegen}})\).

PT1 Annuli in  plane

The phases of interest in the \((\hat{w}{i,1}, \hat{w}{i,2})\) plane.

There are two questions we seek to answer:

  1. At \(\theta=\frac{\pi}{2}\), which phase is preferred by the posterior, since both contain true parameters?
  2. Is there a first order phase transition at some \(\theta_c<\frac{\pi}{2}\) where \(\mathcal{A}_{\mathrm{Degen}}\) becomes preferred, even though it doesn’t contain a true parameter?

Which phase is preferred?

The weight configuration of each phase at \(\frac{\pi}{2}\). Our experiments will test which of these is preferred.

Results

PT1 Posterior density animation

TOP - The posterior samples for each \(\theta \in [1,\frac{\pi}{2}]\), where the red dots indicate the true parameters \(w_1^{(0)}, w_2^{(0)}\) being rotated inwards. 

BOTTOM - Estimates of the relative frequency of each phase, i.e. \(n_{\mathcal{A}}/n\). Do not be confused - these are not free energy curves like in previous figures of the sequence, but they are instead measurements of posterior concentration of each phase at each \(\theta\). In other words, the direction is reversed - higher is better! 

Notice how there is a first order phase transition at \(\theta_c = 1.26\). Note also that defining the phases according to fixed annuli is not a perfect clustering mechanism, resulting in \(\mathcal{A}^c\) increasing in relative frequency for middle \(\theta\) values. It is merely designed to paint a numerical picture of what we see visually in the changing posterior.

There is also a static facet grid of the key frames if you want a closer inspection.

The results of our experiments show:

  1. At \(\theta=\frac{\pi}{2}\), the degenerate phase \(\mathcal{A}_{\mathrm{Degen}}\) is preferred.
  2. There is a first order phase transition at \(\theta_c=1.26^c\) where \(\mathcal{A}_{\mathrm{Degen}}\) becomes preferred, despite not containing a true parameter for \(\theta < \frac{\pi}{2}\),

It is unsurprising (yet satisfying) that the degenerate phase \(\mathcal{A}{\mathrm{Degen}}\) is preferred at \(\theta=\frac{\pi}{2}\), in line with what the WBIC tells us to expect. What might be more surprising, though, is that \(\mathcal{A}{\mathrm{NonDegen}}\) has extremely little posterior density at this \(\theta\) value.  13

As we have argued throughout the sequence, the free energy formula suggests that first order phase transitions happen when there is a change in the accuracy-complexity tradeoff such that the posterior newly preferences one phase over the other. Here, the first order phase transition at \(\theta_c = 1.26^c\) can be understood in these terms with the following graph that depicts how the accuracy of \(\mathcal{A}_{\mathrm{Degen}}\) improves with \(\theta\).

PT1 Accuracy of phases

The inaccuracy \(\min_{w \in \mathcal{W}} L_n(w)\) of each phase \(\mathcal{W}\) changes with \(\theta\). The variations in the inaccuracy are closely correlated to the underlying accuracy of the truth \(L_n(w^{(0)}) = S_n\), which are random variations since the truth is itself a random variable since it is \(f_2(x,\theta)\) + Gaussian noise. Increasing \(n\) and averaging over more trials would hopefully smooth out these curves. Nonetheless, it is clear that the phase transition occurs at the approximate point where the inaccuracy of the two phases is equal, thus implying the complexity of \(\mathcal{A}_{\mathrm{Degen}}\) is lower since it has higher posterior density for \(\theta>\theta_c\).

A Complexity Measure for Non-Analytic ReLU Networks

One last thing to point out here is that since \(K(w)\) is not analytic for ReLU neural networks, the RLCT is not a well defined object. Nonetheless, Watanabe has recently proven in this paper that there is a bound on the free energy, 

\[F_n \leq n S_n + \lambda_{\mathrm{ReLU}} \log n\]

where complexity \(\lambda_{\mathrm{ReLU}} \in \mathbb{Q}_{>0}\) is measured by the number of parameters in the smallest compressed network possible to represent the function, as a kind of ‘pseudo’-RLCT. In our case the\(\) the complexity is 

\[2\lambda_{\mathrm{ReLU}} = \begin{cases} 9 & \text{for }, 0<\theta< \frac{\pi}{2}  \ 5 & \text{for } \theta = \frac{\pi}{2} \end{cases}\]

\(\)since there are five parameters required in the degenerate phase and nine in the non-degenerate phase 14. In this way, Watanabe’s work predicts the results we see. This also shows us how the theory of SLT may be generalisable to the non-analytic setting and still give approximately the same essential insights into singular models.

Phase Transition 2 - Orientation Reversing Symmetry

Defining the Model, Truth, and Phases

This time we are going to consider a model network with \(d=3\) nodes, 

\[f(x,w) = c+ \sum_{i=1}^3 \mathrm{ReLU}(\langle \hat{w}_i, x \rangle + \hat{b}_i) ,,\]

and a realisable true network \(f_3(x,\vartheta)\) with \(m=3\) nodes,

\[f_3(x,\vartheta) = \sum_{i=1}^3 \mathrm{ReLU}\left(\langle \hat{w}_i^{(0)}, x \rangle -\frac{1}{3}\right) ,,\]

where the weights are defined by an order parameter \(\vartheta \in [1,3]\) that scales one gradient,

\[w_1^{(0)} = \left(\cos\frac{\pi}{3}, \sin\frac{\pi}{3}\right), \quad w_2^{(0)}(\vartheta) = \vartheta\left(\cos\pi, \sin\pi\right),, \quad  w_3^{(0)} = \left(\cos\frac{5\pi}{3}, \sin\frac{5\pi}{3}\right) ,.\]

True network changes by scaling left node

The network \(f_3(x,\vartheta)\) changes with \(\vartheta\) by increasing the gradient of the \(\hat{w}_2^{(0)}\) vector.

At \(\vartheta=1\), the weights satisfy the weight annihilation property, 

\[w_1^{(0)} + w_2^{(0)} + w_3^{(0)} = (0,0),,\]

meaning that reversing the orientation of the weights, \(w_i^{(0)} \mapsto -w_i^{(0)}\) (which is equal to a rotation by \(\pi\)), will preserve the function as discussed in DSLT3 - Orientation Reversal. We will use the label weight annihilation phase to refer to the configuration of nodes such that the weights all point into the centre region and annihilate one another.15 Our key question thus becomes: does the posterior prefers the weight annihilation phase, or the non-weight annihilation phase, at \(\vartheta=1\)?

Weight annihilation symmetry

Since \(\sum_{i=1}^3 w_i^{(0)}=0\), the orientation reversing symmetry is present, meaning we want to analyse whether the posterior prefers the weight annihilation phase or the non-weight annihilation phase.

To depict the phases on the \((w_{i,1}, w_{i,2})\) plane, let \(R(\theta) = (\cos\theta, \sin\theta)\), let \(\mathcal{B}(x, \varepsilon)\) be the closed ball of radius epsilon centred at \(x \in \mathbb{R}^2\), and let \(S_3\) denote the permutation group of order 3. Then the two phases of interest are 

\[\begin{aligned}\mathcal{E}{\text{NonWA}} &= \bigcup{\sigma \in S_3} \prod_{k=0}^2 \mathcal{B}\left( R \left(\frac{\pi}{3} + \frac{2\sigma(k)\pi}{3} \right) , \varepsilon \right) \    \mathcal{E}{\text{WA}} &= \bigcup{\sigma \in S_3} \prod_{k=0}^2 \mathcal{B}\left( R \left(\frac{2\sigma(k)\pi}{3} \right), \varepsilon\right) \end{aligned} ,.\]

Since \(w_2^{(0)}\) is being scaled by \(\vartheta\), we will understand the centre of each ball corresponding to \(\sigma(k)=1\) in \(\mathcal{E}_{\text{NonWA}}\) as being multiplied by the scalar \(\vartheta\). (It is easier to state in words that writing down in gory notation).

Phases of orientation reversing symmetry

The phases \(\mathcal{E}{\text{NonWA}}\) and \(\mathcal{E}{\text{WA}}\) in the \((w_{i,1}, w_{i,2})\) plane at \(\vartheta=1\). As \(\vartheta\) increases, the leftmost ball of \(\mathcal{E}_{\text{NonWA}}\) will shift leftward.

In this experiment our two questions are:

  1. At \(\vartheta=1\), which phase is preferred?
  2. Is there a first or second order phase transition at some \(\vartheta_c \in (1,3]\)?

Results

Posterior samples for each \(\vartheta \in [1,3]\), where the red dots are true parameters \(w_1^{(0)}, w_2^{(0)}(\vartheta), w_3^{(0)}\).
Notice how the \(\mathcal{E}{\text{WA}}\) phase has less posterior concentration at \(\vartheta=1\) despite containing a true parameter. There is a second order destruction transition at \(\vartheta_c \approx 2\) where \(\mathcal{W}{\text{WA}}\) ceases to have any mass since its inaccuracy is too high.

The results of this experiment show that:

  1. At \(\vartheta=1\) the non-weight annihilation phase \(\mathcal{E}_{\text{NonWA}}\) is preferred by the posterior.
  2. The weight annihilation phase \(\mathcal{E}{\text{WA}}\) is never preferred by the posterior, thus there is no first order phase transition. But there is a second order phase transition at \(\vartheta_c \approx 2\) where \(\mathcal{E}{\text{WA}}\) is destroyed.

In [Carroll, \(\S\)5.4.3], I perform a calculation on an even simpler orientation-reversing example which shows that the relative error of inner cancellation region strongly dictates the preference of the two phases. This relative error can be made smaller by increasing the size of the prior \(q(x)\). That result suggests that the two phases may have the same RLCT, but differing lower order geometry. This is speculative though, and it would be interesting to better understand the RLCT of both phases.

The second order phase transition is unsurprising since we specifically deform the network so that \(\mathcal{E}_{\text{WA}}\) doesn’t contain a true parameter for \(\vartheta \in (1,3]\). At \(\vartheta_c\), its inaccuracy is too highly penalised and the posterior contains no samples from the region.

References

[Callen] - H. Callen, Thermodynamics and an Introduction to Thermostatistics, 1991

[Gilmore] - R. Gilmore, Catastrophe Theory for Scientists and Engineers, 1981

[Wat18] - S. Watanabe, Mathematical Theory of Bayesian Statistics, 2018

[Carroll] - L. Carroll, Phase Transitions in Neural Networks, 2021

[Wei22] - S. Wei, D. Murfet, et al., Deep learning is singular, and that’s good, 2022


Footnotes

  1. At constant atmospheric pressure, that is.

  2. Yes, in any physics or chemistry textbook you will see the Gibbs free energy denotes by (G). I am writing (F) to keep it consistent with our later statistical learning discussion.

  3. At this point, this is a slight abuse of the physics notions. Typically the probability distribution is proportional to (e^{-\beta H(w)}) where (\beta) is the inverse temperature. In this case we are going to absorb the (\beta) into the (H(w;\theta)) term and not get too caught up in the actual physics - we’re just painting a conceptual picture to apply later on. 

  4. Which often correspond to the moments (mean, variance, etc.) of quantities like (H(w)).

  5. More precisely, considering the tempered posterior at inverse temperature (\beta >0), the Hamiltonian has the form [H_n(w) = n \beta L_n(w) -  \log \varphi(w) ,.](Since (K_n(w) = L_n(w)-S_n), the constant (S_n) in (w) is irrelevant). 

  6. Note here that a phase transitions of a dynamical system (i.e. SGD, which we can imagine as a particle moving subject to a potential well) is a slightly more subtle concept. One imagines the loss landscape to be fixed, and the “phase transition” corresponding to the particle moving from one particular phase in (W) to another. In this sense, there isn’t exactly a phase transition in the general sense, but there is a change in which phase a system finds itself in. 

  7. Which altered the posterior geometry, but not that of (K(w)) since (p(w|D_n) \approx e^{-nK(w)}) (up to a normalisation factor).

  8. It is a little bit disingenuous to continue to call these phases when (\delta ) is very close to 1, as the singularity (w_1^{(0)}) has a non-negligible effect on (\mathcal{W}_2), and vice-versa, meaning the phases lose their individual identities. Alternatively, one defines (\mathcal{W}_0) to centre on (w_0^{(0)}=0), and observes how the free energy changes with (\delta). But, I have kept the two “phases” (\mathcal{W}_1) and (\mathcal{W}_2) in the animation below to illustrate the general idea with minimum fuss.

  9. You might wonder why we still endow the model with the (q_i) parameters in the first place if we just normalise them out after the fact. We assumed it was more important to let the sampling procedure take place on an earnest neural network model without restricting its parameter space, thus trying to keep it in line with neural networks actually used in practice. But, it is likely that these results would hold otherwise, too. 

  10. The astute observer will notice that this is a white lie - the functional equivalence is true as long as each (q_i \geq 0). However, in our experiments, the true outgoing weights are (q_i=1), meaning a good sample will only ever have positive weights, i.e. any sample with a negative (q_i^{(k)}) will be removed by the outlier validation.

  11. Explicitly, the truth is defined by [f_2(x,\theta) = \mathrm{ReLU}\left( \cos(\theta) x_1 + \sin(\theta) x_2 - \frac{1}{3} \right) + \mathrm{ReLU}\left(-\cos(\theta) x_1 + \sin(\theta) x_2 - \frac{1}{3} \right) ,,]

  12. Relatedly, the plot of the KL divergence in Example 3.3 tells us to expect that the degenerate phase may be preferred.

  13. It is worth briefly mentioning the effect of the prior here. The free energy formula tells us that as (n\to \infty), the effects of the prior on learning become negligible. But of course, we are only ever in the finite (n) regime, at which point the prior does have effects on the posterior. In our case, since the prior is a Gaussian centred at (w=(0,0)) with standard deviation (1), it is reasonable to say that it has some bearing on the degenerate phase being preferred. However, further experiments showed that this behaviour is still retained for a flatter prior with increased standard deviation. The problem, however, is that the Markov chains can become very unstable on these priors, producing posterior samples with very high loss, indicating that the chains aren’t converging to the correct long-term distribution. In the interest of time, I decided not to continue to fine-tune the experiments on non-converging chains for a flatter prior, but it would be interesting to see to what extent the prior does affect these results.

  14. In other words, the degenerate phase requires a truth with five parameters[q_1^{(0)}\mathrm{ReLU}(w^{(0)}{1,1}x_1 + w^{(0)}{1,2}x_2 + b_1^{(0)})+ c,,]whereas the non-degenerate phase requires nine, [q_1^{(0)}\mathrm{ReLU}(w^{(0)}{1,1}x_1 + w^{(0)}{1,2}x_2 + b_1^{(0)})+ q_2^{(0)}\mathrm{ReLU}(w^{(0)}{2,1}x_1 + w^{(0)}{2,2}x_2 + b_2^{(0)})+  c ,.]

  15. @Leon Lang correctly pointed out that this is slightly weird terminology to use. Instead these should really be referred to as weight-cancellation instead of weight-annihilation, since both initial configurations obey the weight-annihilation property as I defined it, whereas what I am really referring to is the fact that in one configuration all weights are active and cancel in a region. It’s too late to change the terminology throughout, but do keep this in mind.

Comments