support generating cheaters from the new cheater network

This commit is contained in:
James Betker 2022-07-29 09:19:20 -06:00
parent 27a9b1b750
commit 4ddd01a7fb
2 changed files with 22 additions and 8 deletions

View File

@ -143,13 +143,17 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
self.diff = TransformerDiffusion(**kwargs) self.diff = TransformerDiffusion(**kwargs)
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
def forward(self, x, timesteps, truth_mel, conditioning_free=False): def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None):
unused_parameters = [] unused_parameters = []
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
if not encoder_grad_enabled: if not encoder_grad_enabled:
unused_parameters.extend(list(self.encoder.parameters())) unused_parameters.extend(list(self.encoder.parameters()))
with torch.set_grad_enabled(encoder_grad_enabled):
proj = self.encoder(truth_mel) if cheater is None:
with torch.set_grad_enabled(encoder_grad_enabled):
proj = self.encoder(truth_mel)
else:
proj = cheater
for p in unused_parameters: for p in unused_parameters:
proj = proj + p.mean() * 0 proj = proj + p.mean() * 0
@ -177,6 +181,10 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
p.grad *= .2 p.grad *= .2
def get_cheater_encoder_v2():
return ResEncoder16x(256, 1024, 256, checkpointing_enabled=False)
@register_model @register_model
def register_transformer_diffusion14(opt_net, opt): def register_transformer_diffusion14(opt_net, opt):
return TransformerDiffusion(**opt_net['kwargs']) return TransformerDiffusion(**opt_net['kwargs'])
@ -221,12 +229,12 @@ def extract_cheater_encoder(in_f, out_f):
out = {} out = {}
for k, v in p.items(): for k, v in p.items():
if k.startswith('encoder.'): if k.startswith('encoder.'):
out[k] = v out[k[len('encoder.'):]] = v
torch.save(out, out_f) torch.save(out, out_f)
if __name__ == '__main__': if __name__ == '__main__':
#test_local_attention_mask() #test_local_attention_mask()
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True) extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
test_cheater_model() #test_cheater_model()
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) #extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)

View File

@ -16,13 +16,16 @@ from tqdm import tqdm
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
from codes.models.audio.music.transformer_diffusion14 import get_cheater_encoder_v2
def report_progress(progress_file, file): def report_progress(progress_file, file):
with open(progress_file, 'a', encoding='utf-8') as f: with open(progress_file, 'a', encoding='utf-8') as f:
f.write(f'{file}\n') f.write(f'{file}\n')
cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}) model = get_cheater_encoder_v2().eval().cpu()
model.load_state_dict(torch.load('../experiments/tfd14_cheater_encoder.pth', map_location=torch.device('cpu')))
def process_folder(file, base_path, output_path, progress_file): def process_folder(file, base_path, output_path, progress_file):
@ -30,7 +33,10 @@ def process_folder(file, base_path, output_path, progress_file):
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
with np.load(file) as npz_file: with np.load(file) as npz_file:
mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0) mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0)
cheater = cheater_inj({'in': mel})['out'] global model
model = model.cuda()
with torch.no_grad():
cheater = model(mel)
np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy()) np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy())
report_progress(progress_file, file) report_progress(progress_file, file)