nasty hotfix for transformer's Mixtral throwing an error when batch sizes > 1
This commit is contained in:
parent
e799665759
commit
cce929e136
2
setup.py
2
setup.py
|
@ -48,7 +48,7 @@ setup(
|
|||
"omegaconf==2.0.6",
|
||||
"tqdm>=4.64.1",
|
||||
"humanize>=4.4.0",
|
||||
"transformer>4.36.0",
|
||||
"transformers>4.37.0",
|
||||
|
||||
"pandas>=1.5.0",
|
||||
"torch>=1.13.0",
|
||||
|
|
|
@ -36,7 +36,47 @@ except Exception as e:
|
|||
|
||||
try:
|
||||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||
|
||||
# This is required because batch sizes > 1 throws errors
|
||||
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
top_x_list = top_x.tolist()
|
||||
idx_list = idx.tolist()
|
||||
|
||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
||||
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
|
||||
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
|
||||
|
||||
except Exception as e:
|
||||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
|
@ -463,7 +503,7 @@ class Base(nn.Module):
|
|||
)
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
precision = self.precision_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
|
||||
if aux_loss is not None:
|
||||
|
@ -518,7 +558,7 @@ class Base(nn.Module):
|
|||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
|
||||
# epsilon float comparison because I don't trust Python
|
||||
# epsilon float comparison because I don't trust Python
|
||||
if abs(temperature - min_temperature) >= 0.001:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
||||
else:
|
||||
|
|
|
@ -30,7 +30,7 @@ from .distributed import (
|
|||
|
||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
|
||||
from .utils import to_device, do_gc
|
||||
from .utils import to_device, do_gc, truncate_json
|
||||
from ..utils import wrapper as ml
|
||||
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||
|
||||
|
@ -174,7 +174,8 @@ def train(
|
|||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
metrics = json.dumps(stats)
|
||||
_logger.info(f"Training Metrics: {metrics}.")
|
||||
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
|
|
|
@ -19,6 +19,13 @@ from typing import Callable, TypeVar, overload
|
|||
|
||||
T = TypeVar("T")
|
||||
|
||||
def truncate_json( str ):
|
||||
|
||||
def fun( match ):
|
||||
return "{:.4f}".format(float(match.group()))
|
||||
|
||||
return re.sub(r"\d+\.\d{8,}", fun, str)
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Reference in New Issue
Block a user