Week 9: Capturing Intermediate Layer Outputs From PyTorch Models
Welcome to Week 9 of my Senior Project Blog! This week the progress has been a bit slower as I’ve been under the weather but I’ve identified a viable approach to extract the intermediate layer output for any layer for a PyTorch model.
With PyTorch, it is somewhat difficult to extract intermediate activations from the layers of a model cleanly for the purpose of visualizations, debugging, or running further algorithms such as classification, feature extraction, dimensionality reduction, etc. on the intermediate layer outputs.
The best approach provided natively in PyTorch is through hooks. Hooks are ubiquitous in software programs and are generally functions that execute after a particular event happens. For example, a website displays an ad after visiting a specific page, a banking app sends a notification when funds are credited or debited on an account.
In PyTorch, hooks are specific functions you can define, that can be attached to every layer, and are called each time the layer is used. A hook is just a callable function with a predefined signature, which can be registered to any neural network module (nn.Module) object. During the execution of the forward or backward pass, the module itself with its inputs and possible outputs is passed to the hook and executes prior to the next module’s execution. For capturing each intermediate layer’s inputs and outputs for our Autoencoder models, the forward hook allows us to freeze the execution of the forward pass at a specific module and process its inputs and outputs. The forward hook function has the following signature:
To understand how the forward hook approach works for any generic PyTorch model and show the results more visually, I experimented with a model commonly used in image classification called ResNet34 which uses a 34-layer convolutional neural network that is pre-trained on the ImageNet dataset containing 100,000+ images across 200 different classes. My goal was to take an image and use the pre-trained model from ResNet34 and visualize what each intermediate layer of the model sees by using the forward hook. Since the approach with hooks is generic, if this approach works for the ResNet model with an image input, I should be able to extract the tensors for the outputs of any layer for our lab’s autoencoder models.
First, I defined a generic object to save the outputs for any module
Then I register the forward hook for each of the layers as shown below
(Note: with this generic approach, I will only need to change one line where I check the instance type to check for the layer types used in our Autoencoder models!)
To understand how each layer of ResNet34 learns a human-created art with layers and textures, I chose my favorite painting “Starry Night” by Van Gogh
After running this image through the ResNet34 model capturing the intermediate layer outputs with the forward hooks registered above, I was able to visualize what each layer’s output is for every intermediate layer.
Output for each Resnet Layer with Van Gogh’s “Starry Night” as input
From the above, we can observe how the output of different layers capture different aspects of the original painting: some layers capture the finer details, some highlight the objects in the sky in the foreground, and some capture the textures used for the background.
Next week, I’ll use the forward hooks to extract the latent representations and embeddings for each CAD tile model using the input vcf files that I identified in the past weeks.
Thank you for reading, and see you next week!
- Baskar, Nanditha. “Intermediate Activations – the Forward Hook.” Nandita Bhaskhar/ Stanford University, 17 Aug. 2020, https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/.
- “Forward and Backward Function Hooks – NN Package¶.” Nn Package – PyTorch Tutorials 2.0.0+cu117 Documentation, PyTorch , https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks.
- He, Kaiming, et al. “Deep Residual Learning for Image Recognition.” ArXiv.org, 10 Dec. 2015, https://doi.org/10.48550/arXiv.1512.03385.
- Kathuria, Ayoosh. “Debugging and Visualisation in Pytorch Using Hooks.” Paperspace Blog, 9 Apr. 2021, https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/.
- “The Starry Night.” Encyclopædia Britannica, Inc., https://www.britannica.com/topic/The-Starry-Night.
- Tivadar Danka. “Hooks: The One Pytorch Trick You Must Know.” Mathematics of Machine Learning, Tivadar Danka, 27 Apr. 2022, https://tivadardanka.com/blog/hooks-the-one-pytorch-trick-you-must-know.