This one was a bit tricky and requires slice
(see this answer for more info about slice
, but it should be intuitive). Also this answer for slice trick. Please see comments for explanation:
import torch
def expand(
original: torch.nn.Module,
*args,
**kwargs
# Add other arguments if needed, like different stride
# They won't change weights shape, but may change behaviour
):
new = type(original)(*args, **kwargs)
new_weight_shape = torch.tensor(new.weight.shape)
new_bias_shape = torch.tensor(new.bias.shape)
original_weight_shape = torch.tensor(original.weight.shape)
original_bias_shape = torch.tensor(original.bias.shape)
# I assume bias and weight exist, if not, do some checks
# Also quick check, that new layer is "larger" than original
assert torch.all(new_weight_shape >= original_weight_shape)
assert new_bias_shape >= original_bias_shape
# All the weights will be inputted from top to bottom, bias 1D assumed
new.bias.data[:original_bias_shape] = original.bias.data
# Create slices 0:n for each dimension
slicer = tuple([slice(0, dim) for dim in original_weight_shape])
# And input the data
new.weight.data[slicer] = original.weight.data
return new
layer = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
new = expand(layer, in_channels=32, out_channels=64, kernel_size=3)
This should work for any layer (which has weight
and bias
, adjust if needed). Using this approach you can recreate your neural network or use PyTorch's apply
(docs here)
Also remember, that you have to explicitly pass creational *args
and **kwargs
for "new layer" which will have trained connections inputted.