diff options
Diffstat (limited to 'SoftMaxForest.lua')
-rw-r--r-- | SoftMaxForest.lua | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/SoftMaxForest.lua b/SoftMaxForest.lua new file mode 100644 index 0000000..0fee1fd --- /dev/null +++ b/SoftMaxForest.lua @@ -0,0 +1,55 @@ +local SoftMaxForest, parent = torch.class("nn.SoftMaxForest", "nn.Container") + +function SoftMaxForest:__init(inputSize, trees, rootIds, gaterSize, gaterAct, accUpdate) + local gaterAct = gaterAct or nn.Tanh() + local gaterSize = gaterSize or {} + + -- experts + self.experts = nn.ConcatTable() + self.smts = {} + for i,tree in ipairs(trees) do + local smt = nn.SoftMaxTree(inputSize, tree, rootIds[i], accUpdate) + table.insert(self._smts, smt) + self.experts:add(smt) + end + + -- gater + self.gater = nn.Sequential() + self.gater:add(nn.SelectTable(1)) -- ignore targets + for i,hiddenSize in ipairs(gaterSize) do + self.gater:add(nn.Linear(inputSize, hiddenSize)) + self.gater:add(gaterAct:clone()) + inputSize = hiddenSize + end + self.gater:add(nn.Linear(inputSize, self.experts:size())) + self.gater:add(nn.SoftMax()) + + -- mixture + self.trunk = nn.ConcatTable() + self.trunk:add(self._gater) + self.trunk:add(self._experts) + self.mixture = nn.MixtureTable() + self.module = nn.Sequential() + self.module:add(self.trunk) + self.module:add(self.mixture) + parent.__init(self) + self.modules[1] = self.module +end + +function SoftMaxForest:updateOutput(input) + self.output = self.module:updateOutput(input) + return self.output +end + +function SoftMaxForest:updateGradInput(input, gradOutput) + self.gradInput = self.module:updateGradInput(input, gradOutput) + return self.gradInput +end + +function SoftMaxForest:accGradParameters(input, gradOutput, scale) + self.module:accGradParameters(input, gradOutput, scale) +end + +function SoftMaxForest:accUpdateGradParameters(input, gradOutput, lr) + self.module:accUpdateGradParameters(input, gradOutput, lr) +end |