Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-08-17 22:36:58 +0300
committerGitHub <noreply@github.com>2016-08-17 22:36:58 +0300
commit5cfd34497d0cc7e180ee370767a3ca684176a10c (patch)
tree9c12af6458d3d8802bddd688eb5351fc6d9a0ffe
parentd98a782dea757c2dc48a3135bc8fb8208ac78438 (diff)
parentb7f5185b40f10c87d978db378ee2782a414c8d70 (diff)
Merge pull request #933 from torch/maxminfix
fixing Max and Min for new cutorch types
-rw-r--r--Max.lua15
-rw-r--r--Min.lua15
-rw-r--r--Normalize.lua15
3 files changed, 9 insertions, 36 deletions
diff --git a/Max.lua b/Max.lua
index 691fe9d..1392d8a 100644
--- a/Max.lua
+++ b/Max.lua
@@ -21,7 +21,7 @@ end
function Max:_lazyInit()
self._output = self._output or self.output.new()
self._indices = self._indices or
- (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaTensor() or torch.LongTensor())
+ (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor())
end
function Max:updateOutput(input)
@@ -50,18 +50,9 @@ function Max:updateGradInput(input, gradOutput)
end
function Max:type(type, tensorCache)
- -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
- if type == 'torch.CudaTensor' then
+ self._indices = nil
parent.type(self, type, tensorCache)
- else
- -- self._indices must be a LongTensor. Setting it to nil temporarily avoids
- -- unnecessary memory allocations.
- local indices
- indices, self._indices = self._indices, nil
- parent.type(self, type, tensorCache)
- self._indices = indices and indices:long() or nil
- end
- return self
+ return self
end
function Max:clearState()
diff --git a/Min.lua b/Min.lua
index f1d2b45..dc07cf9 100644
--- a/Min.lua
+++ b/Min.lua
@@ -21,7 +21,7 @@ end
function Min:_lazyInit()
self._output = self._output or self.output.new()
self._indices = self._indices or
- (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaTensor() or torch.LongTensor())
+ (torch.type(self.output) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor())
end
function Min:updateOutput(input)
@@ -50,18 +50,9 @@ function Min:updateGradInput(input, gradOutput)
end
function Min:type(type, tensorCache)
- -- torch.min expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
- if type == 'torch.CudaTensor' then
+ self._indices = nil
parent.type(self, type, tensorCache)
- else
- -- self._indices must be a LongTensor. Setting it to nil temporarily avoids
- -- unnecessary memory allocations.
- local indices
- indices, self._indices = self._indices, nil
- parent.type(self, type, tensorCache)
- self._indices = indices and indices:long() or nil
- end
- return self
+ return self
end
function Min:clearState()
diff --git a/Normalize.lua b/Normalize.lua
index 24c1d07..5cd4857 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -25,7 +25,7 @@ function Normalize:updateOutput(input)
-- specialization for the infinity norm
self._indices = self._indices or
(torch.type(self.output) == 'torch.CudaTensor' and
- torch.CudaTensor() or torch.LongTensor())
+ torch.CudaLongTensor() or torch.LongTensor())
self.buffer:abs(input)
torch.max(self.norm, self._indices, self.buffer, 2)
@@ -127,18 +127,9 @@ function Normalize:__tostring__()
end
function Normalize:type(type, tensorCache)
- -- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
- if type == 'torch.CudaTensor' then
+ self._indices = nil
parent.type(self, type, tensorCache)
- else
- -- self._indices must be a LongTensor. Setting it to nil temporarily avoids
- -- unnecessary memory allocations.
- local indices
- indices, self._indices = self._indices, nil
- parent.type(self, type, tensorCache)
- self._indices = indices and indices:long() or nil
- end
- return self
+ return self
end
function Normalize:clearState()