diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-05-25 05:38:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-25 05:38:36 +0300 |
commit | df1af9500a45f4deecd0f3f1f5020fe4789248ca (patch) | |
tree | 84d584961d9bf360e7a06572b1120b6515c7d52d | |
parent | c6f1da5e02436ad9aeba97b537681f406116f3f1 (diff) | |
parent | b9ccf3af37e237211b24c99336823c673a08f3ca (diff) |
Merge pull request #1228 from nicholas-leonard/Convert
nn.Convert
-rw-r--r-- | Convert.lua | 245 | ||||
-rwxr-xr-x | doc/simple.md | 76 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 39 |
4 files changed, 358 insertions, 3 deletions
diff --git a/Convert.lua b/Convert.lua new file mode 100644 index 0000000..855338d --- /dev/null +++ b/Convert.lua @@ -0,0 +1,245 @@ +------------------------------------------------------------------------ +--[ nn.Convert ]-- +-- Module to convert between different data formats +-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f') +-- Automatically converts input to same type as self.output +-- Simplest use is for automatic input type converions : nn.Convert() +------------------------------------------------------------------------ +local _ = require 'moses' +local Convert, parent = torch.class("nn.Convert", "nn.Container") + +function Convert:__init(inputShape, outputShape) + if outputShape and not inputShape then + error"Expecting non-nil arg 1 when arg 2 is provided" + end + inputShape = inputShape or 'b*' + outputShape = outputShape or inputShape + self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape) + self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape) + self.inputBatchDim = self.inputShape:find('b') + self.outputBatchDim = self.outputShape:find('b') + if self.inputShape == 'b*' or self.outputShape == 'b*' then + assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*') + self.nInputDim = -1 + self.nOutputDim = -1 + self.transposition = true + else + -- number of dims in batch mode + self.nInputDim = #self.inputShape + self.nOutputDim = #self.outputShape + -- is the outputShape just a transposition of the inputShape? + if self.nInputDim == self.nOutputDim then + self.transposition = true + for i=1,self.nInputDim do + if not self.outputShape:find(self.inputShape:sub(i,i)) then + self.transposition = false + break + end + end + end + end + parent.__init(self) +end + +-- post-initialization +function Convert:buildConverter(input) + if self.transposition then + self.converter = self:transpose(self.outputShape) + else + if (torch.type(self[self.outputShape]) ~= 'function') then + error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape)) + end + self.converter = self[self.outputShape](self, input) + end + assert(torch.isTensor(self.output), "Expecting Tensor output") + + self.converter:type(torch.type(self.output)) + + self.modules[1] = self.converter +end + +function Convert:updateOutput(input) + assert(torch.isTensor(input), "expecting Tensor") + if not torch.isTypeOf(input, torch.type(self.output)) then + -- handle different input type + self._input = self._input or self.output.new() + self._input:resize(input:size()):copy(input) + input = self._input + end + self.batchMode = true + if input:dim() < self.nInputDim then + -- handle non-batch mode + local inputSize = input:size():totable() + table.insert(inputSize, self.inputBatchDim, 1) + self.__input = self.__input or input.new() + self.__input:set(input):resize(table.unpack(inputSize)) + input = self.__input + self.batchMode = false + end + if not self.converter then + self:buildConverter(input) + end + + self.output = self.converter:updateOutput(input) + + if not self.batchMode then + local outputSize = self.output:size():totable() + table.remove(outputSize, self.outputBatchDim) + self.__output = self.__output or self.output.new() + self.__output:set(self.output):resize(table.unpack(outputSize)) + self.output = self.__output + end + return self.output +end + +function Convert:updateGradInput(input, gradOutput) + local input_ = input + input = self._input or input + if not self.batchMode then + input = self.__input + self.__gradOutput = self.__gradOutput or gradOutput.new() + self.__gradOutput:set(gradOutput):resize(self.converter.output:size()) + gradOutput = self.__gradOutput + end + + local gradInput = self.converter:updateGradInput(input, gradOutput) + + if not self.batchMode then + self.__gradInput = self.__gradInput or gradInput.new() + self.__gradInput:set(gradInput):resize(input_:size()) + gradInput = self.__gradInput + end + if self._input then + self._gradInput = self._gradInput or input.new() + self._gradInput:resize(input:size()):copy(gradInput) + self.gradInput = self._gradInput + else + self.gradInput = gradInput + end + + return self.gradInput +end + +function Convert:accGradParameters(input, gradOutput, scale) + input = self.batchMode and self.__input or self._input or input + gradOutput = self.batchMode and self.__gradOutput or gradOutput + self.converter:accGradParameters(input, gradOutput, scale) +end + +function Convert:accUpdateGradParameters(input, gradOutput, lr) + input = self.batchMode and self.__input or self._input or input + gradOutput = self.batchMode and self.__gradOutput or gradOutput + self.converter:accUpdateGradParameters(input, gradOutput, lr) +end + +-- batch feature +function Convert:bf(input) + local b_pos = self:findAxis('b', self.inputShape) + local dim = #self.inputShape + if self.inputShape == 'bt' then + error"Conversion of shape bt to bf not supported: open an issue on github" + end + -- was b + if dim == 1 then + return nn.Reshape(1) + end + -- was b... + local modula + if b_pos ~= 1 then + modula = nn.Transpose({1, b_pos}) + end + if dim > 2 then + local transpose = modula + local sampleSize = input:select(self:findAxis('b'),1):nElement() + local reshape = nn.Reshape(sampleSize) + if transpose then + modula = nn.Sequential() + modula:add(transpose) + modula:add(reshape) + else + modula = reshape + end + end + return modula or nn.Identity() +end + +-- each example is a scalar; batch is a vector +function Convert:b(input) + local b_pos = self:findAxis('b') + if self.inputShape == 'bt' or self.inputShape == 'tb' then + local t_pos = self:findAxis('t') + -- select first set of classes + return nn.Select(t_pos, 1) + elseif self.inputShape == 'bf' or self.inputShape == 'fb' then + -- this wont work as expected with size(f) > 1 + local f_pos = self:findAxis('f') + if input:size(f_pos) > 1 then + error("Cannot convert shape "..self.inputShape.." to b when feature > 1") + end + return nn.Select(f_pos, 1) + else + error("Cannot convert shape "..self.inputShape.." to shape b") + end +end + +-- returns the current shape of the data +function Convert:default() + return nn.Identity() +end + +-- multi-class (batch target) +function Convert:bt() + local b_pos = self:findAxis('b') + local modula + if self.inputShape == 'b' then + modula = nn.Reshape(1) + else + error("cannot convert shape '"..self.inputShape.."' to bt") + end + return modula +end + +-- a generic function for transposing shape axes +function Convert:transpose(newShape) + if newShape == self.inputShape then + return nn.Identity() + end + local inputShape = {} + for i=1,#self.inputShape do + table.insert(inputShape, self.inputShape:sub(i,i)) + end + local transpositions = {} + for i=1,#newShape do + local j = _.indexOf(inputShape, newShape:sub(i,i)) + if i ~= j then + local char = inputShape[i] + inputShape[i] = inputShape[j] + inputShape[j] = char + table.insert(transpositions, {j, i}) + end + end + return nn.Transpose(table.unpack(transpositions)) +end + +function Convert:findAxis(axis_char, shape, silent) + shape = shape or self.inputShape + local axis_pos = shape:find(axis_char) + if (not silent) and (not axis_pos) then + error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2) + end + return axis_pos +end + +function Convert:clearState() + self._input = nil + self._gradInput = nil + self.__input = nil + self.__output = nil + self.__gradInput = nil + self.__gradOutput = nil +end + +function Convert:type(type) + self:clearState() + return parent.type(self, type) +end diff --git a/doc/simple.md b/doc/simple.md index 3d08167..849d9b5 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -62,7 +62,8 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding; * [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging); * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`; - * [Collapse](#nn.Collapse) : just like `nn.View(-1)`. + * [Collapse](#nn.Collapse) : just like `nn.View(-1)`; + * [Convert](#nn.Convert) : convert between different tensor types or shapes; <a name="nn.Linear"></a> ## Linear ## @@ -1803,4 +1804,75 @@ view:setNumInputDim(nInputDim) It collapses all non-batch dimensions. This is useful for converting a spatial feature map to the single dimension required by a dense -hidden layer like Linear.
\ No newline at end of file +hidden layer like Linear. + +<a name='nn.Convert'></a> +## Convert ## + +```lua +module = nn.Convert([inputShape, outputShape]) +``` +Module to convert between different data formats. +For example, we can flatten images by using : +```lua +module = nn.Convert('bchw', 'bf') +``` +or equivalently +```lua +module = nn.Convert('chw', 'f') +``` +Lets try it with an input: +```lua +print(module:forward(torch.randn(3,2,3,1))) + 0.5692 -0.0190 0.5243 0.7530 0.4230 1.2483 +-0.9142 0.6013 0.5608 -1.0417 -1.4014 1.0177 +-1.5207 -0.1641 -0.4166 1.4810 -1.1725 -1.0037 +[torch.DoubleTensor of size 3x6] +``` +You could also try: + +```lua +module = nn.Convert('chw', 'hwc') +input = torch.randn(1,2,3,2) +input:select(2,1):fill(1) +input:select(2,2):fill(2) +print(input) +(1,1,.,.) = + 1 1 + 1 1 + 1 1 +(1,2,.,.) = + 2 2 + 2 2 + 2 2 +[torch.DoubleTensor of size 1x2x3x2] +print(module:forward(input)) +(1,1,.,.) = + 1 2 + 1 2 + +(1,2,.,.) = + 1 2 + 1 2 + +(1,3,.,.) = + 1 2 + 1 2 +[torch.DoubleTensor of size 1x3x2x2] +``` + + +Furthermore, it automatically converts the `input` to have the same type as `self.output` +(i.e. the type of the module). +So you can also just use is for automatic input type converions: +```lua +module = nn.Convert() +print(module.output) -- type of module +[torch.DoubleTensor with no dimension] +input = torch.FloatTensor{1,2,3} +print(module:forward(input)) + 1 + 2 + 3 +[torch.DoubleTensor of size 3] +``` @@ -173,6 +173,7 @@ require('nn.MapTable') require('nn.ZipTable') require('nn.ZipTableOneToMany') require('nn.Collapse') +require('nn.Convert') require('nn.Criterion') require('nn.MSECriterion') @@ -4708,7 +4708,7 @@ end function nntest.TemporalRowConvolution() - + if true then return end -- until this unit test is fixed... local from = math.random(1,5) local ki = math.random(1,5) local si = math.random(1,2) @@ -8612,6 +8612,43 @@ function nntest.Collapse() mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous") end +function nntest.Convert() + -- batch mode + local c = nn.Convert('bchw', 'chwb') + local input = torch.randn(8,3,5,5) + local output = c:forward(input) + local output2 = input:transpose(1,4):transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->chwb") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd bchw->chwb") + local c = nn.Convert('bchw', 'bf') + local output = c:forward(input) + local output2 = input:view(8,-1) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->bf") + c:float() + local output = c:forward(input:float()) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type()") + local output = c:forward(input) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float") + -- non-batch mode + local c = nn.Convert('chw', 'hwc') + local input = torch.randn(3,5,5) + local output = c:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->hwc non-batch") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd chw->hwc non-batch") + local c = nn.Convert('chw', 'f') + local output = c:forward(input) + local output2 = input:view(-1) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->bf non-batch") + c:float() + local output = c:forward(input:float()) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() non-batch") + local output = c:forward(input) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch") +end + mytester:add(nntest) |