Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-05-17 01:49:16 +0400
committernicholas-leonard <nick@nikopia.org>2014-05-17 01:49:16 +0400
commitb283ac68ab370df7eeb26c8033ef5c7b6a083907 (patch)
tree40bbc3049ede33e02d17bd93aae68ba78003ff65 /SoftMaxTree.lua
parent1ff10e2b01446685030736a7604eecbd2eb99c70 (diff)
SoftMaxTree:forward works (debugged, unit tested)
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua21
1 files changed, 17 insertions, 4 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index f87a35e..e8aa21c 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -10,6 +10,7 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module')
-- differ setup after init
-- verify that each parent has a parent (except root)
-- nodeIds - 1?
+-- types
------------------------------------------------------------------------
function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
@@ -49,6 +50,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
print("in order to waste less memory on indexes.")
end
end
+
self.nChildNode = nChildNode
self.nParentNode = nParentNode
self.minNodeId = minNodeId
@@ -73,7 +75,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
for parentId, children in pairs(hierarchy) do
local node = self.parentChildren:select(1, parentId)
node[1] = start
- local nChildren = children:size()
+ local nChildren = children:size(1)
node[2] = nChildren
self.childIds:narrow(1, start, nChildren):copy(children)
start = start + nChildren
@@ -81,7 +83,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
-- index of parent by childId
self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1)
- for parendIdx=1,self.parentIds:size(1) do
+ for parentIdx=1,self.parentIds:size(1) do
local parentId = self.parentIds[parentIdx]
local node = self.parentChildren:select(1, parentId)
local start = node[1]
@@ -99,12 +101,23 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
self._linearOutput = torch.Tensor()
self._logSoftMaxOutput = torch.Tensor()
self._narrowOutput = torch.Tensor()
- -- used to store pointers to intermediate outputs
- --self._nodes = torch.Tensor()
+ print("parentIds", self.parentIds)
+ print("parentChildren", self.parentChildren)
+ print("children", self.childIds)
+ print("childParent", self.childParent)
self:reset()
end
+function SoftMaxTree:getNodeParameters(parentId)
+ local node = self.parentChildren:select(1,parentId)
+ local start = node[1]
+ local nChildren = node[2]
+ local weight = self.weight:narrow(1, start, nChildren)
+ local bias = self.bias:narrow(1, start, nChildren)
+ return weight, bias
+end
+
function SoftMaxTree:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)