more backporting
This commit is contained in:
parent
43d85d97aa
commit
90ecf3da7d
|
@ -335,7 +335,7 @@ class Dataset(_Dataset):
|
|||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size )
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
self.samplers = {}
|
||||
|
|
|
@ -71,6 +71,7 @@ class Engine():
|
|||
self.max_nan_losses = 8
|
||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||
|
||||
self.current_batch_size = 0
|
||||
self._global_grad_norm = None
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
|
@ -108,7 +109,7 @@ class Engine():
|
|||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
@property
|
||||
def gradient_accumulation_steps(self):
|
||||
|
|
|
@ -61,11 +61,9 @@ class Engine(DeepSpeedEngine):
|
|||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.current_batch_size = 0
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
for name, param in self.module.named_parameters():
|
||||
|
@ -75,6 +73,9 @@ class Engine(DeepSpeedEngine):
|
|||
self._frozen_params.add(param)
|
||||
return
|
||||
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
|
@ -99,7 +100,7 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
|
|
@ -47,6 +47,7 @@ def train_feeder(engine, batch):
|
|||
|
||||
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
|
||||
|
||||
engine.current_batch_size = batch_size
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user