Skip to content

DTACSNet: onboard cloud detection and atmospheric correction with end-to-end deep learning emulators

This repo contains an open implementation to run inference with DTACSNet models for atmospheric correction and also for cloud detection. These two models are independent (you could run one or the other or both). The trained models provided here are customized to the band configuration that will be available in Phi-Sat-II. Trained models are released under a Creative Commons non-commercial licence licence.

# pip install dtacs

Load data

The image that we will use for this tutorial is in the tutorials/examples folder. It is a small patch of this Sentinel-2 tile: S2A_MSIL1C_20191111T132241_N0208_R038_T23LLK_20191111T145626. The image has been been taken from CloudSEN12 dataset. It corresponds to ROI ROI_0312 in CloudSEN12.

import rasterio
import os


phisat2_bands = ["B2","B3","B4","B5","B6","B7","B8"]

folder_examples = "examples"
# In the Google Earth Engine its id is `var img = ee.Image("COPERNICUS/S2_HARMONIZED/20191111T132241_20191111T132235_T23LLK");` 


with rasterio.open(os.path.join(folder_examples,"S2L1C.tif")) as rst:
    indexes_read = [rst.descriptions.index(b) + 1 for b in phisat2_bands]
    data = rst.read(indexes_read)
data.shape
(7, 509, 509)
data
array([[[ 991.,  982.,  966., ..., 1089., 1096., 1059.],
        [ 996.,  992.,  992., ..., 1062., 1080., 1076.],
        [1010.,  962.,  976., ..., 1064., 1072., 1094.],
        ...,
        [1053., 1550., 1870., ...,  952.,  979.,  988.],
        [1404., 2125., 2356., ...,  979.,  980., 1004.],
        [1608., 2245., 2476., ...,  960.,  973.,  983.]],

       [[ 954.,  932.,  922., ..., 1044., 1063., 1045.],
        [ 969.,  957.,  958., ..., 1006., 1046., 1055.],
        [ 997.,  956.,  944., ..., 1004., 1062., 1077.],
        ...,
        [1268., 1925., 2120., ...,  998., 1013.,  961.],
        [1614., 2223., 2387., ..., 1001., 1001.,  975.],
        [1667., 2116., 2384., ..., 1004., 1013., 1015.]],

       [[ 791.,  785.,  692., ..., 1142., 1115., 1039.],
        [ 817.,  802.,  768., ..., 1013., 1093., 1099.],
        [ 856.,  745.,  750., ...,  976., 1014., 1062.],
        ...,
        [1605., 2188., 2344., ...,  662.,  617.,  657.],
        [1740., 2223., 2446., ...,  649.,  627.,  650.],
        [1571., 2046., 2487., ...,  603.,  625.,  627.]],

       ...,

       [[2310., 2337., 2337., ..., 2163., 2099., 2099.],
        [2310., 2337., 2337., ..., 2163., 2099., 2099.],
        [2318., 2387., 2387., ..., 2291., 2223., 2223.],
        ...,
        [2020., 3246., 3246., ..., 2773., 2788., 2788.],
        [2020., 3246., 3246., ..., 2773., 2788., 2788.],
        [2433., 3233., 3233., ..., 2976., 2858., 2858.]],

       [[2806., 2838., 2838., ..., 2542., 2590., 2590.],
        [2806., 2838., 2838., ..., 2542., 2590., 2590.],
        [2741., 2987., 2987., ..., 2732., 2687., 2687.],
        ...,
        [2419., 3470., 3470., ..., 3398., 3337., 3337.],
        [2419., 3470., 3470., ..., 3398., 3337., 3337.],
        [3041., 3503., 3503., ..., 3621., 3467., 3467.]],

       [[2635., 2549., 2696., ..., 2285., 2320., 2524.],
        [2605., 2594., 2544., ..., 2283., 2406., 2445.],
        [2630., 2827., 2712., ..., 2454., 2637., 2547.],
        ...,
        [1641., 2585., 2960., ..., 3219., 3330., 3077.],
        [2197., 3122., 3377., ..., 3104., 3239., 3210.],
        [2460., 3113., 3395., ..., 3388., 3400., 3253.]]], dtype=float32)

Atmospheric correction model

Load model

from dtacs.model_wrapper import ACModel

model_atmospheric_correction = ACModel(model_name="CNN_corrector_phisat2")
model_atmospheric_correction.load_weights()
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4.13k/4.13k [00:00<00:00, 3.09MiB/s]
/home/gonzalo/git/DTACSNet/dtacs/model_wrapper.py:92: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  weights =  torch.load(fh, map_location=self.device)

Run inference

ac_output = model_atmospheric_correction.predict(data)
ac_output
array([[[ 417,  396,  388, ...,  526,  537,  506],
        [ 423,  416,  412, ...,  489,  522,  521],
        [ 447,  390,  397, ...,  496,  530,  552],
        ...,
        [ 451, 1153, 1614, ...,  394,  445,  427],
        [ 989, 1968, 2301, ...,  419,  437,  461],
        [1237, 2103, 2449, ...,  408,  436,  436]],

       [[ 710,  678,  668, ...,  825,  852,  833],
        [ 728,  713,  710, ...,  770,  833,  846],
        [ 762,  718,  700, ...,  772,  857,  875],
        ...,
        [1059, 1913, 2236, ...,  766,  794,  721],
        [1575, 2413, 2677, ...,  768,  775,  749],
        [1661, 2303, 2699, ...,  775,  795,  790]],

       [[ 646,  629,  532, ..., 1044, 1027,  948],
        [ 674,  651,  612, ...,  900, 1004, 1012],
        [ 722,  592,  594, ...,  859,  918,  970],
        ...,
        [1530, 2249, 2493, ...,  494,  452,  483],
        [1768, 2415, 2710, ...,  479,  458,  485],
        [1618, 2247, 2780, ...,  424,  459,  456]],

       ...,

       [[2503, 2541, 2535, ..., 2323, 2266, 2254],
        [2504, 2539, 2542, ..., 2323, 2261, 2258],
        [2496, 2603, 2610, ..., 2477, 2393, 2398],
        ...,
        [2147, 3515, 3552, ..., 3036, 3044, 3056],
        [2178, 3597, 3620, ..., 3046, 3048, 3054],
        [2715, 3620, 3633, ..., 3273, 3131, 3138]],

       [[2933, 2971, 2971, ..., 2636, 2691, 2685],
        [2932, 2969, 2973, ..., 2642, 2688, 2686],
        [2857, 3128, 3134, ..., 2852, 2794, 2796],
        ...,
        [2461, 3624, 3663, ..., 3576, 3514, 3522],
        [2497, 3718, 3738, ..., 3586, 3516, 3522],
        [3233, 3798, 3796, ..., 3826, 3655, 3661]],

       [[2912, 2811, 2980, ..., 2505, 2569, 2799],
        [2877, 2865, 2807, ..., 2499, 2666, 2710],
        [2892, 3144, 3013, ..., 2700, 2921, 2820],
        ...,
        [1732, 2764, 3277, ..., 3569, 3694, 3403],
        [2457, 3532, 3882, ..., 3442, 3588, 3562],
        [2811, 3569, 3941, ..., 3757, 3778, 3608]]], dtype=uint16)

Plot

import matplotlib.pyplot as plt
from dtacs import plot
import rasterio.plot as rstplt

# Load L2A to show
with rasterio.open(os.path.join(folder_examples,"S2L2A.tif")) as rst:
    indexes_read = [rst.descriptions.index(b) + 1 for b in phisat2_bands]
    data_sen2cor = rst.read(indexes_read)

fig, ax = plt.subplots(2,3,figsize=(18,12),tight_layout=True)

rstplt.show(data[2::-1,...]/4_000,ax=ax[0,0])
ax[0,0].set_title("RGB L1C")
rstplt.show(ac_output[2::-1,...]/4_000,ax=ax[0,1])
ax[0,1].set_title("RGB DTACSNet")
rstplt.show(data_sen2cor[2::-1,...]/4_000,ax=ax[0,2])
ax[0,2].set_title("RGB L2A")

nirredgreen = [-1, 2, 1]
rstplt.show(data[nirredgreen,...]/10_000,ax=ax[1,0])
ax[1,0].set_title("NIRRG L1C")
rstplt.show(ac_output[nirredgreen,...]/10_000,ax=ax[1,1])
ax[1,1].set_title("NIRRG DTACSNet")
rstplt.show(data_sen2cor[nirredgreen,...]/10_000,ax=ax[1,2])
ax[1,2].set_title("NIRRG L2A")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.08225..1.67075].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..2.004].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.00025..2.44].

Text(0.5, 1.0, 'NIRRG L2A')
No description has been provided for this image

Cloud detection model

The cloud detection model of DTACS are based on CloudSEN12 dataset. For more and better models use cloudsen12_models.

Load model

from dtacs.model_wrapper import CDModel, DIR_MODELS_LOCAL
import torch
from dtacs.download_weights import download_weights
import os

assert torch.__version__ &gt;= "1.13", f"Requires torch version &gt;=1.13 current version {torch.__version__ }"

# https://huggingface.co/isp-uv-es/cloudsen12_models/resolve/main/dtacs4bands.pt
weights_path = os.path.join(DIR_MODELS_LOCAL,"cloud4bands.pt")
download_weights(weights_path, "https://huggingface.co/isp-uv-es/cloudsen12_models/resolve/main/dtacs4bands.pt")
model_cd_torchscript = torch.jit.load(weights_path, map_location='cpu')
model_cloud_detection = CDModel(model=model_cd_torchscript)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 41.3M/41.3M [00:01<00:00, 40.9MiB/s]

Run inference

data_cloud_detection = data[[-1,2,1,0],...]

cd_output = model_cloud_detection.predict(data_cloud_detection)
cd_output
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [0, 1, 1, ..., 0, 0, 0],
       [0, 1, 1, ..., 0, 0, 0]], dtype=uint8)
fig, ax = plt.subplots(1,2,figsize=(12,6),tight_layout=True)

nirredgreen = [-1, 2, 1]

rstplt.show(data[2::-1,...]/4_000,ax=ax[0])
ax[0].set_title("RGB L1C")

plot.plot_cloudSEN12mask(cd_output,ax=ax[1])
ax[1].set_title("Cloud shadow mask")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.08225..1.67075].

Text(0.5, 1.0, 'Cloud shadow mask')
No description has been provided for this image