diff options
author | Mevlana Gemici <mevlana@google.com> | 2016-02-02 09:01:30 +0300 |
---|---|---|
committer | Mevlana Gemici <mevlana@google.com> | 2016-02-02 09:01:30 +0300 |
commit | 683ce28d63eead4c935c102f98ad373fb45f940d (patch) | |
tree | f2e24ebde1e5988297e0a3b4c3d6d732b7a507d8 /TensorMath.lua | |
parent | c350ba40f74c8a394d5f43da7991f93e507914f3 (diff) |
Readding the argmin argmax as second argument
Diffstat (limited to 'TensorMath.lua')
-rw-r--r-- | TensorMath.lua | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 88820d4..e27cdc2 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -500,8 +500,23 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", for _,name in ipairs({"min", "max"}) do wrap(name, cname(name .. "all"), - {{name=Tensor}, - {name=real, creturned=true}}, + {{name=real, creturned=true}, + {name="IndexTensor", default=true, returned=true, noreadadd=true, + precall = function(arg) + local txt = {} + return table.concat(txt, '\n') + end, + postcall = function(arg) + local txt = {} + table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg + table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) + table.insert(txt, string.format('else')) -- means we did a new() + table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i)) + table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i)); + return table.concat(txt, '\n') + end + }, + {name=Tensor}}, cname(name), {{name=Tensor, default=true, returned=true}, {name="IndexTensor", default=true, returned=true, noreadadd=true}, |