Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Johnson <jhj@fb.com>2016-03-14 23:21:54 +0300
committerJeff Johnson <jhj@fb.com>2016-03-14 23:41:04 +0300
commit565a0f1294aa90610fc7a1d07a656ffdacc9a017 (patch)
tree7c10391cf9a9d8ce02fc2c0b308b1049520f36aa /README.md
parenta7147d00e61a5e182a277995f5d1e99ec3bdf0f8 (diff)
kernel p2p access and non-blocking streams
Diffstat (limited to 'README.md')
-rw-r--r--README.md6
1 files changed, 5 insertions, 1 deletions
diff --git a/README.md b/README.md
index 234ac1c..7b929ec 100644
--- a/README.md
+++ b/README.md
@@ -31,7 +31,7 @@ This new tensor type behaves exactly like a `torch.FloatTensor`, but has a coupl
- `cutorch.withDevice(devID, f)` - This is a convenience for multi-GPU code, that takes in a device ID as well as a function f. It switches cutorch to the new device, executes the function f, and switches back cutorch to the original device.
#### Low-level streams functions (dont use this as a user, easy to shoot yourself in the foot):
-- `cutorch.reserveStreams(n)`: creates n user streams for use on every device.
+- `cutorch.reserveStreams(n [, nonblocking])`: creates n user streams for use on every device. NOTE: stream index `s` on device 1 is a different cudaStream_t than stream `s` on device 2. Takes an optional non-blocking flag; by default, this is assumed to be false. If true, then the stream is created with cudaStreamNonBlocking.
- `n = cutorch.getNumStreams()`: returns the number of user streams available on every device. By `default`, this is `0`, meaning only the default stream (stream 0) is available.
- `cutorch.setStream(n)`: specifies that the current stream active for the current device (or any other device) is `n`. This is preserved across device switches. 1-N are user streams, `0` is the default stream.
- `n = cutorch.getStream()`: returns the current stream active. By default, returns `0`.
@@ -41,6 +41,10 @@ This new tensor type behaves exactly like a `torch.FloatTensor`, but has a coupl
- `cutorch.streamBarrier({streams...})`: an N-to-N-way barrier between all the streams; all streams will wait for the completion of all other streams on the current device only. More efficient than creating the same N-to-N-way dependency via `streamWaitFor`.
- `cutorch.streamBarrierMultiDevice({[device]={streamsToWaitOn...}...})`: As with streamBarrier but allows barriers between streams on arbitrary devices. Creates a cross-device N-to-N-way barrier between all (device, stream) values listed.
- `cutorch.streamSynchronize(stream)`: equivalent to `cudaStreamSynchronize(stream)` for the current device. Blocks the CPU until stream completes its queued kernels/events.
+- `cutorch.setPeerToPeerAccess(dev, devToAccess, f)`: explicitly enable (`f` true) or disable p2p access (`f` false) from `dev` accessing memory on `devToAccess`. Affects copy efficiency (if disabled, copies will be d2d rather than p2p; i.e., the CPU intermediates), and affects kernel p2p access as well. Can only be enabled if the underlying hardware supports p2p access. p2p access is enabled by default for all pairs of devices if the underlying hardware supports it.
+- `cutorch.getPeerToPeerAccess(dev, devToAccess)`: returns whether or not p2p access is currently enabled or disabled, for reasons of a prior call of `setPeerToPeerAccess` or underlying hardware support.
+- `cutorch.setKernelPeerToPeerAccess(f)`: by default, kernels running on one device cannot directly access memory on another device. This is a check imposed by cutorch, to prevent synchronization and performance issues. To disable the check, call this with `f` true. Kernel p2p access is actually only allowed for a pair of devices if both this is true and the underlying `getPeerToPeerAccess` for the pair involved is true.
+- `cutorch.getKernelPeerToPeerAccess()`: returns whether or not kernel p2p checks are enabled or disabled.
##### Common Examples
Transfering a FloatTensor `src` to the GPU: