ugh
This commit is contained in:
parent
65a8960305
commit
cca542a4c0
|
@ -615,8 +615,6 @@ class Dataset(_Dataset):
|
|||
prom_length = 0
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
print(trim_length / cfg.dataset.frames_per_second)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
|
|
|
@ -2,24 +2,7 @@
|
|||
def get_model(config, training=True):
|
||||
name = config.name
|
||||
|
||||
if not config.experimental:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
)
|
||||
elif "len" in config.capabilities:
|
||||
if "len" in config.capabilities:
|
||||
from .nar import NAR
|
||||
model = NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
|
@ -36,7 +19,7 @@ def get_model(config, training=True):
|
|||
training = training,
|
||||
config = config,
|
||||
)
|
||||
else:
|
||||
elif config.experimental:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
n_text_tokens=config.text_tokens,
|
||||
|
@ -49,6 +32,23 @@ def get_model(config, training=True):
|
|||
|
||||
config = config,
|
||||
)
|
||||
else:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
)
|
||||
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
|
|
|
@ -279,7 +279,9 @@ class Base(nn.Module):
|
|||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.causal and "len" not in self.capabilities:
|
||||
if "len" in self.capabilities:
|
||||
return 0
|
||||
if not self.causal:
|
||||
raise ValueError("Not using stop token!")
|
||||
return self.n_audio_tokens
|
||||
|
||||
|
@ -325,9 +327,15 @@ class Base(nn.Module):
|
|||
|
||||
self.l_padding = l_padding
|
||||
|
||||
# +1 to include the stop token
|
||||
n_prom_tokens = n_audio_tokens
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
|
||||
if "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
|
||||
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
|
||||
split_classifiers = self.config.split_classifiers if self.config is not None else True
|
||||
|
@ -351,7 +359,7 @@ class Base(nn.Module):
|
|||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding_Old(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
l_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
else:
|
||||
|
@ -360,7 +368,7 @@ class Base(nn.Module):
|
|||
sums=audio_embedding_sums,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
|
||||
|
@ -634,13 +642,11 @@ class Base(nn.Module):
|
|||
|
||||
self.metrics = None
|
||||
else:
|
||||
levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
self.classifier = None
|
||||
self.classifiers = AudioClassifier( levels, d_model )
|
||||
self.classifiers = AudioClassifier( l_tokens, d_model )
|
||||
self.accuracy_metric = None
|
||||
self.precision_metric = None
|
||||
self.metrics = Metrics( levels )
|
||||
self.metrics = Metrics( l_tokens )
|
||||
|
||||
|
||||
def _forward(
|
||||
|
@ -905,7 +911,7 @@ class Base(nn.Module):
|
|||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||
self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
|
|
|
@ -145,8 +145,8 @@ class NAR(Base):
|
|||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# is training
|
||||
assert n_levels == self.n_resp_levels
|
||||
# assert n_levels == self.n_resp_levels
|
||||
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "len" if random.random() < p_len_task else "tts"
|
||||
|
@ -170,7 +170,12 @@ class NAR(Base):
|
|||
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
|
||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
||||
for i, resp in enumerate(resps_list):
|
||||
if quant_levels[i] >= resp.shape[-1]:
|
||||
quant_levels[i] = resp.shape[-1] - 1
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
|
@ -355,7 +360,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = NAR(**kwargs).to(device)
|
||||
steps = 200
|
||||
steps = 500
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
Loading…
Reference in New Issue
Block a user