diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-27 23:41:35 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-27 23:41:35 +0400 |
commit | 07a8194cf5d9c737af1b291d2b5b057a1f369437 (patch) | |
tree | 1eae39f3761646165463ac1c94ce97144ffe340d /generic | |
parent | 9d5ffa2a97b32286af7b690884f86ffce7ad3cc2 (diff) |
Added new criterions.
Diffstat (limited to 'generic')
-rw-r--r-- | generic/DistMarginCriterion.c | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/generic/DistMarginCriterion.c b/generic/DistMarginCriterion.c new file mode 100644 index 0000000..6e94c9d --- /dev/null +++ b/generic/DistMarginCriterion.c @@ -0,0 +1,187 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/DistMarginCriterion.c" +#else + +static int nn_(DistMarginCriterion_forward)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage"); + real *input_data, *target_data; + long nframe, dim; + long t, d, m; + THTensor *target_; + THTensor *target; + real sum; + + THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2, "vector or matrix expected"); + + if(input->nDimension == 1) { + nframe = 1; + dim = input->size[0]; + target_ = luaT_checkudata(L, 3, torch_(Tensor_id)); + target = THTensor_(new)(); + THTensor_(set)(target, target_); + THTensor_(resize2d)(target, 1, dim); + } + else { + nframe = input->size[0]; + dim = input->size[1]; + target_ = luaT_checkudata(L, 3, torch_(Tensor_id)); + THArgCheck((target_->nDimension == 2) && (target_->size[0] == nframe) && (target_->size[1] == dim), + 3, "inconsistent target size"); + target = THTensor_(newContiguous)(target_); + } + + for(t = 0; t < nframe; t++) { + for(d = 0; d < dim; d++) { + real idx = THTensor_(get2d)(target, t, d); + THArgCheck((idx >= 0) && (idx <= dim), 3, "target out of range"); + } + } + + input = THTensor_(newContiguous)(input); + input_data = THTensor_(data)(input); + target_data = THTensor_(data)(target); + + sum = 0; + for(t = 0; t < nframe; t++) { + real input_target = THInf; + for (m = 0; m < dim; m++) { + long target_idx = (long)(target_data[m]-1); + if (target_idx == -1) break; + if (input_target > input_data[target_idx]) input_target = input_data[target_idx]; + } + for(d = 0; d < dim; d++) { + int isatarget = 0; + for(m = 0; m < dim; m++) { + long target_idx = (long)(target_data[m]-1); + if (target_idx == -1) break; + else if(d == target_idx) { + isatarget = 1; + break; + } + } + if (isatarget) continue; + + real z = 1 - input_target + input_data[d]; + if(z > 0) sum += z; + } + input_data += dim; + target_data += dim; + } + + if(sizeAverage) + sum /= dim; + + lua_pushnumber(L, sum); + lua_setfield(L, 1, "output"); + + THTensor_(free)(input); + THTensor_(free)(target); + lua_pushnumber(L, sum); + return 1; +} + +static int nn_(DistMarginCriterion_backward)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage"); + THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id)); + real *input_data; + real *gradInput_data; + real *target_data; + THTensor *target_; + THTensor *target; + long nframe, dim; + long t, d, m; + real g; + real sum; + + THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2, "vector or matrix expected"); + + if(input->nDimension == 1) { + nframe = 1; + dim = input->size[0]; + target_ = luaT_checkudata(L, 3, torch_(Tensor_id)); + target = THTensor_(new)(); + THTensor_(set)(target, target_); + THTensor_(resize2d)(target, 1, dim); + } + else { + nframe = input->size[0]; + dim = input->size[1]; + target_ = luaT_checkudata(L, 3, torch_(Tensor_id)); + THArgCheck((target_->nDimension == 2) && (target_->size[0] == nframe) && (target_->size[1] == dim), + 3, "inconsistent target size"); + target = THTensor_(newContiguous)(target_); + } + + g = (sizeAverage ? 1./((real)dim) : 1.); + + input = THTensor_(newContiguous)(input); + input_data = THTensor_(data)(input); + + THTensor_(resizeAs)(gradInput, input); + gradInput_data = THTensor_(data)(gradInput); + + target_data = THTensor_(data)(target); + + for(t = 0; t < nframe; t++) { + real input_target = THInf; + int min_idx = -1; + for (m = 0; m < dim; m++) { + long target_idx = (long)(target_data[m]-1); + if (target_idx == -1) break; + if (input_target > input_data[target_idx]) { + min_idx = target_idx; + input_target = input_data[target_idx]; + } + } + real gradInput_target = 0; + for(d = 0; d < dim; d++) { + int isatarget = 0; + for(m = 0; m < dim; m++) { + long target_idx = (long)(target_data[m]-1); + if (target_idx == -1) break; + else if(d == target_idx) { + isatarget = 1; + break; + } + } + if (isatarget) continue; + + real z = 1 - input_target + input_data[d]; + if(z > 0) { + gradInput_target -= g; + gradInput_data[d] = g; + } + else + gradInput_data[d] = 0; + } + gradInput_data[min_idx] = gradInput_target; + + input_data += dim; + gradInput_data += dim; + target_data += dim; + } + + + THTensor_(free)(input); + THTensor_(free)(target); + return 1; +} + +static const struct luaL_Reg nn_(DistMarginCriterion__) [] = { + {"DistMarginCriterion_forward", nn_(DistMarginCriterion_forward)}, + {"DistMarginCriterion_backward", nn_(DistMarginCriterion_backward)}, + {NULL, NULL} +}; + +static void nn_(DistMarginCriterion_init)(lua_State *L) +{ + luaT_pushmetaclass(L, torch_(Tensor_id)); + luaT_registeratname(L, nn_(DistMarginCriterion__), "nn"); + lua_pop(L,1); +} + +#endif |