Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'SoftMaxForest.lua')
-rw-r--r--SoftMaxForest.lua55
1 files changed, 55 insertions, 0 deletions
diff --git a/SoftMaxForest.lua b/SoftMaxForest.lua
new file mode 100644
index 0000000..0fee1fd
--- /dev/null
+++ b/SoftMaxForest.lua
@@ -0,0 +1,55 @@
+local SoftMaxForest, parent = torch.class("nn.SoftMaxForest", "nn.Container")
+
+function SoftMaxForest:__init(inputSize, trees, rootIds, gaterSize, gaterAct, accUpdate)
+ local gaterAct = gaterAct or nn.Tanh()
+ local gaterSize = gaterSize or {}
+
+ -- experts
+ self.experts = nn.ConcatTable()
+ self.smts = {}
+ for i,tree in ipairs(trees) do
+ local smt = nn.SoftMaxTree(inputSize, tree, rootIds[i], accUpdate)
+ table.insert(self._smts, smt)
+ self.experts:add(smt)
+ end
+
+ -- gater
+ self.gater = nn.Sequential()
+ self.gater:add(nn.SelectTable(1)) -- ignore targets
+ for i,hiddenSize in ipairs(gaterSize) do
+ self.gater:add(nn.Linear(inputSize, hiddenSize))
+ self.gater:add(gaterAct:clone())
+ inputSize = hiddenSize
+ end
+ self.gater:add(nn.Linear(inputSize, self.experts:size()))
+ self.gater:add(nn.SoftMax())
+
+ -- mixture
+ self.trunk = nn.ConcatTable()
+ self.trunk:add(self._gater)
+ self.trunk:add(self._experts)
+ self.mixture = nn.MixtureTable()
+ self.module = nn.Sequential()
+ self.module:add(self.trunk)
+ self.module:add(self.mixture)
+ parent.__init(self)
+ self.modules[1] = self.module
+end
+
+function SoftMaxForest:updateOutput(input)
+ self.output = self.module:updateOutput(input)
+ return self.output
+end
+
+function SoftMaxForest:updateGradInput(input, gradOutput)
+ self.gradInput = self.module:updateGradInput(input, gradOutput)
+ return self.gradInput
+end
+
+function SoftMaxForest:accGradParameters(input, gradOutput, scale)
+ self.module:accGradParameters(input, gradOutput, scale)
+end
+
+function SoftMaxForest:accUpdateGradParameters(input, gradOutput, lr)
+ self.module:accUpdateGradParameters(input, gradOutput, lr)
+end