nasty bandaid because some of my DAC dataset only has 8 RVQ levels instead of the full 9
This commit is contained in:
parent
c4dd523b6f
commit
a8718d35a4
|
@ -134,6 +134,8 @@ class AR_NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
max_resp_context: int = -1,
|
||||
|
@ -157,8 +159,11 @@ class AR_NAR(Base):
|
|||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
if training is None:
|
||||
training = n_levels == self.n_resp_levels
|
||||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
if training:
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "tts"
|
||||
|
|
|
@ -64,6 +64,8 @@ def train_feeder(engine, batch):
|
|||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
lang_list=batch["lang"],
|
||||
|
||||
training=True,
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
|
|
Loading…
Reference in New Issue
Block a user