BSO fix
This commit is contained in:
parent
29534180b2
commit
e16af944c0
|
@ -15,6 +15,12 @@ def create_batch_size_optimizer(opt_train):
|
||||||
return MegabatchBatchSizeOptimizer(opt_train)
|
return MegabatchBatchSizeOptimizer(opt_train)
|
||||||
|
|
||||||
|
|
||||||
|
def grad(p):
|
||||||
|
if p.grad is None:
|
||||||
|
return torch.tensor(0)
|
||||||
|
return p.grad.detach().clone()
|
||||||
|
|
||||||
|
|
||||||
# Base class for BatchSizeOptimizers.
|
# Base class for BatchSizeOptimizers.
|
||||||
class BatchSizeOptimizer:
|
class BatchSizeOptimizer:
|
||||||
def focus(self, optimizer):
|
def focus(self, optimizer):
|
||||||
|
@ -66,7 +72,8 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
|
|
||||||
def focus(self, model):
|
def focus(self, model):
|
||||||
if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished:
|
if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished:
|
||||||
all_params = list(filter(lambda t: '.weight' in t[0] and t[1].requires_grad, list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :)
|
all_params = list(filter(lambda t: '.weight' in t[0] and not hasattr(t[1].requires_grad, 'DO_NOT_TRAIN'),
|
||||||
|
list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :)
|
||||||
num_params = min(len(all_params), self.parameters_to_poll)
|
num_params = min(len(all_params), self.parameters_to_poll)
|
||||||
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
|
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
|
||||||
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
|
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
|
||||||
|
@ -79,7 +86,11 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
def should_step(self, it):
|
def should_step(self, it):
|
||||||
model = self.current_model
|
model = self.current_model
|
||||||
model._gradient_direction_optimizer_step += 1
|
model._gradient_direction_optimizer_step += 1
|
||||||
cur_grads = [p.grad.detach().clone() for k, p in model._gradient_direction_optimizer_params]
|
cur_grads = [grad(p) for k, p in model._gradient_direction_optimizer_params]
|
||||||
|
for cg in cur_grads:
|
||||||
|
if torch.any(torch.isnan(cg)):
|
||||||
|
print("BSO: found NaN. Passing it off to the GradScaler..")
|
||||||
|
return True
|
||||||
if model._gradient_direction_optimizer_prior_grads is not None:
|
if model._gradient_direction_optimizer_prior_grads is not None:
|
||||||
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
||||||
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user