This commit is contained in:
mrq 2024-06-11 23:59:28 -05:00
parent 65a8960305
commit cca542a4c0
4 changed files with 44 additions and 35 deletions

View File

@ -615,8 +615,6 @@ class Dataset(_Dataset):
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
print(trim_length / cfg.dataset.frames_per_second)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:

View File

@ -2,24 +2,7 @@
def get_model(config, training=True):
name = config.name
if not config.experimental:
from .ar_nar import AR_NAR
model = AR_NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
)
elif "len" in config.capabilities:
if "len" in config.capabilities:
from .nar import NAR
model = NAR(
n_text_tokens=config.text_tokens,
@ -36,7 +19,7 @@ def get_model(config, training=True):
training = training,
config = config,
)
else:
elif config.experimental:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=config.text_tokens,
@ -49,6 +32,23 @@ def get_model(config, training=True):
config = config,
)
else:
from .ar_nar import AR_NAR
model = AR_NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -279,7 +279,9 @@ class Base(nn.Module):
@property
def stop_token(self):
if not self.causal and "len" not in self.capabilities:
if "len" in self.capabilities:
return 0
if not self.causal:
raise ValueError("Not using stop token!")
return self.n_audio_tokens
@ -325,9 +327,15 @@ class Base(nn.Module):
self.l_padding = l_padding
# +1 to include the stop token
n_prom_tokens = n_audio_tokens
if "len" not in self.capabilities:
# +1 to include the stop token
n_resp_tokens = n_audio_tokens + self.causal_size
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens
l_tokens = [n_resp_tokens] * self.n_resp_levels
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
split_classifiers = self.config.split_classifiers if self.config is not None else True
@ -351,7 +359,7 @@ class Base(nn.Module):
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
l_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
else:
@ -360,7 +368,7 @@ class Base(nn.Module):
sums=audio_embedding_sums,
)
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
l_tokens, d_model,
sums=audio_embedding_sums,
)
@ -634,13 +642,11 @@ class Base(nn.Module):
self.metrics = None
else:
levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
self.classifier = None
self.classifiers = AudioClassifier( levels, d_model )
self.classifiers = AudioClassifier( l_tokens, d_model )
self.accuracy_metric = None
self.precision_metric = None
self.metrics = Metrics( levels )
self.metrics = Metrics( l_tokens )
def _forward(
@ -905,7 +911,7 @@ class Base(nn.Module):
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)

View File

@ -145,8 +145,8 @@ class NAR(Base):
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# is training
assert n_levels == self.n_resp_levels
# assert n_levels == self.n_resp_levels
# to-do: make this YAML configurable
def sample_task():
return "len" if random.random() < p_len_task else "tts"
@ -170,7 +170,12 @@ class NAR(Base):
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i, resp in enumerate(resps_list):
if quant_levels[i] >= resp.shape[-1]:
quant_levels[i] = resp.shape[-1] - 1
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
@ -355,7 +360,7 @@ def example_usage():
"""
model = NAR(**kwargs).to(device)
steps = 200
steps = 500
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""