diff options
author | soumith <soumith@fb.com> | 2016-08-17 02:34:34 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-08-17 02:48:34 +0300 |
commit | b7f5185b40f10c87d978db378ee2782a414c8d70 (patch) | |
tree | 8f7cd2e0bb86492fb801899b1eca803aec475a21 /Normalize.lua | |
parent | 9700f3dbba4e9fe31d84a70d642af2c40fde1399 (diff) |
fixing nn.Normalize for new cutorch types
Diffstat (limited to 'Normalize.lua')
-rw-r--r-- | Normalize.lua | 15 |
1 files changed, 3 insertions, 12 deletions
diff --git a/Normalize.lua b/Normalize.lua index 24c1d07..5cd4857 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -25,7 +25,7 @@ function Normalize:updateOutput(input) -- specialization for the infinity norm self._indices = self._indices or (torch.type(self.output) == 'torch.CudaTensor' and - torch.CudaTensor() or torch.LongTensor()) + torch.CudaLongTensor() or torch.LongTensor()) self.buffer:abs(input) torch.max(self.norm, self._indices, self.buffer, 2) @@ -127,18 +127,9 @@ function Normalize:__tostring__() end function Normalize:type(type, tensorCache) - -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor. - if type == 'torch.CudaTensor' then + self._indices = nil parent.type(self, type, tensorCache) - else - -- self._indices must be a LongTensor. Setting it to nil temporarily avoids - -- unnecessary memory allocations. - local indices - indices, self._indices = self._indices, nil - parent.type(self, type, tensorCache) - self._indices = indices and indices:long() or nil - end - return self + return self end function Normalize:clearState() |