diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-06-25 20:14:12 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-06-25 20:14:12 +0400 |
commit | e406feabe2b1bd7f6b7a15826aff1e925fe713d6 (patch) | |
tree | 8177ddef413e068c95882e00d228feca0900ebaf | |
parent | ea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff) |
Added very simple add and mul constant modules.
-rw-r--r-- | AddConstant.lua | 20 | ||||
-rw-r--r-- | MulConstant.lua | 21 | ||||
-rwxr-xr-x[-rw-r--r--] | doc/transfer.md | 10 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | test/test.lua | 40 |
5 files changed, 93 insertions, 0 deletions
diff --git a/AddConstant.lua b/AddConstant.lua new file mode 100644 index 0000000..bcf33ed --- /dev/null +++ b/AddConstant.lua @@ -0,0 +1,20 @@ +local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module') + +function AddConstant:__init(constant_scalar) + parent.__init(self) + assert(type(constant_scalar) == 'number', 'input is not scalar!') + self.constant_scalar = constant_scalar +end + +function AddConstant:updateOutput(input) + self.output:resizeAs(input) + self.output:copy(input) + self.output:add(self.constant_scalar) + return self.output +end + +function AddConstant:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(gradOutput) + self.gradInput:copy(gradOutput) + return self.gradInput +end diff --git a/MulConstant.lua b/MulConstant.lua new file mode 100644 index 0000000..982ab41 --- /dev/null +++ b/MulConstant.lua @@ -0,0 +1,21 @@ +local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module') + +function MulConstant:__init(constant_scalar) + parent.__init(self) + assert(type(constant_scalar) == 'number', 'input is not scalar!') + self.constant_scalar = constant_scalar +end + +function MulConstant:updateOutput(input) + self.output:resizeAs(input) + self.output:copy(input) + self.output:mul(self.constant_scalar) + return self.output +end + +function MulConstant:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(gradOutput) + self.gradInput:copy(gradOutput) + self.gradInput:mul(self.constant_scalar) + return self.gradInput +end diff --git a/doc/transfer.md b/doc/transfer.md index 6eb641b..9c9b56d 100644..100755 --- a/doc/transfer.md +++ b/doc/transfer.md @@ -231,3 +231,13 @@ gnuplot.grid(true) ``` ![](image/tanh.png) +<a name="nn.AddConstant"/> +## AddConstant ## + +Adds a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `x + k`, where `k` is a scalar. + +<a name="nn.MullConstant"/> +## MulConstant ## + +Multiplies input tensor by a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `k * x`, where `k` is a scalar. + @@ -22,7 +22,9 @@ include('Mean.lua') include('Sum.lua') include('CMul.lua') include('Mul.lua') +include('MulConstant.lua') include('Add.lua') +include('AddConstant.lua') include('CAddTable.lua') include('CDivTable.lua') diff --git a/test/test.lua b/test/test.lua index be17fd7..775dded 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1789,6 +1789,46 @@ function nntest.LookupTable() mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end + +function nntest.AddConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.AddConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local delta = output - input + mytester:assertlt(delta:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end + +function nntest.MulConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.MulConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local scale = output:clone():cdiv(input) + mytester:assertlt(scale:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end mytester:add(nntest) |