2
votes

I'm interested in training both a CNN model and a simple linear feed forward model in PyTorch, and after training to add more filters -- to the CNN layers, & neurons -- to the linear model layers and the outputs (e.g. from binary classification to multiclass classification) of both. By adding them I specifically mean to keep the weights that were trained constant, and to add random initialized weights to the new, incoming weights.

There's an example of a CNN model here, and an example of a simple linear feed forward model here

1

1 Answers

1
votes

This one was a bit tricky and requires slice (see this answer for more info about slice, but it should be intuitive). Also this answer for slice trick. Please see comments for explanation:

import torch
    
def expand(
    original: torch.nn.Module,
    *args,
    **kwargs
    # Add other arguments if needed, like different stride
    # They won't change weights shape, but may change behaviour
):
    new = type(original)(*args, **kwargs)

    new_weight_shape = torch.tensor(new.weight.shape)
    new_bias_shape = torch.tensor(new.bias.shape)

    original_weight_shape = torch.tensor(original.weight.shape)
    original_bias_shape = torch.tensor(original.bias.shape)
    # I assume bias and weight exist, if not, do some checks
    # Also quick check, that new layer is "larger" than original
    assert torch.all(new_weight_shape >= original_weight_shape)
    assert new_bias_shape >= original_bias_shape

    # All the weights will be inputted from top to bottom, bias 1D assumed
    new.bias.data[:original_bias_shape] = original.bias.data

    # Create slices 0:n for each dimension
    slicer = tuple([slice(0, dim) for dim in original_weight_shape])
    # And input the data
    new.weight.data[slicer] = original.weight.data

    return new


layer = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)

new = expand(layer, in_channels=32, out_channels=64, kernel_size=3)

This should work for any layer (which has weight and bias, adjust if needed). Using this approach you can recreate your neural network or use PyTorch's apply (docs here)

Also remember, that you have to explicitly pass creational *args and **kwargs for "new layer" which will have trained connections inputted.