import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler
SEED = 42 BATCH_SIZE = 8 NUM_EPOCHS = 3
class YourDataset(Dataset):
def __init__(self): pass
def main(): parser = ArgumentParser('DDP usage example') parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work args = parser.parse_args()
for step, batch in enumerate(dataloader): # 将数据发送到对应的设备 batch = tuple(t.to(args.device) for t in batch)
# 正常的前向传播 outputs = model(*batch)
# 计算损失 假设是基于Transformers的模型,它会在第一个变量中返回损失 loss = outputs[0]
if __name__ == '__main__': main()
下面来对单GPU训练代码进行改造。
首先额外引入三个包:
from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler import torch.multiprocessing as mp
Args: rank (int): within the process group, each process is identified by its rank, from 0 to world_size - 1 world_size (int): the number of processes in the group """
# Initialize the process group # world_size process forms a group which is supported by a backend(nccl) # rank 0 as master node # master node: the main gpu responsible for synchronizations, making copies, loading models, writing logs. dist.init_process_group("nccl", rank=rank, world_size=world_size)
同时定义清理函数:
def cleanup(): "Cleans up the distributed environment" dist.destroy_process_group()
然后修改脚本入口代码:
if __name__ == "__main__": os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, train_args.gpus))
# Sets up the process group and configuration for PyTorch Distributed Data Parallelism os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" world_size = min(torch.cuda.device_count(), len(train_args.gpus))
if is_main_process: if train_args.use_wandb: wandb.log( { "train_loss": train_loss, "valid_bleu_score": valid_bleu_score, "valid_loss": valid_loss, } ) wandb.save(f"result-dev.txt")
if train_args.calc_bleu_during_train: if metric_score > best_score: best_score = metric_score
print(f"Save model with best bleu score :{metric_score:.2f}") # 保存验证集上bleu得分最好的模型 torch.save(module.state_dict(), train_args.model_save_path) else: if metric_score best_score = metric_score print(f"Save model with best valid loss :{metric_score:.4f}") torch.save(module.state_dict(), train_args.model_save_path) # 早停 if early_stopper.step(metric_score): print(f"stop from early stopping.") break
# let all processes sync up before starting with a new epoch of training dist.barrier()
total_loss = 0.0
tqdm_iter = tqdm(data_loader)
for step, batch in enumerate(tqdm_iter, start=1): # 发送到指定设备 source, target, labels = [ x.to(rank) for x in (batch.source, batch.target, batch.labels) ] logits = model(source, target)
# loss calculation loss = criterion(logits, labels)
loss.backward() # 支持梯度累积 if step % gradient_accumulation_steps == 0: if clip: torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() optimizer.zero_grad(set_to_none=True) scheduler.step()