Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when caching torch objects in rmarkdown #1199

Open
gavril0 opened this issue Oct 17, 2024 · 2 comments
Open

Errors when caching torch objects in rmarkdown #1199

gavril0 opened this issue Oct 17, 2024 · 2 comments

Comments

@gavril0
Copy link

gavril0 commented Oct 17, 2024

Caching torch objects in Rmarkdown document them give errors (see the issue in knitr repository for more details). The underlying issue is that torch objects are implemented with classes (R6/Rcpp) that use reference semantics.

A possible solution is to save reload the object in each cached chunk. However, torch_save and torch_load work only with torch tensors and torch modules, they don't work with torch datasets and torch dataloaders afaik. Morevover, torch objects can be very big which makes this solution not ideal.

Given the complexity of knitr caching mechanism, it would be great to have some guidelines for torch users (see also here and here).

@gavril0
Copy link
Author

gavril0 commented Oct 17, 2024

A possible strategy for caching objects with torch might be:

  1. cache only the chunk(s) with time-consume operations such as training torch model or creating a torch dataset (which might involve quite a bit of preprocessing). Save the module or dataset at the end of the chunk.

  2. create a uncached chunk that loads the saved module or dataset

  3. the following chunks should not need to load the object and don't need to be cached (assuming they are not time-consuming).

    ```{r, cache=TRUE}
    # create model (inherits from nn_module)
    model <- my_torch_model()
    # train model
    
    # save model
    torch_save(model, "my_model.pt")
    ```
    
    ```{r}
    # load trained model
    model <- torch_load("my_model.pt")
    ```
    
    ```{r}
    # use the model
    model$foward(x)
    ```
    

This schema automatically skips training once the chunk is cached (as long as the chunk is not changed). Dependency can be explicitly added to other cached chunks with dependon option (e.g. chunks that create the dataset used to train the model).
It could also be used with torch datasets if there were function to save and reload torch datasets.

@gavril0
Copy link
Author

gavril0 commented Oct 17, 2024

Thanks to @atusy, the knitr package will offer a new mechanism to define hook functions to cache objects (yihui/knitr#2340). This mechanism can be used to cache torch modules (see yihui/knitr#2339 (comment)). Note that the proposed hook function for caching relies on torch_save and torch_load which works only with torch tensors and torch modules, but similar hook functions can be defined for datasets provided that there is a function to save and load them.

My understanding is that these functions are best included in the package (see yihui/knitr#2340 (comment)) so that caching works transparently for the end user.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant