forked from mrq/DL-Art-School
Allow diffusion model to be trained with masking tokens
This commit is contained in:
parent
798ed7730a
commit
0f3ca28e39
|
@ -115,7 +115,7 @@ class DiffusionTts(nn.Module):
|
||||||
self,
|
self,
|
||||||
model_channels,
|
model_channels,
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
num_tokens=30,
|
num_tokens=32,
|
||||||
out_channels=2, # mean and variance
|
out_channels=2, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||||
|
@ -135,6 +135,7 @@ class DiffusionTts(nn.Module):
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
conditioning_inputs_provided=True,
|
conditioning_inputs_provided=True,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
|
nil_guidance_fwd_proportion=.3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -153,6 +154,8 @@ class DiffusionTts(nn.Module):
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
|
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
||||||
|
self.mask_token_id = num_tokens
|
||||||
|
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
|
@ -183,7 +186,7 @@ class DiffusionTts(nn.Module):
|
||||||
|
|
||||||
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||||
if ds in token_conditioning_resolutions:
|
if ds in token_conditioning_resolutions:
|
||||||
token_conditioning_block = nn.Embedding(num_tokens, ch)
|
token_conditioning_block = nn.Embedding(num_tokens+1, ch)
|
||||||
token_conditioning_block.weight.data.normal_(mean=0.0, std=.02)
|
token_conditioning_block.weight.data.normal_(mean=0.0, std=.02)
|
||||||
self.input_blocks.append(token_conditioning_block)
|
self.input_blocks.append(token_conditioning_block)
|
||||||
token_conditioning_blocks.append(token_conditioning_block)
|
token_conditioning_blocks.append(token_conditioning_block)
|
||||||
|
@ -286,6 +289,24 @@ class DiffusionTts(nn.Module):
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
||||||
|
strict: bool = True):
|
||||||
|
# Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings.
|
||||||
|
lsd = self.state_dict()
|
||||||
|
revised = 0
|
||||||
|
for i, blk in enumerate(self.input_blocks):
|
||||||
|
if isinstance(blk, nn.Embedding):
|
||||||
|
key = f'input_blocks.{i}.weight'
|
||||||
|
if state_dict[key].shape[0] != lsd[key].shape[0]:
|
||||||
|
t = torch.randn_like(lsd[key]) * .02
|
||||||
|
t[:state_dict[key].shape[0]] = state_dict[key]
|
||||||
|
state_dict[key] = t
|
||||||
|
revised += 1
|
||||||
|
print(f"Loaded experimental unet_diffusion_net with {revised} modifications.")
|
||||||
|
return super().load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, tokens, conditioning_input=None):
|
def forward(self, x, timesteps, tokens, conditioning_input=None):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
@ -307,11 +328,16 @@ class DiffusionTts(nn.Module):
|
||||||
hs = []
|
hs = []
|
||||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
if self.conditioning_enabled:
|
if self.conditioning_enabled:
|
||||||
emb2 = self.contextual_embedder(conditioning_input)
|
actual_cond = self.contextual_embedder(conditioning_input)
|
||||||
emb = emb1 + emb2
|
emb = emb1 + actual_cond
|
||||||
else:
|
else:
|
||||||
emb = emb1
|
emb = emb1
|
||||||
|
|
||||||
|
# Mask out guidance tokens for un-guided diffusion.
|
||||||
|
if self.nil_guidance_fwd_proportion > 0:
|
||||||
|
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
||||||
|
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for k, module in enumerate(self.input_blocks):
|
for k, module in enumerate(self.input_blocks):
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
|
@ -370,13 +396,15 @@ def register_diffusion_tts_experimental(opt_net, opt):
|
||||||
|
|
||||||
# Test for ~4 second audio clip at 22050Hz
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 86016)
|
clip = torch.randn(4, 1, 86016)
|
||||||
tok = torch.randint(0,30, (2,388))
|
tok = torch.randint(0,30, (4,388))
|
||||||
cond = torch.randn(2, 1, 44000)
|
cond = torch.randn(4, 1, 44000)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556, 600, 600])
|
||||||
model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4],
|
model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4],
|
||||||
token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3,
|
token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3,
|
||||||
scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4)
|
scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4)
|
||||||
|
model(clip, ts, tok, cond)
|
||||||
|
|
||||||
p, r = model.benchmark(clip, ts, tok, cond)
|
p, r = model.benchmark(clip, ts, tok, cond)
|
||||||
p = {k: v / 1000000000 for k, v in p.items()}
|
p = {k: v / 1000000000 for k, v in p.items()}
|
||||||
p = sorted(p.items(), key=operator.itemgetter(1))
|
p = sorted(p.items(), key=operator.itemgetter(1))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user