Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-14 03:08:37 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-14 03:08:37 +0400
commitba82d19f47ed2e377594f72fad60498efda4b5c8 (patch)
tree8096d9ba1a7fee75d7b2a010ff54f3bb73407b61 /MixtureTable.lua
parent0391fd5f95bf265dbbc1e05f550f807b52a8f881 (diff)
MixtureTable works for 3D tensors (unit tested)
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r--MixtureTable.lua110
1 files changed, 61 insertions, 49 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua
index 030d75c..47a0a16 100644
--- a/MixtureTable.lua
+++ b/MixtureTable.lua
@@ -2,7 +2,7 @@ local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module')
function MixtureTable:__init(dim)
parent.__init(self)
- self.dim = dim
+ self.dim = dim or 2
self._gaterView = torch.Tensor()
self._expert = torch.Tensor()
self._expertView = torch.Tensor()
@@ -10,31 +10,19 @@ function MixtureTable:__init(dim)
self.size = torch.LongStorage()
self.batchSize = 0
self.gradInput = {torch.Tensor(), {}}
- if self.dim then
- self.gradInput[2] = torch.Tensor()
- self.size2 = torch.LongStorage()
- self._expertView2 = torch.Tensor()
- end
+ self._gradInput = torch.Tensor()
+ self.size2 = torch.LongStorage()
+ self._expertView2 = torch.Tensor()
+ self._expert2 = torch.Tensor()
self.backwardSetup = false
end
function MixtureTable:updateOutput(input)
local gaterInput, expertInputs = unpack(input)
- if gaterInput:dim() == 2 then
- if self.dim then -- expertInputs is a Tensor :
- if self.batchSize ~= expertInputs:size(1) then
- self.size:resize(expertInputs:dim()):fill(1)
- self.size[1] = gaterInput:size(1)
- self.size[self.dim] = gaterInput:size(2)
- self.output:resizeAs(expertInputs:select(self.dim, 1))
- self.batchSize = expertInputs:size(1)
- self.backwardSetup = false
- end
- self._gaterView:view(gaterInput, self.size)
- self._expert:cmul(self._gaterView:expandAs(expertInputs), expertInputs)
- self.output:sum(self._expert, self.dim)
- self.output:resizeAs(expertInputs:select(self.dim, 1))
- else -- expertInputs is a Table :
+ if gaterInput:dim() > 1 then
+ if self.table or torch.type(expertInputs) == 'table' then
+ -- expertInputs is a Table :
+ self.table = true
if gaterInput:size(2) ~= #expertInputs then
error"Should be one gater output per expert"
end
@@ -42,18 +30,36 @@ function MixtureTable:updateOutput(input)
if self.batchSize ~= expertInput:size(1) then
self.size:resize(expertInput:dim()+1):fill(1)
self.size[1] = gaterInput:size(1)
- self.size[2] = gaterInput:size(2)
+ self.size[self.dim] = gaterInput:size(2)
self.output:resizeAs(expertInput)
self.batchSize = expertInput:size(1)
+ if torch.type(self.gradInput[2]) ~= 'table' then
+ self.gradInput[2] = {}
+ end
self.backwardSetup = false
end
self._gaterView:view(gaterInput, self.size)
self.output:zero()
-- multiply accumulate gater outputs by their commensurate expert
for i,expertInput in ipairs(expertInputs) do
- local gate = self._gaterView:select(2,i):expandAs(expertInput)
+ local gate = self._gaterView:select(self.dim,i):expandAs(expertInput)
self.output:addcmul(expertInput, gate)
end
+ else
+ -- expertInputs is a Tensor :
+ if self.batchSize ~= expertInputs:size(1) then
+ self.size:resize(expertInputs:dim()):fill(1)
+ self.size[1] = gaterInput:size(1)
+ self.size[self.dim] = gaterInput:size(2)
+ self.output:resizeAs(expertInputs:select(self.dim, 1))
+ self.batchSize = expertInputs:size(1)
+ self.gradInput[2] = self._gradInput
+ self.backwardSetup = false
+ end
+ self._gaterView:view(gaterInput, self.size)
+ self._expert:cmul(self._gaterView:expandAs(expertInputs), expertInputs)
+ self.output:sum(self._expert, self.dim)
+ self.output:resizeAs(expertInputs:select(self.dim, 1))
end
else
error"Only works with mini-batches"
@@ -64,27 +70,8 @@ end
function MixtureTable:updateGradInput(input, gradOutput)
local gaterInput, expertInputs = unpack(input)
local gaterGradInput, expertGradInputs = unpack(self.gradInput)
- if gradOutput:dim() == 2 then
- if self.dim then
- if not self.backwardSetup then
- self.size2:resize(expertInputs:dim())
- self.size2:copy(expertInputs:size())
- self.size2[self.dim] = 1
- gaterGradInput:resizeAs(gaterInput)
- self.backwardSetup = true
- end
-
- -- gater updateGradInput
- self._expertView:view(gradOutput, self.size2)
- local gradOutput = self._expertView:expandAs(expertInputs)
- self._expert:cmul(gradOutput, expertInputs)
- self._expertView2:view(self._expert, gaterInput:size(1), gaterInput:size(2), -1)
- gaterGradInput:sum(self._expertView2, 3)
- gaterGradInput:resizeAs(gaterInput)
-
- -- expert updateGradInput
- expertGradInputs:cmul(self._gaterView:expandAs(expertInputs), gradOutput)
- else
+ if gradOutput:dim() > 1 then
+ if self.table then
if not self.backwardSetup then
for i,expertInput in ipairs(expertInputs) do
local expertGradInput = expertGradInputs[i] or expertInput:clone()
@@ -105,9 +92,34 @@ function MixtureTable:updateGradInput(input, gradOutput)
gaterGradInput:select(2,i):copy(self._sum:select(2,1))
-- expert updateGradInput
- local gate = self._gaterView:select(2,i):expandAs(expertGradInput)
+ local gate = self._gaterView:select(self.dim,i):expandAs(expertGradInput)
expertGradInput:cmul(gate, gradOutput)
end
+ else
+ if not self.backwardSetup then
+ self.size2:resize(expertInputs:dim())
+ self.size2:copy(expertInputs:size())
+ self.size2[self.dim] = 1
+ gaterGradInput:resizeAs(gaterInput)
+ self.backwardSetup = true
+ end
+
+ -- gater updateGradInput
+ self._expertView:view(gradOutput, self.size2)
+ local gradOutput = self._expertView:expandAs(expertInputs)
+ self._expert:cmul(gradOutput, expertInputs)
+ local expert = self._expert:transpose(self.dim, 2)
+ if not expert:isContiguous() then
+ self._expert2:resizeAs(expert)
+ self._expert2:copy(expert)
+ expert = self._expert2
+ end
+ self._expertView2:view(expert, gaterInput:size(1), gaterInput:size(2), -1)
+ gaterGradInput:sum(self._expertView2, 3)
+ gaterGradInput:resizeAs(gaterInput)
+
+ -- expert updateGradInput
+ expertGradInputs:cmul(self._gaterView:expandAs(expertInputs), gradOutput)
end
else
error"Only works with mini-batches"
@@ -118,10 +130,7 @@ end
function MixtureTable:type(type)
self.output = self.output:type(type)
self.gradInput[1] = self.gradInput[1]:type(type)
- if self.dim then
- self.gradInput[2] = self.gradInput[2]:type(type)
- self._expertView2 = self._expertView2:type(type)
- else
+ if torch.type(self.gradInput[2]) == 'table' then
for i,expertGradInput in ipairs(self.gradInput[2]) do
self.gradInput[2][i] = expertGradInput:type(type)
end
@@ -130,4 +139,7 @@ function MixtureTable:type(type)
self._expert = self._expert:type(type)
self._expertView = self._expertView:type(type)
self._sum = self._sum:type(type)
+ self._gradInput = self._gradInput:type(type)
+ self._expert2 = self._expert2:type(type)
+ self._expertView2 = self._expertView2:type(type)
end