DEADSOFTWARE

master: anti-overflow patch
authorTerminalHash <lyashuk.voxx@gmail.com>
Sat, 6 May 2023 22:24:27 +0000 (01:24 +0300)
committerTerminalHash <lyashuk.voxx@gmail.com>
Sat, 6 May 2023 22:24:27 +0000 (01:24 +0300)
src/mastersrv/master.c

index e30abb8c3d26d8c0cca42e4e6e01ef244b20a5a1..ebeb2c5c9a5d7640b8ff04e3fdb9bd4a6a85c6d9 100644 (file)
@@ -43,6 +43,8 @@
 #define LC_MS_BANTOOMUCH "created too many servers"
 #define LC_MS_BANSPAM "suspicious multiple server activity"
 #define LC_MS_BANLIST "address in ban list"
+#define LC_MS_BANTRASH "garbage server data"
+#define LC_MS_BANINVAL "invalid message ID"
 #define LC_MS_OOM "\nOut of memory\n"
 
 #define MS_URGENT_FILE "urgent.txt"
@@ -62,9 +64,9 @@ typedef struct ms_ban_record_s {
 
 struct ms_server_s {
   enet_uint8 used;
-  char s_ip[17];
-  char s_name[256];
-  char s_map[256];
+  char s_ip[18];
+  char s_name[257];
+  char s_map[257];
   enet_uint8  s_pw;
   enet_uint8  s_plrs;
   enet_uint8  s_maxplrs;
@@ -77,6 +79,13 @@ struct ms_server_s {
 
 typedef struct ms_server_s ms_server;
 
+typedef struct enet_buf_s {
+  enet_uint8 *data;
+  size_t size;
+  size_t pos;
+  int overflow;
+} enet_buf;
+
 const char ms_game_ver[] = "0.63";
 char ms_motd[255] = "";
 char ms_urgent[255] = "";
@@ -85,10 +94,9 @@ int ms_port = 25660;
 int ms_timeout = 100;
 int ms_checkmultiple = 0;
 
-size_t b_read = 0;
-size_t b_write = 0;
-
-enet_uint8 b_send[NET_BUFSIZE];
+enet_uint8 b_send_data[NET_BUFSIZE];
+enet_buf b_send = { .data = b_send_data, .size = sizeof(b_send_data) };
+enet_buf b_recv;
 
 ENetHost  *ms_host = NULL;
 ENetPeer  *ms_peers[NET_MAXCLIENTS];
@@ -211,66 +219,81 @@ const char *d_strtime(const time_t t) {
 }
 
 
-enet_uint8 b_read_uint8 (enet_uint8 buf[], size_t *pos) {
-  return buf[(*pos)++];
+static inline int b_enough_left(enet_buf *buf, size_t size) {
+  if (buf->pos + size > buf->size) {
+    buf->overflow = 1;
+    return 0;
+  }
+  return 1;
+}
+
+
+enet_uint8 b_read_uint8 (enet_buf *buf) {
+  if (b_enough_left(buf, 1))
+    return buf->data[buf->pos++];
+  return 0;
 }
 
 
-enet_uint16 b_read_uint16 (enet_uint8 buf[], size_t *pos) {
+enet_uint16 b_read_uint16 (enet_buf *buf) {
   enet_uint16 ret = 0;
 
-  ret = *(enet_uint16*)(buf + *pos);
-  *pos += sizeof(enet_uint16);
+  if (b_enough_left(buf, sizeof(ret))) {
+    ret = *(enet_uint16*)(buf->data + buf->pos);
+    buf->pos += sizeof(ret);
+  }
 
   return ret;
 }
 
 
-char* b_read_dstring (enet_uint8 buf[], size_t *pos) {
+char* b_read_dstring (enet_buf *buf) {
   char *ret = NULL;
 
-  size_t len = b_read_uint8(buf, pos);
-
-  ret = malloc(len + 1);
-
-  memmove(ret, (char*)(buf + *pos), len);
-  ret[len] = '\0';
-  *pos += len;
+  if (b_enough_left(buf, 1)) {
+    size_t len = b_read_uint8(buf);
+    if (b_enough_left(buf, len)) {
+      ret = malloc(len + 1);
+      memmove(ret, (char*)(buf->data + buf->pos), len);
+      buf->pos += len;
+      ret[len] = '\0';
+    }
+  }
 
   return ret;
 }
 
 
-void b_write_uint8 (enet_uint8 buf[], size_t *pos, enet_uint8 val) {
-  buf[(*pos)++] = val;
+void b_write_uint8 (enet_buf *buf, enet_uint8 val) {
+  buf->data[buf->pos++] = val;
 }
 
 
-void b_write_uint16 (enet_uint8 buf[], size_t *pos, enet_uint16 val) {
-  *(enet_uint16*)(buf + *pos) = val;
-  *pos += sizeof(enet_uint16);
+void b_write_uint16 (enet_buf *buf, enet_uint16 val) {
+  *(enet_uint16*)(buf->data + buf->pos) = val;
+  buf->pos += sizeof(val);
 }
 
 
-void b_write_dstring (enet_uint8 buf[], size_t *pos, const char* val) {
+void b_write_dstring (enet_buf *buf, const char* val) {
   enet_uint8 len = strlen(val);
-  b_write_uint8(buf, pos, len);
+  b_write_uint8(buf, len);
 
-  memmove((char*)(buf + *pos), val, len);
-  *pos += len;
+  memmove((char*)(buf->data + buf->pos), val, len);
+  buf->pos += len;
 }
 
 
-void b_write_server (enet_uint8 buf[], size_t *pos, ms_server s) {
-  b_write_dstring(b_send, pos, s.s_ip);
-  b_write_uint16 (b_send, pos, s.s_port);
-  b_write_dstring(b_send, pos, s.s_name);
-  b_write_dstring(b_send, pos, s.s_map);
-  b_write_uint8  (b_send, pos, s.s_mode);
-  b_write_uint8  (b_send, pos, s.s_plrs);
-  b_write_uint8  (b_send, pos, s.s_maxplrs);
-  b_write_uint8  (b_send, pos, s.s_protocol);
-  b_write_uint8  (b_send, pos, s.s_pw);
+void b_write_server (enet_buf *b_send, ms_server s) {
+  b_write_dstring(b_send, s.s_ip);
+  b_write_uint16 (b_send, s.s_port);
+  b_write_dstring(b_send, s.s_name);
+  b_write_dstring(b_send, s.s_map);
+  b_write_uint8  (b_send, s.s_mode);
+  b_write_uint8  (b_send, s.s_plrs);
+  b_write_uint8  (b_send, s.s_maxplrs);
+  b_write_uint8  (b_send, s.s_protocol);
+  b_write_uint8  (b_send, s.s_pw);
 }
 
 
@@ -427,10 +450,13 @@ int ban_heur (const ms_server *srv, const time_t now) {
   return score;
 }
 
-void erase_banned_host(char* ip) {
+void erase_banned_host(const char *ip) {
   for (int i = 0; i < MS_MAXSRVS; ++i) {
-    if (ms_srv[i].s_ip == ip) {
-      ms_srv[i].used = 0;
+    if (!strcmp(ms_srv[i].s_ip, ip)) {
+      if (ms_srv[i].used) {
+        ms_srv[i].used = 0;
+        ms_count--;
+      }
     }
   }
 }
@@ -477,7 +503,11 @@ int count_servers(char* ip) {
 }
 
 
-
+#define CHECK_RECV_OVERFLOW(addr) \
+  if (b_recv.overflow) { \
+    ban_add(addr, LC_MS_BANTRASH); \
+    break; \
+  }
 
 int main (int argc, char *argv[]) {
   d_getargs(argc, argv);
@@ -549,39 +579,44 @@ int main (int argc, char *argv[]) {
         case ENET_EVENT_TYPE_RECEIVE:
           if (!event.peer) continue;
 
-          b_read = 0;
-          msg = b_read_uint8(event.packet->data, &b_read);
+          b_recv.pos = 0;
+          b_recv.overflow = 0;
+          b_recv.data = event.packet->data;
+          b_recv.size = event.packet->dataLength;
+          msg = b_read_uint8(&b_recv);
 
           switch (msg) {
             case NET_MSG_ADD:
               enet_address_get_host_ip(&(event.peer->address), ip, 17);
-              port = b_read_uint16(event.packet->data, &b_read);
+              port = b_read_uint16(&b_recv);
 
-              name = b_read_dstring(event.packet->data, &b_read);
-              map = b_read_dstring(event.packet->data, &b_read);
-              gm  = b_read_uint8(event.packet->data, &b_read);
+              name = b_read_dstring(&b_recv);
+              map = b_read_dstring(&b_recv);
+              gm  = b_read_uint8(&b_recv);
 
-              pl = b_read_uint8(event.packet->data, &b_read);
-              mpl = b_read_uint8(event.packet->data, &b_read);
+              pl = b_read_uint8(&b_recv);
+              mpl = b_read_uint8(&b_recv);
 
-              proto = b_read_uint8(event.packet->data, &b_read);
-              pw = b_read_uint8(event.packet->data, &b_read);
+              proto = b_read_uint8(&b_recv);
+              pw = b_read_uint8(&b_recv);
+
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
 
               for (int i = 0; i < MS_MAXSRVS; ++i) {
                 if (ms_srv[i].used) {
                   if ((strncmp(ip, ms_srv[i].s_ip, 16) == 0) && (ms_srv[i].s_port == port)) {
-                    if (ban_heur(ms_srv + i, now) >= MS_MAXHEUR) {
-                      ban_add(&(event.peer->address), LC_MS_BANHEUR);
-                      break;
-                    }
-
-                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map));
-                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name));
+                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map) - 1);
+                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name) - 1);
                     ms_srv[i].s_plrs = pl;
                     ms_srv[i].s_maxplrs = mpl;
                     ms_srv[i].s_pw = pw;
                     ms_srv[i].s_mode = gm;
 
+                    if (ban_heur(ms_srv + i, now) >= MS_MAXHEUR) {
+                      ban_add(&(event.peer->address), LC_MS_BANHEUR);
+                      break;
+                    }
+
                     ms_srv[i].deathtime = now + ms_timeout;
                     ms_srv[i].lasttime = now;
 
@@ -600,9 +635,9 @@ int main (int argc, char *argv[]) {
                         break;
                       }
                     }
-                    strncpy(ms_srv[i].s_ip, ip, sizeof(ms_srv[i].s_ip));
-                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map));
-                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name));
+                    strncpy(ms_srv[i].s_ip, ip, sizeof(ms_srv[i].s_ip) - 1);
+                    strncpy(ms_srv[i].s_map, map, sizeof(ms_srv[i].s_map) - 1);
+                    strncpy(ms_srv[i].s_name, name, sizeof(ms_srv[i].s_name) - 1);
                     ms_srv[i].s_port = port;
                     ms_srv[i].s_plrs = pl;
                     ms_srv[i].s_maxplrs = mpl;
@@ -631,7 +666,8 @@ int main (int argc, char *argv[]) {
 
             case NET_MSG_RM:
               enet_address_get_host_ip(&(event.peer->address), ip, 17);
-              port = b_read_uint16(event.packet->data, &b_read);
+              port = b_read_uint16(&b_recv);
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
               for (int i = 0; i < MS_MAXSRVS; ++i) {
                 if (ms_srv[i].used) {
                   if ((strncmp(ip, ms_srv[i].s_ip, 16) == 0) && (ms_srv[i].s_port == port)) {
@@ -648,34 +684,37 @@ int main (int argc, char *argv[]) {
               break;
 
             case NET_MSG_LIST:
-              b_write = 0;
-              b_write_uint8(b_send, &b_write, NET_MSG_LIST);
+              b_send.pos = 0;
+              b_write_uint8(&b_send, NET_MSG_LIST);
 
               if (event.packet->dataLength > 2) {
                 // holy shit a fresh client
-                clientver = b_read_dstring(event.packet->data, &b_read);
-                b_write_uint8(b_send, &b_write, ms_count);
+                clientver = b_read_dstring(&b_recv);
+                b_write_uint8(&b_send, ms_count);
               } else {
                 // old client, feed them bullshit first
-                b_write_uint8(b_send, &b_write, ms_count + 2);
+                b_write_uint8(&b_send, ms_count + 2);
                 for (int i = 0; i < MS_FAKESRVS; ++i)
-                  b_write_server(b_send, &b_write, ms_fake_srv[i]);
+                  b_write_server(&b_send, ms_fake_srv[i]);
               }
 
+              CHECK_RECV_OVERFLOW(&(event.peer->address));
+
               for (int i = 0; i < MS_MAXSRVS; ++i) {
-                if (ms_srv[i].used) b_write_server(b_send, &b_write, ms_srv[i]);
+                if (ms_srv[i].used)
+                  b_write_server(&b_send, ms_srv[i]);
               }
 
               if (clientver) {
                 // TODO: check if this client is outdated (?) and send back new verstring
                 // for now just write the same shit back
-                b_write_dstring(b_send, &b_write, clientver);
+                b_write_dstring(&b_send, clientver);
                 // write the motd and urgent message
-                b_write_dstring(b_send, &b_write, ms_motd);
-                b_write_dstring(b_send, &b_write, ms_urgent);
+                b_write_dstring(&b_send, ms_motd);
+                b_write_dstring(&b_send, ms_urgent);
               }
 
-              ENetPacket *p = enet_packet_create(b_send, b_write, ENET_PACKET_FLAG_RELIABLE);
+              ENetPacket *p = enet_packet_create(b_send.data, b_send.pos, ENET_PACKET_FLAG_RELIABLE);
               enet_peer_send(event.peer, NET_CH_MAIN, p);
               enet_host_flush(ms_host);
 
@@ -683,6 +722,11 @@ int main (int argc, char *argv[]) {
               free(clientver);
               clientver = NULL;
               break;
+
+            default:
+              // cheeky cunt sending invalid messages
+              ban_add(&(event.peer->address), LC_MS_BANINVAL);
+              break;
           }
 
           enet_packet_destroy(event.packet);