diff options
-rw-r--r-- | AddConstant.lua | 29 | ||||
-rw-r--r-- | MulConstant.lua | 32 | ||||
-rwxr-xr-x | doc/transfer.md | 12 | ||||
-rw-r--r-- | test.lua | 74 |
4 files changed, 134 insertions, 13 deletions
diff --git a/AddConstant.lua b/AddConstant.lua index bcf33ed..7eff2b1 100644 --- a/AddConstant.lua +++ b/AddConstant.lua @@ -1,20 +1,37 @@ local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module') -function AddConstant:__init(constant_scalar) +function AddConstant:__init(constant_scalar,ip) parent.__init(self) assert(type(constant_scalar) == 'number', 'input is not scalar!') self.constant_scalar = constant_scalar + + -- default for inplace is false + self.inplace = ip or false + if (ip and type(ip) ~= 'boolean') then + error('in-place flag must be boolean') + end end function AddConstant:updateOutput(input) - self.output:resizeAs(input) - self.output:copy(input) - self.output:add(self.constant_scalar) + if self.inplace then + input:add(self.constant_scalar) + self.output = input + else + self.output:resizeAs(input) + self.output:copy(input) + self.output:add(self.constant_scalar) + end return self.output end function AddConstant:updateGradInput(input, gradOutput) - self.gradInput:resizeAs(gradOutput) - self.gradInput:copy(gradOutput) + if self.inplace then + self.gradInput = gradOutput + -- restore previous input value + input:add(-self.constant_scalar) + else + self.gradInput:resizeAs(gradOutput) + self.gradInput:copy(gradOutput) + end return self.gradInput end diff --git a/MulConstant.lua b/MulConstant.lua index 982ab41..eb41d36 100644 --- a/MulConstant.lua +++ b/MulConstant.lua @@ -1,21 +1,39 @@ local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module') -function MulConstant:__init(constant_scalar) +function MulConstant:__init(constant_scalar,ip) parent.__init(self) assert(type(constant_scalar) == 'number', 'input is not scalar!') self.constant_scalar = constant_scalar + + -- default for inplace is false + self.inplace = ip or false + if (ip and type(ip) ~= 'boolean') then + error('in-place flag must be boolean') + end end function MulConstant:updateOutput(input) - self.output:resizeAs(input) - self.output:copy(input) - self.output:mul(self.constant_scalar) + if self.inplace then + input:mul(self.constant_scalar) + self.output = input + else + self.output:resizeAs(input) + self.output:copy(input) + self.output:mul(self.constant_scalar) + end return self.output end function MulConstant:updateGradInput(input, gradOutput) - self.gradInput:resizeAs(gradOutput) - self.gradInput:copy(gradOutput) - self.gradInput:mul(self.constant_scalar) + if self.inplace then + gradOutput:mul(self.constant_scalar) + self.gradInput = gradOutput + -- restore previous input value + input:div(self.constant_scalar) + else + self.gradInput:resizeAs(gradOutput) + self.gradInput:copy(gradOutput) + self.gradInput:mul(self.constant_scalar) + end return self.gradInput end diff --git a/doc/transfer.md b/doc/transfer.md index ce7b874..c03017d 100755 --- a/doc/transfer.md +++ b/doc/transfer.md @@ -272,7 +272,19 @@ Note that weight decay should not be used on it. For reference see http://arxiv. Adds a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `x + k`, where `k` is a scalar. +Can optionally do it's operation in-place without using extra state memory: +```lua +m=nn.AddConstant(k,true) -- true = in-place, false = keeping separate state. +``` +In-place mode restores the original input value after the backward pass, allowing it's use after other in-place modules, like [MulConstant](#nn.MulConstant). + <a name="nn.MulConstant"/> ## MulConstant ## Multiplies input tensor by a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `k * x`, where `k` is a scalar. + +Can optionally do it's operation in-place without using extra state memory: +```lua +m=nn.MulConstant(k,true) -- true = in-place, false = keeping separate state. +``` +In-place mode restores the original input value after the backward pass, allowing it's use after other in-place modules, like [AddConstant](#nn.AddConstant). @@ -2528,6 +2528,43 @@ function nntest.AddConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') + + -- inplace comparisons + local ini = math.random(3,5) + local inj = math.random(3,5) + local ink = math.random(3,5) + local constant = torch.uniform()*math.random(1,10) + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + local module1 = nn.AddConstant(constant,true) + local module2 = nn.AddConstant(constant) + + local gradOutput1 = torch.rand(ink, inj, ini) + local gradOutput2 = gradOutput1:clone() + + local out1 = module1:forward(input1) + local out2 = module2:forward(input2) + + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + ' - in-place forward err ') + + local gradInput1 = module1:backward(input1, gradOutput1) + local gradInput2 = module2:backward(input2, gradOutput2) + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + torch.typename(module1) .. ' - in-place backward err ') + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + module1:forward(input1) + module1:backward(module1.output,torch.rand(input1:size())) + + local err = (input1-input2):abs():max() + mytester:asserteq(err, 0, torch.typename(module1) .. + ' - inplace input change err ') end function nntest.MulConstant() @@ -2548,6 +2585,43 @@ function nntest.MulConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') + + -- inplace comparisons + local ini = math.random(3,5) + local inj = math.random(3,5) + local ink = math.random(3,5) + local constant = torch.uniform()*math.random(1,10) + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + local module1 = nn.MulConstant(constant,true) + local module2 = nn.MulConstant(constant) + + local gradOutput1 = torch.rand(ink, inj, ini) + local gradOutput2 = gradOutput1:clone() + + local out1 = module1:forward(input1) + local out2 = module2:forward(input2) + + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + ' - in-place forward err ') + + local gradInput1 = module1:backward(input1, gradOutput1) + local gradInput2 = module2:backward(input2, gradOutput2) + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + torch.typename(module1) .. ' - in-place backward err ') + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + module1:forward(input1) + module1:backward(module1.output,torch.rand(input1:size())) + + local err = (input1-input2):abs():max() + mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. + ' - inplace input change err ') end function nntest.Copy() |