Trainer details¶
This page dives into the PyTorch Lightning trainer configuration used in GAN-Engine, explaining how key hooks are implemented and how to customise them for specialised workloads.
Automatic vs. manual optimisation¶
The Lightning module detects the installed Lightning version and toggles between automatic and manual optimisation modes:
- Lightning ≥ 2.0 – Uses automatic optimisation; optimisers are returned from
configure_optimizersand Lightning handles stepping. - Lightning 1.x – Switches to manual optimisation to retain fine-grained control over generator/discriminator updates.
Training step flow¶
- Fetch batch containing LR/HR tensors and metadata (plus masks/prompts when tasks require them).
- Normalise using the configured normaliser (per-branch logic allowed). Conditioning tensors are normalised via their own pipelines.
- Generator forward to produce SR output (or reconstructed/synthesised imagery for other tasks).
- Compute reconstruction losses (L1, SSIM, etc.).
- If adversarial active:
- Update discriminator according to cadence (
d_update_interval). - Compute adversarial and feature-matching losses.
- Apply EMA update after generator step if enabled.
- Log scalar metrics, images, optional histograms, and conditioning artefacts (masks, prompts, embeddings).
Callback suite¶
Checkpointing¶
ModelCheckpoint stores top-k checkpoints and monitors user-defined metrics. EMA checkpoints are saved automatically alongside
the raw generator weights.
Learning-rate monitoring¶
LearningRateMonitor records LR schedules for both optimisers. Enable it via Callbacks.lr_monitor: true.
EMA swapper¶
EMACallback swaps EMA weights in before validation/prediction and restores raw weights afterwards. This ensures validation uses
the smoothed generator.
Extra validation hooks¶
Add custom validation logic (e.g. running domain-specific metrics) by registering callbacks under Callbacks.extra_validation.
Each callback receives the Lightning module and current outputs.
Mixed precision considerations¶
Lightning handles autocast and gradient scaling when Training.precision is 16 or bf16. If you introduce custom CUDA ops, ensure
they support the selected precision or guard them with torch.cuda.amp.autocast(enabled=...).
Distributed strategies¶
- DDP – Default for multi-GPU runs. Ensure
find_unused_parametersis set appropriately for custom modules.
Logging integrations¶
- Weights & Biases – Configured via
Logging.logger: wandb. The trainer logs metrics, media, and configuration snapshots. - CSV – Minimal logging for constrained environments.
Early stopping & alerts¶
Add Callbacks.early_stopping with monitor, mode, and patience fields in config file. Combine with gradient monitors or custom alert
hooks to notify operators when training stalls.
Custom trainer arguments¶
Provide additional Trainer kwargs through the config:
Trainer:
enable_progress_bar: true
log_every_n_steps: 50
gradient_clip_algorithm: norm
detect_anomaly: false
Any argument supported by Lightning's Trainer can be set this way, enabling precise tuning for HPC clusters,
or research laptops.
Understanding these trainer details helps you reason about behaviour during long training runs and customise the pipeline for new domains without editing the core codebase.