From ab5acead0e249275a7aab55ff228f5e9e16b16dc Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 May 2022 21:50:38 -0600 Subject: [PATCH] add exp loss for diffusion models --- codes/scripts/audio/prep_music/phase_1_split_files.py | 6 +++--- codes/trainer/injectors/gaussian_diffusion_injector.py | 8 ++++++++ codes/trainer/steps.py | 5 +++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/codes/scripts/audio/prep_music/phase_1_split_files.py b/codes/scripts/audio/prep_music/phase_1_split_files.py index f8065e74..5e85cbaf 100644 --- a/codes/scripts/audio/prep_music/phase_1_split_files.py +++ b/codes/scripts/audio/prep_music/phase_1_split_files.py @@ -46,15 +46,15 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip, break if not passed_checks: continue - torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate) + torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S") report_progress(progress_file, file) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\yt-music-1') + parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music2') parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\yt-music-1\\already_processed.txt') - parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\yt-music-1') + parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\music\\bigdump') parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8) parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30) args = parser.parse_args() diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 2c94846b..c3c1db17 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -28,6 +28,12 @@ class GaussianDiffusionInjector(Injector): self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) + self.recent_loss = 0 + + def extra_metrics(self): + return { + 'exp_diffusion_loss': torch.exp(self.recent_loss.mean()), + } def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -53,6 +59,8 @@ class GaussianDiffusionInjector(Injector): self.output_variational_bounds_key: diffusion_outputs['vb'], self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + self.recent_loss = diffusion_outputs['mse'] + return out diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 886c4d09..aae9b80d 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -237,6 +237,11 @@ class ConfigurableStep(Module): local_state.update(injected) new_state.update(injected) + if hasattr(inj, 'extra_metrics'): + for n, v in inj.extra_metrics().items(): + # Doesn't really work for training setups where multiple of the same injector are used. + loss_accumulator.add_loss(n, v) + if len(self.losses) > 0: # Finally, compute the losses. total_loss = 0