allow more interesting code interpolation

pull/9/head
James Betker 2022-06-17 09:12:44 +07:00
parent 87a86ae6a8
commit f70b16214d
1 changed files with 3 additions and 1 deletions

@ -99,6 +99,7 @@ class TransformerDiffusion(nn.Module):
dropout=0,
use_fp16=False,
ar_prior=False,
code_expansion_mode='nearest',
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for re-training head
@ -114,6 +115,7 @@ class TransformerDiffusion(nn.Module):
self.dropout = dropout
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.code_expansion_mode = code_expansion_mode
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
@ -217,7 +219,7 @@ class TransformerDiffusion(nn.Module):
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode=self.code_expansion_mode).permute(0,2,1)
return expanded_code_emb
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):