more backporting

This commit is contained in:
mrq 2024-06-28 22:44:42 -05:00
parent 43d85d97aa
commit 90ecf3da7d
4 changed files with 9 additions and 6 deletions

View File

@ -335,7 +335,7 @@ class Dataset(_Dataset):
if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size )
else:
self.sampler = OrderedSampler( len(self) )
self.samplers = {}

View File

@ -71,6 +71,7 @@ class Engine():
self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self.current_batch_size = 0
self._global_grad_norm = None
def freeze(self, freeze_all=True):
@ -108,7 +109,7 @@ class Engine():
@property
def batch_size(self):
return cfg.hyperparameters.batch_size
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
@property
def gradient_accumulation_steps(self):

View File

@ -61,11 +61,9 @@ class Engine(DeepSpeedEngine):
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
self.current_batch_size = 0
def freeze(self, freeze_all=True):
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
for name, param in self.module.named_parameters():
@ -75,6 +73,9 @@ class Engine(DeepSpeedEngine):
self._frozen_params.add(param)
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine):
@property
def batch_size(self):
return cfg.hyperparameters.batch_size
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)

View File

@ -47,6 +47,7 @@ def train_feeder(engine, batch):
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
engine.current_batch_size = batch_size
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")