diff options
Diffstat (limited to 'DotProduct.lua')
-rw-r--r-- | DotProduct.lua | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/DotProduct.lua b/DotProduct.lua new file mode 100644 index 0000000..d16d295 --- /dev/null +++ b/DotProduct.lua @@ -0,0 +1,29 @@ +local DotProduct, parent = torch.class('nn.DotProduct', 'nn.Module') + +function DotProduct:__init() + parent.__init(self) + self.gradInput = {torch.Tensor(), torch.Tensor()} + self.output=torch.Tensor(1) +end + +function DotProduct:updateOutput(input,y) + self.output[1] = input[1]:dot(input[2]) + return self.output +end + +function DotProduct:updateGradInput(input, gradOutput) + local v1 = input[1] + local v2 = input[2] + local gw1=self.gradInput[1]; + local gw2=self.gradInput[2]; + gw1:resizeAs(v1) + gw2:resizeAs(v1) + + gw1:copy( v2) + gw1:mul(gradOutput[1]) + + gw2:copy( v1) + gw2:mul(gradOutput[1]) + + return self.gradInput +end |