Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--AddConstant.lua20
-rw-r--r--MulConstant.lua21
-rwxr-xr-x[-rw-r--r--]doc/transfer.md10
-rw-r--r--init.lua2
-rw-r--r--test/test.lua40
5 files changed, 93 insertions, 0 deletions
diff --git a/AddConstant.lua b/AddConstant.lua
new file mode 100644
index 0000000..bcf33ed
--- /dev/null
+++ b/AddConstant.lua
@@ -0,0 +1,20 @@
+local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module')
+
+function AddConstant:__init(constant_scalar)
+ parent.__init(self)
+ assert(type(constant_scalar) == 'number', 'input is not scalar!')
+ self.constant_scalar = constant_scalar
+end
+
+function AddConstant:updateOutput(input)
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:add(self.constant_scalar)
+ return self.output
+end
+
+function AddConstant:updateGradInput(input, gradOutput)
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ return self.gradInput
+end
diff --git a/MulConstant.lua b/MulConstant.lua
new file mode 100644
index 0000000..982ab41
--- /dev/null
+++ b/MulConstant.lua
@@ -0,0 +1,21 @@
+local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module')
+
+function MulConstant:__init(constant_scalar)
+ parent.__init(self)
+ assert(type(constant_scalar) == 'number', 'input is not scalar!')
+ self.constant_scalar = constant_scalar
+end
+
+function MulConstant:updateOutput(input)
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:mul(self.constant_scalar)
+ return self.output
+end
+
+function MulConstant:updateGradInput(input, gradOutput)
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ self.gradInput:mul(self.constant_scalar)
+ return self.gradInput
+end
diff --git a/doc/transfer.md b/doc/transfer.md
index 6eb641b..9c9b56d 100644..100755
--- a/doc/transfer.md
+++ b/doc/transfer.md
@@ -231,3 +231,13 @@ gnuplot.grid(true)
```
![](image/tanh.png)
+<a name="nn.AddConstant"/>
+## AddConstant ##
+
+Adds a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `x + k`, where `k` is a scalar.
+
+<a name="nn.MullConstant"/>
+## MulConstant ##
+
+Multiplies input tensor by a (non-learnable) scalar constant. This module is sometimes useful for debuggging purposes: `f(x)` = `k * x`, where `k` is a scalar.
+
diff --git a/init.lua b/init.lua
index 1fba70a..db1ab47 100644
--- a/init.lua
+++ b/init.lua
@@ -22,7 +22,9 @@ include('Mean.lua')
include('Sum.lua')
include('CMul.lua')
include('Mul.lua')
+include('MulConstant.lua')
include('Add.lua')
+include('AddConstant.lua')
include('CAddTable.lua')
include('CDivTable.lua')
diff --git a/test/test.lua b/test/test.lua
index be17fd7..775dded 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1789,6 +1789,46 @@ function nntest.LookupTable()
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+
+function nntest.AddConstant()
+ local nbatch = torch.random(3, 5)
+ local f = torch.random(3, 5)
+ local h = torch.random(10,20)
+ local w = torch.random(10,20)
+ local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10]
+
+ local constant = torch.randn(1):squeeze()
+ local mod = nn.AddConstant(constant)
+
+ -- Test FPROP
+ local output = mod:forward(input)
+ local delta = output - input
+ mytester:assertlt(delta:add(-constant):abs():max(), precision, 'fprop error')
+
+ -- Test BPROP
+ local err = jac.testJacobian(mod, input)
+ mytester:assertlt(err, precision, 'bprop error ')
+end
+
+function nntest.MulConstant()
+ local nbatch = torch.random(3, 5)
+ local f = torch.random(3, 5)
+ local h = torch.random(10,20)
+ local w = torch.random(10,20)
+ local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10]
+
+ local constant = torch.randn(1):squeeze()
+ local mod = nn.MulConstant(constant)
+
+ -- Test FPROP
+ local output = mod:forward(input)
+ local scale = output:clone():cdiv(input)
+ mytester:assertlt(scale:add(-constant):abs():max(), precision, 'fprop error')
+
+ -- Test BPROP
+ local err = jac.testJacobian(mod, input)
+ mytester:assertlt(err, precision, 'bprop error ')
+end
mytester:add(nntest)