/* * Copyright 2011-2013 Blender Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef __DEVICE_NETWORK_H__ #define __DEVICE_NETWORK_H__ #ifdef WITH_NETWORK # include # include # include # include # include # include # include # include # include # include # include # include # include "render/buffers.h" # include "util/util_foreach.h" # include "util/util_list.h" # include "util/util_map.h" # include "util/util_param.h" # include "util/util_string.h" CCL_NAMESPACE_BEGIN using std::cerr; using std::cout; using std::exception; using std::hex; using std::setw; using boost::asio::ip::tcp; static const int SERVER_PORT = 5120; static const int DISCOVER_PORT = 5121; static const string DISCOVER_REQUEST_MSG = "REQUEST_RENDER_SERVER_IP"; static const string DISCOVER_REPLY_MSG = "REPLY_RENDER_SERVER_IP"; # if 0 typedef boost::archive::text_oarchive o_archive; typedef boost::archive::text_iarchive i_archive; # else typedef boost::archive::binary_oarchive o_archive; typedef boost::archive::binary_iarchive i_archive; # endif /* Serialization of device memory */ class network_device_memory : public device_memory { public: network_device_memory(Device *device) : device_memory(device, "", MEM_READ_ONLY) { } ~network_device_memory() { device_pointer = 0; }; vector local_data; }; /* Common netowrk error function / object for both DeviceNetwork and DeviceServer*/ class NetworkError { public: NetworkError() { error = ""; error_count = 0; } ~NetworkError() { } void network_error(const string &message) { error = message; error_count += 1; } bool have_error() { return true ? error_count > 0 : false; } private: string error; int error_count; }; /* Remote procedure call Send */ class RPCSend { public: RPCSend(tcp::socket &socket_, NetworkError *e, const string &name_ = "") : name(name_), socket(socket_), archive(archive_stream), sent(false) { archive &name_; error_func = e; fprintf(stderr, "rpc send %s\n", name.c_str()); } ~RPCSend() { } void add(const device_memory &mem) { archive &mem.data_type &mem.data_elements &mem.data_size; archive &mem.data_width &mem.data_height &mem.data_depth &mem.device_pointer; archive &mem.type &string(mem.name); archive &mem.interpolation &mem.extension; archive &mem.device_pointer; } template void add(const T &data) { archive &data; } void add(const DeviceTask &task) { int type = (int)task.type; archive &type &task.x &task.y &task.w &task.h; archive &task.rgba_byte &task.rgba_half &task.buffer &task.sample &task.num_samples; archive &task.offset &task.stride; archive &task.shader_input &task.shader_output &task.shader_eval_type; archive &task.shader_x &task.shader_w; archive &task.need_finish_queue; } void add(const RenderTile &tile) { archive &tile.x &tile.y &tile.w &tile.h; archive &tile.start_sample &tile.num_samples &tile.sample; archive &tile.resolution &tile.offset &tile.stride; archive &tile.buffer; } void write() { boost::system::error_code error; /* get string from stream */ string archive_str = archive_stream.str(); /* first send fixed size header with size of following data */ ostringstream header_stream; header_stream << setw(8) << hex << archive_str.size(); string header_str = header_stream.str(); boost::asio::write( socket, boost::asio::buffer(header_str), boost::asio::transfer_all(), error); if (error.value()) error_func->network_error(error.message()); /* then send actual data */ boost::asio::write( socket, boost::asio::buffer(archive_str), boost::asio::transfer_all(), error); if (error.value()) error_func->network_error(error.message()); sent = true; } void write_buffer(void *buffer, size_t size) { boost::system::error_code error; boost::asio::write( socket, boost::asio::buffer(buffer, size), boost::asio::transfer_all(), error); if (error.value()) error_func->network_error(error.message()); } protected: string name; tcp::socket &socket; ostringstream archive_stream; o_archive archive; bool sent; NetworkError *error_func; }; /* Remote procedure call Receive */ class RPCReceive { public: RPCReceive(tcp::socket &socket_, NetworkError *e) : socket(socket_), archive_stream(NULL), archive(NULL) { error_func = e; /* read head with fixed size */ vector header(8); boost::system::error_code error; size_t len = boost::asio::read(socket, boost::asio::buffer(header), error); if (error.value()) { error_func->network_error(error.message()); } /* verify if we got something */ if (len == header.size()) { /* decode header */ string header_str(&header[0], header.size()); istringstream header_stream(header_str); size_t data_size; if ((header_stream >> hex >> data_size)) { vector data(data_size); size_t len = boost::asio::read(socket, boost::asio::buffer(data), error); if (error.value()) error_func->network_error(error.message()); if (len == data_size) { archive_str = (data.size()) ? string(&data[0], data.size()) : string(""); archive_stream = new istringstream(archive_str); archive = new i_archive(*archive_stream); *archive &name; fprintf(stderr, "rpc receive %s\n", name.c_str()); } else { error_func->network_error("Network receive error: data size doesn't match header"); } } else { error_func->network_error("Network receive error: can't decode data size from header"); } } else { error_func->network_error("Network receive error: invalid header size"); } } ~RPCReceive() { delete archive; delete archive_stream; } void read(network_device_memory &mem, string &name) { *archive &mem.data_type &mem.data_elements &mem.data_size; *archive &mem.data_width &mem.data_height &mem.data_depth &mem.device_pointer; *archive &mem.type &name; *archive &mem.interpolation &mem.extension; *archive &mem.device_pointer; mem.name = name.c_str(); mem.host_pointer = 0; /* Can't transfer OpenGL texture over network. */ if (mem.type == MEM_PIXELS) { mem.type = MEM_READ_WRITE; } } template void read(T &data) { *archive &data; } void read_buffer(void *buffer, size_t size) { boost::system::error_code error; size_t len = boost::asio::read(socket, boost::asio::buffer(buffer, size), error); if (error.value()) { error_func->network_error(error.message()); } if (len != size) cout << "Network receive error: buffer size doesn't match expected size\n"; } void read(DeviceTask &task) { int type; *archive &type &task.x &task.y &task.w &task.h; *archive &task.rgba_byte &task.rgba_half &task.buffer &task.sample &task.num_samples; *archive &task.offset &task.stride; *archive &task.shader_input &task.shader_output &task.shader_eval_type; *archive &task.shader_x &task.shader_w; *archive &task.need_finish_queue; task.type = (DeviceTask::Type)type; } void read(RenderTile &tile) { *archive &tile.x &tile.y &tile.w &tile.h; *archive &tile.start_sample &tile.num_samples &tile.sample; *archive &tile.resolution &tile.offset &tile.stride; *archive &tile.buffer; tile.buffers = NULL; } string name; protected: tcp::socket &socket; string archive_str; istringstream *archive_stream; i_archive *archive; NetworkError *error_func; }; /* Server auto discovery */ class ServerDiscovery { public: explicit ServerDiscovery(bool discover = false) : listen_socket(io_service), collect_servers(false) { /* setup listen socket */ listen_endpoint.address(boost::asio::ip::address_v4::any()); listen_endpoint.port(DISCOVER_PORT); listen_socket.open(listen_endpoint.protocol()); boost::asio::socket_base::reuse_address option(true); listen_socket.set_option(option); listen_socket.bind(listen_endpoint); /* setup receive callback */ async_receive(); /* start server discovery */ if (discover) { collect_servers = true; servers.clear(); broadcast_message(DISCOVER_REQUEST_MSG); } /* start thread */ work = new boost::asio::io_service::work(io_service); thread = new boost::thread(boost::bind(&boost::asio::io_service::run, &io_service)); } ~ServerDiscovery() { io_service.stop(); thread->join(); delete thread; delete work; } vector get_server_list() { vector result; mutex.lock(); result = vector(servers.begin(), servers.end()); mutex.unlock(); return result; } private: void handle_receive_from(const boost::system::error_code &error, size_t size) { if (error) { cout << "Server discovery receive error: " << error.message() << "\n"; return; } if (size > 0) { string msg = string(receive_buffer, size); /* handle incoming message */ if (collect_servers) { if (msg == DISCOVER_REPLY_MSG) { string address = receive_endpoint.address().to_string(); mutex.lock(); /* add address if it's not already in the list */ bool found = std::find(servers.begin(), servers.end(), address) != servers.end(); if (!found) servers.push_back(address); mutex.unlock(); } } else { /* reply to request */ if (msg == DISCOVER_REQUEST_MSG) broadcast_message(DISCOVER_REPLY_MSG); } } async_receive(); } void async_receive() { listen_socket.async_receive_from(boost::asio::buffer(receive_buffer), receive_endpoint, boost::bind(&ServerDiscovery::handle_receive_from, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } void broadcast_message(const string &msg) { /* setup broadcast socket */ boost::asio::ip::udp::socket socket(io_service); socket.open(boost::asio::ip::udp::v4()); boost::asio::socket_base::broadcast option(true); socket.set_option(option); boost::asio::ip::udp::endpoint broadcast_endpoint( boost::asio::ip::address::from_string("255.255.255.255"), DISCOVER_PORT); /* broadcast message */ socket.send_to(boost::asio::buffer(msg), broadcast_endpoint); } /* network service and socket */ boost::asio::io_service io_service; boost::asio::ip::udp::endpoint listen_endpoint; boost::asio::ip::udp::socket listen_socket; /* threading */ boost::thread *thread; boost::asio::io_service::work *work; boost::mutex mutex; /* buffer and endpoint for receiving messages */ char receive_buffer[256]; boost::asio::ip::udp::endpoint receive_endpoint; // os, version, devices, status, host name, group name, ip as far as fields go struct ServerInfo { string cycles_version; string os; int device_count; string status; string host_name; string group_name; string host_addr; }; /* collection of server addresses in list */ bool collect_servers; vector servers; }; CCL_NAMESPACE_END #endif #endif /* __DEVICE_NETWORK_H__ */