udmc update

This commit is contained in:
James Betker 2022-06-03 12:02:22 -06:00
parent a14274c845
commit 581bc7ac5c

View File

@ -714,7 +714,7 @@ class UNetMusicModelWithQuantizer(nn.Module):
self.m2v.min_gumbel_temperature,
)
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_free=False):
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
@ -749,9 +749,14 @@ if __name__ == '__main__':
clip = torch.randn(2, 256, 400)
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=512, num_res_blocks=3, input_vec_dim=1024,
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
use_scale_shift_norm=True, dropout=.1, num_heads=8)
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4)
print_network(model)
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth')
model.m2v.load_state_dict(quant_weights, strict=False)
torch.save(model.state_dict(), 'sample.pth')
model(clip, ts, cond)