Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--View.lua36
-rw-r--r--doc/simple.md69
-rw-r--r--init.lua1
-rw-r--r--test/test.lua19
4 files changed, 125 insertions, 0 deletions
diff --git a/View.lua b/View.lua
new file mode 100644
index 0000000..75eda26
--- /dev/null
+++ b/View.lua
@@ -0,0 +1,36 @@
+local View, parent = torch.class('nn.View', 'nn.Module')
+
+function View:__init(...)
+ parent.__init(self)
+ self.size = ...
+ if select('#', ...) > 1 or type(self.size) == "number" then
+ self.size = torch.LongStorage({...})
+ end
+ assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage")
+ self.numElements = 1
+ for i = 1,#self.size do
+ self.numElements = self.numElements * self.size[i]
+ end
+
+ self.output = nil
+ self.gradInput = nil
+end
+
+local function isMinibatch(input, numElements)
+ return input:dim() > 1 and
+ input:nElement()/input:size(1) == numElements
+end
+
+function View:updateOutput(input)
+ if isMinibatch(input, self.numElements) then
+ self.output = input:view(input:size(1), unpack(self.size:totable()))
+ else
+ self.output = input:view(self.size)
+ end
+ return self.output
+end
+
+function View:updateGradInput(input, gradOutput)
+ self.gradInput = gradOutput:view(input:size())
+ return self.gradInput
+end
diff --git a/doc/simple.md b/doc/simple.md
index 78b0609..cf5acd8 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -468,8 +468,77 @@ Example:
12
16
[torch.Tensor of dimension 16]
+```
+
+<a name="nn.View"/>
+## Reshape ##
+
+`module` = `View(sizes)`
+
+This module creates a new view of the input tensor using the `sizes` passed to
+the constructor. The parameter `sizes` can either be a `LongStorage` or numbers.
+
+Example:
+```lua
+> x=torch.Tensor(4,4)
+> for i=1,4 do
+> for j=1,4 do
+> x[i][j]=(i-1)*4+j;
+> end
+> end
+> print(x)
+
+x=torch.Tensor(4,4)
+for i=1,4 do
+ for j=1,4 do
+ x[i][j]=(i-1)*4+j;
+ end
+end
+print(x)
+
+ 1 2 3 4
+ 5 6 7 8
+ 9 10 11 12
+ 13 14 15 16
+[torch.Tensor of dimension 4x4]
+
+> print(nn.View(2,8):forward(x))
+
+ 1 2 3 4 5 6 7 8
+ 9 10 11 12 13 14 15 16
+[torch.DoubleTensor of dimension 2x8]
+
+> print(nn.View(torch.LongStorage{8,2}):forward(x))
+
+ 1 2
+ 3 4
+ 5 6
+ 7 8
+ 9 10
+ 11 12
+ 13 14
+ 15 16
+[torch.DoubleTensor of dimension 8x2]
+> print(nn.View(16):forward(x))
+ 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+[torch.DoubleTensor of dimension 16]
```
diff --git a/init.lua b/init.lua
index 1fba70a..dfdd3a9 100644
--- a/init.lua
+++ b/init.lua
@@ -10,6 +10,7 @@ include('Sequential.lua')
include('Linear.lua')
include('SparseLinear.lua')
include('Reshape.lua')
+include('View.lua')
include('Select.lua')
include('Narrow.lua')
include('Replicate.lua')
diff --git a/test/test.lua b/test/test.lua
index be17fd7..5a127db 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1790,6 +1790,25 @@ function nntest.LookupTable()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.View()
+ local input = torch.rand(10)
+ local template = torch.rand(5,2)
+ local target = template:size():totable()
+ local module = nn.View(template:size())
+ mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)")
+ local module = nn.View(unpack(target))
+ mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)")
+
+ -- Minibatch
+ local minibatch = torch.rand(5,10)
+ mytester:assertTableEq(module:forward(minibatch):size(1),
+ minibatch:size(1),
+ "Error in minibatch dimension")
+ mytester:assertTableEq(module:forward(minibatch):nElement(),
+ minibatch:nElement(),
+ "Error in minibatch nElement")
+end
+
mytester:add(nntest)
if not nn then