nasty bandaid because some of my DAC dataset only has 8 RVQ levels instead of the full 9

This commit is contained in:
mrq 2024-06-29 10:16:37 -05:00
parent c4dd523b6f
commit a8718d35a4
2 changed files with 8 additions and 1 deletions

View File

@ -134,6 +134,8 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | None = None,
max_steps: int = 1000,
max_levels: int = 0,
max_resp_context: int = -1,
@ -157,8 +159,11 @@ class AR_NAR(Base):
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
if training is None:
training = n_levels == self.n_resp_levels
# is training
if n_levels == self.n_resp_levels:
if training:
# to-do: make this YAML configurable
def sample_task():
return "tts"

View File

@ -64,6 +64,8 @@ def train_feeder(engine, batch):
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
training=True,
)
losses = engine.gather_attribute("loss")