Skip to content

Trainer details

This page dives into the PyTorch Lightning trainer configuration used in GAN-Engine, explaining how key hooks are implemented and how to customise them for specialised workloads.

Automatic vs. manual optimisation

The Lightning module detects the installed Lightning version and toggles between automatic and manual optimisation modes:

  • Lightning ≥ 2.0 – Uses automatic optimisation; optimisers are returned from configure_optimizers and Lightning handles stepping.
  • Lightning 1.x – Switches to manual optimisation to retain fine-grained control over generator/discriminator updates.

Training step flow

  1. Fetch batch containing LR/HR tensors and metadata (plus masks/prompts when tasks require them).
  2. Normalise using the configured normaliser (per-branch logic allowed). Conditioning tensors are normalised via their own pipelines.
  3. Generator forward to produce SR output (or reconstructed/synthesised imagery for other tasks).
  4. Compute reconstruction losses (L1, SSIM, etc.).
  5. If adversarial active:
  6. Update discriminator according to cadence (d_update_interval).
  7. Compute adversarial and feature-matching losses.
  8. Apply EMA update after generator step if enabled.
  9. Log scalar metrics, images, optional histograms, and conditioning artefacts (masks, prompts, embeddings).

Callback suite

Checkpointing

ModelCheckpoint stores top-k checkpoints and monitors user-defined metrics. EMA checkpoints are saved automatically alongside the raw generator weights.

Learning-rate monitoring

LearningRateMonitor records LR schedules for both optimisers. Enable it via Callbacks.lr_monitor: true.

EMA swapper

EMACallback swaps EMA weights in before validation/prediction and restores raw weights afterwards. This ensures validation uses the smoothed generator.

Extra validation hooks

Add custom validation logic (e.g. running domain-specific metrics) by registering callbacks under Callbacks.extra_validation. Each callback receives the Lightning module and current outputs.

Mixed precision considerations

Lightning handles autocast and gradient scaling when Training.precision is 16 or bf16. If you introduce custom CUDA ops, ensure they support the selected precision or guard them with torch.cuda.amp.autocast(enabled=...).

Distributed strategies

  • DDP – Default for multi-GPU runs. Ensure find_unused_parameters is set appropriately for custom modules.

Logging integrations

  • Weights & Biases – Configured via Logging.logger: wandb. The trainer logs metrics, media, and configuration snapshots.
  • CSV – Minimal logging for constrained environments.

Early stopping & alerts

Add Callbacks.early_stopping with monitor, mode, and patience fields in config file. Combine with gradient monitors or custom alert hooks to notify operators when training stalls.

Custom trainer arguments

Provide additional Trainer kwargs through the config:

Trainer:
  enable_progress_bar: true
  log_every_n_steps: 50
  gradient_clip_algorithm: norm
  detect_anomaly: false

Any argument supported by Lightning's Trainer can be set this way, enabling precise tuning for HPC clusters, or research laptops.


Understanding these trainer details helps you reason about behaviour during long training runs and customise the pipeline for new domains without editing the core codebase.