Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem caching instances of torch modules and datasets #2339

Open
gavril0 opened this issue Apr 17, 2024 · 8 comments
Open

Problem caching instances of torch modules and datasets #2339

gavril0 opened this issue Apr 17, 2024 · 8 comments

Comments

@gavril0
Copy link

gavril0 commented Apr 17, 2024

Caching chunks that create an instance of torch module or of a torch dataset yields an external pointer is not valid error when the instance is used in another chunk.

Example with torch module:

    ```{r, cache=TRUE}
    lin <- nn_linear(2, 3)
    # torch_save(lin, "lin.pt")
    ``` 
    
    ```{r}
    # lin <- torch_load("lin.pt")
    x <- torch_randn(2)
    lin$forward(x)
    ```

Example with torch dataset:

    ```{r, cache=TRUE}
    ds_gen <- dataset(
      initialize = function() {
        self$x <- torch_tensor(1:10, dtype=torch_long())
        }, 
        .getitem = function(index) {
        self$x[index]
      },
      .length = function() {
         length(self$x)
      }
    )
    
    ds <- ds_gen()
    ``` 

    ```{r}
    ds[1:3]
    ```

If there is no cache, the chunks are executed without problems. However, when a cache exists, an error is created when trying to access the cached instance of the module or of the dataset:

 Error in cpp_tensor_dim(x$ptr) : external pointer is not valid

This might be due to the fact that R torch package relies on reference classes (R6 and/or R7) and could be related to issue #2176. In any case, caching would be useful to cache trained instance of a module or instances of datasets which involve a lot processing during initialization.

At the moment, the only alternative is to save the torch model in the cached chunk with torch_save and load it in the uncached chunk with torch_load (see comments in the chunk above). However, afaik, there is no method to save and load torch datasets.

@cderv

This comment was marked as off-topic.

@gavril0

This comment was marked as off-topic.

@cderv
Copy link
Collaborator

cderv commented Oct 16, 2024

Oh thanks for correcting me! 😓

I still wonder if #2340 could help here 🤔 @atusy what is your take on this ?

@atusy
Copy link
Collaborator

atusy commented Oct 17, 2024

I'll take a look in this weekend.

@gavril0
Copy link
Author

gavril0 commented Oct 17, 2024

@atusy Thanks.

Caching a simple torch tensor can also yield an error. For example, there is an error if the second chunk is added to script after the cache has been created by knitting the script without it. The error also occurs if the second chunk is added without caching. However, the error does not occur if both chunks are present when the cache is created.

```{r, cache=TRUE, eval=TRUE}
x <- torch_tensor(1:10, dtype=torch_long())
# torch_save(x, "x.pt")
```

```{r, cache=TRUE, eval=FALSE}
# x <- torch_load("x.pt")
x[1] <- 2
x
```

For information, I think that and torch modules, torch datasets and torch dataloaders are implemented with R6 while torch tensors appears to be implemented with R7 (recently renamed S7) which creates tensor with pytorch C library via Rcpp.

One way to address the problem of caching objects implemented with reference classes is to explicitly load them in every chunk. The torch package offers methods (torch_save and torch_load) to serialize torch tensors and torch modules. However, these functions do not work for torch datasets and torch dataloaders afaik. I have opened an issue mlverse/torch#1199 in torch github.

A question I have about the cache mechanism, is whether data cached in a chunk stays in memory or must be reloaded in the next chunk. Apparently, the answer it not simple. The pragmatic approach might be to reload the cached object in every chunk (assuming that there is a function for it) but it does not seem ideal since the objects that need to be saved can be quite big.

Update

I have proposed a strategy for caching torch objects in mlverse/torch#1199 (comment). This would actually not require any work to be done on knitr side. It would rely on existing cache mechanism only to skip over some time-consuming chunk, and it would require the user to explicitly save and reload the torch object (I think that reloading the cached object needs to be done only once).

@atusy
Copy link
Collaborator

atusy commented Oct 17, 2024

@gavril0 @cderv

I confirmed #2340 solves this issue. Here is a reproducible example.

```{r}
library(knitr)

# torch package should implement this
registerS3method(
  "knit_cache_hook",
  "nn_module",
  function(x, nm, path) {
    # Cache the object
    d <- paste0(path, "__extra")
    dir.create(d, showWarnings = FALSE, recursive = TRUE)
    f <- file.path(d, paste0(nm, "pt"))
    torch::torch_save(x, f)

    # Return loader function
    structure(function(...) torch::torch_load(f), class = "knit_cache_loader")
  },
  envir = asNamespace("knitr")
)
```

```{r, cache=TRUE}
lin <- torch::nn_linear(2, 3)
``` 
    
```{r}
x <- torch::torch_randn(2)
lin$forward(x)
```

@gavril0
Copy link
Author

gavril0 commented Oct 17, 2024

@atusy

That is great! I will refer to your comment in the issue that I raised in the torch repo to probe the maintained to include these hooks in their package. Thank you.

@atusy
Copy link
Collaborator

atusy commented Oct 17, 2024

@gavril0
Your welcome!
Just to let you know that PR #2340 is not yet merged.
The torch-developer have to wait for the merge before they implement the S3 method in their package.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants