diff options
-rw-r--r-- | PartialLinear.lua | 113 | ||||
-rw-r--r-- | doc/simple.md | 38 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 63 |
4 files changed, 215 insertions, 0 deletions
diff --git a/PartialLinear.lua b/PartialLinear.lua new file mode 100644 index 0000000..9f9eef3 --- /dev/null +++ b/PartialLinear.lua @@ -0,0 +1,113 @@ +local PartialLinear, Module = torch.class('nn.PartialLinear', 'nn.Module') + +--[[ + +PartialLinear is a Linear layer that allows the user to a set a collection of +column indices. When the column indices are set, the layer will behave like a +Linear layer that only has those columns. Meanwhile, all parameters are +preserved, so resetting the PartialLinear layer will result in a module that +behaves just like a regular Linear layer. + +This module is useful, for instance, when you want to do forward-backward on +only a subset of a Linear layer during training but use the full Linear layer +at test time. + +]]-- + +function PartialLinear:__init(inputsize, outputsize, bias) + local bias = ((bias == nil) and true) or bias + Module.__init(self) + + -- define the layer as a small network: + local pt = nn.ParallelTable() + pt:add(nn.Identity()):add(nn.LookupTable(outputsize, inputsize)) + self.network = nn.Sequential():add(pt):add(nn.MM(false, true)) + if bias then + self.bias = torch.Tensor(1, outputsize):zero() + self.gradBias = torch.Tensor(1, outputsize):zero() + end + + -- set partition: + self.inputsize = inputsize + self.outputsize = outputsize + self.allcolumns = torch.range(1, self.outputsize) + self:resetPartition() +end + +function PartialLinear:setPartition(indices) + self.partition = indices:type(self.allcolumns:type()) +end + +function PartialLinear:resetPartition() + self.partition = self.allcolumns +end + +function PartialLinear:parameters() + return {self.network:get(1):get(2).weight, self.bias}, + {self.network:get(1):get(2).gradWeight, self.gradBias} +end -- should return only the relevant partition? + +function PartialLinear:updateOutput(input) + self.output = self.network:forward{input, self.partition} + if self.bias then + self.output:add( + self.bias:index(2, self.partition:long()):expandAs(self.output) + ) + self.addBuffer = self.addBuffer or input.new() + if self.addBuffer:nElement() ~= input:size(1) then + self.addBuffer:resize(input:size(1)):fill(1) + end + end + return self.output +end + +function PartialLinear:updateGradInput(input, gradOutput) + if self.gradInput then + self.network:updateGradInput({input, self.partition}, gradOutput) + self.gradInput = self.network.gradInput[1] + end + return self.gradInput +end + +function PartialLinear:accGradParameters(input, gradOutput, scale) + local scale = scale or 1 + self.network:accGradParameters({input, self.partition}, gradOutput, scale) + if self.bias then + self.buffer = self.buffer or input.new() + self.buffer:resize(gradOutput:size(2)) + self.buffer:mv(gradOutput:t(), self.addBuffer):mul(scale) + self.gradBias:indexAdd( + 2, self.partition:long(), self.buffer:view(1, self.buffer:nElement()) + ) + end +end + +function PartialLinear:accUpdateGradParameters(input, gradOutput, lr) + local gradWeight = self.network:get(1):get(2).gradWeight + local gradBias = self.gradBias + self.network:get(1):get(2).gradWeight = self.network:get(1):get(2).weight + self.gradBias = self.bias + self:accGradParameters(input, gradOutput, -lr) + self.network:get(1):get(2).gradWeight = gradWeight + self.gradBias = gradBias +end + +function PartialLinear:zeroGradParameters() + self.network:zeroGradParameters() + self.gradBias:zero() +end + +function PartialLinear:updateParameters(learningRate) + self.network:updateParameters(learningRate) + self.bias:add(-learningRate, self.gradBias) +end + +-- we do not need to accumulate parameters when sharing +PartialLinear.sharedAccUpdateGradParameters = + PartialLinear.accUpdateGradParameters + +function PartialLinear:__tostring__() + return torch.type(self) .. + string.format('(%d -> %d)', self.inputsize, self.outputsize) .. + (self.bias == nil and ' without bias' or '') +end diff --git a/doc/simple.md b/doc/simple.md index 9a2f022..20f45f9 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -6,6 +6,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [Linear](#nn.Linear) : a linear transformation ; * [SparseLinear](#nn.SparseLinear) : a linear transformation with sparse inputs ; * [Bilinear](#nn.Bilinear) : a bilinear transformation with sparse inputs ; + * [PartialLinear](#nn.PartialLinear) : a linear transformation with sparse inputs with the option of only computing a subset ; * [Add](#nn.Add) : adds a bias term to the incoming data ; * [Mul](#nn.Mul) : multiply a single scalar factor to the incoming data ; * [CMul](#nn.CMul) : a component-wise multiplication to the incoming data ; @@ -144,6 +145,43 @@ Input data for this layer would look as follows: module:forward(input) ``` +<a name="nn.PartialLinear"></a> +## PartialLinear ## + +```lua +module = nn.PartialLinear(inputSize, outputSize, [bias = true]) +``` + +PartialLinear is a Linear layer that allows the user to a set a collection of +column indices. When the column indices are set, the layer will behave like a +Linear layer that only has those columns. Meanwhile, all parameters are +preserved, so resetting the PartialLinear layer will result in a module that +behaves just like a regular Linear layer. + +This module is useful, for instance, when you want to do forward-backward on +only a subset of a Linear layer during training but use the full Linear layer +at test time. + +You can create a layer in the following way: + +```lua + module = nn.PartialLinear(5, 3) -- 5 inputs, 3 outputs +``` + +Input data for this layer would look as follows: +```lua + input = torch.randn(128, 5) -- 128 input examples + module:forward(input) +``` + +One can set the partition of indices to compute using the function `setPartition(indices)` where `indices` is a tensor containing the indices to compute. +```lua +module = nn.PartialLinear(5, 3) -- 5 inputs, 3 outputs +module:setPartition(torch.Tensor({2,4})) -- only compute the 2nd and 4th indices out of a total of 5 indices +``` + +One can reset the partition via the `resetPartition()` function that resets the partition to compute all indices, making it's behaviour equivalent to `nn.Linear` + <a name="nn.Dropout"></a> ## Dropout ## @@ -16,6 +16,7 @@ require('nn.DepthConcat') require('nn.Linear') require('nn.Bilinear') +require('nn.PartialLinear') require('nn.SparseLinear') require('nn.Reshape') require('nn.View') @@ -865,6 +865,69 @@ function nntest.Bilinear() end +function nntest.PartialLinear() + + -- settings for experiment: + local N = 10 + local D = 5 + local K = 15 + + -- test forward-backward pass of module: + local module = nn.PartialLinear(D, K) + for sub_K = 1,K do + + -- get random test case: + local input = torch.randn(N, D) + local partition = torch.randperm(K):narrow(1, 1, sub_K) + + -- do forward-backward pass: + module:setPartition(partition) + module:forward(input) + mytester:asserteq(module.output:size(1), N) + mytester:asserteq(module.output:size(2), sub_K) + module:backward(input, torch.ones(N, sub_K)) + mytester:asserteq(module.gradInput:size(1), input:size(1)) + mytester:asserteq(module.gradInput:size(2), input:size(2)) + + -- do parameter update: + local lr = .01 + module:updateParameters(lr) + end + module:resetPartition() + + -- compare output with linear layer: + local module2 = nn.Linear(D, K) + module2.weight:copy(module.network:get(1):get(2).weight) + module2.bias:fill(0) + if module.bias then module2.bias:copy(module.bias) end + local input = torch.randn(N, D) + local diff = (module:forward(input) - module2:forward(input)):abs():sum() + mytester:assertlt(diff, 1e-7) + + -- gradient checks: + local sub_K = 5 + local partition = torch.randperm(K):narrow(1, 1, sub_K) + module:setPartition(partition) + local err = sjac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local err = sjac.testJacobianParameters(module, input, module.network:get(1):get(2).weight, module.network:get(1):get(2).gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = sjac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on bias ') + + local err = sjac.testJacobianUpdateParameters(module, input, module.network:get(1):get(2).weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + local err = sjac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err,precision, 'error on bias [direct update] ') + + local ferr, berr = sjac.testIO(module, input) + mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') +end + function nntest.Euclidean() local ini = math.random(5,7) local inj = math.random(5,7) |