support projecting to vectors
This commit is contained in:
parent
86694aef4e
commit
6b43915eb8
|
@ -653,11 +653,14 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def get_codes(self, mel):
|
||||
def get_codes(self, mel, project=False):
|
||||
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||
_, proj = self.m2v.projector(proj)
|
||||
codes = self.quantizer.get_codes(proj)
|
||||
return codes
|
||||
if project:
|
||||
proj, _ = self.quantizer(proj)
|
||||
return proj
|
||||
else:
|
||||
return self.quantizer.get_codes(proj)
|
||||
|
||||
def reconstruct(self, mel):
|
||||
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
basepath = 'Y:/clips/podcasts-0'
|
||||
basepath = 'Y:\\bigasr_dataset\\hifi_tts'
|
||||
|
||||
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
||||
if not os.path.exists(english_file):
|
||||
|
|
|
@ -330,13 +330,14 @@ class Mel2vecCodesInjector(Injector):
|
|||
self.m2v = get_music_codegen()
|
||||
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
|
||||
self.needs_move = True
|
||||
self.inj_vector = opt_get(opt, ['vector'], False)
|
||||
|
||||
def forward(self, state):
|
||||
mels = state[self.input]
|
||||
with torch.no_grad():
|
||||
if self.needs_move:
|
||||
self.m2v = self.m2v.to(mels.device)
|
||||
codes = self.m2v.get_codes(mels)
|
||||
codes = self.m2v.get_codes(mels, project=self.inj_vector)
|
||||
return {self.output: codes}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user