diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-24 01:11:15 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-24 01:11:15 +0300 |
commit | d15be43af425013e27ef872ac672700e0b642ac1 (patch) | |
tree | 2006a8a972d101f27097676bf40a5be5d7825c5d | |
parent | 9d40e5cb0813464a9c6089210cbb72d99b94f253 (diff) |
Make bias/subias/diag/scale optional
-rw-r--r-- | dnn/parse_lpcnet_weights.c | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/dnn/parse_lpcnet_weights.c b/dnn/parse_lpcnet_weights.c index 833f972f..0f2def8b 100644 --- a/dnn/parse_lpcnet_weights.c +++ b/dnn/parse_lpcnet_weights.c @@ -124,16 +124,22 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays, int nb_inputs, int nb_outputs) { - int total_blocks; - if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1; - if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1; + layer->bias = NULL; + layer->subias = NULL; layer->weights = NULL; layer->float_weights = NULL; layer->weights_idx = NULL; - if (weights_idx != NULL) { - if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1; + layer->diag = NULL; + layer->scale = NULL; + if (bias != NULL) { + if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1; + } + if (subias != NULL) { + if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1; } if (weights_idx != NULL) { + int total_blocks; + if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1; if (weights != NULL) { if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1; } @@ -148,8 +154,12 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays, if ((layer->float_weights = find_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]))) == NULL) return 1; } } - if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1; - if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1; + if (diag != NULL) { + if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1; + } + if (weights != NULL) { + if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1; + } layer->nb_inputs = nb_inputs; layer->nb_outputs = nb_outputs; return 0; |