forked from mrq/DL-Art-School
permute codes
This commit is contained in:
parent
90b232f965
commit
f425afc965
|
@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
ar_prior=False,
|
ar_prior=False,
|
||||||
new_code_expansion=False,
|
new_code_expansion=False,
|
||||||
|
permute_codes=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for re-training head
|
# Parameters for re-training head
|
||||||
|
@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.new_code_expansion = new_code_expansion
|
self.new_code_expansion = new_code_expansion
|
||||||
|
self.permute_codes = permute_codes
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
||||||
if precomputed_code_embeddings is not None:
|
if precomputed_code_embeddings is not None:
|
||||||
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||||
|
if self.permute_codes:
|
||||||
|
codes = codes.permute(0,2,1)
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
|
@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
||||||
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_tfd():
|
||||||
|
clip = torch.randn(2,256,400)
|
||||||
|
ts = torch.LongTensor([600, 600])
|
||||||
|
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||||
|
prenet_channels=1024, num_heads=3, permute_codes=True,
|
||||||
|
input_vec_dim=256, num_layers=12, prenet_layers=4,
|
||||||
|
dropout=.1)
|
||||||
|
model(clip, ts, clip)
|
||||||
|
|
||||||
def test_quant_model():
|
def test_quant_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
||||||
test_cheater_model()
|
test_tfd()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user