Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-08-17 02:34:34 +0300
committersoumith <soumith@fb.com>2016-08-17 02:48:34 +0300
commitb7f5185b40f10c87d978db378ee2782a414c8d70 (patch)
tree8f7cd2e0bb86492fb801899b1eca803aec475a21 /Normalize.lua
parent9700f3dbba4e9fe31d84a70d642af2c40fde1399 (diff)
fixing nn.Normalize for new cutorch types
Diffstat (limited to 'Normalize.lua')
-rw-r--r--Normalize.lua15
1 files changed, 3 insertions, 12 deletions
diff --git a/Normalize.lua b/Normalize.lua
index 24c1d07..5cd4857 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -25,7 +25,7 @@ function Normalize:updateOutput(input)
-- specialization for the infinity norm
self._indices = self._indices or
(torch.type(self.output) == 'torch.CudaTensor' and
- torch.CudaTensor() or torch.LongTensor())
+ torch.CudaLongTensor() or torch.LongTensor())
self.buffer:abs(input)
torch.max(self.norm, self._indices, self.buffer, 2)
@@ -127,18 +127,9 @@ function Normalize:__tostring__()
end
function Normalize:type(type, tensorCache)
- -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
- if type == 'torch.CudaTensor' then
+ self._indices = nil
parent.type(self, type, tensorCache)
- else
- -- self._indices must be a LongTensor. Setting it to nil temporarily avoids
- -- unnecessary memory allocations.
- local indices
- indices, self._indices = self._indices, nil
- parent.type(self, type, tensorCache)
- self._indices = indices and indices:long() or nil
- end
- return self
+ return self
end
function Normalize:clearState()