Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 04:41:28 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 04:41:28 +0300
commit6714cebc861db18ede18e3d9d56e05340669998c (patch)
treeb69b6d90025f719df32948972841c55af437507f
parenteb6548a0c30db70465de4779d866bfac781ec0b1 (diff)
nn.Convert
-rw-r--r--Convert.lua244
-rwxr-xr-xdoc/simple.md76
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua37
4 files changed, 356 insertions, 2 deletions
diff --git a/Convert.lua b/Convert.lua
new file mode 100644
index 0000000..308bdac
--- /dev/null
+++ b/Convert.lua
@@ -0,0 +1,244 @@
+------------------------------------------------------------------------
+--[ nn.Convert ]--
+-- Module to convert between different data formats
+-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f')
+-- Automatically converts input to same type as self.output
+-- Simplest use is for automatic input type converions : nn.Convert()
+------------------------------------------------------------------------
+local _ = require 'moses'
+local Convert, parent = torch.class("nn.Convert", "nn.Container")
+
+function Convert:__init(inputShape, outputShape)
+ if outputShape and not inputShape then
+ error"Expecting non-nil arg 1 when arg 2 is provided"
+ end
+ inputShape = inputShape or 'b*'
+ outputShape = outputShape or inputShape
+ self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape)
+ self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape)
+ self.inputBatchDim = self.inputShape:find('b')
+ self.outputBatchDim = self.outputShape:find('b')
+ if self.inputShape == 'b*' or self.outputShape == 'b*' then
+ assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*')
+ self.nInputDim = -1
+ self.nOutputDim = -1
+ self.transposition = true
+ else
+ -- number of dims in batch mode
+ self.nInputDim = #self.inputShape
+ self.nOutputDim = #self.outputShape
+ -- is the outputShape just a transposition of the inputShape?
+ if self.nInputDim == self.nOutputDim then
+ self.transposition = true
+ for i=1,self.nInputDim do
+ if not self.outputShape:find(self.inputShape:sub(i,i)) then
+ self.transposition = false
+ break
+ end
+ end
+ end
+ end
+ parent.__init(self)
+end
+
+-- post-initialization
+function Convert:buildConverter(input)
+ if self.transposition then
+ self.converter = self:transpose(self.outputShape)
+ else
+ if (torch.type(self[self.outputShape]) ~= 'function') then
+ error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape))
+ end
+ self.converter = self[self.outputShape](self, input)
+ end
+ assert(torch.isTensor(self.output), "Expecting Tensor output")
+
+ self.converter:type(torch.type(self.output))
+ self.converter:serialMode(self.dpnn_serialEmpty, self.dpnn_serialType)
+
+ self.modules[1] = self.converter
+end
+
+function Convert:updateOutput(input)
+ assert(torch.isTensor(input), "expecting Tensor")
+ if not torch.isTypeOf(input, torch.type(self.output)) then
+ -- handle different input type
+ self._input = self._input or self.output.new()
+ self._input:resize(input:size()):copy(input)
+ input = self._input
+ end
+ self.batchMode = true
+ if input:dim() < self.nInputDim then
+ -- handle non-batch mode
+ local inputSize = input:size():totable()
+ table.insert(inputSize, self.inputBatchDim, 1)
+ self.__input = self.__input or input.new()
+ self.__input:set(input):resize(table.unpack(inputSize))
+ input = self.__input
+ self.batchMode = false
+ end
+ if not self.converter then
+ self:buildConverter(input)
+ end
+
+ self.output = self.converter:updateOutput(input)
+
+ if not self.batchMode then
+ local outputSize = self.output:size():totable()
+ table.remove(outputSize, self.outputBatchDim)
+ self.__output = self.__output or self.output.new()
+ self.__output:set(self.output):resize(table.unpack(outputSize))
+ self.output = self.__output
+ end
+ return self.output
+end
+
+function Convert:updateGradInput(input, gradOutput)
+ local input_ = input
+ input = self._input or input
+ if not self.batchMode then
+ input = self.__input
+ self.__gradOutput = self.__gradOutput or gradOutput.new()
+ self.__gradOutput:set(gradOutput):resize(self.converter.output:size())
+ gradOutput = self.__gradOutput
+ end
+
+ local gradInput = self.converter:updateGradInput(input, gradOutput)
+
+ if not self.batchMode then
+ self.__gradInput = self.__gradInput or gradInput.new()
+ self.__gradInput:set(gradInput):resize(input_:size())
+ gradInput = self.__gradInput
+ end
+ if self._input then
+ self._gradInput = self._gradInput or input.new()
+ self._gradInput:resize(input:size()):copy(gradInput)
+ self.gradInput = self._gradInput
+ else
+ self.gradInput = gradInput
+ end
+
+ return self.gradInput
+end
+
+function Convert:accGradParameters(input, gradOutput, scale)
+ input = self.batchMode and self.__input or self._input or input
+ gradOutput = self.batchMode and self.__gradOutput or gradOutput
+ self.converter:accGradParameters(input, gradOutput, scale)
+end
+
+function Convert:accUpdateGradParameters(input, gradOutput, lr)
+ input = self.batchMode and self.__input or self._input or input
+ gradOutput = self.batchMode and self.__gradOutput or gradOutput
+ self.converter:accUpdateGradParameters(input, gradOutput, lr)
+end
+
+-- batch feature
+function Convert:bf(input)
+ local b_pos = self:findAxis('b', self.inputShape)
+ local dim = #self.inputShape
+ if self.inputShape == 'bt' then
+ error"Conversion of shape bt to bf not supported: open an issue on github"
+ end
+ -- was b
+ if dim == 1 then
+ return nn.Reshape(1)
+ end
+ -- was b...
+ local modula
+ if b_pos ~= 1 then
+ modula = nn.Transpose({1, b_pos})
+ end
+ if dim > 2 then
+ local transpose = modula
+ local sampleSize = input:select(self:findAxis('b'),1):nElement()
+ local reshape = nn.Reshape(sampleSize)
+ if transpose then
+ modula = nn.Sequential()
+ modula:add(transpose)
+ modula:add(reshape)
+ else
+ modula = reshape
+ end
+ end
+ return modula or nn.Identity()
+end
+
+-- each example is a scalar; batch is a vector
+function Convert:b(input)
+ local b_pos = self:findAxis('b')
+ if self.inputShape == 'bt' or self.inputShape == 'tb' then
+ local t_pos = self:findAxis('t')
+ -- select first set of classes
+ return nn.Select(t_pos, 1)
+ elseif self.inputShape == 'bf' or self.inputShape == 'fb' then
+ -- this wont work as expected with size(f) > 1
+ local f_pos = self:findAxis('f')
+ if input:size(f_pos) > 1 then
+ error("Cannot convert shape "..self.inputShape.." to b when feature > 1")
+ end
+ return nn.Select(f_pos, 1)
+ else
+ error("Cannot convert shape "..self.inputShape.." to shape b")
+ end
+end
+
+-- returns the current shape of the data
+function Convert:default()
+ return nn.Identity()
+end
+
+-- multi-class (batch target)
+function Convert:bt()
+ local b_pos = self:findAxis('b')
+ local modula
+ if self.inputShape == 'b' then
+ modula = nn.Reshape(1)
+ else
+ error("cannot convert shape '"..self.inputShape.."' to bt")
+ end
+ return modula
+end
+
+-- a generic function for transposing shape axes
+function Convert:transpose(newShape)
+ if newShape == self.inputShape then
+ return nn.Identity()
+ end
+ local inputShape = {}
+ for i=1,#self.inputShape do
+ table.insert(inputShape, self.inputShape:sub(i,i))
+ end
+ local transpositions = {}
+ for i=1,#newShape do
+ local j = _.indexOf(inputShape, newShape:sub(i,i))
+ if i ~= j then
+ local char = inputShape[i]
+ inputShape[i] = inputShape[j]
+ inputShape[j] = char
+ table.insert(transpositions, {j, i})
+ end
+ end
+ return nn.Transpose(table.unpack(transpositions))
+end
+
+function Convert:findAxis(axis_char, shape, silent)
+ shape = shape or self.inputShape
+ local axis_pos = shape:find(axis_char)
+ if (not silent) and (not axis_pos) then
+ error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2)
+ end
+ return axis_pos
+end
+
+function Convert:type(type)
+ if not torch.isTypeOf(self.output, type) then
+ self._input = nil
+ self._gradInput = nil
+ self.__input = nil
+ self.__output = nil
+ self.__gradInput = nil
+ self.__gradOutput = nil
+ end
+ return parent.type(self, type)
+end
diff --git a/doc/simple.md b/doc/simple.md
index 3d08167..849d9b5 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -62,7 +62,8 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
* [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
* [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`;
- * [Collapse](#nn.Collapse) : just like `nn.View(-1)`.
+ * [Collapse](#nn.Collapse) : just like `nn.View(-1)`;
+ * [Convert](#nn.Convert) : convert between different tensor types or shapes;
<a name="nn.Linear"></a>
## Linear ##
@@ -1803,4 +1804,75 @@ view:setNumInputDim(nInputDim)
It collapses all non-batch dimensions. This is useful for converting
a spatial feature map to the single dimension required by a dense
-hidden layer like Linear. \ No newline at end of file
+hidden layer like Linear.
+
+<a name='nn.Convert'></a>
+## Convert ##
+
+```lua
+module = nn.Convert([inputShape, outputShape])
+```
+Module to convert between different data formats.
+For example, we can flatten images by using :
+```lua
+module = nn.Convert('bchw', 'bf')
+```
+or equivalently
+```lua
+module = nn.Convert('chw', 'f')
+```
+Lets try it with an input:
+```lua
+print(module:forward(torch.randn(3,2,3,1)))
+ 0.5692 -0.0190 0.5243 0.7530 0.4230 1.2483
+-0.9142 0.6013 0.5608 -1.0417 -1.4014 1.0177
+-1.5207 -0.1641 -0.4166 1.4810 -1.1725 -1.0037
+[torch.DoubleTensor of size 3x6]
+```
+You could also try:
+
+```lua
+module = nn.Convert('chw', 'hwc')
+input = torch.randn(1,2,3,2)
+input:select(2,1):fill(1)
+input:select(2,2):fill(2)
+print(input)
+(1,1,.,.) =
+ 1 1
+ 1 1
+ 1 1
+(1,2,.,.) =
+ 2 2
+ 2 2
+ 2 2
+[torch.DoubleTensor of size 1x2x3x2]
+print(module:forward(input))
+(1,1,.,.) =
+ 1 2
+ 1 2
+
+(1,2,.,.) =
+ 1 2
+ 1 2
+
+(1,3,.,.) =
+ 1 2
+ 1 2
+[torch.DoubleTensor of size 1x3x2x2]
+```
+
+
+Furthermore, it automatically converts the `input` to have the same type as `self.output`
+(i.e. the type of the module).
+So you can also just use is for automatic input type converions:
+```lua
+module = nn.Convert()
+print(module.output) -- type of module
+[torch.DoubleTensor with no dimension]
+input = torch.FloatTensor{1,2,3}
+print(module:forward(input))
+ 1
+ 2
+ 3
+[torch.DoubleTensor of size 3]
+```
diff --git a/init.lua b/init.lua
index bd4a5b0..503d2c2 100755
--- a/init.lua
+++ b/init.lua
@@ -173,6 +173,7 @@ require('nn.MapTable')
require('nn.ZipTable')
require('nn.ZipTableOneToMany')
require('nn.Collapse')
+require('nn.Convert')
require('nn.Criterion')
require('nn.MSECriterion')
diff --git a/test.lua b/test.lua
index e776f26..1b1b2c1 100755
--- a/test.lua
+++ b/test.lua
@@ -8612,6 +8612,43 @@ function nntest.Collapse()
mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous")
end
+function nntest.Convert()
+ -- batch mode
+ local c = nn.Convert('bchw', 'chwb')
+ local input = torch.randn(8,3,5,5)
+ local output = c:forward(input)
+ local output2 = input:transpose(1,4):transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->chwb")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd bchw->chwb")
+ local c = nn.Convert('bchw', 'bf')
+ local output = c:forward(input)
+ local output2 = input:view(8,-1)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->bf")
+ c:float()
+ local output = c:forward(input:float())
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type()")
+ local output = c:forward(input)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float")
+ -- non-batch mode
+ local c = nn.Convert('chw', 'hwc')
+ local input = torch.randn(3,5,5)
+ local output = c:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->hwc non-batch")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd chw->hwc non-batch")
+ local c = nn.Convert('chw', 'f')
+ local output = c:forward(input)
+ local output2 = input:view(-1)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->bf non-batch")
+ c:float()
+ local output = c:forward(input:float())
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() non-batch")
+ local output = c:forward(input)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch")
+end
+
mytester:add(nntest)