diff --git a/codes/models/gpt_voice/lucidrains_gpt.py b/codes/models/gpt_voice/lucidrains_gpt.py index 07b1c90b..08646884 100644 --- a/codes/models/gpt_voice/lucidrains_gpt.py +++ b/codes/models/gpt_voice/lucidrains_gpt.py @@ -48,6 +48,7 @@ class LayerScale(nn.Module): scale = torch.zeros(1, 1, dim).fill_(init_eps) self.scale = nn.Parameter(scale) self.fn = fn + def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale @@ -78,8 +79,13 @@ class FeedForward(nn.Module): nn.Linear(dim * mult, dim) ) - def forward(self, x): - return self.net(x) + def forward(self, x, only_last_two_elements=False): + if only_last_two_elements: + h = x[:, -2:] + h = self.net(h) + return torch.cat([x[:, :-2], h], dim=1) + else: + return self.net(x) def exists(val): @@ -110,7 +116,7 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2): class Attention(nn.Module): def __init__(self, dim, seq_len, non_causal_sequence_partition = 0, heads = 8, dim_head = 64, dropout = 0., stable = False): super().__init__() - inner_dim = dim_head * heads + inner_dim = dim_head * heads self.heads = heads self.seq_len = seq_len self.scale = dim_head ** -0.5 @@ -124,15 +130,19 @@ class Attention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, mask = None): + def forward(self, x, mask = None, only_last_two_elements=False): b, n, _, h, device = *x.shape, self.heads, x.device softmax = torch.softmax if not self.stable else stable_softmax + # TODO: Q and V do not need to be recomputed for existing elements in intermediate_latents is specified. V would need to be cached though. qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) - q = q * self.scale + if only_last_two_elements: + q = q[:, :, -2:] + assert not exists(mask) # Don't know how to resolve this (currently) + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) mask_value = max_neg_value(dots) @@ -187,11 +197,40 @@ class Transformer(nn.Module): LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))) ])) + # TODO: Remove this nonsense. I don't want to support reversible sequences and this is just a mess. execute_type = ReversibleSequence if reversible else SequentialSequence route_attn = ((True, False),) * depth attn_route_map = {'mask': route_attn} self.layers = execute_type(layers, args_route = attn_route_map, checkpoint=True) + self.depth = depth - def forward(self, x): - return self.layers(x) + def forward(self, x, return_intermediates=False): + intermediates = [] + for attn, ff in self.layers.layers: + x_ff = attn(x) + x = ff(x_ff) + if return_intermediates: + intermediates.append((x_ff, x)) + if return_intermediates: + return x, intermediates + else: + return x + + def infer_last_two(self, x, prev_intermediates): + """ + Performs an forward pass only on the last two element in the given sequence (allowing it to attend to all other + elements). This is useful for faster autoregressive decoding. + + The last two elements are important because in inference, the last element is the prediction candidate and the + second-to-last element is a newly selected element from the autoregressive searching process. + """ + assert(len(prev_intermediates) == self.depth) + new_intermediates = [] + for (attn, ff), (int_ff, int_out) in zip(self.layers.layers, prev_intermediates): + x = attn(x, only_last_two_elements=True) + # Note that (x) is now only the last two element in the set. Conjoin it with the int_ff latent to compute the norm. + x_ff = torch.cat([int_ff[:,:-1], x], dim=1) + x = ff(x_ff, only_last_two_elements=True) + new_intermediates.append((x_ff, x)) + return x, new_intermediates