diff --git a/vall_e/config.py b/vall_e/config.py index 194fd76..d830183 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -674,6 +674,7 @@ class Inference: class Optimizations: injects: bool = False # overwrites default torch classes (not recommended) replace: bool = False # replaces modules in place with the optimized version (recommended) + compile: bool | str = False # runs torch.compile on the model linear: bool = True # inject/replace linear for BnB embedding: bool = True # inject/replace embedding for BnB @@ -689,6 +690,8 @@ class Optimizations: # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50% # | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2 + tensorrt: bool = False + @dataclass() class Config(BaseConfig): device: str = "cuda" # target device diff --git a/vall_e/data.py b/vall_e/data.py index 16e23d3..4610884 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -71,7 +71,7 @@ def fold_inputs( if isinstance(prom, str): task = get_task_symmap()[f'<{input}>'] - seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype) + seq = torch.tensor([task_start + task], device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) @@ -81,7 +81,7 @@ def fold_inputs( if quant_levels is not None: quant_level = quant_levels[i] if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype) + seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] ) ], device=device, dtype=dtype) else: seq = prom[:, quant_level].to(device=device, dtype=dtype).clone() for idx, token in enumerate( seq ): @@ -89,7 +89,7 @@ def fold_inputs( # interleaved else: if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype) + seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ], device=device, dtype=dtype) else: seq = prom.flatten().to(device=device, dtype=dtype) for idx, token in enumerate( seq ): @@ -111,8 +111,8 @@ def fold_inputs( offset = 0 - sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype) - stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype) + sep = torch.tensor([ sep ], device=device, dtype=dtype) + stop = torch.tensor([ stop ], device=device, dtype=dtype) text_start = 0 text_end = text_start + config.text_tokens @@ -140,7 +140,7 @@ def fold_inputs( if isinstance(text, torch.Tensor): seq = text + text_start else: - seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype) + seq = torch.tensor([text_start + text], device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) @@ -149,7 +149,7 @@ def fold_inputs( if isinstance(lang, torch.Tensor): seq = lang + lang_start else: - seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype) + seq = torch.tensor([lang_start + lang], device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) @@ -159,7 +159,7 @@ def fold_inputs( if isinstance(rvq, torch.Tensor): seq = rvq + rvq_start else: - seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype) + seq = torch.tensor([rvq_start + rvq], device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) @@ -178,7 +178,7 @@ def fold_inputs( if isinstance(tone, torch.Tensor): seq = tone + tone_start else: - seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype) + seq = torch.tensor([tone_start + tone], device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) @@ -253,7 +253,7 @@ def unfold_outputs( length = len(tokens) """ if length % config.resp_levels == 0: - tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t() + tokens = torch.tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t() """ bins = [ [] for _ in range(config.resp_levels) ] for pos in range( length ): @@ -261,7 +261,7 @@ def unfold_outputs( bins[rvq].append( tokens[pos] ) nearest = ( len(bins) // config.resp_levels ) * config.resp_levels bins = bins[:nearest] - return torch.Tensor(bins).t().to(device=device, dtype=dtype) + return torch.tensor(bins, device=device, dtype=dtype).t() if config is None: config = cfg.model @@ -341,16 +341,16 @@ def unfold_outputs( should_flush = True if quant_levels is not None: - prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype) - resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype) + prom_list[i] = torch.tensor(prom_list[i], device=device, dtype=dtype).t() + resp_list[i] = torch.tensor(resp_list[i], device=device, dtype=dtype).t() else: prom_list[i] = bin_to_rvqs( prom_list[i] ) resp_list[i] = bin_to_rvqs( resp_list[i] ) - text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype) - task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype) - lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype) - tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype) + text_list[i] = torch.tensor( text_list[i], device=device, dtype=dtype ) + task_list[i] = torch.tensor( task_list[i], device=device, dtype=dtype ) + lang_list[i] = torch.tensor( lang_list[i], device=device, dtype=dtype ) + tone_list[i] = torch.tensor( tone_list[i], device=device, dtype=dtype ) return dict( text_list=text_list, @@ -1089,13 +1089,13 @@ class Dataset(_Dataset): # create new text text = concat_audio( - torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # + torch.tensor( [ bos_id ] ).to(dtype=self.text_dtype), # pre_text, - None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " + None if pre_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " edit_text, - None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " + None if post_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " post_text, - torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # + torch.tensor( [ eos_id ] ).to(dtype=self.text_dtype), # reencode=False, ) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2e3b38a..3046aaf 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -207,7 +207,9 @@ def load_engines(training=True): # wrap if DDP is requested if ddp: model = ddp_model(model) - + # wrap optimization class + elif cfg.optimizations.compile: + model = ml.compile_model(model, backend=cfg.optimizations.compile) # deepspeed inferencing elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): engine_class = LocalEngine diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d01a708..cf84924 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -102,7 +102,7 @@ class AR_NAR(Base): # trim resps to only contain all levels below the target level resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # tensor to cat for RVQ level 0 - stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16) + stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) # I hate python's value/reference semantics so much for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): # cap quant_level if it exceeds its corresponding resp/prom @@ -206,7 +206,7 @@ class AR_NAR(Base): #mirostat=mirostat, ) - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] if cfg.lora is not None: enable_lora( self ) @@ -515,7 +515,7 @@ def example_usage(): # set the text prompt to empty to train without a guided text prompt if random.random() < 0.5: - text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8) + text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8) texts.append( text.to(device) ) proms.append( prom.to(device) ) @@ -560,6 +560,11 @@ def example_usage(): #sample("init", 5) train() + + """ + if cfg.optimizations.compile: + model = ml.compile_model(model, backend=cfg.optimizations.compile) + """ for task in tasks: sample("final", task=task) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 4340f14..1771c32 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -134,7 +134,7 @@ class AudioEmbedding_Old(nn.Module): # resp are split to where [0] is for the AR, and [1:] are reserved for NAR self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) - self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None + self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor: # prom @@ -208,7 +208,7 @@ class AudioEmbedding(nn.Module): # reintroduce stop token if has_stop_token: - stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 ) + stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 ) embedding = torch.concat( [ embedding, stop_token ] ) return embedding @@ -257,7 +257,7 @@ class AudioClassifier(nn.Module): xi = [ #x if l == 0 else x if x.shape[-1] == max_size else - torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 ) + torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 ) for x, l in zip(xi, levels) ] return torch.stack( xi ) @@ -894,7 +894,7 @@ class Base(nn.Module): inputs[i].append( ( "lang", lang_list[i] ) ) # insert RVQ level guidance token if the model is versioned for it if self.rvq_l_emb is not None: - inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) + inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert input audio prompt if proms_list is not None and proms_list[i] is not None: inputs[i].append( ( "prom", proms_list[i] ) ) @@ -922,7 +922,7 @@ class Base(nn.Module): if self.rvq_l_emb is not None: # override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference) quant_levels[i] = 0 - inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) ) + inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) ) # insert input audio prompt if proms_list is not None and proms_list[i] is not None: inputs[i].append( ( "prom", proms_list[i] ) ) @@ -936,7 +936,7 @@ class Base(nn.Module): # "encode" length to tokens for 0-9 + stop elif resps_list is not None and resps_list[i] is not None: # yes this could be encoded better - inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) ) + inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) ) else: raise Exception(f'Unrecognized task: {task_type}') @@ -950,7 +950,7 @@ class Base(nn.Module): # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_embedding( input, quant_level ): if isinstance(input, str): - return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) ) + return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) ) # get RVQ level 0, or up to targetted RVQ level inference if self.version <= 4: @@ -1068,6 +1068,8 @@ class Base(nn.Module): inputs: list, mask: Tensor, ): + device = mask.device + # shamelessly grabbed from modeling_llama.py ids = mask.long().cumsum(-1) - 1 ids.masked_fill_( mask == 0, 1 ) @@ -1083,26 +1085,26 @@ class Base(nn.Module): # list of tokens if not isinstance(input, torch.Tensor): - return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 + return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1 # ending input will not have a separator later return input.shape[0] + (0 if name in ["resp", "len"] else 1) for batch_index, batch_input in enumerate(inputs): batch = torch.cat( [ - torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32) + torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32) for name, input in batch_input if name != "task" ] ) delta = ids[batch_index].shape[0] - batch.shape[0] if delta > 0: - batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] ) + batch = torch.cat( [ batch, torch.tensor([1] * delta) ] ) x_list.append( batch ) ids = torch.stack( x_list ) - return ids.to(device=mask.device, dtype=torch.int32) + return ids.to(device=device, dtype=torch.int32) def calc_loss( self, @@ -1117,7 +1119,7 @@ class Base(nn.Module): # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): if isinstance(input, str): - return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device) + return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 67efc03..b2ccc65 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -150,7 +150,7 @@ class NAR(Base): max_levels = self.n_resp_levels # fill with mock tokens - prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ] + prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ] start = True for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): @@ -202,7 +202,7 @@ class NAR(Base): return prev_list # is AR - sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ] + sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() stop_token = 10 diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 7d986e4..d4fb780 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -82,6 +82,31 @@ if cfg.optimizations.injects: torch.optim.AdamW = AdamW torch.optim.SGD = SGD +AVAILABLE_COMPILE_BACKENDS = [] + +try: + AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends() +except Exception as e: + pass + + +if cfg.optimizations.tensorrt: + try: + import torch_tensorrt + AVAILABLE_COMPILE_BACKENDS.append("tensorrt") + except Exception as e: + print('Error while importing TensorRT:', str(e)) + pass + +def compile_model(model, backend="auto"): + if not backend or backend == "auto": + backend = AVAILABLE_COMPILE_BACKENDS[0] + + if backend not in AVAILABLE_COMPILE_BACKENDS: + return torch.compile(model) + + return torch.compile(model, backend=backend) + # https://github.com/konstmish/prodigy try: from prodigyopt import Prodigy