head hurt
This commit is contained in:
parent
516b0894d7
commit
a5c90348d9
|
@ -66,16 +66,11 @@ class AR_NAR(Base):
|
|||
return cfg.model.tones
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
||||
"""
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.rotary_embedding_base
|
||||
return cfg.model.rotary_embedding_base
|
||||
"""
|
||||
def causal_size(self) -> int:
|
||||
# 1 for the stop token
|
||||
# governs how much to shift the logits by
|
||||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||
return 1 if self.causal else 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
|
@ -241,7 +236,7 @@ class AR_NAR(Base):
|
|||
max_steps *= self.n_prom_levels
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
|
||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||
if max_resp_context > 0:
|
||||
|
@ -463,9 +458,11 @@ def example_usage():
|
|||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
|
@ -498,9 +495,11 @@ def example_usage():
|
|||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
|
|
|
@ -202,13 +202,9 @@ class Base(nn.Module):
|
|||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
def causal_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
return 10000
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
@ -271,7 +267,7 @@ class Base(nn.Module):
|
|||
|
||||
# +1 to include the stop token
|
||||
n_prom_tokens = n_audio_tokens
|
||||
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -456,12 +452,12 @@ class Base(nn.Module):
|
|||
use_biases=self.version < 3,
|
||||
use_glu=self.version >= 3,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
chunkwise_recurrent=self.causal and self.causal_size > 0,
|
||||
recurrent_chunkwise_size=self.causal_size if self.causal else 0,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
|
||||
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||
rotary_embedding_base=10000
|
||||
)
|
||||
|
||||
if n_experts > 1:
|
||||
|
@ -486,7 +482,7 @@ class Base(nn.Module):
|
|||
activation_fn="gelu",
|
||||
use_glu=False, # self.version >= 3,
|
||||
|
||||
recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
recurrent_chunk_size=self.causal_size if self.causal else 0,
|
||||
decoder_normalize_before=True,
|
||||
|
||||
deepnorm=False,
|
||||
|
@ -710,8 +706,9 @@ class Base(nn.Module):
|
|||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||
l = self.causal_size
|
||||
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||
|
||||
# see comments for the split-loss calc cross_entropy call
|
||||
if False:
|
||||
|
@ -769,8 +766,9 @@ class Base(nn.Module):
|
|||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if quant_level is None or quant_level == 0:
|
||||
logit = logit[..., :-1, :] # get all but the final logit
|
||||
input = input[..., 1:] # shift sequence to the right by one
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
||||
if name not in info:
|
||||
info[name] = {
|
||||
|
@ -802,15 +800,6 @@ class Base(nn.Module):
|
|||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
|
||||
# accuracy sometimes breaks for mamba
|
||||
|
||||
# to-do: compute loss per individual batch to scale per RVQ level
|
||||
"""
|
||||
rvq_loss_factor = self.loss_factor("quant")
|
||||
if isinstance( rvq_loss_factor, list ):
|
||||
...
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -898,12 +887,8 @@ class Base(nn.Module):
|
|||
if quant_levels is not None:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal and self.recurrent_chunk_size > 0:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
||||
# (AR) return just the last code
|
||||
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
|
||||
else:
|
||||
logits = [ logit[-1:] for logit in logits ]
|
||||
elif self.causal:
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
|
|
@ -315,9 +315,11 @@ def example_usage():
|
|||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
|
@ -400,9 +402,11 @@ def example_usage():
|
|||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue
Block a user