PyTorch Lightning baseline template#
by Andrés Muñoz-Jaramillo
This notebook is meant to act as a template to train and use a simple regression model to define a baseline that can be compared with a DS application.
It focuses on the concept of defining a PyTorch model, a PyTorch lightning training loop and the deffinition of metrics of performance.
This notebook assumes familiarity with the concepts of datasets and dataloaders contained in the 0_dataset_dataloader_template.ipynb
Set your cuda visible device#
IMPORTANT: Since we are sharing resources, please make sure that the cuda visible device you put here is the one assigned to your team and your machine.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
from torch.utils.data import DataLoader
import torch
import yaml
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
# Append base path. May need to be modified if the folder structure changes.
# It gives the notebook access to the wokshop_infrastructure folder.
sys.path.append("../../")
# Append Surya path. May need to be modified if the folder structure changes.
# It gives the notebook access to surya's release code.
from workshop_infrastructure.utils import build_scalers # Data scaling utilities for Surya stacks
torch.set_float32_matmul_precision('medium')
Download scalers#
Surya input data needs to be scaled properly for the model to work and this cell downloads the scaling information.
If the cell below fails, try running the provided shell script directly in the terminal.
Sometimes the download may fail due to network or server issues—if that happens, simply re-run the script a few times until it completes successfully.
!sh download_scalers.sh
Load configuration#
Surya was designed to read a configuration file that defines many aspects of the model including the data it uses we use this config file to set default values that do not need to be modified, but also to define values specific to our downstream application
# Configuration paths - modify these if your files are in different locations
config_path = "./configs/config_script.yaml"
# Load configuration
print("Loading configuration...")
try:
config = yaml.safe_load(open(config_path, "r"))
print("Configuration loaded successfully!")
except FileNotFoundError as e:
print(f"Error: {e}")
print("Make sure config_script.yaml exists in your current directory")
raise
# build_scalers accepts a file path directly
scalers = build_scalers(info=config["data"]["scalers_path"])
Define Downstream (DS) datasets#
This child class takes as input all expected HelioFM parameters, plus additonal parameters relevant to the downstream application. Here we focus in particular to the DS index and parameters necessary to combine it with the HelioFM index.
Another important component of creating a dataset class for your DS is normalization. Here we use a log normalization on xray flux that will act as the output target. Making log10(xray_flux) strictly positive and having 66% of its values between 0 and 1
In this case we will define both a training and a validation dataset using the indices pointed at in the config
Important: In this notebook we sets max_number_of_samples=6 to potentially avoid going through the whole dataset as we explore it. Keep in mind this for the future in case the database seems smaller than you expect
from downstream_apps.template.datasets.template_dataset import FlareDSDataset
train_dataset = FlareDSDataset(
#### All these lines are required by the parent HelioNetCDFDataset class
index_path=config["data"]["train_data_path"],
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
rollout_steps=config["training"]["rollout_steps"],
channels=config["data"]["channels"],
drop_hmi_probability=config["training"]["drop_hmi_probability"],
use_latitude_in_learned_flow=config["training"]["use_latitude_in_learned_flow"],
scalers=scalers,
phase="train",
s3_use_simplecache=False,
s3_storage_options={"anon": config["data"]["s3_anon"]},
s3_cache_dir=config["data"]["s3_cache_dir"],
#### Put your downstream (DS) specific parameters below this line
return_surya_stack=True,
max_number_of_samples=6,
ds_flare_index_path=config["data"]["flare_index_path"],
ds_time_column=config["data"]["ds_time_column"],
ds_time_tolerance=config["data"]["ds_time_tolerance"],
ds_match_direction=config["data"]["ds_match_direction"],
)
# The Validation dataset changes the index we read
val_dataset = FlareDSDataset(
#### All these lines are required by the parent HelioNetCDFDataset class
index_path=config["data"]["valid_data_path"], #<---------------- different index path
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
rollout_steps=config["training"]["rollout_steps"],
channels=config["data"]["channels"],
drop_hmi_probability=config["training"]["drop_hmi_probability"],
use_latitude_in_learned_flow=config["training"]["use_latitude_in_learned_flow"],
scalers=scalers,
phase="val",
s3_use_simplecache=False,
s3_storage_options={"anon": config["data"]["s3_anon"]},
s3_cache_dir=config["data"]["s3_cache_dir"],
#### Put your downstream (DS) specific parameters below this line
return_surya_stack=True,
max_number_of_samples=6,
ds_flare_index_path=config["data"]["flare_index_path"],
ds_time_column=config["data"]["ds_time_column"],
ds_time_tolerance=config["data"]["ds_time_tolerance"],
ds_match_direction=config["data"]["ds_match_direction"],
)
We also intialize separate training and validation dataloaders. Since we are working in a shared environment. Using multiprocessing_context=”spawn” helps avoid lockups.
batch_size = config["training"]["batch_size"]
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
multiprocessing_context="spawn",
persistent_workers=True,
pin_memory=True,
)
val_data_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=4,
multiprocessing_context="spawn",
persistent_workers=True,
pin_memory=True,
)
Define simple baseline model#
Defining a simple baseline is important to understand what value is bringing the AI model to the problem.
It is always very good to have a very simple baseline model. Ideally one that cannot overfit the data. This is a very good way of really measuring the value added of complex models. Classical machine learning excels here:
Regressions and logistic regressions.
Climatological averages.
Persistance.
Simple transformations.
Simple models avoid excesively optimistic assessments of the capatiblities of a complex models and for many problems are actually remarkably hard to beat.
In this example we define a simple regression acting on the intensity of each channel. Note that we invert the normalization to deal with strictly positive quantities. As with the dataset we will be importing the model from a module so that we can use it within training scripts later on.
from downstream_apps.template.models.simple_baseline import RegressionFlareModel
We can now test that this model manipulates a batch as expected and returns an estimate of flare intensity
Note that the simple regression model definition requires knowing the number of channels and timesteps so here we pull that information from the configuration intializing the model.
n_input_timestamps = config["model"]["time_embedding"]["time_dim"]
n_channels = len(config["data"]["channels"])
# RegressionFlareModel only needs the flattened input dimension.
# It expects 'ts' in physical (log) space — use inverse_transform_channels() before calling forward().
model = RegressionFlareModel(n_input_timestamps * n_channels)
Now we can pass the input stack ‘ts’ to the model to transform it into our regression output. Note that since this model has not been trained and was initalized randomly. The output here has no real meaning. It only acts as a test that our model forward doesn’t have dimension problems.
Dimension problemns are the dominant source of error in this kind of work.
Note that our output has now the size of our batch.
from downstream_apps.template.models.simple_baseline import inverse_transform_channels
batch = next(iter(train_data_loader))
# RegressionFlareModel works in physical (log) space, not normalized space.
# inverse_transform_channels undoes the per-channel z-score normalization applied by the dataset.
batch_physical = inverse_transform_channels(batch, channel_order=config["data"]["channels"], scalers=scalers)
output = model.forward(batch_physical)[:, 0] # Get rid of singleton dimension
output
Define your metrics#
Metrics are a very important part of training AI models. They provide your models with the quantitification of error, which in turn shifts the weights towards better pefrorming models. They also provide a way for you to monitor performance, identify overfitting, and quantify value added.
We now initialize the metrics class which allows you to control what metrics do you want to use as “loss” (i.e. the metrics that backpropagate through your model) and which ones for monitoring performance. As with other components, this takes the form of a loaded module that can be later use in a training script
from downstream_apps.template.metrics.template_metrics import FlareMetrics
train_loss_metrics = FlareMetrics("train_loss")
train_evaluation_metrics = FlareMetrics("train_metrics")
validation_evaluation_metrics = FlareMetrics("val_metrics")
Now they can be evaluated in our model’s output and our ground truth. First the loss that actually will backpropagate, in this case Mean Squared Errror
train_loss_metrics(output, batch["forecast"])
Then a training evaluation that will not backpropagate and inform our model, but that we can keep an eye on. Note that reporting lots of metrics during training will slow the training process. I’m including it her as an example, but oftentimes is better to put the diagnostics only in the validation evaluation metrics.
Here we are caclulating the Root Relative Squared Error https://lightning.ai/docs/torchmetrics/stable/regression/rse.html
A value below one means the prediction is better than predicting the average. It is unlikely that this metric will be lower than one with a randomly initialized model
train_evaluation_metrics(output, batch["forecast"])
In the validation evaluation metrics we report both MSE and RRSE
validation_evaluation_metrics(output, batch["forecast"])
Define your PyTorch ligthning module#
In this workshop we will use PyTorch lightning to train our models. PyTorch lighting reduces the amount of code required to implement a training loop in comparison to PyTorch (at the expense of control and versatility).
Opening the FlareLightningModule shows a simple Lightning model implementation. It consists of:
An initialization of the class (metrics, model, and learning rate).
The forward code that runs evaluation of the model.
Training and validation steps.
Configuration of optimizers.
from downstream_apps.template.lightning_modules.pl_simple_baseline import FlareLightningModule
Set your global seeds#
Since training AI models generally uses stochastic gradient descent, it is a good idea to fix your random seeds so that your training exercise is reproducible.
L.seed_everything(42, workers=True)
Intialize Lightning module#
Now we properly initalize the Lightning module to enable training, including passing the dictionary of metrics
from functools import partial
metrics = {
'train_loss': train_loss_metrics,
'train_metrics': train_evaluation_metrics,
'val_metrics': validation_evaluation_metrics,
}
# Wire up the inverse transform so FlareLightningModule applies it before every model call.
preprocess_fn = partial(
inverse_transform_channels,
channel_order=config["data"]["channels"],
scalers=scalers,
)
learning_rate = config["training"]["learning_rate"]
lit_model = FlareLightningModule(model, metrics, lr=learning_rate, batch_size=batch_size, preprocess_fn=preprocess_fn)
Logging#
In order to properly compare experiments against each other, it is very useful to log evaluation metrics in a place where they can be compared against other training runs. In this workshop we will use Weights and Biases (WandB).
The first time you run WandB in a machine it will ask you to login to WandB. You should have received an invitation to our project. In order to login you must:
Select option 2 (existing account). In VScode the dialog opens a box at the top of your screen.
Click on get API Key (this will open a browser).
Generate API Key.
Paste it in the dialog box at the top of your VSCode
project_name = config["logging"]["wandb_project"]
run_name = "baseline_experiment_1" # give your run a descriptive name
wandb_logger = WandbLogger(
entity=config["logging"]["wandb_entity"], # set wandb_entity in config; null = personal account
project=project_name,
name=run_name,
log_model=False,
save_dir="./wandb/wandb_tmp",
)
csv_logger = CSVLogger("runs", name=project_name)
Initialize trainer#
With the loggers done, now the trainer needs to be defined. The trainer defines several properties of your training run. Here we define:
The max number of epochs (one epoch represents your model seeing your entire training dataset).
Define where the training run will take place (auto uses the GPU if possible, if not, CPU).
The loggers.
The callbacks (here we save the model with the lowest validation loss).
Logging frequency (because we are working with a small dataset it needs to be small).
max_epochs = 2
# -------------------------------------------------------------------------
# Trainer
# -------------------------------------------------------------------------
trainer = L.Trainer(
max_epochs=max_epochs,
accelerator="auto",
devices="auto",
logger=[wandb_logger, csv_logger],
callbacks=[
ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=1,
)
],
log_every_n_steps=2,
)
Fit the model#
Finally we fit the model. We pass the Lighting module, and our dataloaders.
trainer.fit(lit_model, train_data_loader, val_data_loader)
Conclusion#
With this we have now integrated our dataset, dataloaders, metrics, and baseline into an end-2-end training loop. The next step is to substitute the simple model with Surya.