Inference API¶
Utilities for exporting pretrained SRGAN checkpoints and running tiled inference on Sentinel-2 imagery.
Core helpers¶
load_model(config_path=None, ckpt_path=None, device=None)
¶
Build SRGAN model and (optionally) load weights. Safe to call from tests.
Source code in gan_engine/inference.py
def load_model(config_path=None, ckpt_path=None, device=None):
"""Build SRGAN model and (optionally) load weights. Safe to call from tests."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SRGAN_model(config_file_path=config_path).eval().to(device)
if ckpt_path:
# Try Lightning API first (without 'strict'); fall back to raw state_dict
try:
model = (
SRGAN_model.load_from_checkpoint(ckpt_path, map_location=device)
.eval()
.to(device)
)
except TypeError:
state = torch.load(ckpt_path, map_location=device)
state = state.get("state_dict", state)
model.load_state_dict(state, strict=False)
return model, device
run_sen2_inference(sen2_path=None, config_path=None, ckpt_path=None, gpus=None, window_size=(128, 128), factor=4, overlap=12, eliminate_border_px=2, save_preview=False, debug=False)
¶
Run Sentinel-2 SR inference. Kept out of import-time for CI.
Source code in gan_engine/inference.py
def run_sen2_inference(
sen2_path=None,
config_path=None,
ckpt_path=None,
gpus=None,
window_size=(128, 128),
factor=4,
overlap=12,
eliminate_border_px=2,
save_preview=False,
debug=False,
):
"""Run Sentinel-2 SR inference. Kept out of import-time for CI."""
if gpus is not None and len(gpus) > 0:
os.environ.setdefault("CUDA_VISIBLE_DEVICES", ",".join(map(str, gpus)))
model, device = load_model(config_path=config_path, ckpt_path=ckpt_path)
try: # Prefer the rebranded utility package when available
import gan_engine_utils as _large_utils
except ImportError: # pragma: no cover - legacy dependency support
import opensr_utils as _large_utils
sr_object = _large_utils.large_file_processing(
root=sen2_path,
model=model,
window_size=window_size,
factor=factor,
overlap=overlap,
eliminate_border_px=eliminate_border_px,
device=device,
gpus=gpus if gpus is not None else ([0] if device == "cuda" else []),
save_preview=save_preview,
debug=debug,
)
sr_object.start_super_resolution()
return sr_object
main()
¶
Source code in gan_engine/inference.py
def main():
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# ---- Define placeholders ----
sen2_path = Path(__file__).resolve().parent / "data" / "S2A_MSIL2A_EXAMPLE.SAFE"
config_path = Path(__file__).resolve().parent / "configs" / "config_20m.yaml"
ckpt_path = "checkpoints/srgan-20m-6band/last.ckpt"
gpus = [0]
run_sen2_inference(
sen2_path=str(sen2_path),
config_path=str(config_path),
ckpt_path=ckpt_path,
gpus=gpus,
)
Pretrained model factory¶
Utility helpers to instantiate pretrained SRGAN models.
load_from_config(config_path, checkpoint_uri=None, *, map_location=None, mode='train')
¶
Instantiate SRGAN_model from a YAML config and optional checkpoint.
Parameters¶
config_path:
Filesystem path to the YAML configuration that describes the generator
and discriminator architecture. This should match the configuration the
checkpoint was trained with.
checkpoint_uri:
Optional path or HTTP(S) URL pointing to a Lightning checkpoint. When
omitted, the factory returns an untrained model initialised from the
supplied config.
map_location:
Optional argument forwarded to :func:torch.load during checkpoint
deserialisation.
mode:
Mode in which to instantiate the model. Either "train" or "eval".
Source code in gan_engine/_factory.py
def load_from_config(
config_path: Union[str, Path],
checkpoint_uri: Optional[Union[str, Path]] = None,
*,
map_location: Optional[Union[str, torch.device]] = None,
mode: str = "train",
) -> LightningModule:
"""Instantiate ``SRGAN_model`` from a YAML config and optional checkpoint.
Parameters
----------
config_path:
Filesystem path to the YAML configuration that describes the generator
and discriminator architecture. This should match the configuration the
checkpoint was trained with.
checkpoint_uri:
Optional path or HTTP(S) URL pointing to a Lightning checkpoint. When
omitted, the factory returns an untrained model initialised from the
supplied config.
map_location:
Optional argument forwarded to :func:`torch.load` during checkpoint
deserialisation.
mode:
Mode in which to instantiate the model. Either "train" or "eval".
"""
config_path = Path(config_path)
if not config_path.is_file():
raise FileNotFoundError(f"Config file '{config_path}' could not be located.")
model = SRGAN_model(config=config_path, mode=mode)
if checkpoint_uri is not None:
with _maybe_download(checkpoint_uri) as resolved_path:
checkpoint = torch.load(str(resolved_path), map_location=map_location)
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
if model.ema is not None and "ema_state" in checkpoint:
model.ema.load_state_dict(checkpoint["ema_state"])
model.eval()
return model
load_inference_model(preset, *, cache_dir=None, map_location=None)
¶
Instantiate an off-the-shelf pretrained SRGAN.
The function downloads a known-good configuration + checkpoint pair from the Hugging Face Hub (unless it is already cached) and restores the packaged Lightning module.
Source code in gan_engine/_factory.py
def load_inference_model(
preset: str,
*,
cache_dir: Optional[Union[str, Path]] = None,
map_location: Optional[Union[str, torch.device]] = None,
) -> LightningModule:
"""Instantiate an off-the-shelf pretrained SRGAN.
The function downloads a known-good configuration + checkpoint pair from
the Hugging Face Hub (unless it is already cached) and restores the
packaged Lightning module.
"""
key = preset.strip().replace("_", "-").upper()
try:
preset_meta = _PRESETS[key]
except KeyError as err:
valid = ", ".join(sorted(_PRESETS))
raise ValueError(
f"Unknown preset '{preset}'. Available options: {valid}."
) from err
try: # pragma: no cover - import guard only used at runtime
from huggingface_hub import hf_hub_download
except ImportError as exc: # pragma: no cover - dependency guard
raise ImportError(
"huggingface_hub is required for load_inference_model. "
"Install the project extras or run 'pip install huggingface-hub'."
) from exc
config_path = hf_hub_download(
repo_id=preset_meta.repo_id,
filename=preset_meta.config_filename,
cache_dir=None if cache_dir is None else str(cache_dir),
)
checkpoint_path = hf_hub_download(
repo_id=preset_meta.repo_id,
filename=preset_meta.checkpoint_filename,
cache_dir=None if cache_dir is None else str(cache_dir),
)
return load_from_config(
config_path,
checkpoint_path,
map_location=map_location,
mode="eval",
)