Model & Ranking
mshale-model v0.1.0 — feature engineering, contrastive pair construction, XGBoost LambdaRank (mshale-1 v0), C-index evaluation, Geneformer state embeddings, and the cross-domain generalizability test harness. 73 tests passing.
Architecture
mshale-1 v0
XGBoost LambdaRank
Pairwise ranking on contrastive protocol pairs. Objective: C-index ≥ 0.65 on held-out pairs. Tractable at 150 records.
mshale-1 v1+
PyTorch Transformer
12-layer protocol encoder + Geneformer state encoder + cross-attention outcome predictor. Requires 10K+ records.
Generalizability
Cross-domain transfer
Model trained on fibroblast→neuron applied to microbial fermentation. C-index > 0.52 validates the foundation model thesis.
Feature Engineering
Each ProtocolSpec is converted to a 223-dimensional feature vector:
| Feature Group | Dimensions | Description |
|---|---|---|
| TF gene vector | 200 | One-hot over known transcription factor gene set (HGNC-anchored) |
| Protocol structure | 7 | n_steps, total_duration_days, mean_step_interval, first_intervention_day, n_transgenes, n_small_molecules, n_protein_factors |
| Delivery method | 8 | Binary: lentiviral, retroviral, AAV, mRNA, plasmid, protein, CRISPR, electroporation |
| Dose summary stats | 3 | mean_moi, mean_sm_conc_uM, mean_pf_conc_ngmL |
| State distance | 1 | Cosine distance initial→target cell type embedding (Geneformer) |
| Context flags | 3 | is_inducible, species_human, species_mouse |
| Domain one-hot | 1 | domain_id as categorical feature |
Quickstart
pip install mshale-model # Train mshale-1 v0 on a corpus of ProtocolSpec records mshale-model train \ --corpus output/phase0/specs/ \ --output models/mshale1_v0.ubj \ --eval-split 0.2 # Evaluate: C-index vs baselines mshale-model eval \ --model models/mshale1_v0.ubj \ --corpus output/phase0/specs/ \ --report eval_report.json # Generalizability test: apply neuron model to microbial protocols mshale-model generalize \ --model models/mshale1_v0.ubj \ --corpus data/microbial_fermentation/ \ --report generalize_report.json
Python API
from mshale_model.features import extract_features
from mshale_model.pairs import build_contrastive_pairs
from mshale_model.ranker import MshaleRanker
from mshale_model.evaluate import c_index
# Load corpus
specs = [...] # list[ProtocolSpec]
# Feature engineering
X = extract_features(specs) # (n_protocols, 223)
# Contrastive pairs: label=1 if protocol_i more efficient than protocol_j
pairs = build_contrastive_pairs(specs, train_frac=0.8)
# Train XGBoost LambdaRank
model = MshaleRanker()
model.fit(pairs.X_train, pairs.y_train, pairs.qid_train)
# Evaluate
score = c_index(model.predict(pairs.X_test), pairs.y_test)
print(f"C-index: {score:.3f}") # target ≥ 0.65Contrastive Learning
Rather than predicting absolute efficiency (which suffers from publication bias), mshale-1 learns pairwise preferences: for protocols from the same paper, which was more efficient?
# For all (i, j) protocol pairs from the same paper: # label = 1 if efficiency_i > efficiency_j # label = 0 otherwise # exclude: pairs where either protocol has "not_reported" efficiency # split: 80% train / 20% held-out, stratified by domain # guarantee: pairs from same paper never split across train/test
mshale-1 v0 Parameters
params = {
"objective": "rank:pairwise",
"eval_metric": "ndcg",
"max_depth": 6,
"learning_rate": 0.05,
"n_estimators": 500,
"subsample": 0.8,
"colsample_bytree": 0.8,
"min_child_weight": 3,
}
# Evaluation: C-index on held-out pairs (target ≥ 0.65)
# Baselines: random (0.50), frequency-weighted (TF popularity)Geneformer Integration
The state_distance feature captures how far the initial cell type is from the target in Geneformer embedding space. This grounds the model in single-cell biology without requiring gene expression data at inference time.
from mshale_model.embeddings import GeneformerEmbedder
embedder = GeneformerEmbedder() # loads ctheodoris/Geneformer from HuggingFace
# Compute cosine distance between cell type embeddings
dist = embedder.state_distance(
initial_cell_type="fibroblast", # CL:0000057
target_cell_type="neuron", # CL:0000540
)
# → float: 0.0 (identical) to 1.0 (maximally distant)
# Added as a feature to the 223-dim vector