Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-04-01 18:45:20 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-04-01 18:45:20 +0400
commit29cfb869afa0c3fe5545d92a1b978ddb754b2625 (patch)
tree9ca39051c499b09a6139444ffd294035696179c4
parent6051e93dabb85f43df95cb3315776b6023dce4cc (diff)
Added Log module.
-rw-r--r--Log.lua20
-rw-r--r--init.lua1
-rw-r--r--test/test.lua15
3 files changed, 36 insertions, 0 deletions
diff --git a/Log.lua b/Log.lua
new file mode 100644
index 0000000..fec4664
--- /dev/null
+++ b/Log.lua
@@ -0,0 +1,20 @@
+local Log, parent = torch.class('nn.Log', 'nn.Module')
+
+function Log:__init(inputSize)
+ parent.__init(self)
+end
+
+function Log:updateOutput(input)
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:log()
+ return self.output
+end
+
+function Log:updateGradInput(input, gradOutput)
+ self.gradInput:resizeAs(input)
+ self.gradInput:fill(1)
+ self.gradInput:cdiv(input)
+ self.gradInput:cmul(gradOutput)
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index c6e7df0..d53a803 100644
--- a/init.lua
+++ b/init.lua
@@ -35,6 +35,7 @@ torch.include('nn', 'CosineDistance.lua')
torch.include('nn', 'DotProduct.lua')
torch.include('nn', 'Exp.lua')
+torch.include('nn', 'Log.lua')
torch.include('nn', 'HardTanh.lua')
torch.include('nn', 'LogSigmoid.lua')
torch.include('nn', 'LogSoftMax.lua')
diff --git a/test/test.lua b/test/test.lua
index b1eded7..d67165f 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -74,6 +74,21 @@ function nntest.Exp()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.Log()
+ local ini = math.random(10,20)
+ local inj = math.random(10,20)
+ local ink = math.random(10,20)
+ local input = torch.Tensor(ini,inj,ink):zero()
+ local module = nn.Log()
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+end
+
function nntest.HardTanh()
local ini = math.random(5,10)
local inj = math.random(5,10)