Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Uesato <juesato@mit.edu>2017-02-02 05:20:14 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-02 05:20:14 +0300
commitdb4244e6ee30b0ec689815a00fcc0c8c45f91b12 (patch)
treef1877ec9a78b3046319a79d2ef79c9d6315dcc34 /CMinTable.lua
parent02ebb69c1838c5c63a64b374faceff40671fdaeb (diff)
Make CMaxTable and CMinTable cunn-compatible (#954)
Diffstat (limited to 'CMinTable.lua')
-rw-r--r--CMinTable.lua25
1 files changed, 19 insertions, 6 deletions
diff --git a/CMinTable.lua b/CMinTable.lua
index a8385e8..25b9a19 100644
--- a/CMinTable.lua
+++ b/CMinTable.lua
@@ -4,25 +4,38 @@ function CMinTable:__init()
parent.__init(self)
self.gradInput = {}
self.minIdx = torch.Tensor()
+ self.mask = torch.Tensor()
+ self.minVals = torch.Tensor()
+ self.gradMaxVals = torch.Tensor()
end
function CMinTable:updateOutput(input)
self.output:resizeAs(input[1]):copy(input[1])
self.minIdx:resizeAs(input[1]):fill(1)
for i=2,#input do
- local mask = torch.lt(input[i], self.output)
- self.minIdx:maskedFill(mask, i)
- self.output:maskedCopy(mask, input[i][mask])
+ self.maskByteTensor = self.maskByteTensor or
+ (torch.type(self.output) == 'torch.CudaTensor' and
+ torch.CudaByteTensor() or torch.ByteTensor())
+ self.mask:lt(input[i], self.output)
+ self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
+ self.minIdx:maskedFill(self.maskByteTensor, i)
+ self.minVals:maskedSelect(input[i], self.maskByteTensor)
+ self.output:maskedCopy(self.maskByteTensor, self.minVals)
end
return self.output
end
function CMinTable:updateGradInput(input, gradOutput)
for i=1,#input do
- self.gradInput[i] = torch.Tensor()
+ self.gradInput[i] = self.gradInput[i] or input[i].new()
self.gradInput[i]:resizeAs(input[i]):fill(0.0)
- local mask = torch.eq(self.minIdx, i)
- self.gradInput[i]:maskedCopy(mask, gradOutput[mask])
+ self.maskByteTensor = self.maskByteTensor or
+ (torch.type(self.output) == 'torch.CudaTensor' and
+ torch.CudaByteTensor() or torch.ByteTensor())
+ self.mask:eq(self.minIdx, i)
+ self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
+ self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor)
+ self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals)
end
for i=#input+1, #self.gradInput do