Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SoftMaxForest.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0fee1fda1a79857960e7a9e29066c532a82ae178 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
local SoftMaxForest, parent = torch.class("nn.SoftMaxForest", "nn.Container")

function SoftMaxForest:__init(inputSize, trees, rootIds, gaterSize, gaterAct, accUpdate)
   local gaterAct = gaterAct or nn.Tanh()
   local gaterSize = gaterSize or {}

   -- experts
   self.experts = nn.ConcatTable()
   self.smts = {}
   for i,tree in ipairs(trees) do
      local smt = nn.SoftMaxTree(inputSize, tree, rootIds[i], accUpdate)
      table.insert(self._smts, smt)
      self.experts:add(smt)
   end

   -- gater
   self.gater = nn.Sequential()
   self.gater:add(nn.SelectTable(1)) -- ignore targets
   for i,hiddenSize in ipairs(gaterSize) do
      self.gater:add(nn.Linear(inputSize, hiddenSize))
      self.gater:add(gaterAct:clone())
      inputSize = hiddenSize
   end
   self.gater:add(nn.Linear(inputSize, self.experts:size()))
   self.gater:add(nn.SoftMax())

   -- mixture
   self.trunk = nn.ConcatTable()
   self.trunk:add(self._gater)
   self.trunk:add(self._experts)
   self.mixture = nn.MixtureTable()
   self.module = nn.Sequential()
   self.module:add(self.trunk)
   self.module:add(self.mixture)
   parent.__init(self)
   self.modules[1] = self.module
end

function SoftMaxForest:updateOutput(input)
   self.output = self.module:updateOutput(input)
   return self.output
end

function SoftMaxForest:updateGradInput(input, gradOutput)
   self.gradInput = self.module:updateGradInput(input, gradOutput)
   return self.gradInput
end

function SoftMaxForest:accGradParameters(input, gradOutput, scale)
   self.module:accGradParameters(input, gradOutput, scale)
end

function SoftMaxForest:accUpdateGradParameters(input, gradOutput, lr)
   self.module:accUpdateGradParameters(input, gradOutput, lr)
end