Data Pipeline¶
Components responsible for turning configuration files into ready-to-use PyTorch Lightning data modules and applying reproducible normalization policies.
Dataset selection¶
select_dataset(config)
¶
Build train/val datasets from config and wrap them into a LightningDataModule.
Expected config fields (OmegaConf/dict-like):
- config.Data.dataset_selection : str
One of {"S2_6b", "S2_4b"} in this file.
- config.Generator.scaling_factor : int
Super-resolution scale factor (e.g., 2, 4, 8). Passed to dataset as sr_factor.
Hard-coded choices below (kept as-is, not modified): - manifest_json : Path to prebuilt SAFE window manifest. - band orders : Fixed lists for each selection. - hr_size : (512, 512) - group_by : "granule" - group_regex : r".?/GRANULE/([^/]+)/IMG_DATA/."
Returns¶
pl_datamodule : LightningDataModule A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.
Source code in gan_engine/data/dataset_selector.py
def select_dataset(config):
"""
Build train/val datasets from `config` and wrap them into a LightningDataModule.
Expected `config` fields (OmegaConf/dict-like):
- config.Data.dataset_selection : str
One of {"S2_6b", "S2_4b"} in this file.
- config.Generator.scaling_factor : int
Super-resolution scale factor (e.g., 2, 4, 8). Passed to dataset as `sr_factor`.
Hard-coded choices below (kept as-is, not modified):
- manifest_json : Path to prebuilt SAFE window manifest.
- band orders : Fixed lists for each selection.
- hr_size : (512, 512)
- group_by : "granule"
- group_regex : r".*?/GRANULE/([^/]+)/IMG_DATA/.*"
Returns
-------
pl_datamodule : LightningDataModule
A tiny DataModule that exposes train/val DataLoaders built from the selected datasets.
"""
dataset_selection = config.Data.dataset_type
# Please Note: The "S2_6b","S2_4b","SISR_WW" settings are leftover from previous versions
# I dont want to delete them in case they are needed again.
# Only the "ExampleDataset" is actively used in the current version.
if dataset_selection == "CV":
from gan_engine.data.CV.cv_dataset import CV_dataset
path = "/data3/GAN_datasets/CV"
ds_train = CV_dataset(
path=path,
phase="train",
sr_factor=config.Generator.scaling_factor,
lr_size=128,
)
ds_val = CV_dataset(
path=path,
phase="val",
sr_factor=config.Generator.scaling_factor,
lr_size=128,
)
elif dataset_selection == "SISR_WW":
from .SISR_WW.SISR_WW_dataset import SISRWorldWide
path = "/data3/SEN2NAIP_global"
ds_train = SISRWorldWide(path=path, split="train")
ds_val = SISRWorldWide(path=path, split="val")
elif dataset_selection == "ExampleDataset":
from gan_engine.data.example_data.example_dataset import ExampleDataset
path = "example_dataset/"
ds_train = ExampleDataset(folder=path, phase="train")
ds_val = ExampleDataset(folder=path, phase="val")
elif dataset_selection == "ChestXRay":
from gan_engine.data.chestxrays.XRay_dataset import XRayDataset
ds_train = XRayDataset(phase="train", sr_factor=config.Generator.scaling_factor)
ds_val = XRayDataset(phase="val", sr_factor=config.Generator.scaling_factor)
else:
# Centralized error so unsupported keys fail loudly & clearly.
raise NotImplementedError(
f"Dataset {dataset_selection} not implemented!"
f"Add your dataset in data/dataset_selector.py to train on that."
)
# Wrap the two datasets into a LightningDataModule with config-driven loader knobs.
pl_datamodule = datamodule_from_datasets(config, ds_train, ds_val)
return pl_datamodule
datamodule_from_datasets(config, ds_train, ds_val)
¶
Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.
Parameters¶
config : OmegaConf/dict-like Expected to contain: - Data.train_batch_size : int (fallback: Data.batch_size or 8) - Data.val_batch_size : int (fallback: Data.batch_size or 8) - Data.num_workers : int (default: 4) - Data.prefetch_factor : int (default: 2) ds_train : torch.utils.data.Dataset Training dataset (already instantiated). ds_val : torch.utils.data.Dataset Validation dataset (already instantiated).
Returns¶
LightningDataModule
Exposes train_dataloader() and val_dataloader() using the settings above.
Source code in gan_engine/data/dataset_selector.py
def datamodule_from_datasets(config, ds_train, ds_val):
"""
Convert a pair of prebuilt PyTorch Datasets into a minimal PyTorch Lightning DataModule.
Parameters
----------
config : OmegaConf/dict-like
Expected to contain:
- Data.train_batch_size : int (fallback: Data.batch_size or 8)
- Data.val_batch_size : int (fallback: Data.batch_size or 8)
- Data.num_workers : int (default: 4)
- Data.prefetch_factor : int (default: 2)
ds_train : torch.utils.data.Dataset
Training dataset (already instantiated).
ds_val : torch.utils.data.Dataset
Validation dataset (already instantiated).
Returns
-------
LightningDataModule
Exposes `train_dataloader()` and `val_dataloader()` using the settings above.
"""
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
class CustomDataModule(LightningDataModule):
"""Tiny DataModule that forwards config-driven loader settings to DataLoader."""
def __init__(self, ds_train, ds_val, config):
super().__init__()
self.ds_train = ds_train
self.ds_val = ds_val
# Pull loader settings from config with safe fallbacks.
self.train_bs = getattr(
config.Data, "train_batch_size", getattr(config.Data, "batch_size", 8)
)
self.val_bs = getattr(
config.Data, "val_batch_size", getattr(config.Data, "batch_size", 8)
)
self.num_workers = getattr(config.Data, "num_workers", 4)
self.prefetch_factor = getattr(config.Data, "prefetch_factor", 2)
# print dataset sizes for sanity
print(
f"Created Dataset type {config.Data.dataset_type} with {len(self.ds_train)} training samples and {len(self.ds_val)} validation samples.\n"
)
def train_dataloader(self):
"""Return the training DataLoader with common performance flags."""
kwargs = dict(
batch_size=self.train_bs,
shuffle=True, # Shuffle only in training
num_workers=self.num_workers,
pin_memory=True, # Speeds up host→GPU transfer on CUDA
persistent_workers=self.num_workers
> 0, # Keep workers alive between epochs
)
# prefetch_factor is only valid when num_workers > 0
if self.num_workers > 0:
kwargs["prefetch_factor"] = self.prefetch_factor
return DataLoader(self.ds_train, **kwargs)
def val_dataloader(self):
"""Return the validation DataLoader (no shuffle)."""
kwargs = dict(
batch_size=self.val_bs,
shuffle=True, # shuffle ordering for validation - more diversity in batches
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=self.num_workers > 0,
)
if self.num_workers > 0:
kwargs["prefetch_factor"] = self.prefetch_factor
return DataLoader(self.ds_val, **kwargs)
return CustomDataModule(ds_train, ds_val, config)
Normalization utilities¶
Utility helpers for configuring tensor normalization strategies.
Normalizer
¶
Factory for applying configurable normalization/denormalization.
The normalizer inspects the provided configuration, determines the
requested normalization scheme, and exposes normalize / denormalize
helpers that downstream components can reuse without importing
:mod:utils.spectral_helpers directly.
Supported methods include remote-sensing-focused helpers such as
"normalise_10k" (0–10000 reflectance → [0, 1]),
"normalise_10k_signed" (0–10000 reflectance → [-1, 1]),
"normalise_s2" (Sentinel-2 symmetric stretch), "zero_one" (clamp to
[0, 1]) and "zero_one_signed" ([0, 1] ↔ [-1, 1]). Custom
strategies can be registered by providing a mapping with
{"name": "custom", "normalize": "module:callable", ...} in the
configuration.
Source code in gan_engine/data/utils/normalizer.py
class Normalizer:
"""Factory for applying configurable normalization/denormalization.
The normalizer inspects the provided configuration, determines the
requested normalization scheme, and exposes ``normalize`` / ``denormalize``
helpers that downstream components can reuse without importing
:mod:`utils.spectral_helpers` directly.
Supported methods include remote-sensing-focused helpers such as
``"normalise_10k"`` (0–10000 reflectance → ``[0, 1]``),
``"normalise_10k_signed"`` (0–10000 reflectance → ``[-1, 1]``),
``"normalise_s2"`` (Sentinel-2 symmetric stretch), ``"zero_one"`` (clamp to
``[0, 1]``) and ``"zero_one_signed"`` (``[0, 1]`` ↔ ``[-1, 1]``). Custom
strategies can be registered by providing a mapping with
``{"name": "custom", "normalize": "module:callable", ...}`` in the
configuration.
"""
_STANDARD_METHODS: Dict[str, NormalizationStrategy] = {
"sen2_stretch": NormalizationStrategy(
normalize=sen2_stretch,
denormalize=lambda tensor: torch.clamp(tensor * (3.0 / 10.0), 0.0, 1.0),
),
"normalise_10k": NormalizationStrategy(
normalize=partial(normalise_10k, stage="norm"),
denormalize=partial(normalise_10k, stage="denorm"),
),
"normalise_10k_signed": NormalizationStrategy(
normalize=partial(normalise_10k_signed, stage="norm"),
denormalize=partial(normalise_10k_signed, stage="denorm"),
),
"normalise_s2": NormalizationStrategy(
normalize=partial(normalise_s2, stage="norm"),
denormalize=partial(normalise_s2, stage="denorm"),
),
"zero_one": NormalizationStrategy(
normalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
denormalize=lambda tensor: torch.clamp(tensor, 0.0, 1.0),
),
"zero_one_signed": NormalizationStrategy(
normalize=partial(zero_one_signed, stage="norm"),
denormalize=partial(zero_one_signed, stage="denorm"),
),
"identity": NormalizationStrategy(
normalize=lambda tensor: tensor,
denormalize=lambda tensor: tensor,
),
}
_ALIASES: Dict[str, str] = {
"normalize_10k": "normalise_10k",
"reflectance": "normalise_10k",
"reflectance_0_1": "normalise_10k",
"reflectance_signed": "normalise_10k_signed",
"normalize_10k_signed": "normalise_10k_signed",
"sentinel2": "normalise_10k",
"sentinel2_signed": "normalise_10k_signed",
"s2": "normalise_10k",
"s2_signed": "normalise_10k_signed",
"normalize_s2": "normalise_s2",
"zero_to_one": "zero_one",
"zero_one_range": "zero_one",
"signed_zero_one": "zero_one_signed",
"minusone_one": "zero_one_signed",
"none": "identity",
}
def __init__(self, config: Any):
data_cfg = getattr(config, "Data", None)
raw_cfg: Any = None
if data_cfg is not None:
raw_cfg = getattr(data_cfg, "normalization", None)
if raw_cfg is None and isinstance(data_cfg, dict):
raw_cfg = data_cfg.get("normalization")
if raw_cfg is None:
raw_cfg = "sen2_stretch"
method, strategy = self._resolve_strategy(raw_cfg)
self._cfg = _NormalizerConfig(method=method)
self._strategy = strategy
@property
def method(self) -> str:
"""Return the normalization method configured for this instance."""
return self._cfg.method
def normalize(self, tensor: torch.Tensor) -> torch.Tensor:
"""Normalize ``tensor`` according to the configured method."""
return self._strategy.normalize(tensor)
def denormalize(self, tensor: torch.Tensor) -> torch.Tensor:
"""Invert the normalization previously applied by :meth:`normalize`."""
return self._strategy.denormalize(tensor)
@classmethod
def available_methods(cls) -> Tuple[str, ...]:
"""Return the canonical names of built-in normalization strategies."""
return tuple(sorted(cls._STANDARD_METHODS.keys()))
def _resolve_strategy(
self, raw_cfg: Any
) -> Tuple[str, NormalizationStrategy]:
"""Resolve ``raw_cfg`` into a normalisation strategy.
Parameters
----------
raw_cfg : Any
Configuration value extracted from ``Data.normalization``. Can be a
string alias, a mapping with ``name``/``method`` keys, or a mapping
describing custom callables.
"""
if isinstance(raw_cfg, Mapping):
name = raw_cfg.get("name") or raw_cfg.get("method") or "custom"
name = str(name).strip().lower()
name = name.replace("normalize", "normalise")
if name == "custom":
strategy = self._build_custom_strategy(raw_cfg)
return "custom", strategy
raw_cfg = name
if not isinstance(raw_cfg, str):
raise TypeError(
"Normalization config must be a string or mapping, "
f"received: {type(raw_cfg)!r}"
)
method = raw_cfg.strip().lower()
method = method.replace("normalize", "normalise")
method = self._ALIASES.get(method, method)
if method == "custom":
raise ValueError(
"Use a mapping with 'name: custom' and callable paths to configure custom normalization."
)
try:
strategy = self._STANDARD_METHODS[method]
except KeyError as exc:
supported = ", ".join(sorted(self._STANDARD_METHODS))
raise ValueError(
f"Unsupported normalization '{raw_cfg}'. Supported methods: {supported}."
) from exc
return method, strategy
def _build_custom_strategy(
self, cfg: Mapping[str, Any]
) -> NormalizationStrategy:
"""Instantiate a strategy from user-supplied callables."""
if "normalize" not in cfg:
raise ValueError(
"Custom normalization requires a 'normalize' callable path."
)
normalize_path = cfg["normalize"]
denormalize_path = cfg.get("denormalize")
shared_kwargs = dict(cfg.get("kwargs", {}))
norm_kwargs = {**shared_kwargs, **cfg.get("normalize_kwargs", {})}
denorm_kwargs = {**shared_kwargs, **cfg.get("denormalize_kwargs", {})}
normalize_fn = _load_callable(normalize_path, norm_kwargs)
if denormalize_path is None:
if denorm_kwargs:
raise ValueError(
"'denormalize_kwargs' provided without a 'denormalize' callable."
)
denormalize_fn = lambda tensor: tensor
else:
denormalize_fn = _load_callable(denormalize_path, denorm_kwargs)
return NormalizationStrategy(
normalize=normalize_fn,
denormalize=denormalize_fn,
)