Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkarpathy <andrej.karpathy@gmail.com>2015-05-06 00:55:07 +0300
committerSoumith Chintala <soumith@gmail.com>2015-05-13 07:50:22 +0300
commit905ea8c1a4033af2af0f90e92e16597f442f3512 (patch)
tree206d026415810a399d9db96c75a8625a46b6f11c
parent28b0d2a80f302e39876002a3978bb4f70c4ee171 (diff)
Adding Batch L2 Normalization Layer that makes all rows of input Tensor unit L2 norm
-rw-r--r--L2Normalize.lua40
-rw-r--r--init.lua1
-rw-r--r--test.lua23
3 files changed, 64 insertions, 0 deletions
diff --git a/L2Normalize.lua b/L2Normalize.lua
new file mode 100644
index 0000000..f1dfd0e
--- /dev/null
+++ b/L2Normalize.lua
@@ -0,0 +1,40 @@
+
+--[[
+ This layer expects an [n x d] Tensor and normalizes each
+ row to have unit L2 norm.
+]]--
+local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
+function L2Normalize:__init()
+ parent.__init(self)
+end
+function L2Normalize:updateOutput(input)
+ assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
+ .. input:dim() .. 'D tensor instead')
+ self.output:resizeAs(input)
+ self.buffer = self.buffer or input.new()
+ self.normSquared = self.normSquared or input.new()
+ self.normSquared:sum(self.buffer:cmul(input, input), 2)
+ self.buffer:sqrt(self.normSquared)
+ self.output:copy(input):cdiv(self.buffer:expandAs(input))
+ return self.output
+end
+
+function L2Normalize:updateGradInput(input, gradOutput)
+ assert(input:dim() == 2, 'only mini-batch supported')
+ assert(gradOutput:dim() == 2, 'only mini-batch supported')
+ local n = input:size(1) -- batch size
+ local d = input:size(2) -- dimensionality of vectors
+ -- compute diagonal term
+ self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
+ self.diag = self.diag or self.eye.new()
+ self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
+ -- compute cross term
+ local b1 = input:view(n,d,1)
+ local b2 = input:view(n,1,d)
+ self.diag:add(-torch.bmm(b1,b2))
+ -- compute the local gradient of the L2 transformation
+ self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
+ -- chain the gradient
+ self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index b1d36db..520b66e 100644
--- a/init.lua
+++ b/init.lua
@@ -46,6 +46,7 @@ include('WeightedEuclidean.lua')
include('PairwiseDistance.lua')
include('CosineDistance.lua')
include('DotProduct.lua')
+include('L2Normalize.lua')
include('Exp.lua')
include('Log.lua')
diff --git a/test.lua b/test.lua
index 9414a66..959c369 100644
--- a/test.lua
+++ b/test.lua
@@ -3554,6 +3554,29 @@ function nntest.Padding()
mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
end
+function nntest.L2Normalize()
+ local ini = math.random(6,8)
+ local inj = math.random(3,5)
+ local input = torch.randn(ini, inj)
+
+ local module = nn.L2Normalize()
+
+ -- test correctness of output
+ local output = module:forward(input)
+ local norms = torch.norm(output, 2, 2)
+ local desired_norms = torch.ones(ini)
+ mytester:assertTensorEq(norms, desired_norms, 0.000001, 'L2Normalize forward err')
+
+ -- test the Jacobian
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ -- test IO correctness
+ local ferr, berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+end
+
mytester:add(nntest)
if not nn then