merged dedicated interleaved AR code with the normal AR code
This commit is contained in:
parent
3a6bd50322
commit
2f9cd0842f
|
@ -162,7 +162,7 @@ class Model:
|
|||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
arch_type: str = "transformer"
|
||||
training: bool = True
|
||||
interleave_pattern: str | None = None
|
||||
interleave: bool = False
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
@ -174,6 +174,9 @@ class Model:
|
|||
if self.arch_type != "transformer":
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
|
||||
if self.interleave:
|
||||
name.append("interleaved")
|
||||
|
||||
name.append(f'{cfg.models.prom_levels}')
|
||||
|
||||
return "-".join(name)
|
||||
|
@ -228,8 +231,8 @@ class Models:
|
|||
_prom_levels: int = 1
|
||||
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
# From: https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py
|
||||
# audiocraft has heavy dependencies, so it doesn't make sense to depend on it just for this file.
|
|
@ -23,8 +23,8 @@ class AR(Base):
|
|||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
if hasattr(self, "_cfg") and self._cfg:
|
||||
return self._cfg.arch_type
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.arch_type
|
||||
return cfg.models.ar.arch_type
|
||||
|
||||
@property
|
||||
|
@ -33,8 +33,8 @@ class AR(Base):
|
|||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
if hasattr(self, "_cfg") and self._cfg:
|
||||
return self._cfg.resp_levels
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.resp_levels
|
||||
return cfg.models.ar.resp_levels
|
||||
|
||||
@property
|
||||
|
@ -55,12 +55,30 @@ class AR(Base):
|
|||
return 0
|
||||
return cfg.inference.recurrent_chunk_size
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.interleave
|
||||
return False
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
return l[: indices.min().item()]
|
||||
|
||||
def _interleave( self, codes ):
|
||||
if not self.interleave:
|
||||
return codes
|
||||
|
||||
return codes.flatten()
|
||||
|
||||
def _deinterleave( self, codes, length = 0 ):
|
||||
if not self.interleave:
|
||||
return codes
|
||||
|
||||
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
|
||||
|
||||
@staticmethod
|
||||
def _unsqueeze_list(x_list, axis=-1):
|
||||
return [x.unsqueeze(dim=axis) for x in x_list]
|
||||
|
@ -74,7 +92,10 @@ class AR(Base):
|
|||
sampling_temperature: float = 1.0,
|
||||
):
|
||||
if resps_list is not None:
|
||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
if self.interleave:
|
||||
resps_list = [self._interleave(r) for r in resps_list]
|
||||
else:
|
||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
|
||||
return super().forward(
|
||||
text_list=text_list,
|
||||
|
@ -94,6 +115,9 @@ class AR(Base):
|
|||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
||||
if self.interleave:
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
|
@ -116,9 +140,10 @@ class AR(Base):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
|
||||
pruned = [self._prune(r) for r in resps_list]
|
||||
return pruned
|
||||
res = [self._prune(r) for r in resps_list]
|
||||
if self.interleave:
|
||||
res = [self._deinterleave(r) for r in res]
|
||||
return res
|
||||
|
||||
|
||||
def example_usage():
|
||||
|
@ -163,6 +188,10 @@ def example_usage():
|
|||
'n_heads': 16,
|
||||
'n_layers': 24,
|
||||
}
|
||||
try:
|
||||
kwargs['config'] = cfg.models.ar
|
||||
except Exception as e:
|
||||
pass
|
||||
model = AR(**kwargs).to(device)
|
||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||
|
||||
|
|
|
@ -129,6 +129,10 @@ class Base(nn.Module):
|
|||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -137,8 +141,11 @@ class Base(nn.Module):
|
|||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.n_tokens = n_tokens
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
|
|
|
@ -135,7 +135,7 @@ class Base(nn.Module):
|
|||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
return 4
|
||||
return 1
|
||||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
|
@ -155,7 +155,7 @@ class Base(nn.Module):
|
|||
|
||||
@property
|
||||
def interleave_pattern(self) -> str | None:
|
||||
return "musiclm"
|
||||
return "flatten"
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
|
@ -192,27 +192,12 @@ class Base(nn.Module):
|
|||
return codes
|
||||
|
||||
return codes.flatten()
|
||||
"""
|
||||
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
|
||||
pattern = pattern_provider.get_pattern( codes.shape[0] )
|
||||
res, _, _ = pattern.build_pattern_sequence( codes.t()[None, :, :], self.interleaved_token, keep_only_valid_steps=True )
|
||||
return res[0].t().flatten()
|
||||
"""
|
||||
|
||||
def _deinterleave( self, codes ):
|
||||
def _deinterleave( self, codes, length = 0 ):
|
||||
if not self.interleave_pattern:
|
||||
return codes
|
||||
|
||||
return torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
|
||||
"""
|
||||
if codes.dim() == 1:
|
||||
codes = torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
|
||||
|
||||
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
|
||||
pattern = pattern_provider.get_pattern( codes.shape[0] )
|
||||
res, _, _ = pattern.revert_pattern_sequence( codes, special_token=self.interleaved_token)
|
||||
return res[0].t()
|
||||
"""
|
||||
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -232,13 +217,13 @@ class Base(nn.Module):
|
|||
self.n_layers = n_layers
|
||||
|
||||
# + tasks for each token they represent in the prom
|
||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) + (1 if self.interleave_pattern else 0) # - 1 because tts is an inherent task
|
||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
||||
# +1 to include the stop token + 1 to include interleave token
|
||||
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) + (1 if self.interleave_pattern else 0) # AR requires a stop token to... know when to stop
|
||||
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(1, n_resp_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -270,7 +255,6 @@ class Base(nn.Module):
|
|||
))
|
||||
|
||||
# I imagine because each step returns `resp_level`s tokens at once, so we need to have a classifier for each level
|
||||
#self.classifier = nn.ModuleList([ nn.Linear(d_model, n_resp_tokens) for _ in range(self.n_resp_levels) ]) if self.interleave_pattern else nn.Linear(d_model, n_resp_tokens)
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
|
@ -385,11 +369,6 @@ class Base(nn.Module):
|
|||
# Remove padding
|
||||
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
|
||||
|
||||
if True:
|
||||
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
print( [ r for r in ret ] )
|
||||
|
||||
# compute loss if the target is given
|
||||
if targ_list is not None:
|
||||
if any([l == 0 for l in map(len, targ_list)]):
|
||||
|
@ -487,6 +466,8 @@ class Base(nn.Module):
|
|||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
for n in range(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
|
@ -502,6 +483,7 @@ class Base(nn.Module):
|
|||
for i, ri in enumerate(r):
|
||||
if self.stop_token in ri:
|
||||
stopped[i] = True
|
||||
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
|
||||
# stop token found
|
||||
|
@ -509,12 +491,7 @@ class Base(nn.Module):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
|
||||
pruned = [self._prune(r) for r in resps_list]
|
||||
print( [ r for r in pruned ] )
|
||||
deinterleaved = [ self._deinterleave(r) for r in pruned ]
|
||||
print( [ r for r in deinterleaved ] )
|
||||
return deinterleaved
|
||||
return [self._deinterleave(self._prune(r)) for r in resps_list]
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
|
@ -548,7 +525,7 @@ def example_usage():
|
|||
for name, model in models.items():
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=5e-5)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
||||
|
@ -565,7 +542,7 @@ def example_usage():
|
|||
qnt.to(device),
|
||||
]
|
||||
|
||||
def sample( filename, steps=450 * 4 ):
|
||||
def sample( filename, steps=450 ):
|
||||
AR = None
|
||||
|
||||
engines.eval()
|
||||
|
@ -578,10 +555,10 @@ def example_usage():
|
|||
decode_to_file(resps_list[0].cpu(), f"./data/{filename}.wav", device="cpu")
|
||||
|
||||
if train:
|
||||
sample("init", 15)
|
||||
sample("init", 75 )
|
||||
|
||||
engines.train()
|
||||
t = trange(100)
|
||||
t = trange(500)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
"""
|
||||
|
|
|
@ -17,8 +17,8 @@ class NAR(Base):
|
|||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
if hasattr(self, "_cfg") and self._cfg:
|
||||
return self._cfg.arch_type
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.arch_type
|
||||
return cfg.models.nar.arch_type
|
||||
|
||||
@property
|
||||
|
@ -31,8 +31,8 @@ class NAR(Base):
|
|||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
if hasattr(self, "_cfg") and self._cfg:
|
||||
return self._cfg.resp_levels
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.resp_levels
|
||||
return cfg.models.nar.resp_levels
|
||||
|
||||
@property
|
||||
|
@ -51,6 +51,10 @@ class NAR(Base):
|
|||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
|
|
Loading…
Reference in New Issue
Block a user