head hurt

This commit is contained in:
mrq 2024-06-06 20:51:31 -05:00
parent 516b0894d7
commit a5c90348d9
3 changed files with 28 additions and 40 deletions

View File

@ -66,16 +66,11 @@ class AR_NAR(Base):
return cfg.model.tones
@property
def recurrent_chunk_size(self) -> int:
return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.model.rotary_embedding_base
"""
def causal_size(self) -> int:
# 1 for the stop token
# governs how much to shift the logits by
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 if self.causal else 0
@property
def interleave(self) -> bool:
@ -241,7 +236,7 @@ class AR_NAR(Base):
max_steps *= self.n_prom_levels
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0:
@ -463,9 +458,11 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer)
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@ -498,9 +495,11 @@ def example_usage():
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()

View File

@ -202,13 +202,9 @@ class Base(nn.Module):
raise NotImplementedError
@property
def recurrent_chunk_size(self) -> int:
def causal_size(self) -> int:
raise NotImplementedError
@property
def rotary_embedding_base(self) -> float:
return 10000
@property
def interleave(self) -> bool:
return False
@ -271,7 +267,7 @@ class Base(nn.Module):
# +1 to include the stop token
n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
n_resp_tokens = n_audio_tokens + self.causal_size
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -456,12 +452,12 @@ class Base(nn.Module):
use_biases=self.version < 3,
use_glu=self.version >= 3,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
chunkwise_recurrent=self.causal and self.causal_size > 0,
recurrent_chunkwise_size=self.causal_size if self.causal else 0,
no_output_layer=True,
decoder_normalize_before=True,
rotary_embedding_base=self.rotary_embedding_base, # 10000
rotary_embedding_base=10000
)
if n_experts > 1:
@ -486,7 +482,7 @@ class Base(nn.Module):
activation_fn="gelu",
use_glu=False, # self.version >= 3,
recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0,
recurrent_chunk_size=self.causal_size if self.causal else 0,
decoder_normalize_before=True,
deepnorm=False,
@ -710,8 +706,9 @@ class Base(nn.Module):
if quant_levels is not None and quant_levels[i] > 0:
continue
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
l = self.causal_size
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
target_list[i] = target_list[i][..., l:] # predicts token n + 1
# see comments for the split-loss calc cross_entropy call
if False:
@ -769,8 +766,9 @@ class Base(nn.Module):
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level is None or quant_level == 0:
logit = logit[..., :-1, :] # get all but the final logit
input = input[..., 1:] # shift sequence to the right by one
l = self.causal_size
logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
if name not in info:
info[name] = {
@ -802,15 +800,6 @@ class Base(nn.Module):
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
# accuracy sometimes breaks for mamba
# to-do: compute loss per individual batch to scale per RVQ level
"""
rvq_loss_factor = self.loss_factor("quant")
if isinstance( rvq_loss_factor, list ):
...
"""
def forward(
self,
@ -898,12 +887,8 @@ class Base(nn.Module):
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
# (AR) return just the last code
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
else:
logits = [ logit[-1:] for logit in logits ]
elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ]
devices = [ logit.device for logit in logits ]
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]

View File

@ -315,9 +315,11 @@ def example_usage():
engine = Engine(model=model, optimizer=optimizer)
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@ -400,9 +402,11 @@ def example_usage():
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()