Train models#

  • Last Modified: 07-04-2021

  • Authors: Sam Budd, Gonzalo Mateo-García


Tutorial: Train a Flood Extent segmentation model using the WorldFloods dataset

[1] Mateo-Garcia, G. et al. Towards global flood mapping onboard low cost satellites with machine learning. Scientific Reports 11, 7249 (2021).

This tutorial has been adapted for the Artificial Inteligence for Earth Monitoring online course which is available in the FutureLearn platform.

Note: If you run this notebook in Google Colab change the running environment to use a GPU.

import os

Step 0: Download the training data#

In order to run this tutorial you need (at least a subset of) the WorldFloods dataset. For this tutorial we will get it from our public Google Drive folder. For other alternatives see the download WorldFloods documentation.

Step 0a: mount the Public folder if you are in Google Colab#

If you’re running this tutorial in Google Colab you need to ‘add a shortcut to your Google Drive’ from the public Google Drive folder.

add_shortcut_drive.png

Then, mount that directory with the following code:

try:
    from google.colab import drive
    drive.mount('/content/drive')
    public_folder = '/content/drive/My Drive/Public WorldFloods Dataset'
    assert os.path.exists(public_folder), "Add a shortcut to the publice Google Drive folder: https://drive.google.com/drive/u/0/folders/1dqFYWetX614r49kuVE3CbZwVO6qHvRVH"
    google_colab = True
except ImportError as e:
    print(e)
    print("Setting google colab to false, it will need to install the gdown package!")
    public_folder = '.'
    google_colab = False
No module named 'google.colab'

Step 0b: Unzip the worldfloods sample folder#

If the folder could not be mounted it tries download the data using the gdown package (if not installed run: pip install gdown).

from ml4floods.models import dataset_setup
import zipfile

# Unzip the data
path_to_dataset_folder = "."
dataset_folder = os.path.join(path_to_dataset_folder, "worldfloods_v1_0_sample")

try:
    dataset_setup.validate_worldfloods_data(dataset_folder)
except FileNotFoundError as e:
    print(e)
    zip_file_name = os.path.join(public_folder, "worldfloods_v1_0_sample.zip") # this file size is 12.7Gb

    print("We need to unzip the data")
    # Download the zip file
    if not os.path.exists(zip_file_name):
        print("Download the data from Google Drive")
        import gdown
        # https://drive.google.com/file/d/11O6aKZk4R6DERIx32o4mMTJ5dtzRRKgV/view?usp=sharing
        gdown.download(id="11O6aKZk4R6DERIx32o4mMTJ5dtzRRKgV", output=zip_file_name)

    print("Unzipping the data")
    with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
        zip_ref.extractall(path_to_dataset_folder)
        zip_ref.close()
Data downloaded follows the expected format

Step 1: Setup Configuration file#

First we will load configuration file form models/configurations/worldfloods_template.json in the config file we specify many different hyperparameters to train the model; you could either use this or make a copy of it and modify the hyper-parameters that you want to try out.

from ml4floods.models.config_setup import get_default_config
import pkg_resources

# Set filepath to configuration files
# config_fp = 'path/to/worldfloods_template.json'
config_fp = pkg_resources.resource_filename("ml4floods","models/configurations/worldfloods_template.json")

config = get_default_config(config_fp)
Loaded Config for experiment:  worldfloods_demo_test
{   'data_params': {   'batch_size': 32,
                       'bucket_id': 'ml4cc_data_lake',
                       'channel_configuration': 'all',
                       'download': {'test': True, 'train': True, 'val': True},
                       'filter_windows': {   'apply': False,
                                             'threshold_clouds': 0.5,
                                             'version': 'v1'},
                       'input_folder': 'S2',
                       'loader_type': 'local',
                       'num_workers': 4,
                       'path_to_splits': 'worldfloods',
                       'target_folder': 'gt',
                       'test_transformation': {'normalize': True},
                       'train_test_split_file': '2_PROD/2_Mart/worldfloods_v1_0/train_test_split.json',
                       'train_transformation': {'normalize': True},
                       'window_size': [256, 256]},
    'deploy': False,
    'experiment_name': 'worldfloods_demo_test',
    'gpus': '0',
    'model_params': {   'hyperparameters': {   'channel_configuration': 'all',
                                               'early_stopping_patience': 4,
                                               'label_names': [   'land',
                                                                  'water',
                                                                  'cloud'],
                                               'lr': 0.0001,
                                               'lr_decay': 0.5,
                                               'lr_patience': 2,
                                               'max_epochs': 10,
                                               'max_tile_size': 256,
                                               'metric_monitor': 'val_dice_loss',
                                               'model_type': 'linear',
                                               'num_channels': 13,
                                               'num_classes': 3,
                                               'val_every': 1,
                                               'weight_per_class': [   1.93445299,
                                                                       36.60054169,
                                                                       2.19400729]},
                        'model_folder': 'gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart',
                        'model_version': 'v1',
                        'test': True,
                        'train': True},
    'resume_from_checkpoint': False,
    'seed': 12,
    'test': False,
    'train': False}

Step 1.a: Seed everything for reproducibility#

from pytorch_lightning import seed_everything
# Seed
seed_everything(config.seed)
Global seed set to 12
12

Step 1.b: Make it a unique experiment#

The ‘experiment_name’ is used to specify the folder in which to save models and associated files

config.experiment_name = 'training_demo'

Step 2: Setup Dataloader#

  • ‘loader_type’ can be one of ‘local’ which assumes the images are already saved locally, or ‘bucket’ which will load images directly from the bucket specified in ‘bucket_id’. To load images from the bucket the GOOGLE_APPLICATION_CREDENTIALS and GS_USER_PROJECT env variables must be set. If set to ‘local’ and the dataset is not found in the path config.data_params.path_to_splits it will trigger the download of the data.

  • The WorldFloods dataset contains 264.29GB of data. We can load a subset of this by using a custom train_test_split_sample.json which will only download a subset of the training dataset and the validation and test sets.

%%time

from ml4floods.models.dataset_setup import get_dataset

config.data_params.batch_size = 16 # control this depending on the space on your GPU!
config.data_params.loader_type = 'local'
config.data_params.path_to_splits = dataset_folder # local folder to download the data
config.data_params.train_test_split_file = None

# If files are not in config.data_params.path_to_splits this will trigger the download of the products.
dataset = get_dataset(config.data_params)
train_test_split_file not provided. We will use the content in the folder ./worldfloods_v1_0_sample
train 6298  tiles
val 1284  tiles
test 11  tiles
CPU times: user 98.1 ms, sys: 4.63 ms, total: 103 ms
Wall time: 169 ms

Show some images from the dataloader#

The dataset object is a pytorch_lightining DataModule object. This object has the WorldFloods train, val and test datasets as attributes (dataset.train_dataset, dataset.val_dataset and dataset.test_dataset). In addition we can create pytorch DataLoaders from them using the methods train_dataloader(), val_dataloader() and test_dataloader().

train_dl = dataset.train_dataloader()

train_dl_iter = iter(train_dl)
batch = next(train_dl_iter)

batch["image"].shape, batch["mask"].shape
(torch.Size([16, 13, 256, 256]), torch.Size([16, 1, 256, 256]))
from ml4floods.models import worldfloods_model
import matplotlib.pyplot as plt

n_images=6
fig, axs = plt.subplots(3,n_images, figsize=(18,10),tight_layout=True)
worldfloods_model.plot_batch(batch["image"][:n_images],axs=axs[0],max_clip_val=3500.)
worldfloods_model.plot_batch(batch["image"][:n_images],bands_show=["B11","B8", "B4"],
                             axs=axs[1],max_clip_val=4500.)
worldfloods_model.plot_batch_output_v1(batch["mask"][:n_images, 0],axs=axs[2], show_axis=True)
../../_images/54c034f70576156c25e17f57a7c3c5a7daf4e66e497042c9bcf7b1113397fb64.png

Step 3: Setup Model#

 - 'train' = True specifies that we are training a new model from scratch
 - get_model(args) constructs a pytorch lightning model using the configuration specified in 'config.model_params'
 # folder to store the trained model (it will create a subfolder with the name of the experiment)
config.model_params
{'model_folder': 'gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart',
 'model_version': 'v1',
 'hyperparameters': {'max_tile_size': 256,
  'metric_monitor': 'val_dice_loss',
  'channel_configuration': 'all',
  'label_names': ['land', 'water', 'cloud'],
  'weight_per_class': [1.93445299, 36.60054169, 2.19400729],
  'model_type': 'linear',
  'num_classes': 3,
  'max_epochs': 10,
  'val_every': 1,
  'lr': 0.0001,
  'lr_decay': 0.5,
  'lr_patience': 2,
  'early_stopping_patience': 4,
  'num_channels': 13},
 'train': True,
 'test': True}
from ml4floods.models.model_setup import get_model

config.model_params.model_folder = "models" 
os.makedirs("models", exist_ok=True)
config.model_params.test = False
config.model_params.train = True
config.model_params.hyperparameters.model_type = "simplecnn" # Currently implemented: simplecnn, unet, linear
model = get_model(config.model_params)
model
WorldFloodsModel(
  (network): SimpleCNN(
    (conv): Sequential(
      (0): Sequential(
        (0): Conv2d(13, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (2): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

Step 4: (Optional) Set up Weights and Biases Logger for experiment#

We pass this to the model trainer in a later cell to automaticall log relevant metrics to wandb

setup_weights_and_biases = False
if setup_weights_and_biases:
    import wandb
    from pytorch_lightning.loggers import WandbLogger

    # UNCOMMENT ON FIRST RUN TO LOGIN TO Weights and Biases (only needs to be done once)
    # wandb.login()
    # run = wandb.init()

    # Specifies who is logging the experiment to wandb
    config['wandb_entity'] = 'ml4floods'
    # Specifies which wandb project to log to, multiple runs can exist in the same project
    config['wandb_project'] = 'worldfloods-notebook-demo-project'

    wandb_logger = WandbLogger(
        name=config.experiment_name,
        project=config.wandb_project, 
        entity=config.wandb_entity
    )
else:
    wandb_logger = None

Step 5: Setup Lightning Callbacks#

  • We implement checkpointing using the ModelCheckpoint callback to save the best performing checkpoints to local/gcs storage

  • We implement early stopping using the EarlyStopping callback to stop training early if there is no performance improvement after 10 epochs from the latest best checkpoint

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{experiment_path}/checkpoint",
    save_top_k=True,
    verbose=True,
    monitor='val_dice_loss',
    mode='min'
)

early_stop_callback = EarlyStopping(
    monitor='val_dice_loss',
    patience=10,
    strict=False,
    verbose=False,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

print(f"The trained model will be stored in {config.model_params.model_folder}/{config.experiment_name}")
The trained model will be stored in models/training_demo

Step 6: Setup Lighting Trainer#

-- Pytorch Lightning Trainer handles all the rest of the model training for us!
-- add flags from 
https://pytorch-lightning.readthedocs.io/en/0.7.5/trainer.html 
from pytorch_lightning import Trainer

config.gpus = '0'  # which gpu to use
# config.gpus = None # to not use GPU

config.model_params.hyperparameters.max_epochs = 4 # train for maximum 4 epochs

trainer = Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    callbacks=callbacks,
    default_root_dir=f"{config.model_params.model_folder}/{config.experiment_name}",
    accumulate_grad_batches=1,
    gradient_clip_val=0.0,
    auto_lr_find=False,
    benchmark=False,
    gpus=config.gpus,
    max_epochs=config.model_params.hyperparameters.max_epochs,
    check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
    log_gpu_memory=None,
    resume_from_checkpoint=None
)
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/gonzalo/miniconda3/envs/ml4floods/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1580: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.
  rank_zero_warn(

Start Training!#

trainer.fit(model, dataset)
wandb: Currently logged in as: ipl_uv (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.25 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
Tracking run with wandb version 0.10.21
Syncing run training_demo to Weights & Biases (Documentation).
Project page: https://wandb.ai/ml4floods/worldfloods-notebook-demo-project
Run page: https://wandb.ai/ml4floods/worldfloods-notebook-demo-project/runs/2b206f4a
Run data is saved locally in /home/gonzalo/ml4floods/jupyterbook/content/ml4ops/wandb/run-20210408_084312-2b206f4a

  | Name    | Type      | Params
--------------------------------------
0 | network | SimpleCNN | 266 K 
--------------------------------------
266 K     Trainable params
0         Non-trainable params
266 K     Total params
1.065     Total estimated model params size (MB)
/opt/miniconda3/envs/ml4floods/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:50: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 0, global step 393: val_dice_loss reached 0.60017 (best 0.60017), saving model to "/home/gonzalo/ml4floods/jupyterbook/content/ml4ops/models/training_demo/checkpoint/epoch=0-step=393.ckpt" as top True
Epoch 1, global step 787: val_dice_loss reached 0.59220 (best 0.59220), saving model to "/home/gonzalo/ml4floods/jupyterbook/content/ml4ops/models/training_demo/checkpoint/epoch=1-step=787.ckpt" as top True
Epoch 2, global step 1181: val_dice_loss reached 0.56052 (best 0.56052), saving model to "/home/gonzalo/ml4floods/jupyterbook/content/ml4ops/models/training_demo/checkpoint/epoch=2-step=1181-v1.ckpt" as top True
Epoch 3, global step 1575: val_dice_loss reached 0.55334 (best 0.55334), saving model to "/home/gonzalo/ml4floods/jupyterbook/content/ml4ops/models/training_demo/checkpoint/epoch=3-step=1575.ckpt" as top True
1

Step 7: Eval model#

Plot some images and predictions#

# Run inference on the images shown before
import torch

logits = model(batch["image"].to(model.device))
print(f"Shape of logits: {logits.shape}")
probs = torch.softmax(logits, dim=1)
print(f"Shape of probs: {probs.shape}")
prediction = torch.argmax(probs, dim=1).long().cpu()
print(f"Shape of prediction: {prediction.shape}")
Shape of logits: torch.Size([16, 3, 256, 256])
Shape of probs: torch.Size([16, 3, 256, 256])
Shape of prediction: torch.Size([16, 256, 256])
n_images=6
fig, axs = plt.subplots(4, n_images, figsize=(18,14),tight_layout=True)
worldfloods_model.plot_batch(batch["image"][:n_images],axs=axs[0],max_clip_val=3500.)
worldfloods_model.plot_batch(batch["image"][:n_images],bands_show=["B11","B8", "B4"],
                             axs=axs[1],max_clip_val=4500.)
worldfloods_model.plot_batch_output_v1(batch["mask"][:n_images, 0],axs=axs[2], show_axis=True)
worldfloods_model.plot_batch_output_v1(prediction[:n_images] + 1,axs=axs[3], show_axis=True)

for ax in axs.ravel():
    ax.grid(False)
../../_images/8989c68772f0e895430df5e8c6ce4a807223760bfd1288e83958dab37c7a15c4.png

Eval in the val dataset#

import torch
import numpy as np
from ml4floods.models.utils import metrics
from ml4floods.models.model_setup import get_model_inference_function
import pandas as pd

config.model_params.max_tile_size = 1024

inference_function = get_model_inference_function(model, config, apply_normalization=False, 
                                                  activation="softmax",
                                                  device=torch.device("cuda"))

dl = dataset.val_dataloader() # pytorch Dataloader

# Otherwise fails when reading test dataset from remote bucket
# torch.set_num_threads(1)

thresholds_water = [0,1e-3,1e-2]+np.arange(0.5,.96,.05).tolist() + [.99,.995,.999]

mets = metrics.compute_metrics(
    dl,
    inference_function, 
    thresholds_water=thresholds_water, 
    convert_targets=False,
    plot=False)

label_names = ["land", "water", "cloud"]
metrics.plot_metrics(mets, label_names)
Getting model inference function
Max tile size: 256
81it [00:52,  1.55it/s]                        
../../_images/dd9ee2e66c471f11550f0fb947cc83e41b59d0c241ec0c67cd3ce6061b20bd05.png ../../_images/24c16d8c97e9604667a7eef12defda50363c0d93e2a8c249039d3a20b3f6de14.png
Per Class IOU {
    "cloud": 0.8116431733608086,
    "land": 0.9123927497732395,
    "water": 0.605203573769534
}

Show results for each flood event in the validation dataset#

if hasattr(dl.dataset, "image_files"):
    cems_code = [os.path.basename(f).split("_")[0] for f in dl.dataset.image_files]
else:
    cems_code = [os.path.basename(f.file_name).split("_")[0] for f in dl.dataset.list_of_windows]

iou_per_code = pd.DataFrame(metrics.group_confusion(mets["confusions"],cems_code, metrics.calculate_iou,
                                                    label_names=[f"IoU_{l}"for l in ["land", "water", "cloud"]]))

recall_per_code = pd.DataFrame(metrics.group_confusion(mets["confusions"],cems_code, metrics.calculate_recall,
                                                       label_names=[f"Recall_{l}"for l in ["land", "water", "cloud"]]))

join_data_per_code = pd.merge(recall_per_code,iou_per_code,on="code")
join_data_per_code = join_data_per_code.set_index("code")
join_data_per_code = join_data_per_code*100
print(f"Mean values across flood events: {join_data_per_code.mean(axis=0).to_dict()}")
join_data_per_code
Mean values across flood events: {'Recall_land': 93.08255820002643, 'Recall_water': 81.45670619858558, 'Recall_cloud': 76.05831858902059, 'IoU_land': 90.55111450826736, 'IoU_water': 53.68072319568316, 'IoU_cloud': 64.91956034880197}
Recall_land Recall_water Recall_cloud IoU_land IoU_water IoU_cloud
code
EMSR271 75.652827 97.856891 93.792703 75.435118 24.420104 70.859015
EMSR279 89.994390 78.400188 85.984087 83.700250 32.200438 81.255267
EMSR280 99.209005 91.197996 61.264703 97.885050 86.735377 49.392122
EMSR287 99.548484 73.487523 35.077289 98.816743 64.139818 17.221249
RS2 95.479942 84.557994 87.650254 91.563623 56.458526 82.770675
ST1 98.610701 63.239645 92.580875 95.905902 58.130076 88.019034

Step 8: Save trained model#

Save model to local/gcs along with configuration file used to conduct training!

import torch
from pytorch_lightning.utilities.cloud_io import atomic_save
from ml4floods.models.config_setup import save_json

# Save in the cloud and in the wandb logger save dir
atomic_save(model.state_dict(), f"{experiment_path}/model.pt")
# Save cofig file in experiment_path
config_file_path = f"{experiment_path}/config.json"
save_json(config_file_path, config)

Optional: Save weights and biases model and finish connection#

if setup_weights_and_biases:
    torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model.pt'))
    wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt')) # Copy weights to weights and biases server
    wandb.finish()
wandb: WARNING Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save("/mnt/folder/file.h5", base_path="/mnt")

Waiting for W&B process to finish, PID 3265
Program ended successfully.
Find user logs for this run at: /home/gonzalo/ml4floods/jupyterbook/content/ml4ops/wandb/run-20210408_084312-2b206f4a/logs/debug.log
Find internal logs for this run at: /home/gonzalo/ml4floods/jupyterbook/content/ml4ops/wandb/run-20210408_084312-2b206f4a/logs/debug-internal.log

Run summary:


_runtime619
_timestamp1617872011
_step1575
loss0.76066
epoch3
val_bce_loss1.32021
val_dice_loss0.55334
val_recall land0.93818
val_recall water0.84217
val_recall cloud0.6011
val_iou land0.89992
val_iou water0.57418
val_iou cloud0.31644

Run history:


_runtime▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█
_timestamp▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█
_step▁▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇████
loss██▄▄▂▂▁▄▄▄▄▁▁▂▂▂▂▂▂▂▂▃▃▃▃▁▁▃▃▂▂
epoch▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆█████████
val_bce_loss▅▁█▃
val_dice_loss█▇▂▁
val_recall land▄▁█▆
val_recall water██▁▃
val_recall cloud▁▄▇█
val_iou land▁▁█▇
val_iou water▂▁█▆
val_iou cloud▂▁▇█

Synced 5 W&B file(s), 864 media file(s), 0 artifact file(s) and 1 other file(s)

All Done - Now head to the Model Inference Tutorial to see how your model performed!