Skip to content

Training API

Functions for launching PyTorch Lightning training runs and configuring trainer parameters from OmegaConf-based experiment files.

High-level entry points

train(config)

LOAD CONFIG

Source code in gan_engine/train.py
def train(config):

    #############################################################################################################
    """LOAD CONFIG"""
    # either path to config file or omegaconf object

    if isinstance(config, str) or isinstance(config, Path):
        config = OmegaConf.load(config)
    elif isinstance(config, dict):
        config = OmegaConf.create(config)
    elif OmegaConf.is_config(config):
        pass
    else:
        raise TypeError(
            "Config must be a filepath (str or Path), dict, or OmegaConf object."
        )
    #############################################################################################################

    # Get devices
    cuda_devices = config.Training.gpus
    cuda_strategy = "ddp" if len(cuda_devices) > 1 else None

    #############################################################################################################
    " LOAD MODEL "
    #############################################################################################################
    # load pretrained or instanciate new
    from gan_engine.model.SRGAN import SRGAN_model

    if config.Model.load_checkpoint != False:
        model = SRGAN_model.load_from_checkpoint(
            config.Model.load_checkpoint, strict=False
        )
    else:
        model = SRGAN_model(config=config)
    if config.Model.continue_training != False:
        resume_from_checkpoint_variable = config.Model.continue_training
    else:
        resume_from_checkpoint_variable = None

    #############################################################################################################
    """ GET DATA """
    #############################################################################################################
    # create dataloaders via dataset_selector -> config -> class selection -> convert to pl_module
    from gan_engine.data.dataset_selector import select_dataset

    pl_datamodule = select_dataset(config)

    #############################################################################################################
    """ Configure Trainer """
    #############################################################################################################

    # Configure Logger
    if config.Logging.wandb.enabled:
        # set up logging
        from pytorch_lightning.loggers import WandbLogger

        wandb_project = config.Logging.wandb.project  # whatever you want
        wandb_logger = WandbLogger(
            project=wandb_project, entity=config.Logging.wandb.entity, log_model=False
        )
    else:
        print("Not using Weights & Biases logging, reduced CSV logs written locally.")
        from pytorch_lightning.loggers import CSVLogger

        wandb_logger = CSVLogger(
            save_dir="logs/",
        )

    # Configure Saving Checkpoints
    from pytorch_lightning.callbacks import ModelCheckpoint

    dir_save_checkpoints = os.path.join(
        os.path.normpath("logs/"),
        config.Logging.wandb.project,
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
    )
    from gan_engine.utils.gpu_rank import (
        _is_global_zero,
    )  # make dir only on main process

    if _is_global_zero():  # only on main process
        os.makedirs(dir_save_checkpoints, exist_ok=True)
        print("Experiment Path:", dir_save_checkpoints)
        with open(
            os.path.join(dir_save_checkpoints, "config.yaml"), "w"
        ) as f:  # save config to experiment folder
            OmegaConf.save(config, f)
    checkpoint_callback = ModelCheckpoint(
        dirpath=dir_save_checkpoints,
        monitor=config.Schedulers.metric_g,
        mode="min",
        save_last=True,
        save_top_k=2,
    )

    # callback to set up early stopping
    from pytorch_lightning.callbacks.early_stopping import EarlyStopping

    early_stop_callback = EarlyStopping(
        monitor=config.Schedulers.metric_g,
        min_delta=0.00,
        patience=250,
        verbose=True,
        mode="min",
        check_finite=True,
    )  # patience in epochs

    #############################################################################################################
    """ Set Args for Training and Start Training """
    """ make it robust for both PL<2.0 and PL>=2.0 """
    #############################################################################################################
    from gan_engine.utils.build_trainer_kwargs import build_lightning_kwargs

    trainer_kwargs, fit_kwargs = (
        build_lightning_kwargs(  # get kwargs depending on PL version
            config=config,
            logger=wandb_logger,
            checkpoint_callback=checkpoint_callback,
            early_stop_callback=early_stop_callback,
            resume_ckpt=resume_from_checkpoint_variable,
        )
    )

    # Start training
    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(model, datamodule=pl_datamodule, **fit_kwargs)
    wandb.finish()

Trainer configuration helpers

build_lightning_kwargs(config, logger, checkpoint_callback, early_stop_callback, resume_ckpt=None)

Return Trainer/fit keyword arguments compatible with Lightning < 2 and ≥ 2.

Builds two dictionaries: 1) trainer_kwargs — safe, version-aware arguments for pytorch_lightning.Trainer. 2) fit_kwargs — arguments for Trainer.fit (e.g., ckpt_path on PL ≥ 2).

The helper normalizes device configuration (CPU/GPU, DDP when multiple GPUs), removes deprecated/None entries, and maps the legacy resume API: - PL < 2: uses resume_from_checkpoint in trainer_kwargs. - PL ≥ 2: uses ckpt_path in fit_kwargs (if supported by signature).

It also clears the legacy environment variable PL_TRAINER_RESUME_FROM_CHECKPOINT to avoid non-deterministic resume behavior.

Parameters:

Name Type Description Default
config

OmegaConf-like config with Training fields: - Training.device (str): "auto"|"cpu"|"cuda"/"gpu". - Training.gpus (int|Sequence[int]|None): device count/IDs. - Training.val_check_interval (int|float). - Training.limit_val_batches (int|float). - Training.max_epochs (int).

required
logger

A Lightning-compatible logger instance.

required
checkpoint_callback

Model checkpoint callback instance.

required
early_stop_callback

Early stopping callback instance.

required
resume_ckpt str | None

Path to checkpoint to resume from.

None

Returns:

Type Description

Tuple[Dict[str, Any], Dict[str, Any]]: - trainer_kwargs: Dict for pl.Trainer(**trainer_kwargs). - fit_kwargs: Dict for trainer.fit(..., **fit_kwargs) (may be empty).

Raises:

Type Description
ValueError

If Training.device is not one of {"auto","cpu","cuda","gpu"}.

Notes
  • CPU runs force devices=1 and no strategy.
  • GPU runs honor Training.gpus; DDP is enabled when requesting >1 device.
  • All None values are pruned; kwargs are filtered to match the current Trainer.__init__ and Trainer.fit signatures to stay future-proof.
Source code in gan_engine/utils/build_trainer_kwargs.py
def build_lightning_kwargs(
    config,
    logger,
    checkpoint_callback,
    early_stop_callback,
    resume_ckpt: str | None = None,
):
    """Return Trainer/fit keyword arguments compatible with Lightning < 2 and ≥ 2.

    Builds two dictionaries:
    1) ``trainer_kwargs`` — safe, version-aware arguments for ``pytorch_lightning.Trainer``.
    2) ``fit_kwargs`` — arguments for ``Trainer.fit`` (e.g., ``ckpt_path`` on PL ≥ 2).

    The helper normalizes device configuration (CPU/GPU, DDP when multiple GPUs),
    removes deprecated/None entries, and maps the legacy resume API:
    - PL < 2: uses ``resume_from_checkpoint`` in ``trainer_kwargs``.
    - PL ≥ 2: uses ``ckpt_path`` in ``fit_kwargs`` (if supported by signature).

    It also clears the legacy environment variable
    ``PL_TRAINER_RESUME_FROM_CHECKPOINT`` to avoid non-deterministic resume behavior.

    Args:
        config: OmegaConf-like config with ``Training`` fields:
            - ``Training.device`` (str): "auto"|"cpu"|"cuda"/"gpu".
            - ``Training.gpus`` (int|Sequence[int]|None): device count/IDs.
            - ``Training.val_check_interval`` (int|float).
            - ``Training.limit_val_batches`` (int|float).
            - ``Training.max_epochs`` (int).
        logger: A Lightning-compatible logger instance.
        checkpoint_callback: Model checkpoint callback instance.
        early_stop_callback: Early stopping callback instance.
        resume_ckpt (str | None): Path to checkpoint to resume from.

    Returns:
        Tuple[Dict[str, Any], Dict[str, Any]]:
            - trainer_kwargs: Dict for ``pl.Trainer(**trainer_kwargs)``.
            - fit_kwargs: Dict for ``trainer.fit(..., **fit_kwargs)`` (may be empty).

    Raises:
        ValueError: If ``Training.device`` is not one of {"auto","cpu","cuda","gpu"}.

    Notes:
        - CPU runs force ``devices=1`` and no strategy.
        - GPU runs honor ``Training.gpus``; DDP is enabled when requesting >1 device.
        - All ``None`` values are pruned; kwargs are filtered to match the current
        ``Trainer.__init__`` and ``Trainer.fit`` signatures to stay future-proof.
    """

    # ---------------------------------------------------------------------
    # 1) Version detection and environment cleanup
    # ---------------------------------------------------------------------
    # Determine whether the installed Lightning version is 2.x or newer.
    # The behaviour of ``resume_from_checkpoint`` changed between major
    # versions, so we compute this once and use the flag later when assembling
    # the kwargs.
    is_v2 = Version(pl.__version__) >= Version("2.0.0")

    # Lightning < 2 used an environment variable to infer the checkpoint path
    # when resuming.  The variable is ignored (and in some cases triggers
    # warnings) on newer versions, so we proactively remove it to provide a
    # deterministic behaviour across environments.
    os.environ.pop("PL_TRAINER_RESUME_FROM_CHECKPOINT", None)

    # ---------------------------------------------------------------------
    # 2) Parse device configuration from the OmegaConf config
    # ---------------------------------------------------------------------
    # ``Training.gpus`` may be specified either as an integer (e.g. ``2``) or a
    # sequence (e.g. ``[0, 1]``).  We keep the raw object so it can be passed to
    # the Trainer later if required, but we also count how many devices are
    # requested to decide on the parallelisation strategy.
    devices_cfg = getattr(config.Training, "gpus", None)

    # ``Training.device`` is the user-facing string that selects the backend.
    # Valid values are ``"cuda"`` / ``"gpu"`` (equivalent), ``"cpu"`` or
    # ``"auto"`` to defer to ``torch.cuda.is_available``.
    device_cfg = str(getattr(config.Training, "device", "auto")).lower()

    def _count_devices(devices):
        """Return how many explicit device identifiers were supplied."""

        # ``Trainer(devices=N)`` accepts both integers and sequences.  When the
        # user specifies an integer we can return it directly.  For sequences we
        # only count non-string iterables because strings are technically
        # sequences too but do not represent a collection of device identifiers.
        if isinstance(devices, int):
            return devices
        if isinstance(devices, Sequence) and not isinstance(devices, (str, bytes)):
            return len(devices)
        return 0

    ndev = _count_devices(devices_cfg)

    # Map the high-level ``device`` selector to the Lightning ``accelerator``
    # option.  ``auto`` chooses GPU when available and CPU otherwise so CLI
    # overrides are not required when moving between machines.
    if device_cfg in {"cuda", "gpu"}:
        accelerator = "gpu"
    elif device_cfg == "cpu":
        accelerator = "cpu"
    elif device_cfg in {"auto", ""}:
        accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    else:
        raise ValueError(f"Unsupported Training.device '{device_cfg}'")

    # When operating on CPU we force Lightning to a single device.  Allowing the
    # caller to pass the GPU list would be misleading because PyTorch does not
    # support multiple CPUs in the same way as GPUs.  On GPU we honour the user
    # supplied configuration and enable DistributedDataParallel only when more
    # than one device is requested.
    if accelerator == "cpu":
        devices = 1
        strategy = None
    else:
        devices = devices_cfg if ndev else 1
        strategy = "ddp" if ndev > 1 else None

    # ---------------------------------------------------------------------
    # 3) Assemble the base Trainer kwargs shared across Lightning versions
    # ---------------------------------------------------------------------
    trainer_kwargs = dict(
        accelerator=accelerator,
        strategy=strategy,  # removed in the next step when ``None``
        devices=devices,
        val_check_interval=config.Training.val_check_interval,
        limit_val_batches=config.Training.limit_val_batches,
        max_epochs=config.Training.max_epochs,
        log_every_n_steps=100,
        logger=[logger],
        callbacks=[checkpoint_callback, early_stop_callback],
        gradient_clip_val=config.Optimizers.gradient_clip_val,
    )

    # ``strategy`` defaults to ``None`` on CPU runs.  Lightning does not accept
    # explicit ``None`` values in its constructor, therefore we prune every
    # key/value pair whose value evaluates to ``None`` before forwarding the
    # kwargs.
    trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}

    # ---------------------------------------------------------------------
    # 4) Add compatibility shims for pre-Lightning 2 releases
    # ---------------------------------------------------------------------
    if not is_v2 and resume_ckpt:
        trainer_kwargs["resume_from_checkpoint"] = resume_ckpt

    # Some Lightning releases occasionally deprecate constructor arguments.  To
    # ensure we do not pass stale options we filter the dictionary so it only
    # contains parameters that are still accepted by ``Trainer.__init__``.
    init_sig = inspect.signature(pl.Trainer.__init__).parameters
    trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if k in init_sig}

    # ---------------------------------------------------------------------
    # 5) ``Trainer.fit`` keyword arguments (Lightning >= 2)
    # ---------------------------------------------------------------------
    fit_kwargs = {}
    if is_v2 and resume_ckpt:
        # ``ckpt_path`` is the new name for ``resume_from_checkpoint``.
        fit_sig = inspect.signature(pl.Trainer.fit).parameters
        if "ckpt_path" in fit_sig:
            fit_kwargs["ckpt_path"] = resume_ckpt

    return trainer_kwargs, fit_kwargs