From 4ddd01a7fb50e01eac6d2a923c86807d52272850 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 29 Jul 2022 09:19:20 -0600 Subject: [PATCH] support generating cheaters from the new cheater network --- .../audio/music/transformer_diffusion14.py | 20 +++++++++++++------ .../prep_music/generate_long_cheaters.py | 10 ++++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion14.py b/codes/models/audio/music/transformer_diffusion14.py index 4d28c7c5..c318a800 100644 --- a/codes/models/audio/music/transformer_diffusion14.py +++ b/codes/models/audio/music/transformer_diffusion14.py @@ -143,13 +143,17 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): self.diff = TransformerDiffusion(**kwargs) self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) - def forward(self, x, timesteps, truth_mel, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None): unused_parameters = [] encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until if not encoder_grad_enabled: unused_parameters.extend(list(self.encoder.parameters())) - with torch.set_grad_enabled(encoder_grad_enabled): - proj = self.encoder(truth_mel) + + if cheater is None: + with torch.set_grad_enabled(encoder_grad_enabled): + proj = self.encoder(truth_mel) + else: + proj = cheater for p in unused_parameters: proj = proj + p.mean() * 0 @@ -177,6 +181,10 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): p.grad *= .2 +def get_cheater_encoder_v2(): + return ResEncoder16x(256, 1024, 256, checkpointing_enabled=False) + + @register_model def register_transformer_diffusion14(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -221,12 +229,12 @@ def extract_cheater_encoder(in_f, out_f): out = {} for k, v in p.items(): if k.startswith('encoder.'): - out[k] = v + out[k[len('encoder.'):]] = v torch.save(out, out_f) if __name__ == '__main__': #test_local_attention_mask() - #extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True) - test_cheater_model() + extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth') + #test_cheater_model() #extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) diff --git a/codes/scripts/audio/prep_music/generate_long_cheaters.py b/codes/scripts/audio/prep_music/generate_long_cheaters.py index 2ee11f6a..e59bb067 100644 --- a/codes/scripts/audio/prep_music/generate_long_cheaters.py +++ b/codes/scripts/audio/prep_music/generate_long_cheaters.py @@ -16,13 +16,16 @@ from tqdm import tqdm from trainer.injectors.audio_injectors import MusicCheaterLatentInjector +from codes.models.audio.music.transformer_diffusion14 import get_cheater_encoder_v2 + def report_progress(progress_file, file): with open(progress_file, 'a', encoding='utf-8') as f: f.write(f'{file}\n') -cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}) +model = get_cheater_encoder_v2().eval().cpu() +model.load_state_dict(torch.load('../experiments/tfd14_cheater_encoder.pth', map_location=torch.device('cpu'))) def process_folder(file, base_path, output_path, progress_file): @@ -30,7 +33,10 @@ def process_folder(file, base_path, output_path, progress_file): os.makedirs(outdir, exist_ok=True) with np.load(file) as npz_file: mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0) - cheater = cheater_inj({'in': mel})['out'] + global model + model = model.cuda() + with torch.no_grad(): + cheater = model(mel) np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy()) report_progress(progress_file, file)