diff options
author | Jonathan Uesato <juesato@mit.edu> | 2017-02-02 05:20:14 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-02-02 05:20:14 +0300 |
commit | db4244e6ee30b0ec689815a00fcc0c8c45f91b12 (patch) | |
tree | f1877ec9a78b3046319a79d2ef79c9d6315dcc34 /CMinTable.lua | |
parent | 02ebb69c1838c5c63a64b374faceff40671fdaeb (diff) |
Make CMaxTable and CMinTable cunn-compatible (#954)
Diffstat (limited to 'CMinTable.lua')
-rw-r--r-- | CMinTable.lua | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/CMinTable.lua b/CMinTable.lua index a8385e8..25b9a19 100644 --- a/CMinTable.lua +++ b/CMinTable.lua @@ -4,25 +4,38 @@ function CMinTable:__init() parent.__init(self) self.gradInput = {} self.minIdx = torch.Tensor() + self.mask = torch.Tensor() + self.minVals = torch.Tensor() + self.gradMaxVals = torch.Tensor() end function CMinTable:updateOutput(input) self.output:resizeAs(input[1]):copy(input[1]) self.minIdx:resizeAs(input[1]):fill(1) for i=2,#input do - local mask = torch.lt(input[i], self.output) - self.minIdx:maskedFill(mask, i) - self.output:maskedCopy(mask, input[i][mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:lt(input[i], self.output) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.minIdx:maskedFill(self.maskByteTensor, i) + self.minVals:maskedSelect(input[i], self.maskByteTensor) + self.output:maskedCopy(self.maskByteTensor, self.minVals) end return self.output end function CMinTable:updateGradInput(input, gradOutput) for i=1,#input do - self.gradInput[i] = torch.Tensor() + self.gradInput[i] = self.gradInput[i] or input[i].new() self.gradInput[i]:resizeAs(input[i]):fill(0.0) - local mask = torch.eq(self.minIdx, i) - self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) + self.maskByteTensor = self.maskByteTensor or + (torch.type(self.output) == 'torch.CudaTensor' and + torch.CudaByteTensor() or torch.ByteTensor()) + self.mask:eq(self.minIdx, i) + self.maskByteTensor:resize(self.mask:size()):copy(self.mask) + self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor) + self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals) end for i=#input+1, #self.gradInput do |