Add missing imports to adam
This commit is contained in:
parent
22b2877c7f
commit
56f5274848
|
@ -2,7 +2,12 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
|
@ -220,9 +225,9 @@ class AnalysisAdam(torch.optim.Optimizer):
|
|||
if self.savedir != '' and state['step'] % 100 == 0:
|
||||
if not os.path.exists(self.savedir): os.makedirs(self.savedir)
|
||||
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
|
||||
pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
|
||||
pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
|
||||
pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
|
||||
pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
|
||||
pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
|
||||
pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
|
||||
torch.save(e, pathe)
|
||||
torch.save(rele, pathrele)
|
||||
torch.save(counts, pathcounts)
|
||||
|
|
Loading…
Reference in New Issue
Block a user