forked from mrq/DL-Art-School
tfd8
This commit is contained in:
parent
d98b895307
commit
e78c4b422c
|
@ -149,6 +149,7 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
def timestep_independent(self, prior, expected_seq_len):
|
||||
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
|
@ -156,7 +157,6 @@ class TransformerDiffusion(nn.Module):
|
|||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
|
||||
code_emb)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||
return expanded_code_emb
|
||||
|
@ -215,7 +215,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.quantizer.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)
|
||||
|
@ -336,7 +336,8 @@ def test_ar_model():
|
|||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
|
||||
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True)
|
||||
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
|
||||
unconditioned_percentage=.4)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
|
||||
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt_upper.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user