Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clay v1 base #136

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Clay v1 base #136

wants to merge 13 commits into from

Conversation

myscon
Copy link

@myscon myscon commented Aug 27, 2024

No description provided.

@Joao-L-S-Almeida
Copy link
Member

The tests are failing due to an import for vit_pytorch and I understand that it should came from here. Could it be replaced by some builtin module or should we add it to our requirements ?

@myscon
Copy link
Author

myscon commented Sep 5, 2024

Okay I removed vit-torch. I seem to have omitted the sign off and that was the reason for the force.

@Joao-L-S-Almeida
Copy link
Member

Some of the automatic tests are still failing. Maybe it could be interesting to run pytest -s tests locally and check if there any missing stuff.

@myscon
Copy link
Author

myscon commented Sep 6, 2024

Made some small adjustments to the test file. They pass locally.

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
@Joao-L-S-Almeida
Copy link
Member

@myconv I saw that the problem was related to the device. I modified it to "cpu" (the GitHub runner was not configured with GPU) and the tests are passing. Please, check the PR #160, try to update your fork and we can proceed with the review.

img_size=256,
num_frames=1,
ckpt_path=None,
device="cuda",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device shouldn't be required to be passed. In pytorch lightning, .to(device) statements should not be used, as the responsibility to place modules in the correct device is passed to lightning.

heads=12,
dim_head=64,
mlp_ratio=4.0,
).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, such statements should be removed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Those have been removed.

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Copy link
Collaborator

@CarlosGomes98 CarlosGomes98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments. This is quite a big task, so we may want to think about how we can split work somehow :)

img_size=256,
num_frames=1,
ckpt_path=None,
bands=["blue", "green", "red", "nir", "swir16", "swir22"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable data structures should not be used for argument defaults. Consider None and in the method body explicitly check for None and replace it with the desired default. https://docs.astral.sh/ruff/rules/mutable-argument-default/

Also, if compatible, consider using the HLSBands Enum instead of strings.


# for use in features list. Single layer feature for simplicity
self.feature_info.append(
{"num_chs": 768, "reduction": 1, "module": f"clay_encoder"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for f-string in module

@@ -190,12 +160,14 @@ def build_model(
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
)

to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

closing bracket can be on the same line

@@ -57,7 +59,7 @@ def build_model(
backbone: str | nn.Module,
decoder: str | nn.Module,
in_channels: int,
bands: list[HLSBands | int],
bands: list[int] = [],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable data structure should not be used here. See other comment on a simlar issue for details

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, consider allowing the HLSBands Enum to be used

# When the model is not on HG, it needs be restored locally.
print("This model is not available on HuggingFace. Trying to instantiate locally ...")
except Exception as e:
print(e, "Error loading from HF. Trying to instantiate locally ...")

assert checkpoint_path, "A checkpoint must be provided to restore the model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer checking and throwing an exception if not. Assertions can be removed when python is run with optimizations



@pytest.mark.parametrize("backbone", ["clay_v1_base"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here



@pytest.mark.parametrize("backbone", ["clay_v1_base"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FCN and upernet probably dont belong here

os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"

# central wavelengths of pretrained model
WAVELENGTHS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this use HLSBands for the keys?

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h=self.heads), qkv)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(rearrange(
t, 'b n (h d) -> b h n d', h=self.heads) for t in qkv)

may be simpler

patches, waves_encoded = self.to_patch_embed(
cube, waves
) # [B L D] - patchify & create embeddings per patch
# TODO: Add time & latlon as encoding to patches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still todo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants