Welcome to mirror list, hosted at ThFree Co, Russian Federation.

MaskedSelect.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: c3f7834e17460ee40f0356843452a452d9db88a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local unpack = unpack or table.unpack

local MaskedSelect, parent = torch.class('nn.MaskedSelect', 'nn.Module')

--[[ Sets the provided mask value for the module. ]]
function MaskedSelect:__init()
  parent.__init(self)
  self._maskIndices = torch.LongTensor()
  self._maskIndexBuffer = torch.LongTensor()
  self._maskIndexBufferCPU = torch.FloatTensor()
  self._gradBuffer = torch.Tensor()
  self._gradMask = torch.ByteTensor()
end

--[[ Performs maskedSelect operation. ]]
function MaskedSelect:updateOutput(input)
  local input, mask = unpack(input)
  self.output:maskedSelect(input, mask)
  return self.output
end

--[[ Reverse maps unmasked gradOutput back to gradInput. ]]
function MaskedSelect:updateGradInput(input, gradOutput)
  local input, mask = unpack(input)
  if input:type() == 'torch.CudaTensor' then
    self._maskIndexBufferCPU:range(1, mask:nElement()):resize(mask:size())
    self._maskIndexBuffer:resize(
      self._maskIndexBufferCPU:size()):copy(self._maskIndexBufferCPU)
  else
    self._maskIndexBuffer:range(1, mask:nElement()):resize(mask:size())
  end
  self._maskIndices:maskedSelect(self._maskIndexBuffer, mask)
  self._gradBuffer:resize(input:nElement()):zero()
  self._gradBuffer:scatter(1, self._maskIndices, gradOutput)
  self._gradBuffer:resize(input:size())
  self.gradInput = {self._gradBuffer,
                    self._gradMask:resize(mask:size()):fill(0)}
  return self.gradInput
end

function MaskedSelect:type(type, tensorCache)
  if not type then
    return self._type
  end
  self._gradBuffer = self._gradBuffer:type(type)
  self.gradInput = self.gradInput:type(type)
  self.output = self.output:type(type)

  -- These casts apply when switching between cuda/non-cuda types
  if type ~= 'torch.CudaTensor' then
    self._maskIndexBuffer = self._maskIndexBuffer:long()
    self._maskIndices = self._maskIndices:long()
    self._gradMask = self._gradMask:byte()
  elseif  type == 'torch.CudaTensor' then
    self._maskIndexBuffer = self._maskIndexBuffer:cuda()
    self._maskIndices = self._maskIndices:cuda()
    self._gradMask = self._gradMask:cuda()
  end
  self._type = type
  return self
end

function MaskedSelect:clearState()
  return nn.utils.clear(self, {'output',
                               'gradInput',
                               '_maskIndexBuffer',
                               '_maskIndexBufferCPU',
                               '_maskIndices',
                               '_gradBuffer',
                               '_gradMask'})
end