Merge pull request #9 from ditschuk/fix_adam_imports

Add missing imports to adam
This commit is contained in:
Tim Dettmers 2021-11-15 07:58:44 -08:00 committed by GitHub
commit 037022e878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,12 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
import os
import torch import torch
import torch.distributed as dist
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F import bitsandbytes.functional as F
@ -220,9 +225,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != '' and state['step'] % 100 == 0: if self.savedir != '' and state['step'] % 100 == 0:
if not os.path.exists(self.savedir): os.makedirs(self.savedir) if not os.path.exists(self.savedir): os.makedirs(self.savedir)
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
torch.save(e, pathe) torch.save(e, pathe)
torch.save(rele, pathrele) torch.save(rele, pathrele)
torch.save(counts, pathcounts) torch.save(counts, pathcounts)