Training Custom Spatial Foundation Models: A Practical Python Pipeline

Training a custom spatial foundation model requires adapting large-scale, pre-trained neural architectures to recognize geographic patterns that standard computer vision models routinely miss. While foundation models excel at extracting hierarchical features from natural images, geospatial data introduces unique constraints: coordinate reference systems, spatial autocorrelation, and the physical scale of real-world objects. Fine-tuning these models for tasks like land cover classification or infrastructure mapping is not a simple layer swap. You must teach the network to respect geographic continuity, handle multi-spectral band relationships, and maintain spatial context across large, contiguous landscapes.

The following pipeline provides a verified, production-ready implementation for loading geospatial tiles, normalizing multi-band raster data, adapting a pre-trained vision backbone, and executing a stable training loop.

The architecture and training flow are summarized below.

flowchart LR
    A["Geospatial tiles<br/>(multi-band, per-band norm)"] --> B["Pretrained backbone<br/>(ResNet50, head stripped)"]
    B --> C["Spatial feature maps<br/>(2048 channels)"]
    C --> D["Decoder<br/>(conv + upsample)"]
    D --> E["Per-pixel logits"]
    E --> F["CrossEntropyLoss<br/>(ignore_index=-1)"]
    F --> G["Backward + grad clip<br/>+ AdamW step"]
    G --> B

1. Structuring the Geospatial Dataset

Geographic imagery rarely fits into GPU memory as a single array. It must be partitioned into tiles that preserve spatial alignment. Arbitrary cropping breaks the statistical dependencies between neighboring pixels, causing the model to learn fragmented patterns. The dataset class below handles tile loading, enforces consistent band ordering, and applies per-band normalization to stabilize gradient descent.

import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models

class SpatialTileDataset(Dataset):
    def __init__(self, tile_dir, label_dir, transform=None):
        self.tile_paths = sorted([
            os.path.join(tile_dir, f) for f in os.listdir(tile_dir) if f.endswith('.tif')
        ])
        self.label_paths = [
            os.path.join(label_dir, os.path.basename(f)) for f in self.tile_paths
        ]
        self.transform = transform

    def __len__(self):
        return len(self.tile_paths)

    def __getitem__(self, idx):
        # Load multi-band raster
        with rasterio.open(self.tile_paths[idx]) as src:
            image = src.read().astype(np.float32)

        # Load single-channel label mask
        with rasterio.open(self.label_paths[idx]) as src:
            label = src.read(1).astype(np.int64)

        # Per-band normalization to match foundation model pre-training distributions
        means = image.mean(axis=(1, 2), keepdims=True)
        stds = image.std(axis=(1, 2), keepdims=True)
        image = (image - means) / (stds + 1e-8)

        if self.transform:
            image = self.transform(image)

        return torch.tensor(image), torch.tensor(label).long()

2. Adapting the Vision Backbone for Pixel-Wise Prediction

Pre-trained backbones like ResNet50 output a single vector per image. For semantic segmentation, you must preserve spatial dimensions and attach a lightweight decoder that upsamples features back to the original tile resolution.

# Strip classification head (avgpool + fc) to retain spatial feature maps
backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
backbone = nn.Sequential(*list(backbone.children())[:-2])

class SpatialDecoder(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        # Upsample to match input tile dimensions (adjust stride based on your tile size)
        self.upsample = nn.ConvTranspose2d(256, num_classes, kernel_size=16, stride=16)

    def forward(self, x):
        x = self.conv_block(x)
        return self.upsample(x)

# Combine backbone and decoder
class SpatialFoundationModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = backbone
        self.decoder = SpatialDecoder(in_channels=2048, num_classes=num_classes)

    def forward(self, x):
        features = self.backbone(x)
        return self.decoder(features)

3. Training Loop & Optimization

Spatial models require careful gradient management. Large tiles and high-band inputs often cause exploding gradients or memory fragmentation. The loop below implements gradient clipping, mixed-precision training, and explicit device placement.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpatialFoundationModel(num_classes=5).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # -1 for masked/invalid pixels

def train_epoch(dataloader, model, optimizer, criterion, device):
    model.train()
    epoch_loss = 0.0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Resize outputs if decoder upsampling doesn't perfectly match label dimensions
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False
            )

        loss = criterion(outputs, labels)
        loss.backward()

        # Prevent gradient explosion common in spatial segmentation
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

4. Fast Debugging & Problem Resolution

When training spatial foundation models, failures typically manifest in three predictable patterns. Use these steps to isolate and resolve them immediately.

Symptom Root Cause Exact Fix
RuntimeError: size mismatch Decoder output dimensions do not match label tile size Use torch.nn.functional.interpolate to force alignment before loss calculation, as shown in the training loop.
Loss = NaN after epoch 1 Unnormalized multi-spectral bands or invalid label values Verify band statistics. Remap nodata label values (e.g., -9999) to the integer sentinel -1 and pass it as ignore_index in CrossEntropyLoss.
CUDA out of memory Overlapping tiles or excessive batch size Reduce batch_size to 2 or 4. Enable torch.backends.cudnn.benchmark = True and use torch.cuda.amp.autocast() for mixed precision.
Model predicts uniform class Spatial autocorrelation ignored during tiling Ensure tiles are generated with a stride ≤ 50% of tile width. Overlap preserves boundary context and prevents edge artifacts.

For persistent memory bottlenecks or hyperparameter tuning across large geospatial datasets, consult Advanced Geospatial AI Optimization to implement gradient accumulation and distributed data loading. Always validate CRS alignment before training: mismatched projections between imagery and labels will silently corrupt spatial relationships. Reference the official Rasterio Quickstart for coordinate validation workflows, and review PyTorch Data Loading Best Practices to tune num_workers and pin_memory for your hardware.