head hurt
This commit is contained in:
parent
516b0894d7
commit
a5c90348d9
|
@ -66,16 +66,11 @@ class AR_NAR(Base):
|
||||||
return cfg.model.tones
|
return cfg.model.tones
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recurrent_chunk_size(self) -> int:
|
def causal_size(self) -> int:
|
||||||
return 0
|
# 1 for the stop token
|
||||||
|
# governs how much to shift the logits by
|
||||||
"""
|
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||||
@property
|
return 1 if self.causal else 0
|
||||||
def rotary_embedding_base(self) -> float:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.rotary_embedding_base
|
|
||||||
return cfg.model.rotary_embedding_base
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
|
@ -241,7 +236,7 @@ class AR_NAR(Base):
|
||||||
max_steps *= self.n_prom_levels
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
|
for n in trange(max_steps // max(1, self.causal_size), desc="AR"):
|
||||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||||
if max_resp_context > 0:
|
if max_resp_context > 0:
|
||||||
|
@ -463,9 +458,11 @@ def example_usage():
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
"""
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
"""
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
|
@ -498,9 +495,11 @@ def example_usage():
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
|
"""
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
"""
|
||||||
|
|
||||||
#sample("init", 5)
|
#sample("init", 5)
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -202,13 +202,9 @@ class Base(nn.Module):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recurrent_chunk_size(self) -> int:
|
def causal_size(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def rotary_embedding_base(self) -> float:
|
|
||||||
return 10000
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
@ -271,7 +267,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_prom_tokens = n_audio_tokens
|
n_prom_tokens = n_audio_tokens
|
||||||
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
|
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -456,12 +452,12 @@ class Base(nn.Module):
|
||||||
use_biases=self.version < 3,
|
use_biases=self.version < 3,
|
||||||
use_glu=self.version >= 3,
|
use_glu=self.version >= 3,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
chunkwise_recurrent=self.causal and self.causal_size > 0,
|
||||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunkwise_size=self.causal_size if self.causal else 0,
|
||||||
no_output_layer=True,
|
no_output_layer=True,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
rotary_embedding_base=10000
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_experts > 1:
|
if n_experts > 1:
|
||||||
|
@ -486,7 +482,7 @@ class Base(nn.Module):
|
||||||
activation_fn="gelu",
|
activation_fn="gelu",
|
||||||
use_glu=False, # self.version >= 3,
|
use_glu=False, # self.version >= 3,
|
||||||
|
|
||||||
recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunk_size=self.causal_size if self.causal else 0,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
deepnorm=False,
|
deepnorm=False,
|
||||||
|
@ -710,8 +706,9 @@ class Base(nn.Module):
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
l = self.causal_size
|
||||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
logits[i] = logits[i][..., :-l, :] # shift the target so that token n...
|
||||||
|
target_list[i] = target_list[i][..., l:] # predicts token n + 1
|
||||||
|
|
||||||
# see comments for the split-loss calc cross_entropy call
|
# see comments for the split-loss calc cross_entropy call
|
||||||
if False:
|
if False:
|
||||||
|
@ -769,8 +766,9 @@ class Base(nn.Module):
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||||
if quant_level is None or quant_level == 0:
|
if quant_level is None or quant_level == 0:
|
||||||
logit = logit[..., :-1, :] # get all but the final logit
|
l = self.causal_size
|
||||||
input = input[..., 1:] # shift sequence to the right by one
|
logit = logit[..., :-l, :]
|
||||||
|
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||||
|
|
||||||
if name not in info:
|
if name not in info:
|
||||||
info[name] = {
|
info[name] = {
|
||||||
|
@ -803,15 +801,6 @@ class Base(nn.Module):
|
||||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||||
|
|
||||||
# accuracy sometimes breaks for mamba
|
|
||||||
|
|
||||||
# to-do: compute loss per individual batch to scale per RVQ level
|
|
||||||
"""
|
|
||||||
rvq_loss_factor = self.loss_factor("quant")
|
|
||||||
if isinstance( rvq_loss_factor, list ):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
@ -898,12 +887,8 @@ class Base(nn.Module):
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||||
# (AR chunkwise) return the last chunkwise piece
|
# (AR chunkwise) return the last chunkwise piece
|
||||||
elif self.causal and self.recurrent_chunk_size > 0:
|
elif self.causal:
|
||||||
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||||
# (AR) return just the last code
|
|
||||||
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
|
|
||||||
else:
|
|
||||||
logits = [ logit[-1:] for logit in logits ]
|
|
||||||
|
|
||||||
devices = [ logit.device for logit in logits ]
|
devices = [ logit.device for logit in logits ]
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
|
@ -315,9 +315,11 @@ def example_usage():
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
"""
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
"""
|
||||||
|
|
||||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
|
@ -400,9 +402,11 @@ def example_usage():
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
|
"""
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
"""
|
||||||
|
|
||||||
#sample("init", 5)
|
#sample("init", 5)
|
||||||
train()
|
train()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user