This commit is contained in:
James Betker 2022-06-10 09:24:41 -06:00
parent d98b895307
commit e78c4b422c
2 changed files with 5 additions and 4 deletions

View File

@ -149,6 +149,7 @@ class TransformerDiffusion(nn.Module):
def timestep_independent(self, prior, expected_seq_len):
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
@ -156,7 +157,6 @@ class TransformerDiffusion(nn.Module):
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
code_emb)
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
return expanded_code_emb
@ -215,7 +215,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.quantizer.min_gumbel_temperature,
)
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)
@ -336,7 +336,8 @@ def test_ar_model():
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True)
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
unconditioned_percentage=.4)
model.get_grad_norm_parameter_groups()
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')

View File

@ -339,7 +339,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt_upper.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)