diff --git a/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py b/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py index 32b3ecc6..d9d923b1 100644 --- a/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py +++ b/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py @@ -19,6 +19,8 @@ import dlas.trainer.losses as L from dlas.trainer.networks import register_model from dlas.utils.util import checkpoint, opt_get +import os + try: from apex import amp @@ -26,7 +28,11 @@ try: except: APEX_AVAILABLE = False -assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' + +if os.environ.get("AIVC_TRAIN_ONEAPI"): + assert torch.xpu.is_available(), 'You have chosen to train with oneAPI, but no XPU is available.' +else: + assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed. Alternatively, you may train with oneAPI.' num_cores = multiprocessing.cpu_count() diff --git a/dlas/train.py b/dlas/train.py index 4150ab18..fa616dd3 100644 --- a/dlas/train.py +++ b/dlas/train.py @@ -55,8 +55,14 @@ def init_dist(backend, **kwargs): import torch.distributed as dist rank = int(os.environ['LOCAL_RANK']) - assert rank < torch.cuda.device_count() - torch.cuda.set_device(rank) + if os.environ.get("AIVC_TRAIN_ONEAPI"): + import intel_extension_for_pytorch + import oneccl_bindings_for_pytorch + assert rank < torch.xpu.device_count() + torch.xpu.set_device(rank) + else: + assert rank < torch.cuda.device_count() + torch.cuda.set_device(rank) dist.init_process_group(backend=backend, **kwargs) diff --git a/dlas/trainer/ExtensibleTrainer.py b/dlas/trainer/ExtensibleTrainer.py index 43b01dde..12c0b159 100644 --- a/dlas/trainer/ExtensibleTrainer.py +++ b/dlas/trainer/ExtensibleTrainer.py @@ -70,6 +70,9 @@ class ExtensibleTrainer(BaseModel): self.auto_scale_basis = opt_get( opt, ['automatically_scale_base_layer_size'], 1024) + self.tdevice = "xpu:" + str(self.device) if os.environ.get("AIVC_TRAIN_ONEAPI") else "cuda:" + str(self.device) + self.tdevice = torch.device(self.tdevice) + self.netsG = {} self.netsD = {} for name, net in opt['networks'].items(): @@ -84,12 +87,12 @@ class ExtensibleTrainer(BaseModel): if net['type'] == 'generator': if new_net is None: new_net = networks.create_model( - opt, net, self.netsG).to(self.device) + opt, net, self.netsG).to(self.tdevice) self.netsG[name] = new_net elif net['type'] == 'discriminator': if new_net is None: new_net = networks.create_model( - opt, net, self.netsD).to(self.device) + opt, net, self.netsD).to(self.tdevice) self.netsD[name] = new_net else: raise NotImplementedError( @@ -155,8 +158,9 @@ class ExtensibleTrainer(BaseModel): # Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is # used and in a few other cases. But you can try it if you really want. - dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], - output_device=torch.cuda.current_device(), + + dev_id = torch.xpu.current_device() if os.environ.get("AIVC_TRAIN_ONEAPI") else torch.cuda.current_device() + dnet = DistributedDataParallel(anet, device_ids=[dev_id], output_device=dev_id, find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False)) # DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True, # which does not work with this trainer (as stated above). However, if the graph is not subject @@ -241,7 +245,7 @@ class ExtensibleTrainer(BaseModel): else: v = v[sort_indices] if isinstance(v, torch.Tensor): - self.dstate[k] = [t.to(self.device) for t in torch.chunk( + self.dstate[k] = [t.to(self.tdevice) for t in torch.chunk( v, chunks=batch_factor, dim=0)] if opt_get(self.opt, ['train', 'auto_collate'], False): diff --git a/dlas/trainer/base_model.py b/dlas/trainer/base_model.py index 71ed4930..ec3815aa 100644 --- a/dlas/trainer/base_model.py +++ b/dlas/trainer/base_model.py @@ -17,8 +17,11 @@ class BaseModel(): self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training - self.device = torch.cuda.current_device( - ) if opt['gpu_ids'] else torch.device('cpu') + + if os.environ.get("AIVC_TRAIN_ONEAPI"): + self.device = torch.xpu.current_device() + else: + self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu') self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.is_train = opt['is_train'] self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False) diff --git a/dlas/trainer/injectors/audio_injectors.py b/dlas/trainer/injectors/audio_injectors.py index 943a97a2..61874f61 100644 --- a/dlas/trainer/injectors/audio_injectors.py +++ b/dlas/trainer/injectors/audio_injectors.py @@ -9,6 +9,8 @@ from dlas.trainer.inject import Injector from dlas.utils.music_utils import get_music_codegen from dlas.utils.util import load_model_from_config, opt_get, pad_or_truncate +import os + MEL_MIN = -11.512925148010254 TACOTRON_MEL_MAX = 2.3143386840820312 TORCH_MEL_MAX = 4.82 # FYI: this STILL isn't assertive enough... @@ -185,8 +187,9 @@ class DiscreteTokenInjector(Injector): cfg = opt_get( opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml") dvae_name = opt_get(opt, ['dvae_name'], 'dvae') + devstr = "xpu:" if os.environ.get("AIVC_TRAIN_ONEAPI") else "cuda:" self.dvae = load_model_from_config( - cfg, dvae_name, device=f'cuda:{env["device"]}').eval() + cfg, dvae_name, device=devstr + str(env["device"])).eval() def forward(self, state): inp = state[self.input] diff --git a/dlas/utils/util.py b/dlas/utils/util.py index 74c4d85d..823c8a49 100644 --- a/dlas/utils/util.py +++ b/dlas/utils/util.py @@ -533,10 +533,16 @@ def load_model_from_config(cfg_file=None, model_name=None, also_load_savepoint=T # Mapper for torch.load() that maps cuda devices to the correct CUDA device, but leaves CPU devices alone. def map_cuda_to_correct_device(storage, loc): - if str(loc).startswith('cuda'): - return storage.cuda(torch.cuda.current_device()) + if os.environ.get("AIVC_TRAIN_ONEAPI"): + if str(loc).startswith('xpu'): + return storage.xpu(torch.xpu.current_device()) + else: + return storage.cpu() else: - return storage.cpu() + if str(loc).startswith('cuda'): + return storage.cuda(torch.cuda.current_device()) + else: + return storage.cpu() def list_to_device(l, dev):