Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-01-07 00:37:32 +0300
committernicholas-leonard <nick@nikopia.org>2015-01-07 00:37:32 +0300
commitbd7b315d7427f8320d44d8b29077154db87df633 (patch)
treebaf59604cc58335e530b69553de893adc6747dbe
parent4e0a96d801060121521ccc46f7294aeb3b247965 (diff)
new Copy constructor arguments
-rw-r--r--Copy.lua13
-rw-r--r--doc/simple.md10
-rw-r--r--test.lua14
3 files changed, 32 insertions, 5 deletions
diff --git a/Copy.lua b/Copy.lua
index 83be6ab..61c7f1f 100644
--- a/Copy.lua
+++ b/Copy.lua
@@ -1,14 +1,16 @@
local Copy, parent = torch.class('nn.Copy', 'nn.Module')
-function Copy:__init(intype, outtype)
+function Copy:__init(intype, outtype, forceCopy, dontCast)
intype = intype or torch.Tensor.__typename
outtype = outtype or torch.Tensor.__typename
+
+ self.dontCast = dontCast
parent.__init(self)
self.gradInput = torch.getmetatable(intype).new()
self.output = torch.getmetatable(outtype).new()
- if intype == outtype then
+ if (not forceCopy) and intype == outtype then
self.updateOutput = function(self, input)
self.output = input
@@ -31,3 +33,10 @@ function Copy:updateGradInput(input, gradOutput)
self.gradInput:resize(gradOutput:size()):copy(gradOutput)
return self.gradInput
end
+
+function Copy:type(type)
+ if type and self.dontCopy then
+ return self
+ end
+ return parent.type(self, type)
+end
diff --git a/doc/simple.md b/doc/simple.md
index 8cbb017..4b1b182 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -468,11 +468,15 @@ end
<a name="nn.Copy"/>
## Copy ##
-`module` = `Copy(inputType,outputType)`
+`module` = `Copy(inputType,outputType,[forceCopy,dontCast])`
This layer copies the input to output with type casting from input
-type from `inputType` to `outputType`.
-
+type from `inputType` to `outputType`. Unless `forceCopy` is true, when
+the first two arguments are the same, the input isn't copied, only transfered
+as the output. The default `forceCopy` is false.
+When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast
+the module's `output` and `gradInput` Tensors to the new type. The default
+is false.
<a name="nn.Narrow"/>
## Narrow ##
diff --git a/test.lua b/test.lua
index 3cf6a58..38a659d 100644
--- a/test.lua
+++ b/test.lua
@@ -2181,6 +2181,20 @@ function nntest.MulConstant()
mytester:assertlt(err, precision, 'bprop error ')
end
+function nntest.Copy()
+ local input = torch.randn(3,4):double()
+ local c = nn.Copy('torch.DoubleTensor', 'torch.FloatTensor')
+ local output = c:forward(input)
+ mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err')
+ mytester:assertTensorEq(output, input:float(), 0.000001, 'copy forward value err')
+ local gradInput = c:backward(input, output)
+ mytester:assert(torch.type(gradInput) == 'torch.DoubleTensor', 'copy backward type err')
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'copy backward value err')
+ c.dontCast = true
+ c:double()
+ mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err')
+end
+
function nntest.JoinTable()
local tensor = torch.rand(3,4,5)
local input = {tensor, tensor}