Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ELU.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 48a6caa2cc7c80fc7ef87da2c2ec6caef1842e9d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
local ELU, parent = torch.class('nn.ELU', 'nn.Module')

--[[
   Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter
   Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
   http://arxiv.org/pdf/1511.07289.pdf
--]]

function ELU:__init(alpha, inplace)
   parent.__init(self)
   self.alpha = alpha or 1
   assert(type(self.alpha) == 'number')
   self.inplace = inplace or false
   assert(type(self.inplace) == 'boolean')
end

function ELU:updateOutput(input)
   local inplace = self.inplace or false

   input.THNN.ELU_updateOutput(
      input:cdata(),
      self.output:cdata(),
      self.alpha,
      inplace
   )
   return self.output
end

function ELU:updateGradInput(input, gradOutput)
   local inplace = self.inplace or false

   input.THNN.ELU_updateGradInput(
      input:cdata(),
      gradOutput:cdata(),
      self.gradInput:cdata(),
      self.output:cdata(),
      self.alpha,
      inplace
   )
   return self.gradInput
end

function ELU:__tostring__()
  return string.format('%s (alpha:%f)', torch.type(self), self.alpha)
end