This commit is contained in:
James Betker 2022-06-25 21:22:08 -06:00
parent f12f0200d6
commit d0f2560396
2 changed files with 4 additions and 4 deletions

View File

@ -286,7 +286,7 @@ def inference_tfdpc3_with_cheater():
model = TransformerDiffusionWithConditioningEncoder(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=512, num_heads=8, num_layers=12, dropout=0,
use_fp16=False, unconditioned_percentage=0).eval().cuda()
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth'))
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/61000_generator_ema.pth'))
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True,

View File

@ -229,8 +229,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
return out
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
@ -251,7 +251,7 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
contraction_dim=384, num_heads=6, num_layers=18, dropout=0,
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
unconditioned_percentage=.4)
print_network(model)
o = model(clip, ts, cl)