Toy Models of LayerNorm
LayerNorm can have a large impact on learning dynamics. Can we characterize this in a simple toy model?
Type: Applied
Difficulty: Medium
Status: Unstarted
Background
Take the following toy model which displays striking dynamical phase transitions and analyze it using SLT. Reach out to Jesse Hoogland for more context.
class ToyLNModel(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.linear1 = nn.Linear(n_features, n_hidden)
self.ln = nn.LayerNorm(n_hidden)
self.linear2 = nn.Linear(n_hidden, n_features, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.ln(x)
x = self.linear2(x)
return x
num_steps = 100000
num_inputs = 4
num_hidden = 64
indices = [0]
log_ivl = 50
steps = list(range(num_steps))[::log_ivl]
model = ToyLNModel(num_inputs, num_hidden)
optimizer = torch.optim.SGD(model.parameters(), lr=3e-2, weight_decay=0)
losses = []
weights = []
for step in tqdm.trange(num_steps):
model.zero_grad()
x = torch.randn(128, num_inputs) * 20
y = model(x)
if indices:
loss = F.mse_loss(y[:, indices], x[:, indices])
else:
loss = F.mse_loss(y, x)
loss.backward()
optimizer.step()
if step % log_ivl == 0:
losses.append(loss.item())
weights.append(deepcopy(model.state_dict()))
Where to begin:
- Quantifying degeneracy (Lau et al. 2023) ,
- Toy Models of Superposition ,
- Dynamical versus Bayesian Phase Transitions in a Toy Model of Superposition
If you have decided to start working on this, please let us know in the Discord. We'll update this listing so that other people who are interested in this project can find you.