there's something wrong with it on my 4xV100 rig......

This commit is contained in:
mrq 2025-02-25 15:14:08 -06:00
parent db181f8e88
commit de27115bb7
2 changed files with 14 additions and 21 deletions

View File

@ -1367,8 +1367,8 @@ class AR_NAR(Base):
def example_usage():
cfg.device = "cuda"
cfg.trainer.backend = "local"
#cfg.device = "cuda"
#cfg.trainer.backend = "local"
from functools import partial
from einops import repeat

View File

@ -356,29 +356,22 @@ class AudioEncoder(nn.Module):
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
return torch.zeros((0, self.proj.weight.shape[0]), device=xi.device, dtype=xi.dtype)
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# old way
"""
# this probably is a tried and true good way to go about it
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
"""
# naive way to "encode" by flattening
"""
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
x = self.proj(x)
"""
# encode by interleaving
# resultant tensor is equal to prior naive attempt
seq_len = xi.shape[0]
# (seq_len, 8, dim)
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
"""
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
# (seq_len, 8 * dim)
x = x.view(x.shape[0], -1)
# => (seq_len, dim)
x = self.proj(x)
"""
return x
@ -392,17 +385,17 @@ class AudioDecoder(nn.Module):
):
super().__init__()
self.up = nn.Linear( d_model, hidden_size )
self.down = nn.Linear( hidden_size, vocab_size * resp_levels )
self.resp_levels = resp_levels
self.head = nn.Linear( d_model, vocab_size * resp_levels )
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
x = self.up( x )
x = self.down( x )
# prior way up-projected then down-projected, but that's silly
x = self.head( x )
# interleave by reshaping / permuting
# at least I hope this does it properly
# at least I hope this does it properly, it checks out against my OCR classifier
batch_size, seq_len, dim = x.shape
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
x = x.permute( 0, 2, 1, 3 )
return x