diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-01-07 00:37:32 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-01-07 00:37:32 +0300 |
commit | bd7b315d7427f8320d44d8b29077154db87df633 (patch) | |
tree | baf59604cc58335e530b69553de893adc6747dbe | |
parent | 4e0a96d801060121521ccc46f7294aeb3b247965 (diff) |
new Copy constructor arguments
-rw-r--r-- | Copy.lua | 13 | ||||
-rw-r--r-- | doc/simple.md | 10 | ||||
-rw-r--r-- | test.lua | 14 |
3 files changed, 32 insertions, 5 deletions
@@ -1,14 +1,16 @@ local Copy, parent = torch.class('nn.Copy', 'nn.Module') -function Copy:__init(intype, outtype) +function Copy:__init(intype, outtype, forceCopy, dontCast) intype = intype or torch.Tensor.__typename outtype = outtype or torch.Tensor.__typename + + self.dontCast = dontCast parent.__init(self) self.gradInput = torch.getmetatable(intype).new() self.output = torch.getmetatable(outtype).new() - if intype == outtype then + if (not forceCopy) and intype == outtype then self.updateOutput = function(self, input) self.output = input @@ -31,3 +33,10 @@ function Copy:updateGradInput(input, gradOutput) self.gradInput:resize(gradOutput:size()):copy(gradOutput) return self.gradInput end + +function Copy:type(type) + if type and self.dontCopy then + return self + end + return parent.type(self, type) +end diff --git a/doc/simple.md b/doc/simple.md index 8cbb017..4b1b182 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -468,11 +468,15 @@ end <a name="nn.Copy"/> ## Copy ## -`module` = `Copy(inputType,outputType)` +`module` = `Copy(inputType,outputType,[forceCopy,dontCast])` This layer copies the input to output with type casting from input -type from `inputType` to `outputType`. - +type from `inputType` to `outputType`. Unless `forceCopy` is true, when +the first two arguments are the same, the input isn't copied, only transfered +as the output. The default `forceCopy` is false. +When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast +the module's `output` and `gradInput` Tensors to the new type. The default +is false. <a name="nn.Narrow"/> ## Narrow ## @@ -2181,6 +2181,20 @@ function nntest.MulConstant() mytester:assertlt(err, precision, 'bprop error ') end +function nntest.Copy() + local input = torch.randn(3,4):double() + local c = nn.Copy('torch.DoubleTensor', 'torch.FloatTensor') + local output = c:forward(input) + mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err') + mytester:assertTensorEq(output, input:float(), 0.000001, 'copy forward value err') + local gradInput = c:backward(input, output) + mytester:assert(torch.type(gradInput) == 'torch.DoubleTensor', 'copy backward type err') + mytester:assertTensorEq(gradInput, input, 0.000001, 'copy backward value err') + c.dontCast = true + c:double() + mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err') +end + function nntest.JoinTable() local tensor = torch.rand(3,4,5) local input = {tensor, tensor} |