torch.utils.checkpoint¶
-
torch.utils.checkpoint.checkpoint(function, *args, **kwargs)[source]¶ Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.
Specifically, in the forward pass,
functionwill run intorch.no_grad()manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and thefunctionparameter. In the backwards pass, the saved inputs andfunctionis retrieved, and the forward pass is computed onfunctionagain, now tracking the intermediate activations, and then the gradients are calculated using these activation values.Warning
Checkpointing doesn’t work with
torch.autograd.grad(), but only withtorch.autograd.backward().Warning
If
functioninvocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.Warning
If checkpointed segment contains tensors detached from the computational graph by detach() or torch.no_grad(), the backward pass will raise an error. This is because checkpoint makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the checkpoint function.
- Parameters
function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes
(activation, hidden),functionshould correctly use the first input asactivationand the second input ashiddenpreserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
args – tuple containing inputs to the
function
- Returns
Output of running
functionon*args
-
torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)[source]¶ A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in
torch.no_grad()manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.See
checkpoint()on how checkpointing works.Warning
Checkpointing doesn’t work with
torch.autograd.grad(), but only withtorch.autograd.backward().- Parameters
functions – A
torch.nn.Sequentialor the list of modules or functions (comprising the model) to run sequentially.segments – Number of chunks to create in the model
input – A Tensor that is input to
functionspreserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
- Returns
Output of running
functionssequentially on*inputs
Example
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)