forked from mrq/bitsandbytes-rocm
run backward
This commit is contained in:
parent
591f60395a
commit
2cd047e35d
|
@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
assert mlp.fc1.state.idx is not None
|
assert mlp.fc1.state.idx is not None
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
assert mlp.fc2.state.idx is not None
|
assert mlp.fc2.state.idx is not None
|
||||||
|
|
||||||
assert mlp.fc1.weight.dtype == torch.int8
|
assert mlp.fc1.weight.dtype == torch.int8
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
assert mlp.fc1.weight.device.type == "cuda"
|
assert mlp.fc1.weight.device.type == "cuda"
|
||||||
assert mlp.fc2.weight.device.type == "cuda"
|
assert mlp.fc2.weight.device.type == "cuda"
|
||||||
|
|
||||||
|
if memory_efficient_backward:
|
||||||
|
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
|
||||||
|
o1 = mlp(b1)
|
||||||
|
assert o1.dtype == torch.float16
|
||||||
|
assert o1.requires_grad
|
||||||
|
grad_proj = torch.randn_like(o1)
|
||||||
|
|
||||||
|
(o1 * grad_proj).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_linear8bitlt_fp32_bias():
|
def test_linear8bitlt_fp32_bias():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user