diff options
author | Alykhan Tejani <alykhan.tejani@gmail.com> | 2017-01-01 20:37:35 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-01-01 20:37:35 +0300 |
commit | 422374f615e596e4d4418a7d07e49bde49668a27 (patch) | |
tree | c773142206916871b43082a68f52b94c8d7a591d | |
parent | 8251438690b9b9d90efe8ecef3c4a8cbe3f13653 (diff) |
added CReLU transfer function + tests (#1075)
-rw-r--r-- | CReLU.lua | 57 | ||||
-rw-r--r-- | doc/transfer.md | 36 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 73 |
4 files changed, 167 insertions, 0 deletions
diff --git a/CReLU.lua b/CReLU.lua new file mode 100644 index 0000000..8da6e79 --- /dev/null +++ b/CReLU.lua @@ -0,0 +1,57 @@ +local CReLU, parent = torch.class('nn.CReLU', 'nn.Sequential') + +-- Implements the CReLU activation function as described by +-- W. Shang et al. in "Understanding and Improving Convolutional Neural Networks +-- via Concatenated Rectified Linear Units" +function CReLU:__init(nInputDims, inplace) + parent.__init(self) + self.nInputDims = nInputDims + self.inplace = inplace or false + + local concatTable = nn.ConcatTable() + concatTable:add(nn.Identity()) + concatTable:add(nn.MulConstant(-1)) + self:add(concatTable) + self:add(nn.JoinTable(2)) + self:add(nn.ReLU(self.inplace)) +end + +function CReLU:updateOutput(input) + local input_ + local batched = input:dim() == (self.nInputDims + 1) + if not batched then + input_ = input:view(1, -1) + else + input_ = input:view(input:size(1), -1) + end + parent.updateOutput(self, input_) + local osize = input:size() + if not batched then + osize[1] = osize[1] * 2 + else + osize[2] = osize[2] * 2 + end + self.output:resize(osize) + return self.output +end + +function CReLU:backward(input, gradOutput) + return self:updateGradInput(input, gradOutput) +end + +function CReLU:updateGradInput(input, gradOutput) + local batched = input:dim() == (self.nInputDims + 1) + if not batched then + parent.updateGradInput(self, input:view(1, -1), gradOutput:view(1, -1)) + else + parent.updateGradInput(self, input:view(input:size(1), -1), + gradOutput:view(input:size(1), -1)) + end + + self.gradInput:resizeAs(input) + return self.gradInput +end + +function CReLU:__tostring__() + return "CReLU()" +end diff --git a/doc/transfer.md b/doc/transfer.md index 814aedf..964030a 100644 --- a/doc/transfer.md +++ b/doc/transfer.md @@ -465,6 +465,42 @@ gnuplot.grid(true) ![](image/rrelu.png) +<a name="nn.CReLU"></a> +## CReLU ## +``` +f = nn.CReLU(nInputDims, [inplace]) +``` + +Applies the Concatenated Rectified Linear Unit (`CReLU`) function to the input Tensor, outputting a `Tensor` with twice as many channels. The parameter `nInputDim` is the number of non-batched dimensions, larger than that value will be considered batches. +`CReLU` is defined as: + +``` +f(x) = concat(max(0, x), max(0, -x)) +``` + +i.e. `CReLU` applies `ReLU` to the input, `x`, and the negated input, `-x`, and concatenates the output along the 1st non-batched dimension. + +``` +crelu = nn.CReLU(3) +input = torch.Tensor(2, 3, 20, 20):uniform(-1, 1) +output = crelu:forward(input) +output:size() +2 +6 +20 +20 +[torch.LongStorage of size 4] + +input = torch.Tensor(3, 20, 20):uniform(-1, 1) +output = crelu:forward(input) +output:size() +6 +20 +20 +[torch.LongStorage of size 3] +``` + +For reference see [Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units](https://arxiv.org/abs/1603.05201). <a name="nn.ELU"></a> ## ELU ## @@ -89,6 +89,7 @@ require('nn.Threshold') require('nn.ReLU') require('nn.ReLU6') require('nn.PReLU') +require('nn.CReLU') require('nn.LeakyReLU') require('nn.SpatialSoftMax') require('nn.SpatialLogSoftMax') @@ -605,6 +605,79 @@ function nntest.GatedLinearUnit() mytester:assert(err < precision, 'Gated Linear gradient with other layers') end +function nntest.CReLU() + local function _verifyCReLU(featureMaps, concatenatedFeatureMaps) + local rectifiedFeatureMaps = nn.ReLU():forward(featureMaps) + local rectifiedNegFeatureMaps = nn.ReLU():forward(-featureMaps) + + mytester:asserteq(concatenatedFeatureMaps:size(1), featureMaps:size(1) * 2, + "CReLU should double the number of feature maps") + + for i = 1, rectifiedFeatureMaps:size(1) do + local found = false + for j = 1, concatenatedFeatureMaps:size(1) do + found = found or rectifiedFeatureMaps[i]:equal(concatenatedFeatureMaps[j]) + end + mytester:assert(found, "Original (rectified) feature maps should be in the output of CReLU") + end + + for i = 1, rectifiedNegFeatureMaps:size(1) do + local found = false + for j = 1, concatenatedFeatureMaps:size(1) do + found = found or rectifiedFeatureMaps[i]:equal(concatenatedFeatureMaps[j]) + end + mytester:assert(found, "The negative of the original (rectified) feature maps should be in the output of CReLU") + end + end + + local model = nn.Sequential() + model:add(nn.SpatialConvolution(1, 3, 3, 3, 1, 1, 1, 1)) + + for _, inplace in pairs({true, false}) do + --batched + local crelu = nn.CReLU(3, inplace) + local input = torch.Tensor(2, 1, 20, 20):uniform() + local featureMaps = model:forward(input) + local concatenatedFeatureMaps = crelu:forward(featureMaps) + for i = 1, input:size(1) do + _verifyCReLU(featureMaps[i], concatenatedFeatureMaps[i]) + end + + --non-batched + local input = torch.Tensor(1, 20, 20):uniform() + local featureMaps = model:forward(input) + local concatenatedFeatureMaps = crelu:forward(featureMaps) + _verifyCReLU(featureMaps, concatenatedFeatureMaps) + end + + --test gradients w.r.t input + local jac = nn.Jacobian + + for _, inplace in pairs({true, false}) do + local crelu = nn.CReLU(3, inplace) + --batched + local input = torch.Tensor(2, 3, 20, 20):uniform() + local err = jac.testJacobian(crelu, input) + mytester:assertlt(err, precision, "error computing gradients w.r.t. inputs") + + --I/O + local fwdErr,bkwdErr = jac.testIO(crelu,input) + mytester:asserteq(fwdErr, 0, torch.typename(crelu) .. " - i/o forward err ") + mytester:asserteq(bkwdErr, 0, torch.typename(crelu) .. " - i/o backward err ") + + --non-batched + input = torch.Tensor(3, 20, 20):uniform() + err = jac.testJacobian(crelu,input) + mytester:assertlt(err, precision, "error computing gradients w.r.t. inputs") + + --I/O + local fwdErr,bkwdErr = jac.testIO(crelu,input) + mytester:asserteq(fwdErr, 0, torch.typename(crelu) .. " - i/o forward err ") + mytester:asserteq(bkwdErr, 0, torch.typename(crelu) .. " - i/o backward err ") + end + +end + function nntest.Exp() local ini = math.random(3,5) local inj = math.random(3,5) |