Add initial oneAPI support

This commit is contained in:
a-One-Fan 2023-04-30 23:05:24 +03:00
parent b6a213bbbd
commit 44d2dcbb19
4 changed files with 44 additions and 4 deletions

View File

@ -680,9 +680,17 @@ class TextToSpeech:
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
text_tokens = migrate_to_device( text_tokens, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
if self.device.type == 'xpu': # The following autocasts were hardcoded for CUDA
_device_type = 'xpu'
_dtype = torch.bfloat16 # float16 support for oneAPI was missing / is worse?
else:
_device_type = 'cuda' # Should these be changed to just use the device directly? Unsure how this will do for dml/rocm
_dtype = torch.float16
with torch.autocast(device_type=_device_type, dtype=_dtype, enabled=half_p):
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
check_for_kill_signal()
do_gc() # oneAPI VRAM
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
top_p=top_p,
@ -710,7 +718,7 @@ class TextToSpeech:
if auto_conds is not None:
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
with torch.autocast(device_type=_device_type, dtype=_dtype, enabled=half_p):
if not self.preloaded_tensors:
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
self.clvp = migrate_to_device( self.clvp, self.device )

View File

@ -309,7 +309,11 @@ class DiffusionTts(nn.Module):
else:
# First and last blocks will have autocast disabled for improved precision.
# x.device.type
with autocast(device_type='cuda', enabled=self.enable_fp16 and i != 0):
if self.device.type == 'xpu': # The following autocast was hardcoded for CUDA
_device_type = 'xpu'
else:
_device_type = 'cuda' # Should these be changed to just use the device directly? Unsure how this will do for dml/rocm
with autocast(device_type=_device_type, enabled=self.enable_fp16 and i != 0):
x = lyr(x, time_emb)
x = x.float()

View File

@ -8,6 +8,10 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo
import gc
def xpu_get_mem(device=0):
total_memory = ipex.xpu.get_device_properties(device).total_memory
return total_memory, total_memory - torch.xpu.memory_allocated(device)
def do_gc():
gc.collect()
try:
@ -15,6 +19,11 @@ def do_gc():
except Exception as e:
pass
try:
torch.xpu.empty_cache()
except Exception as e:
pass
def print_stats(collect=False):
cf = currentframe().f_back
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
@ -36,6 +45,16 @@ def has_dml():
import torch_directml
return torch_directml.is_available()
def has_ipex():
loader = importlib.find_loader('intel_extension_for_pytorch')
if loader is None:
return False
import intel_extension_for_pytorch
global ipex
ipex = intel_extension_for_pytorch # Could doing this over and over be an issue?
return torch.xpu.is_available()
def set_device_name(name):
global DEVICE_OVERRIDE
DEVICE_OVERRIDE = name
@ -51,6 +70,10 @@ def get_device_name(attempt_gc=True):
name = 'cuda'
if attempt_gc:
torch.cuda.empty_cache() # may have performance implications
elif has_ipex():
name = 'xpu'
if attempt_gc:
torch.xpu.empty_cache()
elif has_dml():
name = 'dml'
@ -76,6 +99,8 @@ def get_device_vram( name=get_device_name() ):
if name == "cuda":
_, available = torch.cuda.mem_get_info()
elif name == "xpu":
_, available = xpu_get_mem()
elif name == "cpu":
available = psutil.virtual_memory()[4]

View File

@ -1271,7 +1271,10 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
if timesteps.device.type == 'xpu': # TODO: Arc currently does not support FP64 broadly, and this will change eventually. Remove when this happens?
res = th.from_numpy(arr).float().to(device=timesteps.device)[timesteps]
else:
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)