merged dedicated interleaved AR code with the normal AR code
This commit is contained in:
parent
3a6bd50322
commit
2f9cd0842f
|
@ -162,7 +162,7 @@ class Model:
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
interleave_pattern: str | None = None
|
interleave: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
@ -174,6 +174,9 @@ class Model:
|
||||||
if self.arch_type != "transformer":
|
if self.arch_type != "transformer":
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
||||||
|
if self.interleave:
|
||||||
|
name.append("interleaved")
|
||||||
|
|
||||||
name.append(f'{cfg.models.prom_levels}')
|
name.append(f'{cfg.models.prom_levels}')
|
||||||
|
|
||||||
return "-".join(name)
|
return "-".join(name)
|
||||||
|
@ -228,8 +231,8 @@ class Models:
|
||||||
_prom_levels: int = 1
|
_prom_levels: int = 1
|
||||||
|
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
|
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
# From: https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py
|
|
||||||
# audiocraft has heavy dependencies, so it doesn't make sense to depend on it just for this file.
|
|
|
@ -23,8 +23,8 @@ class AR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arch_type(self) -> str:
|
def arch_type(self) -> str:
|
||||||
if hasattr(self, "_cfg") and self._cfg:
|
if hasattr(self, "config") and self.config:
|
||||||
return self._cfg.arch_type
|
return self.config.arch_type
|
||||||
return cfg.models.ar.arch_type
|
return cfg.models.ar.arch_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -33,8 +33,8 @@ class AR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "_cfg") and self._cfg:
|
if hasattr(self, "config") and self.config:
|
||||||
return self._cfg.resp_levels
|
return self.config.resp_levels
|
||||||
return cfg.models.ar.resp_levels
|
return cfg.models.ar.resp_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -55,12 +55,30 @@ class AR(Base):
|
||||||
return 0
|
return 0
|
||||||
return cfg.inference.recurrent_chunk_size
|
return cfg.inference.recurrent_chunk_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interleave(self) -> bool:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.interleave
|
||||||
|
return False
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
indices = (l == self.stop_token).nonzero()
|
indices = (l == self.stop_token).nonzero()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
return l
|
return l
|
||||||
return l[: indices.min().item()]
|
return l[: indices.min().item()]
|
||||||
|
|
||||||
|
def _interleave( self, codes ):
|
||||||
|
if not self.interleave:
|
||||||
|
return codes
|
||||||
|
|
||||||
|
return codes.flatten()
|
||||||
|
|
||||||
|
def _deinterleave( self, codes, length = 0 ):
|
||||||
|
if not self.interleave:
|
||||||
|
return codes
|
||||||
|
|
||||||
|
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _unsqueeze_list(x_list, axis=-1):
|
def _unsqueeze_list(x_list, axis=-1):
|
||||||
return [x.unsqueeze(dim=axis) for x in x_list]
|
return [x.unsqueeze(dim=axis) for x in x_list]
|
||||||
|
@ -74,7 +92,10 @@ class AR(Base):
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
if self.interleave:
|
||||||
|
resps_list = [self._interleave(r) for r in resps_list]
|
||||||
|
else:
|
||||||
|
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||||
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -94,6 +115,9 @@ class AR(Base):
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
|
||||||
|
if self.interleave:
|
||||||
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
|
|
||||||
|
@ -116,9 +140,10 @@ class AR(Base):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
res = [self._prune(r) for r in resps_list]
|
||||||
pruned = [self._prune(r) for r in resps_list]
|
if self.interleave:
|
||||||
return pruned
|
res = [self._deinterleave(r) for r in res]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
|
@ -163,6 +188,10 @@ def example_usage():
|
||||||
'n_heads': 16,
|
'n_heads': 16,
|
||||||
'n_layers': 24,
|
'n_layers': 24,
|
||||||
}
|
}
|
||||||
|
try:
|
||||||
|
kwargs['config'] = cfg.models.ar
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
model = AR(**kwargs).to(device)
|
model = AR(**kwargs).to(device)
|
||||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||||
|
|
||||||
|
|
|
@ -130,6 +130,10 @@ class Base(nn.Module):
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interleave(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int = 1024,
|
n_tokens: int = 1024,
|
||||||
|
@ -137,8 +141,11 @@ class Base(nn.Module):
|
||||||
n_heads: int = 8,
|
n_heads: int = 8,
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
config = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
|
@ -135,7 +135,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
return 4
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_max_levels(self) -> int:
|
def n_max_levels(self) -> int:
|
||||||
|
@ -155,7 +155,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave_pattern(self) -> str | None:
|
def interleave_pattern(self) -> str | None:
|
||||||
return "musiclm"
|
return "flatten"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
|
@ -192,27 +192,12 @@ class Base(nn.Module):
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
return codes.flatten()
|
return codes.flatten()
|
||||||
"""
|
|
||||||
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
|
|
||||||
pattern = pattern_provider.get_pattern( codes.shape[0] )
|
|
||||||
res, _, _ = pattern.build_pattern_sequence( codes.t()[None, :, :], self.interleaved_token, keep_only_valid_steps=True )
|
|
||||||
return res[0].t().flatten()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _deinterleave( self, codes ):
|
def _deinterleave( self, codes, length = 0 ):
|
||||||
if not self.interleave_pattern:
|
if not self.interleave_pattern:
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
return torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
|
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
|
||||||
"""
|
|
||||||
if codes.dim() == 1:
|
|
||||||
codes = torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) )
|
|
||||||
|
|
||||||
pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels )
|
|
||||||
pattern = pattern_provider.get_pattern( codes.shape[0] )
|
|
||||||
res, _, _ = pattern.revert_pattern_sequence( codes, special_token=self.interleaved_token)
|
|
||||||
return res[0].t()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -232,13 +217,13 @@ class Base(nn.Module):
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
|
|
||||||
# + tasks for each token they represent in the prom
|
# + tasks for each token they represent in the prom
|
||||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) + (1 if self.interleave_pattern else 0) # - 1 because tts is an inherent task
|
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
||||||
# +1 to include the stop token + 1 to include interleave token
|
# +1 to include the stop token + 1 to include interleave token
|
||||||
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) + (1 if self.interleave_pattern else 0) # AR requires a stop token to... know when to stop
|
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(1, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
@ -270,7 +255,6 @@ class Base(nn.Module):
|
||||||
))
|
))
|
||||||
|
|
||||||
# I imagine because each step returns `resp_level`s tokens at once, so we need to have a classifier for each level
|
# I imagine because each step returns `resp_level`s tokens at once, so we need to have a classifier for each level
|
||||||
#self.classifier = nn.ModuleList([ nn.Linear(d_model, n_resp_tokens) for _ in range(self.n_resp_levels) ]) if self.interleave_pattern else nn.Linear(d_model, n_resp_tokens)
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
self.accuracy_metric = MulticlassAccuracy(
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
|
@ -385,11 +369,6 @@ class Base(nn.Module):
|
||||||
# Remove padding
|
# Remove padding
|
||||||
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
|
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
|
||||||
|
|
||||||
if True:
|
|
||||||
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
|
||||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
|
||||||
print( [ r for r in ret ] )
|
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if targ_list is not None:
|
||||||
if any([l == 0 for l in map(len, targ_list)]):
|
if any([l == 0 for l in map(len, targ_list)]):
|
||||||
|
@ -487,6 +466,8 @@ class Base(nn.Module):
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
|
||||||
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
for n in range(max_steps // max(1, self.recurrent_chunk_size)):
|
for n in range(max_steps // max(1, self.recurrent_chunk_size)):
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
|
|
||||||
|
@ -502,6 +483,7 @@ class Base(nn.Module):
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
if self.stop_token in ri:
|
if self.stop_token in ri:
|
||||||
stopped[i] = True
|
stopped[i] = True
|
||||||
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||||
|
|
||||||
# stop token found
|
# stop token found
|
||||||
|
@ -509,12 +491,7 @@ class Base(nn.Module):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
return [self._deinterleave(self._prune(r)) for r in resps_list]
|
||||||
pruned = [self._prune(r) for r in resps_list]
|
|
||||||
print( [ r for r in pruned ] )
|
|
||||||
deinterleaved = [ self._deinterleave(r) for r in pruned ]
|
|
||||||
print( [ r for r in deinterleaved ] )
|
|
||||||
return deinterleaved
|
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
@ -548,7 +525,7 @@ def example_usage():
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=5e-5)) for name, model in models.items() })
|
||||||
|
|
||||||
train = True
|
train = True
|
||||||
|
|
||||||
|
@ -565,7 +542,7 @@ def example_usage():
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
def sample( filename, steps=450 * 4 ):
|
def sample( filename, steps=450 ):
|
||||||
AR = None
|
AR = None
|
||||||
|
|
||||||
engines.eval()
|
engines.eval()
|
||||||
|
@ -578,10 +555,10 @@ def example_usage():
|
||||||
decode_to_file(resps_list[0].cpu(), f"./data/{filename}.wav", device="cpu")
|
decode_to_file(resps_list[0].cpu(), f"./data/{filename}.wav", device="cpu")
|
||||||
|
|
||||||
if train:
|
if train:
|
||||||
sample("init", 15)
|
sample("init", 75 )
|
||||||
|
|
||||||
engines.train()
|
engines.train()
|
||||||
t = trange(100)
|
t = trange(500)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -17,8 +17,8 @@ class NAR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arch_type(self) -> str:
|
def arch_type(self) -> str:
|
||||||
if hasattr(self, "_cfg") and self._cfg:
|
if hasattr(self, "config") and self.config:
|
||||||
return self._cfg.arch_type
|
return self.config.arch_type
|
||||||
return cfg.models.nar.arch_type
|
return cfg.models.nar.arch_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -31,8 +31,8 @@ class NAR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "_cfg") and self._cfg:
|
if hasattr(self, "config") and self.config:
|
||||||
return self._cfg.resp_levels
|
return self.config.resp_levels
|
||||||
return cfg.models.nar.resp_levels
|
return cfg.models.nar.resp_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -51,6 +51,10 @@ class NAR(Base):
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interleave(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user