Skip to content

Data Pipeline

Components responsible for turning configuration files into ready-to-use PyTorch Lightning data modules and applying reproducible normalization policies.

Dataset selection

select_dataset(config)

Build train/val datasets from config and wrap them into a LightningDataModule.

Expected config fields (OmegaConf/dict-like): - config.Data.dataset_selection : str One of {"S2_6b", "S2_4b"} in this file. - config.Generator.scaling_factor : int Super-resolution scale factor (e.g., 2, 4, 8). Passed to dataset as sr_factor.

Hard-coded choices below (kept as-is, not modified): - manifest_json : Path to prebuilt SAFE window manifest. - band orders : Fixed lists for each selection. - hr_size : (512, 512) - group_by : "granule" - group_regex : r".?/GRANULE/([^/]+)/IMG_DATA/."

Returns

pl_datamodule : LightningDataModule A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.

Source code in gan_engine/data/dataset_selector.py
def select_dataset(config):
    """
    Build train/val datasets from `config` and wrap them into a LightningDataModule.

    Expected `config` fields (OmegaConf/dict-like):
    - config.Data.dataset_selection : str
        One of {"S2_6b", "S2_4b"} in this file.
    - config.Generator.scaling_factor : int
        Super-resolution scale factor (e.g., 2, 4, 8). Passed to dataset as `sr_factor`.

    Hard-coded choices below (kept as-is, not modified):
    - manifest_json : Path to prebuilt SAFE window manifest.
    - band orders   : Fixed lists for each selection.
    - hr_size       : (512, 512)
    - group_by      : "granule"
    - group_regex   : r".*?/GRANULE/([^/]+)/IMG_DATA/.*"

    Returns
    -------
    pl_datamodule : LightningDataModule
        A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.
    """
    dataset_selection = config.Data.dataset_type

    # Please Note: The "S2_6b","S2_4b","SISR_WW" settings are leftover from previous versions
    # I dont want to delete them in case they are needed again.
    # Only the "ExampleDataset" is actively used in the current version.

    if dataset_selection == "CV":
        from gan_engine.data.CV.cv_dataset import CV_dataset

        path = "/data3/GAN_datasets/CV"
        ds_train = CV_dataset(
            path=path,
            phase="train",
            sr_factor=config.Generator.scaling_factor,
            lr_size=128,
        )
        ds_val = CV_dataset(
            path=path,
            phase="val",
            sr_factor=config.Generator.scaling_factor,
            lr_size=128,
        )

    elif dataset_selection == "SISR_WW":
        from .SISR_WW.SISR_WW_dataset import SISRWorldWide

        path = "/data3/SEN2NAIP_global"
        ds_train = SISRWorldWide(path=path, split="train")
        ds_val = SISRWorldWide(path=path, split="val")

    elif dataset_selection == "ExampleDataset":
        from gan_engine.data.example_data.example_dataset import ExampleDataset

        path = "example_dataset/"
        ds_train = ExampleDataset(folder=path, phase="train")
        ds_val = ExampleDataset(folder=path, phase="val")

    elif dataset_selection == "ChestXRay":
        from gan_engine.data.chestxrays.XRay_dataset import XRayDataset

        ds_train = XRayDataset(phase="train", sr_factor=config.Generator.scaling_factor)
        ds_val = XRayDataset(phase="val", sr_factor=config.Generator.scaling_factor)

    else:
        # Centralized error so unsupported keys fail loudly & clearly.
        raise NotImplementedError(
            f"Dataset {dataset_selection} not implemented!"
            f"Add your dataset in data/dataset_selector.py to train on that."
        )

    # Wrap the two datasets into a LightningDataModule with config-driven loader knobs.
    pl_datamodule = datamodule_from_datasets(config, ds_train, ds_val)
    return pl_datamodule

datamodule_from_datasets(config, ds_train, ds_val)

Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.

Parameters

config : OmegaConf/dict-like Expected to contain: - Data.train_batch_size : int (fallback: Data.batch_size or 8) - Data.val_batch_size : int (fallback: Data.batch_size or 8) - Data.num_workers : int (default: 4) - Data.prefetch_factor : int (default: 2) ds_train : torch.utils.data.Dataset Training dataset (already instantiated). ds_val : torch.utils.data.Dataset Validation dataset (already instantiated).

Returns

LightningDataModule Exposes train_dataloader() and val_dataloader() using the settings above.

Source code in gan_engine/data/dataset_selector.py
def datamodule_from_datasets(config, ds_train, ds_val):
    """
    Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.

    Parameters
    ----------
    config : OmegaConf/dict-like
        Expected to contain:
          - Data.train_batch_size : int (fallback: Data.batch_size or 8)
          - Data.val_batch_size   : int (fallback: Data.batch_size or 8)
          - Data.num_workers      : int (default: 4)
          - Data.prefetch_factor  : int (default: 2)
    ds_train : torch.utils.data.Dataset
        Training dataset (already instantiated).
    ds_val : torch.utils.data.Dataset
        Validation dataset (already instantiated).

    Returns
    -------
    LightningDataModule
        Exposes `train_dataloader()` and `val_dataloader()` using the settings above.
    """
    from pytorch_lightning import LightningDataModule
    from torch.utils.data import DataLoader

    class CustomDataModule(LightningDataModule):
        """Tiny DataModule that forwards config-driven loader settings to DataLoader."""

        def __init__(self, ds_train, ds_val, config):
            super().__init__()
            self.ds_train = ds_train
            self.ds_val = ds_val

            # Pull loader settings from config with safe fallbacks.
            self.train_bs = getattr(
                config.Data, "train_batch_size", getattr(config.Data, "batch_size", 8)
            )
            self.val_bs = getattr(
                config.Data, "val_batch_size", getattr(config.Data, "batch_size", 8)
            )
            self.num_workers = getattr(config.Data, "num_workers", 4)
            self.prefetch_factor = getattr(config.Data, "prefetch_factor", 2)

            # print dataset sizes for sanity
            print(
                f"Created Dataset type {config.Data.dataset_type} with {len(self.ds_train)} training samples and {len(self.ds_val)} validation samples.\n"
            )

        def train_dataloader(self):
            """Return the training DataLoader with common performance flags."""
            kwargs = dict(
                batch_size=self.train_bs,
                shuffle=True,  # Shuffle only in training
                num_workers=self.num_workers,
                pin_memory=True,  # Speeds up host→GPU transfer on CUDA
                persistent_workers=self.num_workers
                > 0,  # Keep workers alive between epochs
            )
            # prefetch_factor is only valid when num_workers > 0
            if self.num_workers > 0:
                kwargs["prefetch_factor"] = self.prefetch_factor
            return DataLoader(self.ds_train, **kwargs)

        def val_dataloader(self):
            """Return the validation DataLoader (no shuffle)."""
            kwargs = dict(
                batch_size=self.val_bs,
                shuffle=True,  # shuffle ordering for validation - more diversity in batches
                num_workers=self.num_workers,
                pin_memory=True,
                persistent_workers=self.num_workers > 0,
            )
            if self.num_workers > 0:
                kwargs["prefetch_factor"] = self.prefetch_factor
            return DataLoader(self.ds_val, **kwargs)

    return CustomDataModule(ds_train, ds_val, config)

Normalization utilities

Utility helpers for configuring tensor normalization strategies.

Normalizer

Factory for applying configurable normalization/denormalization.

The normalizer inspects the provided configuration, determines the requested normalization scheme, and exposes normalize / denormalize helpers that downstream components can reuse without importing :mod:utils.spectral_helpers directly.

Supported methods include remote-sensing-focused helpers such as "normalise_10k" (0–10000 reflectance → [0, 1]), "normalise_10k_signed" (0–10000 reflectance → [-1, 1]), "normalise_s2" (Sentinel-2 symmetric stretch), "zero_one" (clamp to [0, 1]) and "zero_one_signed" ([0, 1][-1, 1]). Custom strategies can be registered by providing a mapping with {"name": "custom", "normalize": "module:callable", ...} in the configuration.

Source code in gan_engine/data/utils/normalizer.py
class Normalizer:
    """Factory for applying configurable normalization/denormalization.

    The normalizer inspects the provided configuration, determines the
    requested normalization scheme, and exposes ``normalize`` / ``denormalize``
    helpers that downstream components can reuse without importing
    :mod:`utils.spectral_helpers` directly.

    Supported methods include remote-sensing-focused helpers such as
    ``"normalise_10k"`` (0–10000 reflectance → ``[0, 1]``),
    ``"normalise_10k_signed"`` (0–10000 reflectance → ``[-1, 1]``),
    ``"normalise_s2"`` (Sentinel-2 symmetric stretch), ``"zero_one"`` (clamp to
    ``[0, 1]``) and ``"zero_one_signed"`` (``[0, 1]`` ↔ ``[-1, 1]``). Custom
    strategies can be registered by providing a mapping with
    ``{"name": "custom", "normalize": "module:callable", ...}`` in the
    configuration.
    """

    _STANDARD_METHODS: Dict[str, NormalizationStrategy] = {
        "sen2_stretch": NormalizationStrategy(
            normalize=sen2_stretch,
            denormalize=lambda tensor: torch.clamp(tensor * (3.0 / 10.0), 0.0, 1.0),
        ),
        "normalise_10k": NormalizationStrategy(
            normalize=partial(normalise_10k, stage="norm"),
            denormalize=partial(normalise_10k, stage="denorm"),
        ),
        "normalise_10k_signed": NormalizationStrategy(
            normalize=partial(normalise_10k_signed, stage="norm"),
            denormalize=partial(normalise_10k_signed, stage="denorm"),
        ),
        "normalise_s2": NormalizationStrategy(
            normalize=partial(normalise_s2, stage="norm"),
            denormalize=partial(normalise_s2, stage="denorm"),
        ),
        "zero_one": NormalizationStrategy(
            normalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
            denormalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
        ),
        "zero_one_signed": NormalizationStrategy(
            normalize=partial(zero_one_signed, stage="norm"),
            denormalize=partial(zero_one_signed, stage="denorm"),
        ),
        "identity": NormalizationStrategy(
            normalize=lambda tensor: tensor,
            denormalize=lambda tensor: tensor,
        ),
    }

    _ALIASES: Dict[str, str] = {
        "normalize_10k": "normalise_10k",
        "reflectance": "normalise_10k",
        "reflectance_0_1": "normalise_10k",
        "reflectance_signed": "normalise_10k_signed",
        "normalize_10k_signed": "normalise_10k_signed",
        "sentinel2": "normalise_10k",
        "sentinel2_signed": "normalise_10k_signed",
        "s2": "normalise_10k",
        "s2_signed": "normalise_10k_signed",
        "normalize_s2": "normalise_s2",
        "zero_to_one": "zero_one",
        "zero_one_range": "zero_one",
        "signed_zero_one": "zero_one_signed",
        "minusone_one": "zero_one_signed",
        "none": "identity",
    }

    def __init__(self, config: Any):
        data_cfg = getattr(config, "Data", None)
        raw_cfg: Any = None
        if data_cfg is not None:
            raw_cfg = getattr(data_cfg, "normalization", None)
            if raw_cfg is None and isinstance(data_cfg, dict):
                raw_cfg = data_cfg.get("normalization")
        if raw_cfg is None:
            raw_cfg = "sen2_stretch"

        method, strategy = self._resolve_strategy(raw_cfg)
        self._cfg = _NormalizerConfig(method=method)
        self._strategy = strategy

    @property
    def method(self) -> str:
        """Return the normalization method configured for this instance."""

        return self._cfg.method

    def normalize(self, tensor: torch.Tensor) -> torch.Tensor:
        """Normalize ``tensor`` according to the configured method."""

        return self._strategy.normalize(tensor)

    def denormalize(self, tensor: torch.Tensor) -> torch.Tensor:
        """Invert the normalization previously applied by :meth:`normalize`."""

        return self._strategy.denormalize(tensor)

    @classmethod
    def available_methods(cls) -> Tuple[str, ...]:
        """Return the canonical names of built-in normalization strategies."""

        return tuple(sorted(cls._STANDARD_METHODS.keys()))

    def _resolve_strategy(
        self, raw_cfg: Any
    ) -> Tuple[str, NormalizationStrategy]:
        """Resolve ``raw_cfg`` into a normalisation strategy.

        Parameters
        ----------
        raw_cfg : Any
            Configuration value extracted from ``Data.normalization``. Can be a
            string alias, a mapping with ``name``/``method`` keys, or a mapping
            describing custom callables.
        """

        if isinstance(raw_cfg, Mapping):
            name = raw_cfg.get("name") or raw_cfg.get("method") or "custom"
            name = str(name).strip().lower()
            name = name.replace("normalize", "normalise")
            if name == "custom":
                strategy = self._build_custom_strategy(raw_cfg)
                return "custom", strategy
            raw_cfg = name

        if not isinstance(raw_cfg, str):
            raise TypeError(
                "Normalization config must be a string or mapping, "
                f"received: {type(raw_cfg)!r}"
            )

        method = raw_cfg.strip().lower()
        method = method.replace("normalize", "normalise")
        method = self._ALIASES.get(method, method)

        if method == "custom":
            raise ValueError(
                "Use a mapping with 'name: custom' and callable paths to configure custom normalization."
            )

        try:
            strategy = self._STANDARD_METHODS[method]
        except KeyError as exc:
            supported = ", ".join(sorted(self._STANDARD_METHODS))
            raise ValueError(
                f"Unsupported normalization '{raw_cfg}'. Supported methods: {supported}."
            ) from exc

        return method, strategy

    def _build_custom_strategy(
        self, cfg: Mapping[str, Any]
    ) -> NormalizationStrategy:
        """Instantiate a strategy from user-supplied callables."""

        if "normalize" not in cfg:
            raise ValueError(
                "Custom normalization requires a 'normalize' callable path."
            )

        normalize_path = cfg["normalize"]
        denormalize_path = cfg.get("denormalize")

        shared_kwargs = dict(cfg.get("kwargs", {}))
        norm_kwargs = {**shared_kwargs, **cfg.get("normalize_kwargs", {})}
        denorm_kwargs = {**shared_kwargs, **cfg.get("denormalize_kwargs", {})}

        normalize_fn = _load_callable(normalize_path, norm_kwargs)
        if denormalize_path is None:
            if denorm_kwargs:
                raise ValueError(
                    "'denormalize_kwargs' provided without a 'denormalize' callable."
                )
            denormalize_fn = lambda tensor: tensor
        else:
            denormalize_fn = _load_callable(denormalize_path, denorm_kwargs)

        return NormalizationStrategy(
            normalize=normalize_fn,
            denormalize=denormalize_fn,
        )

method property

Return the normalization method configured for this instance.

normalize(tensor)

Normalize tensor according to the configured method.

Source code in gan_engine/data/utils/normalizer.py
def normalize(self, tensor: torch.Tensor) -> torch.Tensor:
    """Normalize ``tensor`` according to the configured method."""

    return self._strategy.normalize(tensor)

denormalize(tensor)

Invert the normalization previously applied by :meth:normalize.

Source code in gan_engine/data/utils/normalizer.py
def denormalize(self, tensor: torch.Tensor) -> torch.Tensor:
    """Invert the normalization previously applied by :meth:`normalize`."""

    return self._strategy.denormalize(tensor)

available_methods() classmethod

Return the canonical names of built-in normalization strategies.

Source code in gan_engine/data/utils/normalizer.py
@classmethod
def available_methods(cls) -> Tuple[str, ...]:
    """Return the canonical names of built-in normalization strategies."""

    return tuple(sorted(cls._STANDARD_METHODS.keys()))