forked from mrq/DL-Art-School
permute codes
This commit is contained in:
parent
90b232f965
commit
f425afc965
|
@ -100,6 +100,7 @@ class TransformerDiffusion(nn.Module):
|
|||
use_fp16=False,
|
||||
ar_prior=False,
|
||||
new_code_expansion=False,
|
||||
permute_codes=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# Parameters for re-training head
|
||||
|
@ -116,6 +117,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.new_code_expansion = new_code_expansion
|
||||
self.permute_codes = permute_codes
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||
|
||||
|
@ -228,6 +230,8 @@ class TransformerDiffusion(nn.Module):
|
|||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||
if self.permute_codes:
|
||||
codes = codes.permute(0,2,1)
|
||||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
|
@ -605,6 +609,15 @@ def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
|||
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
||||
|
||||
|
||||
def test_tfd():
|
||||
clip = torch.randn(2,256,400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=3, permute_codes=True,
|
||||
input_vec_dim=256, num_layers=12, prenet_layers=4,
|
||||
dropout=.1)
|
||||
model(clip, ts, clip)
|
||||
|
||||
def test_quant_model():
|
||||
clip = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
@ -767,4 +780,4 @@ def extract_diff(in_f, out_f, remove_head=False):
|
|||
|
||||
if __name__ == '__main__':
|
||||
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
||||
test_cheater_model()
|
||||
test_tfd()
|
||||
|
|
Loading…
Reference in New Issue
Block a user