diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-10 23:49:39 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-10 23:49:39 +0400 |
commit | d52181d3e5d7f93e8f666c0ff506f65d4c949948 (patch) | |
tree | 4c5aeda69d108507cf9b095efc3f4930e1b2bdc0 /MixtureTable.lua | |
parent | 5e5d7f244a00ab12a5e8d5a0171c6f3bc3c4e9cc (diff) |
MixtureTable:updateOutput
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r-- | MixtureTable.lua | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua new file mode 100644 index 0000000..5403800 --- /dev/null +++ b/MixtureTable.lua @@ -0,0 +1,39 @@ +local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module') + +function MixtureTable:__init() + parent.__init(self) + self._gate = torch.Tensor() + self.size = torch.LongTensor() + self.batchSize = 0 +end + +function MixtureTable:updateOutput(input) + local gaterInput, expertInputs = unpack(input) + if gaterInput:dim() == 2 then + if gaterInput:size(2) ~= #expertInputs then + error"Should be one gater output per expert" + end + local expertInput = expertInputs[1] + if self.batchSize ~= expertInput:size(1) then + self.size:resize(expertInput:dim()):fill(1) + self.size[1] = expertInput:size(1) + self.output:resizeAs(expertInput) + self.batchSize ~= expertInput:size(1) + end + self.output:zero() + for i,expertInput in ipairs(expertInputs) do + -- multiply each gater output (a gate) by its + -- commensurate expert + self._gate:resize(self.size:storage()) + self._gate:copy(gaterInput:select(2,i)) + self.output:addcmul(expertInput,gate:expandAs(expertInput)) + end + end + + return self.output +end + +function MixtureTable:updateGradInput(input, gradOutput) + + return self.gradInput +end |