forked from mrq/DL-Art-School
Add evaluation logic for gpt_asr_hf2
This commit is contained in:
parent
47fe032a3d
commit
04454ee63a
|
@ -45,7 +45,9 @@ class MelEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.encoder(x)
|
for e in self.encoder:
|
||||||
|
x = e(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
|
@ -262,22 +264,21 @@ class GptAsrHf2(nn.Module):
|
||||||
|
|
||||||
mel_emb = self.mel_encoder(mel_inputs)
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
assert mel_emb.shape[-1] <= self.max_mel_frames
|
assert mel_emb.shape[-1] <= self.max_mel_frames
|
||||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
self.inference_model.store_mel_emb(mel_emb)
|
self.inference_model.store_mel_emb(mel_emb)
|
||||||
|
|
||||||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||||
if cond_text is None:
|
if cond_text is None:
|
||||||
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||||
fake_inputs[:,-1] = self.NUMBER_SYMBOLS
|
fake_inputs[:,-1] = self.START_TOKEN
|
||||||
else:
|
else:
|
||||||
cond_used = 10
|
cond_used = 10
|
||||||
fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||||
fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS
|
fake_inputs[:,-1-cond_used] = self.START_TOKEN
|
||||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_TOKEN, pad_token_id=0, eos_token_id=0,
|
||||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||||
return gen[:, self.max_mel_frames:]
|
return gen[:, self.max_mel_frames:]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf2.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
|
Loading…
Reference in New Issue
Block a user