Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2015-12-13 00:52:35 +0300
committersoumith <soumith@gmail.com>2015-12-30 00:38:06 +0300
commit1e3abcb83ee71478a5f772b757cfbaf5c66be3c4 (patch)
tree63993b7f944fd500b32112af5e5324cdccc81d9a /THNN.lua
parentb887940ea683415d1d94614fd15ddddf124af68b (diff)
Add functional version of AbsCriterion using metatable call
THNN state is now passed implicitely.
Diffstat (limited to 'THNN.lua')
-rw-r--r--THNN.lua24
1 files changed, 20 insertions, 4 deletions
diff --git a/THNN.lua b/THNN.lua
index adb07de..fd0b269 100644
--- a/THNN.lua
+++ b/THNN.lua
@@ -12,6 +12,19 @@ TH_API void THNN_(Abs_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput);
+
+TH_API void THNN_(AbsCriterion_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ real *output,
+ bool sizeAverage);
+TH_API void THNN_(AbsCriterion_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *gradInput,
+ bool sizeAverage);
]]
-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h
@@ -82,14 +95,14 @@ local function extract_function_names(s)
return t
end
-function THNN.bind(lib, base_names, type_name)
+function THNN.bind(lib, base_names, type_name, state_getter)
local ftable = {}
local prefix = 'THNN_' .. type_name
for i,n in ipairs(base_names) do
-- use pcall since some libs might not support all functions (e.g. cunn)
local ok,v = pcall(function() return lib[prefix .. n] end)
if ok then
- ftable[n] = v
+ ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state
end
end
return ftable
@@ -99,8 +112,11 @@ end
local function_names = extract_function_names(generic_THNN_h)
THNN.kernels = {}
-THNN.kernels['torch.FloatTensor'] = THNN.bind(THNN.C, function_names, 'Float')
-THNN.kernels['torch.DoubleTensor'] = THNN.bind(THNN.C, function_names, 'Double')
+THNN.kernels['torch.FloatTensor'] = THNN.bind(THNN.C, function_names, 'Float', THNN.getState)
+THNN.kernels['torch.DoubleTensor'] = THNN.bind(THNN.C, function_names, 'Double', THNN.getState)
+
+torch.getmetatable('torch.FloatTensor').THNN = THNN.kernels['torch.FloatTensor']
+torch.getmetatable('torch.DoubleTensor').THNN = THNN.kernels['torch.DoubleTensor']
function THNN.runKernel(f, type, ...)
local ftable = THNN.kernels[type]