forked from mrq/DL-Art-School
expand codes before the code converters for cheater latents
This commit is contained in:
parent
f70b16214d
commit
3efd64ed7a
|
@ -99,7 +99,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
ar_prior=False,
|
ar_prior=False,
|
||||||
code_expansion_mode='nearest',
|
new_code_expansion=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for re-training head
|
# Parameters for re-training head
|
||||||
|
@ -115,7 +115,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.code_expansion_mode = code_expansion_mode
|
self.new_code_expansion = new_code_expansion
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
@ -209,7 +209,9 @@ class TransformerDiffusion(nn.Module):
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def timestep_independent(self, prior, expected_seq_len):
|
def timestep_independent(self, prior, expected_seq_len):
|
||||||
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
if self.new_code_expansion:
|
||||||
|
code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||||
|
code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb)
|
||||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||||
|
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
@ -219,8 +221,9 @@ class TransformerDiffusion(nn.Module):
|
||||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
|
||||||
code_emb)
|
code_emb)
|
||||||
|
|
||||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode=self.code_expansion_mode).permute(0,2,1)
|
if not self.new_code_expansion:
|
||||||
return expanded_code_emb
|
code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||||
|
return code_emb
|
||||||
|
|
||||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
||||||
if precomputed_code_embeddings is not None:
|
if precomputed_code_embeddings is not None:
|
||||||
|
@ -722,7 +725,7 @@ def test_cheater_model():
|
||||||
model_channels=1024, contraction_dim=512,
|
model_channels=1024, contraction_dim=512,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=8,
|
||||||
input_vec_dim=256, num_layers=12, prenet_layers=6,
|
input_vec_dim=256, num_layers=12, prenet_layers=6,
|
||||||
dropout=.1,
|
dropout=.1, new_code_expansion=True,
|
||||||
)
|
)
|
||||||
diff_weights = torch.load('extracted_diff.pth')
|
diff_weights = torch.load('extracted_diff.pth')
|
||||||
model.diff.load_state_dict(diff_weights, strict=False)
|
model.diff.load_state_dict(diff_weights, strict=False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user