forked from mrq/DL-Art-School
fixes
This commit is contained in:
parent
19ca5b26c1
commit
03d0b90bda
|
@ -92,7 +92,7 @@ class CLVP(nn.Module):
|
|||
return {
|
||||
'conditioning': list(self.conditioning_transformer.parameters()),
|
||||
'text': list(self.text_transformer.parameters()),
|
||||
'speech': list(self.speech_transformer.parameters()) + list(self.mel_head.parameters()),
|
||||
'speech': list(self.speech_transformer.parameters()),
|
||||
}
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -779,6 +779,7 @@ class AttentionLayers(nn.Module):
|
|||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
self.causal = causal
|
||||
|
||||
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||
|
@ -921,6 +922,7 @@ class AttentionLayers(nn.Module):
|
|||
return_hiddens=False,
|
||||
norm_scale_shift_inp=None,
|
||||
past_key_values=None,
|
||||
expected_seq_len=None,
|
||||
):
|
||||
|
||||
assert not (self.cross_attend ^ (exists(context) or exists(
|
||||
|
@ -939,10 +941,14 @@ class AttentionLayers(nn.Module):
|
|||
|
||||
rotary_pos_emb = None
|
||||
if exists(self.rotary_pos_emb):
|
||||
if not self.training and self.causal:
|
||||
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
|
||||
elif expected_seq_len is None:
|
||||
expected_seq_len = 0
|
||||
seq_len = x.shape[1]
|
||||
if past_key_values is not None:
|
||||
seq_len += past_key_values[0][0].shape[-2]
|
||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)))
|
||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
present_key_values = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user