From c93d5863fd799920166bed85c09e648c250b0b51 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 00:07:00 -0500 Subject: [PATCH] fixes --- vall_e/data.py | 10 +- vall_e/ext/retnet_hf/modeling_retnet.py | 1 + vall_e/models/__init__.py | 4 +- vall_e/models/ar_nar.py | 13 +-- vall_e/models/experimental.py | 50 +++++++-- vall_e/models/retnet_hf.py | 121 +++++++++++----------- vall_e/models/retnet_ts.py | 128 ++++++++++++------------ vall_e/train.py | 4 +- 8 files changed, 178 insertions(+), 153 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 734d178..aa6e1af 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -44,7 +44,7 @@ def fold_inputs( text_tokens = 256, audio_tokens = 1024, - audio_rvq_levels = cfg.model.prom_levels + audio_rvq_levels = cfg.model.max_levels ): def _create_mask(l, device): seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -107,7 +107,7 @@ def unfold_outputs( text_tokens = 256, audio_tokens = 1024, - audio_rvq_levels = cfg.model.prom_levels + audio_rvq_levels = cfg.model.max_levels ): device = output_ids.device batch_size = output_ids.shape[0] @@ -139,7 +139,7 @@ def unfold_outputs( bins[rvq].append( prom_list[i][pos] ) nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels bins = bins[:nearest] - prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) resp_len = len(resp_list[i]) @@ -152,9 +152,9 @@ def unfold_outputs( bins[rvq].append( resp_list[i][pos] ) nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels bins = bins[:nearest] - resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) - text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64) + text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64) return dict( text_list=text_list, diff --git a/vall_e/ext/retnet_hf/modeling_retnet.py b/vall_e/ext/retnet_hf/modeling_retnet.py index 4a7fcdc..1f9730f 100644 --- a/vall_e/ext/retnet_hf/modeling_retnet.py +++ b/vall_e/ext/retnet_hf/modeling_retnet.py @@ -963,6 +963,7 @@ class RetNetModel(RetNetPreTrainedModel): retention_mask, forward_impl, past_key_value, + use_reentrant=True, ) else: layer_outputs = layer( diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 97309ae..085bfbf 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,10 +1,9 @@ -from .ar_nar import AR_NAR -from .experimental import Model as Experimental def get_model(cfg, training=True): name = cfg.name if not cfg.experimental: + from .ar_nar import AR_NAR model = AR_NAR( n_tokens=cfg.tokens, d_model=cfg.dim, @@ -21,6 +20,7 @@ def get_model(cfg, training=True): ) model._cfg = cfg else: + from .experimental import Model as Experimental model = Experimental( d_model=cfg.dim, n_layers=cfg.layers, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 57e5af1..f98bd4f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -386,7 +386,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 200 + steps = 50 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" @@ -448,7 +448,7 @@ def example_usage(): torch.save( { 'module': model.state_dict() - }, "./data/test.pth" ) + }, f"./data/{cfg.model.arch_type}.pth" ) print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @@ -459,16 +459,11 @@ def example_usage(): engine.eval() resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - - if cfg.audio_backend != "dac": - for i, o in enumerate(resps_list): - _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) - resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) for i, o in enumerate(resps_list): - _ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device) + _ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) unload_model() @@ -484,7 +479,7 @@ def example_usage(): torch.save( { 'module': model.state_dict() - }, "./data/test.pth" ) + }, f"./data/{cfg.model.arch_type}.pth" ) sample("init", 5) train() diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 2875ed0..98dc744 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -31,6 +31,15 @@ except Exception as e: print("Error importing `llama` arch:", e) pass +try: + from .retnet_hf import RetNetConfig + from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM + + AVAILABLE_ARCHES.append("retnet") +except Exception as e: + print("Error importing `retnet` arch:", e) + pass + try: from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm @@ -75,6 +84,8 @@ if SELECTED_ARCH == "mamba": LlmArchClass = MambaLMHeadModel elif SELECTED_ARCH == "llama": LlmArchClass = LlamaForCausalLM +elif SELECTED_ARCH == "retnet": + LlmArchClass = RetNetForCausalLM else: raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") @@ -92,18 +103,19 @@ class Model(LlmArchClass): hf_attention = config.attention if config is not None else None gradient_checkpointing = config.gradient_checkpointing if config is not None else True + vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1 if SELECTED_ARCH == "llama": super().__init__(config=LlamaConfig( - vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + vocab_size=vocab_size, hidden_size=d_model, - max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds + max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds intermediate_size=d_model*4, num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout, num_key_value_heads=n_heads, - sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12, + sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, @@ -114,9 +126,31 @@ class Model(LlmArchClass): self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) + elif SELECTED_ARCH == "retnet": + super().__init__(config=RetNetConfig( + vocab_size=vocab_size, + decoder_embed_dim=d_model, + decoder_value_embed_dim =d_model * 2, + decoder_retention_heads=n_heads, + decoder_ffn_embed_dim=d_model * 4, + decoder_layers=n_layers, + dropout=p_dropout, + checkpoint_activations=gradient_checkpointing, + activation_fn="gelu", + use_layernorm=False, + use_biases=False, + use_glu=True, + + #chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, + #recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, + #no_output_layer=True, + #rotary_embedding_base=self.rotary_embedding_base, # 10000 + + decoder_normalize_before=True, + )) elif SELECTED_ARCH == "mamba": super().__init__(config=MambaConfig( - vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + vocab_size=vocab_size, d_model=d_model, n_layer=n_layers*2, #ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan @@ -132,15 +166,15 @@ class Model(LlmArchClass): ): output = super().forward(*args, **kwargs) - if SELECTED_ARCH == "llama": + if SELECTED_ARCH in ["llama", "retnet"]: if output.loss is not None: self.loss = dict( nll = output.loss, ) elif SELECTED_ARCH == "mamba": if "labels" in kwargs: - logits = output.logits labels = kwargs.pop("labels") + logits = output.logits # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() @@ -183,7 +217,7 @@ def example_usage(): def _load_quants(path) -> Tensor: qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16) qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") @@ -278,7 +312,7 @@ def example_usage(): print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() - def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ): + def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ): engine.eval() if SELECTED_ARCH == "mamba": output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3) diff --git a/vall_e/models/retnet_hf.py b/vall_e/models/retnet_hf.py index 12e0589..91ef4a9 100644 --- a/vall_e/models/retnet_hf.py +++ b/vall_e/models/retnet_hf.py @@ -1,6 +1,6 @@ # https://github.com/syncdoth/RetNet/ from ..ext.retnet_hf.configuration_retnet import RetNetConfig -from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder +from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM # things we're overriding or required to override from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos @@ -32,77 +32,74 @@ def FeedForwardNetwork_init( FeedForwardNetwork.__init__ = FeedForwardNetwork_init -# removes embed_tokens def RetNetModel_init( - self, - config: RetNetConfig, - embed_tokens: torch.nn.Embedding = None, - tensor_parallel: bool = False, - ): - super(RetNetDecoder, self).__init__(config) - self.config = config + self, + config: RetNetConfig, + embed_tokens: torch.nn.Embedding = None, + tensor_parallel: bool = False, +): + super(RetNetDecoder, self).__init__(config) + self.config = config - self.dropout_module = torch.nn.Dropout(config.dropout) + self.dropout_module = torch.nn.Dropout(config.dropout) - self.embed_dim = config.decoder_embed_dim - self.embed_scale = ( - 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + self.embed_dim = config.decoder_embed_dim + self.embed_scale = ( + 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + ) + + if embed_tokens is None and config.vocab_size: + embed_tokens = torch.nn.Embedding( + config.vocab_size, config.decoder_embed_dim, config.pad_token_id + ) + self.embed_tokens = embed_tokens + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layernorm_embedding = None + + self.layers = torch.nn.ModuleList([]) + + for i in range(config.decoder_layers): + self.layers.append( + RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) ) - """ - if embed_tokens is None: - embed_tokens = torch.nn.Embedding( - config.vocab_size, config.decoder_embed_dim, config.pad_token_id - ) - """ - self.embed_tokens = None + self.decoder_layers = len(self.layers) - if config.layernorm_embedding: - self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layernorm_embedding = None + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layer_norm = None - self.layers = torch.nn.ModuleList([]) + self.retnet_rel_pos = RetNetRelPos(config) + self.recurrent_chunk_size = config.recurrent_chunk_size - for i in range(config.decoder_layers): - self.layers.append( - RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel) - ) + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) - self.decoder_layers = len(self.layers) + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) - if config.decoder_normalize_before: - self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - - self.gradient_checkpointing = True - self.post_init() + self.gradient_checkpointing = True + self.post_init() RetNetDecoder.__init__ = RetNetModel_init diff --git a/vall_e/models/retnet_ts.py b/vall_e/models/retnet_ts.py index 76a2aea..cacc367 100644 --- a/vall_e/models/retnet_ts.py +++ b/vall_e/models/retnet_ts.py @@ -36,83 +36,81 @@ FeedForwardNetwork.__init__ = FeedForwardNetwork_init # removes embed_tokens def RetNetModel_init( - self, config, embed_tokens=None, output_projection=None, **kwargs - ): - super(RetNetDecoder, self).__init__(**kwargs) - self.config = config + self, config, embed_tokens=None, output_projection=None, **kwargs +): + super(RetNetDecoder, self).__init__(**kwargs) + self.config = config - self.dropout_module = torch.nn.Dropout(config.dropout) + self.dropout_module = torch.nn.Dropout(config.dropout) - self.embed_dim = config.decoder_embed_dim - self.embed_scale = ( - 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + self.embed_dim = config.decoder_embed_dim + self.embed_scale = ( + 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim) + ) + + if embed_tokens is None and config.vocab_size: + embed_tokens = torch.nn.Embedding( + config.vocab_size, config.decoder_embed_dim, config.pad_token_id ) + self.embed_tokens = embed_tokens + if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): + self.output_projection = self.build_output_projection(config) + else: + self.output_projection = output_projection + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layernorm_embedding = None + + self.layers = torch.nn.ModuleList([]) + + for i in range(config.decoder_layers): + layer = self.build_decoder_layer( + config, + depth=i, + ) """ - if embed_tokens is None: - embed_tokens = torch.nn.Embedding( - config.vocab_size, config.decoder_embed_dim, config.pad_token_id - ) + if config.checkpoint_activations: + layer = checkpoint_wrapper(layer) """ - self.embed_tokens = None + self.layers.append(layer) - if (output_projection is None and not config.no_output_layer and config.vocab_size > 0): - self.output_projection = self.build_output_projection(config) - else: - self.output_projection = output_projection + self.num_layers = len(self.layers) - if config.layernorm_embedding: - self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layernorm_embedding = None + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm + else: + self.layer_norm = None - self.layers = torch.nn.ModuleList([]) + self.retnet_rel_pos = RetNetRelPos(config) + self.chunkwise_recurrent = config.chunkwise_recurrent + self.recurrent_chunk_size = config.recurrent_chunk_size - for i in range(config.decoder_layers): - layer = self.build_decoder_layer( - config, - depth=i, - ) - """ - if config.checkpoint_activations: - layer = checkpoint_wrapper(layer) - """ - self.layers.append(layer) + if config.deepnorm: + init_scale = math.pow(8.0 * config.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) - self.num_layers = len(self.layers) + if config.subln and not config.use_glu: + init_scale = math.sqrt(math.log(config.decoder_layers * 2)) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.mul_(init_scale) - if config.decoder_normalize_before: - self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm - else: - self.layer_norm = None - - self.retnet_rel_pos = RetNetRelPos(config) - self.chunkwise_recurrent = config.chunkwise_recurrent - self.recurrent_chunk_size = config.recurrent_chunk_size - - if config.deepnorm: - init_scale = math.pow(8.0 * config.decoder_layers, 0.25) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.div_(init_scale) - - if config.subln and not config.use_glu: - init_scale = math.sqrt(math.log(config.decoder_layers * 2)) - for name, p in self.named_parameters(): - if ( - "fc1" in name - or "fc2" in name - or "out_proj" in name - or "v_proj" in name - ): - p.data.mul_(init_scale) - - self.gradient_checkpointing = True + self.gradient_checkpointing = True RetNetDecoder.__init__ = RetNetModel_init diff --git a/vall_e/train.py b/vall_e/train.py index 2e9a1c7..d03d87d 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -109,9 +109,9 @@ def run_eval(engines, eval_name, dl): if engine.hyper_config.experimental: input_ids, attention_mask = fold_inputs( text_list=batch["text"], - proms_list=batch["proms"], + prom_list=batch["proms"], ) - output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) + output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) resps_list = unfold_outputs( output )["resp_list"] else: resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)