21 lines
328 B
Python
21 lines
328 B
Python
|
import bitsandbytes as bnb
|
||
|
import torch
|
||
|
|
||
|
p = torch.nn.Parameter(torch.rand(10,10).cuda())
|
||
|
a = torch.rand(10,10).cuda()
|
||
|
|
||
|
p1 = p.data.sum().item()
|
||
|
|
||
|
adam = bnb.optim.Adam([p])
|
||
|
|
||
|
out = a*p
|
||
|
loss = out.sum()
|
||
|
loss.backward()
|
||
|
adam.step()
|
||
|
|
||
|
p2 = p.data.sum().item()
|
||
|
|
||
|
assert p1 != p2
|
||
|
print('SUCCESS!')
|
||
|
print('Installation was successful!')
|