Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-01-23 18:05:03 +0300
committerDominik Grewe <dominikg@google.com>2015-01-27 20:43:38 +0300
commita61e94efee98ca9c2861567a54f3adcb222dbf42 (patch)
tree6e7bd4e44b051f6d1481d62793b6d158e80dcbfd /test.lua
parent8e3c4d93378249a1e730c3226cd4a9902cea8a79 (diff)
Matrix matrix multiplication layer.
Multiplies two matrices (or two batches of matrices). This is a pure Lua implementation taking advantage of the new torch.bmm function for minibatch inputs.
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua171
1 files changed, 171 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 5018de6..6db38d3 100644
--- a/test.lua
+++ b/test.lua
@@ -2843,6 +2843,177 @@ function nntest.DepthConcat()
mytester:assertTensorEq(gradInput, gradInputConcat, 0.000001, "Error in SpatialConcat:updateGradInput")
end
+local function createMatrixInputSizes()
+ local M = torch.random(10, 20)
+ local N = torch.random(10, 20)
+ local P = torch.random(10, 20)
+ return M, N, P
+end
+
+function nntest.MM()
+ local mm = nn.MM(false, true)
+ local M, N, P = createMatrixInputSizes()
+ local A = torch.randn(M, N)
+ local B = torch.randn(P, N)
+
+ -- Test forward pass.
+ local output = mm:forward({A, B})
+ mytester:assertTableEq(output:size():totable(), {M, P},
+ 'Output has wrong dimensionality')
+ mytester:assertTensorEq(output, A * B:t(), 1e-10,
+ 'Wrong output')
+
+ -- Test backward pass.
+ local gradOutput = torch.randn(M, P)
+ local gradInput = mm:backward({A, B}, gradOutput)
+ mytester:assert(#gradInput == 2, 'gradInput must be table of size 2')
+ local gradA, gradB = unpack(gradInput)
+ mytester:assertTableEq(gradA:size():totable(), A:size():totable(),
+ 'Gradient for input A has wrong size')
+ mytester:assertTableEq(gradB:size():totable(), B:size():totable(),
+ 'Gradient for input B has wrong size')
+ mytester:assertTensorEq(gradA, gradOutput * B, 1e-10,
+ 'Wrong gradient for input A')
+ mytester:assertTensorEq(gradB, gradOutput:t() * A, 1e-10,
+ 'Wrong gradient for input B')
+end
+
+function nntest.BatchMMNoTranspose()
+ local mm = nn.MM()
+ local M, N, P = createMatrixInputSizes()
+ for bSize = 1, 11, 5 do
+ local A = torch.randn(bSize, M, N)
+ local B = torch.randn(bSize, N, P)
+
+ -- Test forward pass.
+ local output = mm:forward({A, B})
+ mytester:assertTableEq(output:size():totable(), {bSize, M, P},
+ 'Output has wrong dimensionality')
+ for i = 1, bSize do
+ mytester:assertTensorEq(output[i], A[i] * B[i], 1e-10,
+ 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+
+ -- Test backward pass.
+ local gradOutput = torch.randn(bSize, M, P)
+ local gradInput = mm:backward({A, B}, gradOutput)
+ mytester:assert(#gradInput == 2, 'gradInput must be table of size 2')
+ local gradA, gradB = unpack(gradInput)
+ mytester:assertTableEq(gradA:size():totable(), A:size():totable(),
+ 'Gradient for input A has wrong size')
+ mytester:assertTableEq(gradB:size():totable(), B:size():totable(),
+ 'Gradient for input B has wrong size')
+ for i = 1, bSize do
+ mytester:assertTensorEq(gradA[i], gradOutput[i] * B[i]:t(), 1e-10,
+ 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ mytester:assertTensorEq(gradB[i], A[i]:t() * gradOutput[i], 1e-10,
+ 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+ end
+end
+
+function nntest.BatchMMTransposeA()
+ local mm = nn.MM(true, false)
+ local M, N, P = createMatrixInputSizes()
+ for bSize = 1, 11, 5 do
+ local A = torch.randn(bSize, N, M)
+ local B = torch.randn(bSize, N, P)
+
+ -- Test forward pass.
+ local output = mm:forward({A, B})
+ mytester:assertTableEq(output:size():totable(), {bSize, M, P},
+ 'Output has wrong dimensionality')
+ for i = 1, bSize do
+ mytester:assertTensorEq(output[i], A[i]:t() * B[i], 1e-10,
+ 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+
+ -- Test backward pass.
+ local gradOutput = torch.randn(bSize, M, P)
+ local gradInput = mm:backward({A, B}, gradOutput)
+ mytester:assert(#gradInput == 2, 'gradInput must be table of size 2')
+ local gradA, gradB = unpack(gradInput)
+ mytester:assertTableEq(gradA:size():totable(), A:size():totable(),
+ 'Gradient for input A has wrong size')
+ mytester:assertTableEq(gradB:size():totable(), B:size():totable(),
+ 'Gradient for input B has wrong size')
+ for i = 1, bSize do
+ mytester:assertTensorEq(gradA[i], B[i] * gradOutput[i]:t(), 1e-10,
+ 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ mytester:assertTensorEq(gradB[i], A[i] * gradOutput[i], 1e-10,
+ 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+ end
+end
+
+function nntest.BatchMMTransposeB()
+ local mm = nn.MM(false, true)
+ local M, N, P = createMatrixInputSizes()
+ for bSize = 1, 11, 5 do
+ local A = torch.randn(bSize, M, N)
+ local B = torch.randn(bSize, P, N)
+
+ -- Test forward pass.
+ local output = mm:forward({A, B})
+ mytester:assertTableEq(output:size():totable(), {bSize, M, P},
+ 'Output has wrong dimensionality')
+ for i = 1, bSize do
+ mytester:assertTensorEq(output[i], A[i] * B[i]:t(), 1e-10,
+ 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+
+ -- Test backward pass.
+ local gradOutput = torch.randn(bSize, M, P)
+ local gradInput = mm:backward({A, B}, gradOutput)
+ mytester:assert(#gradInput == 2, 'gradInput must be table of size 2')
+ local gradA, gradB = unpack(gradInput)
+ mytester:assertTableEq(gradA:size():totable(), A:size():totable(),
+ 'Gradient for input A has wrong size')
+ mytester:assertTableEq(gradB:size():totable(), B:size():totable(),
+ 'Gradient for input B has wrong size')
+ for i = 1, bSize do
+ mytester:assertTensorEq(gradA[i], gradOutput[i] * B[i], 1e-10,
+ 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ mytester:assertTensorEq(gradB[i], gradOutput[i]:t() * A[i], 1e-10,
+ 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+ end
+end
+
+function nntest.BatchMMTransposeBoth()
+ local mm = nn.MM(true, true)
+ local M, N, P = createMatrixInputSizes()
+ for bSize = 1, 11, 5 do
+ local A = torch.randn(bSize, N, M)
+ local B = torch.randn(bSize, P, N)
+
+ -- Test forward pass.
+ local output = mm:forward({A, B})
+ mytester:assertTableEq(output:size():totable(), {bSize, M, P},
+ 'Output has wrong dimensionality')
+ for i = 1, bSize do
+ mytester:assertTensorEq(output[i], A[i]:t() * B[i]:t(), 1e-10,
+ 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+
+ -- Test backward pass.
+ local gradOutput = torch.randn(bSize, M, P)
+ local gradInput = mm:backward({A, B}, gradOutput)
+ mytester:assert(#gradInput == 2, 'gradInput must be table of size 2')
+ local gradA, gradB = unpack(gradInput)
+ mytester:assertTableEq(gradA:size():totable(), A:size():totable(),
+ 'Gradient for input A has wrong size')
+ mytester:assertTableEq(gradB:size():totable(), B:size():totable(),
+ 'Gradient for input B has wrong size')
+ for i = 1, bSize do
+ mytester:assertTensorEq(gradA[i], B[i]:t() * gradOutput[i]:t(), 1e-10,
+ 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ mytester:assertTensorEq(gradB[i], gradOutput[i]:t() * A[i]:t(), 1e-10,
+ 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i)
+ end
+ end
+end
+
mytester:add(nntest)
if not nn then