ESM3: A simplified primer to the model's architecture
I found myself bed ridden with a nasty cold (painful joints, fatigued body - the whole nine yards) so i decided this was a good time to summarize/simplify a research paper I read recently.
Context
A few weeks ago I attended a "Hidden Layers" event at a place called FractalTech, this is a hub for founders and engineers in Brooklyn (check them out...seriously). Hidden Layers is an event on Wednesdays where people suggest a research paper in the AI space, you read it in silence with a group of Engineers, ML Researchers, ML Engineers etc... and discuss the paper.
ESM3
The paper I am writing about relates to ESM3. This is a protein language model trained on over 2.7 billion protein sequences and their functions. This model simulates the evolutionary process to create new viable proteins.
I took down a lot of notes to better understand how the model works and why its necessary. I collated these notes to write this blog.
Architecture
The underlying model of ESM3 is a transformer model with an encoder and decoder layer. The model starts off by converting the data (amino acids) into dense vectors/tensors representing the complexity of amino acids, structure coordinates etc.
Embeddings
For the embedding layer
There are 7 unique input tracks to ESM3: a.) sequence (amino acid tokens), b.) structure coordinates, c.) structure tokens, d.) 8-class secondary structure labels (SS8), e.) quantized solvent-accessible surface area (SASA) values, f.) function keyword tokens and g.) residue (InterPro) annotation binary features.
Think of embeddings as numerical representations of data that machines understand.
Positional encoding layer
Typically, this is the layer where a function is applied to the embeddings enabling the model to understand the order of the input tokens. For example if this was text, you would need to understand the order in which the text occurs in order to predict what comes after a particular text. Unlike text however, amino acids are represented by 3D coordinates. Meaning that instead of tracking position you also have to track how far apart and what angles the amino acids are from each other in a space. This is crucial because the shape of a protein is crucial to its function.
Geometric Attention Layer
This is where it gets interesting, ESM3 uses an "invariant geometric attention mechanism" to efficiently process the 3D structure of proteins. Typically the standard multi-head attention applies multiple self attention mechanisms, where a weighted sum is computed for each token to represent the relationship between each token. This is done in order to understand how much attention to apply to a given sequence of tokens when making a prediction (ie. what comes after the word "it" depends on the context of the sentence it appears in). The multiple heads aspect ensures this is applied multiple times in parallel capturing the different and complex relationships between the data.
With regards to ESM3 this is adapted to capture the 3D spatial relationships between the amino acids in the protein. This is necessary since the model would need to 'pay attention' to the distances and angles between the amino acids and not just sequence proximity.
This means that the geometric information has to be invariant to rotations because proteins are 3D structures that can be oriented in any direction in a space. The functionality and structure are defined by the relative positions and interactions of their amino acids and not their absolute positions or orientations. This means that a model that looks to generate protein structures needs to be able to interpret these structures regardless of how they are positioned or oriented in the 3D space. Additionally, this adds robustness to the model. Thus ensuring an ability to produce consistent results without resulting in redundancy (ie. imagine the amount of redundancy created if you needed the model to learn about every possible orientation and position).
Local reference frames are defined by the bond geometry at each amino-acid, and allows local frames to interact globally through a transformation into the global frame
Simplified, each amino acid in a protein has a local reference frame based on its geometric context. For each amino acid, a local frame is constructed that captures it's geometric context. Each local frame is then encoded into tokens that represent the geometric relationships in a rotation and translation invariant manner. These token are then transformed to a global frame where interactions between the amino acids can be analyzed. This global interaction is what helps the model understand the overall 3D shape of the protein segment. The attention mechanism then uses these encoded tokens to understand how much influence each amino acid has on the others considering their geometric relationships
Decoder
While the encoder independently processes all local structures in parallel, the decoder attends over the entire set of L tokens to reconstruct the full structure. It is composed using a stack of bidirectional Transformer blocks with regular self-attention
The decoder is customized to process 3D geometric information with structure specific refinement ensuring the final output adheres to known physical and biological constraints. In simple terms, the decoder starts off by making an initial guess of the 3D coordinates of amino acids. The geometric attention layer ensures that each amino acid attends to its local frame. Cross-attention is applied integrating information from the encoder to refine its predictions. The geometric tokens are converted to 3D coordinates, predicting the positions of each amino acid. This predictions are adjusted iteratively refining the structure until it converges to a stable and accurate configuration.
This is a very simplified breakdown of parts of the paper related more to the architecture here is the link to the full paper.