From 6b43915eb89dd4aa01362eb334e1b0d601d8d91b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 28 May 2022 22:27:45 -0600 Subject: [PATCH] support projecting to vectors --- codes/models/audio/mel2vec.py | 9 ++++++--- .../audio/preparation/combine_phonetic_and_text.py | 2 +- codes/trainer/injectors/audio_injectors.py | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 5972c8ac..251767e7 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -653,11 +653,14 @@ class ContrastiveTrainingWrapper(nn.Module): } return groups - def get_codes(self, mel): + def get_codes(self, mel, project=False): proj = self.m2v.input_blocks(mel).permute(0,2,1) _, proj = self.m2v.projector(proj) - codes = self.quantizer.get_codes(proj) - return codes + if project: + proj, _ = self.quantizer(proj) + return proj + else: + return self.quantizer.get_codes(proj) def reconstruct(self, mel): proj = self.m2v.input_blocks(mel).permute(0,2,1) diff --git a/codes/scripts/audio/preparation/combine_phonetic_and_text.py b/codes/scripts/audio/preparation/combine_phonetic_and_text.py index 587e4f78..cd3d582c 100644 --- a/codes/scripts/audio/preparation/combine_phonetic_and_text.py +++ b/codes/scripts/audio/preparation/combine_phonetic_and_text.py @@ -1,7 +1,7 @@ import os if __name__ == '__main__': - basepath = 'Y:/clips/podcasts-0' + basepath = 'Y:\\bigasr_dataset\\hifi_tts' english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv') if not os.path.exists(english_file): diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 98a52c19..ea1daed4 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -330,13 +330,14 @@ class Mel2vecCodesInjector(Injector): self.m2v = get_music_codegen() del self.m2v.m2v.encoder # This is a big memory sink which will not get used. self.needs_move = True + self.inj_vector = opt_get(opt, ['vector'], False) def forward(self, state): mels = state[self.input] with torch.no_grad(): if self.needs_move: self.m2v = self.m2v.to(mels.device) - codes = self.m2v.get_codes(mels) + codes = self.m2v.get_codes(mels, project=self.inj_vector) return {self.output: codes}