Training Custom Spatial Foundation Models: A Practical Python Pipeline
Training a custom spatial foundation model requires adapting large-scale, pre-trained neural architectures to recognize geographic patterns that standard computer vision models routinely miss. While foundation models excel at extracting hierarchical features from natural images, geospatial data introduces unique constraints: coordinate reference systems, spatial autocorrelation, and the physical scale of real-world objects. Fine-tuning these models for tasks like land cover classification or infrastructure mapping is not a simple layer swap. You must teach the network to respect geographic continuity, handle multi-spectral band relationships, and maintain spatial context across large, contiguous landscapes.
The following pipeline provides a verified, production-ready implementation for loading geospatial tiles, normalizing multi-band raster data, adapting a pre-trained vision backbone, and executing a stable training loop.
The architecture and training flow are summarized below.
flowchart LR
A["Geospatial tiles<br/>(multi-band, per-band norm)"] --> B["Pretrained backbone<br/>(ResNet50, head stripped)"]
B --> C["Spatial feature maps<br/>(2048 channels)"]
C --> D["Decoder<br/>(conv + upsample)"]
D --> E["Per-pixel logits"]
E --> F["CrossEntropyLoss<br/>(ignore_index=-1)"]
F --> G["Backward + grad clip<br/>+ AdamW step"]
G --> B
1. Structuring the Geospatial Dataset
Geographic imagery rarely fits into GPU memory as a single array. It must be partitioned into tiles that preserve spatial alignment. Arbitrary cropping breaks the statistical dependencies between neighboring pixels, causing the model to learn fragmented patterns. The dataset class below handles tile loading, enforces consistent band ordering, and applies per-band normalization to stabilize gradient descent.
import os
import numpy as np
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models
class SpatialTileDataset(Dataset):
def __init__(self, tile_dir, label_dir, transform=None):
self.tile_paths = sorted([
os.path.join(tile_dir, f) for f in os.listdir(tile_dir) if f.endswith('.tif')
])
self.label_paths = [
os.path.join(label_dir, os.path.basename(f)) for f in self.tile_paths
]
self.transform = transform
def __len__(self):
return len(self.tile_paths)
def __getitem__(self, idx):
# Load multi-band raster
with rasterio.open(self.tile_paths[idx]) as src:
image = src.read().astype(np.float32)
# Load single-channel label mask
with rasterio.open(self.label_paths[idx]) as src:
label = src.read(1).astype(np.int64)
# Per-band normalization to match foundation model pre-training distributions
means = image.mean(axis=(1, 2), keepdims=True)
stds = image.std(axis=(1, 2), keepdims=True)
image = (image - means) / (stds + 1e-8)
if self.transform:
image = self.transform(image)
return torch.tensor(image), torch.tensor(label).long()
2. Adapting the Vision Backbone for Pixel-Wise Prediction
Pre-trained backbones like ResNet50 output a single vector per image. For semantic segmentation, you must preserve spatial dimensions and attach a lightweight decoder that upsamples features back to the original tile resolution.
# Strip classification head (avgpool + fc) to retain spatial feature maps
backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
backbone = nn.Sequential(*list(backbone.children())[:-2])
class SpatialDecoder(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
# Upsample to match input tile dimensions (adjust stride based on your tile size)
self.upsample = nn.ConvTranspose2d(256, num_classes, kernel_size=16, stride=16)
def forward(self, x):
x = self.conv_block(x)
return self.upsample(x)
# Combine backbone and decoder
class SpatialFoundationModel(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.backbone = backbone
self.decoder = SpatialDecoder(in_channels=2048, num_classes=num_classes)
def forward(self, x):
features = self.backbone(x)
return self.decoder(features)
3. Training Loop & Optimization
Spatial models require careful gradient management. Large tiles and high-band inputs often cause exploding gradients or memory fragmentation. The loop below implements gradient clipping, mixed-precision training, and explicit device placement.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpatialFoundationModel(num_classes=5).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-1) # -1 for masked/invalid pixels
def train_epoch(dataloader, model, optimizer, criterion, device):
model.train()
epoch_loss = 0.0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# Resize outputs if decoder upsampling doesn't perfectly match label dimensions
if outputs.shape[-2:] != labels.shape[-2:]:
outputs = torch.nn.functional.interpolate(
outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False
)
loss = criterion(outputs, labels)
loss.backward()
# Prevent gradient explosion common in spatial segmentation
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
4. Fast Debugging & Problem Resolution
When training spatial foundation models, failures typically manifest in three predictable patterns. Use these steps to isolate and resolve them immediately.
| Symptom | Root Cause | Exact Fix |
|---|---|---|
RuntimeError: size mismatch |
Decoder output dimensions do not match label tile size | Use torch.nn.functional.interpolate to force alignment before loss calculation, as shown in the training loop. |
Loss = NaN after epoch 1 |
Unnormalized multi-spectral bands or invalid label values | Verify band statistics. Remap nodata label values (e.g., -9999) to the integer sentinel -1 and pass it as ignore_index in CrossEntropyLoss. |
CUDA out of memory |
Overlapping tiles or excessive batch size | Reduce batch_size to 2 or 4. Enable torch.backends.cudnn.benchmark = True and use torch.cuda.amp.autocast() for mixed precision. |
| Model predicts uniform class | Spatial autocorrelation ignored during tiling | Ensure tiles are generated with a stride ≤ 50% of tile width. Overlap preserves boundary context and prevents edge artifacts. |
For persistent memory bottlenecks or hyperparameter tuning across large geospatial datasets, consult Advanced Geospatial AI Optimization to implement gradient accumulation and distributed data loading. Always validate CRS alignment before training: mismatched projections between imagery and labels will silently corrupt spatial relationships. Reference the official Rasterio Quickstart for coordinate validation workflows, and review PyTorch Data Loading Best Practices to tune num_workers and pin_memory for your hardware.