changed torch.Tensor().to(device, dtype) to just torch.tensor(..., device, dtype) because it's been bothering my autism that I'm creating tensors then converting rather than creating with the right device/dtype, some 'optimization' to compile the model but it doesnt seem to do anything useful

This commit is contained in:
mrq 2024-08-03 22:10:21 -05:00
parent ab673e0426
commit 6a733eb2ed
7 changed files with 76 additions and 39 deletions

View File

@ -674,6 +674,7 @@ class Inference:
class Optimizations: class Optimizations:
injects: bool = False # overwrites default torch classes (not recommended) injects: bool = False # overwrites default torch classes (not recommended)
replace: bool = False # replaces modules in place with the optimized version (recommended) replace: bool = False # replaces modules in place with the optimized version (recommended)
compile: bool | str = False # runs torch.compile on the model
linear: bool = True # inject/replace linear for BnB linear: bool = True # inject/replace linear for BnB
embedding: bool = True # inject/replace embedding for BnB embedding: bool = True # inject/replace embedding for BnB
@ -689,6 +690,8 @@ class Optimizations:
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50% # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2 # | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
tensorrt: bool = False
@dataclass() @dataclass()
class Config(BaseConfig): class Config(BaseConfig):
device: str = "cuda" # target device device: str = "cuda" # target device

View File

@ -71,7 +71,7 @@ def fold_inputs(
if isinstance(prom, str): if isinstance(prom, str):
task = get_task_symmap()[f'<{input}>'] task = get_task_symmap()[f'<{input}>']
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype) seq = torch.tensor([task_start + task], device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
@ -81,7 +81,7 @@ def fold_inputs(
if quant_levels is not None: if quant_levels is not None:
quant_level = quant_levels[i] quant_level = quant_levels[i]
if ignore_index is not None: if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype) seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] ) ], device=device, dtype=dtype)
else: else:
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone() seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
@ -89,7 +89,7 @@ def fold_inputs(
# interleaved # interleaved
else: else:
if ignore_index is not None: if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype) seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ], device=device, dtype=dtype)
else: else:
seq = prom.flatten().to(device=device, dtype=dtype) seq = prom.flatten().to(device=device, dtype=dtype)
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
@ -111,8 +111,8 @@ def fold_inputs(
offset = 0 offset = 0
sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype) sep = torch.tensor([ sep ], device=device, dtype=dtype)
stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype) stop = torch.tensor([ stop ], device=device, dtype=dtype)
text_start = 0 text_start = 0
text_end = text_start + config.text_tokens text_end = text_start + config.text_tokens
@ -140,7 +140,7 @@ def fold_inputs(
if isinstance(text, torch.Tensor): if isinstance(text, torch.Tensor):
seq = text + text_start seq = text + text_start
else: else:
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype) seq = torch.tensor([text_start + text], device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
@ -149,7 +149,7 @@ def fold_inputs(
if isinstance(lang, torch.Tensor): if isinstance(lang, torch.Tensor):
seq = lang + lang_start seq = lang + lang_start
else: else:
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype) seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
@ -159,7 +159,7 @@ def fold_inputs(
if isinstance(rvq, torch.Tensor): if isinstance(rvq, torch.Tensor):
seq = rvq + rvq_start seq = rvq + rvq_start
else: else:
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype) seq = torch.tensor([rvq_start + rvq], device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
@ -178,7 +178,7 @@ def fold_inputs(
if isinstance(tone, torch.Tensor): if isinstance(tone, torch.Tensor):
seq = tone + tone_start seq = tone + tone_start
else: else:
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype) seq = torch.tensor([tone_start + tone], device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
@ -253,7 +253,7 @@ def unfold_outputs(
length = len(tokens) length = len(tokens)
""" """
if length % config.resp_levels == 0: if length % config.resp_levels == 0:
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t() tokens = torch.tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
""" """
bins = [ [] for _ in range(config.resp_levels) ] bins = [ [] for _ in range(config.resp_levels) ]
for pos in range( length ): for pos in range( length ):
@ -261,7 +261,7 @@ def unfold_outputs(
bins[rvq].append( tokens[pos] ) bins[rvq].append( tokens[pos] )
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
bins = bins[:nearest] bins = bins[:nearest]
return torch.Tensor(bins).t().to(device=device, dtype=dtype) return torch.tensor(bins, device=device, dtype=dtype).t()
if config is None: if config is None:
config = cfg.model config = cfg.model
@ -341,16 +341,16 @@ def unfold_outputs(
should_flush = True should_flush = True
if quant_levels is not None: if quant_levels is not None:
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype) prom_list[i] = torch.tensor(prom_list[i], device=device, dtype=dtype).t()
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype) resp_list[i] = torch.tensor(resp_list[i], device=device, dtype=dtype).t()
else: else:
prom_list[i] = bin_to_rvqs( prom_list[i] ) prom_list[i] = bin_to_rvqs( prom_list[i] )
resp_list[i] = bin_to_rvqs( resp_list[i] ) resp_list[i] = bin_to_rvqs( resp_list[i] )
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype) text_list[i] = torch.tensor( text_list[i], device=device, dtype=dtype )
task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype) task_list[i] = torch.tensor( task_list[i], device=device, dtype=dtype )
lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype) lang_list[i] = torch.tensor( lang_list[i], device=device, dtype=dtype )
tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype) tone_list[i] = torch.tensor( tone_list[i], device=device, dtype=dtype )
return dict( return dict(
text_list=text_list, text_list=text_list,
@ -1089,13 +1089,13 @@ class Dataset(_Dataset):
# create new text # create new text
text = concat_audio( text = concat_audio(
torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s> torch.tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
pre_text, pre_text,
None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " None if pre_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
edit_text, edit_text,
None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " " None if post_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
post_text, post_text,
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s> torch.tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
reencode=False, reencode=False,
) )

View File

@ -207,7 +207,9 @@ def load_engines(training=True):
# wrap if DDP is requested # wrap if DDP is requested
if ddp: if ddp:
model = ddp_model(model) model = ddp_model(model)
# wrap optimization class
elif cfg.optimizations.compile:
model = ml.compile_model(model, backend=cfg.optimizations.compile)
# deepspeed inferencing # deepspeed inferencing
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = LocalEngine engine_class = LocalEngine

View File

@ -102,7 +102,7 @@ class AR_NAR(Base):
# trim resps to only contain all levels below the target level # trim resps to only contain all levels below the target level
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
# tensor to cat for RVQ level 0 # tensor to cat for RVQ level 0
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16) stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much # I hate python's value/reference semantics so much
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
# cap quant_level if it exceeds its corresponding resp/prom # cap quant_level if it exceeds its corresponding resp/prom
@ -206,7 +206,7 @@ class AR_NAR(Base):
#mirostat=mirostat, #mirostat=mirostat,
) )
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self ) enable_lora( self )
@ -515,7 +515,7 @@ def example_usage():
# set the text prompt to empty to train without a guided text prompt # set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5: if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8) text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
texts.append( text.to(device) ) texts.append( text.to(device) )
proms.append( prom.to(device) ) proms.append( prom.to(device) )
@ -561,6 +561,11 @@ def example_usage():
#sample("init", 5) #sample("init", 5)
train() train()
"""
if cfg.optimizations.compile:
model = ml.compile_model(model, backend=cfg.optimizations.compile)
"""
for task in tasks: for task in tasks:
sample("final", task=task) sample("final", task=task)

View File

@ -134,7 +134,7 @@ class AudioEmbedding_Old(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor: def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
# prom # prom
@ -208,7 +208,7 @@ class AudioEmbedding(nn.Module):
# reintroduce stop token # reintroduce stop token
if has_stop_token: if has_stop_token:
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 ) stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] ) embedding = torch.concat( [ embedding, stop_token ] )
return embedding return embedding
@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
xi = [ xi = [
#x if l == 0 else #x if l == 0 else
x if x.shape[-1] == max_size else x if x.shape[-1] == max_size else
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 ) torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
for x, l in zip(xi, levels) for x, l in zip(xi, levels)
] ]
return torch.stack( xi ) return torch.stack( xi )
@ -894,7 +894,7 @@ class Base(nn.Module):
inputs[i].append( ( "lang", lang_list[i] ) ) inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it # insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None: if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert input audio prompt # insert input audio prompt
if proms_list is not None and proms_list[i] is not None: if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) ) inputs[i].append( ( "prom", proms_list[i] ) )
@ -922,7 +922,7 @@ class Base(nn.Module):
if self.rvq_l_emb is not None: if self.rvq_l_emb is not None:
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference) # override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels[i] = 0 quant_levels[i] = 0
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) ) inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) )
# insert input audio prompt # insert input audio prompt
if proms_list is not None and proms_list[i] is not None: if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) ) inputs[i].append( ( "prom", proms_list[i] ) )
@ -936,7 +936,7 @@ class Base(nn.Module):
# "encode" length to tokens for 0-9 + stop # "encode" length to tokens for 0-9 + stop
elif resps_list is not None and resps_list[i] is not None: elif resps_list is not None and resps_list[i] is not None:
# yes this could be encoded better # yes this could be encoded better
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) ) inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
else: else:
raise Exception(f'Unrecognized task: {task_type}') raise Exception(f'Unrecognized task: {task_type}')
@ -950,7 +950,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ): def prompt_input_to_embedding( input, quant_level ):
if isinstance(input, str): if isinstance(input, str):
return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) ) return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) )
# get RVQ level 0, or up to targetted RVQ level inference # get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4: if self.version <= 4:
@ -1068,6 +1068,8 @@ class Base(nn.Module):
inputs: list, inputs: list,
mask: Tensor, mask: Tensor,
): ):
device = mask.device
# shamelessly grabbed from modeling_llama.py # shamelessly grabbed from modeling_llama.py
ids = mask.long().cumsum(-1) - 1 ids = mask.long().cumsum(-1) - 1
ids.masked_fill_( mask == 0, 1 ) ids.masked_fill_( mask == 0, 1 )
@ -1083,26 +1085,26 @@ class Base(nn.Module):
# list of tokens # list of tokens
if not isinstance(input, torch.Tensor): if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
# ending input will not have a separator later # ending input will not have a separator later
return input.shape[0] + (0 if name in ["resp", "len"] else 1) return input.shape[0] + (0 if name in ["resp", "len"] else 1)
for batch_index, batch_input in enumerate(inputs): for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [ batch = torch.cat( [
torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32) torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
for name, input in batch_input if name != "task" for name, input in batch_input if name != "task"
] ) ] )
delta = ids[batch_index].shape[0] - batch.shape[0] delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0: if delta > 0:
batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] ) batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
x_list.append( batch ) x_list.append( batch )
ids = torch.stack( x_list ) ids = torch.stack( x_list )
return ids.to(device=mask.device, dtype=torch.int32) return ids.to(device=device, dtype=torch.int32)
def calc_loss( def calc_loss(
self, self,
@ -1117,7 +1119,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ): def prompt_input_to_token( input, quant_level ):
if isinstance(input, str): if isinstance(input, str):
return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device) return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):

View File

@ -150,7 +150,7 @@ class NAR(Base):
max_levels = self.n_resp_levels max_levels = self.n_resp_levels
# fill with mock tokens # fill with mock tokens
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ] prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
start = True start = True
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
@ -202,7 +202,7 @@ class NAR(Base):
return prev_list return prev_list
# is AR # is AR
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ] sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
stop_token = 10 stop_token = 10

View File

@ -82,6 +82,31 @@ if cfg.optimizations.injects:
torch.optim.AdamW = AdamW torch.optim.AdamW = AdamW
torch.optim.SGD = SGD torch.optim.SGD = SGD
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
print('Error while importing TensorRT:', str(e))
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
# https://github.com/konstmish/prodigy # https://github.com/konstmish/prodigy
try: try:
from prodigyopt import Prodigy from prodigyopt import Prodigy