diff options
author | Dominik Grewe <dominikg@google.com> | 2015-01-23 18:05:03 +0300 |
---|---|---|
committer | Dominik Grewe <dominikg@google.com> | 2015-01-27 20:43:38 +0300 |
commit | a61e94efee98ca9c2861567a54f3adcb222dbf42 (patch) | |
tree | 6e7bd4e44b051f6d1481d62793b6d158e80dcbfd /test.lua | |
parent | 8e3c4d93378249a1e730c3226cd4a9902cea8a79 (diff) |
Matrix matrix multiplication layer.
Multiplies two matrices (or two batches of matrices).
This is a pure Lua implementation taking advantage of the new torch.bmm
function for minibatch inputs.
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 171 |
1 files changed, 171 insertions, 0 deletions
@@ -2843,6 +2843,177 @@ function nntest.DepthConcat() mytester:assertTensorEq(gradInput, gradInputConcat, 0.000001, "Error in SpatialConcat:updateGradInput") end +local function createMatrixInputSizes() + local M = torch.random(10, 20) + local N = torch.random(10, 20) + local P = torch.random(10, 20) + return M, N, P +end + +function nntest.MM() + local mm = nn.MM(false, true) + local M, N, P = createMatrixInputSizes() + local A = torch.randn(M, N) + local B = torch.randn(P, N) + + -- Test forward pass. + local output = mm:forward({A, B}) + mytester:assertTableEq(output:size():totable(), {M, P}, + 'Output has wrong dimensionality') + mytester:assertTensorEq(output, A * B:t(), 1e-10, + 'Wrong output') + + -- Test backward pass. + local gradOutput = torch.randn(M, P) + local gradInput = mm:backward({A, B}, gradOutput) + mytester:assert(#gradInput == 2, 'gradInput must be table of size 2') + local gradA, gradB = unpack(gradInput) + mytester:assertTableEq(gradA:size():totable(), A:size():totable(), + 'Gradient for input A has wrong size') + mytester:assertTableEq(gradB:size():totable(), B:size():totable(), + 'Gradient for input B has wrong size') + mytester:assertTensorEq(gradA, gradOutput * B, 1e-10, + 'Wrong gradient for input A') + mytester:assertTensorEq(gradB, gradOutput:t() * A, 1e-10, + 'Wrong gradient for input B') +end + +function nntest.BatchMMNoTranspose() + local mm = nn.MM() + local M, N, P = createMatrixInputSizes() + for bSize = 1, 11, 5 do + local A = torch.randn(bSize, M, N) + local B = torch.randn(bSize, N, P) + + -- Test forward pass. + local output = mm:forward({A, B}) + mytester:assertTableEq(output:size():totable(), {bSize, M, P}, + 'Output has wrong dimensionality') + for i = 1, bSize do + mytester:assertTensorEq(output[i], A[i] * B[i], 1e-10, + 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + + -- Test backward pass. + local gradOutput = torch.randn(bSize, M, P) + local gradInput = mm:backward({A, B}, gradOutput) + mytester:assert(#gradInput == 2, 'gradInput must be table of size 2') + local gradA, gradB = unpack(gradInput) + mytester:assertTableEq(gradA:size():totable(), A:size():totable(), + 'Gradient for input A has wrong size') + mytester:assertTableEq(gradB:size():totable(), B:size():totable(), + 'Gradient for input B has wrong size') + for i = 1, bSize do + mytester:assertTensorEq(gradA[i], gradOutput[i] * B[i]:t(), 1e-10, + 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i) + mytester:assertTensorEq(gradB[i], A[i]:t() * gradOutput[i], 1e-10, + 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + end +end + +function nntest.BatchMMTransposeA() + local mm = nn.MM(true, false) + local M, N, P = createMatrixInputSizes() + for bSize = 1, 11, 5 do + local A = torch.randn(bSize, N, M) + local B = torch.randn(bSize, N, P) + + -- Test forward pass. + local output = mm:forward({A, B}) + mytester:assertTableEq(output:size():totable(), {bSize, M, P}, + 'Output has wrong dimensionality') + for i = 1, bSize do + mytester:assertTensorEq(output[i], A[i]:t() * B[i], 1e-10, + 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + + -- Test backward pass. + local gradOutput = torch.randn(bSize, M, P) + local gradInput = mm:backward({A, B}, gradOutput) + mytester:assert(#gradInput == 2, 'gradInput must be table of size 2') + local gradA, gradB = unpack(gradInput) + mytester:assertTableEq(gradA:size():totable(), A:size():totable(), + 'Gradient for input A has wrong size') + mytester:assertTableEq(gradB:size():totable(), B:size():totable(), + 'Gradient for input B has wrong size') + for i = 1, bSize do + mytester:assertTensorEq(gradA[i], B[i] * gradOutput[i]:t(), 1e-10, + 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i) + mytester:assertTensorEq(gradB[i], A[i] * gradOutput[i], 1e-10, + 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + end +end + +function nntest.BatchMMTransposeB() + local mm = nn.MM(false, true) + local M, N, P = createMatrixInputSizes() + for bSize = 1, 11, 5 do + local A = torch.randn(bSize, M, N) + local B = torch.randn(bSize, P, N) + + -- Test forward pass. + local output = mm:forward({A, B}) + mytester:assertTableEq(output:size():totable(), {bSize, M, P}, + 'Output has wrong dimensionality') + for i = 1, bSize do + mytester:assertTensorEq(output[i], A[i] * B[i]:t(), 1e-10, + 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + + -- Test backward pass. + local gradOutput = torch.randn(bSize, M, P) + local gradInput = mm:backward({A, B}, gradOutput) + mytester:assert(#gradInput == 2, 'gradInput must be table of size 2') + local gradA, gradB = unpack(gradInput) + mytester:assertTableEq(gradA:size():totable(), A:size():totable(), + 'Gradient for input A has wrong size') + mytester:assertTableEq(gradB:size():totable(), B:size():totable(), + 'Gradient for input B has wrong size') + for i = 1, bSize do + mytester:assertTensorEq(gradA[i], gradOutput[i] * B[i], 1e-10, + 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i) + mytester:assertTensorEq(gradB[i], gradOutput[i]:t() * A[i], 1e-10, + 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + end +end + +function nntest.BatchMMTransposeBoth() + local mm = nn.MM(true, true) + local M, N, P = createMatrixInputSizes() + for bSize = 1, 11, 5 do + local A = torch.randn(bSize, N, M) + local B = torch.randn(bSize, P, N) + + -- Test forward pass. + local output = mm:forward({A, B}) + mytester:assertTableEq(output:size():totable(), {bSize, M, P}, + 'Output has wrong dimensionality') + for i = 1, bSize do + mytester:assertTensorEq(output[i], A[i]:t() * B[i]:t(), 1e-10, + 'Output wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + + -- Test backward pass. + local gradOutput = torch.randn(bSize, M, P) + local gradInput = mm:backward({A, B}, gradOutput) + mytester:assert(#gradInput == 2, 'gradInput must be table of size 2') + local gradA, gradB = unpack(gradInput) + mytester:assertTableEq(gradA:size():totable(), A:size():totable(), + 'Gradient for input A has wrong size') + mytester:assertTableEq(gradB:size():totable(), B:size():totable(), + 'Gradient for input B has wrong size') + for i = 1, bSize do + mytester:assertTensorEq(gradA[i], B[i]:t() * gradOutput[i]:t(), 1e-10, + 'Gradient for input A wrong for bSize = ' .. bSize .. ' and i = ' .. i) + mytester:assertTensorEq(gradB[i], gradOutput[i]:t() * A[i]:t(), 1e-10, + 'Gradient for input B wrong for bSize = ' .. bSize .. ' and i = ' .. i) + end + end +end + mytester:add(nntest) if not nn then |