forked from mrq/DL-Art-School
support generating cheaters from the new cheater network
This commit is contained in:
parent
27a9b1b750
commit
4ddd01a7fb
|
@ -143,13 +143,17 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None):
|
||||||
unused_parameters = []
|
unused_parameters = []
|
||||||
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
|
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
|
||||||
if not encoder_grad_enabled:
|
if not encoder_grad_enabled:
|
||||||
unused_parameters.extend(list(self.encoder.parameters()))
|
unused_parameters.extend(list(self.encoder.parameters()))
|
||||||
with torch.set_grad_enabled(encoder_grad_enabled):
|
|
||||||
proj = self.encoder(truth_mel)
|
if cheater is None:
|
||||||
|
with torch.set_grad_enabled(encoder_grad_enabled):
|
||||||
|
proj = self.encoder(truth_mel)
|
||||||
|
else:
|
||||||
|
proj = cheater
|
||||||
|
|
||||||
for p in unused_parameters:
|
for p in unused_parameters:
|
||||||
proj = proj + p.mean() * 0
|
proj = proj + p.mean() * 0
|
||||||
|
@ -177,6 +181,10 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
|
def get_cheater_encoder_v2():
|
||||||
|
return ResEncoder16x(256, 1024, 256, checkpointing_enabled=False)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion14(opt_net, opt):
|
def register_transformer_diffusion14(opt_net, opt):
|
||||||
return TransformerDiffusion(**opt_net['kwargs'])
|
return TransformerDiffusion(**opt_net['kwargs'])
|
||||||
|
@ -221,12 +229,12 @@ def extract_cheater_encoder(in_f, out_f):
|
||||||
out = {}
|
out = {}
|
||||||
for k, v in p.items():
|
for k, v in p.items():
|
||||||
if k.startswith('encoder.'):
|
if k.startswith('encoder.'):
|
||||||
out[k] = v
|
out[k[len('encoder.'):]] = v
|
||||||
torch.save(out, out_f)
|
torch.save(out, out_f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#test_local_attention_mask()
|
#test_local_attention_mask()
|
||||||
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
|
||||||
test_cheater_model()
|
#test_cheater_model()
|
||||||
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||||
|
|
|
@ -16,13 +16,16 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
||||||
|
|
||||||
|
from codes.models.audio.music.transformer_diffusion14 import get_cheater_encoder_v2
|
||||||
|
|
||||||
|
|
||||||
def report_progress(progress_file, file):
|
def report_progress(progress_file, file):
|
||||||
with open(progress_file, 'a', encoding='utf-8') as f:
|
with open(progress_file, 'a', encoding='utf-8') as f:
|
||||||
f.write(f'{file}\n')
|
f.write(f'{file}\n')
|
||||||
|
|
||||||
|
|
||||||
cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {})
|
model = get_cheater_encoder_v2().eval().cpu()
|
||||||
|
model.load_state_dict(torch.load('../experiments/tfd14_cheater_encoder.pth', map_location=torch.device('cpu')))
|
||||||
|
|
||||||
|
|
||||||
def process_folder(file, base_path, output_path, progress_file):
|
def process_folder(file, base_path, output_path, progress_file):
|
||||||
|
@ -30,7 +33,10 @@ def process_folder(file, base_path, output_path, progress_file):
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
with np.load(file) as npz_file:
|
with np.load(file) as npz_file:
|
||||||
mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0)
|
mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0)
|
||||||
cheater = cheater_inj({'in': mel})['out']
|
global model
|
||||||
|
model = model.cuda()
|
||||||
|
with torch.no_grad():
|
||||||
|
cheater = model(mel)
|
||||||
np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy())
|
np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy())
|
||||||
report_progress(progress_file, file)
|
report_progress(progress_file, file)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user