Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-27 23:41:35 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-27 23:41:35 +0400
commit07a8194cf5d9c737af1b291d2b5b057a1f369437 (patch)
tree1eae39f3761646165463ac1c94ce97144ffe340d /generic
parent9d5ffa2a97b32286af7b690884f86ffce7ad3cc2 (diff)
Added new criterions.
Diffstat (limited to 'generic')
-rw-r--r--generic/DistMarginCriterion.c187
1 files changed, 187 insertions, 0 deletions
diff --git a/generic/DistMarginCriterion.c b/generic/DistMarginCriterion.c
new file mode 100644
index 0000000..6e94c9d
--- /dev/null
+++ b/generic/DistMarginCriterion.c
@@ -0,0 +1,187 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/DistMarginCriterion.c"
+#else
+
+static int nn_(DistMarginCriterion_forward)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ real *input_data, *target_data;
+ long nframe, dim;
+ long t, d, m;
+ THTensor *target_;
+ THTensor *target;
+ real sum;
+
+ THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2, "vector or matrix expected");
+
+ if(input->nDimension == 1) {
+ nframe = 1;
+ dim = input->size[0];
+ target_ = luaT_checkudata(L, 3, torch_(Tensor_id));
+ target = THTensor_(new)();
+ THTensor_(set)(target, target_);
+ THTensor_(resize2d)(target, 1, dim);
+ }
+ else {
+ nframe = input->size[0];
+ dim = input->size[1];
+ target_ = luaT_checkudata(L, 3, torch_(Tensor_id));
+ THArgCheck((target_->nDimension == 2) && (target_->size[0] == nframe) && (target_->size[1] == dim),
+ 3, "inconsistent target size");
+ target = THTensor_(newContiguous)(target_);
+ }
+
+ for(t = 0; t < nframe; t++) {
+ for(d = 0; d < dim; d++) {
+ real idx = THTensor_(get2d)(target, t, d);
+ THArgCheck((idx >= 0) && (idx <= dim), 3, "target out of range");
+ }
+ }
+
+ input = THTensor_(newContiguous)(input);
+ input_data = THTensor_(data)(input);
+ target_data = THTensor_(data)(target);
+
+ sum = 0;
+ for(t = 0; t < nframe; t++) {
+ real input_target = THInf;
+ for (m = 0; m < dim; m++) {
+ long target_idx = (long)(target_data[m]-1);
+ if (target_idx == -1) break;
+ if (input_target > input_data[target_idx]) input_target = input_data[target_idx];
+ }
+ for(d = 0; d < dim; d++) {
+ int isatarget = 0;
+ for(m = 0; m < dim; m++) {
+ long target_idx = (long)(target_data[m]-1);
+ if (target_idx == -1) break;
+ else if(d == target_idx) {
+ isatarget = 1;
+ break;
+ }
+ }
+ if (isatarget) continue;
+
+ real z = 1 - input_target + input_data[d];
+ if(z > 0) sum += z;
+ }
+ input_data += dim;
+ target_data += dim;
+ }
+
+ if(sizeAverage)
+ sum /= dim;
+
+ lua_pushnumber(L, sum);
+ lua_setfield(L, 1, "output");
+
+ THTensor_(free)(input);
+ THTensor_(free)(target);
+ lua_pushnumber(L, sum);
+ return 1;
+}
+
+static int nn_(DistMarginCriterion_backward)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+ real *input_data;
+ real *gradInput_data;
+ real *target_data;
+ THTensor *target_;
+ THTensor *target;
+ long nframe, dim;
+ long t, d, m;
+ real g;
+ real sum;
+
+ THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2, "vector or matrix expected");
+
+ if(input->nDimension == 1) {
+ nframe = 1;
+ dim = input->size[0];
+ target_ = luaT_checkudata(L, 3, torch_(Tensor_id));
+ target = THTensor_(new)();
+ THTensor_(set)(target, target_);
+ THTensor_(resize2d)(target, 1, dim);
+ }
+ else {
+ nframe = input->size[0];
+ dim = input->size[1];
+ target_ = luaT_checkudata(L, 3, torch_(Tensor_id));
+ THArgCheck((target_->nDimension == 2) && (target_->size[0] == nframe) && (target_->size[1] == dim),
+ 3, "inconsistent target size");
+ target = THTensor_(newContiguous)(target_);
+ }
+
+ g = (sizeAverage ? 1./((real)dim) : 1.);
+
+ input = THTensor_(newContiguous)(input);
+ input_data = THTensor_(data)(input);
+
+ THTensor_(resizeAs)(gradInput, input);
+ gradInput_data = THTensor_(data)(gradInput);
+
+ target_data = THTensor_(data)(target);
+
+ for(t = 0; t < nframe; t++) {
+ real input_target = THInf;
+ int min_idx = -1;
+ for (m = 0; m < dim; m++) {
+ long target_idx = (long)(target_data[m]-1);
+ if (target_idx == -1) break;
+ if (input_target > input_data[target_idx]) {
+ min_idx = target_idx;
+ input_target = input_data[target_idx];
+ }
+ }
+ real gradInput_target = 0;
+ for(d = 0; d < dim; d++) {
+ int isatarget = 0;
+ for(m = 0; m < dim; m++) {
+ long target_idx = (long)(target_data[m]-1);
+ if (target_idx == -1) break;
+ else if(d == target_idx) {
+ isatarget = 1;
+ break;
+ }
+ }
+ if (isatarget) continue;
+
+ real z = 1 - input_target + input_data[d];
+ if(z > 0) {
+ gradInput_target -= g;
+ gradInput_data[d] = g;
+ }
+ else
+ gradInput_data[d] = 0;
+ }
+ gradInput_data[min_idx] = gradInput_target;
+
+ input_data += dim;
+ gradInput_data += dim;
+ target_data += dim;
+ }
+
+
+ THTensor_(free)(input);
+ THTensor_(free)(target);
+ return 1;
+}
+
+static const struct luaL_Reg nn_(DistMarginCriterion__) [] = {
+ {"DistMarginCriterion_forward", nn_(DistMarginCriterion_forward)},
+ {"DistMarginCriterion_backward", nn_(DistMarginCriterion_backward)},
+ {NULL, NULL}
+};
+
+static void nn_(DistMarginCriterion_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(DistMarginCriterion__), "nn");
+ lua_pop(L,1);
+}
+
+#endif