From 90ecf3da7dc1db5f394f3a07bddba89513baf1d9 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Jun 2024 22:44:42 -0500 Subject: [PATCH] more backporting --- tortoise_tts/data.py | 2 +- tortoise_tts/engines/base.py | 3 ++- tortoise_tts/engines/deepspeed.py | 9 +++++---- tortoise_tts/train.py | 1 + 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py index 1e58c65..0deca4a 100755 --- a/tortoise_tts/data.py +++ b/tortoise_tts/data.py @@ -335,7 +335,7 @@ class Dataset(_Dataset): if self.sampler_type == "path": if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: - self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size ) + self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size ) else: self.sampler = OrderedSampler( len(self) ) self.samplers = {} diff --git a/tortoise_tts/engines/base.py b/tortoise_tts/engines/base.py index 09bc8f9..bb2f532 100755 --- a/tortoise_tts/engines/base.py +++ b/tortoise_tts/engines/base.py @@ -71,6 +71,7 @@ class Engine(): self.max_nan_losses = 8 self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + self.current_batch_size = 0 self._global_grad_norm = None def freeze(self, freeze_all=True): @@ -108,7 +109,7 @@ class Engine(): @property def batch_size(self): - return cfg.hyperparameters.batch_size + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size @property def gradient_accumulation_steps(self): diff --git a/tortoise_tts/engines/deepspeed.py b/tortoise_tts/engines/deepspeed.py index 8196a8f..5d3e2b1 100755 --- a/tortoise_tts/engines/deepspeed.py +++ b/tortoise_tts/engines/deepspeed.py @@ -61,11 +61,9 @@ class Engine(DeepSpeedEngine): self.tokens_processed = stats["tokens_processed"] self.max_nan_losses = 8 + self.current_batch_size = 0 def freeze(self, freeze_all=True): - if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): - raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") - # freeze non-LoRA params if requested if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: for name, param in self.module.named_parameters(): @@ -75,6 +73,9 @@ class Engine(DeepSpeedEngine): self._frozen_params.add(param) return + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) @@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine): @property def batch_size(self): - return cfg.hyperparameters.batch_size + return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index 444cfd8..5ea5c95 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -47,6 +47,7 @@ def train_feeder(engine, batch): engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths) + engine.current_batch_size = batch_size losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats")