diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 35499434..1c6cd251 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -92,7 +92,7 @@ class CLVP(nn.Module): return { 'conditioning': list(self.conditioning_transformer.parameters()), 'text': list(self.text_transformer.parameters()), - 'speech': list(self.speech_transformer.parameters()) + list(self.mel_head.parameters()), + 'speech': list(self.speech_transformer.parameters()), } def forward( diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 7f389dba..2e32c096 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -779,6 +779,7 @@ class AttentionLayers(nn.Module): self.dim = dim self.depth = depth self.layers = nn.ModuleList([]) + self.causal = causal rel_pos_bias = 'rel_pos_bias' in attn_kwargs self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb @@ -921,6 +922,7 @@ class AttentionLayers(nn.Module): return_hiddens=False, norm_scale_shift_inp=None, past_key_values=None, + expected_seq_len=None, ): assert not (self.cross_attend ^ (exists(context) or exists( @@ -939,10 +941,14 @@ class AttentionLayers(nn.Module): rotary_pos_emb = None if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 seq_len = x.shape[1] if past_key_values is not None: seq_len += past_key_values[0][0].shape[-2] - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems))) + max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) present_key_values = []