diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-01-06 23:07:53 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-01-06 23:25:47 +0300 |
commit | 3dcecf7c22745b13fd5ba85848423a346ec42ad2 (patch) | |
tree | 26a17b3482d624bc80734056556fbf4ae6fd4958 /Parallel.lua | |
parent | 4e0a96d801060121521ccc46f7294aeb3b247965 (diff) |
Parallel optimization. ParallelTable inherits Container. unit tests
Diffstat (limited to 'Parallel.lua')
-rw-r--r-- | Parallel.lua | 74 |
1 files changed, 28 insertions, 46 deletions
diff --git a/Parallel.lua b/Parallel.lua index ef42723..b63b211 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -9,28 +9,29 @@ function Parallel:__init(inputDimension,outputDimension) end function Parallel:updateOutput(input) - - local modules=input:size(self.inputDimension) + local nModule=input:size(self.inputDimension) + local outputs = {} - for i=1,modules do - local currentOutput = - self.modules[i]:updateOutput(input:select(self.inputDimension,i)) + for i=1,nModule do + local currentInput = input:select(self.inputDimension,i) + local currentOutput = self.modules[i]:updateOutput(currentInput) + table.insert(outputs, currentOutput) + local outputSize = currentOutput:size(self.outputDimension) if i == 1 then self.size:resize(currentOutput:dim()):copy(currentOutput:size()) else - self.size[self.outputDimension] = self.size[self.outputDimension] - + currentOutput:size(self.outputDimension) + self.size[self.outputDimension] = self.size[self.outputDimension] + outputSize end + end self.output:resize(self.size) local offset = 1 - for i=1,modules do - local currentOutput = self.modules[i]:updateOutput(input:select(self.inputDimension,i)) - - self.output:narrow(self.outputDimension, offset, - currentOutput:size(self.outputDimension)):copy(currentOutput) + for i=1,nModule do + local currentOutput = outputs[i] + local outputSize = currentOutput:size(self.outputDimension) + self.output:narrow(self.outputDimension, offset, outputSize):copy(currentOutput) offset = offset + currentOutput:size(self.outputDimension) end return self.output @@ -42,15 +43,16 @@ function Parallel:updateGradInput(input, gradOutput) local offset = 1 for i=1,nModule do - local module=self.modules[i]; + local module=self.modules[i] + local currentInput = input:select(self.inputDimension,i) local currentOutput = module.output - local currentGradInput = - module:updateGradInput(input:select(self.inputDimension,i), - gradOutput:narrow(self.outputDimension, - offset, currentOutput:size(self.outputDimension))) + local outputSize = currentOutput:size(self.outputDimension) + local currentGradOutput = gradOutput:narrow(self.outputDimension, offset, outputSize) + + local currentGradInput = module:updateGradInput(currentInput, currentGradOutput) self.gradInput:select(self.inputDimension,i):copy(currentGradInput) - offset = offset + currentOutput:size(self.outputDimension) + offset = offset + outputSize end return self.gradInput end @@ -60,16 +62,17 @@ function Parallel:accGradParameters(input, gradOutput, scale) local offset = 1 for i=1,nModule do - local module = self.modules[i]; + local module = self.modules[i] local currentOutput = module.output + local outputSize = currentOutput:size(self.outputDimension) + module:accGradParameters( input:select(self.inputDimension,i), - gradOutput:narrow( - self.outputDimension, offset, - currentOutput:size(self.outputDimension)), - scale) + gradOutput:narrow(self.outputDimension, offset,outputSize), + scale + ) - offset = offset + currentOutput:size(self.outputDimension) + offset = offset + outputSize end end @@ -81,6 +84,7 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) local module = self.modules[i]; local currentOutput = module.output module:accUpdateGradParameters( + input:select(self.inputDimension,i), gradOutput:narrow(self.outputDimension, offset, currentOutput:size(self.outputDimension)), @@ -89,28 +93,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) offset = offset + currentOutput:size(self.outputDimension) end end - -function Parallel:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end function Parallel:__tostring__() local tab = ' ' |