support generating cheaters from the new cheater network
This commit is contained in:
parent
27a9b1b750
commit
4ddd01a7fb
|
@ -143,13 +143,17 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|||
self.diff = TransformerDiffusion(**kwargs)
|
||||
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None):
|
||||
unused_parameters = []
|
||||
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
|
||||
if not encoder_grad_enabled:
|
||||
unused_parameters.extend(list(self.encoder.parameters()))
|
||||
with torch.set_grad_enabled(encoder_grad_enabled):
|
||||
proj = self.encoder(truth_mel)
|
||||
|
||||
if cheater is None:
|
||||
with torch.set_grad_enabled(encoder_grad_enabled):
|
||||
proj = self.encoder(truth_mel)
|
||||
else:
|
||||
proj = cheater
|
||||
|
||||
for p in unused_parameters:
|
||||
proj = proj + p.mean() * 0
|
||||
|
@ -177,6 +181,10 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|||
p.grad *= .2
|
||||
|
||||
|
||||
def get_cheater_encoder_v2():
|
||||
return ResEncoder16x(256, 1024, 256, checkpointing_enabled=False)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion14(opt_net, opt):
|
||||
return TransformerDiffusion(**opt_net['kwargs'])
|
||||
|
@ -221,12 +229,12 @@ def extract_cheater_encoder(in_f, out_f):
|
|||
out = {}
|
||||
for k, v in p.items():
|
||||
if k.startswith('encoder.'):
|
||||
out[k] = v
|
||||
out[k[len('encoder.'):]] = v
|
||||
torch.save(out, out_f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_local_attention_mask()
|
||||
#extract_cheater_encoder('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\104500_generator_ema.pth', 'X:\\dlas\\experiments\\tfd12_self_learned_cheater_enc.pth', True)
|
||||
test_cheater_model()
|
||||
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
|
||||
#test_cheater_model()
|
||||
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||
|
|
|
@ -16,13 +16,16 @@ from tqdm import tqdm
|
|||
|
||||
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
||||
|
||||
from codes.models.audio.music.transformer_diffusion14 import get_cheater_encoder_v2
|
||||
|
||||
|
||||
def report_progress(progress_file, file):
|
||||
with open(progress_file, 'a', encoding='utf-8') as f:
|
||||
f.write(f'{file}\n')
|
||||
|
||||
|
||||
cheater_inj = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {})
|
||||
model = get_cheater_encoder_v2().eval().cpu()
|
||||
model.load_state_dict(torch.load('../experiments/tfd14_cheater_encoder.pth', map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
def process_folder(file, base_path, output_path, progress_file):
|
||||
|
@ -30,7 +33,10 @@ def process_folder(file, base_path, output_path, progress_file):
|
|||
os.makedirs(outdir, exist_ok=True)
|
||||
with np.load(file) as npz_file:
|
||||
mel = torch.tensor(npz_file['arr_0']).cuda().unsqueeze(0)
|
||||
cheater = cheater_inj({'in': mel})['out']
|
||||
global model
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
cheater = model(mel)
|
||||
np.savez(os.path.join(outdir, os.path.basename(file)), cheater.cpu().numpy())
|
||||
report_progress(progress_file, file)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user