local training backend should be a bit more aware of variable batch sizes, maybe
This commit is contained in:
parent
83075c1505
commit
1a392b69f6
|
@ -71,6 +71,7 @@ class Engine():
|
|||
self.max_nan_losses = 8
|
||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||
|
||||
self.current_batch_size = 0
|
||||
self._global_grad_norm = None
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
|
@ -108,7 +109,7 @@ class Engine():
|
|||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
@property
|
||||
def gradient_accumulation_steps(self):
|
||||
|
@ -176,7 +177,7 @@ class Engine():
|
|||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||
self.module.load_state_dict(state['module'])
|
||||
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
|
|
@ -61,6 +61,7 @@ class Engine(DeepSpeedEngine):
|
|||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.current_batch_size = 0
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
|
@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
|
|
@ -28,8 +28,10 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
|||
|
||||
def train_feeder(engine, batch):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
if engine.hyper_config.experimental:
|
||||
batch_size = len(batch["text"])
|
||||
if cfg.model.interleave:
|
||||
quant_levels = 0
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user