1 Pytorch中.to()和.cuda()的区别

如果需要指定的设备是GPU则.to()和.cuda()没有区别,如果设备是cpu,则不能使用.cuda()。也就是说.to()既可以指定CPU也可以指定GPU,而.cuda()只能指定GPU。

1.1 .cuda()

1.单GPU

os.environ['CUDA_VISIBLE_DEVICE']='0'
model.cuda()

2.多GPU

os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
device_ids = [0,1,2,3]
model  = torch.nn.Dataparallel(model, device_ids =device_ids)
model = model.cuda()

1.2 .to()

1.CPU或者单GPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

2.多GPU

device_ids = [0,1,2,3]
output_device = 0
model = nn.DataParallel(model,device_ids=device_ids)
model.to(output_device)

可以将上述代码进行封装

class DataParallel(Module):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]