diff --git a/setup.py b/setup.py index 2b48f40..b5ece1a 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ setup( "omegaconf==2.0.6", "tqdm>=4.64.1", "humanize>=4.4.0", - "transformer>4.36.0", + "transformers>4.37.0", "pandas>=1.5.0", "torch>=1.13.0", diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f678a2d..e872687 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -36,7 +36,47 @@ except Exception as e: try: from transformers import MixtralModel, MixtralConfig - from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock + + # This is required because batch sizes > 1 throws errors + def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_dim) # was view() + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward + MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward + except Exception as e: print("Error importing `mixtral` arch:", e) @@ -463,7 +503,7 @@ class Base(nn.Module): ) self.stats = dict( acc = self.accuracy_metric( inputs, target ), - precision = self.precision_metric( inputs, target ), + # precision = self.precision_metric( inputs, target ), ) if aux_loss is not None: @@ -518,7 +558,7 @@ class Base(nn.Module): logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] # trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature - # epsilon float comparison because I don't trust Python + # epsilon float comparison because I don't trust Python if abs(temperature - min_temperature) >= 0.001: logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ] else: diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index d1afd84..100c7a5 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -30,7 +30,7 @@ from .distributed import ( from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines -from .utils import to_device, do_gc +from .utils import to_device, do_gc, truncate_json from ..utils import wrapper as ml from ..data import get_phone_symmap # should decouple from this trainer script @@ -174,7 +174,8 @@ def train( elapsed_time = stats.get("elapsed_time", 0) metrics = json.dumps(stats) - _logger.info(f"Training Metrics: {metrics}.") + + _logger.info(f"Training Metrics: {truncate_json(metrics)}.") command = _non_blocking_input() diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 988f595..e92239a 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -19,6 +19,13 @@ from typing import Callable, TypeVar, overload T = TypeVar("T") +def truncate_json( str ): + + def fun( match ): + return "{:.4f}".format(float(match.group())) + + return re.sub(r"\d+\.\d{8,}", fun, str) + def do_gc(): gc.collect() torch.cuda.empty_cache()