Logo Search packages:      
Sourcecode: wesnoth-1.7 version File versions  Download package

network_worker.cpp

/* $Id: network_worker.cpp 40489 2010-01-01 13:16:49Z mordante $ */
/*
   Copyright (C) 2003 - 2010 by David White <dave@whitevine.net>
   Part of the Battle for Wesnoth Project http://www.wesnoth.org/

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License version 2
   or at your option any later version.
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY.

   See the COPYING file for more details.
*/

/**
 * Network worker handles data transfers in threads
 * Remember to use mutexs as little as possible
 * All global vars should be used in mutex
 * FIXME: @todo: All code which holds a mutex should run O(1) time
 * for scalability. Implement read/write locks.
 *  (postponed for 1.5)
 */

#include "global.hpp"

#include "scoped_resource.hpp"
#include "log.hpp"
#include "network_worker.hpp"
#include "filesystem.hpp"
#include "thread.hpp"
#include "serialization/binary_or_text.hpp"
#include "serialization/binary_wml.hpp"
#include "serialization/parser.hpp"
#include "wesconfig.h"

#include <cerrno>
#include <deque>
#include <sstream>

#ifdef HAVE_SENDFILE
#include <sys/sendfile.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#endif


#ifdef __AMIGAOS4__
#include <unistd.h>
//#include <sys/clib2_net.h>
#endif

#if defined(_WIN32) || defined(__WIN32__) || defined (WIN32)
#  undef INADDR_ANY
#  undef INADDR_BROADCAST
#  undef INADDR_NONE
#  ifndef NOMINMAX
#   define NOMINMAX
#  endif
#  include <windows.h>
#  define USE_SELECT 1
typedef int socklen_t;
#else
#  include <sys/types.h>
#  include <sys/socket.h>
#  ifdef __BEOS__
#    include <socket.h>
#  else
#    include <fcntl.h>
#  endif
#  define SOCKET int
#  ifdef HAVE_POLL_H
#    define USE_POLL 1
#    include <poll.h>
#  elif defined(HAVE_SYS_POLL_H)
#    define USE_POLL 1
#    include <sys/poll.h>
#  endif
#  ifndef USE_POLL
#    define USE_SELECT 1
#    ifdef HAVE_SYS_SELECT_H
#      include <sys/select.h>
#    else
#      include <sys/time.h>
#      include <sys/types.h>
#      include <unistd.h>
#    endif
#  endif
#endif

static lg::log_domain log_network("network");
#define DBG_NW LOG_STREAM(debug, log_network)
#define LOG_NW LOG_STREAM(info, log_network)
#define ERR_NW LOG_STREAM(err, log_network)

namespace {
struct _TCPsocket {
      int ready;
      SOCKET channel;
      IPaddress remoteAddress;
      IPaddress localAddress;
      int sflag;
};

#ifndef NUM_SHARDS
#define NUM_SHARDS 1
#endif

unsigned int waiting_threads[NUM_SHARDS];
size_t min_threads = 0;
size_t max_threads = 0;

size_t get_shard(TCPsocket sock) { return reinterpret_cast<uintptr_t>(sock)%NUM_SHARDS; }

struct buffer {
      explicit buffer(TCPsocket sock) :
            sock(sock),
            config_buf(),
            config_error(""),
            stream(),
            gzipped(false),
            raw_buffer()
            {}

      TCPsocket sock;
      mutable config config_buf;
      std::string config_error;
      std::ostringstream stream;

      /**
       * Do we wish to send the data gzipped, if not use binary wml. This needs
       * to stay until the last user of binary_wml has been removed.
       */
      bool gzipped;

      /**
       * This field is used if we're sending a raw buffer instead of through a
       * config object. It will contain the entire contents of the buffer being
       * sent.
       */
      std::vector<char> raw_buffer;
};


bool managed = false, raw_data_only = false;
typedef std::vector< buffer* > buffer_set;
buffer_set outgoing_bufs[NUM_SHARDS];

struct schema_pair
{
      schema_pair() :
            incoming(),
            outgoing()
      {
      }

      compression_schema incoming, outgoing;
};

typedef std::map<TCPsocket,schema_pair> schema_map;

schema_map schemas; //schemas_mutex

/** a queue of sockets that we are waiting to receive on */
typedef std::vector<TCPsocket> receive_list;
receive_list pending_receives[NUM_SHARDS];

typedef std::deque<buffer*> received_queue;
received_queue received_data_queue;  // receive_mutex

enum SOCKET_STATE { SOCKET_READY, SOCKET_LOCKED, SOCKET_ERRORED, SOCKET_INTERRUPT };
typedef std::map<TCPsocket,SOCKET_STATE> socket_state_map;
typedef std::map<TCPsocket, std::pair<network::statistics,network::statistics> > socket_stats_map;

socket_state_map sockets_locked[NUM_SHARDS];
socket_stats_map transfer_stats; // stats_mutex

int socket_errors[NUM_SHARDS];
threading::mutex* shard_mutexes[NUM_SHARDS];
threading::mutex* stats_mutex = NULL;
threading::mutex* schemas_mutex = NULL;
threading::mutex* received_mutex = NULL;
threading::condition* cond[NUM_SHARDS];

std::map<Uint32,threading::thread*> threads[NUM_SHARDS];
std::vector<Uint32> to_clear[NUM_SHARDS];
int system_send_buffer_size = 0;
bool network_use_system_sendfile = false;

int receive_bytes(TCPsocket s, char* buf, size_t nbytes)
{
#ifdef NETWORK_USE_RAW_SOCKETS
      _TCPsocket* sock = reinterpret_cast<_TCPsocket*>(s);
      int res = 0;
      do {
            errno = 0;
            res = recv(sock->channel, buf, nbytes, MSG_DONTWAIT);
      } while(errno == EINTR);
      sock->ready = 0;
      return res;
#else
      return SDLNet_TCP_Recv(s, buf, nbytes);
#endif
}


void check_send_buffer_size(TCPsocket& s)
{
      if (system_send_buffer_size)
            return;
      _TCPsocket* sock = reinterpret_cast<_TCPsocket*>(s);
      socklen_t len = sizeof(system_send_buffer_size);
#ifdef _WIN32
      getsockopt(sock->channel, SOL_SOCKET, SO_RCVBUF,reinterpret_cast<char*>(&system_send_buffer_size), &len);
#else
      getsockopt(sock->channel, SOL_SOCKET, SO_RCVBUF,&system_send_buffer_size, &len);
#endif
      --system_send_buffer_size;
      DBG_NW << "send buffer size: " << system_send_buffer_size << "\n";
}

bool receive_with_timeout(TCPsocket s, char* buf, size_t nbytes,
            bool update_stats=false, int idle_timeout_ms=30000,
            int total_timeout_ms=300000)
{
#if !defined(USE_POLL) && !defined(USE_SELECT)
      int startTicks = SDL_GetTicks();
      int time_used = 0;
#endif
      int timeout_ms = idle_timeout_ms;
      while(nbytes > 0) {
            const int bytes_read = receive_bytes(s, buf, nbytes);
            if(bytes_read == 0) {
                  return false;
            } else if(bytes_read < 0) {
#if defined(EAGAIN) && !defined(__BEOS__) && !defined(_WIN32)
                  if(errno == EAGAIN)
#elif defined(EWOULDBLOCK)
                  if(errno == EWOULDBLOCK)
#elif defined(_WIN32) && defined(WSAEWOULDBLOCK)
                  if(WSAGetLastError() == WSAEWOULDBLOCK)
#else
                  // assume non-recoverable error.
                  if(false)
#endif
                  {
#ifdef USE_POLL
                        struct pollfd fd = { ((_TCPsocket*)s)->channel, POLLIN, 0 };
                        int poll_res;

                        //we timeout of the poll every 100ms. This lets us check to
                        //see if we have been disconnected, in which case we should
                        //abort the receive.
                        const int poll_timeout = std::min(timeout_ms, 100);
                        do {
                              poll_res = poll(&fd, 1, poll_timeout);

                              if(poll_res == 0) {
                                    timeout_ms -= poll_timeout;
                                    total_timeout_ms -= poll_timeout;
                                    if(timeout_ms <= 0 || total_timeout_ms <= 0) {
                                          //we've been waiting too long; abort the receive
                                          //as having failed due to timeout.
                                          return false;
                                    }

                                    //check to see if we've been interrupted
                                    const size_t shard = get_shard(s);
                                    const threading::lock lock(*shard_mutexes[shard]);
                                    socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
                                    assert(lock_it != sockets_locked[shard].end());
                                    if(lock_it->second == SOCKET_INTERRUPT) {
                                          return false;
                                    }
                              }

                        } while(poll_res == 0 || (poll_res == -1 && errno == EINTR));

                        if (poll_res < 1)
                              return false;
#elif defined(USE_SELECT)
                        int retval;
                        const int select_timeout = std::min(timeout_ms, 100);
                        do {
                              fd_set readfds;
                              FD_ZERO(&readfds);
                              FD_SET(((_TCPsocket*)s)->channel, &readfds);
                              struct timeval tv;
                              tv.tv_sec = select_timeout/1000;
                              tv.tv_usec = select_timeout % 1000 * 1000;
                              retval = select(((_TCPsocket*)s)->channel + 1, &readfds, NULL, NULL, &tv);
                              DBG_NW << "select retval: " << retval << ", timeout idle " << timeout_ms
                                    << " total " << total_timeout_ms << " (ms)\n";
                              if(retval == 0) {
                                    timeout_ms -= select_timeout;
                                    total_timeout_ms -= select_timeout;
                                    if(timeout_ms <= 0 || total_timeout_ms <= 0) {
                                          //we've been waiting too long; abort the receive
                                          //as having failed due to timeout.
                                          return false;
                                    }

                                    //check to see if we've been interrupted
                                    const size_t shard = get_shard(s);
                                    const threading::lock lock(*shard_mutexes[shard]);
                                    socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
                                    assert(lock_it != sockets_locked[shard].end());
                                    if(lock_it->second == SOCKET_INTERRUPT) {
                                          return false;
                                    }
                              }
                        } while(retval == 0 || retval == -1 && errno == EINTR);

                        if (retval < 1) {
                              return false;
                        }
#else
                        //TODO: consider replacing this with a select call
                        time_used = SDL_GetTicks() - startTicks;
                        if(time_used >= timeout_ms) {
                              return false;
                        }
                        SDL_Delay(20);
#endif
                  } else {
                        return false;
                  }
            } else {
                  timeout_ms = idle_timeout_ms;
                  buf += bytes_read;
                  if(update_stats && !raw_data_only) {
                        const threading::lock lock(*stats_mutex);
                        transfer_stats[s].second.transfer(static_cast<size_t>(bytes_read));
                  }

                  if(bytes_read > static_cast<int>(nbytes)) {
                        return false;
                  }
                  nbytes -= bytes_read;
                  // We got some data from server so reset start time so slow conenction won't timeout.
#if !defined(USE_POLL) && !defined(USE_SELECT)
                  startTicks = SDL_GetTicks();
#endif
            }
            {
                  const size_t shard = get_shard(s);
                  const threading::lock lock(*shard_mutexes[shard]);
                  socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
                  assert(lock_it != sockets_locked[shard].end());
                  if(lock_it->second == SOCKET_INTERRUPT) {
                        return false;
                  }
            }
      }

      return true;
}

static void output_to_buffer(TCPsocket sock, const config& cfg, std::ostringstream& compressor, bool gzipped)
{
      if(gzipped) {
            config_writer writer(compressor, true);
            writer.write(cfg);
      } else {
            compression_schema *compress;
            {
                  const threading::lock lock(*schemas_mutex);
                  compress = &schemas.insert(std::pair<TCPsocket,schema_pair>(sock,schema_pair())).first->second.outgoing;
            }
            write_compressed(compressor, cfg, *compress);
      }
}

static void make_network_buffer(const char* input, int len, std::vector<char>& buf)
{
      buf.resize(4 + len);
      SDLNet_Write32(len, &buf[0]);
      memcpy(&buf[4], input, len);
}

static SOCKET_STATE send_buffer(TCPsocket sock, std::vector<char>& buf, int in_size = -1)
{
#ifdef __BEOS__
      int timeout = 60000;
#endif
//    check_send_buffer_size(sock);
      size_t upto = 0;
      size_t size = buf.size();
      if (in_size != -1)
            size = in_size;
      int send_len = 0;

      if (!raw_data_only)
      {
            const threading::lock lock(*stats_mutex);
            transfer_stats[sock].first.fresh_current(size);
      }
#ifdef __BEOS__
      while(upto < size && timeout > 0) {
#else
      while(true) {
#endif
            {
                  const size_t shard = get_shard(sock);
                  // check if the socket is still locked
                  const threading::lock lock(*shard_mutexes[shard]);
                  if(sockets_locked[shard][sock] != SOCKET_LOCKED)
                  {
                        return SOCKET_ERRORED;
                  }
            }
            send_len = static_cast<int>(size - upto);
            const int res = SDLNet_TCP_Send(sock, &buf[upto],send_len);


            if( res == send_len) {
                  if (!raw_data_only)
                  {
                        const threading::lock lock(*stats_mutex);
                        transfer_stats[sock].first.transfer(static_cast<size_t>(res));
                  }
                  return SOCKET_READY;
            }
#if defined(_WIN32)
            if(WSAGetLastError() == WSAEWOULDBLOCK)
#elif defined(EAGAIN) && !defined(__BEOS__)
            if(errno == EAGAIN)
#elif defined(EWOULDBLOCK)
            if(errno == EWOULDBLOCK)
#endif
            {
                  // update how far we are
                  upto += static_cast<size_t>(res);
                  if (!raw_data_only)
                  {
                        const threading::lock lock(*stats_mutex);
                        transfer_stats[sock].first.transfer(static_cast<size_t>(res));
                  }

#ifdef USE_POLL
                  struct pollfd fd = { ((_TCPsocket*)sock)->channel, POLLOUT, 0 };
                  int poll_res;
                  do {
                        poll_res = poll(&fd, 1, 60000);
                  } while(poll_res == -1 && errno == EINTR);


                  if(poll_res > 0)
                        continue;
#elif defined(USE_SELECT) && !defined(__BEOS__)
                  fd_set writefds;
                  FD_ZERO(&writefds);
                  FD_SET(((_TCPsocket*)sock)->channel, &writefds);
                  int retval;
                  struct timeval tv;
                  tv.tv_sec = 60;
                  tv.tv_usec = 0;

                  do {
                        retval = select(((_TCPsocket*)sock)->channel + 1, NULL, &writefds, NULL, &tv);
                  } while(retval == -1 && errno == EINTR);

                  if(retval > 0)
                        continue;
#elif defined(__BEOS__)
                  if(res > 0) {
                        // some data was sent, reset timeout
                        timeout = 60000;
                  } else {
                        // sleep for 100 milliseconds
                        SDL_Delay(100);
                        timeout -= 100;
                  }
                  continue;
#endif
            }

            return SOCKET_ERRORED;
      }
}

#ifdef HAVE_SENDFILE
struct cork_setter {
      cork_setter(int socket) : cork_(1), socket_(socket)
      {
            setsockopt(socket_, IPPROTO_TCP, TCP_CORK, &cork_, sizeof(cork_));;
      }
      ~cork_setter()
      {
            cork_ = 0;
            setsockopt(socket_, IPPROTO_TCP, TCP_CORK, &cork_, sizeof(cork_));
      }
      private:
      int cork_;
      int socket_;
};

struct close_fd {
          void operator()(int fd) const { close(fd); }
};
typedef util::scoped_resource<int, close_fd> scoped_fd;
#endif

static SOCKET_STATE send_file(buffer* buf)
{
      size_t upto = 0;
      size_t filesize = file_size(buf->config_error);
#ifdef HAVE_SENDFILE
      // implements linux sendfile support
      LOG_NW << "send_file use system sendfile: " << (network_use_system_sendfile?"yes":"no") << "\n";
      if (network_use_system_sendfile)
      {
            std::vector<char> buffer;
            buffer.resize(4);
            SDLNet_Write32(filesize,&buffer[0]);
            int socket = reinterpret_cast<_TCPsocket*>(buf->sock)->channel;
            const scoped_fd in_file(open(buf->config_error.c_str(), O_RDONLY));
            cork_setter set_socket_cork(socket);
            int poll_res;
            struct pollfd fd = {socket, POLLOUT, 0 };
            do {
                  poll_res = poll(&fd, 1, 600000);
            } while(poll_res == -1 && errno == EINTR);

            SOCKET_STATE result;
            if (poll_res > 0)
                  result = send_buffer(buf->sock, buffer, 4);
            else
                  result = SOCKET_ERRORED;


            if (result != SOCKET_READY)
            {
                  return result;
            }
            result = SOCKET_READY;

            while (true)
            {

                  do {
                        poll_res = poll(&fd, 1, 600000);
                  } while(poll_res == -1 && errno == EINTR);

                  if (poll_res <= 0 )
                  {
                        result = SOCKET_ERRORED;
                        break;
                  }


                  int bytes = ::sendfile(socket, in_file, 0, filesize);

                  if (bytes == -1)
                  {
                        if (errno == EAGAIN)
                              continue;
                        result = SOCKET_ERRORED;
                        break;
                  }
                  upto += bytes;


                  if (upto == filesize)
                  {
                        break;
                  }
            }

            return result;
      }
#endif
      // default sendfile implementation
      // if no system implementation is enabled
      int send_size = 0;
      // reserve 1024*8 bytes buffer
      buf->raw_buffer.resize(std::min<size_t>(1024*8, filesize));
      SDLNet_Write32(filesize,&buf->raw_buffer[0]);
      scoped_istream file_stream = istream_file(buf->config_error);
      SOCKET_STATE result = send_buffer(buf->sock, buf->raw_buffer, 4);

      if (!file_stream->good()) {
            ERR_NW << "send_file: Couldn't open file " << buf->config_error << "\n";
      }
      if (result != SOCKET_READY)
      {
            return result;
      }
      while (file_stream->good())
      {
            // read data
            file_stream->read(&buf->raw_buffer[0], buf->raw_buffer.size());
            send_size = file_stream->gcount();
            upto += send_size;
            // send data to socket
            result = send_buffer(buf->sock, buf->raw_buffer, send_size);
            if (result != SOCKET_READY)
            {
                  break;
            }
            if (upto == filesize)
            {
                  break;
            }

      }
      if (upto != filesize && !file_stream->good()) {
            ERR_NW << "send_file failed because the stream from file '"
                  << buf->config_error << "' is not good. Sent up to: " << upto
                  << " of file size: " << filesize << "\n";
      }
      return result;
}

static SOCKET_STATE receive_buf(TCPsocket sock, std::vector<char>& buf)
{
      char num_buf[4] ALIGN_4;
      bool res = receive_with_timeout(sock,num_buf,4,false);

      if(!res) {
            return SOCKET_ERRORED;
      }

      const int len = SDLNet_Read32(reinterpret_cast<void*>(num_buf));

      if(len < 1 || len > 100000000) {
            return SOCKET_ERRORED;
      }

      buf.resize(len);
      char* beg = &buf[0];
      const char* const end = beg + len;

      if (!raw_data_only)
      {
            const threading::lock lock(*stats_mutex);
            transfer_stats[sock].second.fresh_current(len);
      }

      res = receive_with_timeout(sock, beg, end - beg, true);
      if(!res) {
            return SOCKET_ERRORED;
      }

      return SOCKET_READY;
}

inline void check_socket_result(TCPsocket& sock, SOCKET_STATE& result)
{
      const size_t shard = get_shard(sock);
      const threading::lock lock(*shard_mutexes[shard]);
      socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
      assert(lock_it != sockets_locked[shard].end());
      lock_it->second = result;
      if(result == SOCKET_ERRORED) {
            ++socket_errors[shard];
      }
}

static int process_queue(void* shard_num)
{
      size_t shard = static_cast<size_t>(reinterpret_cast<uintptr_t>(shard_num));
      DBG_NW << "thread started...\n";
      for(;;) {

            //if we find a socket to send data to, sent_buf will be non-NULL. If we find a socket
            //to receive data from, sent_buf will be NULL. 'sock' will always refer to the socket
            //that data is being sent to/received from
            TCPsocket sock = NULL;
            buffer* sent_buf = 0;

            {
                  const threading::lock lock(*shard_mutexes[shard]);
                  while(managed && !to_clear[shard].empty()) {
                        Uint32 tmp = to_clear[shard].back();
                        to_clear[shard].pop_back();
                        threading::thread *zombie = threads[shard][tmp];
                        threads[shard].erase(tmp);
                        delete zombie;

                  }
                  if(min_threads && waiting_threads[shard] >= min_threads) {
                              DBG_NW << "worker thread exiting... not enough jobs\n";
                              to_clear[shard].push_back(threading::get_current_thread_id());
                              return 0;
                  }
                  waiting_threads[shard]++;
                  for(;;) {

                        buffer_set::iterator itor = outgoing_bufs[shard].begin(), itor_end = outgoing_bufs[shard].end();
                        for(; itor != itor_end; ++itor) {
                              socket_state_map::iterator lock_it = sockets_locked[shard].find((*itor)->sock);
                              assert(lock_it != sockets_locked[shard].end());
                              if(lock_it->second == SOCKET_READY) {
                                    lock_it->second = SOCKET_LOCKED;
                                    sent_buf = *itor;
                                    sock = sent_buf->sock;
                                    outgoing_bufs[shard].erase(itor);
                                    break;
                              }
                        }

                        if(sock == NULL) {
                              receive_list::iterator itor = pending_receives[shard].begin(), itor_end = pending_receives[shard].end();
                              for(; itor != itor_end; ++itor) {
                                    socket_state_map::iterator lock_it = sockets_locked[shard].find(*itor);
                                    assert(lock_it != sockets_locked[shard].end());
                                    if(lock_it->second == SOCKET_READY) {
                                          lock_it->second = SOCKET_LOCKED;
                                          sock = *itor;
                                          pending_receives[shard].erase(itor);
                                          break;
                                    }
                              }
                        }

                        if(sock != NULL) {
                              break;
                        }

                        if(managed == false) {
                              DBG_NW << "worker thread exiting...\n";
                              waiting_threads[shard]--;
                              to_clear[shard].push_back(threading::get_current_thread_id());
                              return 0;
                        }

                        cond[shard]->wait(*shard_mutexes[shard]); // temporarily release the mutex and wait for a buffer
                  }
                  waiting_threads[shard]--;
                  // if we are the last thread in the pool, create a new one
                  if(!waiting_threads[shard] && managed == true) {
                        // max_threads of 0 is unlimited
                        if(!max_threads || max_threads >threads[shard].size()) {
                              threading::thread * tmp = new threading::thread(process_queue,shard_num);
                              threads[shard][tmp->get_id()] =tmp;
                        }
                  }
            }

            assert(sock);

            DBG_NW << "thread found a buffer...\n";

            SOCKET_STATE result = SOCKET_READY;
            std::vector<char> buf;

            if(sent_buf) {

                  if(!sent_buf->config_error.empty())
                  {
                        // We have file to send over net
                        result = send_file(sent_buf);
                  } else {
                        if(sent_buf->raw_buffer.empty()) {
                              const std::string &value = sent_buf->stream.str();
                              make_network_buffer(value.c_str(), value.size(), sent_buf->raw_buffer);
                        }

                        result = send_buffer(sent_buf->sock, sent_buf->raw_buffer);
                  }
                  delete sent_buf;
            } else {
                  result = receive_buf(sock,buf);
            }


            if(result != SOCKET_READY || buf.empty())
            {
                  check_socket_result(sock,result);
                        continue;
            }
            //if we received data, add it to the queue
            buffer* received_data = new buffer(sock);

            if(raw_data_only) {
                  received_data->raw_buffer.swap(buf);
            } else {
                  std::string buffer(buf.begin(), buf.end());
                  std::istringstream stream(buffer);
                  // Binary wml starts with a char < 4, the first char of a gzip header is 31
                  // so test that here and use the proper reader.
                  try {
                        if(stream.peek() == 31) {
                              read_gz(received_data->config_buf, stream);
                              received_data->gzipped = true;
                        } else {
                              compression_schema *compress;
                              {
                                    const threading::lock lock_schemas(*schemas_mutex);
                                    compress = &schemas.insert(std::pair<TCPsocket,schema_pair>(sock,schema_pair())).first->second.incoming;
                              }
                              read_compressed(received_data->config_buf, stream, *compress);
                              received_data->gzipped = false;
                        }
                  } catch(config::error &e)
                  {
                        received_data->config_error = e.message;
                  }
            }

            {
                  // Now add data
                  const threading::lock lock_received(*received_mutex);
                  received_data_queue.push_back(received_data);
            }
            check_socket_result(sock,result);
      }
      // unreachable
}

} //anonymous namespace

namespace network_worker_pool
{

manager::manager(size_t p_min_threads,size_t p_max_threads) : active_(!managed)
{
      if(active_) {
            managed = true;
            for(int i = 0; i != NUM_SHARDS; ++i) {
                  shard_mutexes[i] = new threading::mutex();
                  cond[i] = new threading::condition();
            }
            stats_mutex = new threading::mutex();
            schemas_mutex = new threading::mutex();
            received_mutex = new threading::mutex();

            min_threads = p_min_threads;
            max_threads = p_max_threads;

            for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
                  const threading::lock lock(*shard_mutexes[shard]);
                  for(size_t n = 0; n != p_min_threads; ++n) {
                        threading::thread * tmp = new threading::thread(process_queue,(void*)uintptr_t(shard));
                        threads[shard][tmp->get_id()] = tmp;
                  }
            }
      }
}

manager::~manager()
{
      if(active_) {
            managed = false;

            for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
                  {
                        const threading::lock lock(*shard_mutexes[shard]);
                        socket_errors[shard] = 0;
                  }

                  cond[shard]->notify_all();

                  for(std::map<Uint32,threading::thread*>::const_iterator i = threads[shard].begin(); i != threads[shard].end(); ++i) {

                        DBG_NW << "waiting for thread " << i->first << " to exit...\n";
                        delete i->second;
                        DBG_NW << "thread exited...\n";
                  }

                  // Condition variables must be deleted first as
                  // they make reference to mutexs. If the mutexs
                  // are destroyed first, the condition variables
                  // will access memory already freed by way of
                  // stale mutex. Bad things will follow. ;)
                  threads[shard].clear();
                  // Have to clean up to_clear so no bogus clearing of threads
                  to_clear[shard].clear();
                  delete cond[shard];
                  cond[shard] = NULL;
                  delete shard_mutexes[shard];
                  shard_mutexes[shard] = NULL;
            }

            delete stats_mutex;
            delete schemas_mutex;
            delete received_mutex;
            stats_mutex = 0;
            schemas_mutex = 0;
            received_mutex = 0;

            for(int i = 0; i != NUM_SHARDS; ++i) {
                  sockets_locked[i].clear();
            }
            transfer_stats.clear();

            DBG_NW << "exiting manager::~manager()\n";
      }
}

network::pending_statistics get_pending_stats()
{
      network::pending_statistics stats;
      stats.npending_sends = 0;
      stats.nbytes_pending_sends = 0;
      for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
            const threading::lock lock(*shard_mutexes[shard]);
            stats.npending_sends += outgoing_bufs[shard].size();
            for(buffer_set::const_iterator i = outgoing_bufs[shard].begin(); i != outgoing_bufs[shard].end(); ++i) {
                  stats.nbytes_pending_sends += (*i)->raw_buffer.size();
            }
      }

      return stats;
}

void set_raw_data_only()
{
      raw_data_only = true;
}

void set_use_system_sendfile(bool use)
{
      network_use_system_sendfile = use;
}

void receive_data(TCPsocket sock)
{
      {
            const size_t shard = get_shard(sock);
            const threading::lock lock(*shard_mutexes[shard]);
            pending_receives[shard].push_back(sock);

            socket_state_map::const_iterator i = sockets_locked[shard].insert(std::pair<TCPsocket,SOCKET_STATE>(sock,SOCKET_READY)).first;
            if(i->second == SOCKET_READY || i->second == SOCKET_ERRORED) {
                  cond[shard]->notify_one();
            }
      }
}

TCPsocket get_received_data(TCPsocket sock, config& cfg, bool* gzipped,network::bandwidth_in_ptr& bandwidth_in)
{
      assert(!raw_data_only);
      const threading::lock lock_received(*received_mutex);
      received_queue::iterator itor = received_data_queue.begin();
      if(sock != NULL) {
            for(; itor != received_data_queue.end(); ++itor) {
                  if((*itor)->sock == sock) {
                        break;
                  }
            }
      }

      if(itor == received_data_queue.end()) {
            return NULL;
      } else if (!(*itor)->config_error.empty()){
            // throw the error in parent thread
            std::string error = (*itor)->config_error;
            buffer* buf = *itor;
            received_data_queue.erase(itor);
            if (gzipped)
                  *gzipped = buf->gzipped;
            delete buf;
            throw config::error(error);
      } else {
            cfg.swap((*itor)->config_buf);
            const TCPsocket res = (*itor)->sock;
            buffer* buf = *itor;
            if (gzipped)
                  *gzipped = buf->gzipped;
            bandwidth_in.reset(new network::bandwidth_in((*itor)->raw_buffer.size()));
            received_data_queue.erase(itor);
            delete buf;
            return res;
      }
}

TCPsocket get_received_data(std::vector<char>& out)
{
      assert(raw_data_only);
      const threading::lock lock_received(*received_mutex);
      if(received_data_queue.empty()) {
            return NULL;
      }

      buffer* buf = received_data_queue.front();
      received_data_queue.pop_front();
      out.swap(buf->raw_buffer);
      const TCPsocket res = buf->sock;
      delete buf;
      return res;
}

static void queue_buffer(TCPsocket sock, buffer* queued_buf)
{
      const size_t shard = get_shard(sock);
      const threading::lock lock(*shard_mutexes[shard]);
      outgoing_bufs[shard].push_back(queued_buf);
      socket_state_map::const_iterator i = sockets_locked[shard].insert(std::pair<TCPsocket,SOCKET_STATE>(sock,SOCKET_READY)).first;
      if(i->second == SOCKET_READY || i->second == SOCKET_ERRORED) {
            cond[shard]->notify_one();
      }

}

void queue_raw_data(TCPsocket sock, const char* buf, int len)
{
      buffer* queued_buf = new buffer(sock);
      assert(*buf == 31);
      make_network_buffer(buf, len, queued_buf->raw_buffer);
      queue_buffer(sock, queued_buf);
}


void queue_file(TCPsocket sock, const std::string& filename)
{
      buffer* queued_buf = new buffer(sock);
      queued_buf->config_error = filename;
      queue_buffer(sock, queued_buf);
}

size_t queue_data(TCPsocket sock,const config& buf, const bool gzipped, const std::string& packet_type)
{
      DBG_NW << "queuing data...\n";

      buffer* queued_buf = new buffer(sock);
      output_to_buffer(sock, buf, queued_buf->stream, gzipped);
      queued_buf->gzipped = gzipped;
      const size_t size = queued_buf->stream.str().size();

      network::add_bandwidth_out(packet_type, size);
      queue_buffer(sock, queued_buf);
      return size;
}

namespace
{

/** Caller has to make sure to own the mutex for this shard */
void remove_buffers(TCPsocket sock)
{
      {
            const size_t shard = get_shard(sock);
            for(buffer_set::iterator i = outgoing_bufs[shard].begin(); i != outgoing_bufs[shard].end();) {
                  if ((*i)->sock == sock)
                  {
                        buffer* buf = *i;
                        i = outgoing_bufs[shard].erase(i);
                        delete buf;
                  }
                  else
                  {
                        ++i;
                  }
            }
      }

      {
            const threading::lock lock_receive(*received_mutex);

            for(received_queue::iterator j = received_data_queue.begin(); j != received_data_queue.end(); ) {
                  if((*j)->sock == sock) {
                        buffer *buf = *j;
                        j = received_data_queue.erase(j);
                        delete buf;
                  } else {
                        ++j;
                  }
            }
      }
}

} // anonymous namespace

bool is_locked(const TCPsocket sock) {
      const size_t shard = get_shard(sock);
      const threading::lock lock(*shard_mutexes[shard]);
      const socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
      if (lock_it == sockets_locked[shard].end()) return false;
      return (lock_it->second == SOCKET_LOCKED);
}

bool close_socket(TCPsocket sock)
{
      {
            const size_t shard = get_shard(sock);
            const threading::lock lock(*shard_mutexes[shard]);

            pending_receives[shard].erase(std::remove(pending_receives[shard].begin(),pending_receives[shard].end(),sock),pending_receives[shard].end());
            {
                  const threading::lock lock_schemas(*schemas_mutex);
                  schemas.erase(sock);
            }

            const socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
            if(lock_it == sockets_locked[shard].end()) {
                  remove_buffers(sock);
                  return true;
            }
            if (!(lock_it->second == SOCKET_LOCKED || lock_it->second == SOCKET_INTERRUPT)) {
                  sockets_locked[shard].erase(lock_it);
                  remove_buffers(sock);
                  return true;
            } else {
                  lock_it->second = SOCKET_INTERRUPT;
                  return false;
            }

      }


}

TCPsocket detect_error()
{
      for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
            const threading::lock lock(*shard_mutexes[shard]);
            if(socket_errors[shard] > 0) {
                  for(socket_state_map::iterator i = sockets_locked[shard].begin(); i != sockets_locked[shard].end();) {
                        if(i->second == SOCKET_ERRORED) {
                              --socket_errors[shard];
                              const TCPsocket sock = i->first;
                              sockets_locked[shard].erase(i++);
                              pending_receives[shard].erase(std::remove(pending_receives[shard].begin(),pending_receives[shard].end(),sock),pending_receives[shard].end());
                              remove_buffers(sock);
                              const threading::lock lock_schema(*schemas_mutex);
                              schemas.erase(sock);
                              return sock;
                        }
                        else
                        {
                              ++i;
                        }
                  }
            }

            socket_errors[shard] = 0;
      }

      return 0;
}

std::pair<network::statistics,network::statistics> get_current_transfer_stats(TCPsocket sock)
{
      const threading::lock lock(*stats_mutex);
      return transfer_stats[sock];
}

} // network_worker_pool namespace



Generated by  Doxygen 1.6.0   Back to index