diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-17 22:36:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-17 22:36:58 +0300 |
commit | 5cfd34497d0cc7e180ee370767a3ca684176a10c (patch) | |
tree | 9c12af6458d3d8802bddd688eb5351fc6d9a0ffe | |
parent | d98a782dea757c2dc48a3135bc8fb8208ac78438 (diff) | |
parent | b7f5185b40f10c87d978db378ee2782a414c8d70 (diff) |
Merge pull request #933 from torch/maxminfix
fixing Max and Min for new cutorch types
-rw-r--r-- | Max.lua | 15 | ||||
-rw-r--r-- | Min.lua | 15 | ||||
-rw-r--r-- | Normalize.lua | 15 |
3 files changed, 9 insertions, 36 deletions
@@ -21,7 +21,7 @@ end function Max:_lazyInit() self._output = self._output or self.output.new() self._indices = self._indices or - (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaTensor() or torch.LongTensor()) + (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor()) end function Max:updateOutput(input) @@ -50,18 +50,9 @@ function Max:updateGradInput(input, gradOutput) end function Max:type(type, tensorCache) - -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor. - if type == 'torch.CudaTensor' then + self._indices = nil parent.type(self, type, tensorCache) - else - -- self._indices must be a LongTensor. Setting it to nil temporarily avoids - -- unnecessary memory allocations. - local indices - indices, self._indices = self._indices, nil - parent.type(self, type, tensorCache) - self._indices = indices and indices:long() or nil - end - return self + return self end function Max:clearState() @@ -21,7 +21,7 @@ end function Min:_lazyInit() self._output = self._output or self.output.new() self._indices = self._indices or - (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaTensor() or torch.LongTensor()) + (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor()) end function Min:updateOutput(input) @@ -50,18 +50,9 @@ function Min:updateGradInput(input, gradOutput) end function Min:type(type, tensorCache) - -- torch.min expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor. - if type == 'torch.CudaTensor' then + self._indices = nil parent.type(self, type, tensorCache) - else - -- self._indices must be a LongTensor. Setting it to nil temporarily avoids - -- unnecessary memory allocations. - local indices - indices, self._indices = self._indices, nil - parent.type(self, type, tensorCache) - self._indices = indices and indices:long() or nil - end - return self + return self end function Min:clearState() diff --git a/Normalize.lua b/Normalize.lua index 24c1d07..5cd4857 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -25,7 +25,7 @@ function Normalize:updateOutput(input) -- specialization for the infinity norm self._indices = self._indices or (torch.type(self.output) == 'torch.CudaTensor' and - torch.CudaTensor() or torch.LongTensor()) + torch.CudaLongTensor() or torch.LongTensor()) self.buffer:abs(input) torch.max(self.norm, self._indices, self.buffer, 2) @@ -127,18 +127,9 @@ function Normalize:__tostring__() end function Normalize:type(type, tensorCache) - -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor. - if type == 'torch.CudaTensor' then + self._indices = nil parent.type(self, type, tensorCache) - else - -- self._indices must be a LongTensor. Setting it to nil temporarily avoids - -- unnecessary memory allocations. - local indices - indices, self._indices = self._indices, nil - parent.type(self, type, tensorCache) - self._indices = indices and indices:long() or nil - end - return self + return self end function Normalize:clearState() |