diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-20 07:32:05 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-20 07:32:05 +0400 |
commit | fbdc1bbfb0b303cf419c09ee09f2c19759d290be (patch) | |
tree | 84cd4071cb63f18f975351c4132cabfb112d9e47 /SoftMaxTree.lua | |
parent | 954acf95fdef51b37ef132502deeab1b2d2f3c22 (diff) |
debugged SoftMaxTree:type()
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 3d2649d..9127306 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -31,10 +31,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) local maxChildrenId = children:max() maxChildId = math.max(maxChildrenId, maxChildId) maxNodeId = math.max(parentId, maxNodeId, maxChildrenId) - minNodeId = math.max(parentId, minNodeId, children:min()) + minNodeId = math.min(parentId, minNodeId, children:min()) table.insert(parentIds, parentId) end - assert(minNodeId >= 0, "nodeIds must must be non-negative") + if minNodeId < 0 then + error("nodeIds must must be positive: "..minNodeId, 2) + end if verbose then print("Hierachy has :") print(nParentNode.." parent nodes") @@ -177,6 +179,11 @@ function SoftMaxTree:type(type) self.bias = self.bias:type(type) self.gradWeight = self.gradWeight:type(type) self.gradBias = self.gradBias:type(type) + self._linearOutput = self._linearOutput:type(type) + self._linearGradOutput = self._linearGradOutput:type(type) + self._logSoftMaxOutput = self._logSoftMaxOutput:type(type) + self.output = self.output:type(type) + self.gradInput = self.gradInput:type(type) end return self end |