diff options
Diffstat (limited to 'SpatialAdaptiveMaxPooling.lua')
-rw-r--r-- | SpatialAdaptiveMaxPooling.lua | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/SpatialAdaptiveMaxPooling.lua b/SpatialAdaptiveMaxPooling.lua index 74d4cd6..a2cf104 100644 --- a/SpatialAdaptiveMaxPooling.lua +++ b/SpatialAdaptiveMaxPooling.lua @@ -2,13 +2,18 @@ local SpatialAdaptiveMaxPooling, parent = torch.class('nn.SpatialAdaptiveMaxPool function SpatialAdaptiveMaxPooling:__init(W, H) parent.__init(self) - + self.W = W self.H = H end function SpatialAdaptiveMaxPooling:updateOutput(input) - self.indices = self.indices or input.new() + self.indices = self.indices or torch.LongTensor() + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices + else + self.indices = self.indices:long() + end input.THNN.SpatialAdaptiveMaxPooling_updateOutput( input:cdata(), self.output:cdata(), |