Training reference¶
This chapter explains how the training loop in GAN-Engine is structured and which switches you can flip to adapt it to your hardware and modality.
Workflow overview¶
- Config parsing – YAML files are loaded and validated.
- Module creation – Generator, discriminator, losses, and normalisation pipelines are instantiated.
- Trainer setup – PyTorch Lightning
Traineris configured with devices, precision, callbacks, and logging. - Training loop – Generator-only pretraining (optional), followed by adversarial updates with warm-ups and schedules.
- Checkpointing & logging – Images, metrics, and model weights are saved throughout the run.
When the Project.task is set to inpainting, conditional_gan, or text_to_image, the workflow adds mask/prompt loaders and
task-specific validation hooks automatically.
Command-line arguments¶
python -m gan_engine.train accepts only accepts the YAML path as input. All the important settings are handled from there.
| Flag | Description |
|---|---|
config PATH |
Path to YAML configuration file. |
resume |
Resume from the latest checkpoint in Project.output_dir. |
checkpoint PATH |
Start training from a specific checkpoint. |
devices N |
Number of GPUs or cpu. |
accelerator |
Lightning accelerator (gpu, cpu, mps). |
strategy |
Parallelisation strategy (ddp, fsdp, deepspeed, etc.). |
Optimisers & schedulers¶
- Default optimisers – Adam for both networks with config-defined learning rates/betas.
- Warm-ups – Configure linear or cosine warm-up per optimiser (
Schedulers.*.warmup_steps). - Plateau schedulers – Reduce LR on validation plateau with patience and factor controls.
- Cosine annealing / OneCycle – Enable via
Schedulers.generator.nameorSchedulers.discriminator.name.
Gradient clipping & penalties¶
Training.gradient_clip_val– Clip global norm of gradients.
Logging¶
The module logs via the configured logger:
- Scalars – Individual loss terms, learning rates, gradient norms.
- Images – LR/HR/SR triplets, masked reconstructions, or unconditional samples depending on the task.
- Histograms – Output distributions and discriminator logits.
- System info – GPU utilisation, memory footprint.
Adjust logging cadence through Logging.log_every_n_steps and Logging.log_images_every_n_steps.
Checkpoints¶
- Top-K checkpoints – Controlled by
Callbacks.checkpoint(metric, mode, count). - EMA checkpoints – Saved alongside raw weights when EMA is enabled.
- Periodic snapshots – Add
Callbacks.periodicwithevery_n_stepsfor archival.
Validation & testing¶
- Validation runs automatically according to
Trainer.check_val_every_n_epoch. - You can trigger additional evaluation loops by enabling
Callbacks.extra_validationwith custom hooks. - For final testing, reuse the inference CLI or add a custom evaluation callback to compute metrics on the test split (a dedicated validation command is on the roadmap).
Multi-node training¶
- Configure
strategy: ddp(orfsdpfor large models) and setnum_nodes/devicesaccordingly. - Ensure shared storage for checkpoints and dataset access.
The training stack is designed to be transparent: every stabilisation trick is optional and controlled through configuration so you can tailor runs to healthcare, geospatial, microscopy, or consumer imaging workloads.