diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py index c1e04905..56d25d61 100644 --- a/codes/models/audio/tts/diffusion_encoder.py +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -152,6 +152,7 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): layer_types = default_block * depth self.layer_types = layer_types + self.num_layer_types = len(set(self.layer_types)) self.num_attn_layers = len(list(filter(equals('a'), layer_types))) # calculate token shifting @@ -225,21 +226,24 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) unused_params = [] + to_drop = 0 for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): - is_last = ind == (len(self.layers) - 1) + if layer_type == 'a': + # Do layer drop where applicable. Do not drop first layer. When doing layer-drop, drop all of the joined layers (e.g. attention + context + feedforward) + if self.training and self.layerdrop_percent > 0 and ind != 0 and random.random() < self.layerdrop_percent: + to_drop = self.num_layer_types - # Do layer drop where applicable. Do not drop first and last layers. - if self.training and self.layerdrop_percent > 0 and not is_last and ind != 0 and random.random() < self.layerdrop_percent: + hiddens.append(x) + layer_mem = mems.pop(0) if mems else None + + if to_drop > 0: + to_drop -= 1 # Record the unused parameters so they can be used in null-operations later to not trigger DDP. unused_params.extend(list(block.parameters())) unused_params.extend(list(residual_fn.parameters())) unused_params.extend(list(norm.parameters())) continue - if layer_type == 'a': - hiddens.append(x) - layer_mem = mems.pop(0) if mems else None - residual = x pre_branch_norm, post_branch_norm, post_main_norm = norm diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 23ea674a..35410212 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -21,7 +21,7 @@ class DiffusionTtsFlat(nn.Module): def __init__( self, model_channels=512, - num_layers=8, + num_layers=16, in_channels=100, in_latent_channels=512, in_tokens=8193,