Toy Models of LayerNorm

LayerNorm can have a large impact on learning dynamics. Can we characterize this in a simple toy model?

Type: Applied
Difficulty: Medium
Status: Unstarted

Background

Take the following toy model which displays striking dynamical phase transitions and analyze it using SLT. Reach out to Jesse Hoogland for more context.

class ToyLNModel(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.linear1 = nn.Linear(n_features, n_hidden)
        self.ln = nn.LayerNorm(n_hidden)
        self.linear2 = nn.Linear(n_hidden, n_features, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.ln(x)
        x = self.linear2(x)
        return x

num_steps = 100000
num_inputs = 4
num_hidden = 64
indices = [0]
log_ivl = 50

steps  = list(range(num_steps))[::log_ivl]

model = ToyLNModel(num_inputs, num_hidden)
optimizer = torch.optim.SGD(model.parameters(), lr=3e-2, weight_decay=0)

losses = []
weights = []

for step in tqdm.trange(num_steps):
    model.zero_grad()

    x = torch.randn(128, num_inputs) * 20
    y = model(x)

    if indices:
        loss = F.mse_loss(y[:, indices], x[:, indices])
    else:
        loss = F.mse_loss(y, x)
    loss.backward()

    optimizer.step()

    if step % log_ivl == 0:
        losses.append(loss.item())
        weights.append(deepcopy(model.state_dict()))

Example of drop

Where to begin:

If you have decided to start working on this, please let us know in the Discord. We'll update this listing so that other people who are interested in this project can find you.