diff options
author | Sasank Chilamkurthy <sasankchilamkurthy@gmail.com> | 2016-08-11 15:11:23 +0300 |
---|---|---|
committer | Sasank Chilamkurthy <sasankchilamkurthy@gmail.com> | 2016-08-11 15:11:23 +0300 |
commit | a2e5587fa931a7d0fbda4194354e200763656e85 (patch) | |
tree | 43d95612409a4b4c1c8c8f37d678c50d91ac5b49 | |
parent | 59ac9afc330e3b848fb204afeb8d6f2835902540 (diff) |
Add cudnn.VolumetricLogSoftMax cudnn.VolumetricSoftMax.lua
-rw-r--r-- | VolumetricLogSoftMax.lua | 7 | ||||
-rw-r--r-- | VolumetricSoftMax.lua | 47 |
2 files changed, 54 insertions, 0 deletions
diff --git a/VolumetricLogSoftMax.lua b/VolumetricLogSoftMax.lua new file mode 100644 index 0000000..a23ed60 --- /dev/null +++ b/VolumetricLogSoftMax.lua @@ -0,0 +1,7 @@ +local SoftMax, parent = torch.class('cudnn.VolumetricLogSoftMax', 'cudnn.VolumetricSoftMax') + +function SoftMax:__init(fast) + parent.__init(self, fast) + self.ssm.mode = 'CUDNN_SOFTMAX_MODE_CHANNEL' + self.ssm.algorithm = 'CUDNN_SOFTMAX_LOG' +end diff --git a/VolumetricSoftMax.lua b/VolumetricSoftMax.lua new file mode 100644 index 0000000..7a463a2 --- /dev/null +++ b/VolumetricSoftMax.lua @@ -0,0 +1,47 @@ +local VolumetricSoftMax, parent = torch.class('cudnn.VolumetricSoftMax', 'nn.Module') + +function VolumetricSoftMax:__init(fast) + parent.__init(self) + self.ssm = cudnn.SpatialSoftMax(fast) +end + +local fold = function(input) + -- Fold time and height into one dimension + if input:dim() == 4 then + -- dthw -> d(t*h)w + input = input:view(input:size(1), input:size(2)*input:size(3), + input:size(4)) + else + -- bdthw -> bd(t*h)w + input = input:view(input:size(1), input:size(2), + input:size(3)*input:size(4), input:size(5)) + end + return input +end + +function VolumetricSoftMax:updateOutput(input) + assert(input:dim() == 4 or input:dim() == 5, + 'input should either be a 3d image or a minibatch of them') + local originalInputSize = input:size() + + -- Apply SpatialSoftMax to folded input + self.ssm:updateOutput(fold(input)) + self.output = self.ssm.output:view(originalInputSize) + return self.output +end + +function VolumetricSoftMax:updateGradInput(input, gradOutput) + assert(input:dim() == 4 or input:dim() == 5, + 'input should either be a 3d image or a minibatch of them') + + local originalInputSize = input:size() + self.ssm:updateGradInput(fold(input), fold(gradOutput)) + + self.gradInput = self.ssm.gradInput:view(originalInputSize) + return self.gradInput +end + +function VolumetricSoftMax:clearState() + self.ssm:clearState() + return parent.clearState(self) +end
\ No newline at end of file |