drop full layers in layerdrop, not half layers
This commit is contained in:
parent
57da6d0ddf
commit
8707a3e0c3
|
@ -152,6 +152,7 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_layer_types = len(set(self.layer_types))
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
# calculate token shifting
|
||||
|
@ -225,21 +226,24 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
unused_params = []
|
||||
to_drop = 0
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
if layer_type == 'a':
|
||||
# Do layer drop where applicable. Do not drop first layer. When doing layer-drop, drop all of the joined layers (e.g. attention + context + feedforward)
|
||||
if self.training and self.layerdrop_percent > 0 and ind != 0 and random.random() < self.layerdrop_percent:
|
||||
to_drop = self.num_layer_types
|
||||
|
||||
# Do layer drop where applicable. Do not drop first and last layers.
|
||||
if self.training and self.layerdrop_percent > 0 and not is_last and ind != 0 and random.random() < self.layerdrop_percent:
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0) if mems else None
|
||||
|
||||
if to_drop > 0:
|
||||
to_drop -= 1
|
||||
# Record the unused parameters so they can be used in null-operations later to not trigger DDP.
|
||||
unused_params.extend(list(block.parameters()))
|
||||
unused_params.extend(list(residual_fn.parameters()))
|
||||
unused_params.extend(list(norm.parameters()))
|
||||
continue
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0) if mems else None
|
||||
|
||||
residual = x
|
||||
|
||||
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
||||
|
|
|
@ -21,7 +21,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
model_channels=512,
|
||||
num_layers=8,
|
||||
num_layers=16,
|
||||
in_channels=100,
|
||||
in_latent_channels=512,
|
||||
in_tokens=8193,
|
||||
|
|
Loading…
Reference in New Issue
Block a user