Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-10 23:49:39 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-10 23:49:39 +0400
commitd52181d3e5d7f93e8f666c0ff506f65d4c949948 (patch)
tree4c5aeda69d108507cf9b095efc3f4930e1b2bdc0 /MixtureTable.lua
parent5e5d7f244a00ab12a5e8d5a0171c6f3bc3c4e9cc (diff)
MixtureTable:updateOutput
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r--MixtureTable.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua
new file mode 100644
index 0000000..5403800
--- /dev/null
+++ b/MixtureTable.lua
@@ -0,0 +1,39 @@
+local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module')
+
+function MixtureTable:__init()
+ parent.__init(self)
+ self._gate = torch.Tensor()
+ self.size = torch.LongTensor()
+ self.batchSize = 0
+end
+
+function MixtureTable:updateOutput(input)
+ local gaterInput, expertInputs = unpack(input)
+ if gaterInput:dim() == 2 then
+ if gaterInput:size(2) ~= #expertInputs then
+ error"Should be one gater output per expert"
+ end
+ local expertInput = expertInputs[1]
+ if self.batchSize ~= expertInput:size(1) then
+ self.size:resize(expertInput:dim()):fill(1)
+ self.size[1] = expertInput:size(1)
+ self.output:resizeAs(expertInput)
+ self.batchSize ~= expertInput:size(1)
+ end
+ self.output:zero()
+ for i,expertInput in ipairs(expertInputs) do
+ -- multiply each gater output (a gate) by its
+ -- commensurate expert
+ self._gate:resize(self.size:storage())
+ self._gate:copy(gaterInput:select(2,i))
+ self.output:addcmul(expertInput,gate:expandAs(expertInput))
+ end
+ end
+
+ return self.output
+end
+
+function MixtureTable:updateGradInput(input, gradOutput)
+
+ return self.gradInput
+end