there's something wrong with it on my 4xV100 rig......
This commit is contained in:
parent
db181f8e88
commit
de27115bb7
|
@ -1367,8 +1367,8 @@ class AR_NAR(Base):
|
|||
|
||||
|
||||
def example_usage():
|
||||
cfg.device = "cuda"
|
||||
cfg.trainer.backend = "local"
|
||||
#cfg.device = "cuda"
|
||||
#cfg.trainer.backend = "local"
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
|
|
|
@ -356,29 +356,22 @@ class AudioEncoder(nn.Module):
|
|||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device, dtype=xi.dtype)
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
|
||||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
# old way
|
||||
"""
|
||||
# this probably is a tried and true good way to go about it
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
"""
|
||||
# naive way to "encode" by flattening
|
||||
"""
|
||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
# encode by interleaving
|
||||
# resultant tensor is equal to prior naive attempt
|
||||
seq_len = xi.shape[0]
|
||||
# (seq_len, 8, dim)
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
"""
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
# (seq_len, 8 * dim)
|
||||
x = x.view(x.shape[0], -1)
|
||||
# => (seq_len, dim)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
return x
|
||||
|
||||
|
@ -392,17 +385,17 @@ class AudioDecoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
self.up = nn.Linear( d_model, hidden_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size * resp_levels )
|
||||
self.resp_levels = resp_levels
|
||||
self.head = nn.Linear( d_model, vocab_size * resp_levels )
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
x = self.up( x )
|
||||
x = self.down( x )
|
||||
# prior way up-projected then down-projected, but that's silly
|
||||
x = self.head( x )
|
||||
|
||||
# interleave by reshaping / permuting
|
||||
# at least I hope this does it properly
|
||||
# at least I hope this does it properly, it checks out against my OCR classifier
|
||||
batch_size, seq_len, dim = x.shape
|
||||
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
||||
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
|
||||
x = x.permute( 0, 2, 1, 3 )
|
||||
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue
Block a user