fix
This commit is contained in:
parent
f12f0200d6
commit
d0f2560396
|
@ -286,7 +286,7 @@ def inference_tfdpc3_with_cheater():
|
|||
model = TransformerDiffusionWithConditioningEncoder(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=512, num_heads=8, num_layers=12, dropout=0,
|
||||
use_fp16=False, unconditioned_percentage=0).eval().cuda()
|
||||
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth'))
|
||||
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/61000_generator_ema.pth'))
|
||||
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True,
|
||||
|
|
|
@ -229,8 +229,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
return out
|
||||
|
||||
def before_step(self, step):
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + \
|
||||
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers]))
|
||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
|
@ -251,7 +251,7 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||
contraction_dim=384, num_heads=6, num_layers=18, dropout=0,
|
||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||
unconditioned_percentage=.4)
|
||||
print_network(model)
|
||||
o = model(clip, ts, cl)
|
||||
|
|
Loading…
Reference in New Issue
Block a user