diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-17 01:49:16 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-17 01:49:16 +0400 |
commit | b283ac68ab370df7eeb26c8033ef5c7b6a083907 (patch) | |
tree | 40bbc3049ede33e02d17bd93aae68ba78003ff65 /SoftMaxTree.lua | |
parent | 1ff10e2b01446685030736a7604eecbd2eb99c70 (diff) |
SoftMaxTree:forward works (debugged, unit tested)
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index f87a35e..e8aa21c 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -10,6 +10,7 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module') -- differ setup after init -- verify that each parent has a parent (except root) -- nodeIds - 1? +-- types ------------------------------------------------------------------------ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) @@ -49,6 +50,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) print("in order to waste less memory on indexes.") end end + self.nChildNode = nChildNode self.nParentNode = nParentNode self.minNodeId = minNodeId @@ -73,7 +75,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) for parentId, children in pairs(hierarchy) do local node = self.parentChildren:select(1, parentId) node[1] = start - local nChildren = children:size() + local nChildren = children:size(1) node[2] = nChildren self.childIds:narrow(1, start, nChildren):copy(children) start = start + nChildren @@ -81,7 +83,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) -- index of parent by childId self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1) - for parendIdx=1,self.parentIds:size(1) do + for parentIdx=1,self.parentIds:size(1) do local parentId = self.parentIds[parentIdx] local node = self.parentChildren:select(1, parentId) local start = node[1] @@ -99,12 +101,23 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) self._linearOutput = torch.Tensor() self._logSoftMaxOutput = torch.Tensor() self._narrowOutput = torch.Tensor() - -- used to store pointers to intermediate outputs - --self._nodes = torch.Tensor() + print("parentIds", self.parentIds) + print("parentChildren", self.parentChildren) + print("children", self.childIds) + print("childParent", self.childParent) self:reset() end +function SoftMaxTree:getNodeParameters(parentId) + local node = self.parentChildren:select(1,parentId) + local start = node[1] + local nChildren = node[2] + local weight = self.weight:narrow(1, start, nChildren) + local bias = self.bias:narrow(1, start, nChildren) + return weight, bias +end + function SoftMaxTree:reset(stdv) if stdv then stdv = stdv * math.sqrt(3) |