Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-05-20 07:32:05 +0400
committernicholas-leonard <nick@nikopia.org>2014-05-20 07:32:05 +0400
commitfbdc1bbfb0b303cf419c09ee09f2c19759d290be (patch)
tree84cd4071cb63f18f975351c4132cabfb112d9e47 /SoftMaxTree.lua
parent954acf95fdef51b37ef132502deeab1b2d2f3c22 (diff)
debugged SoftMaxTree:type()
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua11
1 files changed, 9 insertions, 2 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 3d2649d..9127306 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -31,10 +31,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
local maxChildrenId = children:max()
maxChildId = math.max(maxChildrenId, maxChildId)
maxNodeId = math.max(parentId, maxNodeId, maxChildrenId)
- minNodeId = math.max(parentId, minNodeId, children:min())
+ minNodeId = math.min(parentId, minNodeId, children:min())
table.insert(parentIds, parentId)
end
- assert(minNodeId >= 0, "nodeIds must must be non-negative")
+ if minNodeId < 0 then
+ error("nodeIds must must be positive: "..minNodeId, 2)
+ end
if verbose then
print("Hierachy has :")
print(nParentNode.." parent nodes")
@@ -177,6 +179,11 @@ function SoftMaxTree:type(type)
self.bias = self.bias:type(type)
self.gradWeight = self.gradWeight:type(type)
self.gradBias = self.gradBias:type(type)
+ self._linearOutput = self._linearOutput:type(type)
+ self._linearGradOutput = self._linearGradOutput:type(type)
+ self._logSoftMaxOutput = self._logSoftMaxOutput:type(type)
+ self.output = self.output:type(type)
+ self.gradInput = self.gradInput:type(type)
end
return self
end