Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-05-18 00:49:43 +0400
committernicholas-leonard <nick@nikopia.org>2014-05-18 00:49:43 +0400
commitd55ce4806210e235efe8a79410b15e7eee25a541 (patch)
tree8437daa7bdf704205812bf3c490ce14bddc99fa1 /generic
parent35e28f3d0d91bef1181f0f9e10604632f6d40887 (diff)
implemented SoftMaxTree_accGradParameters
Diffstat (limited to 'generic')
-rw-r--r--generic/SoftMaxTree.c69
1 files changed, 58 insertions, 11 deletions
diff --git a/generic/SoftMaxTree.c b/generic/SoftMaxTree.c
index 8b697b9..71518d9 100644
--- a/generic/SoftMaxTree.c
+++ b/generic/SoftMaxTree.c
@@ -227,21 +227,21 @@ static int nn_(SoftMaxTree_accGradParameters)(lua_State *L)
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
THIntTensor *target = (THIntTensor*)luaT_checkudata(L, 4, "torch.IntTensor");
real scale = luaL_optnumber(L, 5, 1);
+ long rootId = (long)(luaT_getfieldcheckint(L, 1, "rootId") - 1);
int inputSize = luaT_getfieldcheckint(L, 1, "inputSize");
THIntTensor *childParent = (THIntTensor*)luaT_getfieldcheckudata(L, 1, "childParent", "torch.IntTensor");
THIntTensor *parentChildren = (THIntTensor*)luaT_getfieldcheckudata(L, 1, "parentChildren", "torch.IntTensor");
- THTensor *linearOutput = luaT_getfieldcheckudata(L, 1, "_linearGradOutput", torch_Tensor);;
+ THTensor *linearGradOutput = luaT_getfieldcheckudata(L, 1, "_linearGradOutput", torch_Tensor);;
- THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
- THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
+ luaT_getfield(L, 1, "updates");
+
THIntTensor *node;
- THTensor *nodeWeight, *nodeBias, *nodeOutput, *nodeInput, *nodeGradInter, *nodeGradOutput;
- real *input_data, *output_data;
+ THTensor *nodeGradWeight, *nodeGradBias, *nodeInput, *nodeGradOutput;
long i, d;
long n = 0;
@@ -249,16 +249,63 @@ static int nn_(SoftMaxTree_accGradParameters)(lua_State *L)
luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected");
luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size");
- luaL_argcheck(L, gradOutput->nDimension == 1, 2, "1D tensor expected");
-
node = THIntTensor_new();
- nodeWeight = THTensor_(new)();
- nodeBias = THTensor_(new)();
- nodeOutput = THTensor_(new)();
+ nodeGradWeight = THTensor_(new)();
+ nodeGradBias = THTensor_(new)();
nodeGradOutput = THTensor_(new)();
nodeInput = THTensor_(new)();
- nodeGradInter = THTensor_(new)();
+ for(i = 0; i < input->size[0]; i++)
+ {
+ long childId = (long)(THIntTensor_get1d(target, i)) - 1;
+ real grad = THTensor_(get1d)(gradOutput, i);
+
+ THTensor_(select)(nodeGradOutput, gradOutput, 0, i);
+
+ while(1)
+ {
+ long parentId, parentIdx, childIdx, nChildren;
+ /* get next Node in Tree */
+ THIntTensor_select(node, childParent, 0, childId);
+ parentId = (long)(THIntTensor_get1d(node, 0)) - 1;
+ childIdx = (long)(THIntTensor_get1d(node, 1)) - 1;
+
+ luaL_argcheck(L, parentId != -2, 2, "Non-root node has no parent in tree.");
+
+ THIntTensor_select(node, parentChildren, 0, parentId);
+ parentIdx = (long)(THIntTensor_get1d(node, 0)) - 1;
+ nChildren = (long)(THIntTensor_get1d(node, 1));
+
+ THTensor_(narrow)(nodeGradOutput, linearGradOutput, 0, n, nChildren);
+
+ THTensor_(addr)(nodeGradWeight, 1, nodeGradWeight, scale, nodeInput, nodeGradOutput);
+ THTensor_(cadd)(nodeGradBias, nodeGradBias, scale, nodeGradOutput);
+
+ /* updates will contain parentId (key) sum of scales (value)*/
+ lua_pushinteger(L, (int)parentId);
+ lua_gettable(L, -2);
+ double count = lua_tonumber(L, -1) + scale;
+
+ lua_pushinteger(L, (int)parentId); /* key */
+ lua_pushnumber(L, count); /* value */
+ lua_settable(L, -3);
+
+ n += nChildren;
+ /* Break when root is reached */
+ if (parentId == rootId)
+ {
+ break;
+ }
+ childId = parentId;
+ }
+ }
+
+ THIntTensor_free(node);
+ THTensor_(free)(nodeGradWeight);
+ THTensor_(free)(nodeGradBias);
+ THTensor_(free)(nodeGradOutput);
+ THTensor_(free)(nodeInput);
+
return 0;
}