In torch.distributed, how to average gradients on different GPUs correctly?
Modified from https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py, the codes below can successfully make use of both GPUs (can be checked with nvidia-smi).
But one thing difficult to understand is whether the 'average_gradients' below is indeed the correct way of averaging gradients on the two models on the two GPUs. Like the codes below, the two 'model = Net()' run with two processes are two models on two different GPUs, but and the line 'average_gradients(model)' just 'averages' gradients of the model on one GPU, not two models on the two GPUs.
The question is that is the codes below indeed a correct way for averaging gradients on the two GPU? If true, how to read, how to understand the codes? If not, what is the correct way of averaging gradients on the two models below?
import os import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from math import ceil from random import Random from torch.multiprocessing import Process from torchvision import datasets, transforms os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" class Partition(object): """ Dataset-like object, but only access a subset of it. """ def __init__(self, data, index): self.data = data self.index = index def __len__(self): return len(self.index) def __getitem__(self, index): data_idx = self.index[index] return self.data[data_idx] class DataPartitioner(object): """ Partitions a dataset into different chuncks. """ def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): self.data = data self.partitions = [] rng = Random() rng.seed(seed) data_len = len(data) indexes = [x for x in range(0, data_len)] rng.shuffle(indexes) for frac in sizes: part_len = int(frac * data_len) self.partitions.append(indexes[0:part_len]) indexes = indexes[part_len:] def use(self, partition): return Partition(self.data, self.partitions[partition]) class Net(nn.Module): """ Network architecture. """ def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) def partition_dataset(): """ Partitioning MNIST """ dataset = datasets.MNIST( './data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) size = dist.get_world_size() bsz = int(256 / float(size)) partition_sizes = [1.0 / size for _ in range(size)] partition = DataPartitioner(dataset, partition_sizes) partition = partition.use(dist.get_rank()) train_set = torch.utils.data.DataLoader( partition, batch_size=bsz, shuffle=True) return train_set, bsz def average_gradients(model): """ Gradient averaging. """ size = float(dist.get_world_size()) for param in model.parameters(): dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) param.grad.data /= size def run(rank, size): """ Distributed Synchronous SGD Example """ # print("107 size = ", size) # print("dist.get_world_size() = ", dist.get_world_size()) ## 2 torch.manual_seed(1234) train_set, bsz = partition_dataset() device = torch.device("cuda:{}".format(rank)) model = Net() model = model.to(device) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) num_batches = ceil(len(train_set.dataset) / float(bsz)) for epoch in range(10): epoch_loss = 0.0 for data, target in train_set: # data, target = Variable(data), Variable(target) # data, target = Variable(data.cuda(rank)), Variable(target.cuda(rank)) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) epoch_loss += loss.item() loss.backward() average_gradients(model) optimizer.step() print('Rank ', dist.get_rank(), ', epoch ', epoch, ': ', epoch_loss / num_batches) # if epoch == 4: # from utils import module_utils # module_utils.save_model() def init_processes(rank, size, fn, backend='gloo'): """ Initialize the distributed environment. """ os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(backend, rank=rank, world_size=size) fn(rank, size) if __name__ == "__main__": size = 2 processes = [] for rank in range(size): p = Process(target=init_processes, args=(rank, size, run)) p.start() processes.append(p) for p in processes: p.join()