ugh
This commit is contained in:
parent
75f9a6ab93
commit
4b2d6f5595
|
@ -951,16 +951,16 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None
|
quant_levels: int | list[int] | Tensor | None = None
|
||||||
):
|
):
|
||||||
if text_list:
|
if text_list and text_list[0] is not None:
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
elif raw_text_list:
|
elif raw_text_list and raw_text_list[0] is not None:
|
||||||
device = raw_text_list[0].device
|
device = raw_text_list[0].device
|
||||||
batch_size = len(raw_text_list)
|
batch_size = len(raw_text_list)
|
||||||
elif proms_list:
|
elif proms_list and proms_list[0] is not None:
|
||||||
device = proms_list[0].device
|
device = proms_list[0].device
|
||||||
batch_size = len(proms_list)
|
batch_size = len(proms_list)
|
||||||
elif resps_list:
|
elif resps_list and resps_list[0] is not None:
|
||||||
device = resps_list[0].device
|
device = resps_list[0].device
|
||||||
batch_size = len(resps_list)
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user