From 3044339bdddea29398e84fa1fce2db813a4ef7cc Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Tue, 26 Sep 2023 14:35:36 +0200 Subject: changed checkpoint format --- dnn/torch/neural-pitch/evaluation.py | 126 ++------------------- .../neural-pitch/export_neuralpitch_weights.py | 5 +- dnn/torch/neural-pitch/neural_pitch_update.py | 15 ++- dnn/torch/neural-pitch/training.py | 17 ++- 4 files changed, 30 insertions(+), 133 deletions(-) diff --git a/dnn/torch/neural-pitch/evaluation.py b/dnn/torch/neural-pitch/evaluation.py index 0369cafa..b7f8d318 100644 --- a/dnn/torch/neural-pitch/evaluation.py +++ b/dnn/torch/neural-pitch/evaluation.py @@ -120,31 +120,9 @@ def rpa(model,device = 'cpu',data_format = 'if'): cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int') - # if (model == 'penn'): - # model_frequency, _ = penn.from_audio( - # torch.from_numpy(audio).unsqueeze(0).float(), - # 16000, - # hopsize=0.01, - # fmin=(16000.0/256), - # fmax=500, - # checkpoint=penn.DEFAULT_CHECKPOINT, - # batch_size=32, - # pad=True, - # interp_unvoiced_at=0.065, - # gpu=0) - # model_frequency = model_frequency.cpu().detach().squeeze().numpy() - # model_cents = 1200*np.log2(model_frequency/(16000/256)) - - # elif (model == 'crepe'): - # _, model_frequency, _, _ = crepe.predict(audio, 16000, viterbi=vflag,center=True,verbose=0) - # lpcnet_file_name = '/home/ubuntu/Code/Datasets/SPEECH_DATA/lpcnet_f0_16k_residual/' + file_name + '_f0.f32' - # period_lpcnet = np.fromfile(lpcnet_file_name, dtype='float32') - # model_frequency = 16000/(period_lpcnet + 1.0e-6) - # model_cents = 1200*np.log2(model_frequency/(16000/256)) - # else: + model_cents = model(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device)) model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy() - # model_cents = np.roll(model_cents,-1*3) num_frames = min(cent.shape[0],model_cents.shape[0]) pitch = pitch[:num_frames] @@ -158,131 +136,62 @@ def rpa(model,device = 'cpu',data_format = 'if'): voicing_all[force_out_of_pitch] = 0 C_all = C_all + np.where(voicing_all != 0)[0].shape[0] - # list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0])) list_rca_model_all.append(rca(cent,model_cents,voicing_all,thresh)) - # list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents)))) if "mic_M" in audio_file: - # list_rca_male_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0])) list_rca_male_all.append(rca(cent,model_cents,voicing_all,thresh)) C_all_m = C_all_m + np.where(voicing_all != 0)[0].shape[0] else: - # list_rca_female_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0])) list_rca_female_all.append(rca(cent,model_cents,voicing_all,thresh)) C_all_f = C_all_f + np.where(voicing_all != 0)[0].shape[0] - """ - # Low pitch estimation - voicing_lp = np.copy(voicing) - force_out_of_pitch = np.where(np.logical_or(pitch < 65,pitch > 125)==True) - voicing_lp[force_out_of_pitch] = 0 - C_lp = C_lp + np.where(voicing_lp != 0)[0].shape[0] - - # list_rca_model_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0])) - list_rca_model_lp.append(rca(cent,model_cents,voicing_lp,thresh)) - - if "mic_M" in audio_file: - # list_rca_male_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0])) - list_rca_male_lp.append(rca(cent,model_cents,voicing_lp,thresh)) - C_lp_m = C_lp_m + np.where(voicing_lp != 0)[0].shape[0] - else: - # list_rca_female_lp.append(sweep_rca(cent,model_cents,voicing_lp,thresh,[0])) - list_rca_female_lp.append(rca(cent,model_cents,voicing_lp,thresh)) - C_lp_f = C_lp_f + np.where(voicing_lp != 0)[0].shape[0] - - # High pitch estimation - voicing_hp = np.copy(voicing) - force_out_of_pitch = np.where(np.logical_or(pitch < 125,pitch > 500)==True) - voicing_hp[force_out_of_pitch] = 0 - C_hp = C_hp + np.where(voicing_hp != 0)[0].shape[0] - - # list_rca_model_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0])) - list_rca_model_hp.append(rca(cent,model_cents,voicing_hp,thresh)) - - if "mic_M" in audio_file: - # list_rca_male_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0])) - list_rca_male_hp.append(rca(cent,model_cents,voicing_hp,thresh)) - C_hp_m = C_hp_m + np.where(voicing_hp != 0)[0].shape[0] - else: - # list_rca_female_hp.append(sweep_rca(cent,model_cents,voicing_hp,thresh,[0])) - list_rca_female_hp.append(rca(cent,model_cents,voicing_hp,thresh)) - C_hp_f = C_hp_f + np.where(voicing_hp != 0)[0].shape[0] - # list_rca_model.append(acc_model) - # list_rca_crepe.append(acc_crepe) - # list_rca_lpcnet.append(acc_lpcnet) - # list_rca_penn.append(acc_penn) - """ - - # list_rca_crepe = np.array(list_rca_crepe) - # list_rca_model_lp = np.array(list_rca_model_lp) - # list_rca_male_lp = np.array(list_rca_male_lp) - # list_rca_female_lp = np.array(list_rca_female_lp) - - # list_rca_model_hp = np.array(list_rca_model_hp) - # list_rca_male_hp = np.array(list_rca_male_hp) - # list_rca_female_hp = np.array(list_rca_female_hp) - list_rca_model_all = np.array(list_rca_model_all) list_rca_male_all = np.array(list_rca_male_all) list_rca_female_all = np.array(list_rca_female_all) - # list_rca_lpcnet = np.array(list_rca_lpcnet) - # list_rca_penn = np.array(list_rca_penn) + x = PrettyTable() x.field_names = ["Experiment", "Mean RPA"] x.add_row(["Both all pitches", np.sum(list_rca_model_all)/C_all]) - # x.add_row(["Both low pitches", np.sum(list_rca_model_lp)/C_lp]) - # x.add_row(["Both high pitches", np.sum(list_rca_model_hp)/C_hp]) x.add_row(["Male all pitches", np.sum(list_rca_male_all)/C_all_m]) - # x.add_row(["Male low pitches", np.sum(list_rca_male_lp)/C_lp_m]) - # x.add_row(["Male high pitches", np.sum(list_rca_male_hp)/C_hp_m]) x.add_row(["Female all pitches", np.sum(list_rca_female_all)/C_all_f]) - # x.add_row(["Female low pitches", np.sum(list_rca_female_lp)/C_lp_f]) - # x.add_row(["Female high pitches", np.sum(list_rca_female_hp)/C_hp_f]) print(x) return None -def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50): +def cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None, list_snr = [-20,-15,-10,-5,0,5,10,15,20], ptdb_dataset_path = None,fraction = 0.1,thresh = 50): """ - Cycle through SNR evaluation for list of .pth files + Cycle through SNR evaluation for list of checkpoints """ - # list_files = glob.glob('/home/ubuntu/Code/Datasets/SPEECH DATA/combined_mic_16k_raw/*.raw') - # dir_f0 = '/home/ubuntu/Code/Datasets/SPEECH DATA/combine_f0_ptdb/' - # random_shuffle = list(np.random.permutation(len(list_files))) list_files = glob.glob(ptdb_dataset_path + 'combined_mic_16k/*.raw') dir_f0 = ptdb_dataset_path + 'combined_reference_f0/' random.shuffle(list_files) list_files = list_files[:(int)(fraction*len(list_files))] - # list_nfiles = ['DKITCHEN','NFIELD','OHALLWAY','PCAFETER','SPSQUARE','TCAR','DLIVING','NPARK','OMEETING','PRESTO','STRAFFIC','TMETRO','DWASHING','NRIVER','OOFFICE','PSTATION','TBUS'] - dict_models = {} list_snr.append(np.inf) - # thresh = 50 - for f in list_files_pth: + for f in checkpoint_list: if (f!='crepe') and (f!='lpcnet'): - fname = os.path.basename(f).split('_')[0] + '_' + os.path.basename(f).split('_')[-1][:-4] - config_path = os.path.dirname(f) + '/' + os.path.basename(f).split('_')[0] + '_' + 'config_' + os.path.basename(f).split('_')[-1][:-4] + '.json' - with open(config_path) as json_file: - dict_params = json.load(json_file) + + checkpoint = torch.load(f, map_location='cpu') + dict_params = checkpoint['config'] if dict_params['data_format'] == 'if': from models import large_if_ccode as model - pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim']) elif dict_params['data_format'] == 'xcorr': from models import large_xcorr as model - pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']) else: from models import large_joint as model - pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']) - pitch_nn.load_state_dict(torch.load(f)) + pitch_nn.load_state_dict(checkpoint['state_dict']) N = dict_params['window_size'] H = dict_params['hop_factor'] @@ -356,15 +265,8 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l cent = np.rint(1200*np.log2(np.divide(pitch, (16000/256), out=np.zeros_like(pitch), where=pitch!=0) + 1.0e-8)).astype('int') - # if os.path.basename(f) == 'crepe': - # elif (model == 'crepe'): - # _, model_frequency, _, _ = crepe.predict(np.concatenate([np.zeros(80),audio]), 16000, viterbi=True,center=True,verbose=0) - # model_cents = 1200*np.log2(model_frequency/(16000/256)) - # else: - # else: model_cents = pitch_nn(torch.from_numpy(np.copy(np.expand_dims(feature,0))).float().to(device)) model_cents = 20*model_cents.argmax(dim=1).cpu().detach().squeeze().numpy() - # model_cents = np.roll(model_cents,-1*3) num_frames = min(cent.shape[0],model_cents.shape[0]) pitch = pitch[:num_frames] @@ -378,9 +280,7 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l voicing_all[force_out_of_pitch] = 0 C_all = C_all + np.where(voicing_all != 0)[0].shape[0] - # list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0])) C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh) - # list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents)))) list_mean.append(C_correct/C_all) else: fname = f @@ -453,9 +353,7 @@ def cycle_eval(list_files_pth, noise_type = 'synthetic', noise_dataset = None, l voicing_all[force_out_of_pitch] = 0 C_all = C_all + np.where(voicing_all != 0)[0].shape[0] - # list_rca_model_all.append(sweep_rca(cent,model_cents,voicing_all,thresh,[0])) C_correct = C_correct + rca(cent,model_cents,voicing_all,thresh) - # list_rca_model_all.append(np.count_nonzero(np.where(np.abs(cent - model_cents)))) list_mean.append(C_correct/C_all) dict_models[fname] = {} dict_models[fname]['list_SNR'] = list_mean[:-1] diff --git a/dnn/torch/neural-pitch/export_neuralpitch_weights.py b/dnn/torch/neural-pitch/export_neuralpitch_weights.py index be374281..a56784a9 100644 --- a/dnn/torch/neural-pitch/export_neuralpitch_weights.py +++ b/dnn/torch/neural-pitch/export_neuralpitch_weights.py @@ -36,7 +36,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) parser = argparse.ArgumentParser() -parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint') +parser.add_argument('checkpoint', type=str, help='model checkpoint') parser.add_argument('output_dir', type=str, help='output folder') args = parser.parse_args() @@ -85,5 +85,6 @@ if __name__ == "__main__": os.makedirs(args.output_dir, exist_ok=True) model = large_if_ccode() - model.load_state_dict(torch.load(args.checkpoint,map_location='cpu')) + checkpoint = torch.load(args.checkpoint ,map_location='cpu') + model.load_state_dict(checkpoint['state_dict']) c_export(args, model) diff --git a/dnn/torch/neural-pitch/neural_pitch_update.py b/dnn/torch/neural-pitch/neural_pitch_update.py index 5d8074cf..a72abee6 100644 --- a/dnn/torch/neural-pitch/neural_pitch_update.py +++ b/dnn/torch/neural-pitch/neural_pitch_update.py @@ -4,7 +4,7 @@ parser = argparse.ArgumentParser() parser.add_argument('features', type=str, help='Features generated from dump_data') parser.add_argument('data', type=str, help='Data generated from dump_data (offset by 5ms)') parser.add_argument('output', type=str, help='output .f32 feature file with replaced neural pitch') -parser.add_argument('pth_file', type=str, help='.pth file to use for pitch') +parser.add_argument('checkpoint', type=str, help='model checkpoint file') parser.add_argument('path_lpcnet_extractor', type=str, help='path to LPCNet extractor object file (generated on compilation)') parser.add_argument('--device', type=str, help='compute device',default = None,required = False) parser.add_argument('--replace_xcorr', type = bool, default = False, help='Replace LPCNet xcorr with updated one') @@ -26,21 +26,20 @@ if device is not None: device = torch.device(args.device) # Loading the appropriate model -config_path = os.path.dirname(args.pth_file) + '/' + os.path.basename(args.pth_file).split('_')[0] + '_' + 'config_' + os.path.basename(args.pth_file).split('_')[-1][:-4] + '.json' -with open(config_path) as json_file: - dict_params = json.load(json_file) +checkpoint = torch.load(args.checkpoint, map_location='cpu') +dict_params = checkpoint['config'] if dict_params['data_format'] == 'if': from models import large_if_ccode as model - pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim']) elif dict_params['data_format'] == 'xcorr': from models import large_xcorr as model - pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']) else: from models import large_joint as model - pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']).to(device) + pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim']) -pitch_nn.load_state_dict(torch.load(args.pth_file)) +pitch_nn.load_state_dict(checkpoint['state_dict']) pitch_nn = pitch_nn.to(device) N = dict_params['window_size'] diff --git a/dnn/torch/neural-pitch/training.py b/dnn/torch/neural-pitch/training.py index 3faf165a..f4139222 100644 --- a/dnn/torch/neural-pitch/training.py +++ b/dnn/torch/neural-pitch/training.py @@ -3,6 +3,7 @@ Training the neural pitch estimator """ +import os import argparse parser = argparse.ArgumentParser() @@ -22,6 +23,7 @@ parser.add_argument('--output_dim', type=int, help='Output dimension',default = parser.add_argument('--learning_rate', type=float, help='Learning Rate',default = 1.0e-3,required = False) parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False) parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False) +parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model') args = parser.parse_args() @@ -163,12 +165,9 @@ choice_cel = args.choice_cel, context = args.context, ) -now = datetime.now() -dir_pth_save = args.output_folder -dir_network = dir_pth_save + str(now) + '_net_' + args.data_format + '.pth' -dir_dictparams = dir_pth_save + str(now) + '_config_' + args.data_format + '.json' -# Save Weights -torch.save(pitch_nn.state_dict(), dir_network) -# Save Config -with open(dir_dictparams, 'w') as fp: - json.dump(config, fp) +model_save_path = os.path.join(args.output, f"{args.prefix}_{args.data_format}.pth") +checkpoint = { + 'state_dict': pitch_nn.state_dict(), + 'config': config +} +torch.save(checkpoint, model_save_path) -- cgit v1.2.3