This commit is contained in:
James Betker 2022-02-12 20:01:04 -07:00
parent 29534180b2
commit e16af944c0

View File

@ -15,6 +15,12 @@ def create_batch_size_optimizer(opt_train):
return MegabatchBatchSizeOptimizer(opt_train)
def grad(p):
if p.grad is None:
return torch.tensor(0)
return p.grad.detach().clone()
# Base class for BatchSizeOptimizers.
class BatchSizeOptimizer:
def focus(self, optimizer):
@ -66,7 +72,8 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
def focus(self, model):
if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished:
all_params = list(filter(lambda t: '.weight' in t[0] and t[1].requires_grad, list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :)
all_params = list(filter(lambda t: '.weight' in t[0] and not hasattr(t[1].requires_grad, 'DO_NOT_TRAIN'),
list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :)
num_params = min(len(all_params), self.parameters_to_poll)
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
@ -79,7 +86,11 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
def should_step(self, it):
model = self.current_model
model._gradient_direction_optimizer_step += 1
cur_grads = [p.grad.detach().clone() for k, p in model._gradient_direction_optimizer_params]
cur_grads = [grad(p) for k, p in model._gradient_direction_optimizer_params]
for cg in cur_grads:
if torch.any(torch.isnan(cg)):
print("BSO: found NaN. Passing it off to the GradScaler..")
return True
if model._gradient_direction_optimizer_prior_grads is not None:
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]