This commit is contained in:
James Betker 2022-04-10 21:02:12 -06:00
parent 19ca5b26c1
commit 03d0b90bda
2 changed files with 8 additions and 2 deletions

View File

@ -92,7 +92,7 @@ class CLVP(nn.Module):
return {
'conditioning': list(self.conditioning_transformer.parameters()),
'text': list(self.text_transformer.parameters()),
'speech': list(self.speech_transformer.parameters()) + list(self.mel_head.parameters()),
'speech': list(self.speech_transformer.parameters()),
}
def forward(

View File

@ -779,6 +779,7 @@ class AttentionLayers(nn.Module):
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.causal = causal
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
@ -921,6 +922,7 @@ class AttentionLayers(nn.Module):
return_hiddens=False,
norm_scale_shift_inp=None,
past_key_values=None,
expected_seq_len=None,
):
assert not (self.cross_attend ^ (exists(context) or exists(
@ -939,10 +941,14 @@ class AttentionLayers(nn.Module):
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
if not self.training and self.causal:
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
elif expected_seq_len is None:
expected_seq_len = 0
seq_len = x.shape[1]
if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)))
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = []