changed torch.Tensor().to(device, dtype) to just torch.tensor(..., device, dtype) because it's been bothering my autism that I'm creating tensors then converting rather than creating with the right device/dtype, some 'optimization' to compile the model but it doesnt seem to do anything useful
This commit is contained in:
parent
ab673e0426
commit
6a733eb2ed
|
@ -674,6 +674,7 @@ class Inference:
|
|||
class Optimizations:
|
||||
injects: bool = False # overwrites default torch classes (not recommended)
|
||||
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||
compile: bool | str = False # runs torch.compile on the model
|
||||
|
||||
linear: bool = True # inject/replace linear for BnB
|
||||
embedding: bool = True # inject/replace embedding for BnB
|
||||
|
@ -689,6 +690,8 @@ class Optimizations:
|
|||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||
|
||||
tensorrt: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(BaseConfig):
|
||||
device: str = "cuda" # target device
|
||||
|
|
|
@ -71,7 +71,7 @@ def fold_inputs(
|
|||
|
||||
if isinstance(prom, str):
|
||||
task = get_task_symmap()[f'<{input}>']
|
||||
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor([task_start + task], device=device, dtype=dtype)
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
@ -81,7 +81,7 @@ def fold_inputs(
|
|||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] ) ], device=device, dtype=dtype)
|
||||
else:
|
||||
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||
for idx, token in enumerate( seq ):
|
||||
|
@ -89,7 +89,7 @@ def fold_inputs(
|
|||
# interleaved
|
||||
else:
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ], device=device, dtype=dtype)
|
||||
else:
|
||||
seq = prom.flatten().to(device=device, dtype=dtype)
|
||||
for idx, token in enumerate( seq ):
|
||||
|
@ -111,8 +111,8 @@ def fold_inputs(
|
|||
|
||||
offset = 0
|
||||
|
||||
sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype)
|
||||
stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype)
|
||||
sep = torch.tensor([ sep ], device=device, dtype=dtype)
|
||||
stop = torch.tensor([ stop ], device=device, dtype=dtype)
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
@ -140,7 +140,7 @@ def fold_inputs(
|
|||
if isinstance(text, torch.Tensor):
|
||||
seq = text + text_start
|
||||
else:
|
||||
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
|
@ -149,7 +149,7 @@ def fold_inputs(
|
|||
if isinstance(lang, torch.Tensor):
|
||||
seq = lang + lang_start
|
||||
else:
|
||||
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
|
@ -159,7 +159,7 @@ def fold_inputs(
|
|||
if isinstance(rvq, torch.Tensor):
|
||||
seq = rvq + rvq_start
|
||||
else:
|
||||
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor([rvq_start + rvq], device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
|
@ -178,7 +178,7 @@ def fold_inputs(
|
|||
if isinstance(tone, torch.Tensor):
|
||||
seq = tone + tone_start
|
||||
else:
|
||||
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype)
|
||||
seq = torch.tensor([tone_start + tone], device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
|
@ -253,7 +253,7 @@ def unfold_outputs(
|
|||
length = len(tokens)
|
||||
"""
|
||||
if length % config.resp_levels == 0:
|
||||
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
||||
tokens = torch.tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
||||
"""
|
||||
bins = [ [] for _ in range(config.resp_levels) ]
|
||||
for pos in range( length ):
|
||||
|
@ -261,7 +261,7 @@ def unfold_outputs(
|
|||
bins[rvq].append( tokens[pos] )
|
||||
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
|
||||
bins = bins[:nearest]
|
||||
return torch.Tensor(bins).t().to(device=device, dtype=dtype)
|
||||
return torch.tensor(bins, device=device, dtype=dtype).t()
|
||||
|
||||
if config is None:
|
||||
config = cfg.model
|
||||
|
@ -341,16 +341,16 @@ def unfold_outputs(
|
|||
should_flush = True
|
||||
|
||||
if quant_levels is not None:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype)
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype)
|
||||
prom_list[i] = torch.tensor(prom_list[i], device=device, dtype=dtype).t()
|
||||
resp_list[i] = torch.tensor(resp_list[i], device=device, dtype=dtype).t()
|
||||
else:
|
||||
prom_list[i] = bin_to_rvqs( prom_list[i] )
|
||||
resp_list[i] = bin_to_rvqs( resp_list[i] )
|
||||
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype)
|
||||
task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype)
|
||||
lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype)
|
||||
tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype)
|
||||
text_list[i] = torch.tensor( text_list[i], device=device, dtype=dtype )
|
||||
task_list[i] = torch.tensor( task_list[i], device=device, dtype=dtype )
|
||||
lang_list[i] = torch.tensor( lang_list[i], device=device, dtype=dtype )
|
||||
tone_list[i] = torch.tensor( tone_list[i], device=device, dtype=dtype )
|
||||
|
||||
return dict(
|
||||
text_list=text_list,
|
||||
|
@ -1089,13 +1089,13 @@ class Dataset(_Dataset):
|
|||
|
||||
# create new text
|
||||
text = concat_audio(
|
||||
torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
||||
torch.tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
||||
pre_text,
|
||||
None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
None if pre_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
edit_text,
|
||||
None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
None if post_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||
post_text,
|
||||
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
||||
torch.tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
||||
|
||||
reencode=False,
|
||||
)
|
||||
|
|
|
@ -207,7 +207,9 @@ def load_engines(training=True):
|
|||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
|
||||
# wrap optimization class
|
||||
elif cfg.optimizations.compile:
|
||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
# deepspeed inferencing
|
||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = LocalEngine
|
||||
|
|
|
@ -102,7 +102,7 @@ class AR_NAR(Base):
|
|||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
# tensor to cat for RVQ level 0
|
||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
||||
stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
|
@ -206,7 +206,7 @@ class AR_NAR(Base):
|
|||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self )
|
||||
|
@ -515,7 +515,7 @@ def example_usage():
|
|||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8)
|
||||
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
||||
|
||||
texts.append( text.to(device) )
|
||||
proms.append( prom.to(device) )
|
||||
|
@ -560,6 +560,11 @@ def example_usage():
|
|||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
|
||||
"""
|
||||
if cfg.optimizations.compile:
|
||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
"""
|
||||
|
||||
for task in tasks:
|
||||
sample("final", task=task)
|
||||
|
|
|
@ -134,7 +134,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
|
@ -208,7 +208,7 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
# reintroduce stop token
|
||||
if has_stop_token:
|
||||
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
||||
stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
||||
embedding = torch.concat( [ embedding, stop_token ] )
|
||||
|
||||
return embedding
|
||||
|
@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
|
|||
xi = [
|
||||
#x if l == 0 else
|
||||
x if x.shape[-1] == max_size else
|
||||
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
for x, l in zip(xi, levels)
|
||||
]
|
||||
return torch.stack( xi )
|
||||
|
@ -894,7 +894,7 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
|
@ -922,7 +922,7 @@ class Base(nn.Module):
|
|||
if self.rvq_l_emb is not None:
|
||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||
quant_levels[i] = 0
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) )
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
|
@ -936,7 +936,7 @@ class Base(nn.Module):
|
|||
# "encode" length to tokens for 0-9 + stop
|
||||
elif resps_list is not None and resps_list[i] is not None:
|
||||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||
else:
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
|
||||
|
@ -950,7 +950,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_embedding( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) )
|
||||
return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) )
|
||||
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -1068,6 +1068,8 @@ class Base(nn.Module):
|
|||
inputs: list,
|
||||
mask: Tensor,
|
||||
):
|
||||
device = mask.device
|
||||
|
||||
# shamelessly grabbed from modeling_llama.py
|
||||
ids = mask.long().cumsum(-1) - 1
|
||||
ids.masked_fill_( mask == 0, 1 )
|
||||
|
@ -1083,26 +1085,26 @@ class Base(nn.Module):
|
|||
|
||||
# list of tokens
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
|
||||
|
||||
# ending input will not have a separator later
|
||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = torch.cat( [
|
||||
torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32)
|
||||
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
||||
for name, input in batch_input if name != "task"
|
||||
] )
|
||||
|
||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||
if delta > 0:
|
||||
batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] )
|
||||
batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
|
||||
|
||||
x_list.append( batch )
|
||||
|
||||
ids = torch.stack( x_list )
|
||||
|
||||
return ids.to(device=mask.device, dtype=torch.int32)
|
||||
return ids.to(device=device, dtype=torch.int32)
|
||||
|
||||
def calc_loss(
|
||||
self,
|
||||
|
@ -1117,7 +1119,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
|
|
|
@ -150,7 +150,7 @@ class NAR(Base):
|
|||
max_levels = self.n_resp_levels
|
||||
|
||||
# fill with mock tokens
|
||||
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
|
||||
start = True
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
|
@ -202,7 +202,7 @@ class NAR(Base):
|
|||
return prev_list
|
||||
|
||||
# is AR
|
||||
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = 10
|
||||
|
|
|
@ -82,6 +82,31 @@ if cfg.optimizations.injects:
|
|||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
|
||||
AVAILABLE_COMPILE_BACKENDS = []
|
||||
|
||||
try:
|
||||
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
if cfg.optimizations.tensorrt:
|
||||
try:
|
||||
import torch_tensorrt
|
||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||
except Exception as e:
|
||||
print('Error while importing TensorRT:', str(e))
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
||||
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
||||
return torch.compile(model)
|
||||
|
||||
return torch.compile(model, backend=backend)
|
||||
|
||||
# https://github.com/konstmish/prodigy
|
||||
try:
|
||||
from prodigyopt import Prodigy
|
||||
|
|
Loading…
Reference in New Issue
Block a user