diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-18 05:49:34 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-18 05:49:34 +0400 |
commit | 954acf95fdef51b37ef132502deeab1b2d2f3c22 (patch) | |
tree | b02f41d7687e3072118b86f4aae86a38d9b8144a /SoftMaxTree.lua | |
parent | d55ce4806210e235efe8a79410b15e7eee25a541 (diff) |
SoftMaxTree works (debugged and unit tested)
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 73 |
1 files changed, 52 insertions, 21 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 904adcf..3d2649d 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -1,17 +1,13 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module') ------------------------------------------------------------------------ --[[ SoftMaxTree ]]-- --- Generates an output tensor of size 1D -- Computes the log of a product of softmaxes in a path --- One parent per child +-- Returns an output tensor of size 1D +-- Only works with a tree (one parent per child) -- TODO: --- a shareClone method to make speedier clones --- differ setup after init --- verify that each parent has a parent (except root) --- nodeIds - 1? --- types --- scrap narrowOutput and linearOutput +-- a shareClone method to make clones without wasting memory +-- which may require differing setup after initialization ------------------------------------------------------------------------ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) @@ -106,22 +102,9 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) self._linearGradOutput = torch.Tensor() self._logSoftMaxOutput = torch.Tensor() - print("parentIds", self.parentIds) - print("parentChildren", self.parentChildren) - print("children", self.childIds) - print("childParent", self.childParent) self:reset() end -function SoftMaxTree:getNodeParameters(parentId) - local node = self.parentChildren:select(1,parentId) - local start = node[1] - local nChildren = node[2] - local weight = self.weight:narrow(1, start, nChildren) - local bias = self.bias:narrow(1, start, nChildren) - return weight, bias -end - function SoftMaxTree:reset(stdv) if stdv then stdv = stdv * math.sqrt(3) @@ -150,5 +133,53 @@ function SoftMaxTree:accGradParameters(inputTable, gradOutput, scale) input.nn.SoftMaxTree_accGradParameters(self, input, gradOutput, target, scale) end +function SoftMaxTree:parameters() + local params, grads = {}, {} + for parentId, scale in pairs(self.updates) do + local node = self.parentChildren:select(1, parentId) + local parentIdx = node[1] + local nChildren = node[2] + table.insert(params, self.weight:narrow(1, parentIdx, nChildren)) + table.insert(params, self.bias:narrow(1, parentIdx, nChildren)) + table.insert(grads, self.gradWeight:narrow(1, parentIdx, nChildren)) + table.insert(grads, self.gradBias:narrow(1, parentIdx, nChildren)) + end + if #params == 0 then + return {self.weight, self.bias}, {self.gradWeight, self.gradBias} + end + return params, grads +end + +function SoftMaxTree:getNodeParameters(parentId) + local node = self.parentChildren:select(1,parentId) + local start = node[1] + local nChildren = node[2] + local weight = self.weight:narrow(1, start, nChildren) + local bias = self.bias:narrow(1, start, nChildren) + local gradWeight = self.gradWeight:narrow(1, start, nChildren) + local gradBias = self.gradBias:narrow(1, start, nChildren) + return {weight, bias}, {gradWeight, gradBias} +end + +function SoftMaxTree:zeroGradParameters() + local _,gradParams = self:parameters() + if gradParams then + for i=1,#gradParams do + gradParams[i]:zero() + end + end + self.updates = {} +end + +function SoftMaxTree:type(type) + if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor') then + self.weight = self.weight:type(type) + self.bias = self.bias:type(type) + self.gradWeight = self.gradWeight:type(type) + self.gradBias = self.gradBias:type(type) + end + return self +end + -- we do not need to accumulate parameters when sharing SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters |