In torch.distributed, how to average gradients on different GPUs correctly?
Modified from https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py, the codes below can successfully make use of both GPUs (can be checked with nvidia-smi).
But one thing difficult to understand is whether the 'average_gradients' below is indeed the correct way of averaging gradients on the two models on the two GPUs. Like the codes below, the two 'model = Net()' run with two processes are two models on two different GPUs, but and the line 'average_gradients(model)' just 'averages' gradients of the model on one GPU, not two models on the two GPUs.
The question is that is the codes below indeed a correct way for averaging gradients on the two GPU? If true, how to read, how to understand the codes? If not, what is the correct way of averaging gradients on the two models below?
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from math import ceil
from random import Random
from torch.multiprocessing import Process
from torchvision import datasets, transforms
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
class Partition(object):
""" Dataset-like object, but only access a subset of it. """
def __init__(self, data, index):
self.data = data
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, index):
data_idx = self.index[index]
return self.data[data_idx]
class DataPartitioner(object):
""" Partitions a dataset into different chuncks. """
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
self.data = data
self.partitions = []
rng = Random()
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
for frac in sizes:
part_len = int(frac * data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]
def use(self, partition):
return Partition(self.data, self.partitions[partition])
class Net(nn.Module):
""" Network architecture. """
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def partition_dataset():
""" Partitioning MNIST """
dataset = datasets.MNIST(
'./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]))
size = dist.get_world_size()
bsz = int(256 / float(size))
partition_sizes = [1.0 / size for _ in range(size)]
partition = DataPartitioner(dataset, partition_sizes)
partition = partition.use(dist.get_rank())
train_set = torch.utils.data.DataLoader(
partition, batch_size=bsz, shuffle=True)
return train_set, bsz
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
param.grad.data /= size
def run(rank, size):
""" Distributed Synchronous SGD Example """
# print("107 size = ", size)
# print("dist.get_world_size() = ", dist.get_world_size()) ## 2
torch.manual_seed(1234)
train_set, bsz = partition_dataset()
device = torch.device("cuda:{}".format(rank))
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
num_batches = ceil(len(train_set.dataset) / float(bsz))
for epoch in range(10):
epoch_loss = 0.0
for data, target in train_set:
# data, target = Variable(data), Variable(target)
# data, target = Variable(data.cuda(rank)), Variable(target.cuda(rank))
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Rank ',
dist.get_rank(), ', epoch ', epoch, ': ',
epoch_loss / num_batches)
# if epoch == 4:
# from utils import module_utils
# module_utils.save_model()
def init_processes(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()