Skip to content

Inference API

Utilities for exporting pretrained SRGAN checkpoints and running tiled inference on Sentinel-2 imagery.

Core helpers

load_model(config_path=None, ckpt_path=None, device=None)

Build SRGAN model and (optionally) load weights. Safe to call from tests.

Source code in gan_engine/inference.py
def load_model(config_path=None, ckpt_path=None, device=None):
    """Build SRGAN model and (optionally) load weights. Safe to call from tests."""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = SRGAN_model(config_file_path=config_path).eval().to(device)

    if ckpt_path:
        # Try Lightning API first (without 'strict'); fall back to raw state_dict
        try:
            model = (
                SRGAN_model.load_from_checkpoint(ckpt_path, map_location=device)
                .eval()
                .to(device)
            )
        except TypeError:
            state = torch.load(ckpt_path, map_location=device)
            state = state.get("state_dict", state)
            model.load_state_dict(state, strict=False)

    return model, device

run_sen2_inference(sen2_path=None, config_path=None, ckpt_path=None, gpus=None, window_size=(128, 128), factor=4, overlap=12, eliminate_border_px=2, save_preview=False, debug=False)

Run Sentinel-2 SR inference. Kept out of import-time for CI.

Source code in gan_engine/inference.py
def run_sen2_inference(
    sen2_path=None,
    config_path=None,
    ckpt_path=None,
    gpus=None,
    window_size=(128, 128),
    factor=4,
    overlap=12,
    eliminate_border_px=2,
    save_preview=False,
    debug=False,
):
    """Run Sentinel-2 SR inference. Kept out of import-time for CI."""
    if gpus is not None and len(gpus) > 0:
        os.environ.setdefault("CUDA_VISIBLE_DEVICES", ",".join(map(str, gpus)))

    model, device = load_model(config_path=config_path, ckpt_path=ckpt_path)

    try:  # Prefer the rebranded utility package when available
        import gan_engine_utils as _large_utils
    except ImportError:  # pragma: no cover - legacy dependency support
        import opensr_utils as _large_utils

    sr_object = _large_utils.large_file_processing(
        root=sen2_path,
        model=model,
        window_size=window_size,
        factor=factor,
        overlap=overlap,
        eliminate_border_px=eliminate_border_px,
        device=device,
        gpus=gpus if gpus is not None else ([0] if device == "cuda" else []),
        save_preview=save_preview,
        debug=debug,
    )
    sr_object.start_super_resolution()
    return sr_object

main()

Source code in gan_engine/inference.py
def main():
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")

    # ---- Define placeholders ----
    sen2_path = Path(__file__).resolve().parent / "data" / "S2A_MSIL2A_EXAMPLE.SAFE"
    config_path = Path(__file__).resolve().parent / "configs" / "config_20m.yaml"
    ckpt_path = "checkpoints/srgan-20m-6band/last.ckpt"
    gpus = [0]

    run_sen2_inference(
        sen2_path=str(sen2_path),
        config_path=str(config_path),
        ckpt_path=ckpt_path,
        gpus=gpus,
    )

Pretrained model factory

Utility helpers to instantiate pretrained SRGAN models.

load_from_config(config_path, checkpoint_uri=None, *, map_location=None, mode='train')

Instantiate SRGAN_model from a YAML config and optional checkpoint.

Parameters

config_path: Filesystem path to the YAML configuration that describes the generator and discriminator architecture. This should match the configuration the checkpoint was trained with. checkpoint_uri: Optional path or HTTP(S) URL pointing to a Lightning checkpoint. When omitted, the factory returns an untrained model initialised from the supplied config. map_location: Optional argument forwarded to :func:torch.load during checkpoint deserialisation. mode: Mode in which to instantiate the model. Either "train" or "eval".

Source code in gan_engine/_factory.py
def load_from_config(
    config_path: Union[str, Path],
    checkpoint_uri: Optional[Union[str, Path]] = None,
    *,
    map_location: Optional[Union[str, torch.device]] = None,
    mode: str = "train",
) -> LightningModule:
    """Instantiate ``SRGAN_model`` from a YAML config and optional checkpoint.

    Parameters
    ----------
    config_path:
        Filesystem path to the YAML configuration that describes the generator
        and discriminator architecture. This should match the configuration the
        checkpoint was trained with.
    checkpoint_uri:
        Optional path or HTTP(S) URL pointing to a Lightning checkpoint. When
        omitted, the factory returns an untrained model initialised from the
        supplied config.
    map_location:
        Optional argument forwarded to :func:`torch.load` during checkpoint
        deserialisation.
    mode:
        Mode in which to instantiate the model. Either "train" or "eval".
    """

    config_path = Path(config_path)
    if not config_path.is_file():
        raise FileNotFoundError(f"Config file '{config_path}' could not be located.")

    model = SRGAN_model(config=config_path, mode=mode)

    if checkpoint_uri is not None:
        with _maybe_download(checkpoint_uri) as resolved_path:
            checkpoint = torch.load(str(resolved_path), map_location=map_location)
        state_dict = checkpoint.get("state_dict", checkpoint)
        model.load_state_dict(state_dict, strict=False)

        if model.ema is not None and "ema_state" in checkpoint:
            model.ema.load_state_dict(checkpoint["ema_state"])

    model.eval()
    return model

load_inference_model(preset, *, cache_dir=None, map_location=None)

Instantiate an off-the-shelf pretrained SRGAN.

The function downloads a known-good configuration + checkpoint pair from the Hugging Face Hub (unless it is already cached) and restores the packaged Lightning module.

Source code in gan_engine/_factory.py
def load_inference_model(
    preset: str,
    *,
    cache_dir: Optional[Union[str, Path]] = None,
    map_location: Optional[Union[str, torch.device]] = None,
) -> LightningModule:
    """Instantiate an off-the-shelf pretrained SRGAN.

    The function downloads a known-good configuration + checkpoint pair from
    the Hugging Face Hub (unless it is already cached) and restores the
    packaged Lightning module.
    """

    key = preset.strip().replace("_", "-").upper()
    try:
        preset_meta = _PRESETS[key]
    except KeyError as err:
        valid = ", ".join(sorted(_PRESETS))
        raise ValueError(
            f"Unknown preset '{preset}'. Available options: {valid}."
        ) from err

    try:  # pragma: no cover - import guard only used at runtime
        from huggingface_hub import hf_hub_download
    except ImportError as exc:  # pragma: no cover - dependency guard
        raise ImportError(
            "huggingface_hub is required for load_inference_model. "
            "Install the project extras or run 'pip install huggingface-hub'."
        ) from exc

    config_path = hf_hub_download(
        repo_id=preset_meta.repo_id,
        filename=preset_meta.config_filename,
        cache_dir=None if cache_dir is None else str(cache_dir),
    )
    checkpoint_path = hf_hub_download(
        repo_id=preset_meta.repo_id,
        filename=preset_meta.checkpoint_filename,
        cache_dir=None if cache_dir is None else str(cache_dir),
    )

    return load_from_config(
        config_path,
        checkpoint_path,
        map_location=map_location,
        mode="eval",
    )