3
votes

Feature visualizing in tensor flow or keras is easy and can be found here. https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/ or Convolutional Neural Network visualization - weights or activations?

how to do this in pytorch?

I am using PyTorch with pretrained resnet18 model. All i need to input the image and get activation for specific layer(e.g. Layer2.0.conv2). Layer2.0.conv2 is specified in the pretrained model.

In simple words; how to convert link one code to PyTorch? how to get the specific layers in resnet18 PyTorch and how to get the activation for input image. I tried this in tensorflow and it worked but not PyTorch.

1
You can directly access layer weight tensors, ex Layer2.conv2.weights. You can almost directly copy the code in the first link you posted and just replace the appropriate torch function. - KDecker

1 Answers

4
votes

You would have to register PyTorch's hooks on specific layer. See this tutorial for intro about hooks.

Basically, it allows to capture input/output of forward/backward going into the torch.nn.Module. Whole thing could be a bit complicated, there exists a library with similar goal to your (disclaimer I'm the author), called torchfunc. Especially torchfunc.hooks.recorder allows you to do what you want, see code snippet and comments below:

import torchvision
import torchfunc

my_network = torchvision.resnet18(pretrained=True)
# Recorder saving inputs to all submodules
recorder = torchfunc.hooks.recorders.ForwardPre()

# Will register hook for all submodules of resnet18
# You could specify some submodules by index or by layer type, see docs
recorder.modules(my_networks)

# Push example image through network
my_network(torch.randn(1, 3, 224, 224))

You could register recorder only for some layers (submodule) specified by index or layer type, to get necessary info, run:

# Zero image before going into the third submodule of this network
recorder.data[3][0]

# You can see all submodules and their positions by running this:    
for i, submodule in enumerate(my_network.modules()):
    print(i, submodule)

# Or you can just print the network to get this info
print(my_network)