Welcome to mirror list, hosted at ThFree Co, Russian Federation.

RReLU.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 843415f7e9e9af6a79707f4b34c62ba15d63d65d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
local ffi = require 'ffi'
local RReLU, parent = torch.class('nn.RReLU', 'nn.Module')

function RReLU:__init(l, u, ip)
   parent.__init(self)
   self.lower = l or 1/8
   self.upper = u or 1/3
   assert(self.lower <= self.upper and self.lower >= 0 and self.upper >= 0)
   self.noise = torch.Tensor()
   self.train = true
   self.inplace = ip or false
end

function RReLU:updateOutput(input)
   local gen = ffi.typeof('THGenerator**')(torch._gen)[0]
   input.THNN.RReLU_updateOutput(
      input:cdata(),
      self.output:cdata(),
      self.noise:cdata(),
      self.lower,
      self.upper,
      self.train,
      self.inplace,
      gen
   )
   return self.output
end

function RReLU:updateGradInput(input, gradOutput)
   input.THNN.RReLU_updateGradInput(
      input:cdata(),
      gradOutput:cdata(),
      self.gradInput:cdata(),
      self.noise:cdata(),
      self.lower,
      self.upper,
      self.train,
      self.inplace
   )
   return self.gradInput
end

function RReLU:__tostring__()
  return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper)
end

function RReLU:clearState()
   if self.noise then self.noise:set() end
   return parent.clearState(self)
end