udmc update
This commit is contained in:
parent
a14274c845
commit
581bc7ac5c
|
@ -714,7 +714,7 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
|||
self.m2v.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||
|
@ -749,9 +749,14 @@ if __name__ == '__main__':
|
|||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=512, num_res_blocks=3, input_vec_dim=1024,
|
||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
|
||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||
use_scale_shift_norm=True, dropout=.1, num_heads=8)
|
||||
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4)
|
||||
print_network(model)
|
||||
|
||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth')
|
||||
model.m2v.load_state_dict(quant_weights, strict=False)
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
model(clip, ts, cond)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user