ProteinModelWrapper
Overview
ProteinModelWrapper
is the base class for all protein prediction models in AIDE. It inherits from scikit-learn’s BaseEstimator
and TransformerMixin
, providing a familiar API while adding protein-specific validation and functionality. The framework uses a powerful mixin architecture to define model requirements and capabilities.
Core Behavior
The wrapper handles all protein-specific validation through its public methods:
# Public methods do all validation
def fit(self, X, y=None):
# Validates sequences
# Checks requirements based on mixins
# Runs pre-fit hooks from mixins
# Then calls _fit()
# Runs post-fit hooks from mixins
def transform(self, X):
# Validates sequences
# Checks requirements based on mixins
# Runs pre-transform hooks from mixins
# Then calls _transform()
# Runs post-transform hooks from mixins
You never need to call the private methods directly - they implement just the core model logic while the public methods handle validation, hooks, and other processing.
Key Attributes
Data Paths and References
metadata_folder
: Directory for model files (weights, checkpoints, etc.), randomly generated if not given.This serves to give each model a dedicated place to store necessary files, for example if a fasta file needs to be passed to an external program. Some models do not use it. If the model is capable of Caching (see the caching section), it is stored here. Currently this location may break when saving and loading models across machines.
wt
: Optional wild-type sequence for comparative predictions
Automated Properties
These are set based on which mixins you use:
# Model requirements
model.expects_no_fit # Whether the model expects no fit
model.requires_msa_for_fit # Whether the model requires an MSA as input for fitting
model.requires_wt_msa # Whether the model requires a wild type MSA
model.requires_msa_per_sequence # Whether the model requires an MSA for each sequence
model.requires_fixed_length # Whether the model requires a fixed length input
model.requires_wt_to_function # Whether the model requires the wild type sequence
model.requires_wt_during_inference # Whether wild is used during inference, see note below
model.requires_structure # Whether the model requires structure information
# Model capabilities
model.per_position_capable # Whether the model can output per position scores
model.can_regress # Whether the model outputs are estimates of activity
model.can_handle_aligned_sequences # Whether the model can handle unaligned sequences
model.accepts_lower_case # Whether the model can accept lowercase sequences
model.should_refit_on_sequences # Whether model should refit on new sequences
Mixin Hooks System
AIDE uses a system of hooks to allow mixins to modify the behavior of fitting and transformation. Hooks are registered during subclass creation and executed in order:
_mixin_init_handlers
: Runs during initialization to extract mixin parameters_pre_fit_hooks
: Runs before fitting to prepare data_post_fit_hooks
: Runs after fitting to process results_pre_transform_hooks
: Runs before transformation to prepare data_post_transform_hooks
: Runs after transformation to process results
For example, the CacheMixin
uses _pre_transform_hook
to check if results are already cached and _post_transform_hook
to store new results in the cache.
Wrapping a New Model
To wrap a new model:
Inherit from
ProteinModelWrapper
and any needed mixinsImplement only the core logic in private methods
Example:
class MyModel(CanRegressMixin, ProteinModelWrapper):
def __init__(self, metadata_folder=None, my_param=1.0):
super().__init__(metadata_folder=metadata_folder)
self.my_param = my_param
def _fit(self, X, y=None):
# Just core training logic
# X is guaranteed to be valid
# set a fitted attribute so sklearn can check if it's been fit
self.fitted_ = True
return self
def _transform(self, X):
# Just core transformation
# X is guaranteed to be valid
return features
Availability Checking
Models can specify their availability based on installed dependencies:
try:
import some_required_package
AVAILABLE = MessageBool(True, "Model is available")
except ImportError:
AVAILABLE = MessageBool(False, "Requires some_required_package")
class MyModel(ProteinModelWrapper):
_available = AVAILABLE
Mixins
Mixins define model requirements and capabilities. When combined with ProteinModelWrapper
, they automatically enable appropriate validation and behavior. Mixins are grouped by their general purpose:
Training Behavior
ExpectsNoFitMixin
For models that don’t require training:
Sets
expects_no_fit = True
Used by pretrained models like ESM2, SaProt
Allows
fit()
to be called with no data
ShouldRefitOnSequencesMixin
For models that need retraining on new sequences:
Sets
should_refit_on_sequences = True
By default sklearn will refit when fit is called, this is undesirable for some models (e.g., those that fit using an MSA)
That default behavior was disabled for
ProteinModelWrapper
but can be re-enabled by using this mixin
Input Requirements
RequiresMSAForFitMixin
For models trained on multiple sequence alignments:
Sets
requires_msa_for_fit = True
Base class ensures input is aligned during fit
Model receives guaranteed aligned sequences in
_fit
Used by evolutionary models like EVMutation, HMM
RequiresWTMSAMixin
For models that need the wild-type MSA during training:
Sets
requires_wt_msa = True
Ensures wild-type sequence has an associated MSA
Used by models like EVE that need WT context
RequiresMSAPerSequenceMixin
For models requiring MSA information for each sequence:
Sets
requires_msa_per_sequence = True
Ensures each sequence has its own MSA or falls back to WT MSA
Used by MSA Transformer for embedding generation
RequiresFixedLengthMixin
For models requiring uniform sequence length:
Sets
requires_fixed_length = True
Base class validates all sequences are same length
Common in neural networks and position-specific models
Validates wild-type length matches if present
RequiresStructureMixin
For models using structural information:
Sets
requires_structure = True
Ensures structures available or falls back to wild-type structure
Used by structure-aware models like SaProt
RequiresWTToFunctionMixin
For models that need a reference sequence:
Sets
requires_wt_to_function = True
Ensures wild-type sequence provided at initialization
Used by models computing mutation effects
RequiresWTDuringInferenceMixin
These models are expected to handle their own use of any widl type sequence:
Sets
requires_wt_during_inference = True
If wt is present during inference, AIDE does no further processing, where by default AIDE will call the model on any WT present and normalize the output to that sequence
Output Capabilities
CanRegressMixin
For models producing numeric predictions:
Sets
can_regress = True
Enables
predict()
methodAdds Spearman correlation scoring
Common in variant effect predictors
PositionSpecificMixin
For models with per-position outputs:
Sets
per_position_capable = True
Adds position selection with
positions
parameterControls output format with
pool
andflatten
parametersHandles dimension validation and reshaping
Common in language models and conservation analysis
Data Processing
CanHandleAlignedSequencesMixin
For models working with gapped sequences:
Sets
can_handle_aligned_sequences = True
Prevents automatic gap removal by base class
Essential for MSA-based models
Ensures gap characters preserved during processing
AcceptsLowerCaseMixin
For case-sensitive models:
Sets
accepts_lower_case = True
Disables automatic uppercase conversion
Useful when case represents conservation or focus columns
Computational Behavior
CacheMixin
For caching model outputs:
Adds disk-based caching system using SQLite and HDF5
Thread-safe for parallel processing
Automatic cache invalidation on parameter changes
Particularly useful for computationally expensive models
Using Multiple Mixins
Mixins can be combined to define complex model requirements. Always put ProteinModelWrapper
last in the inheritance chain:
class ComplexModel(
RequiresMSAForFitMixin, # Needs MSA for training
RequiresWTToFunctionMixin, # Needs wild-type reference
CanRegressMixin, # Makes predictions
PositionSpecificMixin, # Per-position outputs
CacheMixin, # Caches results
ProteinModelWrapper # Always last
):
def _fit(self, X: ProteinSequences, y=None):
# X is guaranteed to be aligned
# Implementation focuses on core logic
pass
def _transform(self, X: ProteinSequences):
# All requirements validated by base class
# Implementation focuses on core logic
pass
Complete List of Mixins
Mixin |
Purpose |
Sets |
---|---|---|
|
For models without fitting |
|
|
For MSA-based training |
|
|
For models needing WT MSA |
|
|
For per-sequence MSA |
|
|
For fixed-length inputs |
|
|
For structure-aware models |
|
|
For models needing WT |
|
|
For WT during inference |
|
|
For prediction models |
|
|
For position outputs |
|
|
For gapped sequences |
|
|
For case sensitivity |
|
|
For result caching |
Adds caching hooks |
|
For retraining |
|