diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-14 03:08:37 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-14 03:08:37 +0400 |
commit | ba82d19f47ed2e377594f72fad60498efda4b5c8 (patch) | |
tree | 8096d9ba1a7fee75d7b2a010ff54f3bb73407b61 /MixtureTable.lua | |
parent | 0391fd5f95bf265dbbc1e05f550f807b52a8f881 (diff) |
MixtureTable works for 3D tensors (unit tested)
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r-- | MixtureTable.lua | 110 |
1 files changed, 61 insertions, 49 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua index 030d75c..47a0a16 100644 --- a/MixtureTable.lua +++ b/MixtureTable.lua @@ -2,7 +2,7 @@ local MixtureTable, parent = torch.class('nn.MixtureTable', 'nn.Module') function MixtureTable:__init(dim) parent.__init(self) - self.dim = dim + self.dim = dim or 2 self._gaterView = torch.Tensor() self._expert = torch.Tensor() self._expertView = torch.Tensor() @@ -10,31 +10,19 @@ function MixtureTable:__init(dim) self.size = torch.LongStorage() self.batchSize = 0 self.gradInput = {torch.Tensor(), {}} - if self.dim then - self.gradInput[2] = torch.Tensor() - self.size2 = torch.LongStorage() - self._expertView2 = torch.Tensor() - end + self._gradInput = torch.Tensor() + self.size2 = torch.LongStorage() + self._expertView2 = torch.Tensor() + self._expert2 = torch.Tensor() self.backwardSetup = false end function MixtureTable:updateOutput(input) local gaterInput, expertInputs = unpack(input) - if gaterInput:dim() == 2 then - if self.dim then -- expertInputs is a Tensor : - if self.batchSize ~= expertInputs:size(1) then - self.size:resize(expertInputs:dim()):fill(1) - self.size[1] = gaterInput:size(1) - self.size[self.dim] = gaterInput:size(2) - self.output:resizeAs(expertInputs:select(self.dim, 1)) - self.batchSize = expertInputs:size(1) - self.backwardSetup = false - end - self._gaterView:view(gaterInput, self.size) - self._expert:cmul(self._gaterView:expandAs(expertInputs), expertInputs) - self.output:sum(self._expert, self.dim) - self.output:resizeAs(expertInputs:select(self.dim, 1)) - else -- expertInputs is a Table : + if gaterInput:dim() > 1 then + if self.table or torch.type(expertInputs) == 'table' then + -- expertInputs is a Table : + self.table = true if gaterInput:size(2) ~= #expertInputs then error"Should be one gater output per expert" end @@ -42,18 +30,36 @@ function MixtureTable:updateOutput(input) if self.batchSize ~= expertInput:size(1) then self.size:resize(expertInput:dim()+1):fill(1) self.size[1] = gaterInput:size(1) - self.size[2] = gaterInput:size(2) + self.size[self.dim] = gaterInput:size(2) self.output:resizeAs(expertInput) self.batchSize = expertInput:size(1) + if torch.type(self.gradInput[2]) ~= 'table' then + self.gradInput[2] = {} + end self.backwardSetup = false end self._gaterView:view(gaterInput, self.size) self.output:zero() -- multiply accumulate gater outputs by their commensurate expert for i,expertInput in ipairs(expertInputs) do - local gate = self._gaterView:select(2,i):expandAs(expertInput) + local gate = self._gaterView:select(self.dim,i):expandAs(expertInput) self.output:addcmul(expertInput, gate) end + else + -- expertInputs is a Tensor : + if self.batchSize ~= expertInputs:size(1) then + self.size:resize(expertInputs:dim()):fill(1) + self.size[1] = gaterInput:size(1) + self.size[self.dim] = gaterInput:size(2) + self.output:resizeAs(expertInputs:select(self.dim, 1)) + self.batchSize = expertInputs:size(1) + self.gradInput[2] = self._gradInput + self.backwardSetup = false + end + self._gaterView:view(gaterInput, self.size) + self._expert:cmul(self._gaterView:expandAs(expertInputs), expertInputs) + self.output:sum(self._expert, self.dim) + self.output:resizeAs(expertInputs:select(self.dim, 1)) end else error"Only works with mini-batches" @@ -64,27 +70,8 @@ end function MixtureTable:updateGradInput(input, gradOutput) local gaterInput, expertInputs = unpack(input) local gaterGradInput, expertGradInputs = unpack(self.gradInput) - if gradOutput:dim() == 2 then - if self.dim then - if not self.backwardSetup then - self.size2:resize(expertInputs:dim()) - self.size2:copy(expertInputs:size()) - self.size2[self.dim] = 1 - gaterGradInput:resizeAs(gaterInput) - self.backwardSetup = true - end - - -- gater updateGradInput - self._expertView:view(gradOutput, self.size2) - local gradOutput = self._expertView:expandAs(expertInputs) - self._expert:cmul(gradOutput, expertInputs) - self._expertView2:view(self._expert, gaterInput:size(1), gaterInput:size(2), -1) - gaterGradInput:sum(self._expertView2, 3) - gaterGradInput:resizeAs(gaterInput) - - -- expert updateGradInput - expertGradInputs:cmul(self._gaterView:expandAs(expertInputs), gradOutput) - else + if gradOutput:dim() > 1 then + if self.table then if not self.backwardSetup then for i,expertInput in ipairs(expertInputs) do local expertGradInput = expertGradInputs[i] or expertInput:clone() @@ -105,9 +92,34 @@ function MixtureTable:updateGradInput(input, gradOutput) gaterGradInput:select(2,i):copy(self._sum:select(2,1)) -- expert updateGradInput - local gate = self._gaterView:select(2,i):expandAs(expertGradInput) + local gate = self._gaterView:select(self.dim,i):expandAs(expertGradInput) expertGradInput:cmul(gate, gradOutput) end + else + if not self.backwardSetup then + self.size2:resize(expertInputs:dim()) + self.size2:copy(expertInputs:size()) + self.size2[self.dim] = 1 + gaterGradInput:resizeAs(gaterInput) + self.backwardSetup = true + end + + -- gater updateGradInput + self._expertView:view(gradOutput, self.size2) + local gradOutput = self._expertView:expandAs(expertInputs) + self._expert:cmul(gradOutput, expertInputs) + local expert = self._expert:transpose(self.dim, 2) + if not expert:isContiguous() then + self._expert2:resizeAs(expert) + self._expert2:copy(expert) + expert = self._expert2 + end + self._expertView2:view(expert, gaterInput:size(1), gaterInput:size(2), -1) + gaterGradInput:sum(self._expertView2, 3) + gaterGradInput:resizeAs(gaterInput) + + -- expert updateGradInput + expertGradInputs:cmul(self._gaterView:expandAs(expertInputs), gradOutput) end else error"Only works with mini-batches" @@ -118,10 +130,7 @@ end function MixtureTable:type(type) self.output = self.output:type(type) self.gradInput[1] = self.gradInput[1]:type(type) - if self.dim then - self.gradInput[2] = self.gradInput[2]:type(type) - self._expertView2 = self._expertView2:type(type) - else + if torch.type(self.gradInput[2]) == 'table' then for i,expertGradInput in ipairs(self.gradInput[2]) do self.gradInput[2][i] = expertGradInput:type(type) end @@ -130,4 +139,7 @@ function MixtureTable:type(type) self._expert = self._expert:type(type) self._expertView = self._expertView:type(type) self._sum = self._sum:type(type) + self._gradInput = self._gradInput:type(type) + self._expert2 = self._expert2:type(type) + self._expertView2 = self._expertView2:type(type) end |