local training backend should be a bit more aware of variable batch sizes, maybe
This commit is contained in:
parent
83075c1505
commit
1a392b69f6
|
@ -71,6 +71,7 @@ class Engine():
|
||||||
self.max_nan_losses = 8
|
self.max_nan_losses = 8
|
||||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||||
|
|
||||||
|
self.current_batch_size = 0
|
||||||
self._global_grad_norm = None
|
self._global_grad_norm = None
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
|
@ -108,7 +109,7 @@ class Engine():
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return cfg.hyperparameters.batch_size
|
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gradient_accumulation_steps(self):
|
def gradient_accumulation_steps(self):
|
||||||
|
@ -176,7 +177,7 @@ class Engine():
|
||||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||||
|
|
|
@ -61,6 +61,7 @@ class Engine(DeepSpeedEngine):
|
||||||
self.tokens_processed = stats["tokens_processed"]
|
self.tokens_processed = stats["tokens_processed"]
|
||||||
|
|
||||||
self.max_nan_losses = 8
|
self.max_nan_losses = 8
|
||||||
|
self.current_batch_size = 0
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
|
@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return cfg.hyperparameters.batch_size
|
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
return gather_attribute(self.module, *args, **kwargs)
|
return gather_attribute(self.module, *args, **kwargs)
|
||||||
|
|
|
@ -28,8 +28,10 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||||
|
|
||||||
def train_feeder(engine, batch):
|
def train_feeder(engine, batch):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
if engine.hyper_config.experimental:
|
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
engine.current_batch_size = batch_size
|
||||||
|
|
||||||
|
if engine.hyper_config.experimental:
|
||||||
if cfg.model.interleave:
|
if cfg.model.interleave:
|
||||||
quant_levels = 0
|
quant_levels = 0
|
||||||
resps_list = [ resp for resp in batch["resps"] ]
|
resps_list = [ resp for resp in batch["resps"] ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user