Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clay v1 base #136

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
6 changes: 5 additions & 1 deletion terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from terratorch.datamodules.sen1floods11 import Sen1Floods11NonGeoDataModule
from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule

# miscellaneous datamodules
from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule

__all__ = (
"GenericNonGeoSegmentationDataModule",
"GenericNonGeoPixelwiseRegressionDataModule",
Expand All @@ -50,5 +53,6 @@
"MChesapeakeLandcoverNonGeoDataModule",
"MPv4gerSegNonGeoDataModule",
"MSACropTypeNonGeoDataModule",
"MNeonTreeNonGeoDataModule"
"MNeonTreeNonGeoDataModule",
"OpenEarthMapModule"
)
58 changes: 58 additions & 0 deletions terratorch/datamodules/openearthmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any
import torch

import albumentations as A
import kornia.augmentation as K
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.transforms import AugmentationSequential
from terratorch.datasets import OpenEarthMapNonGeo
from terratorch.datamodules.utils import wrap_in_compose_is_list

MEANS = {
"BLUE": 116.628328,
"GREEN": 119.65935,
"RED": 113.385309
}

STDS = {
"BLUE": 44.668890717415586,
"GREEN": 48.282311849967364,
"RED": 54.19692448815262,
}

class OpenEarthMapNonGeoDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
**kwargs: Any
) -> None:
super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
self.means = torch.tensor([MEANS[b] for b in bands])
self.stds = torch.tensor([STDS[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
)
4 changes: 4 additions & 0 deletions terratorch/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
# TorchGeo RasterDatasets
from terratorch.datasets.wsf import WSF2019, WSFEvolution

# miscellaneous datasets
from terratorch.datasets.openearthmap import OpenEarthMapNonGeo

__all__ = (
"GenericNonGeoSegmentationDataset",
"GenericNonGeoPixelwiseRegressionDataset",
Expand All @@ -59,4 +62,5 @@
"WSFEvolution",
"HLSL30",
"HLSS30",
"OpenEarthMapNonGeo"
)
114 changes: 114 additions & 0 deletions terratorch/datasets/openearthmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from collections.abc import Sequence
import matplotlib.pyplot as plt
import torch
import rasterio
from pathlib import Path

import albumentations as A

from torchgeo.datasets import NonGeoDataset
from terratorch.datasets.utils import to_tensor



class OpenEarthMapNonGeo(NonGeoDataset):

all_band_names = ("BLUE","GREEN","RED")

rgb_bands = ("RED","GREEN","BLUE")

BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

def __init__(self, data_root: str,
bands: Sequence[str] = BAND_SETS["all"],
transform: A.Compose | None = None,
split="train",
crop_size: int = 256,
random_crop: bool = True) -> None:
super().__init__()
if split not in ["train", "test", "val"]:
msg = "Split must be one of train, test, val."
raise Exception(msg)

self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False)
self.split = split
self.data_root = data_root

# images in openearthmap are not all 1024x1024 and must be cropped
self.crop_size = crop_size
self.random_crop = random_crop

assert self.crop_size > 0, "Crop size must be greater than 0"

self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt"))

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
image_path, label_path = self.image_files[index]

with rasterio.open(image_path) as src:
image = src.read()
with rasterio.open(label_path) as src:
mask = src.read()

# some images in the dataset are not perfect squares
# cropping to fit to the prepare_features_for_image_model call
if self.random_crop:
image, mask = self._random_crop(image, mask)
else:
image, mask = self._center_crop(image, mask)

output = {
"image": image.astype(np.float32),
"mask": mask
}

output = self.transform(**output)
output['mask'] = output['mask'].long()

return output

def _parse_file_name(self, file_name: str):
underscore_pos = file_name.rfind('_')
folder_name = file_name[:underscore_pos]
region_path = Path(self.data_root, folder_name)
image_path = Path(region_path, "images", file_name)
label_path = Path(region_path, "labels", file_name)
return image_path, label_path

def _get_file_paths(self, text_file_path: str):
with open(text_file_path, 'r') as file:
lines = file.readlines()
file_paths = [self._parse_file_name(line.strip()) for line in lines]
return file_paths

def __len__(self):
return len(self.image_files)

def _random_crop(self, image, mask):
h, w = image.shape[1:]
top = np.random.randint(0, h - self.crop_size)
left = np.random.randint(0, w - self.crop_size)

image = image[:, top: top + self.crop_size, left: left + self.crop_size]
mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

return image, mask

def _center_crop(self, image, mask):
h, w = image.shape[1:]
top = (h - self.crop_size) // 2
left = (w - self.crop_size) // 2

image = image[:, top: top + self.crop_size, left: left + self.crop_size]
mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

return image, mask

def plot(self, arg, suptitle: str | None = None) -> None:
pass

def plot_sample(self, sample, prediction=None, suptitle: str | None = None, class_names=None):
pass


4 changes: 2 additions & 2 deletions terratorch/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def _split_filter_function(file_name, valid_files: list[str], ignore_extensions=
return False


def to_tensor(d):
def to_tensor(d, transpose=True):
new_dict = {}
for k, v in d.items():
if not isinstance(v, np.ndarray):
new_dict[k] = v
else:
if k == "image":
if k == "image" and transpose:
v = np.moveaxis(v, -1, 0)
new_dict[k] = torch.from_numpy(v)
return new_dict
1 change: 1 addition & 0 deletions terratorch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright contributors to the Terratorch project

from terratorch.models.clay_model_factory import ClayModelFactory
from terratorch.models.prithvi_model_factory import PrithviModelFactory
from terratorch.models.satmae_model_factory import SatMAEModelFactory
from terratorch.models.scalemae_model_factory import ScaleMAEModelFactory
Expand Down
1 change: 1 addition & 0 deletions terratorch/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
# import so they get registered
import terratorch.models.backbones.prithvi_swin
import terratorch.models.backbones.prithvi_vit
import terratorch.models.backbones.clay_v1
3 changes: 3 additions & 0 deletions terratorch/models/backbones/clay_v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import terratorch.models.backbones.clay_v1.embedder
import terratorch.models.backbones.clay_v1.modules
import terratorch.models.backbones.clay_v1.utils
Loading
Loading