if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Parameter-Server RPC based training")
parser.add_argument("--world_size", type=int, default=4, help="Total number of participating processes. Should be the sum of master node and all training nodes, add 1 if creating training node on master.")
parser.add_argument("--rank", type=int, default=None, help="Global rank of this process. Pass in 0 for master. Note that ranks should be unique across all nodes participating in training.")
parser.add_argument("--master_addr", type=str, default="localhost", help="Address of master, will default to localhost if not provided. Master must be able to accept network traffic on the address + port.")
parser.add_argument("--master_port", type=str, default="29500", help="Port that master is listening on, will default to 29500 if not provided. Master must be able to accept network traffic on the host and port.")args = parser.parse_args()
assert args.rank is not None, "must provide rank argument."
os.environ["MASTER_ADDR"] = args.master_addr
os.environ["MASTER_PORT"] = args.master_port
// Get data to train on
train_loader = torch.utils.data.DataLoader(
datasets.MNIST("../data", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=32, shuffle=True,)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("../data", train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=32, shuffle=True, )
processes = []
world_size = args.world_size
if args.rank == 0:
p = mp.Process(target=run_parameter_server, args=(0, world_size))
p.start()
processes.append(p)
else:
// start training worker on this node
p = mp.Process(target=run_worker, args=(args.rank, world_size, train_loader, test_loader))
p.start()
processes.append(p)