Fixed unsafe use of eval. #8
This commit is contained in:
parent
b3fe8a6d0f
commit
108cf9fc1f
|
@ -41,8 +41,9 @@ Docs:
|
||||||
### 0.26.0:
|
### 0.26.0:
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
|
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer.
|
||||||
- Added AdamW (copy of Adam with weight decay init 1e-2)
|
- Added AdamW (copy of Adam with weight decay init 1e-2). #10
|
||||||
|
|
||||||
Bug fixes:
|
Bug fixes:
|
||||||
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam
|
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
|
||||||
|
- Fixed an unsafe use of eval. #8
|
||||||
|
|
|
@ -242,8 +242,9 @@ class Optimizer2State(Optimizer8bit):
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
if isinstance(betas, str):
|
if isinstance(betas, str):
|
||||||
betas = eval(betas)
|
# format: '(beta1, beta2)'
|
||||||
print(betas, 'parsed')
|
betas = betas.replace('(', '').replace(')', '').strip().split(',')
|
||||||
|
betas = [float(b) for b in betas]
|
||||||
for i in range(len(betas)):
|
for i in range(len(betas)):
|
||||||
if not 0.0 <= betas[i] < 1.0:
|
if not 0.0 <= betas[i] < 1.0:
|
||||||
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
||||||
|
|
|
@ -392,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
||||||
#assert s < 3.9
|
#assert s < 3.9
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_betas():
|
||||||
|
betas = (0.80, 0.95)
|
||||||
|
strbetas = '(0.80, 0.95)'
|
||||||
|
|
||||||
|
layer = torch.nn.Linear(10, 10)
|
||||||
|
|
||||||
|
base = bnb.optim.Adam(layer.parameters(), betas=betas)
|
||||||
|
strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
|
||||||
|
assert base.defaults['betas'][0] == 0.8
|
||||||
|
assert base.defaults['betas'][1] == 0.95
|
||||||
|
assert strbase.defaults['betas'][0] == 0.8
|
||||||
|
assert strbase.defaults['betas'][1] == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user