fixes
This commit is contained in:
parent
186b93a77e
commit
c93d5863fd
|
@ -44,7 +44,7 @@ def fold_inputs(
|
||||||
|
|
||||||
text_tokens = 256,
|
text_tokens = 256,
|
||||||
audio_tokens = 1024,
|
audio_tokens = 1024,
|
||||||
audio_rvq_levels = cfg.model.prom_levels
|
audio_rvq_levels = cfg.model.max_levels
|
||||||
):
|
):
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
@ -107,7 +107,7 @@ def unfold_outputs(
|
||||||
|
|
||||||
text_tokens = 256,
|
text_tokens = 256,
|
||||||
audio_tokens = 1024,
|
audio_tokens = 1024,
|
||||||
audio_rvq_levels = cfg.model.prom_levels
|
audio_rvq_levels = cfg.model.max_levels
|
||||||
):
|
):
|
||||||
device = output_ids.device
|
device = output_ids.device
|
||||||
batch_size = output_ids.shape[0]
|
batch_size = output_ids.shape[0]
|
||||||
|
@ -139,7 +139,7 @@ def unfold_outputs(
|
||||||
bins[rvq].append( prom_list[i][pos] )
|
bins[rvq].append( prom_list[i][pos] )
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||||
bins = bins[:nearest]
|
bins = bins[:nearest]
|
||||||
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
|
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
resp_len = len(resp_list[i])
|
resp_len = len(resp_list[i])
|
||||||
|
@ -152,9 +152,9 @@ def unfold_outputs(
|
||||||
bins[rvq].append( resp_list[i][pos] )
|
bins[rvq].append( resp_list[i][pos] )
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||||
bins = bins[:nearest]
|
bins = bins[:nearest]
|
||||||
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
|
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||||
|
|
||||||
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
|
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
|
@ -963,6 +963,7 @@ class RetNetModel(RetNetPreTrainedModel):
|
||||||
retention_mask,
|
retention_mask,
|
||||||
forward_impl,
|
forward_impl,
|
||||||
past_key_value,
|
past_key_value,
|
||||||
|
use_reentrant=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from .ar_nar import AR_NAR
|
|
||||||
from .experimental import Model as Experimental
|
|
||||||
|
|
||||||
def get_model(cfg, training=True):
|
def get_model(cfg, training=True):
|
||||||
name = cfg.name
|
name = cfg.name
|
||||||
|
|
||||||
if not cfg.experimental:
|
if not cfg.experimental:
|
||||||
|
from .ar_nar import AR_NAR
|
||||||
model = AR_NAR(
|
model = AR_NAR(
|
||||||
n_tokens=cfg.tokens,
|
n_tokens=cfg.tokens,
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
|
@ -21,6 +20,7 @@ def get_model(cfg, training=True):
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
else:
|
else:
|
||||||
|
from .experimental import Model as Experimental
|
||||||
model = Experimental(
|
model = Experimental(
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
|
|
|
@ -386,7 +386,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 200
|
steps = 50
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||||
|
@ -448,7 +448,7 @@ def example_usage():
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, "./data/test.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
|
@ -459,16 +459,11 @@ def example_usage():
|
||||||
|
|
||||||
engine.eval()
|
engine.eval()
|
||||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
|
||||||
if cfg.audio_backend != "dac":
|
|
||||||
for i, o in enumerate(resps_list):
|
|
||||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
|
||||||
|
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||||
|
|
||||||
for i, o in enumerate(resps_list):
|
for i, o in enumerate(resps_list):
|
||||||
_ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device)
|
_ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
|
@ -484,7 +479,7 @@ def example_usage():
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, "./data/test.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
|
|
||||||
sample("init", 5)
|
sample("init", 5)
|
||||||
train()
|
train()
|
||||||
|
|
|
@ -31,6 +31,15 @@ except Exception as e:
|
||||||
print("Error importing `llama` arch:", e)
|
print("Error importing `llama` arch:", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .retnet_hf import RetNetConfig
|
||||||
|
from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM
|
||||||
|
|
||||||
|
AVAILABLE_ARCHES.append("retnet")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error importing `retnet` arch:", e)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||||
|
|
||||||
|
@ -75,6 +84,8 @@ if SELECTED_ARCH == "mamba":
|
||||||
LlmArchClass = MambaLMHeadModel
|
LlmArchClass = MambaLMHeadModel
|
||||||
elif SELECTED_ARCH == "llama":
|
elif SELECTED_ARCH == "llama":
|
||||||
LlmArchClass = LlamaForCausalLM
|
LlmArchClass = LlamaForCausalLM
|
||||||
|
elif SELECTED_ARCH == "retnet":
|
||||||
|
LlmArchClass = RetNetForCausalLM
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
||||||
|
|
||||||
|
@ -92,18 +103,19 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
hf_attention = config.attention if config is not None else None
|
hf_attention = config.attention if config is not None else None
|
||||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||||
|
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||||
|
|
||||||
if SELECTED_ARCH == "llama":
|
if SELECTED_ARCH == "llama":
|
||||||
super().__init__(config=LlamaConfig(
|
super().__init__(config=LlamaConfig(
|
||||||
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
|
vocab_size=vocab_size,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds
|
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
|
||||||
intermediate_size=d_model*4,
|
intermediate_size=d_model*4,
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout,
|
attention_dropout=p_dropout,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=n_heads,
|
||||||
sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12,
|
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
|
@ -114,9 +126,31 @@ class Model(LlmArchClass):
|
||||||
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
use_reentrant=False
|
use_reentrant=False
|
||||||
))
|
))
|
||||||
|
elif SELECTED_ARCH == "retnet":
|
||||||
|
super().__init__(config=RetNetConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
decoder_embed_dim=d_model,
|
||||||
|
decoder_value_embed_dim =d_model * 2,
|
||||||
|
decoder_retention_heads=n_heads,
|
||||||
|
decoder_ffn_embed_dim=d_model * 4,
|
||||||
|
decoder_layers=n_layers,
|
||||||
|
dropout=p_dropout,
|
||||||
|
checkpoint_activations=gradient_checkpointing,
|
||||||
|
activation_fn="gelu",
|
||||||
|
use_layernorm=False,
|
||||||
|
use_biases=False,
|
||||||
|
use_glu=True,
|
||||||
|
|
||||||
|
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||||
|
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
|
#no_output_layer=True,
|
||||||
|
#rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||||
|
|
||||||
|
decoder_normalize_before=True,
|
||||||
|
))
|
||||||
elif SELECTED_ARCH == "mamba":
|
elif SELECTED_ARCH == "mamba":
|
||||||
super().__init__(config=MambaConfig(
|
super().__init__(config=MambaConfig(
|
||||||
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
|
vocab_size=vocab_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers*2,
|
n_layer=n_layers*2,
|
||||||
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
|
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
|
||||||
|
@ -132,15 +166,15 @@ class Model(LlmArchClass):
|
||||||
):
|
):
|
||||||
output = super().forward(*args, **kwargs)
|
output = super().forward(*args, **kwargs)
|
||||||
|
|
||||||
if SELECTED_ARCH == "llama":
|
if SELECTED_ARCH in ["llama", "retnet"]:
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = output.loss,
|
nll = output.loss,
|
||||||
)
|
)
|
||||||
elif SELECTED_ARCH == "mamba":
|
elif SELECTED_ARCH == "mamba":
|
||||||
if "labels" in kwargs:
|
if "labels" in kwargs:
|
||||||
logits = output.logits
|
|
||||||
labels = kwargs.pop("labels")
|
labels = kwargs.pop("labels")
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
@ -183,7 +217,7 @@ def example_usage():
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
qnt = np.load(path, allow_pickle=True)[()]
|
qnt = np.load(path, allow_pickle=True)[()]
|
||||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
|
||||||
|
|
||||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
|
||||||
|
@ -278,7 +312,7 @@ def example_usage():
|
||||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ):
|
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
if SELECTED_ARCH == "mamba":
|
if SELECTED_ARCH == "mamba":
|
||||||
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
|
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# https://github.com/syncdoth/RetNet/
|
# https://github.com/syncdoth/RetNet/
|
||||||
from ..ext.retnet_hf.configuration_retnet import RetNetConfig
|
from ..ext.retnet_hf.configuration_retnet import RetNetConfig
|
||||||
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder
|
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM
|
||||||
|
|
||||||
# things we're overriding or required to override
|
# things we're overriding or required to override
|
||||||
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
||||||
|
@ -32,13 +32,12 @@ def FeedForwardNetwork_init(
|
||||||
|
|
||||||
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
||||||
|
|
||||||
# removes embed_tokens
|
|
||||||
def RetNetModel_init(
|
def RetNetModel_init(
|
||||||
self,
|
self,
|
||||||
config: RetNetConfig,
|
config: RetNetConfig,
|
||||||
embed_tokens: torch.nn.Embedding = None,
|
embed_tokens: torch.nn.Embedding = None,
|
||||||
tensor_parallel: bool = False,
|
tensor_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super(RetNetDecoder, self).__init__(config)
|
super(RetNetDecoder, self).__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -49,13 +48,11 @@ def RetNetModel_init(
|
||||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
if embed_tokens is None and config.vocab_size:
|
||||||
if embed_tokens is None:
|
|
||||||
embed_tokens = torch.nn.Embedding(
|
embed_tokens = torch.nn.Embedding(
|
||||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||||
)
|
)
|
||||||
"""
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_tokens = None
|
|
||||||
|
|
||||||
if config.layernorm_embedding:
|
if config.layernorm_embedding:
|
||||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||||
|
|
|
@ -37,7 +37,7 @@ FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
||||||
# removes embed_tokens
|
# removes embed_tokens
|
||||||
def RetNetModel_init(
|
def RetNetModel_init(
|
||||||
self, config, embed_tokens=None, output_projection=None, **kwargs
|
self, config, embed_tokens=None, output_projection=None, **kwargs
|
||||||
):
|
):
|
||||||
super(RetNetDecoder, self).__init__(**kwargs)
|
super(RetNetDecoder, self).__init__(**kwargs)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -48,13 +48,11 @@ def RetNetModel_init(
|
||||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
if embed_tokens is None and config.vocab_size:
|
||||||
if embed_tokens is None:
|
|
||||||
embed_tokens = torch.nn.Embedding(
|
embed_tokens = torch.nn.Embedding(
|
||||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||||
)
|
)
|
||||||
"""
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_tokens = None
|
|
||||||
|
|
||||||
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
||||||
self.output_projection = self.build_output_projection(config)
|
self.output_projection = self.build_output_projection(config)
|
||||||
|
|
|
@ -109,9 +109,9 @@ def run_eval(engines, eval_name, dl):
|
||||||
if engine.hyper_config.experimental:
|
if engine.hyper_config.experimental:
|
||||||
input_ids, attention_mask = fold_inputs(
|
input_ids, attention_mask = fold_inputs(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
prom_list=batch["proms"],
|
||||||
)
|
)
|
||||||
output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||||
resps_list = unfold_outputs( output )["resp_list"]
|
resps_list = unfold_outputs( output )["resp_list"]
|
||||||
else:
|
else:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user