t5 model fix

This commit is contained in:
Artidoro Pagnoni 2023-02-27 14:23:21 -08:00
parent 9851a10b46
commit 6c31a5fe99

View File

@ -190,10 +190,10 @@ class LinearFP4(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
inp_dtype = x.dtype inp_dtype = x.dtype
x = x.to(torch.float16) x = x.to(torch.float16)
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state) bias = None if self.bias is None else self.bias.half()
out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype) out = out.to(inp_dtype)
return out return out