diff options
author | Trevor Killeen <killeent@users.noreply.github.com> | 2017-04-25 17:39:20 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-04-25 17:39:20 +0300 |
commit | 93a6864de6a861f44c80fcf33e937acb768bccdd (patch) | |
tree | 5394343d6e655c708886641043f464807e4b65c3 /TensorMath.lua | |
parent | 181a86935614d1abc10ef4b9b95ad33f4fc911dd (diff) |
Generic TopK implementation (#744)
* move TopK to generic
* partial genericization of kernel code
* introduce TopKTypeConfig, specialize radix type and conversion for floats
* implement topk for byte tensor
* implement for char tensor
* implement for int tensor, extend test to check indices as well
* works for longs too
* make bitfield set/get a struct, add support for 64-bit types
* extend to double tensor
* implement for half tensor
* asserts; test fix
Diffstat (limited to 'TensorMath.lua')
-rw-r--r-- | TensorMath.lua | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 01341e8..e1b5a0f 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -885,6 +885,16 @@ for k, Tensor_ in pairs(handledTypenames) do {name="boolean", default=0}} ) + wrap("topk", + cname("topk"), + {{name=Tensor, default=true, returned=true}, + {name="CudaLongTensor", default=true, returned=true, noreadadd=true}, + {name=Tensor}, + {name="long", default=1}, + {name="index", default=lastdim(3)}, + {name="boolean", default=0}, + {name="boolean", default=0}}) + wrap("mode", cname("mode"), {{name=Tensor, default=true, returned=true, noreadadd=true}, |