nasty bandaid because some of my DAC dataset only has 8 RVQ levels instead of the full 9
This commit is contained in:
parent
c4dd523b6f
commit
a8718d35a4
|
@ -134,6 +134,8 @@ class AR_NAR(Base):
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
training: bool | None = None,
|
||||||
|
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
max_levels: int = 0,
|
max_levels: int = 0,
|
||||||
max_resp_context: int = -1,
|
max_resp_context: int = -1,
|
||||||
|
@ -157,8 +159,11 @@ class AR_NAR(Base):
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
n_levels = next(iter(n_levels_set))
|
n_levels = next(iter(n_levels_set))
|
||||||
|
|
||||||
|
if training is None:
|
||||||
|
training = n_levels == self.n_resp_levels
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if n_levels == self.n_resp_levels:
|
if training:
|
||||||
# to-do: make this YAML configurable
|
# to-do: make this YAML configurable
|
||||||
def sample_task():
|
def sample_task():
|
||||||
return "tts"
|
return "tts"
|
||||||
|
|
|
@ -64,6 +64,8 @@ def train_feeder(engine, batch):
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
resps_list=batch["resps"],
|
resps_list=batch["resps"],
|
||||||
lang_list=batch["lang"],
|
lang_list=batch["lang"],
|
||||||
|
|
||||||
|
training=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user