I trained a small transformer model and wanted to understand how it produced its results. While many papers and tutorials focus on multi-head self-attention, they don’t explain what happens after attention and how it leads to accurate predictions. After a 6-month investigation, I have a working theory: each transformer block learns weights that associate a prompt with a class of strings from the training data. The tokens that follow those strings in the training data become the block’s predictions for the next token. I implemented code to approximate the transformer’s output using these associations and found that it closely matches the actual output. This approach may provide a reasonable approximation of what the transformer is doing.
https://shyam.blog/posts/beyond-self-attention/