From 10b5fe0259e4a195d0e86076c8236a31b820ca7b Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Mon, 30 May 2011 13:46:27 -0400 Subject: Training now stops when stuck in a minimum --- src/mlp_train.c | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/mlp_train.c b/src/mlp_train.c index 1c1fa28e..a9d548b1 100644 --- a/src/mlp_train.c +++ b/src/mlp_train.c @@ -80,7 +80,7 @@ MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int std = .001; std = 1/sqrt(inDim*std); for (k=0;kweights[0][k*(topo[0]+1)+j+1] = randn(4*std); + net->weights[0][k*(topo[0]+1)+j+1] = randn(std); } net->in_rate[0] = 1; for (j=0;jtopo; inDim = net->topo[0]; hiddenDim = net->topo[1]; @@ -313,10 +315,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam float mean_rate = 0, min_rate = 1e10; rms = (rms/(outDim*nbSamples)); error_rate = (error_rate/(outDim*nbSamples)); - fprintf (stderr, "%f (%f %f) ", error_rate, rms, last_rms); - if (rms < last_rms) + fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms); + if (rms < best_rms) { - last_rms = rms; + best_rms = rms; for (i=0;i last_rms) { + count_retries=0; + } else { count_worse++; - if (count_worse>20) + if (count_worse>30) { + count_retries++; count_worse=0; for (i=0;i10) + break; for (i=0;i 0) @@ -386,7 +392,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam W1[i] += W1_grad[i]*W1_rate[i]; } mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]; - fprintf (stderr, "%g (min %g) %d\n", mean_rate, min_rate, e); + fprintf (stderr, "%g %d", mean_rate, e); + if (count_retries) + fprintf(stderr, " %d", count_retries); + fprintf(stderr, "\n"); if (stopped) break; } @@ -403,7 +412,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam free(W1_grad); free(W0_rate); free(W1_rate); - return last_rms; + return best_rms; } int main(int argc, char **argv) -- cgit v1.2.3