From 918e0dbac170af56f6371e801c0c1543c06891f7 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 24 Feb 2025 19:03:53 -0600 Subject: [PATCH] small slop cleanup --- vall_e/models/base.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3ec7294..f85501d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -371,14 +371,12 @@ class AudioEncoder(nn.Module): """ # encode by interleaving + # resultant tensor is equal to prior naive attempt seq_len = xi.shape[0] - # (8, seq_len, dim) - x = [ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ] - # => (seq_len, dim * 8) interleaved - x_i = [] - for i in range(xi.shape[0]): - x_i.append(torch.cat([ x[l][i] for l in range(len(self.embs)) ], dim=-1)) - x = torch.stack( x_i, dim=0 ) + # (seq_len, 8, dim) + x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1) + # (seq_len, 8 * dim) + x = x.view(x.shape[0], -1) # => (seq_len, dim) x = self.proj(x)