added a guard for hypernet training that will stop early if weights are getting no gradients
This commit is contained in:
parent
1cd3ed7def
commit
7fd90128eb
|
@ -310,6 +310,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
@ -332,8 +334,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
weights[0].grad = None
|
||||
loss.backward()
|
||||
|
||||
if weights[0].grad is None:
|
||||
steps_without_grad += 1
|
||||
else:
|
||||
steps_without_grad = 0
|
||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||
|
||||
optimizer.step()
|
||||
|
||||
mean_loss = losses.mean()
|
||||
if torch.isnan(mean_loss):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user