merged dedicated interleaved AR code with the normal AR code

This commit is contained in:
mrq 2023-09-03 22:46:08 -05:00
parent 3a6bd50322
commit 2f9cd0842f
6 changed files with 73 additions and 55 deletions

View File

@ -162,7 +162,7 @@ class Model:
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer"
training: bool = True
interleave_pattern: str | None = None
interleave: bool = False
@property
def full_name(self):
@ -174,6 +174,9 @@ class Model:
if self.arch_type != "transformer":
name.append(self.arch_type.replace("/", "-"))
if self.interleave:
name.append("interleaved")
name.append(f'{cfg.models.prom_levels}')
return "-".join(name)
@ -228,8 +231,8 @@ class Models:
_prom_levels: int = 1
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
])
def get(self, name=None):

View File

@ -1,2 +0,0 @@
# From: https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py
# audiocraft has heavy dependencies, so it doesn't make sense to depend on it just for this file.

View File

@ -23,8 +23,8 @@ class AR(Base):
@property
def arch_type(self) -> str:
if hasattr(self, "_cfg") and self._cfg:
return self._cfg.arch_type
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.ar.arch_type
@property
@ -33,8 +33,8 @@ class AR(Base):
@property
def n_resp_levels(self) -> int:
if hasattr(self, "_cfg") and self._cfg:
return self._cfg.resp_levels
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.ar.resp_levels
@property
@ -55,12 +55,30 @@ class AR(Base):
return 0
return cfg.inference.recurrent_chunk_size
@property
def interleave(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.interleave
return False
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
def _interleave( self, codes ):
if not self.interleave:
return codes
return codes.flatten()
def _deinterleave( self, codes, length = 0 ):
if not self.interleave:
return codes
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]
@ -74,7 +92,10 @@ class AR(Base):
sampling_temperature: float = 1.0,
):
if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
if self.interleave:
resps_list = [self._interleave(r) for r in resps_list]
else:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward(
text_list=text_list,
@ -94,6 +115,9 @@ class AR(Base):
state = {} if cfg.inference.recurrent_forward else None
if self.interleave:
max_steps *= self.n_prom_levels
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence
@ -116,9 +140,10 @@ class AR(Base):
if stopped.all().item():
break
pruned = [self._prune(r) for r in resps_list]
return pruned
res = [self._prune(r) for r in resps_list]
if self.interleave:
res = [self._deinterleave(r) for r in res]
return res
def example_usage():
@ -163,6 +188,10 @@ def example_usage():
'n_heads': 16,
'n_layers': 24,
}
try:
kwargs['config'] = cfg.models.ar
except Exception as e:
pass
model = AR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))

View File

@ -129,6 +129,10 @@ class Base(nn.Module):
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
@property
def interleave(self) -> bool:
return False
def __init__(
self,
@ -137,8 +141,11 @@ class Base(nn.Module):
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
config = None,
):
super().__init__()
self.config = config
self.n_tokens = n_tokens
self.d_model = d_model
self.n_heads = n_heads

View File

@ -135,7 +135,7 @@ class Base(nn.Module):
@property
def n_resp_levels(self) -> int:
return 4
return 1
@property
def n_max_levels(self) -> int:
@ -155,7 +155,7 @@ class Base(nn.Module):
@property
def interleave_pattern(self) -> str | None:
return "musiclm"
return "flatten"
@property
def stop_token(self):
@ -192,27 +192,12 @@ class Base(nn.Module):
return codes
return codes.flatten()
"""
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
pattern = pattern_provider.get_pattern( codes.shape[0] )
res, _, _ = pattern.build_pattern_sequence( codes.t()[None, :, :], self.interleaved_token, keep_only_valid_steps=True )
return res[0].t().flatten()
"""
def _deinterleave( self, codes ):
def _deinterleave( self, codes, length = 0 ):
if not self.interleave_pattern:
return codes
return torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
"""
if codes.dim() == 1:
codes = torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
pattern = pattern_provider.get_pattern( codes.shape[0] )
res, _, _ = pattern.revert_pattern_sequence( codes, special_token=self.interleaved_token)
return res[0].t()
"""
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
def __init__(
self,
@ -232,13 +217,13 @@ class Base(nn.Module):
self.n_layers = n_layers
# + tasks for each token they represent in the prom
n_prom_tokens = n_tokens + (self.n_tasks - 1) + (1 if self.interleave_pattern else 0) # - 1 because tts is an inherent task
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
# +1 to include the stop token + 1 to include interleave token
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) + (1 if self.interleave_pattern else 0) # AR requires a stop token to... know when to stop
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(1, n_resp_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
@ -270,7 +255,6 @@ class Base(nn.Module):
))
# I imagine because each step returns `resp_level`s tokens at once, so we need to have a classifier for each level
#self.classifier = nn.ModuleList([ nn.Linear(d_model, n_resp_tokens) for _ in range(self.n_resp_levels) ]) if self.interleave_pattern else nn.Linear(d_model, n_resp_tokens)
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy(
@ -385,11 +369,6 @@ class Base(nn.Module):
# Remove padding
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
if True:
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
print( [ r for r in ret ] )
# compute loss if the target is given
if targ_list is not None:
if any([l == 0 for l in map(len, targ_list)]):
@ -487,6 +466,8 @@ class Base(nn.Module):
state = {} if cfg.inference.recurrent_forward else None
max_steps *= self.n_prom_levels
for n in range(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence
@ -502,6 +483,7 @@ class Base(nn.Module):
for i, ri in enumerate(r):
if self.stop_token in ri:
stopped[i] = True
resps_list[i] = torch.cat([resps_list[i], ri])
# stop token found
@ -509,12 +491,7 @@ class Base(nn.Module):
if stopped.all().item():
break
pruned = [self._prune(r) for r in resps_list]
print( [ r for r in pruned ] )
deinterleaved = [ self._deinterleave(r) for r in pruned ]
print( [ r for r in deinterleaved ] )
return deinterleaved
return [self._deinterleave(self._prune(r)) for r in resps_list]
def example_usage():
from ..config import cfg
@ -548,7 +525,7 @@ def example_usage():
for name, model in models.items():
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=5e-5)) for name, model in models.items() })
train = True
@ -565,7 +542,7 @@ def example_usage():
qnt.to(device),
]
def sample( filename, steps=450 * 4 ):
def sample( filename, steps=450 ):
AR = None
engines.eval()
@ -578,10 +555,10 @@ def example_usage():
decode_to_file(resps_list[0].cpu(), f"./data/{filename}.wav", device="cpu")
if train:
sample("init", 15)
sample("init", 75 )
engines.train()
t = trange(100)
t = trange(500)
for i in t:
stats = {"step": i}
"""

View File

@ -17,8 +17,8 @@ class NAR(Base):
@property
def arch_type(self) -> str:
if hasattr(self, "_cfg") and self._cfg:
return self._cfg.arch_type
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.nar.arch_type
@property
@ -31,8 +31,8 @@ class NAR(Base):
@property
def n_resp_levels(self) -> int:
if hasattr(self, "_cfg") and self._cfg:
return self._cfg.resp_levels
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.nar.resp_levels
@property
@ -51,6 +51,10 @@ class NAR(Base):
def recurrent_chunk_size(self) -> int:
return 0
@property
def interleave(self) -> bool:
return False
def forward(
self,
text_list: list[Tensor],