saner mask creation? (it doesnt matter, kv cache wont work)
This commit is contained in:
parent
ded746e157
commit
3826f9bae4
|
@ -513,6 +513,8 @@ def get_task_symmap():
|
||||||
}
|
}
|
||||||
|
|
||||||
def _replace_file_extension(path, suffix):
|
def _replace_file_extension(path, suffix):
|
||||||
|
if not isinstance( path, Path ):
|
||||||
|
path = Path(path)
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def _get_quant_extension():
|
def _get_quant_extension():
|
||||||
|
|
|
@ -72,7 +72,7 @@ def process(
|
||||||
|
|
||||||
# easy way to load the model and handle encoding audio
|
# easy way to load the model and handle encoding audio
|
||||||
if tts is None:
|
if tts is None:
|
||||||
tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype )
|
tts = init_tts( config=yaml, restart=False, device=device, dtype=dtype )
|
||||||
|
|
||||||
features = { key: None for key in metadata_keys }
|
features = { key: None for key in metadata_keys }
|
||||||
|
|
||||||
|
|
|
@ -78,9 +78,11 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
||||||
l = list(map(len, x_list))
|
l = list(map(len, x_list))
|
||||||
x = rearrange(pad_sequence(x_list), pattern)
|
x = rearrange(pad_sequence(x_list), pattern)
|
||||||
m = _create_mask(l, x_list[0].device)
|
m = _create_mask(l, x_list[0].device)
|
||||||
|
"""
|
||||||
m = m.t().unsqueeze(-1) # (t b 1)
|
m = m.t().unsqueeze(-1) # (t b 1)
|
||||||
m = rearrange(m, pattern)
|
m = rearrange(m, pattern)
|
||||||
m = m.to(x)
|
"""
|
||||||
|
m = m.to(x).int()
|
||||||
return x, m
|
return x, m
|
||||||
|
|
||||||
def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
||||||
|
@ -835,7 +837,7 @@ class Base(nn.Module):
|
||||||
output_hidden_states = False,
|
output_hidden_states = False,
|
||||||
):
|
):
|
||||||
x = inputs
|
x = inputs
|
||||||
m = mask.squeeze(-1).int()
|
m = mask #.squeeze(-1).int()
|
||||||
|
|
||||||
aux_loss = None
|
aux_loss = None
|
||||||
attentions = None
|
attentions = None
|
||||||
|
@ -844,7 +846,7 @@ class Base(nn.Module):
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
attention_mask=m,
|
#attention_mask=m,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
past_key_values=state,
|
past_key_values=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1475,7 +1477,9 @@ class Base(nn.Module):
|
||||||
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
|
return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"]
|
||||||
|
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
|
||||||
|
x, mask = list_to_tensor(x_list)
|
||||||
|
m = mask.unsqueeze(dim=-1)
|
||||||
|
|
||||||
training = self.training
|
training = self.training
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -1501,16 +1505,17 @@ class Base(nn.Module):
|
||||||
# pad mask
|
# pad mask
|
||||||
shape[2] = 1
|
shape[2] = 1
|
||||||
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||||
m = torch.cat([m, padding], dim=1)
|
mask = torch.cat([mask, padding], dim=1)
|
||||||
|
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||||
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
mask=m,
|
mask=mask,
|
||||||
state=state,
|
state=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
|
@ -1530,7 +1535,7 @@ class Base(nn.Module):
|
||||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||||
# to-do: piece-wise classification, now that there's a head for text
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
# although again, one single monolithic head would be preferable instead......
|
# although again, one single monolithic head would be preferable instead......
|
||||||
if self.classifiers is not None:
|
elif self.classifiers is not None:
|
||||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user