Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--AddConstant.lua29
-rw-r--r--MulConstant.lua32
-rwxr-xr-xdoc/transfer.md12
-rw-r--r--test.lua74
4 files changed, 134 insertions, 13 deletions
diff --git a/AddConstant.lua b/AddConstant.lua
index bcf33ed..7eff2b1 100644
--- a/AddConstant.lua
+++ b/AddConstant.lua
@@ -1,20 +1,37 @@
local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module')
-function AddConstant:__init(constant_scalar)
+function AddConstant:__init(constant_scalar,ip)
parent.__init(self)
assert(type(constant_scalar) == 'number', 'input is not scalar!')
self.constant_scalar = constant_scalar
+
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
end
function AddConstant:updateOutput(input)
- self.output:resizeAs(input)
- self.output:copy(input)
- self.output:add(self.constant_scalar)
+ if self.inplace then
+ input:add(self.constant_scalar)
+ self.output = input
+ else
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:add(self.constant_scalar)
+ end
return self.output
end
function AddConstant:updateGradInput(input, gradOutput)
- self.gradInput:resizeAs(gradOutput)
- self.gradInput:copy(gradOutput)
+ if self.inplace then
+ self.gradInput = gradOutput
+ -- restore previous input value
+ input:add(-self.constant_scalar)
+ else
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ end
return self.gradInput
end
diff --git a/MulConstant.lua b/MulConstant.lua
index 982ab41..eb41d36 100644
--- a/MulConstant.lua
+++ b/MulConstant.lua
@@ -1,21 +1,39 @@
local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module')
-function MulConstant:__init(constant_scalar)
+function MulConstant:__init(constant_scalar,ip)
parent.__init(self)
assert(type(constant_scalar) == 'number', 'input is not scalar!')
self.constant_scalar = constant_scalar
+
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
end
function MulConstant:updateOutput(input)
- self.output:resizeAs(input)
- self.output:copy(input)
- self.output:mul(self.constant_scalar)
+ if self.inplace then
+ input:mul(self.constant_scalar)
+ self.output = input
+ else
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:mul(self.constant_scalar)
+ end
return self.output
end
function MulConstant:updateGradInput(input, gradOutput)
- self.gradInput:resizeAs(gradOutput)
- self.gradInput:copy(gradOutput)
- self.gradInput:mul(self.constant_scalar)
+ if self.inplace then
+ gradOutput:mul(self.constant_scalar)
+ self.gradInput = gradOutput
+ -- restore previous input value
+ input:div(self.constant_scalar)
+ else
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ self.gradInput:mul(self.constant_scalar)
+ end
return self.gradInput
end
diff --git a/doc/transfer.md b/doc/transfer.md
index ce7b874..c03017d 100755
--- a/doc/transfer.md
+++ b/doc/transfer.md
@@ -272,7 +272,19 @@ Note that weight decay should not be used on it. For reference see http://arxiv.
Adds a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `x + k`, where `k` is a scalar.
+Can optionally do it's operation in-place without using extra state memory:
+```lua
+m=nn.AddConstant(k,true) -- true = in-place, false = keeping separate state.
+```
+In-place mode restores the original input value after the backward pass, allowing it's use after other in-place modules, like [MulConstant](#nn.MulConstant).
+
<a name="nn.MulConstant"/>
## MulConstant ##
Multiplies input tensor by a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `k * x`, where `k` is a scalar.
+
+Can optionally do it's operation in-place without using extra state memory:
+```lua
+m=nn.MulConstant(k,true) -- true = in-place, false = keeping separate state.
+```
+In-place mode restores the original input value after the backward pass, allowing it's use after other in-place modules, like [AddConstant](#nn.AddConstant).
diff --git a/test.lua b/test.lua
index 23c7fbd..e0afbbb 100644
--- a/test.lua
+++ b/test.lua
@@ -2528,6 +2528,43 @@ function nntest.AddConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
+
+ -- inplace comparisons
+ local ini = math.random(3,5)
+ local inj = math.random(3,5)
+ local ink = math.random(3,5)
+ local constant = torch.uniform()*math.random(1,10)
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ local module1 = nn.AddConstant(constant,true)
+ local module2 = nn.AddConstant(constant)
+
+ local gradOutput1 = torch.rand(ink, inj, ini)
+ local gradOutput2 = gradOutput1:clone()
+
+ local out1 = module1:forward(input1)
+ local out2 = module2:forward(input2)
+
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ ' - in-place forward err ')
+
+ local gradInput1 = module1:backward(input1, gradOutput1)
+ local gradInput2 = module2:backward(input2, gradOutput2)
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ torch.typename(module1) .. ' - in-place backward err ')
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ module1:forward(input1)
+ module1:backward(module1.output,torch.rand(input1:size()))
+
+ local err = (input1-input2):abs():max()
+ mytester:asserteq(err, 0, torch.typename(module1) ..
+ ' - inplace input change err ')
end
function nntest.MulConstant()
@@ -2548,6 +2585,43 @@ function nntest.MulConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
+
+ -- inplace comparisons
+ local ini = math.random(3,5)
+ local inj = math.random(3,5)
+ local ink = math.random(3,5)
+ local constant = torch.uniform()*math.random(1,10)
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ local module1 = nn.MulConstant(constant,true)
+ local module2 = nn.MulConstant(constant)
+
+ local gradOutput1 = torch.rand(ink, inj, ini)
+ local gradOutput2 = gradOutput1:clone()
+
+ local out1 = module1:forward(input1)
+ local out2 = module2:forward(input2)
+
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ ' - in-place forward err ')
+
+ local gradInput1 = module1:backward(input1, gradOutput1)
+ local gradInput2 = module2:backward(input2, gradOutput2)
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ torch.typename(module1) .. ' - in-place backward err ')
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ module1:forward(input1)
+ module1:backward(module1.output,torch.rand(input1:size()))
+
+ local err = (input1-input2):abs():max()
+ mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
+ ' - inplace input change err ')
end
function nntest.Copy()