Skip to content

Architecture overview

GAN-Engine is structured around a set of modular building blocks that can be remixed for any imaging domain. The project is implemented in PyTorch Lightning; every component is configurable through YAML and can be swapped without touching training loops.

Lightning module

gan_engine.model.SRGAN.SRGAN_model (and the upcoming task-specific Lightning modules) orchestrate the training, validation, and inference lifecycle:

  • Instantiates generators, discriminators, and loss functions via registry-driven factories.
  • Normalises inputs and denormalises outputs according to the dataset configuration.
  • Computes content, perceptual, and adversarial losses while managing warm-up schedules and ramps.
  • Logs metrics, images, and histograms to the configured logger (Weights & Biases or TensorBoard).
  • Exposes predict_step and tiling helpers for inference on large rasters, volumes, or long videos.

Dual-optimiser setup

The module supports both Lightning automatic-optimisation (for Lightning 2.x) and manual optimisation (for Lightning 1.x) to maintain backwards compatibility. Each step updates the generator and discriminator separately while respecting cadence controls (e.g. update the discriminator every n steps) and gradient clipping thresholds specified in the config.

Generators

Generators live under gan_engine.model.generators. The registry includes:

  • SRResNet / residual blocks – Strong baselines for RGB and grayscale imagery.
  • RCAB / channel attention – Residual channel attention blocks for hyperspectral or medical volumes where channel interactions matter.
  • RRDB (ESRGAN) – Residual-in-residual blocks with dense connections and adjustable growth channels.
  • Large-kernel attention (LKA) – Convolution + attention hybrids for detail-rich microscopy or satellite data.
  • Stochastic generators – Latent-conditioned branches for perceptual diversity and hallucinated detail.
  • UNet inpainting heads (roadmap) – Mask-aware decoders that fuse context from surrounding pixels.
  • Text-conditioned decoders (roadmap) – CLIP/transformer driven adapters that inject prompt embeddings.

Custom generators can be registered via gan_engine.model.registry.register_generator. As long as they expose the same signature, they can be configured through YAML like any built-in model.

Discriminators

Discriminators live in gan_engine.model.discriminator and share the same registry pattern:

  • Standard SRGAN discriminator – Convolutional classifier operating on whole images.
  • PatchGAN variants – Local adversaries ideal for texture-heavy microscopy or photographic enhancement.
  • ESRGAN discriminator – Deeper architecture with spectral-normalised layers and feature matching heads.
  • 3D-ready options – 3D convolutions for volumetric data (enable by setting Data.dimensions: 3d).
  • Mask-aware discriminators (roadmap) – Fuse mask channels to validate inpainting consistency.
  • Text/image fusion discriminators (roadmap) – Cross-attend to prompt embeddings for conditional synthesis.

Cadence and learning-rate scheduling can be tuned per discriminator via configuration keys.

Loss suite

gan_engine.model.losses provides a palette of loss functions that can be blended together:

  • Pixel/structural: L1, L2, SSIM, total variation.
  • Spectral/geometric: Spectral angle mapper, histogram matching penalties, gradient-domain losses.
  • Perceptual: VGG19, LPIPS, and custom feature extractors with channel-selection masks so you can target only the relevant bands.
  • Adversarial: BCE-with-logits, relativistic GAN, and feature-matching losses for discriminator stabilisation.
  • Prompt alignment (roadmap): CLIP/Text losses to align generated content with prompts or class labels.

Loss weights, warm-up phases, and perceptual channel masks are all defined in configuration files.

Normalisation & statistics

The gan_engine.data.normalizers package converts raw sensor/scanner values into network-friendly ranges. It supports per-channel z-score, percentile, min-max, histogram matching, and custom loaders. Normalisers can operate on 2D or 3D data and have separate behaviour for low-resolution (LR) and high-resolution (HR) branches.

Data pipeline

Datasets are defined in gan_engine.data. The key abstractions are:

  • Dataset selectors – YAML-driven wrappers that map modality keywords (e.g. CV, ChestXRay), mask-aware datasets, or prompt corpora to dataset classes.
  • Paired datasets – Return aligned LR/HR pairs with optional augmentation pipelines.
  • Conditional datasets – Provide (image, mask/prompt, metadata) tuples for upcoming inpainting and text tasks.

You can create custom dataset classes and register them via entry points or Python hooks referenced in configuration files.

Utilities

  • EMA manager (gan_engine.utils.ema) – Maintains exponential moving averages of generator weights.
  • Scheduler helpers (gan_engine.utils.schedulers) – Implement warm-ups, cosine ramps, and plateau detectors for both optimisers.
  • Logging utilities (gan_engine.utils.loggers) – Standardise image grids, scalar tracking, and histogram logging across loggers, including prompt/mask previews.

Understanding how these pieces interact will help you design new models, integrate domain-specific metrics, or extend the toolkit for novel modalities.