diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-18 00:49:43 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-18 00:49:43 +0400 |
commit | d55ce4806210e235efe8a79410b15e7eee25a541 (patch) | |
tree | 8437daa7bdf704205812bf3c490ce14bddc99fa1 /generic | |
parent | 35e28f3d0d91bef1181f0f9e10604632f6d40887 (diff) |
implemented SoftMaxTree_accGradParameters
Diffstat (limited to 'generic')
-rw-r--r-- | generic/SoftMaxTree.c | 69 |
1 files changed, 58 insertions, 11 deletions
diff --git a/generic/SoftMaxTree.c b/generic/SoftMaxTree.c index 8b697b9..71518d9 100644 --- a/generic/SoftMaxTree.c +++ b/generic/SoftMaxTree.c @@ -227,21 +227,21 @@ static int nn_(SoftMaxTree_accGradParameters)(lua_State *L) THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); THIntTensor *target = (THIntTensor*)luaT_checkudata(L, 4, "torch.IntTensor"); real scale = luaL_optnumber(L, 5, 1); + long rootId = (long)(luaT_getfieldcheckint(L, 1, "rootId") - 1); int inputSize = luaT_getfieldcheckint(L, 1, "inputSize"); THIntTensor *childParent = (THIntTensor*)luaT_getfieldcheckudata(L, 1, "childParent", "torch.IntTensor"); THIntTensor *parentChildren = (THIntTensor*)luaT_getfieldcheckudata(L, 1, "parentChildren", "torch.IntTensor"); - THTensor *linearOutput = luaT_getfieldcheckudata(L, 1, "_linearGradOutput", torch_Tensor);; + THTensor *linearGradOutput = luaT_getfieldcheckudata(L, 1, "_linearGradOutput", torch_Tensor);; - THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); - THTensor *bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor); THTensor *gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor); THTensor *gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor); + luaT_getfield(L, 1, "updates"); + THIntTensor *node; - THTensor *nodeWeight, *nodeBias, *nodeOutput, *nodeInput, *nodeGradInter, *nodeGradOutput; - real *input_data, *output_data; + THTensor *nodeGradWeight, *nodeGradBias, *nodeInput, *nodeGradOutput; long i, d; long n = 0; @@ -249,16 +249,63 @@ static int nn_(SoftMaxTree_accGradParameters)(lua_State *L) luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected"); luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size"); - luaL_argcheck(L, gradOutput->nDimension == 1, 2, "1D tensor expected"); - node = THIntTensor_new(); - nodeWeight = THTensor_(new)(); - nodeBias = THTensor_(new)(); - nodeOutput = THTensor_(new)(); + nodeGradWeight = THTensor_(new)(); + nodeGradBias = THTensor_(new)(); nodeGradOutput = THTensor_(new)(); nodeInput = THTensor_(new)(); - nodeGradInter = THTensor_(new)(); + for(i = 0; i < input->size[0]; i++) + { + long childId = (long)(THIntTensor_get1d(target, i)) - 1; + real grad = THTensor_(get1d)(gradOutput, i); + + THTensor_(select)(nodeGradOutput, gradOutput, 0, i); + + while(1) + { + long parentId, parentIdx, childIdx, nChildren; + /* get next Node in Tree */ + THIntTensor_select(node, childParent, 0, childId); + parentId = (long)(THIntTensor_get1d(node, 0)) - 1; + childIdx = (long)(THIntTensor_get1d(node, 1)) - 1; + + luaL_argcheck(L, parentId != -2, 2, "Non-root node has no parent in tree."); + + THIntTensor_select(node, parentChildren, 0, parentId); + parentIdx = (long)(THIntTensor_get1d(node, 0)) - 1; + nChildren = (long)(THIntTensor_get1d(node, 1)); + + THTensor_(narrow)(nodeGradOutput, linearGradOutput, 0, n, nChildren); + + THTensor_(addr)(nodeGradWeight, 1, nodeGradWeight, scale, nodeInput, nodeGradOutput); + THTensor_(cadd)(nodeGradBias, nodeGradBias, scale, nodeGradOutput); + + /* updates will contain parentId (key) sum of scales (value)*/ + lua_pushinteger(L, (int)parentId); + lua_gettable(L, -2); + double count = lua_tonumber(L, -1) + scale; + + lua_pushinteger(L, (int)parentId); /* key */ + lua_pushnumber(L, count); /* value */ + lua_settable(L, -3); + + n += nChildren; + /* Break when root is reached */ + if (parentId == rootId) + { + break; + } + childId = parentId; + } + } + + THIntTensor_free(node); + THTensor_(free)(nodeGradWeight); + THTensor_(free)(nodeGradBias); + THTensor_(free)(nodeGradOutput); + THTensor_(free)(nodeInput); + return 0; } |