local training backend should be a bit more aware of variable batch sizes, maybe

This commit is contained in:
mrq 2024-06-28 22:39:05 -05:00
parent 83075c1505
commit 1a392b69f6
3 changed files with 8 additions and 4 deletions

View File

@ -71,6 +71,7 @@ class Engine():
self.max_nan_losses = 8 self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self.current_batch_size = 0
self._global_grad_norm = None self._global_grad_norm = None
def freeze(self, freeze_all=True): def freeze(self, freeze_all=True):
@ -108,7 +109,7 @@ class Engine():
@property @property
def batch_size(self): def batch_size(self):
return cfg.hyperparameters.batch_size return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
@property @property
def gradient_accumulation_steps(self): def gradient_accumulation_steps(self):
@ -176,7 +177,7 @@ class Engine():
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module']) self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state

View File

@ -61,6 +61,7 @@ class Engine(DeepSpeedEngine):
self.tokens_processed = stats["tokens_processed"] self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8 self.max_nan_losses = 8
self.current_batch_size = 0
def freeze(self, freeze_all=True): def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested # freeze non-LoRA params if requested
@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine):
@property @property
def batch_size(self): def batch_size(self):
return cfg.hyperparameters.batch_size return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs) return gather_attribute(self.module, *args, **kwargs)

View File

@ -28,8 +28,10 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch): def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.experimental:
batch_size = len(batch["text"]) batch_size = len(batch["text"])
engine.current_batch_size = batch_size
if engine.hyper_config.experimental:
if cfg.model.interleave: if cfg.model.interleave:
quant_levels = 0 quant_levels = 0
resps_list = [ resp for resp in batch["resps"] ] resps_list = [ resp for resp in batch["resps"] ]