changed torch.Tensor().to(device, dtype) to just torch.tensor(..., device, dtype) because it's been bothering my autism that I'm creating tensors then converting rather than creating with the right device/dtype, some 'optimization' to compile the model but it doesnt seem to do anything useful
This commit is contained in:
parent
ab673e0426
commit
6a733eb2ed
|
@ -674,6 +674,7 @@ class Inference:
|
||||||
class Optimizations:
|
class Optimizations:
|
||||||
injects: bool = False # overwrites default torch classes (not recommended)
|
injects: bool = False # overwrites default torch classes (not recommended)
|
||||||
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||||
|
compile: bool | str = False # runs torch.compile on the model
|
||||||
|
|
||||||
linear: bool = True # inject/replace linear for BnB
|
linear: bool = True # inject/replace linear for BnB
|
||||||
embedding: bool = True # inject/replace embedding for BnB
|
embedding: bool = True # inject/replace embedding for BnB
|
||||||
|
@ -689,6 +690,8 @@ class Optimizations:
|
||||||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||||
|
|
||||||
|
tensorrt: bool = False
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
device: str = "cuda" # target device
|
device: str = "cuda" # target device
|
||||||
|
|
|
@ -71,7 +71,7 @@ def fold_inputs(
|
||||||
|
|
||||||
if isinstance(prom, str):
|
if isinstance(prom, str):
|
||||||
task = get_task_symmap()[f'<{input}>']
|
task = get_task_symmap()[f'<{input}>']
|
||||||
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype)
|
seq = torch.tensor([task_start + task], device=device, dtype=dtype)
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
@ -81,7 +81,7 @@ def fold_inputs(
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
quant_level = quant_levels[i]
|
quant_level = quant_levels[i]
|
||||||
if ignore_index is not None:
|
if ignore_index is not None:
|
||||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype)
|
seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] ) ], device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
|
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
|
@ -89,7 +89,7 @@ def fold_inputs(
|
||||||
# interleaved
|
# interleaved
|
||||||
else:
|
else:
|
||||||
if ignore_index is not None:
|
if ignore_index is not None:
|
||||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype)
|
seq = torch.tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ], device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
seq = prom.flatten().to(device=device, dtype=dtype)
|
seq = prom.flatten().to(device=device, dtype=dtype)
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
|
@ -111,8 +111,8 @@ def fold_inputs(
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype)
|
sep = torch.tensor([ sep ], device=device, dtype=dtype)
|
||||||
stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype)
|
stop = torch.tensor([ stop ], device=device, dtype=dtype)
|
||||||
|
|
||||||
text_start = 0
|
text_start = 0
|
||||||
text_end = text_start + config.text_tokens
|
text_end = text_start + config.text_tokens
|
||||||
|
@ -140,7 +140,7 @@ def fold_inputs(
|
||||||
if isinstance(text, torch.Tensor):
|
if isinstance(text, torch.Tensor):
|
||||||
seq = text + text_start
|
seq = text + text_start
|
||||||
else:
|
else:
|
||||||
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype)
|
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ def fold_inputs(
|
||||||
if isinstance(lang, torch.Tensor):
|
if isinstance(lang, torch.Tensor):
|
||||||
seq = lang + lang_start
|
seq = lang + lang_start
|
||||||
else:
|
else:
|
||||||
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype)
|
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ def fold_inputs(
|
||||||
if isinstance(rvq, torch.Tensor):
|
if isinstance(rvq, torch.Tensor):
|
||||||
seq = rvq + rvq_start
|
seq = rvq + rvq_start
|
||||||
else:
|
else:
|
||||||
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype)
|
seq = torch.tensor([rvq_start + rvq], device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ def fold_inputs(
|
||||||
if isinstance(tone, torch.Tensor):
|
if isinstance(tone, torch.Tensor):
|
||||||
seq = tone + tone_start
|
seq = tone + tone_start
|
||||||
else:
|
else:
|
||||||
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype)
|
seq = torch.tensor([tone_start + tone], device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ def unfold_outputs(
|
||||||
length = len(tokens)
|
length = len(tokens)
|
||||||
"""
|
"""
|
||||||
if length % config.resp_levels == 0:
|
if length % config.resp_levels == 0:
|
||||||
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
tokens = torch.tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
||||||
"""
|
"""
|
||||||
bins = [ [] for _ in range(config.resp_levels) ]
|
bins = [ [] for _ in range(config.resp_levels) ]
|
||||||
for pos in range( length ):
|
for pos in range( length ):
|
||||||
|
@ -261,7 +261,7 @@ def unfold_outputs(
|
||||||
bins[rvq].append( tokens[pos] )
|
bins[rvq].append( tokens[pos] )
|
||||||
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
|
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
|
||||||
bins = bins[:nearest]
|
bins = bins[:nearest]
|
||||||
return torch.Tensor(bins).t().to(device=device, dtype=dtype)
|
return torch.tensor(bins, device=device, dtype=dtype).t()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = cfg.model
|
config = cfg.model
|
||||||
|
@ -341,16 +341,16 @@ def unfold_outputs(
|
||||||
should_flush = True
|
should_flush = True
|
||||||
|
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype)
|
prom_list[i] = torch.tensor(prom_list[i], device=device, dtype=dtype).t()
|
||||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype)
|
resp_list[i] = torch.tensor(resp_list[i], device=device, dtype=dtype).t()
|
||||||
else:
|
else:
|
||||||
prom_list[i] = bin_to_rvqs( prom_list[i] )
|
prom_list[i] = bin_to_rvqs( prom_list[i] )
|
||||||
resp_list[i] = bin_to_rvqs( resp_list[i] )
|
resp_list[i] = bin_to_rvqs( resp_list[i] )
|
||||||
|
|
||||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype)
|
text_list[i] = torch.tensor( text_list[i], device=device, dtype=dtype )
|
||||||
task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype)
|
task_list[i] = torch.tensor( task_list[i], device=device, dtype=dtype )
|
||||||
lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype)
|
lang_list[i] = torch.tensor( lang_list[i], device=device, dtype=dtype )
|
||||||
tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype)
|
tone_list[i] = torch.tensor( tone_list[i], device=device, dtype=dtype )
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -1089,13 +1089,13 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# create new text
|
# create new text
|
||||||
text = concat_audio(
|
text = concat_audio(
|
||||||
torch.Tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
torch.tensor( [ bos_id ] ).to(dtype=self.text_dtype), # <s>
|
||||||
pre_text,
|
pre_text,
|
||||||
None if pre_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
None if pre_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||||
edit_text,
|
edit_text,
|
||||||
None if post_text is None else torch.Tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
None if post_text is None else torch.tensor( [ space_id ] ).to(dtype=self.text_dtype), # " "
|
||||||
post_text,
|
post_text,
|
||||||
torch.Tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
torch.tensor( [ eos_id ] ).to(dtype=self.text_dtype), # </s>
|
||||||
|
|
||||||
reencode=False,
|
reencode=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -207,7 +207,9 @@ def load_engines(training=True):
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
model = ddp_model(model)
|
model = ddp_model(model)
|
||||||
|
# wrap optimization class
|
||||||
|
elif cfg.optimizations.compile:
|
||||||
|
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||||
engine_class = LocalEngine
|
engine_class = LocalEngine
|
||||||
|
|
|
@ -102,7 +102,7 @@ class AR_NAR(Base):
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
# tensor to cat for RVQ level 0
|
# tensor to cat for RVQ level 0
|
||||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||||
# I hate python's value/reference semantics so much
|
# I hate python's value/reference semantics so much
|
||||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
|
@ -206,7 +206,7 @@ class AR_NAR(Base):
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self )
|
enable_lora( self )
|
||||||
|
@ -515,7 +515,7 @@ def example_usage():
|
||||||
|
|
||||||
# set the text prompt to empty to train without a guided text prompt
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8)
|
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
texts.append( text.to(device) )
|
texts.append( text.to(device) )
|
||||||
proms.append( prom.to(device) )
|
proms.append( prom.to(device) )
|
||||||
|
@ -561,6 +561,11 @@ def example_usage():
|
||||||
#sample("init", 5)
|
#sample("init", 5)
|
||||||
train()
|
train()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cfg.optimizations.compile:
|
||||||
|
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||||
|
"""
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
sample("final", task=task)
|
sample("final", task=task)
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,7 @@ class AudioEmbedding_Old(nn.Module):
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||||
|
|
||||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||||
# prom
|
# prom
|
||||||
|
@ -208,7 +208,7 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
# reintroduce stop token
|
# reintroduce stop token
|
||||||
if has_stop_token:
|
if has_stop_token:
|
||||||
stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
||||||
embedding = torch.concat( [ embedding, stop_token ] )
|
embedding = torch.concat( [ embedding, stop_token ] )
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
|
||||||
xi = [
|
xi = [
|
||||||
#x if l == 0 else
|
#x if l == 0 else
|
||||||
x if x.shape[-1] == max_size else
|
x if x.shape[-1] == max_size else
|
||||||
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||||
for x, l in zip(xi, levels)
|
for x, l in zip(xi, levels)
|
||||||
]
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
@ -894,7 +894,7 @@ class Base(nn.Module):
|
||||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||||
# insert RVQ level guidance token if the model is versioned for it
|
# insert RVQ level guidance token if the model is versioned for it
|
||||||
if self.rvq_l_emb is not None:
|
if self.rvq_l_emb is not None:
|
||||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||||
# insert input audio prompt
|
# insert input audio prompt
|
||||||
if proms_list is not None and proms_list[i] is not None:
|
if proms_list is not None and proms_list[i] is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
@ -922,7 +922,7 @@ class Base(nn.Module):
|
||||||
if self.rvq_l_emb is not None:
|
if self.rvq_l_emb is not None:
|
||||||
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
|
||||||
quant_levels[i] = 0
|
quant_levels[i] = 0
|
||||||
inputs[i].append( ( "quant_level", torch.Tensor([ self.n_resp_levels ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) )
|
||||||
# insert input audio prompt
|
# insert input audio prompt
|
||||||
if proms_list is not None and proms_list[i] is not None:
|
if proms_list is not None and proms_list[i] is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
@ -936,7 +936,7 @@ class Base(nn.Module):
|
||||||
# "encode" length to tokens for 0-9 + stop
|
# "encode" length to tokens for 0-9 + stop
|
||||||
elif resps_list is not None and resps_list[i] is not None:
|
elif resps_list is not None and resps_list[i] is not None:
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Unrecognized task: {task_type}')
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
|
|
||||||
|
@ -950,7 +950,7 @@ class Base(nn.Module):
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_embedding( input, quant_level ):
|
def prompt_input_to_embedding( input, quant_level ):
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return self.tasks_emb( torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(device=device, dtype=torch.int16) )
|
return self.tasks_emb( torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) )
|
||||||
|
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
|
@ -1068,6 +1068,8 @@ class Base(nn.Module):
|
||||||
inputs: list,
|
inputs: list,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
):
|
):
|
||||||
|
device = mask.device
|
||||||
|
|
||||||
# shamelessly grabbed from modeling_llama.py
|
# shamelessly grabbed from modeling_llama.py
|
||||||
ids = mask.long().cumsum(-1) - 1
|
ids = mask.long().cumsum(-1) - 1
|
||||||
ids.masked_fill_( mask == 0, 1 )
|
ids.masked_fill_( mask == 0, 1 )
|
||||||
|
@ -1083,26 +1085,26 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# list of tokens
|
# list of tokens
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
|
||||||
|
|
||||||
# ending input will not have a separator later
|
# ending input will not have a separator later
|
||||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||||
|
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = torch.cat( [
|
batch = torch.cat( [
|
||||||
torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32)
|
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
||||||
for name, input in batch_input if name != "task"
|
for name, input in batch_input if name != "task"
|
||||||
] )
|
] )
|
||||||
|
|
||||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
batch = torch.cat( [ batch, torch.Tensor([1] * delta) ] )
|
batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
|
||||||
|
|
||||||
x_list.append( batch )
|
x_list.append( batch )
|
||||||
|
|
||||||
ids = torch.stack( x_list )
|
ids = torch.stack( x_list )
|
||||||
|
|
||||||
return ids.to(device=mask.device, dtype=torch.int32)
|
return ids.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
def calc_loss(
|
def calc_loss(
|
||||||
self,
|
self,
|
||||||
|
@ -1117,7 +1119,7 @@ class Base(nn.Module):
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return torch.Tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||||
|
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
|
|
|
@ -150,7 +150,7 @@ class NAR(Base):
|
||||||
max_levels = self.n_resp_levels
|
max_levels = self.n_resp_levels
|
||||||
|
|
||||||
# fill with mock tokens
|
# fill with mock tokens
|
||||||
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
|
prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||||
|
|
||||||
start = True
|
start = True
|
||||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||||
|
@ -202,7 +202,7 @@ class NAR(Base):
|
||||||
return prev_list
|
return prev_list
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
stop_token = 10
|
stop_token = 10
|
||||||
|
|
|
@ -82,6 +82,31 @@ if cfg.optimizations.injects:
|
||||||
torch.optim.AdamW = AdamW
|
torch.optim.AdamW = AdamW
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
|
AVAILABLE_COMPILE_BACKENDS = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if cfg.optimizations.tensorrt:
|
||||||
|
try:
|
||||||
|
import torch_tensorrt
|
||||||
|
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||||
|
except Exception as e:
|
||||||
|
print('Error while importing TensorRT:', str(e))
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compile_model(model, backend="auto"):
|
||||||
|
if not backend or backend == "auto":
|
||||||
|
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||||
|
|
||||||
|
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
||||||
|
return torch.compile(model)
|
||||||
|
|
||||||
|
return torch.compile(model, backend=backend)
|
||||||
|
|
||||||
# https://github.com/konstmish/prodigy
|
# https://github.com/konstmish/prodigy
|
||||||
try:
|
try:
|
||||||
from prodigyopt import Prodigy
|
from prodigyopt import Prodigy
|
||||||
|
|
Loading…
Reference in New Issue
Block a user