Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'TanhShrink.lua')
-rw-r--r--TanhShrink.lua20
1 files changed, 20 insertions, 0 deletions
diff --git a/TanhShrink.lua b/TanhShrink.lua
new file mode 100644
index 0000000..96df6c5
--- /dev/null
+++ b/TanhShrink.lua
@@ -0,0 +1,20 @@
+local TanhShrink, parent = torch.class('nn.TanhShrink','nn.Module')
+
+function TanhShrink:__init()
+ parent.__init(self)
+ self.tanh = nn.Tanh()
+end
+
+function TanhShrink:updateOutput(input)
+ local th = self.tanh:updateOutput(input)
+ self.output:resizeAs(input):copy(input)
+ self.output:add(-1,th)
+ return self.output
+end
+
+function TanhShrink:updateGradInput(input, gradOutput)
+ local dth = self.tanh:updateGradInput(input,gradOutput)
+ self.gradInput:resizeAs(input):copy(gradOutput)
+ self.gradInput:add(-1,dth)
+ return self.gradInput
+end