DTACSNet: onboard cloud detection and atmospheric correction with end-to-end deep learning emulators
This repo contains an open implementation to run inference with DTACSNet models for atmospheric correction and also for cloud detection. These two models are independent (you could run one or the other or both). The trained models provided here are customized to the band configuration that will be available in Phi-Sat-II. Trained models are released under a Creative Commons non-commercial licence .
# pip install dtacs
Load data
The image that we will use for this tutorial is in the tutorials/examples
folder. It is a small patch of this Sentinel-2 tile: S2A_MSIL1C_20191111T132241_N0208_R038_T23LLK_20191111T145626
. The image has been been taken from CloudSEN12 dataset. It corresponds to ROI ROI_0312
in CloudSEN12.
import rasterio
import os
phisat2_bands = ["B2","B3","B4","B5","B6","B7","B8"]
folder_examples = "examples"
# In the Google Earth Engine its id is `var img = ee.Image("COPERNICUS/S2_HARMONIZED/20191111T132241_20191111T132235_T23LLK");`
with rasterio.open(os.path.join(folder_examples,"S2L1C.tif")) as rst:
indexes_read = [rst.descriptions.index(b) + 1 for b in phisat2_bands]
data = rst.read(indexes_read)
data.shape
data
Atmospheric correction model
Load model
from dtacs.model_wrapper import ACModel
model_atmospheric_correction = ACModel(model_name="CNN_corrector_phisat2")
model_atmospheric_correction.load_weights()
Run inference
ac_output = model_atmospheric_correction.predict(data)
ac_output
Plot
import matplotlib.pyplot as plt
from dtacs import plot
import rasterio.plot as rstplt
# Load L2A to show
with rasterio.open(os.path.join(folder_examples,"S2L2A.tif")) as rst:
indexes_read = [rst.descriptions.index(b) + 1 for b in phisat2_bands]
data_sen2cor = rst.read(indexes_read)
fig, ax = plt.subplots(2,3,figsize=(18,12),tight_layout=True)
rstplt.show(data[2::-1,...]/4_000,ax=ax[0,0])
ax[0,0].set_title("RGB L1C")
rstplt.show(ac_output[2::-1,...]/4_000,ax=ax[0,1])
ax[0,1].set_title("RGB DTACSNet")
rstplt.show(data_sen2cor[2::-1,...]/4_000,ax=ax[0,2])
ax[0,2].set_title("RGB L2A")
nirredgreen = [-1, 2, 1]
rstplt.show(data[nirredgreen,...]/10_000,ax=ax[1,0])
ax[1,0].set_title("NIRRG L1C")
rstplt.show(ac_output[nirredgreen,...]/10_000,ax=ax[1,1])
ax[1,1].set_title("NIRRG DTACSNet")
rstplt.show(data_sen2cor[nirredgreen,...]/10_000,ax=ax[1,2])
ax[1,2].set_title("NIRRG L2A")
Cloud detection model
The cloud detection model of DTACS are based on CloudSEN12 dataset. For more and better models use cloudsen12_models
.
Load model
from dtacs.model_wrapper import CDModel, DIR_MODELS_LOCAL
import torch
from dtacs.download_weights import download_weights
import os
assert torch.__version__ >= "1.13", f"Requires torch version >=1.13 current version {torch.__version__ }"
# https://huggingface.co/isp-uv-es/cloudsen12_models/resolve/main/dtacs4bands.pt
weights_path = os.path.join(DIR_MODELS_LOCAL,"cloud4bands.pt")
download_weights(weights_path, "https://huggingface.co/isp-uv-es/cloudsen12_models/resolve/main/dtacs4bands.pt")
model_cd_torchscript = torch.jit.load(weights_path, map_location='cpu')
model_cloud_detection = CDModel(model=model_cd_torchscript)
Run inference
data_cloud_detection = data[[-1,2,1,0],...]
cd_output = model_cloud_detection.predict(data_cloud_detection)
cd_output
fig, ax = plt.subplots(1,2,figsize=(12,6),tight_layout=True)
nirredgreen = [-1, 2, 1]
rstplt.show(data[2::-1,...]/4_000,ax=ax[0])
ax[0].set_title("RGB L1C")
plot.plot_cloudSEN12mask(cd_output,ax=ax[1])
ax[1].set_title("Cloud shadow mask")