devinterp Project

Toy Models of LayerNorm

LayerNorm can have a large impact on learning dynamics. Can we characterize this in a simple toy model?

Project Details

Status: Unstarted
Difficulty: Medium
Type: Applied

Team & Contact

Tags

devinterp

Background

Take the following toy model which displays striking dynamical phase transitions and analyze it using SLT. Reach out to Jesse Hoogland for more context.

class ToyLNModel(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.linear1 = nn.Linear(n_features, n_hidden)
        self.ln = nn.LayerNorm(n_hidden)
        self.linear2 = nn.Linear(n_hidden, n_features, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.ln(x)
        x = self.linear2(x)
        return x

num_steps = 100000
num_inputs = 4
num_hidden = 64
indices = [0]
log_ivl = 50

steps  = list(range(num_steps))[::log_ivl]

model = ToyLNModel(num_inputs, num_hidden)
optimizer = torch.optim.SGD(model.parameters(), lr=3e-2, weight_decay=0)

losses = []
weights = []

for step in tqdm.trange(num_steps):
    model.zero_grad()

    x = torch.randn(128, num_inputs) * 20
    y = model(x)

    if indices:
        loss = F.mse_loss(y[:, indices], x[:, indices])
    else:
        loss = F.mse_loss(y, x)
    loss.backward()

    optimizer.step()

    if step % log_ivl == 0:
        losses.append(loss.item())
        weights.append(deepcopy(model.state_dict()))

Example of drop

Where to Begin

Before starting this project, we recommend familiarizing yourself with these resources:

Ready to contribute? Let us know in our Discord community . We'll update this listing so that other people interested in this project can find you.