DEADSOFTWARE

master: anti-overflow patch
[d2df-sdl.git] / src / mastersrv / master.c
index b6d2cd7d8e64b5449ec7df340e8bb868cdc161c6..ebeb2c5c9a5d7640b8ff04e3fdb9bd4a6a85c6d9 100644 (file)
@@ -1,12 +1,15 @@
 #include <stdlib.h>
 #include <stdio.h>
+#include <ctype.h>
 #include <string.h>
+#include <time.h>
+
 #include <enet/enet.h>
 #include <enet/types.h>
-#include <time.h>
 
 #define MS_VERSION "0.2"
 #define MS_MAXSRVS 128
+#define MS_MAXHOST 5
 #define MS_MAXBANS 256
 #define MS_TIMEOUT 100
 #define MS_BANTIME (3 * 86400)
 #define LC_MS_NOBANS "\nCould not load ban list from file\n"
 #define LC_MS_BADADR "\nBad address in file: %s\n"
 #define LC_MS_BANHEUR "tripped heuristic check"
+#define LC_MS_BANTOOMUCH "created too many servers"
+#define LC_MS_BANSPAM "suspicious multiple server activity"
 #define LC_MS_BANLIST "address in ban list"
+#define LC_MS_BANTRASH "garbage server data"
+#define LC_MS_BANINVAL "invalid message ID"
 #define LC_MS_OOM "\nOut of memory\n"
 
 #define MS_URGENT_FILE "urgent.txt"
@@ -57,9 +64,9 @@ typedef struct ms_ban_record_s {
 
 struct ms_server_s {
   enet_uint8 used;
-  char s_ip[17];
-  char s_name[256];
-  char s_map[256];
+  char s_ip[18];
+  char s_name[257];
+  char s_map[257];
   enet_uint8  s_pw;
   enet_uint8  s_plrs;
   enet_uint8  s_maxplrs;
@@ -72,17 +79,24 @@ struct ms_server_s {
 
 typedef struct ms_server_s ms_server;
 
+typedef struct enet_buf_s {
+  enet_uint8 *data;
+  size_t size;
+  size_t pos;
+  int overflow;
+} enet_buf;
+
 const char ms_game_ver[] = "0.63";
 char ms_motd[255] = "";
 char ms_urgent[255] = "";
 
 int ms_port = 25660;
 int ms_timeout = 100;
+int ms_checkmultiple = 0;
 
-size_t b_read = 0;
-size_t b_write = 0;
-
-enet_uint8 b_send[NET_BUFSIZE];
+enet_uint8 b_send_data[NET_BUFSIZE];
+enet_buf b_send = { .data = b_send_data, .size = sizeof(b_send_data) };
+enet_buf b_recv;
 
 ENetHost  *ms_host = NULL;
 ENetPeer  *ms_peers[NET_MAXCLIENTS];
@@ -160,7 +174,10 @@ void d_getargs (int argc, char *argv[]) {
       }
     } else if (!strcmp(argv[i], "-t") & (i + 1 < argc)) {
         ms_timeout = atoi(argv[++i]);
+    } else if (!strcmp(argv[i], "--check-multihost")) {
+        ms_checkmultiple = 1;
     }
+
   }
 }
 
@@ -202,66 +219,81 @@ const char *d_strtime(const time_t t) {
 }
 
 
-enet_uint8 b_read_uint8 (enet_uint8 buf[], size_t *pos) {
-  return buf[(*pos)++];
+static inline int b_enough_left(enet_buf *buf, size_t size) {
+  if (buf->pos + size > buf->size) {
+    buf->overflow = 1;
+    return 0;
+  }
+  return 1;
+}
+
+
+enet_uint8 b_read_uint8 (enet_buf *buf) {
+  if (b_enough_left(buf, 1))
+    return buf->data[buf->pos++];
+  return 0;
 }
 
 
-enet_uint16 b_read_uint16 (enet_uint8 buf[], size_t *pos) {
+enet_uint16 b_read_uint16 (enet_buf *buf) {
   enet_uint16 ret = 0;
 
-  ret = *(enet_uint16*)(buf + *pos);
-  *pos += sizeof(enet_uint16);
+  if (b_enough_left(buf, sizeof(ret))) {
+    ret = *(enet_uint16*)(buf->data + buf->pos);
+    buf->pos += sizeof(ret);
+  }
 
   return ret;
 }
 
 
-char* b_read_dstring (enet_uint8 buf[], size_t *pos) {
+char* b_read_dstring (enet_buf *buf) {
   char *ret = NULL;
 
-  size_t len = b_read_uint8(buf, pos);
-
-  ret = malloc(len + 1);
-
-  memmove(ret, (char*)(buf + *pos), len);
-  ret[len] = '\0';
-  *pos += len;
+  if (b_enough_left(buf, 1)) {
+    size_t len = b_read_uint8(buf);
+    if (b_enough_left(buf, len)) {
+      ret = malloc(len + 1);
+      memmove(ret, (char*)(buf->data + buf->pos), len);
+      buf->pos += len;
+      ret[len] = '\0';
+    }
+  }
 
   return ret;
 }
 
 
-void b_write_uint8 (enet_uint8 buf[], size_t *pos, enet_uint8 val) {
-  buf[(*pos)++] = val;
+void b_write_uint8 (enet_buf *buf, enet_uint8 val) {
+  buf->data[buf->pos++] = val;
 }
 
 
-void b_write_uint16 (enet_uint8 buf[], size_t *pos, enet_uint16 val) {
-  *(enet_uint16*)(buf + *pos) = val;
-  *pos += sizeof(enet_uint16);
+void b_write_uint16 (enet_buf *buf, enet_uint16 val) {
+  *(enet_uint16*)(buf->data + buf->pos) = val;
+  buf->pos += sizeof(val);
 }
 
 
-void b_write_dstring (enet_uint8 buf[], size_t *pos, const char* val) {
+void b_write_dstring (enet_buf *buf, const char* val) {
   enet_uint8 len = strlen(val);
-  b_write_uint8(buf, pos, len);
+  b_write_uint8(buf, len);
 
-  memmove((char*)(buf + *pos), val, len);
-  *pos += len;
+  memmove((char*)(buf->data + buf->pos), val, len);
+  buf->pos += len;
 }
 
 
-void b_write_server (enet_uint8 buf[], size_t *pos, ms_server s) {
-  b_write_dstring(b_send, pos, s.s_ip);
-  b_write_uint16 (b_send, pos, s.s_port);
-  b_write_dstring(b_send, pos, s.s_name);
-  b_write_dstring(b_send, pos, s.s_map);
-  b_write_uint8  (b_send, pos, s.s_mode);
-  b_write_uint8  (b_send, pos, s.s_plrs);
-  b_write_uint8  (b_send, pos, s.s_maxplrs);
-  b_write_uint8  (b_send, pos, s.s_protocol);
-  b_write_uint8  (b_send, pos, s.s_pw);
+void b_write_server (enet_buf *b_send, ms_server s) {
+  b_write_dstring(b_send, s.s_ip);
+  b_write_uint16 (b_send, s.s_port);
+  b_write_dstring(b_send, s.s_name);
+  b_write_dstring(b_send, s.s_map);
+  b_write_uint8  (b_send, s.s_mode);
+  b_write_uint8  (b_send, s.s_plrs);
+  b_write_uint8  (b_send, s.s_maxplrs);
+  b_write_uint8  (b_send, s.s_protocol);
+  b_write_uint8  (b_send, s.s_pw);
 }
 
 
@@ -384,7 +416,7 @@ int ban_heur (const ms_server *srv, const time_t now) {
   int score = 0;
 
   // can't have more than 24 maxplayers; can't have more than max
-  if (srv->s_plrs > srv->s_maxplrs || srv->s_maxplrs > 24)
+  if (srv->s_plrs > srv->s_maxplrs || srv->s_maxplrs > 24 || srv->s_maxplrs == 0)
     score += MS_MAXHEUR;
 
   // name and map have to be non-garbage
@@ -418,6 +450,27 @@ int ban_heur (const ms_server *srv, const time_t now) {
   return score;
 }
 
+void erase_banned_host(const char *ip) {
+  for (int i = 0; i < MS_MAXSRVS; ++i) {
+    if (!strcmp(ms_srv[i].s_ip, ip)) {
+      if (ms_srv[i].used) {
+        ms_srv[i].used = 0;
+        ms_count--;
+      }
+    }
+  }
+}
+
+time_t get_sum_lasttime(char* ip) {
+  time_t sumLastTime = 0;
+  const time_t now = time(NULL);
+  for (int i = 0; i < MS_MAXSRVS; ++i) {
+    if (ms_srv[i].used && (strncmp(ip, ms_srv[i].s_ip, 16) == 0)) {
+      sumLastTime = sumLastTime + (now - ms_srv[i].lasttime);
+    }
+  }
+  return sumLastTime;
+}
 
 void ban_add (const ENetAddress *addr, const char *reason) {
   const time_t now = time(NULL);
@@ -431,6 +484,7 @@ void ban_add (const ENetAddress *addr, const char *reason) {
   printf(LC_MS_BANNED, rec->ip, d_strtime(rec->cur_ban), reason, rec->ban_count);
 
   ban_save_list(MS_BAN_FILE);
+  erase_banned_host(rec->ip);
 }
 
 
@@ -438,6 +492,22 @@ void d_deinit(void) {
   ban_save_list(MS_BAN_FILE);
 }
 
+int count_servers(char* ip) {
+  int sameHostServers = 0;
+  for (int i = 0; i < MS_MAXSRVS; ++i) {
+    if ((strncmp(ip, ms_srv[i].s_ip, 16) == 0)) {
+      ++sameHostServers;
+    }
+  }
+  return sameHostServers;
+}
+
+
+#define CHECK_RECV_OVERFLOW(addr) \
+  if (b_recv.overflow) { \
+    ban_add(addr, LC_MS_BANTRASH); \
+    break; \
+  }
 
 int main (int argc, char *argv[]) {
   d_getargs(argc, argv);
@@ -509,39 +579,44 @@ int main (int argc, char *argv[]) {
         case ENET_EVENT_TYPE_RECEIVE:
           if (!event.peer) continue;
 
-          b_read = 0;
-          msg = b_read_uint8(event.packet->data, &b_read);
+          b_recv.pos = 0;
+          b_recv.overflow = 0;
+          b_recv.data = event.packet->data;
+          b_recv.size = event.packet->dataLength;
+          msg = b_read_uint8(&b_recv);
 
           switch (msg) {
             case NET_MSG_ADD:
               enet_address_get_host_ip(&(event.peer->address), ip, 17);
-              port = b_read_uint16(event.packet->data, &b_read);
+              port = b_read_uint16(&b_recv);
+
+              name = b_read_dstring(&b_recv);
+              map = b_read_dstring(&b_recv);
+              gm  = b_read_uint8(&b_recv);
 
-              name = b_read_dstring(event.packet->data, &b_read);
-              map = b_read_dstring(event.packet->data, &b_read);
-              gm  = b_read_uint8(event.packet->data, &b_read);
+              pl = b_read_uint8(&b_recv);
+              mpl = b_read_uint8(&b_recv);
 
-              pl = b_read_uint8(event.packet->data, &b_read);
-              mpl = b_read_uint8(event.packet->data, &b_read);
+              proto = b_read_uint8(&b_recv);
+              pw = b_read_uint8(&b_recv);
 
-              proto = b_read_uint8(event.packet->data, &b_read);
-              pw = b_read_uint8(event.packet->data, &b_read);
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
 
               for (int i = 0; i < MS_MAXSRVS; ++i) {
                 if (ms_srv[i].used) {
                   if ((strncmp(ip, ms_srv[i].s_ip, 16) == 0) && (ms_srv[i].s_port == port)) {
-                    if (ban_heur(ms_srv + i, now) >= MS_MAXHEUR) {
-                      ban_add(&(event.peer->address), LC_MS_BANHEUR);
-                      break;
-                    }
-
-                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map));
-                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name));
+                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map) - 1);
+                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name) - 1);
                     ms_srv[i].s_plrs = pl;
                     ms_srv[i].s_maxplrs = mpl;
                     ms_srv[i].s_pw = pw;
                     ms_srv[i].s_mode = gm;
 
+                    if (ban_heur(ms_srv + i, now) >= MS_MAXHEUR) {
+                      ban_add(&(event.peer->address), LC_MS_BANHEUR);
+                      break;
+                    }
+
                     ms_srv[i].deathtime = now + ms_timeout;
                     ms_srv[i].lasttime = now;
 
@@ -549,9 +624,20 @@ int main (int argc, char *argv[]) {
                     break;
                   }
                 } else {
-                    strncpy(ms_srv[i].s_ip, ip, sizeof(ms_srv[i].s_ip));
-                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map));
-                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name));
+                    int countServer = count_servers(ip);
+                    if (countServer > MS_MAXHOST) {
+                      ban_add(&(event.peer->address), LC_MS_BANTOOMUCH);
+                      break;
+                    }
+                    else if (ms_checkmultiple && countServer > 1) {
+                      if (get_sum_lasttime(ip) < (countServer*3)) {
+                        ban_add(&(event.peer->address), LC_MS_BANSPAM);
+                        break;
+                      }
+                    }
+                    strncpy(ms_srv[i].s_ip, ip, sizeof(ms_srv[i].s_ip) - 1);
+                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map) - 1);
+                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name) - 1);
                     ms_srv[i].s_port = port;
                     ms_srv[i].s_plrs = pl;
                     ms_srv[i].s_maxplrs = mpl;
@@ -580,7 +666,8 @@ int main (int argc, char *argv[]) {
 
             case NET_MSG_RM:
               enet_address_get_host_ip(&(event.peer->address), ip, 17);
-              port = b_read_uint16(event.packet->data, &b_read);
+              port = b_read_uint16(&b_recv);
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
               for (int i = 0; i < MS_MAXSRVS; ++i) {
                 if (ms_srv[i].used) {
                   if ((strncmp(ip, ms_srv[i].s_ip, 16) == 0) && (ms_srv[i].s_port == port)) {
@@ -597,34 +684,37 @@ int main (int argc, char *argv[]) {
               break;
 
             case NET_MSG_LIST:
-              b_write = 0;
-              b_write_uint8(b_send, &b_write, NET_MSG_LIST);
+              b_send.pos = 0;
+              b_write_uint8(&b_send, NET_MSG_LIST);
 
               if (event.packet->dataLength > 2) {
                 // holy shit a fresh client
-                clientver = b_read_dstring(event.packet->data, &b_read);
-                b_write_uint8(b_send, &b_write, ms_count);
+                clientver = b_read_dstring(&b_recv);
+                b_write_uint8(&b_send, ms_count);
               } else {
                 // old client, feed them bullshit first
-                b_write_uint8(b_send, &b_write, ms_count + 2);
+                b_write_uint8(&b_send, ms_count + 2);
                 for (int i = 0; i < MS_FAKESRVS; ++i)
-                  b_write_server(b_send, &b_write, ms_fake_srv[i]);
+                  b_write_server(&b_send, ms_fake_srv[i]);
               }
 
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
+
               for (int i = 0; i < MS_MAXSRVS; ++i) {
-                if (ms_srv[i].used) b_write_server(b_send, &b_write, ms_srv[i]);
+                if (ms_srv[i].used)
+                  b_write_server(&b_send, ms_srv[i]);
               }
 
               if (clientver) {
                 // TODO: check if this client is outdated (?) and send back new verstring
                 // for now just write the same shit back
-                b_write_dstring(b_send, &b_write, clientver);
+                b_write_dstring(&b_send, clientver);
                 // write the motd and urgent message
-                b_write_dstring(b_send, &b_write, ms_motd);
-                b_write_dstring(b_send, &b_write, ms_urgent);
+                b_write_dstring(&b_send, ms_motd);
+                b_write_dstring(&b_send, ms_urgent);
               }
 
-              ENetPacket *p = enet_packet_create(b_send, b_write, ENET_PACKET_FLAG_RELIABLE);
+              ENetPacket *p = enet_packet_create(b_send.data, b_send.pos, ENET_PACKET_FLAG_RELIABLE);
               enet_peer_send(event.peer, NET_CH_MAIN, p);
               enet_host_flush(ms_host);
 
@@ -632,6 +722,11 @@ int main (int argc, char *argv[]) {
               free(clientver);
               clientver = NULL;
               break;
+
+            default:
+              // cheeky cunt sending invalid messages
+              ban_add(&(event.peer->address), LC_MS_BANINVAL);
+              break;
           }
 
           enet_packet_destroy(event.packet);