Skip to content

Training/config refactor plan

Scope

Keep changes tightly limited to the canonical model/training package boundary (dfode_kit/models, dfode_kit/training) and the train CLI surface needed to select registered components. Do not redesign data loading, labeling, or DeepFlame integration in this slice.

Current audit

  • dfode_kit/training/train.py hard-codes:
  • model architecture (MLP([2+n_species, 400, 400, 400, 400, n_species-1]))
  • optimizer (Adam)
  • loop hyperparameters (max_epochs, LR schedule, batch size)
  • physics loss composition
  • The top-level train() function mixes config construction, data prep, model creation, optimizer setup, and the epoch loop.
  • Adding a new architecture or training algorithm currently requires editing the monolithic train module.

Minimal target design

  1. Typed training config
  2. Add small dataclass-based config objects for:
    • model selection + kwargs
    • optimizer hyperparameters
    • trainer selection + loop hyperparameters
  3. Keep a default config equivalent to today’s behavior.
  4. Model registry
  5. Add a simple in-process registry keyed by string name.
  6. Default-register mlp.
  7. Construction contract: factory(model_config, *, n_species, device).
  8. Trainer registry hook
  9. Add the same registry pattern for trainers, but keep only one default trainer in this slice.
  10. This creates the extension seam without rewriting the full training algorithm ecosystem.
  11. Refactor train entrypoint
  12. Preserve the public train(mech_path, source_file, output_path, time_step=...) signature.
  13. Add optional config parameter and route model/trainer creation through registries.
  14. Move loop internals into a trainer implementation function/class.
  15. Tests for lightweight harness
  16. Add pure-Python tests for config defaults/overrides and registry behavior.
  17. Avoid torch/cantera dependencies in harness tests.

First implementation slice in this branch

  • Add dataclass config module.
  • Add model + trainer registries.
  • Register current MLP architecture behind the registry.
  • Refactor train.py to use default config + registries while keeping behavior unchanged.
  • Add lightweight tests for config/registry contracts.

Follow-up slices

  • CLI flags or config-file loading (--model, --trainer, --config).
  • Multiple trainer implementations (baseline supervised vs physics-informed variants).
  • Synthetic smoke training test path that runs in the lightweight harness.
  • Dataset schema validation before training starts.