forked from mrq/DL-Art-School
udmc update
This commit is contained in:
parent
a14274c845
commit
581bc7ac5c
|
@ -714,7 +714,7 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
||||||
self.m2v.min_gumbel_temperature,
|
self.m2v.min_gumbel_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||||
with torch.set_grad_enabled(quant_grad_enabled):
|
with torch.set_grad_enabled(quant_grad_enabled):
|
||||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||||
|
@ -749,9 +749,14 @@ if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=512, num_res_blocks=3, input_vec_dim=1024,
|
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024,
|
||||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||||
use_scale_shift_norm=True, dropout=.1, num_heads=8)
|
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
|
|
||||||
|
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth')
|
||||||
|
model.m2v.load_state_dict(quant_weights, strict=False)
|
||||||
|
torch.save(model.state_dict(), 'sample.pth')
|
||||||
|
|
||||||
model(clip, ts, cond)
|
model(clip, ts, cond)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user