DEADSOFTWARE

master: increase enet_host_service() timeout
[d2df-sdl.git] / src / mastersrv / master.c
index 0b02b2d2854bdf397125affff7506ebbb764723a..6573cc62337cd76c634b40626bf940ee069b2775 100644 (file)
@@ -8,6 +8,7 @@
 #include <time.h>
 #include <signal.h>
 
+#define ENET_DEBUG 1
 #include <enet/enet.h>
 #include <enet/types.h>
 
@@ -21,7 +22,9 @@
 #define DEFAULT_SPAM_CAP 10
 #define DEFAULT_MAX_SERVERS MS_MAX_SERVERS
 #define DEFAULT_MAX_PER_HOST 4
-#define DEFAULT_TIMEOUT 100
+#define DEFAULT_SERVER_TIMEOUT 100
+#define DEFAULT_CLIENT_TIMEOUT 3
+#define DEFAULT_SPAM_TIMEOUT 1
 #define DEFAULT_PORT 25665
 
 #define NET_BUFSIZE 65536
@@ -89,6 +92,7 @@ typedef struct server_s {
   char        map[MAX_STRLEN + 2];
   time_t      death_time;
   time_t      timestamp;
+  ENetPeer   *peer; // who sent this server in
 } server_t;
 
 // real servers
@@ -122,7 +126,9 @@ static ban_record_t *banlist;
 
 // settings
 static int ms_port = DEFAULT_PORT;
-static int ms_timeout = DEFAULT_TIMEOUT;
+static int ms_sv_timeout = DEFAULT_SERVER_TIMEOUT;
+static int ms_cl_timeout = DEFAULT_CLIENT_TIMEOUT;
+static int ms_spam_timeout = DEFAULT_SPAM_TIMEOUT;
 static int ms_spam_cap = DEFAULT_SPAM_CAP;
 static char ms_motd[MAX_STRLEN + 1] = "";
 static char ms_urgent[MAX_STRLEN + 1] = "";
@@ -338,24 +344,32 @@ void b_write_server(enet_buf_t *buf, const server_t *s) {
 
 /* server functions */
 
-static void sv_remove(const enet_uint32 host, const enet_uint16 port) {
-  for (int i = 0; i < max_servers; ++i) {
-    if (servers[i].host == host && servers[i].port == port) {
-      servers[i].host = 0;
-      servers[i].port = 0;
-      --num_servers;
+static inline void sv_remove(server_t *sv) {
+  if (sv->host) {
+    // drop the associated peer, if any
+    if (sv->peer && sv->peer->state == ENET_PEER_STATE_CONNECTED && sv->peer->data == sv) {
+      sv->peer->data = NULL;
+      sv->peer = NULL;
+      enet_peer_reset(sv->peer);
     }
+    sv->host = 0;
+    sv->port = 0;
+    --num_servers;
+  }
+}
+
+static void sv_remove_by_addr(const enet_uint32 host, const enet_uint16 port) {
+  for (int i = 0; i < max_servers; ++i) {
+    if (servers[i].host == host && servers[i].port == port)
+      sv_remove(servers + i);
   }
 }
 
 static void sv_remove_by_host(enet_uint32 host, enet_uint32 mask) {
   host &= mask;
   for (int i = 0; i < max_servers; ++i) {
-    if (servers[i].host && (servers[i].host & mask) == host) {
-      servers[i].host = 0;
-      servers[i].port = 0;
-      --num_servers;
-    }
+    if (servers[i].host && (servers[i].host & mask) == host)
+      sv_remove(servers + i);
   }
 }
 
@@ -580,6 +594,7 @@ static void ban_add(const enet_uint32 host, const char *reason) {
 static inline void ban_peer(ENetPeer *peer, const char *reason) {
   if (peer) {
     ban_add(peer->address.host, reason);
+    peer->data = NULL;
     enet_peer_reset(peer);
   }
 }
@@ -648,8 +663,15 @@ static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
           return true;
         }
         // only then update the times
-        sv->death_time = now + ms_timeout;
+        sv->death_time = now + ms_sv_timeout;
         sv->timestamp = now;
+        // check if we're updating from a new peer
+        if (sv->peer != peer) {
+          // if there was an old one, kill it
+          if (sv->peer)
+            enet_peer_reset(peer);
+          sv->peer = peer;
+        }
         u_log(LOG_NOTE, "updated server #%d:", sv - servers);
         u_printsv(sv);
       } else {
@@ -675,7 +697,7 @@ static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
         // then add that shit
         *sv = tmpsv;
         sv->host = peer->address.host;
-        sv->death_time = now + ms_timeout;
+        sv->death_time = now + ms_sv_timeout;
         sv->timestamp = now;
         if (!ban_sanity_check(sv)) {
           sv->host = 0;
@@ -683,6 +705,8 @@ static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
           ban_peer(peer, "tripped sanity check");
           return true;
         }
+        sv->peer = peer;
+        peer->data = sv;
         ++num_servers;
         u_log(LOG_NOTE, "added new server #%d:", sv - servers);
         u_printsv(sv);
@@ -695,7 +719,10 @@ static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
         ban_peer(peer, "malformed MSG_RM");
         return true;
       }
-      sv_remove(peer->address.host, tmpsv.port);
+      sv_remove_by_addr(peer->address.host, tmpsv.port);
+      // this peer can be disconnected pretty much immediately since he has no servers left, tell him to fuck off
+      peer->data = NULL;
+      enet_peer_disconnect_later(peer, 0);
       return true;
 
     case NET_MSG_LIST:
@@ -741,7 +768,10 @@ static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
 
       ENetPacket *p = enet_packet_create(buf_send.data, buf_send.pos, ENET_PACKET_FLAG_RELIABLE);
       enet_peer_send(peer, NET_CH_MAIN, p);
-      enet_host_flush(ms_host);
+      // enet_host_flush(ms_host);
+
+      // this peer can be disconnected pretty much immediately after receiving the server list, tell him to fuck off
+      enet_peer_disconnect_later(peer, 0);
 
       u_log(LOG_NOTE, "sent server list to %s:%d (ver %s)", u_iptostr(peer->address.host), peer->address.port, clientver[0] ? clientver : "<old>");
       return true;
@@ -758,10 +788,12 @@ static void print_usage(void) {
   printf("Available options:\n");
   printf("-h     show this message and exit\n");
   printf("-p N   listen on port N (default: %d)\n", DEFAULT_PORT);
-  printf("-t N   seconds before server is removed from list (default: %d)\n", DEFAULT_TIMEOUT);
+  printf("-t N   seconds before server is removed from list (default: %d)\n", DEFAULT_SERVER_TIMEOUT);
+  printf("-c N   how long a client is allowed to hold the connection active (default: %d)\n", DEFAULT_CLIENT_TIMEOUT);
   printf("-s N   max number of servers in server list, 1-%d (default: %d)\n", MS_MAX_SERVERS, DEFAULT_MAX_SERVERS);
   printf("-d N   if N > 0, disallow more than N servers on the same IP (default: %d)\n", DEFAULT_MAX_PER_HOST);
-  printf("-f N   crappy spam filter: ban people after they send N requests in a row too fast (default: %d)\n", DEFAULT_SPAM_CAP);
+  printf("-f N   crappy spam filter: ban clients after they send N requests in a row too fast (default: %d)\n", DEFAULT_SPAM_CAP);
+  printf("-w N   how often does a client have to send packets for the filter to kick in, i.e. once every N sec (default: %d)\n", DEFAULT_SPAM_TIMEOUT);
   fflush(stdout);
 }
 
@@ -796,10 +828,12 @@ static bool parse_args(int argc, char **argv) {
   for (int i = 1; i < argc; ++i) {
     const bool success =
          parse_int_arg(argc, argv, i, "-p", 1, 0xFFFF, &ms_port)
-      || parse_int_arg(argc, argv, i, "-t", 1, 0x7FFFFFFF, &ms_timeout)
+      || parse_int_arg(argc, argv, i, "-t", 1, 0x7FFFFFFF, &ms_sv_timeout)
+      || parse_int_arg(argc, argv, i, "-c", 1, 0x7FFFFFFF, &ms_cl_timeout)
       || parse_int_arg(argc, argv, i, "-s", 1, MS_MAX_SERVERS, &max_servers)
       || parse_int_arg(argc, argv, i, "-d", 0, MS_MAX_SERVERS, &max_servers_per_host)
-      || parse_int_arg(argc, argv, i, "-f", 0, 0xFFFF, &ms_spam_cap);
+      || parse_int_arg(argc, argv, i, "-f", 0, 0xFFFF, &ms_spam_cap)
+      || parse_int_arg(argc, argv, i, "-w", 1, 0x7FFFFFFF, &ms_spam_timeout);
     if (success) {
       ++i;
     } else {
@@ -812,11 +846,10 @@ static bool parse_args(int argc, char **argv) {
 }
 
 // a stupid thing to filter sustained spam from a single IP
-static bool spam_filter(ENetPeer *peer) {
-  const time_t now = time(NULL);
+static bool spam_filter(ENetPeer *peer, const time_t now) {
   if (peer->address.host == cl_last_addr) {
     // spam === sending shit faster than once a second
-    if (now - cl_last_time < 1) {
+    if (now - cl_last_time < ms_spam_timeout) {
       if (cl_spam_cnt > 1)
         u_log(LOG_WARN, "address %s is sending packets too fast", u_iptostr(peer->address.host));
       if (++cl_spam_cnt >= ms_spam_cap) {
@@ -835,6 +868,11 @@ static bool spam_filter(ENetPeer *peer) {
   return false;
 }
 
+// filter incoming UDP packets before the protocol kicks in
+static int packet_filter(ENetHost *host, ENetEvent *event) {
+  return !!ban_check(host->receivedAddress.host);
+}
+
 int main(int argc, char **argv) {
   if (enet_initialize() != 0)
     u_fatal("could not init enet");
@@ -865,22 +903,28 @@ int main(int argc, char **argv) {
   ENetAddress addr;
   addr.host = 0;
   addr.port = ms_port;
-  ms_host = enet_host_create(&addr, MS_MAX_CLIENTS, NET_CH_COUNT, 0, 0);
+  ms_host = enet_host_create(&addr, MS_MAX_CLIENTS, NET_CH_COUNT + 1, 0, 0);
   if (!ms_host)
     u_fatal("could not create enet host on port %d", ms_port);
 
+  ms_host->intercept = packet_filter;
+
   bool running = true;
   enet_uint8 msgid = 0;
   ENetEvent event;
   while (running) {
-    while (enet_host_service(ms_host, &event, 1000) > 0) {
-      bool filtered = !event.peer || ban_check(event.peer->address.host);
-      if (!filtered && ms_spam_cap) filtered = spam_filter(event.peer);
+    while (enet_host_service(ms_host, &event, 500) > 0) {
+      const time_t now = time(NULL);
+      const bool filtered = !event.peer || (ms_spam_cap && spam_filter(event.peer, now));
 
       if (!filtered) {
         switch (event.type) {
           case ENET_EVENT_TYPE_CONNECT:
             u_log(LOG_NOTE, "%s:%d connected", u_iptostr(event.peer->address.host), event.peer->address.port);
+            if (event.peer->channelCount != NET_CH_COUNT)
+              ban_peer(event.peer, "what is this");
+            else
+              enet_peer_timeout(event.peer, 0, 0, ms_cl_timeout * 1000);
             break;
 
           case ENET_EVENT_TYPE_RECEIVE:
@@ -901,10 +945,17 @@ int main(int argc, char **argv) {
             }
             break;
 
+          case ENET_EVENT_TYPE_DISCONNECT:
+            event.peer->data = NULL;
+            // u_log(LOG_NOTE, "%s:%d disconnected", u_iptostr(event.peer->address.host), event.peer->address.port);
+            break;
+
           default:
             break;
         }
       } else if (event.peer) {
+        // u_log(LOG_WARN, "filtered event %d from %s", event.type, u_iptostr(event.peer->address.host));
+        event.peer->data = NULL;
         enet_peer_reset(event.peer);
       }
 
@@ -915,14 +966,12 @@ int main(int argc, char **argv) {
     }
 
     const time_t now = time(NULL);
+
+    // time out servers
     for (int i = 0; i < max_servers; ++i) {
-      if (servers[i].host) {
-        if (servers[i].death_time <= now) {
-          u_log(LOG_NOTE, "server #%d %s:%d timed out", i, u_iptostr(servers[i].host), servers[i].port);
-          servers[i].host = 0;
-          servers[i].port = 0;
-          --num_servers;
-        }
+      if (servers[i].host && servers[i].death_time <= now) {
+        u_log(LOG_NOTE, "server #%d %s:%d timed out", i, u_iptostr(servers[i].host), servers[i].port);
+        sv_remove(servers + i);
       }
     }
   }