forked from mrq/DL-Art-School
drop full layers in layerdrop, not half layers
This commit is contained in:
parent
57da6d0ddf
commit
8707a3e0c3
|
@ -152,6 +152,7 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
layer_types = default_block * depth
|
layer_types = default_block * depth
|
||||||
|
|
||||||
self.layer_types = layer_types
|
self.layer_types = layer_types
|
||||||
|
self.num_layer_types = len(set(self.layer_types))
|
||||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||||
|
|
||||||
# calculate token shifting
|
# calculate token shifting
|
||||||
|
@ -225,21 +226,24 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
to_drop = 0
|
||||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||||
is_last = ind == (len(self.layers) - 1)
|
if layer_type == 'a':
|
||||||
|
# Do layer drop where applicable. Do not drop first layer. When doing layer-drop, drop all of the joined layers (e.g. attention + context + feedforward)
|
||||||
|
if self.training and self.layerdrop_percent > 0 and ind != 0 and random.random() < self.layerdrop_percent:
|
||||||
|
to_drop = self.num_layer_types
|
||||||
|
|
||||||
# Do layer drop where applicable. Do not drop first and last layers.
|
hiddens.append(x)
|
||||||
if self.training and self.layerdrop_percent > 0 and not is_last and ind != 0 and random.random() < self.layerdrop_percent:
|
layer_mem = mems.pop(0) if mems else None
|
||||||
|
|
||||||
|
if to_drop > 0:
|
||||||
|
to_drop -= 1
|
||||||
# Record the unused parameters so they can be used in null-operations later to not trigger DDP.
|
# Record the unused parameters so they can be used in null-operations later to not trigger DDP.
|
||||||
unused_params.extend(list(block.parameters()))
|
unused_params.extend(list(block.parameters()))
|
||||||
unused_params.extend(list(residual_fn.parameters()))
|
unused_params.extend(list(residual_fn.parameters()))
|
||||||
unused_params.extend(list(norm.parameters()))
|
unused_params.extend(list(norm.parameters()))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if layer_type == 'a':
|
|
||||||
hiddens.append(x)
|
|
||||||
layer_mem = mems.pop(0) if mems else None
|
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
||||||
|
|
|
@ -21,7 +21,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_channels=512,
|
model_channels=512,
|
||||||
num_layers=8,
|
num_layers=16,
|
||||||
in_channels=100,
|
in_channels=100,
|
||||||
in_latent_channels=512,
|
in_latent_channels=512,
|
||||||
in_tokens=8193,
|
in_tokens=8193,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user