forked from mrq/DL-Art-School
allow more interesting code interpolation
This commit is contained in:
parent
87a86ae6a8
commit
f70b16214d
|
@ -99,6 +99,7 @@ class TransformerDiffusion(nn.Module):
|
|||
dropout=0,
|
||||
use_fp16=False,
|
||||
ar_prior=False,
|
||||
code_expansion_mode='nearest',
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# Parameters for re-training head
|
||||
|
@ -114,6 +115,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.dropout = dropout
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.enable_fp16 = use_fp16
|
||||
self.code_expansion_mode = code_expansion_mode
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||
|
||||
|
@ -217,7 +219,7 @@ class TransformerDiffusion(nn.Module):
|
|||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
|
||||
code_emb)
|
||||
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode=self.code_expansion_mode).permute(0,2,1)
|
||||
return expanded_code_emb
|
||||
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
|
||||
|
|
Loading…
Reference in New Issue
Block a user