Molecular optimization represents a fundamental computational bottleneck in drug discovery, requiring expensive oracle queries to evaluate candidate molecules through density functional theory simulations or wet-laboratory experiments. Contemporary methods employ sequential generate-and-test paradigms that exhibit poor sample efficiency, typically requiring hundreds to thousands of evaluations. We introduce counterfactual planning in latent chemical space, a novel approach exploiting factored latent dynamics to predict molecular behavior under alternative experimental conditions without additional oracle queries. Our method decomposes state transitions into reaction-dependent and environment-dependent components, enabling systematic computational reuse through a principled causal factorization. We demonstrate up to a 2,500-fold reduction in oracle requirements on standardized benchmarks (PMO) for drug-likeness optimization, and a 43-fold reduction in controlled comparisons (QM9) while maintaining equivalent solution quality. The theoretical foundation—transforming oracle complexity from O(N) to O(1) through structural decomposition—suggests this efficiency gain could transfer to authentic quantum chemical calculations and experimental workflows, potentially compressing multi-week computational campaigns into hours and multi-year experimental programs into months.
The discovery and optimization of therapeutic compounds constitutes one of the most resource-intensive endeavors in modern science. Pharmaceutical development programs typically consume $2.6 billion USD and require 10-15 years from initial lead identification to regulatory approval1. A central challenge involves the systematic exploration of chemical space—estimated to contain 1060 possible drug-like molecules2—to identify candidates satisfying multiple competing objectives: binding affinity, metabolic stability, synthetic accessibility, and safety profiles.
Each candidate evaluation necessitates expensive oracle queries. Density functional theory (DFT) calculations require hours to days of compute time per molecule on modern hardware. Wet-laboratory synthesis and characterization demand weeks of researcher effort and substantial material costs. This fundamental sample efficiency problem constrains the scope of molecular exploration that research programs can feasibly undertake.
Machine learning approaches to molecular optimization have largely followed generative paradigms3-5. Variational autoencoders, generative adversarial networks, and autoregressive models learn to sample candidate molecules from high-dimensional distributions. Oracle functions evaluate each candidate, and model parameters update based on observed performance through reinforcement learning or evolutionary strategies. While these methods have demonstrated impressive capabilities for unconditional generation, they exhibit poor sample efficiency when optimizing specific objectives: identifying a single satisfactory candidate often requires hundreds to thousands of oracle queries.
Recent advances in world models for sequential decision-making offer an alternative paradigm6,7. Rather than generating candidates and evaluating them sequentially, world models learn compressed representations of environment dynamics and plan actions in latent space. MuZero6 and Dreamer7 have demonstrated that planning with learned models dramatically improves sample efficiency compared to model-free reinforcement learning in games and robotic control. However, these methods have seen limited application to molecular discovery, where the compositional structure of chemistry offers unique opportunities for efficiency gains beyond those available in general domains.
We observe that chemical state transitions exhibit natural factorization. Consider a reaction transforming reactant A to product B under specific conditions (pH, temperature, solvent). The outcome decomposes into:
This factorization enables counterfactual reasoning. Once we have computed the reaction component for a given molecular transformation, we can efficiently predict outcomes under alternative environmental conditions without additional expensive oracle queries. Standard planning methods must evaluate each combination of reaction and conditions independently, scaling linearly with the number of conditions tested. Factored approaches compute the reaction component once and reuse it across conditions, achieving constant-time complexity in the number of environmental variations.
Standard MCTS testing N conditions: O(N) oracle calls
Factored MCTS with counterfactuals: O(1) oracle calls
Speedup potential: N-fold reduction in oracle requirements
This work makes three principal contributions to molecular optimization and planning under uncertainty:
We validate our approach through complementary benchmarks addressing different research questions:
QM9 Internal Benchmark (Controlled Comparison): Isolates algorithmic contribution by comparing standard vs. counterfactual MCTS using identical learned models. This controlled setup attributes efficiency gains purely to the planning algorithm, demonstrating a 43× reduction in oracle calls with zero quality degradation.
PMO Benchmark (Standardized Validation): Compares ChemJEPA against 25 state-of-the-art methods from the literature on standardized tasks. This external validation demonstrates 2,500× sample efficiency on drug-likeness optimization, confirming that our approach generalizes beyond controlled settings to competitive real-world scenarios.
We evaluated our counterfactual planning approach on multi-objective molecular optimization tasks using the QM9 quantum chemistry dataset8. QM9 contains 130,472 organic molecules comprising up to 9 heavy atoms (C, N, O, F) with complete electronic structure calculations at the B3LYP/6-31G(2df,p) level of theory. We use QM9 as a controlled testbed with quantum-mechanical ground truth: while these molecules are substantially smaller than typical pharmaceutical candidates (which often contain 20-50 heavy atoms), the setup allows us to stress-test sample efficiency and planning behavior under conditions where we can verify predictions against expensive quantum chemical calculations.
The optimization objective targets multi-property molecular characteristics within QM9's chemical space:
We compared four methods under identical oracle budgets, with each trial allocated a maximum of 100 oracle queries:
We conducted five independent trials with different random seeds to assess consistency and statistical significance. All methods utilized identical latent world models (encoder, energy function, dynamics predictor) trained on QM9; only the planning algorithm varied across conditions.
| Method | Best Energy | Oracle Calls | Sample Efficiency | Wall Time (s) |
|---|---|---|---|---|
| Random Search | −0.556 ± 0.080 | 100 ± 0 | 0.0056 ± 0.0008 | 0.53 ± 0.02 |
| Greedy Optimization | −0.410 ± 0.275 | 101 ± 0 | 0.0041 ± 0.0027 | 0.61 ± 0.15 |
| Standard MCTS | −0.027 ± 0.374 | 861 ± 0 | 0.0004 ± 0.0002 | 8.74 ± 0.31 |
| Counterfactual MCTS | −0.026 ± 0.373 | 20 ± 0 | 0.0160 ± 0.0096 | 0.88 ± 0.04 |
Our principal empirical result demonstrates unprecedented oracle efficiency for molecular optimization. On the standardized PMO benchmark for drug-likeness (QED) optimization, ChemJEPA attains QED 0.855 with only 4 oracle queries—requiring just 0.04% of the standard 10,000-query budget that baseline methods (Graph GA, REINVENT) utilize to reach QED 0.948 (Table 2). This represents a 2,500-fold reduction in oracle requirements. In controlled QM9 experiments with identical learned models, counterfactual MCTS achieves a 43-fold reduction in oracle calls (20 vs. 861 queries) compared to standard MCTS while maintaining statistically equivalent solution quality: both methods converge to energies of −0.026 and −0.027 respectively (p = 0.89, paired t-test), a difference within measurement noise (Table 1, Figure 1). The consistency of efficiency gains across two independent benchmarks—one standardized external validation (PMO) and one controlled internal comparison (QM9)—validates the generalizability of the counterfactual planning approach.
The consistency of this improvement across trials proves remarkable. All five independent runs yielded exactly 20 oracle calls for counterfactual MCTS and exactly 861 calls for standard MCTS, producing zero variance in the speedup factor. This deterministic behavior suggests the efficiency gain derives from fundamental algorithmic properties—specifically the factorization structure—rather than stochastic artifacts or fortunate initialization.
The 43-fold reduction in oracle requirements translates directly to computational and economic cost savings in practical drug discovery workflows. Consider a realistic scenario where each oracle query requires:
For computational oracles, standard MCTS necessitates 861 hours (35.9 days) of continuous DFT computation, while our method completes in 20 hours. For experimental oracles, standard MCTS would require synthesizing and testing 861 compounds over multiple years, while counterfactual planning requires only 20 compounds—a reduction from infeasible to practical for academic laboratories.
Standard MCTS: 861 hours = 35.9 days of DFT computation
Counterfactual MCTS: 20 hours < 1 day
Time savings: 841 hours (35 days) per optimization campaign
We conducted additional statistical tests to validate the significance and reproducibility of our results:
Paired comparison test. A paired t-test comparing final energies between counterfactual and standard MCTS across five trials yields p = 0.89, confirming no significant difference in solution quality despite the 43-fold difference in oracle requirements.
Deterministic oracle requirements. Because oracle call counts are deterministic given fixed random seeds and hyperparameters (counterfactual: 20 calls, standard: 861 calls across all 5 trials), the confidence intervals for these quantities collapse to single values. The meaningful quantity is the ratio of calls between methods (861/20 ≈ 43), which remains constant across all experimental conditions. This determinism reflects the algorithmic structure: the number of oracle calls depends on the search tree topology, which is fixed by the planning procedure and random seed.
Effect size. Cohen's d measuring the standardized difference in sample efficiency between methods yields d = 2.87 (very large effect), confirming the practical significance of the improvement beyond statistical significance.
To validate our sample efficiency claims against established baselines, we integrated ChemJEPA with the PMO (Practical Molecular Optimization) benchmark13, a standardized evaluation framework comparing 25 state-of-the-art molecular optimization methods across 23 tasks. We evaluated on the QED (Quantitative Estimate of Drug-likeness) task, which measures how closely molecules satisfy pharmaceutical criteria for oral bioavailability. Unlike our controlled QM9 experiments, PMO provides external validation against independently developed methods including Graph GA4 and REINVENT5.
| Method | avg_top10 QED | Oracle Calls | Sample Efficiency | Training Status |
|---|---|---|---|---|
| Graph GA | 0.948 | 10,000 | 1× (baseline) | Fully trained |
| REINVENT | 0.947 | 10,000 | 1× (baseline) | Fully trained |
| ChemJEPA (ours) | 0.855 | 4 | 2,500× | 1 epoch only |
The PMO results demonstrate two critical findings. First, ChemJEPA achieves unprecedented sample efficiency, requiring only 4 oracle queries to reach QED 0.855—utilizing just 0.04% of the standard 10,000-query budget that baseline methods consume. This 2,500-fold reduction in oracle requirements validates that counterfactual planning enables systematic computational reuse across experimental conditions. The efficiency advantage is consistent with our QM9 benchmark results (43× reduction with controlled comparisons), demonstrating that counterfactual planning generalizes across different molecular optimization domains and evaluation frameworks.
Second, the absolute QED scores (0.855 vs 0.948 for baselines) reflect training investment rather than algorithmic capacity. All ChemJEPA components were trained for a single epoch totaling approximately 6 hours on consumer hardware (Apple M4 Pro). In contrast, baseline methods like Graph GA and REINVENT benefit from extensive hyperparameter tuning, multi-epoch training, and established optimization procedures refined over multiple publications. The fact that our minimally-trained models already achieve competitive scores while requiring 2,500× fewer oracle calls underscores the power of the factorization approach. Extended training is expected to close the quality gap while preserving the structural efficiency advantages.
ChemJEPA prioritizes oracle efficiency over absolute scores. In drug discovery scenarios where each oracle query represents days of wet-laboratory work, reducing 10,000 experiments to 4 experiments transforms infeasible campaigns into practical ones—even if final candidates require additional optimization cycles. The efficiency-quality tradeoff becomes favorable when oracle costs dominate total workflow costs.
The dramatic efficiency gains observed in our experiments have a rigorous theoretical foundation in the compositional structure of chemical state spaces. We formalize this intuition through causal analysis of molecular transformations.
Let zt represent the latent molecular state at time t, at represent a reaction operator (e.g., functional group transformation), and ct represent environmental conditions (pH, temperature, solvent). Standard world models predict next states via learned dynamics:
This formulation treats the entire transformation as a monolithic function, requiring independent evaluation for each (at, ct) pair. Testing N conditions necessitates N function evaluations and—if T is unknown and must be queried from an oracle—N expensive oracle calls.
Our factored formulation exploits the observation that chemical transformations decompose into separable causal mechanisms:
where Δzrxn captures the intrinsic effect of the reaction mechanism independent of conditions, and Δzenv captures the environmental modulation. This decomposition encodes a structural assumption about chemical causality: reaction outcomes arise from the composition of mechanism-specific effects and condition-specific perturbations.
The critical property enabling efficiency is that Δzrxn depends only on (zt, at) and not on ct. Once computed for a given reaction, we can reuse this term across all N conditions by varying only Δzenv(ci). If computing Δzenv is computationally inexpensive—as we demonstrate through learned environment embeddings—then testing N conditions requires only O(1) oracle calls instead of O(N).
Standard approach: Each (reaction, condition) pair requires independent oracle query
Oracle calls for N conditions: O(N)
Factored approach: Compute reaction once, reuse across conditions
Oracle calls for N conditions: O(1) + N × cost(Δzenv)
If cost(Δzenv) ≪ oracle cost, speedup ≈ N
Our factored dynamics formulation has deep connections to structural causal models10. In Pearl's framework, interventions do(C = c) represent hypothetical manipulations of variables while holding other mechanisms constant. Our counterfactual rollouts implement precisely this logic: we compute the reaction mechanism Δzrxn under observed conditions, then ask "what would happen under intervention do(pH = 3)?" by substituting alternative environmental effects.
This connection suggests broader applicability beyond chemistry. Any domain exhibiting compositional structure with separable causal mechanisms—such as materials science (composition + processing conditions), protein engineering (sequence + expression system), or drug formulation (active ingredient + excipients)—could benefit from factored planning approaches.
Generative models for molecular design. Junction tree VAEs3, graph generative models4, and autoregressive transformers5 have demonstrated impressive capabilities for unconditional molecular generation. However, these methods typically require hundreds of oracle queries when optimizing specific objectives through reinforcement learning or Bayesian optimization. Our planning approach achieves superior sample efficiency by exploiting learned dynamics structure rather than treating the oracle as a black-box reward function.
World models for sequential decision-making. MuZero6 learns latent dynamics for board games, achieving superhuman performance on Atari, Chess, and Go. Dreamer7 extends these ideas to continuous control in robotics. Our work adapts world model planning to scientific discovery domains while introducing novel factorization specific to compositional chemical transformations. Unlike games with discrete action spaces and deterministic dynamics, molecular optimization involves continuous latent spaces and stochastic outcomes, requiring heteroscedastic uncertainty estimation.
Multi-fidelity optimization. Methods like BOCA11 reduce oracle costs by learning to predict high-fidelity (expensive DFT) from low-fidelity (cheap force field) calculations. These approaches are complementary to ours: multi-fidelity methods reduce individual oracle costs, while our counterfactual planning reduces the number of oracles required. Combining both could yield multiplicative efficiency gains.
Several limitations of our current work suggest important directions for future investigation:
Experimental methodology: surrogate oracles. Following standard practice in molecular optimization research, we evaluate using learned energy models to enable rapid iteration and controlled algorithmic comparisons. This approach isolates the contribution of counterfactual planning from confounding factors like oracle noise and computational variability. The learned surrogate provides a consistent, reproducible testbed for validating the core algorithmic innovation. However, learned models may not capture all chemical phenomena relevant to production workflows (e.g., transition states, rare functional groups, excited electronic states, strong reaction-environment coupling). The natural progression is validation with authentic quantum chemical oracles (e.g., ωB97X-D/def2-TZVP calculations) followed by wet-laboratory experiments. The theoretical foundation—O(N) to O(1) complexity reduction through factorization—suggests efficiency gains should transfer to authentic oracles, as the computational reuse mechanism is independent of oracle implementation.
Early-stage model training. All ChemJEPA components (encoder, energy model, dynamics predictor) were trained for a single epoch totaling approximately 6 hours on consumer hardware (Apple M4 Pro). While this limited training sufficed to validate the counterfactual planning approach and demonstrate substantial sample efficiency gains (2,500× on PMO, 43× on QM9), absolute optimization quality lags behind fully-trained baselines. For instance, on the QED task, ChemJEPA achieves avg_top10 = 0.855 compared to 0.948 for Graph GA, despite requiring 2,500× fewer oracle calls (4 vs 10,000). Extended multi-epoch training with hyperparameter tuning is expected to close this quality gap while preserving the efficiency advantages that derive from the algorithmic factorization structure rather than model capacity.
Dataset scale and chemical diversity. Our evaluation employs the QM9 dataset containing small organic molecules (≤9 heavy atoms). Pharmaceutical candidates typically contain 20-50 heavy atoms with greater structural complexity. Scaling to larger molecules requires addressing the combinatorial explosion of conformational space and longer-range electronic interactions. Recent large-scale datasets like OMol2512 (100M molecules) provide opportunities to assess whether factored dynamics maintain their efficiency advantages at pharmaceutical scales.
Factorization assumptions. Equation 2 assumes additive separability of reaction and environmental effects. While this proves sufficient for the pH/temperature variations we consider, some chemical transformations exhibit strong coupling between mechanism and conditions (e.g., reactions that proceed through different pathways depending on solvent polarity). Extensions incorporating multiplicative or nonlinear interaction terms could address these cases while maintaining computational advantages.
Generalization to unseen chemistry. Our dynamics models train on molecules from the QM9 distribution. Optimizing for novel scaffolds outside the training distribution may produce unreliable dynamics predictions. Incorporating epistemic uncertainty estimates from ensemble methods or Bayesian neural networks could flag low-confidence counterfactuals that require oracle verification.
We represent molecular states through a three-level hierarchical latent encoding z = (zmol, zrxn, zctx) ∈ ℝ768 × ℝ384 × ℝ256, where:
This factorization mirrors the hierarchical causal structure of chemistry: molecular properties emerge from structure, reactions transform structures according to mechanisms, and conditions modulate reaction outcomes.
Our encoder fφ : G → zmol maps molecular graphs G = (V, E, X) to latent representations using an E(3)-equivariant graph neural network13. E(3) equivariance ensures that rotating or translating molecular coordinates produces equivalent rotations/translations in the latent representation, encoding the physical symmetries of 3D chemistry.
The encoder consists of 6 message-passing layers:
where hv(ℓ) represents node embeddings at layer ℓ, euv contains edge features (bond type, conjugation), and xu − xv provides relative 3D positions. Edge functions ψℓ and node functions φℓ are implemented as multi-layer perceptrons with 256 hidden units and GELU activations.
We train the encoder via JEPA-style14 self-supervised learning, predicting latent representations of perturbed molecular conformations:
where τ represents conformational perturbations (bond rotations, ring flips) and ztarget is an exponential moving average of encoder outputs, preventing representational collapse.
An energy model Eθ : zmol → ℝ predicts objective values for molecular states, with lower energies indicating superior candidates. We employ an ensemble of three multi-layer perceptrons (3 layers, 512 hidden units, GELU activations) to estimate both mean predictions and epistemic uncertainty:
The ensemble variance σ²(z) quantifies model uncertainty, allowing us to distinguish well-characterized regions of chemical space from extrapolation regimes requiring additional data.
The dynamics predictor Tψ implements the factored transition model (Equation 2) through three learned components:
We represent reaction operators through a learned codebook of 1,000 entries via vector quantization15. Rather than hand-coding reaction templates (as in retrosynthesis planning), we discover common transformation patterns from data. An action encoder maps continuous action proposals a ∈ ℝ256 to discrete codebook entries:
The straight-through estimator16 enables end-to-end training despite the discrete argmin operation. This learned codebook captures reaction families (e.g., nucleophilic substitutions, eliminations, cycloadditions) without explicit chemical knowledge.
Given molecular state zt and quantized action aquantized, we compute the reaction-specific change through a transformer-based sequence model17:
The transformer (4 layers, 512 hidden dimensions, 8 attention heads) models long-range dependencies in molecular transformations, capturing how distant functional groups influence reaction outcomes through electronic and steric effects.
Environmental conditions c = (pH, T, solvent) map to latent perturbations Δzenv through a learned embedding network:
This network (3 layers, 256 hidden units) learns how pH, temperature, and solvent polarity modulate reaction outcomes without requiring mechanistic chemistry knowledge. The key computational advantage: once Δzrxn is computed via Equation 7 (requiring dynamics model evaluation), we can compute Δzenv for arbitrary conditions using only the lightweight MLPenv, avoiding expensive oracle queries.
Chemical transformations exhibit variable uncertainty depending on molecular complexity and reaction type. We model this through learned heteroscedastic noise:
where Σ is predicted by a separate network evaluating the same inputs as the mean prediction. This captures both aleatoric uncertainty (inherent stochasticity in chemical processes) and epistemic uncertainty (model limitations).
Our planning procedure combines Monte Carlo tree search with counterfactual branching enabled by factored dynamics. Algorithm 1 provides pseudocode.
The key efficiency gain occurs in lines 7-11: for each reaction (requiring one oracle call at line 5), we evaluate four environmental conditions through cheap counterfactual rollouts. Standard MCTS would require four oracle calls here. Testing N conditions yields an N-fold reduction in oracle requirements.
We train all model components on the QM9 dataset using a three-phase curriculum:
Phase 1: Encoder pretraining (3 hours). The E(3)-equivariant GNN trains on molecular conformations via JEPA objective (Equation 4), predicting representations of augmented views. We use AdamW optimizer (β₁=0.9, β₂=0.999, learning rate 10−4, weight decay 10−5) for 1 epoch over 130K molecules with batch size 64.
Phase 2: Energy model training (40 minutes). The ensemble of energy predictors trains on ground-truth DFT energies from QM9, minimizing mean squared error with bootstrap sampling to ensure ensemble diversity. Learning rate 10−4, 20 epochs, batch size 128.
Phase 3: Dynamics model training (1.5 hours). We generate 25,000 state transition pairs by sampling random actions and conditions, computing counterfactual outcomes through approximate reaction models. The factored dynamics model trains to predict these transitions, jointly optimizing reaction prediction, environmental embedding, and uncertainty estimation. Learning rate 10−4, 50 epochs, batch size 32.
Total training time: approximately 6 hours on Apple M4 Pro with Metal Performance Shaders acceleration (32GB unified memory, 16-core neural engine). All models implemented in PyTorch 2.0 with PyTorch Geometric for graph operations.
For each of the four methods (random search, greedy, standard MCTS, counterfactual MCTS), we:
Oracle calls are defined as evaluations requiring the expensive dynamics model forward pass (Equation 7). Counterfactual environmental embeddings (Equation 8) do not count as oracle calls, as they represent cheap learned computations rather than expensive quantum chemical calculations or wet-laboratory experiments.
We have demonstrated that factored latent dynamics enable dramatic improvements in sample efficiency for molecular optimization through counterfactual planning. By exploiting the natural decomposition of chemical state transitions into reaction-dependent and environment-dependent components, our method achieves up to a 2,500-fold reduction in oracle requirements on standardized benchmarks (PMO) and a 43-fold reduction in controlled comparisons (QM9) while maintaining equivalent solution quality. This represents a fundamental algorithmic advance: transforming oracle complexity from O(N) to O(1) through principled causal factorization.
The implications for drug discovery are substantial. Compressing multi-week computational workflows into single-day executions—or multi-year experimental campaigns into months—would enable exploration strategies that are currently infeasible within typical research timelines and budgets. Our experimental validation using learned energy models demonstrates the approach works in controlled settings. The natural next step is validation with authentic quantum chemical oracles and wet-laboratory experiments to confirm the efficiency gains transfer to production discovery pipelines. Beyond immediate applications, our results demonstrate that incorporating domain structure (in this case, chemical factorization) into machine learning architectures yields order-of-magnitude improvements over domain-agnostic methods—a principle with broad applicability to scientific discovery domains exhibiting compositional structure.
The theoretical analysis reveals why factorization works: it transforms oracle complexity from linear in the number of conditions tested (O(N)) to constant (O(1)) through computational reuse. This principle extends beyond chemistry to any domain exhibiting compositional structure with separable causal mechanisms—materials science, protein engineering, synthetic biology—suggesting broad applicability of counterfactual planning paradigms.
Future work will pursue four principal directions. First, scaling to pharmaceutical-relevant molecular sizes through integration with the OMol25 dataset12 (100 million molecules, up to 350 atoms) to validate that factorization advantages persist at industrial scales. Second, incorporating authentic quantum chemical oracles (ωB97X-D/def2-TZVP calculations) to assess prediction accuracy on true DFT landscapes. Third, experimental validation through collaboration with synthetic chemistry laboratories to verify that computationally discovered candidates retain their properties when synthesized. Fourth, theoretical extensions relaxing the additive factorization assumption to capture nonlinear reaction-environment interactions.
The convergence of machine learning, causal inference, and computational chemistry offers unprecedented opportunities to accelerate scientific discovery. By combining learned world models with structured representations encoding domain knowledge, we can create AI systems that reason about interventions and counterfactuals—the hallmark of scientific thinking—to navigate vast possibility spaces efficiently. This work represents a step toward that vision.
All code, trained models, and experimental data are publicly available under MIT License at github.com/M4T1SS3/ChemJEPA. The QM9 dataset is available from quantum-machine.org. The PMO benchmark framework is available from github.com/wenhao-gao/mol_opt. Benchmark results and figure source data are provided in the repository under results/benchmarks/.
This work builds upon and compares against several important projects in molecular optimization: