diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-24 01:25:14 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-24 01:25:14 +0300 |
commit | 8e7080903dcac51ebbddc541429c338259a4d031 (patch) | |
tree | 81c4863c6edb7797e8f34e54375276dce0375e60 | |
parent | d15be43af425013e27ef872ac672700e0b642ac1 (diff) |
Make float_weights optional
-rw-r--r-- | dnn/parse_lpcnet_weights.c | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/dnn/parse_lpcnet_weights.c b/dnn/parse_lpcnet_weights.c index 0f2def8b..06332fc7 100644 --- a/dnn/parse_lpcnet_weights.c +++ b/dnn/parse_lpcnet_weights.c @@ -88,6 +88,13 @@ static const void *find_array_check(const WeightArray *arrays, const char *name, else return NULL; } +static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) { + const WeightArray *a = find_array_entry(arrays, name); + *error = (a != NULL && a->size != size); + if (a && a->size == size) return a->data; + else return NULL; +} + static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) { int remain; const int *idx; @@ -124,6 +131,7 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays, int nb_inputs, int nb_outputs) { + int err; layer->bias = NULL; layer->subias = NULL; layer->weights = NULL; @@ -144,14 +152,16 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays, if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1; } if (float_weights != NULL) { - if ((layer->float_weights = find_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]))) == NULL) return 1; + layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err); + if (err) return 1; } } else { if (weights != NULL) { if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1; } if (float_weights != NULL) { - if ((layer->float_weights = find_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]))) == NULL) return 1; + layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err); + if (err) return 1; } } if (diag != NULL) { |