Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-01-06 23:07:53 +0300
committernicholas-leonard <nick@nikopia.org>2015-01-06 23:25:47 +0300
commit3dcecf7c22745b13fd5ba85848423a346ec42ad2 (patch)
tree26a17b3482d624bc80734056556fbf4ae6fd4958 /Parallel.lua
parent4e0a96d801060121521ccc46f7294aeb3b247965 (diff)
Parallel optimization. ParallelTable inherits Container. unit tests
Diffstat (limited to 'Parallel.lua')
-rw-r--r--Parallel.lua74
1 files changed, 28 insertions, 46 deletions
diff --git a/Parallel.lua b/Parallel.lua
index ef42723..b63b211 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -9,28 +9,29 @@ function Parallel:__init(inputDimension,outputDimension)
end
function Parallel:updateOutput(input)
-
- local modules=input:size(self.inputDimension)
+ local nModule=input:size(self.inputDimension)
+ local outputs = {}
- for i=1,modules do
- local currentOutput =
- self.modules[i]:updateOutput(input:select(self.inputDimension,i))
+ for i=1,nModule do
+ local currentInput = input:select(self.inputDimension,i)
+ local currentOutput = self.modules[i]:updateOutput(currentInput)
+ table.insert(outputs, currentOutput)
+ local outputSize = currentOutput:size(self.outputDimension)
if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
- self.size[self.outputDimension] = self.size[self.outputDimension]
- + currentOutput:size(self.outputDimension)
+ self.size[self.outputDimension] = self.size[self.outputDimension] + outputSize
end
+
end
self.output:resize(self.size)
local offset = 1
- for i=1,modules do
- local currentOutput = self.modules[i]:updateOutput(input:select(self.inputDimension,i))
-
- self.output:narrow(self.outputDimension, offset,
- currentOutput:size(self.outputDimension)):copy(currentOutput)
+ for i=1,nModule do
+ local currentOutput = outputs[i]
+ local outputSize = currentOutput:size(self.outputDimension)
+ self.output:narrow(self.outputDimension, offset, outputSize):copy(currentOutput)
offset = offset + currentOutput:size(self.outputDimension)
end
return self.output
@@ -42,15 +43,16 @@ function Parallel:updateGradInput(input, gradOutput)
local offset = 1
for i=1,nModule do
- local module=self.modules[i];
+ local module=self.modules[i]
+ local currentInput = input:select(self.inputDimension,i)
local currentOutput = module.output
- local currentGradInput =
- module:updateGradInput(input:select(self.inputDimension,i),
- gradOutput:narrow(self.outputDimension,
- offset, currentOutput:size(self.outputDimension)))
+ local outputSize = currentOutput:size(self.outputDimension)
+ local currentGradOutput = gradOutput:narrow(self.outputDimension, offset, outputSize)
+
+ local currentGradInput = module:updateGradInput(currentInput, currentGradOutput)
self.gradInput:select(self.inputDimension,i):copy(currentGradInput)
- offset = offset + currentOutput:size(self.outputDimension)
+ offset = offset + outputSize
end
return self.gradInput
end
@@ -60,16 +62,17 @@ function Parallel:accGradParameters(input, gradOutput, scale)
local offset = 1
for i=1,nModule do
- local module = self.modules[i];
+ local module = self.modules[i]
local currentOutput = module.output
+ local outputSize = currentOutput:size(self.outputDimension)
+
module:accGradParameters(
input:select(self.inputDimension,i),
- gradOutput:narrow(
- self.outputDimension, offset,
- currentOutput:size(self.outputDimension)),
- scale)
+ gradOutput:narrow(self.outputDimension, offset,outputSize),
+ scale
+ )
- offset = offset + currentOutput:size(self.outputDimension)
+ offset = offset + outputSize
end
end
@@ -81,6 +84,7 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
local module = self.modules[i];
local currentOutput = module.output
module:accUpdateGradParameters(
+
input:select(self.inputDimension,i),
gradOutput:narrow(self.outputDimension, offset,
currentOutput:size(self.outputDimension)),
@@ -89,28 +93,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
offset = offset + currentOutput:size(self.outputDimension)
end
end
-
-function Parallel:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
function Parallel:__tostring__()
local tab = ' '