devinterp Project
Toy Models of LayerNorm
LayerNorm can have a large impact on learning dynamics. Can we characterize this in a simple toy model?
Background
Take the following toy model which displays striking dynamical phase transitions and analyze it using SLT. Reach out to Jesse Hoogland for more context.
class ToyLNModel(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.linear1 = nn.Linear(n_features, n_hidden)
self.ln = nn.LayerNorm(n_hidden)
self.linear2 = nn.Linear(n_hidden, n_features, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.ln(x)
x = self.linear2(x)
return x
num_steps = 100000
num_inputs = 4
num_hidden = 64
indices = [0]
log_ivl = 50
steps = list(range(num_steps))[::log_ivl]
model = ToyLNModel(num_inputs, num_hidden)
optimizer = torch.optim.SGD(model.parameters(), lr=3e-2, weight_decay=0)
losses = []
weights = []
for step in tqdm.trange(num_steps):
model.zero_grad()
x = torch.randn(128, num_inputs) * 20
y = model(x)
if indices:
loss = F.mse_loss(y[:, indices], x[:, indices])
else:
loss = F.mse_loss(y, x)
loss.backward()
optimizer.step()
if step % log_ivl == 0:
losses.append(loss.item())
weights.append(deepcopy(model.state_dict()))
Where to Begin
Before starting this project, we recommend familiarizing yourself with these resources:
1
Quantifying degeneracy (Lau et al. 2023)
2
Toy Models of Superposition
3
Dynamical versus Bayesian Phase Transitions in a Toy Model of Superposition
Ready to contribute? Let us know in our Discord community . We'll update this listing so that other people interested in this project can find you.