diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7ec6c8c..bb2f532 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -71,6 +71,7 @@ class Engine(): self.max_nan_losses = 8 self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + self.current_batch_size = 0 self._global_grad_norm = None def freeze(self, freeze_all=True): @@ -108,7 +109,7 @@ class Engine(): @property def batch_size(self): - return cfg.hyperparameters.batch_size + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size @property def gradient_accumulation_steps(self): @@ -176,7 +177,7 @@ class Engine(): self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] - self.module.load_state_dict(state['module']) + self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index e6d3640..5d3e2b1 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -61,6 +61,7 @@ class Engine(DeepSpeedEngine): self.tokens_processed = stats["tokens_processed"] self.max_nan_losses = 8 + self.current_batch_size = 0 def freeze(self, freeze_all=True): # freeze non-LoRA params if requested @@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine): @property def batch_size(self): - return cfg.hyperparameters.batch_size + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) diff --git a/vall_e/train.py b/vall_e/train.py index 6958066..5a21be2 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -28,8 +28,10 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + batch_size = len(batch["text"]) + engine.current_batch_size = batch_size + if engine.hyper_config.experimental: - batch_size = len(batch["text"]) if cfg.model.interleave: quant_levels = 0 resps_list = [ resp for resp in batch["resps"] ]