reverted automatically disabling split loss calc, since it seems that it's actually cacling loss on prom causes the oddities, maybe
This commit is contained in:
parent
8cf176ab46
commit
c1fcd889d5
@ -213,7 +213,7 @@ class Model:
|
|||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
loss_factors: dict = field(default_factory=lambda: {}) # "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||||
kv_heads: int = 0
|
kv_heads: int = 0
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
@ -162,7 +162,7 @@ def train(
|
|||||||
|
|
||||||
#batch = to_device(batch, torch.cuda.current_device())
|
#batch = to_device(batch, torch.cuda.current_device())
|
||||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||||
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
|
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths) * world_size()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stats['batch'] = {
|
stats['batch'] = {
|
||||||
|
Loading…
Reference in New Issue
Block a user