Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-29 16:31:45 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-29 16:31:45 +0300
commitce28695844c12f43c31b4ee739749883c8b44b17 (patch)
tree1b9b0dbf6cd62f3a0223e3134aae2beb9831cb3e
parent49014454907d515e3c8ca8b06add78ad74c417d1 (diff)
refactoring and cleanup
-rw-r--r--dnn/torch/neural-pitch/evaluation.py23
-rw-r--r--dnn/torch/neural-pitch/models.py121
-rw-r--r--dnn/torch/neural-pitch/neural_pitch_update.py35
-rw-r--r--dnn/torch/neural-pitch/training.py69
4 files changed, 75 insertions, 173 deletions
diff --git a/dnn/torch/neural-pitch/evaluation.py b/dnn/torch/neural-pitch/evaluation.py
index b7f8d318..38ba5765 100644
--- a/dnn/torch/neural-pitch/evaluation.py
+++ b/dnn/torch/neural-pitch/evaluation.py
@@ -21,6 +21,8 @@ from utils import stft, random_filter, feature_xform
import subprocess
import crepe
+from models import PitchDNN, PitchDNNIF, PitchDNNXcorr
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def rca(reference,input,voicing,thresh = 25):
@@ -43,20 +45,6 @@ def rpa(model,device = 'cpu',data_format = 'if'):
random.shuffle(list_files)
list_files = list_files[:1000]
- # C_lp = 0
- # C_lp_m = 0
- # C_lp_f = 0
- # list_rca_model_lp = []
- # list_rca_male_lp = []
- # list_rca_female_lp = []
-
- # C_hp = 0
- # C_hp_m = 0
- # C_hp_f = 0
- # list_rca_model_hp = []
- # list_rca_male_hp = []
- # list_rca_female_hp = []
-
C_all = 0
C_all_m = 0
C_all_f = 0
@@ -180,16 +168,15 @@ def cycle_eval(checkpoint_list, noise_type = 'synthetic', noise_dataset = None,
checkpoint = torch.load(f, map_location='cpu')
dict_params = checkpoint['config']
-
if dict_params['data_format'] == 'if':
from models import large_if_ccode as model
- pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNNIF(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
elif dict_params['data_format'] == 'xcorr':
from models import large_xcorr as model
- pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
else:
from models import large_joint as model
- pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNN(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
pitch_nn.load_state_dict(checkpoint['state_dict'])
diff --git a/dnn/torch/neural-pitch/models.py b/dnn/torch/neural-pitch/models.py
index 6822353e..2d24f9a5 100644
--- a/dnn/torch/neural-pitch/models.py
+++ b/dnn/torch/neural-pitch/models.py
@@ -6,16 +6,16 @@ Pitch Estimation Models and dataloaders
import torch
import numpy as np
-class large_if_ccode(torch.nn.Module):
+class PitchDNNIF(torch.nn.Module):
- def __init__(self,input_dim = 88,gru_dim = 64,output_dim = 192):
- super(large_if_ccode,self).__init__()
+ def __init__(self, input_dim=88, gru_dim=64, output_dim=192):
+ super().__init__()
self.activation = torch.nn.Tanh()
- self.initial = torch.nn.Linear(input_dim,gru_dim)
- self.hidden = torch.nn.Linear(gru_dim,gru_dim)
- self.gru = torch.nn.GRU(input_size = gru_dim,hidden_size = gru_dim,batch_first = True)
- self.upsample = torch.nn.Linear(gru_dim,output_dim)
+ self.initial = torch.nn.Linear(input_dim, gru_dim)
+ self.hidden = torch.nn.Linear(gru_dim, gru_dim)
+ self.gru = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, batch_first=True)
+ self.upsample = torch.nn.Linear(gru_dim, output_dim)
def forward(self, x):
@@ -30,71 +30,53 @@ class large_if_ccode(torch.nn.Module):
return x
-class large_xcorr(torch.nn.Module):
+class PitchDNNXcorr(torch.nn.Module):
- def __init__(self,input_dim = 90,gru_dim = 64,output_dim = 192):
- super(large_xcorr,self).__init__()
+ def __init__(self, input_dim=90, gru_dim=64, output_dim=192):
+ super().__init__()
self.activation = torch.nn.Tanh()
self.conv = torch.nn.Sequential(
- torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(1, 8, 3, bias = True),
+ torch.nn.ZeroPad2d((2, 0, 1, 1)),
+ torch.nn.Conv2d(1, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(8, 8, 3, bias = True),
+ torch.nn.Conv2d(8, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(8, 1, 3, bias = True),
+ torch.nn.Conv2d(8, 1, 3, bias=True),
self.activation,
)
- # self.conv = torch.nn.Sequential(
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(64,10,3),
- # self.activation,
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(10,64,3),
- # self.activation,
- # )
-
self.downsample = torch.nn.Sequential(
- torch.nn.Linear(input_dim,gru_dim),
+ torch.nn.Linear(input_dim, gru_dim),
self.activation
)
- self.GRU = torch.nn.GRU(input_size = gru_dim,hidden_size = gru_dim,num_layers = 1,batch_first = True)
+ self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
self.upsample = torch.nn.Sequential(
torch.nn.Linear(gru_dim,output_dim),
self.activation
)
def forward(self, x):
- # x = x[:,:,:257].unsqueeze(-1)
x = self.conv(x.unsqueeze(-1).permute(0,3,2,1)).squeeze(1)
- # print(x.shape)
- # x = self.conv(x.permute(0,3,2,1)).squeeze(1)
x,_ = self.GRU(self.downsample(x.permute(0,2,1)))
x = self.upsample(x).permute(0,2,1)
- # x = self.downsample(x)
- # x = self.activation(x)
- # x = self.conv(x.permute(0,2,1)).permute(0,2,1)
- # x,_ = self.GRU(x)
- # x = self.upsample(x).permute(0,2,1)
return x
-class large_joint(torch.nn.Module):
+class PitchDNN(torch.nn.Module):
"""
Joint IF-xcorr
1D CNN on IF, merge with xcorr, 2D CNN on merged + GRU
"""
- def __init__(self,input_IF_dim = 88,input_xcorr_dim = 224,gru_dim = 64,output_dim = 192):
- super(large_joint,self).__init__()
+ def __init__(self,input_IF_dim=88, input_xcorr_dim=224, gru_dim=64, output_dim=192):
+ super().__init__()
self.activation = torch.nn.Tanh()
- print("dim=", input_IF_dim)
self.if_upsample = torch.nn.Sequential(
torch.nn.Linear(input_IF_dim,64),
self.activation,
@@ -102,54 +84,34 @@ class large_joint(torch.nn.Module):
self.activation,
)
- # self.if_upsample = torch.nn.Sequential(
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(90,10,3),
- # self.activation,
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(10,257,3),
- # self.activation,
- # )
-
self.conv = torch.nn.Sequential(
torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(1, 8, 3, bias = True),
+ torch.nn.Conv2d(1, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(8, 8, 3, bias = True),
+ torch.nn.Conv2d(8, 8, 3, bias=True),
self.activation,
torch.nn.ZeroPad2d((2,0,1,1)),
- torch.nn.Conv2d(8, 1, 3, bias = True),
+ torch.nn.Conv2d(8, 1, 3, bias=True),
self.activation,
)
- # self.conv = torch.nn.Sequential(
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(257,10,3),
- # self.activation,
- # torch.nn.ConstantPad1d((2,0),0),
- # torch.nn.Conv1d(10,64,3),
- # self.activation,
- # )
-
self.downsample = torch.nn.Sequential(
- torch.nn.Linear(64 + input_xcorr_dim,gru_dim),
+ torch.nn.Linear(64 + input_xcorr_dim, gru_dim),
self.activation
)
- self.GRU = torch.nn.GRU(input_size = gru_dim,hidden_size = gru_dim,num_layers = 1,batch_first = True)
+ self.GRU = torch.nn.GRU(input_size=gru_dim, hidden_size=gru_dim, num_layers=1, batch_first=True)
self.upsample = torch.nn.Sequential(
- torch.nn.Linear(gru_dim,output_dim),
+ torch.nn.Linear(gru_dim, output_dim),
self.activation
)
def forward(self, x):
xcorr_feat = x[:,:,:224]
if_feat = x[:,:,224:]
- # x = torch.cat([xcorr_feat.unsqueeze(-1),self.if_upsample(if_feat).unsqueeze(-1)],axis = -1)
xcorr_feat = self.conv(xcorr_feat.unsqueeze(-1).permute(0,3,2,1)).squeeze(1).permute(0,2,1)
if_feat = self.if_upsample(if_feat)
x = torch.cat([xcorr_feat,if_feat],axis = - 1)
- # x = self.conv(x.permute(0,3,2,1)).squeeze(1)
x,_ = self.GRU(self.downsample(x))
x = self.upsample(x).permute(0,2,1)
@@ -157,8 +119,8 @@ class large_joint(torch.nn.Module):
# Dataloaders
-class loader(torch.utils.data.Dataset):
- def __init__(self, features_if, file_pitch,confidence_threshold = 0.4,dimension_if = 30,context = 100):
+class Loader(torch.utils.data.Dataset):
+ def __init__(self, features_if, file_pitch, confidence_threshold=0.4, dimension_if=30, context=100):
self.if_feat = np.memmap(features_if, dtype=np.float32).reshape(-1,3*dimension_if)
# Resolution of 20 cents
@@ -170,24 +132,24 @@ class loader(torch.utils.data.Dataset):
self.confidence[self.confidence < confidence_threshold] = 0
self.context = context
# Clip both to same size
- size_common = min(self.if_feat.shape[0],self.cents.shape[0])
+ size_common = min(self.if_feat.shape[0], self.cents.shape[0])
self.if_feat = self.if_feat[:size_common,:]
self.cents = self.cents[:size_common]
self.confidence = self.confidence[:size_common]
frame_max = self.if_feat.shape[0]//context
- self.if_feat = np.reshape(self.if_feat[:frame_max*context,:],(frame_max,context,3*dimension_if))
- self.cents = np.reshape(self.cents[:frame_max*context],(frame_max,context))
- self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max,context))
+ self.if_feat = np.reshape(self.if_feat[:frame_max*context, :],(frame_max, context,3*dimension_if))
+ self.cents = np.reshape(self.cents[:frame_max * context],(frame_max, context))
+ self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max, context))
def __len__(self):
return self.if_feat.shape[0]
def __getitem__(self, index):
- return torch.from_numpy(self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
+ return torch.from_numpy(self.if_feat[index,:,:]), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
-class loader_joint(torch.utils.data.Dataset):
- def __init__(self, features, file_pitch, confidence_threshold = 0.4,context = 100, choice_data = 'both'):
+class PitchDNNDataloader(torch.utils.data.Dataset):
+ def __init__(self, features, file_pitch, confidence_threshold=0.4, context=100, choice_data='both'):
self.feat = np.memmap(features, mode='r', dtype=np.int8).reshape(-1,312)
self.xcorr = self.feat[:,:224]
self.if_feat = self.feat[:,224:]
@@ -199,24 +161,21 @@ class loader_joint(torch.utils.data.Dataset):
# Filter confidence for CREPE
self.confidence[self.confidence < confidence_threshold] = 0
self.context = context
- print(np.mean(self.confidence), np.mean(self.cents))
self.choice_data = choice_data
frame_max = self.if_feat.shape[0]//context
- self.if_feat = np.reshape(self.if_feat[:frame_max*context,:],(frame_max,context,88))
- self.cents = np.reshape(self.cents[:frame_max*context],(frame_max,context))
- self.xcorr = np.reshape(self.xcorr[:frame_max*context,:],(frame_max,context,224))
- # self.cents = np.rint(60*np.log2(256/(self.periods + 1.0e-8))).astype('int')
- # self.cents = np.clip(self.cents,0,239)
- self.confidence = np.reshape(self.confidence[:frame_max*context],(frame_max,context))
- # print(self.if_feat.shape)
+ self.if_feat = np.reshape(self.if_feat[:frame_max*context,:], (frame_max, context, 88))
+ self.cents = np.reshape(self.cents[:frame_max*context], (frame_max,context))
+ self.xcorr = np.reshape(self.xcorr[:frame_max*context,:], (frame_max,context, 224))
+ self.confidence = np.reshape(self.confidence[:frame_max*context], (frame_max, context))
+
def __len__(self):
return self.if_feat.shape[0]
def __getitem__(self, index):
if self.choice_data == 'both':
- return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]),torch.from_numpy((1./127)*self.if_feat[index,:,:])],dim=-1),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
+ return torch.cat([torch.from_numpy((1./127)*self.xcorr[index,:,:]), torch.from_numpy((1./127)*self.if_feat[index,:,:])], dim=-1), torch.from_numpy(self.cents[index]), torch.from_numpy(self.confidence[index])
elif self.choice_data == 'if':
return torch.from_numpy((1./127)*self.if_feat[index,:,:]),torch.from_numpy(self.cents[index]),torch.from_numpy(self.confidence[index])
else:
diff --git a/dnn/torch/neural-pitch/neural_pitch_update.py b/dnn/torch/neural-pitch/neural_pitch_update.py
index a72abee6..aa2caf99 100644
--- a/dnn/torch/neural-pitch/neural_pitch_update.py
+++ b/dnn/torch/neural-pitch/neural_pitch_update.py
@@ -20,6 +20,7 @@ import json
import torch
import tqdm
+from models import PitchDNNIF, PitchDNNXcorr, PitchDNN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is not None:
@@ -30,14 +31,11 @@ checkpoint = torch.load(args.checkpoint, map_location='cpu')
dict_params = checkpoint['config']
if dict_params['data_format'] == 'if':
- from models import large_if_ccode as model
- pitch_nn = model(dict_params['freq_keep']*3,dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNNIF(dict_params['freq_keep']*3, dict_params['gru_dim'], dict_params['output_dim'])
elif dict_params['data_format'] == 'xcorr':
- from models import large_xcorr as model
- pitch_nn = model(dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNNXcorr(dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
else:
- from models import large_joint as model
- pitch_nn = model(dict_params['freq_keep']*3,dict_params['xcorr_dim'],dict_params['gru_dim'],dict_params['output_dim'])
+ pitch_nn = PitchDNN(dict_params['freq_keep']*3, dict_params['xcorr_dim'], dict_params['gru_dim'], dict_params['output_dim'])
pitch_nn.load_state_dict(checkpoint['state_dict'])
pitch_nn = pitch_nn.to(device)
@@ -46,22 +44,8 @@ N = dict_params['window_size']
H = dict_params['hop_factor']
freq_keep = dict_params['freq_keep']
-# import os
-# import argparse
-
-
-
-# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OMP_NUM_THREADS"] = "16"
-# parser = argparse.ArgumentParser()
-
-# parser.add_argument('features', type=str, help='input features')
-# parser.add_argument('data', type=str, help='input data')
-# parser.add_argument('output', type=str, help='output features')
-# parser.add_argument('--add-confidence', action='store_true', help='add CREPE confidence to features')
-# parser.add_argument('--viterbi', action='store_true', help='enable viterbi algo for pitch tracking')
-
def run_lpc(signal, lpcs, frame_length=160):
num_frames, lpc_order = lpcs.shape
@@ -85,9 +69,6 @@ if __name__ == "__main__":
assert feature_dim == 36
- # if args.add_confidence:
- # feature_dim += 1
-
output = np.memmap(args.output, dtype=np.float32, shape=(num_frames, feature_dim), mode='w+')
output[:, :36] = features
@@ -96,7 +77,6 @@ if __name__ == "__main__":
sig = data[:, 1]
# parameters
- # use_viterbi=args.viterbi
# constants
pitch_min = 32
@@ -125,7 +105,6 @@ if __name__ == "__main__":
break
chunk = np.concatenate((history, sig[signal_start:signal_stop]))
chunk_la = np.concatenate((history, sig[signal_start:signal_stop + 80]))
- # time, frequency, confidence, _ = crepe.predict(chunk, fs, center=True, viterbi=True,verbose=0)
# Feature computation
spec = stft(x = np.concatenate([np.zeros(80),chunk_la/(2**15 - 1)]), w = 'boxcar', N = N, H = H).T
@@ -160,20 +139,14 @@ if __name__ == "__main__":
frequency = 62.5*2**(model_cents/1200)
frequency = frequency[overlap_frames : overlap_frames + frame_stop - frame_start]
- # confidence = confidence[overlap_frames : overlap_frames + frame_stop - frame_start]
# convert frequencies to periods
periods = np.round(fs / frequency)
- # adjust to pitch range
- # confidence[periods < pitch_min] = 0
- # confidence[periods > pitch_max] = 0
periods = np.clip(periods, pitch_min, pitch_max)
output[frame_start:frame_stop, pitch_position] = (periods - 100) / 50
- # if args.replace_xcorr:
- # re-calculate xcorr
frame_offset = (pitch_max + frame_length - 1) // frame_length
offset = frame_offset * frame_length
padding = lpc_order
diff --git a/dnn/torch/neural-pitch/training.py b/dnn/torch/neural-pitch/training.py
index f4139222..04b3deb1 100644
--- a/dnn/torch/neural-pitch/training.py
+++ b/dnn/torch/neural-pitch/training.py
@@ -37,33 +37,25 @@ import time
np_seed = int(time.time())
torch_seed = int(time.time())
-import json
import torch
torch.manual_seed(torch_seed)
import numpy as np
np.random.seed(np_seed)
from utils import count_parameters
import tqdm
-import sys
-from datetime import datetime
-#from evaluation import rpa
+from models import PitchDNN, PitchDNNIF, PitchDNNXcorr, PitchDNNDataloader
-# print(list(range(torch.cuda.device_count())))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-# device = 'cpu'
-from models import loader_joint as loader
+
if args.data_format == 'if':
- from models import large_if_ccode as model
- pitch_nn = model(args.freq_keep*3,args.gru_dim,args.output_dim)
+ pitch_nn = PitchDNNIF(3 * args.freq_keep - 2, args.gru_dim, args.output_dim)
elif args.data_format == 'xcorr':
- from models import large_xcorr as model
- pitch_nn = model(args.xcorr_dimension,args.gru_dim,args.output_dim)
+ pitch_nn = PitchDNNXcorr(args.xcorr_dimension, args.gru_dim, args.output_dim)
else:
- from models import large_joint as model
- pitch_nn = model(88,224,args.gru_dim,args.output_dim)
+ pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
-dataset_training = loader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
+dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
@@ -84,23 +76,15 @@ def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
def accuracy(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
logits_softmax = torch.nn.Softmax(dim = 1)(logits).permute(0,2,1)
pred_pitch = torch.argmax(logits_softmax, 2)
- #print(pred_pitch.shape, labels.long().shape)
accuracy = (pred_pitch != labels.long())*1.
- #print(accuracy.shape, confidence.shape)
return 1.-torch.mean(confidence*accuracy)
-# features = args.features
-# pitch = args.crepe_pitch
-# dataset_training = loader(features,pitch,args.confidence_threshold,args.freq_keep,args.context)
-# dataset_training = loader(features,pitch,'../../../../testing/testing_features_10pct_xcorr.f32')
-
-train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05],generator=torch.Generator().manual_seed(torch_seed))
+train_dataset, test_dataset = torch.utils.data.random_split(dataset_training, [0.95,0.05], generator=torch.Generator().manual_seed(torch_seed))
batch_size = 256
-train_dataloader = torch.utils.data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True,num_workers = 0, pin_memory = False)
-test_dataloader = torch.utils.data.DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle = True,num_workers = 0, pin_memory = False)
+train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
+test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
-# pitch_nn = model(args.freq_keep*3,args.gru_dim,args.output_dim).to(device)
pitch_nn = pitch_nn.to(device)
num_params = count_parameters(pitch_nn)
learning_rate = args.learning_rate
@@ -143,26 +127,25 @@ for epoch in range(num_epochs):
test_epoch.set_postfix({"Epoch" : epoch, "Test Loss":avg_loss})
pitch_nn.eval()
-#rpa(pitch_nn,device,data_format = args.data_format)
config = dict(
-data_format = args.data_format,
-epochs = num_epochs,
-window_size = args.N,
-hop_factor = args.H,
-freq_keep = args.freq_keep,
-batch_size = batch_size,
-learning_rate = learning_rate,
-confidence_threshold = args.confidence_threshold,
-model_parameters = num_params,
-np_seed = np_seed,
-torch_seed = torch_seed,
-xcorr_dim = args.xcorr_dimension,
-dim_input = 3*args.freq_keep,
-gru_dim = args.gru_dim,
-output_dim = args.output_dim,
-choice_cel = args.choice_cel,
-context = args.context,
+ data_format=args.data_format,
+ epochs=num_epochs,
+ window_size= args.N,
+ hop_factor= args.H,
+ freq_keep=args.freq_keep,
+ batch_size=batch_size,
+ learning_rate=learning_rate,
+ confidence_threshold=args.confidence_threshold,
+ model_parameters=num_params,
+ np_seed=np_seed,
+ torch_seed=torch_seed,
+ xcorr_dim=args.xcorr_dimension,
+ dim_input=3*args.freq_keep - 2,
+ gru_dim=args.gru_dim,
+ output_dim=args.output_dim,
+ choice_cel=args.choice_cel,
+ context=args.context,
)
model_save_path = os.path.join(args.output, f"{args.prefix}_{args.data_format}.pth")