forked from mrq/bitsandbytes-rocm
run backward
This commit is contained in:
parent
591f60395a
commit
2cd047e35d
|
@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
|||
assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0:
|
||||
assert mlp.fc2.state.idx is not None
|
||||
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
assert mlp.fc1.weight.device.type == "cuda"
|
||||
assert mlp.fc2.weight.device.type == "cuda"
|
||||
|
||||
if memory_efficient_backward:
|
||||
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
assert o1.requires_grad
|
||||
grad_proj = torch.randn_like(o1)
|
||||
|
||||
(o1 * grad_proj).sum().backward()
|
||||
|
||||
|
||||
|
||||
|
||||
def test_linear8bitlt_fp32_bias():
|
||||
|
|
Loading…
Reference in New Issue
Block a user