diff options
Diffstat (limited to 'intern/cycles/device/device_network.cpp')
-rw-r--r-- | intern/cycles/device/device_network.cpp | 457 |
1 files changed, 363 insertions, 94 deletions
diff --git a/intern/cycles/device/device_network.cpp b/intern/cycles/device/device_network.cpp index 23c1a10fa0a..90339b89cce 100644 --- a/intern/cycles/device/device_network.cpp +++ b/intern/cycles/device/device_network.cpp @@ -20,9 +20,25 @@ #include "util_foreach.h" +#if defined(WITH_NETWORK) + CCL_NAMESPACE_BEGIN -#ifdef WITH_NETWORK +typedef map<device_ptr, device_ptr> PtrMap; +typedef vector<uint8_t> DataVector; +typedef map<device_ptr, DataVector> DataMap; + +/* tile list */ +typedef vector<RenderTile> TileList; + +/* search a list of tiles and find the one that matches the passed render tile */ +static TileList::iterator tile_list_find(TileList& tile_list, RenderTile& tile) +{ + for(TileList::iterator it = tile_list.begin(); it != tile_list.end(); ++it) + if(tile.x == it->x && tile.y == it->y && tile.start_sample == it->start_sample) + return it; + return tile_list.end(); +} class NetworkDevice : public Device { @@ -32,8 +48,10 @@ public: device_ptr mem_counter; DeviceTask the_task; /* todo: handle multiple tasks */ - NetworkDevice(Stats &stats, const char *address) - : Device(stats), socket(io_service) + thread_mutex rpc_lock; + + NetworkDevice(DeviceInfo& info, Stats &stats, const char *address) + : Device(info, stats, true), socket(io_service) { stringstream portstr; portstr << SERVER_PORT; @@ -64,6 +82,8 @@ public: void mem_alloc(device_memory& mem, MemoryType type) { + thread_scoped_lock lock(rpc_lock); + mem.device_pointer = ++mem_counter; RPCSend snd(socket, "mem_alloc"); @@ -75,6 +95,8 @@ public: void mem_copy_to(device_memory& mem) { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "mem_copy_to"); snd.add(mem); @@ -84,6 +106,10 @@ public: void mem_copy_from(device_memory& mem, int y, int w, int h, int elem) { + thread_scoped_lock lock(rpc_lock); + + size_t data_size = mem.memory_size(); + RPCSend snd(socket, "mem_copy_from"); snd.add(mem); @@ -94,11 +120,13 @@ public: snd.write(); RPCReceive rcv(socket); - rcv.read_buffer((void*)mem.data_pointer, mem.memory_size()); + rcv.read_buffer((void*)mem.data_pointer, data_size); } void mem_zero(device_memory& mem) { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "mem_zero"); snd.add(mem); @@ -108,6 +136,8 @@ public: void mem_free(device_memory& mem) { if(mem.device_pointer) { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "mem_free"); snd.add(mem); @@ -119,6 +149,8 @@ public: void const_copy_to(const char *name, void *host, size_t size) { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "const_copy_to"); string name_string(name); @@ -131,6 +163,8 @@ public: void tex_alloc(const char *name, device_memory& mem, bool interpolation, bool periodic) { + thread_scoped_lock lock(rpc_lock); + mem.device_pointer = ++mem_counter; RPCSend snd(socket, "tex_alloc"); @@ -148,6 +182,8 @@ public: void tex_free(device_memory& mem) { if(mem.device_pointer) { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "tex_free"); snd.add(mem); @@ -157,8 +193,25 @@ public: } } + bool load_kernels(bool experimental) + { + thread_scoped_lock lock(rpc_lock); + + RPCSend snd(socket, "load_kernels"); + snd.add(experimental); + snd.write(); + + bool result; + RPCReceive rcv(socket); + rcv.read(result); + + return result; + } + void task_add(DeviceTask& task) { + thread_scoped_lock lock(rpc_lock); + the_task = task; RPCSend snd(socket, "task_add"); @@ -168,55 +221,73 @@ public: void task_wait() { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "task_wait"); snd.write(); - list<RenderTile> the_tiles; + lock.unlock(); + + TileList the_tiles; /* todo: run this threaded for connecting to multiple clients */ for(;;) { - RPCReceive rcv(socket); RenderTile tile; + lock.lock(); + RPCReceive rcv(socket); + if(rcv.name == "acquire_tile") { + lock.unlock(); + /* todo: watch out for recursive calls! */ if(the_task.acquire_tile(this, tile)) { /* write return as bool */ the_tiles.push_back(tile); + lock.lock(); RPCSend snd(socket, "acquire_tile"); snd.add(tile); snd.write(); + lock.unlock(); } else { + lock.lock(); RPCSend snd(socket, "acquire_tile_none"); snd.write(); + lock.unlock(); } } else if(rcv.name == "release_tile") { rcv.read(tile); + lock.unlock(); - for(list<RenderTile>::iterator it = the_tiles.begin(); it != the_tiles.end(); it++) { - if(tile.x == it->x && tile.y == it->y && tile.start_sample == it->start_sample) { - tile.buffers = it->buffers; - the_tiles.erase(it); - break; - } + TileList::iterator it = tile_list_find(the_tiles, tile); + if (it != the_tiles.end()) { + tile.buffers = it->buffers; + the_tiles.erase(it); } assert(tile.buffers != NULL); the_task.release_tile(tile); + lock.lock(); RPCSend snd(socket, "release_tile"); snd.write(); + lock.unlock(); } - else if(rcv.name == "task_wait_done") + else if(rcv.name == "task_wait_done") { + lock.unlock(); break; + } + else + lock.unlock(); } } void task_cancel() { + thread_scoped_lock lock(rpc_lock); RPCSend snd(socket, "task_cancel"); snd.write(); } @@ -224,7 +295,7 @@ public: Device *device_network_create(DeviceInfo& info, Stats &stats, const char *address) { - return new NetworkDevice(stats, address); + return new NetworkDevice(info, stats, address); } void device_network_info(vector<DeviceInfo>& devices) @@ -243,8 +314,10 @@ void device_network_info(vector<DeviceInfo>& devices) class DeviceServer { public: + thread_mutex rpc_lock; + DeviceServer(Device *device_, tcp::socket& socket_) - : device(device_), socket(socket_) + : device(device_), socket(socket_), stop(false), blocked_waiting(false) { } @@ -252,56 +325,151 @@ public: { /* receive remote function calls */ for(;;) { - RPCReceive rcv(socket); + listen_step(); - if(rcv.name == "stop") + if(stop) break; - - process(rcv); } } protected: - void process(RPCReceive& rcv) + void listen_step() + { + thread_scoped_lock lock(rpc_lock); + RPCReceive rcv(socket); + + if(rcv.name == "stop") + stop = true; + else + process(rcv, lock); + } + + /* create a memory buffer for a device buffer and insert it into mem_data */ + DataVector &data_vector_insert(device_ptr client_pointer, size_t data_size) + { + /* create a new DataVector and insert it into mem_data */ + pair<DataMap::iterator,bool> data_ins = mem_data.insert( + DataMap::value_type(client_pointer, DataVector())); + + /* make sure it was a unique insertion */ + assert(data_ins.second); + + /* get a reference to the inserted vector */ + DataVector &data_v = data_ins.first->second; + + /* size the vector */ + data_v.resize(data_size); + + return data_v; + } + + DataVector &data_vector_find(device_ptr client_pointer) { - // fprintf(stderr, "receive process %s\n", rcv.name.c_str()); + DataMap::iterator i = mem_data.find(client_pointer); + assert(i != mem_data.end()); + return i->second; + } + + /* setup mapping and reverse mapping of client_pointer<->real_pointer */ + void pointer_mapping_insert(device_ptr client_pointer, device_ptr real_pointer) + { + pair<PtrMap::iterator,bool> mapins; + + /* insert mapping from client pointer to our real device pointer */ + mapins = ptr_map.insert(PtrMap::value_type(client_pointer, real_pointer)); + assert(mapins.second); + + /* insert reverse mapping from real our device pointer to client pointer */ + mapins = ptr_imap.insert(PtrMap::value_type(real_pointer, client_pointer)); + assert(mapins.second); + } + + device_ptr device_ptr_from_client_pointer(device_ptr client_pointer) + { + PtrMap::iterator i = ptr_map.find(client_pointer); + assert(i != ptr_map.end()); + return i->second; + } + + device_ptr device_ptr_from_client_pointer_erase(device_ptr client_pointer) + { + PtrMap::iterator i = ptr_map.find(client_pointer); + assert(i != ptr_map.end()); + + device_ptr result = i->second; + /* erase the mapping */ + ptr_map.erase(i); + + /* erase the reverse mapping */ + PtrMap::iterator irev = ptr_imap.find(result); + assert(irev != ptr_imap.end()); + ptr_imap.erase(irev); + + /* erase the data vector */ + DataMap::iterator idata = mem_data.find(client_pointer); + assert(idata != mem_data.end()); + mem_data.erase(idata); + + return result; + } + + /* note that the lock must be already acquired upon entry. + * This is necessary because the caller often peeks at + * the header and delegates control to here when it doesn't + * specifically handle the current RPC. + * The lock must be unlocked before returning */ + void process(RPCReceive& rcv, thread_scoped_lock &lock) + { if(rcv.name == "mem_alloc") { MemoryType type; network_device_memory mem; - device_ptr remote_pointer; + device_ptr client_pointer; rcv.read(mem); rcv.read(type); - /* todo: CPU needs mem.data_pointer */ + lock.unlock(); + + client_pointer = mem.device_pointer; - remote_pointer = mem.device_pointer; + /* create a memory buffer for the device buffer */ + size_t data_size = mem.memory_size(); + DataVector &data_v = data_vector_insert(client_pointer, data_size); - mem_data[remote_pointer] = vector<uint8_t>(); - mem_data[remote_pointer].resize(mem.memory_size()); - if(mem.memory_size()) - mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]); + if(data_size) + mem.data_pointer = (device_ptr)&(data_v[0]); else mem.data_pointer = 0; + /* perform the allocation on the actual device */ device->mem_alloc(mem, type); - ptr_map[remote_pointer] = mem.device_pointer; - ptr_imap[mem.device_pointer] = remote_pointer; + /* store a mapping to/from client_pointer and real device pointer */ + pointer_mapping_insert(client_pointer, mem.device_pointer); } else if(rcv.name == "mem_copy_to") { network_device_memory mem; rcv.read(mem); + lock.unlock(); - device_ptr remote_pointer = mem.device_pointer; - mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]); + device_ptr client_pointer = mem.device_pointer; - rcv.read_buffer((uint8_t*)mem.data_pointer, mem.memory_size()); + DataVector &data_v = data_vector_find(client_pointer); - mem.device_pointer = ptr_map[remote_pointer]; + size_t data_size = mem.memory_size(); + /* get pointer to memory buffer for device buffer */ + mem.data_pointer = (device_ptr)&data_v[0]; + + /* copy data from network into memory buffer */ + rcv.read_buffer((uint8_t*)mem.data_pointer, data_size); + + /* translate the client pointer to a real device pointer */ + mem.device_pointer = device_ptr_from_client_pointer(client_pointer); + + /* copy the data from the memory buffer to the device buffer */ device->mem_copy_to(mem); } else if(rcv.name == "mem_copy_from") { @@ -314,37 +482,47 @@ protected: rcv.read(h); rcv.read(elem); - device_ptr remote_pointer = mem.device_pointer; - mem.device_pointer = ptr_map[remote_pointer]; - mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]); + device_ptr client_pointer = mem.device_pointer; + mem.device_pointer = device_ptr_from_client_pointer(client_pointer); + + DataVector &data_v = data_vector_find(client_pointer); + + mem.data_pointer = (device_ptr)&(data_v[0]); device->mem_copy_from(mem, y, w, h, elem); + size_t data_size = mem.memory_size(); + RPCSend snd(socket); snd.write(); - snd.write_buffer((uint8_t*)mem.data_pointer, mem.memory_size()); + snd.write_buffer((uint8_t*)mem.data_pointer, data_size); + lock.unlock(); } else if(rcv.name == "mem_zero") { network_device_memory mem; rcv.read(mem); - device_ptr remote_pointer = mem.device_pointer; - mem.device_pointer = ptr_map[mem.device_pointer]; - mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]); + lock.unlock(); + + device_ptr client_pointer = mem.device_pointer; + mem.device_pointer = device_ptr_from_client_pointer(client_pointer); + + DataVector &data_v = data_vector_find(client_pointer); + + mem.data_pointer = (device_ptr)&(data_v[0]); device->mem_zero(mem); } else if(rcv.name == "mem_free") { network_device_memory mem; - device_ptr remote_pointer; + device_ptr client_pointer; rcv.read(mem); + lock.unlock(); + + client_pointer = mem.device_pointer; - remote_pointer = mem.device_pointer; - mem.device_pointer = ptr_map[mem.device_pointer]; - ptr_map.erase(remote_pointer); - ptr_imap.erase(mem.device_pointer); - mem_data.erase(remote_pointer); + mem.device_pointer = device_ptr_from_client_pointer_erase(client_pointer); device->mem_free(mem); } @@ -357,6 +535,7 @@ protected: vector<char> host_vector(size); rcv.read_buffer(&host_vector[0], size); + lock.unlock(); device->const_copy_to(name_string.c_str(), &host_vector[0], size); } @@ -365,53 +544,76 @@ protected: string name; bool interpolation; bool periodic; - device_ptr remote_pointer; + device_ptr client_pointer; rcv.read(name); rcv.read(mem); rcv.read(interpolation); rcv.read(periodic); + lock.unlock(); + + client_pointer = mem.device_pointer; + + size_t data_size = mem.memory_size(); - remote_pointer = mem.device_pointer; + DataVector &data_v = data_vector_insert(client_pointer, data_size); - mem_data[remote_pointer] = vector<uint8_t>(); - mem_data[remote_pointer].resize(mem.memory_size()); - if(mem.memory_size()) - mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]); + if(data_size) + mem.data_pointer = (device_ptr)&(data_v[0]); else mem.data_pointer = 0; - rcv.read_buffer((uint8_t*)mem.data_pointer, mem.memory_size()); + rcv.read_buffer((uint8_t*)mem.data_pointer, data_size); device->tex_alloc(name.c_str(), mem, interpolation, periodic); - ptr_map[remote_pointer] = mem.device_pointer; - ptr_imap[mem.device_pointer] = remote_pointer; + pointer_mapping_insert(client_pointer, mem.device_pointer); } else if(rcv.name == "tex_free") { network_device_memory mem; - device_ptr remote_pointer; + device_ptr client_pointer; rcv.read(mem); + lock.unlock(); - remote_pointer = mem.device_pointer; - mem.device_pointer = ptr_map[mem.device_pointer]; - ptr_map.erase(remote_pointer); - ptr_map.erase(mem.device_pointer); - mem_data.erase(remote_pointer); + client_pointer = mem.device_pointer; + + mem.device_pointer = device_ptr_from_client_pointer_erase(client_pointer); device->tex_free(mem); } + else if(rcv.name == "load_kernels") { + bool experimental; + rcv.read(experimental); + + bool result; + result = device->load_kernels(experimental); + RPCSend snd(socket); + snd.add(result); + snd.write(); + lock.unlock(); + } else if(rcv.name == "task_add") { DeviceTask task; rcv.read(task); + lock.unlock(); + + if(task.buffer) + task.buffer = device_ptr_from_client_pointer(task.buffer); + + if(task.rgba_half) + task.rgba_half = device_ptr_from_client_pointer(task.rgba_half); + + if(task.rgba_byte) + task.rgba_byte = device_ptr_from_client_pointer(task.rgba_byte); + + if(task.shader_input) + task.shader_input = device_ptr_from_client_pointer(task.shader_input); + + if(task.shader_output) + task.shader_output = device_ptr_from_client_pointer(task.shader_output); - if(task.buffer) task.buffer = ptr_map[task.buffer]; - if(task.rgba_byte) task.rgba_byte = ptr_map[task.rgba_byte]; - if(task.rgba_half) task.rgba_half = ptr_map[task.rgba_half]; - if(task.shader_input) task.shader_input = ptr_map[task.shader_input]; - if(task.shader_output) task.shader_output = ptr_map[task.shader_output]; task.acquire_tile = function_bind(&DeviceServer::task_acquire_tile, this, _1, _2); task.release_tile = function_bind(&DeviceServer::task_release_tile, this, _1); @@ -422,14 +624,44 @@ protected: device->task_add(task); } else if(rcv.name == "task_wait") { + lock.unlock(); + + blocked_waiting = true; device->task_wait(); + blocked_waiting = false; + lock.lock(); RPCSend snd(socket, "task_wait_done"); snd.write(); + lock.unlock(); } else if(rcv.name == "task_cancel") { + lock.unlock(); device->task_cancel(); } + else if(rcv.name == "acquire_tile") { + AcquireEntry entry; + entry.name = rcv.name; + rcv.read(entry.tile); + acquire_queue.push_back(entry); + lock.unlock(); + } + else if(rcv.name == "acquire_tile_none") { + AcquireEntry entry; + entry.name = rcv.name; + acquire_queue.push_back(entry); + lock.unlock(); + } + else if(rcv.name == "release_tile") { + AcquireEntry entry; + entry.name = rcv.name; + acquire_queue.push_back(entry); + lock.unlock(); + } + else { + cout << "Error: unexpected RPC receive call \"" + rcv.name + "\"\n"; + lock.unlock(); + } } bool task_acquire_tile(Device *device, RenderTile& tile) @@ -441,23 +673,34 @@ protected: RPCSend snd(socket, "acquire_tile"); snd.write(); - while(1) { - RPCReceive rcv(socket); + do { + if(blocked_waiting) + listen_step(); - if(rcv.name == "acquire_tile") { - rcv.read(tile); + /* todo: avoid busy wait loop */ + thread_scoped_lock lock(rpc_lock); - if(tile.buffer) tile.buffer = ptr_map[tile.buffer]; - if(tile.rng_state) tile.rng_state = ptr_map[tile.rng_state]; + if(!acquire_queue.empty()) { + AcquireEntry entry = acquire_queue.front(); + acquire_queue.pop_front(); - result = true; - break; + if(entry.name == "acquire_tile") { + tile = entry.tile; + + if(tile.buffer) tile.buffer = ptr_map[tile.buffer]; + if(tile.rng_state) tile.rng_state = ptr_map[tile.rng_state]; + + result = true; + break; + } + else if(entry.name == "acquire_tile_none") { + break; + } + else { + cout << "Error: unexpected acquire RPC receive call \"" + entry.name + "\"\n"; + } } - else if(rcv.name == "acquire_tile_none") - break; - else - process(rcv); - } + } while(acquire_queue.empty() && !stop); return result; } @@ -479,18 +722,34 @@ protected: if(tile.buffer) tile.buffer = ptr_imap[tile.buffer]; if(tile.rng_state) tile.rng_state = ptr_imap[tile.rng_state]; - RPCSend snd(socket, "release_tile"); - snd.add(tile); - snd.write(); + { + thread_scoped_lock lock(rpc_lock); + RPCSend snd(socket, "release_tile"); + snd.add(tile); + snd.write(); + lock.unlock(); + } - while(1) { - RPCReceive rcv(socket); + do { + if(blocked_waiting) + listen_step(); - if(rcv.name == "release_tile") - break; - else - process(rcv); - } + /* todo: avoid busy wait loop */ + thread_scoped_lock lock(rpc_lock); + + if(!acquire_queue.empty()) { + AcquireEntry entry = acquire_queue.front(); + acquire_queue.pop_front(); + + if(entry.name == "release_tile") { + lock.unlock(); + break; + } + else { + cout << "Error: unexpected release RPC receive call \"" + entry.name + "\"\n"; + } + } + } while(acquire_queue.empty() && !stop); } bool task_get_cancel() @@ -503,11 +762,20 @@ protected: tcp::socket& socket; /* mapping of remote to local pointer */ - map<device_ptr, device_ptr> ptr_map; - map<device_ptr, device_ptr> ptr_imap; - map<device_ptr, vector<uint8_t> > mem_data; + PtrMap ptr_map; + PtrMap ptr_imap; + DataMap mem_data; + + struct AcquireEntry { + string name; + RenderTile tile; + }; thread_mutex acquire_mutex; + list<AcquireEntry> acquire_queue; + + bool stop; + bool blocked_waiting; /* todo: free memory and device (osl) on network error */ }; @@ -540,7 +808,8 @@ void Device::server_run() } } +CCL_NAMESPACE_END + #endif -CCL_NAMESPACE_END |