Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] sincos and rope embeddings in torch instead of numpy #9654

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

What does this PR do?

Internal discussion: https://huggingface.slack.com/archives/C065E480NN9/p1727418894443269

Code
import numpy as np
import torch
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, aryan_get_1d_sincos_pos_embed_from_grid
from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid, aryan_get_2d_sincos_pos_embed_from_grid
from diffusers.models.embeddings import get_2d_sincos_pos_embed, aryan_get_2d_sincos_pos_embed
from diffusers.models.embeddings import get_3d_sincos_pos_embed, aryan_get_3d_sincos_pos_embed
from diffusers.models.embeddings import get_2d_rotary_pos_embed, aryan_get_2d_rotary_pos_embed
from diffusers.models.embeddings import get_3d_rotary_pos_embed, aryan_get_3d_rotary_pos_embed


@torch.no_grad()
def test__get_1d_sincos_pos_embed_from_grid():
    base_size = 16
    interpolation_scale = 1.0
    for embed_dim in [128, 256, 1024]:
        for grid_size in [16, 32, 64, 128]:
            numpy_grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale
            sincos_pos_embed_numpy = get_1d_sincos_pos_embed_from_grid(embed_dim, numpy_grid)

            for device in ["cpu", "cuda"]:
                torch_grid = torch.from_numpy(numpy_grid).to(device)
                sincos_pos_embed_torch = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim, torch_grid).cpu().numpy()
                print(f"===== testing {test__get_1d_sincos_pos_embed_from_grid.__name__}({embed_dim=}, {grid_size=}, {device=}) =====")
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                print("==========")
                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_sincos_pos_embed_from_grid():
    base_size = 16
    interpolation_scale = 1.0
    for embed_dim in [128, 256, 1024]:
        for grid_size in [(32, 32), (64, 64), (128, 64), (64, 128)]:
            grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
            grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
            grid = np.meshgrid(grid_w, grid_h)
            grid = np.stack(grid)
            grid_numpy = grid.reshape(2, 1, *grid_size)
            sincos_pos_embed_numpy = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_numpy)
            
            for device in ["cpu", "cuda"]:
                grid_torch = torch.from_numpy(grid_numpy).to(device)
                sincos_pos_embed_torch = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim, grid_torch).cpu().numpy()
                print(f"===== testing {test__get_2d_sincos_pos_embed_from_grid.__name__}({embed_dim=}, {grid_size=}, {device=}) =====")
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                print("==========")
                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_sincos_pos_embed():
    for base_size in [16, 32, 64]:
        for interpolation_scale in [1.0, 2.0]:
            for embed_dim in [256, 1024]:
                for grid_size in [(64, 64), (128, 64), (64, 128)]:
                    for cls_token in [False, True]:
                        for extra_tokens in [0, 16]:
                            sincos_pos_embed_numpy = get_2d_sincos_pos_embed(
                                embed_dim=embed_dim,
                                grid_size=grid_size,
                                cls_token=cls_token,
                                extra_tokens=extra_tokens,
                                interpolation_scale=interpolation_scale,
                                base_size=base_size,
                            )

                            for device in ["cpu", "cuda"]:
                                sincos_pos_embed_torch = aryan_get_2d_sincos_pos_embed(
                                    embed_dim=embed_dim,
                                    grid_size=grid_size,
                                    cls_token=cls_token,
                                    extra_tokens=extra_tokens,
                                    interpolation_scale=interpolation_scale,
                                    base_size=base_size,
                                    device=device,
                                ).cpu().numpy()
                                print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({base_size=}, {interpolation_scale=}, {embed_dim=}, {grid_size=}, {cls_token=}, {extra_tokens=}, {device=}) =====")
                                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                                print("==========")
                                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_3d_sincos_pos_embed():
    for embed_dim in [256, 1024]:
        for spatial_size in [(64, 64), (128, 64), (64, 128)]:
            for temporal_size in [8, 16]:
                for spatial_interpolation_scale in [1.0, 2.0]:
                    for temporal_interpolation_scale in [1.0, 2.0]:
                        sincos_pos_embed_numpy = get_3d_sincos_pos_embed(
                            embed_dim=embed_dim,
                            spatial_size=spatial_size,
                            temporal_size=temporal_size,
                            spatial_interpolation_scale=spatial_interpolation_scale,
                            temporal_interpolation_scale=temporal_interpolation_scale,
                        )

                        for device in ["cpu", "cuda"]:
                            sincos_pos_embed_torch = aryan_get_3d_sincos_pos_embed(
                                embed_dim=embed_dim,
                                spatial_size=spatial_size,
                                temporal_size=temporal_size,
                                spatial_interpolation_scale=spatial_interpolation_scale,
                                temporal_interpolation_scale=temporal_interpolation_scale,
                                device=device,
                            ).cpu().numpy()
                            print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({embed_dim=}, {spatial_size=}, {temporal_size=}, {spatial_interpolation_scale=}, {temporal_interpolation_scale=}) =====")
                            print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                            print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                            print("==========")
                            assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_rotary_pos_embed():
    for embed_dim in [256, 1024]:
        for crops_coords in [[(0, 0), (8, 8)], [(0, 0), (16, 16)]]:
            for grid_size in [(64, 64), (128, 64), (64, 128)]:
                rope_numpy = get_2d_rotary_pos_embed(
                    embed_dim=embed_dim,
                    crops_coords=crops_coords,
                    grid_size=grid_size,
                )

                for device in ["cpu", "cuda"]:
                    rope_torch = aryan_get_2d_rotary_pos_embed(
                        embed_dim=embed_dim,
                        crops_coords=crops_coords,
                        grid_size=grid_size,
                    )
                    rope_torch = rope_torch[0].cpu().numpy(), rope_torch[1].cpu().numpy()
                    print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({embed_dim=}, {crops_coords=}, {grid_size=}, {device=}) =====")
                    print(np.abs(rope_numpy[0] - rope_torch[0]).max())
                    print(np.abs(rope_numpy[1] - rope_torch[1]).max())
                    print(np.abs(rope_numpy[0] - rope_torch[0]).sum())
                    print(np.abs(rope_numpy[1] - rope_torch[1]).sum())
                    print("==========")
                    assert np.allclose(rope_numpy, rope_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_3d_rotary_pos_embed():
    for embed_dim in [256, 1024]:
        for crops_coords in [[(0, 0), (8, 8)], [(0, 0), (16, 16)]]:
            for grid_size in [(64, 64), (128, 64), (64, 128)]:
                for temporal_size in [8, 16]:
                    rope_numpy = get_3d_rotary_pos_embed(
                        embed_dim=embed_dim,
                        crops_coords=crops_coords,
                        grid_size=grid_size,
                        temporal_size=temporal_size,
                    )

                    for device in ["cpu", "cuda"]:
                        rope_torch = aryan_get_3d_rotary_pos_embed(
                            embed_dim=embed_dim,
                            crops_coords=crops_coords,
                            grid_size=grid_size,
                            temporal_size=temporal_size,
                            device=device,
                        )

                        # ============== ============== ============== ============== ============== ============== ============== 
                        # NOTE/TODO: NOT SURE WHY THIS NEEDS HIGHER TOLERANCE EVEN THOUGH ALL THE OPERATIONS ARE SIMILAR-ISH TO THE 2D CASE
                        # - IT IS EXACTLY ZERO ON CPU
                        # - BUT ON CUDA, IT IS SLIGHTLY NUMERICALLY DIFFERENT
                        # ============== ============== ============== ============== ============== ============== ============== 

                        rope_torch = rope_torch[0].cpu().numpy(), rope_torch[1].cpu().numpy()
                        print(f"===== testing {test__get_3d_sincos_pos_embed.__name__}({embed_dim=}, {crops_coords=}, {grid_size=}, {temporal_size=}, {device=}) =====")
                        print(np.abs(rope_numpy[0] - rope_torch[0]).max())
                        print(np.abs(rope_numpy[1] - rope_torch[1]).max())
                        print(np.abs(rope_numpy[0] - rope_torch[0]).sum())
                        print(np.abs(rope_numpy[1] - rope_torch[1]).sum())
                        print("==========")
                        assert np.allclose(rope_numpy, rope_torch, rtol=1e-12, atol=1e-6)


test__get_1d_sincos_pos_embed_from_grid()
test__get_2d_sincos_pos_embed_from_grid()
test__get_2d_sincos_pos_embed()
test__get_3d_sincos_pos_embed()
test__get_2d_rotary_pos_embed()
test__get_3d_rotary_pos_embed()
import numpy as np
assert np.finfo(np.float32).eps == np.float32(1.1920929e-07)

Most of the numerical differences that are seen are okay to have when comparing numpy array with tensor-to-numpy array, and are within the magnitude of float32.eps (in fact, similar on the order of 1e-12 so these changes should be safe to make). They are sometimes even 0 if you do the comparison in tensors instead of numpy. The only peculiar case is the 3D rope embeddings, which seem to require a much higher tolerance (1e-6) instead of 1e-12-1e-15. I'm trying to figure out why :(

Once we finalize and decide to go forward with this, I'll rename the functions (currently prefixed), remove the older numpy based functions, and pass the tensor device and all usage occurrences so that all sincos/rope embedding creation occurs on device directly.


To those who don't have access to the internal link, you might wonder why we need this change? It's because using torch.compile leads to a graph break and cudaMemSync on models that create sincos/rope positional embeddings on-the-fly. This is due to creating numpy arrays and converting to cpu pytorch tensors and then moving to accelerator device. It's usually not a problem when prepared inside the pipeline, but for multiresolution training and specific models, you have to create it on the fly. These changes are to ensure that we remain in tensor land and don't have to deal with numpy arrays anywhere in the execution path.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Seems to be in the right direction.

@@ -78,6 +78,53 @@ def get_timestep_embedding(
return emb


def aryan_get_3d_sincos_pos_embed(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could compare this and the numpy implementation on the same inputs and see if there's any divergence. Usually a good start to try to localize if there are any problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to match for all cases except the 3D sincos ones, based on my test code above. I'll give it a look again soon


def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.Tensor, np.ndarray, int],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of an np.array, the device will be a CPU, right? Would we incur any device placement penalty for that in case we're on a non-CPU device?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an np.array is passed, what you said is right. I did not want to redo this function because it was already mostly torch land, and now it supports the extra torch.Tensor case (previously, it only supported np.ndarray and int)



def aryan_get_2d_rotary_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor, use_real: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert embed_dim % 4 == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move to raise ValueError() here and elsewhere applicable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
device: Optional[torch.device] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code, where we're calling these functions, I think it would be useful to always pass the right device. That way, we won't have to incur any device placement costs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ofcourse, sounds good!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants