Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMevlana Gemici <mevlana@google.com>2016-02-02 09:01:30 +0300
committerMevlana Gemici <mevlana@google.com>2016-02-02 09:01:30 +0300
commit683ce28d63eead4c935c102f98ad373fb45f940d (patch)
treef2e24ebde1e5988297e0a3b4c3d6d732b7a507d8 /TensorMath.lua
parentc350ba40f74c8a394d5f43da7991f93e507914f3 (diff)
Readding the argmin argmax as second argument
Diffstat (limited to 'TensorMath.lua')
-rw-r--r--TensorMath.lua19
1 files changed, 17 insertions, 2 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 88820d4..e27cdc2 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -500,8 +500,23 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
for _,name in ipairs({"min", "max"}) do
wrap(name,
cname(name .. "all"),
- {{name=Tensor},
- {name=real, creturned=true}},
+ {{name=real, creturned=true},
+ {name="IndexTensor", default=true, returned=true, noreadadd=true,
+ precall = function(arg)
+ local txt = {}
+ return table.concat(txt, '\n')
+ end,
+ postcall = function(arg)
+ local txt = {}
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i))
+ table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i));
+ return table.concat(txt, '\n')
+ end
+ },
+ {name=Tensor}},
cname(name),
{{name=Tensor, default=true, returned=true},
{name="IndexTensor", default=true, returned=true, noreadadd=true},