permute codes

This commit is contained in:
James Betker 2022-06-19 18:00:30 -06:00
parent 90b232f965
commit f425afc965

View File

@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module):
use_fp16=False,
ar_prior=False,
new_code_expansion=False,
permute_codes=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for re-training head
@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.new_code_expansion = new_code_expansion
self.permute_codes = permute_codes
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module):
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
if precomputed_code_embeddings is not None:
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
if self.permute_codes:
codes = codes.permute(0,2,1)
unused_params = []
if conditioning_free:
@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
def test_tfd():
clip = torch.randn(2,256,400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
prenet_channels=1024, num_heads=3, permute_codes=True,
input_vec_dim=256, num_layers=12, prenet_layers=4,
dropout=.1)
model(clip, ts, clip)
def test_quant_model():
clip = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False):
if __name__ == '__main__':
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
test_cheater_model()
test_tfd()