get rid of nil tokens in <2>
This commit is contained in:
parent
0152174c0e
commit
935a4e853e
|
@ -135,7 +135,6 @@ class DiffusionTts(nn.Module):
|
|||
scale_factor=2,
|
||||
conditioning_inputs_provided=True,
|
||||
time_embed_dim_multiplier=4,
|
||||
nil_guidance_fwd_proportion=.3,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -154,8 +153,6 @@ class DiffusionTts(nn.Module):
|
|||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.dims = dims
|
||||
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
||||
self.mask_token_id = num_tokens
|
||||
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
|
@ -186,7 +183,7 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||
if ds in token_conditioning_resolutions:
|
||||
token_conditioning_block = nn.Embedding(num_tokens+1, ch)
|
||||
token_conditioning_block = nn.Embedding(num_tokens, ch)
|
||||
token_conditioning_block.weight.data.normal_(mean=0.0, std=.02)
|
||||
self.input_blocks.append(token_conditioning_block)
|
||||
token_conditioning_blocks.append(token_conditioning_block)
|
||||
|
@ -289,23 +286,6 @@ class DiffusionTts(nn.Module):
|
|||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||
strict: bool = True):
|
||||
# Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings.
|
||||
lsd = self.state_dict()
|
||||
revised = 0
|
||||
for i, blk in enumerate(self.input_blocks):
|
||||
if isinstance(blk, nn.Embedding):
|
||||
key = f'input_blocks.{i}.weight'
|
||||
if state_dict[key].shape[0] != lsd[key].shape[0]:
|
||||
t = torch.randn_like(lsd[key]) * .02
|
||||
t[:state_dict[key].shape[0]] = state_dict[key]
|
||||
state_dict[key] = t
|
||||
revised += 1
|
||||
print(f"Loaded experimental unet_diffusion_net with {revised} modifications.")
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
|
||||
|
||||
def forward(self, x, timesteps, tokens, conditioning_input=None):
|
||||
"""
|
||||
|
@ -333,11 +313,6 @@ class DiffusionTts(nn.Module):
|
|||
else:
|
||||
emb = emb1
|
||||
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Embedding):
|
||||
|
|
Loading…
Reference in New Issue
Block a user