PyTorch Lightning fine-tuning template#

by Andrés Muñoz-Jaramillo

This notebook is meant to act as a template to train and use a surya model to implement DS application.

It focuses on the concept of defining a modified Surya model, loading its weigths, and using a PyTorch lightning training loop to train it

This notebook assumes familiarity with the concepts of datasets and dataloaders contained in the 0_dataset_dataloader_template.ipynb

It doesn’t require having seen the baselines template, but they are meant to complement each other. In fact they are on purpose almost identical!!!

Set your cuda visible device#

IMPORTANT: Since we are sharing resources, please make sure that the cuda visible device you put here is the one assigned to your team and your machine.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import sys
from torch.utils.data import DataLoader

import torch
import yaml

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger

# Append base path.  May need to be modified if the folder structure changes.
# It gives the notebook access to the wokshop_infrastructure folder.
sys.path.append("../../")
 
# Append Surya path. May need to be modified if the folder structure changes.
# It gives the notebook access to surya's release code.

from workshop_infrastructure.utils import build_scalers  # Data scaling utilities for Surya stacks
from workshop_infrastructure.utils import apply_peft_lora
torch.set_float32_matmul_precision('medium')

Download scalers and Weights#

Surya input data needs to be scaled properly for the model to work and this cell downloads the scaling information. In this notebook we also download the model weights for finetuning

  • If the cell below fails, try running the provided shell script directly in the terminal.

  • Sometimes the download may fail due to network or server issues—if that happens, simply re-run the script a few times until it completes successfully.

!sh download_scalers_and_weights.sh

Load configuration#

Surya was designed to read a configuration file that defines many aspects of the model including the data it uses we use this config file to set default values that do not need to be modified, but also to define values specific to our downstream application

# Configuration paths - modify these if your files are in different locations
config_path = "./configs/config_script.yaml"

# Load configuration
print("Loading configuration...")
try:
    config = yaml.safe_load(open(config_path, "r"))
    print("Configuration loaded successfully!")
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Make sure config_script.yaml exists in your current directory")
    raise

# build_scalers accepts a file path directly
scalers = build_scalers(info=config["data"]["scalers_path"])

From notebook to script: raw dict vs. typed config#

This notebook accesses configuration as a raw Python dict (config["section"]["key"]). The companion training script 3_finetune_template_1D.py calls load_config() instead, which parses the same YAML and returns a fully typed TrainingConfig. Both read exactly the same config_script.yaml — the difference is only in how Python exposes the values.

Here is how the notebook dict keys map to cfg.* attributes in the script:

Notebook (config[...])

Script (cfg.*)

Dataclass

config["data"]["train_data_path"]

cfg.data.train_data_path

DataConfig

config["data"]["channels"]

cfg.data.channels

DataConfig

config["data"]["s3_anon"]

cfg.data.s3_anon

DataConfig

config["data"]["s3_cache_dir"]

cfg.data.s3_cache_dir

DataConfig

config["model"]["pretrained_path"]

cfg.model.pretrained_path

ModelConfig

config["model"]["time_embedding"]["time_dim"]

cfg.model.time_embedding.time_dim

TimeEmbeddingConfig

config["model"]["use_lora"]

cfg.model.use_lora

ModelConfig

config["model"]["lora_config"]

cfg.model.lora_config

LoraAdapterConfig

config["training"]["learning_rate"]

cfg.learning_rate

TrainingConfig

config["training"]["batch_size"]

cfg.batch_size

TrainingConfig

config["training"]["dtype"]

cfg.dtype

TrainingConfig

config["logging"]["wandb_project"]

cfg.wandb_project

TrainingConfig

config["output"]["ckpt_dir"]

cfg.output.ckpt_dir

OutputConfig

The typed config catches typos and missing keys at startup (not mid-training) and gives IDE autocompletion. When adapting the template, add new fields to the appropriate dataclass in configs.py and to the YAML — load_config() picks them up automatically.

Define Downstream (DS) datasets#

This child class takes as input all expected HelioFM parameters, plus additonal parameters relevant to the downstream application. Here we focus in particular to the DS index and parameters necessary to combine it with the HelioFM index.

Another important component of creating a dataset class for your DS is normalization. Here we use a log normalization on xray flux that will act as the output target. Making log10(xray_flux) strictly positive and having 66% of its values between 0 and 1

In this case we will define both a training and a validation dataset using the indices pointed at in the config

Important: In this notebook we sets max_number_of_samples=6 to potentially avoid going through the whole dataset as we explore it. Keep in mind this for the future in case the database seems smaller than you expect

from downstream_apps.template.datasets.template_dataset import FlareDSDataset
train_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["train_data_path"],
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["training"]["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["training"]["drop_hmi_probability"],
    use_latitude_in_learned_flow=config["training"]["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="train",
    s3_use_simplecache=False,
    s3_storage_options={"anon": config["data"]["s3_anon"]},
    s3_cache_dir=config["data"]["s3_cache_dir"],
    #### Put your downstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=10,
    ds_flare_index_path=config["data"]["flare_index_path"],
    ds_time_column=config["data"]["ds_time_column"],
    ds_time_tolerance=config["data"]["ds_time_tolerance"],
    ds_match_direction=config["data"]["ds_match_direction"],
)

# The Validation dataset changes the index we read
val_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["valid_data_path"],  #<---------------- different index path
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["training"]["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["training"]["drop_hmi_probability"],
    use_latitude_in_learned_flow=config["training"]["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="val",
    s3_use_simplecache=False,
    s3_storage_options={"anon": config["data"]["s3_anon"]},
    s3_cache_dir=config["data"]["s3_cache_dir"],
    #### Put your downstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=10,
    ds_flare_index_path=config["data"]["flare_index_path"],
    ds_time_column=config["data"]["ds_time_column"],
    ds_time_tolerance=config["data"]["ds_time_tolerance"],
    ds_match_direction=config["data"]["ds_match_direction"],
)

We also intialize separate training and validation dataloaders. Since we are working in a shared environment. Using multiprocessing_context=”spawn” helps avoid lockups.

batch_size = config["training"]["batch_size"]

train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    multiprocessing_context="spawn",
    persistent_workers=True,
    pin_memory=True,
)

val_data_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=4,
    multiprocessing_context="spawn",
    persistent_workers=True,
    pin_memory=True,
)

Initialize the HelioSpectformer model#

This is the main difference beteween the notebook that trains the simple model and the one that fine-tunes Surya.

In the case of the finetuning exercise one of the main differences between DS applications is the dimensionality of the output. In this notebook we use a modified HelioSpectformer that projects into a 1D space.

IMPORTANT: If your DS application is 2D you need to use the HelioSpectformer2D

from workshop_infrastructure.models.finetune_models import HelioSpectformer1D

Now the config file really comes into bear. The Spectformer has a metric ton of hyperparameters

import torch

# Map the dtype string from config to a torch.dtype
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16}
dtype = dtype_map.get(config["training"]["dtype"], torch.float32)

model = HelioSpectformer1D(
    #### Backbone (Surya 366M defaults — must match the pretrained checkpoint)
    img_size=config["model"]["img_size"],
    patch_size=config["model"]["patch_size"],
    in_chans=config["model"]["in_channels"],
    embed_dim=config["model"]["embed_dim"],
    time_embedding=config["model"]["time_embedding"],
    depth=config["model"]["depth"],
    n_spectral_blocks=config["model"]["spectral_blocks"],
    num_heads=config["model"]["num_heads"],
    mlp_ratio=config["model"]["mlp_ratio"],
    drop_rate=config["model"]["drop_rate"],
    dtype=dtype,
    window_size=config["model"]["window_size"],
    dp_rank=config["model"]["dp_rank"],
    learned_flow=config["model"]["learned_flow"],
    use_latitude_in_learned_flow=config["training"]["use_latitude_in_learned_flow"],
    init_weights=config["model"]["init_weights"],
    checkpoint_layers=config["model"]["checkpoint_layers"],
    rpe=config["model"]["rpe"],
    ensemble=config["model"]["ensemble"],
    nglo=config["model"]["nglo"],
    #### Fine-tuning head parameters
    dropout=config["model"]["dropout"],
    pooling=config["model"]["pooling"],
    penultimate_linear_layer=config["model"]["penultimate_linear_layer"],
    num_outputs=1,
)

Load model weights#

Here we load the pre-trained checkpoint and load the weights. The exercise of loading follows the idea of us as many of the weights as possible. This is accomplished through the filtered_checkpoint_state. It checks to see if the pretrained model’s layers match those of your finetuning architecture. It also checks that all your dimensions across layers check out. If something does not work those paramameters are left in their random initialization.

model_state = model.state_dict()
checkpoint_state = torch.load(
    config["model"]["pretrained_path"], weights_only=True, map_location="cpu"
)

# The checkpoint was saved from HelioSpectFormer directly, so its keys are flat
# (e.g. "embedding.proj.weight"). The fine-tuning model nests the backbone under
# "backbone.*", so we try both the original key and the backbone.-prefixed key.
remapped = {}
for k, v in checkpoint_state.items():
    for candidate in (k, f"backbone.{k}"):
        if candidate in model_state and hasattr(v, "shape") and v.shape == model_state[candidate].shape:
            remapped[candidate] = v
            break

model_state.update(remapped)
model.load_state_dict(model_state, strict=True)
print(f"Loaded {len(remapped)} / {len(checkpoint_state)} pretrained weights.")

To LoRA or not to Lora#

This cell gives you two options. On the one hand we have the classic freezing of the backbone (the initial layers of the model). On the other hand we have the use of a LoRA.

LoRas have been a remarkable addition to our arsenal of models. They have the advantage of keeping pretty much the entire model intact and only add broad modifications to weights as needed.

from workshop_infrastructure.configs import LoraAdapterConfig

if config["model"]["use_lora"]:
    lora_cfg = LoraAdapterConfig(**config["model"]["lora_config"])
    model = apply_peft_lora(model, lora_cfg)
else:
    for name, param in model.named_parameters():
        if "backbone" in name:
            param.requires_grad = False
    parameters_with_grads = [name for name, param in model.named_parameters() if param.requires_grad]
    print(
        f"{len(parameters_with_grads)} parameters require gradients: {', '.join(parameters_with_grads)}."
    )

We can now test that this model manipulates a batch as expected and returns an estimate of flare intensity as we did for the simple baseline.

We pass the input stack ‘ts’ to the model to transform it into our regression output. Note that since this model was trained for a different task, it’s likely it won’t perform very well. As with the simple baseline, this only acts as a test that our model forward doesn’t have dimension problems.

Dimension problemns are the dominant source of error in this kind of work.

Note that our output has now the size of our batch.

batch = next(iter(train_data_loader))
output = model.forward(batch)  # Get rid of singleton dimension
output

Define your metrics#

Metrics are a very important part of training AI models. They provide your models with the quantitification of error, which in turn shifts the weights towards better pefrorming models. They also provide a way for you to monitor performance, identify overfitting, and quantify value added.

We now initialize the metrics class which allows you to control what metrics do you want to use as “loss” (i.e. the metrics that backpropagate through your model) and which ones for monitoring performance. As with other components, this takes the form of a loaded module that can be later use in a training script

from downstream_apps.template.metrics.template_metrics import FlareMetrics
train_loss_metrics = FlareMetrics("train_loss")
train_evaluation_metrics = FlareMetrics("train_metrics")
validation_evaluation_metrics = FlareMetrics("val_metrics")

Now they can be evaluated in our model’s output and our ground truth. First the loss that actually will backpropagate, in this case Mean Squared Errror

train_loss_metrics(output, batch["forecast"])

Then a training evaluation that will not backpropagate and inform our model, but that we can keep an eye on. Note that reporting lots of metrics during training will slow the training process. I’m including it her as an example, but oftentimes is better to put the diagnostics only in the validation evaluation metrics.

Here we are caclulating the Root Relative Squared Error https://lightning.ai/docs/torchmetrics/stable/regression/rse.html

A value below one means the prediction is better than predicting the average. It is unlikely that this metric will be lower than one with a randomly initialized model

train_evaluation_metrics(output, batch["forecast"])

In the validation evaluation metrics we report both MSE and RRSE

validation_evaluation_metrics(output, batch["forecast"])

Define your PyTorch ligthning module#

In this workshop we will use PyTorch lightning to train our models. PyTorch lighting reduces the amount of code required to implement a training loop in comparison to PyTorch (at the expense of control and versatility).

Opening the FlareLightningModule shows a simple Lightning model implementation. It consists of:

  • An initialization of the class (metrics, model, and learning rate).

  • The forward code that runs evaluation of the model.

  • Training and validation steps.

  • Configuration of optimizers.

Note that it is the same Lightning module we used for the baseline!!

from downstream_apps.template.lightning_modules.pl_simple_baseline import FlareLightningModule

Set your global seeds#

Since training AI models generally uses stochastic gradient descent, it is a good idea to fix your random seeds so that your training exercise is reproducible.

L.seed_everything(42, workers=True)

Intialize Lightning module#

Now we properly initalize the Lightning module to enable training, including passing the dictionary of metrics

metrics = {
    'train_loss': train_loss_metrics,
    'train_metrics': train_evaluation_metrics,
    'val_metrics': validation_evaluation_metrics,
}

learning_rate = config["training"]["learning_rate"]
lit_model = FlareLightningModule(model, metrics, lr=learning_rate, batch_size=batch_size)

Logging#

In order to properly compare experiments against each other, it is very useful to log evaluation metrics in a place where they can be compared against other training runs. In this workshop we will use Weights and Biases (WandB).

The first time you run WandB in a machine it will ask you to login to WandB. You should have received an invitation to our project. In order to login you must:

  • Select option 2 (existing account). In VScode the dialog opens a box at the top of your screen.

  • Click on get API Key (this will open a browser).

  • Generate API Key.

  • Paste it in the dialog box at the top of your VSCode

project_name = config["logging"]["wandb_project"]
run_name = "finetune_experiment_1"  # give your run a descriptive name

wandb_logger = WandbLogger(
    entity=config["logging"]["wandb_entity"],  # set wandb_entity in config; null = personal account
    project=project_name,
    name=run_name,
    log_model=False,
    save_dir="./wandb/wandb_tmp",
)

csv_logger = CSVLogger("runs", name=project_name)

Initialize trainer#

With the loggers done, now the trainer needs to be defined. The trainer defines several properties of your training run. Here we define:

  • The max number of epochs (one epoch represents your model seeing your entire training dataset).

  • Define where the training run will take place (auto uses the GPU if possible, if not, CPU).

  • The loggers.

  • The callbacks (here we save the model with the lowest validation loss).

  • Logging frequency (because we are working with a small dataset it needs to be small).

Note that in this notebook we also set a mixed precision to reduce the model’s footprint in memory.

max_epochs = 2

# -------------------------------------------------------------------------
# Trainer
# -------------------------------------------------------------------------
trainer = L.Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    devices="auto",
    precision="bf16-mixed", 
    logger=[wandb_logger, csv_logger],
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )
    ],
    log_every_n_steps=2,
)

Fit the model#

Finally we fit the model. We pass the Lighting module, and our dataloaders.

trainer.fit(lit_model, train_data_loader, val_data_loader)

Conclusion#

With this we have now integrated our dataset, dataloaders, metrics, and DS into an end-2-end training loop and we are ready to experiment!