DEADSOFTWARE

master: check client timeouts manually; fix channel processing
[d2df-sdl.git] / src / mastersrv / master.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <stdint.h>
4 #include <stdbool.h>
5 #include <stdarg.h>
6 #include <ctype.h>
7 #include <string.h>
8 #include <time.h>
9 #include <signal.h>
11 #define ENET_DEBUG 1
12 #include <enet/enet.h>
13 #include <enet/types.h>
15 #define MS_VERSION "0.3"
16 #define MS_MAX_SERVERS 128
17 #define MS_MAX_CLIENTS (MS_MAX_SERVERS + 1)
18 #define MS_URGENT_FILE "urgent.txt"
19 #define MS_MOTD_FILE "motd.txt"
20 #define MS_BAN_FILE "master_bans.txt"
22 #define DEFAULT_SPAM_CAP 10
23 #define DEFAULT_MAX_SERVERS MS_MAX_SERVERS
24 #define DEFAULT_MAX_PER_HOST 4
25 #define DEFAULT_SERVER_TIMEOUT 100
26 #define DEFAULT_CLIENT_TIMEOUT 3
27 #define DEFAULT_SPAM_TIMEOUT 1
28 #define DEFAULT_PORT 25665
30 #define NET_BUFSIZE 65536
31 #define NET_FULLMASK 0xFFFFFFFF
33 #define SV_PROTO_MIN 140
34 #define SV_PROTO_MAX 210
35 #define SV_NAME_MAX 64
36 #define SV_MAP_MAX 64
37 #define SV_MAX_PLAYERS 24
38 #define SV_MAX_GAMEMODE 5
39 #define SV_NEW_SERVER_INTERVAL 3
41 #define MAX_STRLEN 0xFF
43 enum log_severity_e {
44 LOG_NOTE,
45 LOG_WARN,
46 LOG_ERROR
47 };
49 enum net_ch_e {
50 NET_CH_MAIN,
51 NET_CH_UPD,
52 NET_CH_COUNT
53 };
55 enum net_msg_e {
56 NET_MSG_ADD = 200,
57 NET_MSG_RM = 201,
58 NET_MSG_LIST = 202
59 };
61 enum sv_flags_e {
62 SV_FL_PASSWORD = 1 << 0,
63 SV_FL_VERIFIED = 1 << 1,
64 SV_FL_MAX = SV_FL_PASSWORD | SV_FL_VERIFIED,
65 };
67 typedef struct enet_buf_s {
68 enet_uint8 *data;
69 size_t size;
70 size_t pos;
71 int overflow;
72 } enet_buf_t;
74 typedef struct ban_record_s {
75 enet_uint32 host;
76 enet_uint32 mask;
77 int ban_count;
78 time_t cur_ban;
79 struct ban_record_s *next;
80 struct ban_record_s *prev;
81 } ban_record_t;
83 typedef struct server_s {
84 enet_uint32 host; // BE; 0 means this slot is unused
85 enet_uint16 port; // LE, which is what the game and enet both expect
86 enet_uint8 flags;
87 enet_uint8 proto;
88 enet_uint8 gamemode;
89 enet_uint8 players;
90 enet_uint8 maxplayers;
91 char name[MAX_STRLEN + 2];
92 char map[MAX_STRLEN + 2];
93 time_t death_time;
94 time_t timestamp;
95 } server_t;
97 // real servers
98 static server_t servers[MS_MAX_SERVERS];
99 static int max_servers = DEFAULT_MAX_SERVERS;
100 static int max_servers_per_host = DEFAULT_MAX_PER_HOST;
101 static int num_servers = 0;
103 // fake servers to show on old versions of the game
104 static const server_t fake_servers[] = {
106 .name = "! \xc2\xc0\xd8\xc0 \xca\xce\xcf\xc8\xdf \xc8\xc3\xd0\xdb "
107 "\xd3\xd1\xd2\xc0\xd0\xc5\xcb\xc0! "
108 "\xd1\xca\xc0\xd7\xc0\xc9\xd2\xc5 \xcd\xce\xc2\xd3\xde C "
109 "doom2d.org !",
110 .map = "! Your game is outdated. "
111 "Get latest version at doom2d.org !",
112 .proto = 255,
113 },
115 .name = "! \xcf\xd0\xce\xc1\xd0\xce\xd1\xdcTE \xcf\xce\xd0\xd2\xdb "
116 "25666 \xc8 57133 HA CEPBEPE \xcf\xc5\xd0\xc5\xc4 \xc8\xc3\xd0\xce\xc9 !",
117 .map = "! Forward ports 25666 and 57133 before hosting !",
118 .proto = 255,
119 },
120 };
121 static const int num_fake_servers = sizeof(fake_servers) / sizeof(*fake_servers);
123 // ban list
124 static ban_record_t *banlist;
126 // settings
127 static int ms_port = DEFAULT_PORT;
128 static int ms_sv_timeout = DEFAULT_SERVER_TIMEOUT;
129 static int ms_cl_timeout = DEFAULT_CLIENT_TIMEOUT;
130 static int ms_spam_timeout = DEFAULT_SPAM_TIMEOUT;
131 static int ms_spam_cap = DEFAULT_SPAM_CAP;
132 static char ms_motd[MAX_STRLEN + 1] = "";
133 static char ms_urgent[MAX_STRLEN + 1] = "";
134 static ENetHost *ms_host;
136 // network buffers
137 static enet_uint8 buf_send_data[NET_BUFSIZE];
138 static enet_buf_t buf_send = { .data = buf_send_data, .size = sizeof(buf_send_data) };
139 static enet_buf_t buf_recv; // rx data supplied by enet packets
141 // stupid client spam filter
142 static enet_uint32 cl_last_addr;
143 static time_t cl_last_time;
144 static int cl_spam_cnt;
146 /* common utility functions */
148 static char *u_vabuf(void) {
149 static char vabuf[4][MAX_STRLEN];
150 static int idx = 0;
151 char *ret = vabuf[idx++];
152 if (idx >= 4) idx = 0;
153 return ret;
156 static const char *u_strtime(const time_t t) {
157 char *buf = u_vabuf();
158 struct tm *ptm = localtime(&t);
159 strftime(buf, MAX_STRLEN - 1, "%d/%m/%y %H:%M:%S", ptm);
160 return buf;
163 static inline const char *u_logprefix(const enum log_severity_e s) {
164 switch (s) {
165 case LOG_WARN: return "WARNING: ";
166 case LOG_ERROR: return "ERROR: ";
167 default: return "";
171 static void u_log(const enum log_severity_e severity, const char *fmt, ...) {
172 printf("[%s] %s", u_strtime(time(NULL)), u_logprefix(severity));
173 va_list args;
174 va_start(args, fmt);
175 vprintf(fmt, args);
176 va_end(args);
177 printf("\n");
180 static void __attribute__((noreturn)) u_fatal(const char *fmt, ...) {
181 fprintf(stderr, "[%s] FATAL ERROR:\n", u_strtime(time(NULL)));
182 va_list args;
183 va_start(args, fmt);
184 vfprintf(stderr, fmt, args);
185 va_end(args);
186 fprintf(stderr, "\n");
187 fflush(stderr);
188 exit(1);
191 static bool u_strisprint(const char *str) {
192 if (!str || !*str)
193 return false;
194 for (const char *p = str; *p; ++p) {
195 // only stuff before space, DEL, NBSP and SHY are considered garbage since we're on 1251
196 if (*p < 0x20 || *p == 0x7F || *p == 0xA0 || *p == 0xAD)
197 return false;
199 return true;
202 static bool u_strisver(const char *str) {
203 if (!str || !*str)
204 return false;
205 for (const char *p = str; *p; ++p) {
206 // version strings consist of 0-9 . and space
207 if (!isdigit(*p) && *p != '.' && *p != ' ')
208 return false;
210 return true;
213 static const char *u_iptostr(const enet_uint32 host) {
214 ENetAddress addr = { .host = host, .port = 0 };
215 char *buf = u_vabuf();
216 enet_address_get_host_ip(&addr, buf, MAX_STRLEN - 1);
217 return buf;
220 static bool u_readtextfile(const char *fname, char *buf, size_t max) {
221 FILE *f = fopen(fname, "r");
222 char *const end = buf + max - 1;
223 char *p = buf;
224 if (f) {
225 char ln[max];
226 char *const lend = ln + max - 1;
227 while (p < end && fgets(ln, max, f)) {
228 for (char *n = ln; n < lend && *n && *n != '\r' && *n != '\n'; ++n) {
229 *(p++) = *n;
230 if (p == end) break;
233 *p = '\0';
234 fclose(f);
235 return true;
237 return false;
240 static inline enet_uint32 u_prefixtomask(const enet_uint32 prefix) {
241 return ENET_HOST_TO_NET_32((0xFFFFFFFF << (32 - prefix)) & 0xFFFFFFFF);
244 static inline enet_uint32 u_masktoprefix(const enet_uint32 mask) {
245 return (32 - __builtin_ctz(mask));
248 static inline void u_printsv(const server_t *sv) {
249 printf("* addr: %s:%d\n", u_iptostr(sv->host), sv->port);
250 printf("* name: %s\n", sv->name);
251 printf("* map: %s (mode %d)\n", sv->map, sv->gamemode);
252 printf("* plrs: %d/%d\n", sv->players, sv->maxplayers);
253 printf("* flag: %04x\n", sv->flags);
256 /* buffer utility functions */
258 static inline int b_enough_left(enet_buf_t *buf, size_t size) {
259 if (buf->pos + size > buf->size) {
260 buf->overflow = 1;
261 return 0;
263 return 1;
266 static enet_uint8 b_read_uint8(enet_buf_t *buf) {
267 if (b_enough_left(buf, 1))
268 return buf->data[buf->pos++];
269 return 0;
272 static enet_uint16 b_read_uint16(enet_buf_t *buf) {
273 enet_uint16 ret = 0;
275 if (b_enough_left(buf, sizeof(ret))) {
276 ret = *(enet_uint16*)(buf->data + buf->pos);
277 buf->pos += sizeof(ret);
280 return ret;
283 static char *b_read_dstring(enet_buf_t *buf) {
284 char *ret = NULL;
286 if (b_enough_left(buf, 1)) {
287 const size_t len = b_read_uint8(buf);
288 if (b_enough_left(buf, len)) {
289 ret = malloc(len + 1);
290 memmove(ret, (char*)(buf->data + buf->pos), len);
291 buf->pos += len;
292 ret[len] = '\0';
296 return ret;
299 static char *b_read_dstring_to(enet_buf_t *buf, char *out, size_t out_size) {
300 if (b_enough_left(buf, 1)) {
301 const size_t len = b_read_uint8(buf);
302 if (b_enough_left(buf, len)) {
303 if (len < out_size) {
304 memmove(out, (char*)(buf->data + buf->pos), len);
305 out[len] = '\0';
306 } else if (out_size) {
307 out[0] = '\0';
309 buf->pos += len;
310 return out;
313 return NULL;
316 static void b_write_uint8(enet_buf_t *buf, enet_uint8 val) {
317 buf->data[buf->pos++] = val;
320 static void b_write_uint16(enet_buf_t *buf, enet_uint16 val) {
321 *(enet_uint16*)(buf->data + buf->pos) = val;
322 buf->pos += sizeof(val);
325 static void b_write_dstring(enet_buf_t *buf, const char* val) {
326 enet_uint8 len = strlen(val);
327 b_write_uint8(buf, len);
328 memmove((char*)(buf->data + buf->pos), val, len);
329 buf->pos += len;
332 void b_write_server(enet_buf_t *buf, const server_t *s) {
333 b_write_dstring(buf, u_iptostr(s->host));
334 b_write_uint16 (buf, s->port);
335 b_write_dstring(buf, s->name);
336 b_write_dstring(buf, s->map);
337 b_write_uint8 (buf, s->gamemode);
338 b_write_uint8 (buf, s->players);
339 b_write_uint8 (buf, s->maxplayers);
340 b_write_uint8 (buf, s->proto);
341 b_write_uint8 (buf, (s->flags & SV_FL_PASSWORD));
344 /* server functions */
346 static void sv_remove(const enet_uint32 host, const enet_uint16 port) {
347 for (int i = 0; i < max_servers; ++i) {
348 if (servers[i].host == host && servers[i].port == port) {
349 servers[i].host = 0;
350 servers[i].port = 0;
351 --num_servers;
356 static void sv_remove_by_host(enet_uint32 host, enet_uint32 mask) {
357 host &= mask;
358 for (int i = 0; i < max_servers; ++i) {
359 if (servers[i].host && (servers[i].host & mask) == host) {
360 servers[i].host = 0;
361 servers[i].port = 0;
362 --num_servers;
367 static int sv_count_by_host(enet_uint32 host, enet_uint32 mask) {
368 host &= mask;
369 int count = 0;
370 for (int i = 0; i < max_servers; ++i) {
371 if (servers[i].host && (servers[i].host & mask) == host)
372 ++count;
374 return count;
377 static time_t sv_last_timestamp_for_host(enet_uint32 host, enet_uint32 mask) {
378 host &= mask;
379 time_t last = 0;
380 for (int i = 0; i < max_servers; ++i) {
381 if (servers[i].host && (servers[i].host & mask) == host) {
382 if (servers[i].timestamp > last)
383 last = servers[i].timestamp;
386 return last;
389 static inline server_t *sv_find_or_add(const enet_uint32 host, const enet_uint32 port) {
390 server_t *empty = NULL;
391 for (int i = 0; i < max_servers; ++i) {
392 server_t *s = servers + i;
393 if (s->host == host && s->port == port)
394 return s; // this server already exists
395 if (!s->host && !empty)
396 empty = s; // remember the first empty slot in case it's needed later
398 return empty;
401 /* ban list functions */
403 static inline time_t ban_get_time(const int cnt) {
404 static const time_t times[] = {
405 1 * 5 * 60,
406 1 * 30 * 60,
407 1 * 60 * 60,
408 24 * 60 * 60,
409 72 * 60 * 60,
410 720 * 60 * 60,
411 8760 * 60 * 60,
412 };
414 static const size_t numtimes = sizeof(times) / sizeof(*times);
416 if (cnt >= numtimes || cnt < 0)
417 return times[numtimes - 1];
419 return times[cnt];
422 static ban_record_t *ban_check(const enet_uint32 host) {
423 const time_t now = time(NULL);
425 for (ban_record_t *b = banlist; b; b = b->next) {
426 if ((b->host & b->mask) == (host & b->mask)) {
427 if (b->cur_ban > now)
428 return b;
432 return NULL;
435 static inline ban_record_t *ban_record_check(const enet_uint32 host) {
436 for (ban_record_t *b = banlist; b; b = b->next) {
437 if ((b->host & b->mask) == (host & b->mask))
438 return b;
440 return NULL;
443 static ban_record_t *ban_record_add_addr(const enet_uint32 host, const enet_uint32 mask, const int cnt, const time_t cur) {
444 ban_record_t *rec = ban_record_check(host);
445 if (rec) return rec;
447 rec = calloc(1, sizeof(*rec));
448 if (!rec) return NULL;
450 rec->host = host & mask;
451 rec->mask = mask;
452 if (rec->mask == 0) rec->mask = NET_FULLMASK;
453 rec->ban_count = cnt;
454 rec->cur_ban = cur;
456 if (banlist) banlist->prev = rec;
457 rec->next = banlist;
458 banlist = rec;
460 return rec;
463 static ban_record_t *ban_record_add_ip(const char *ip, const int cnt, const time_t cur) {
464 enet_uint32 prefix = 32;
466 // find and get the prefix length, if any
467 char ip_copy[24] = { 0 };
468 strncpy(ip_copy, ip, sizeof(ip_copy) - 1);
469 char *slash = strrchr(ip_copy, '/');
470 if (slash) {
471 *slash++ = '\0'; // strip the prefix length off
472 if (*slash) prefix = atoi(slash);
475 ENetAddress addr = { 0 };
476 if (enet_address_set_host_ip(&addr, ip_copy) != 0) {
477 u_log(LOG_ERROR, "banlist: `%s` is not a valid IP address", ip_copy);
478 return NULL;
481 // transform prefix length into mask
482 const enet_uint32 mask = u_prefixtomask(prefix);
484 return ban_record_add_addr(addr.host, mask, cnt, cur);
487 static void ban_free_list(void) {
488 ban_record_t *rec = banlist;
489 while (rec) {
490 ban_record_t *next = rec->next;
491 free(rec);
492 rec = next;
494 banlist = NULL;
497 static void ban_load_list(const char *fname) {
498 FILE *f = fopen(fname, "r");
499 if (!f) {
500 u_log(LOG_WARN, "banlist: could not open %s for reading", fname);
501 return;
504 char ln[MAX_STRLEN] = { 0 };
506 while (fgets(ln, sizeof(ln), f)) {
507 for (int i = sizeof(ln) - 1; i >= 0; --i)
508 if (ln[i] == '\n' || ln[i] == '\r')
509 ln[i] = 0;
511 if (ln[0] == 0)
512 continue;
514 char ip[21] = { 0 }; // optionally includes the "/nn" prefix length at the end
515 time_t exp = 0;
516 int count = 0;
517 if (sscanf(ln, "%20s %ld %d", ip, &exp, &count) < 3) {
518 u_log(LOG_ERROR, "banlist: malformed line: `%s`", ln);
519 continue;
522 if (ban_record_add_ip(ip, count, exp))
523 u_log(LOG_NOTE, "banlist: banned %s until %s (ban level %d)", ip, u_strtime(exp), count);
526 fclose(f);
529 static void ban_save_list(const char *fname) {
530 FILE *f = fopen(fname, "w");
531 if (!f) {
532 u_log(LOG_ERROR, "banlist: could not open %s for writing", fname);
533 return;
536 for (ban_record_t *rec = banlist; rec; rec = rec->next) {
537 if (rec->ban_count)
538 fprintf(f, "%s/%u %ld %d\n", u_iptostr(rec->host), u_masktoprefix(rec->mask), rec->cur_ban, rec->ban_count);
541 fclose(f);
544 static bool ban_sanity_check(const server_t *srv) {
545 // can't have more than 24 maxplayers; can't have more than max
546 if (srv->players > srv->maxplayers || srv->maxplayers > SV_MAX_PLAYERS || srv->maxplayers == 0)
547 return false;
548 // name and map have to be non-garbage
549 if (!u_strisprint(srv->map) || !u_strisprint(srv->name))
550 return false;
551 // these protocols don't exist
552 if (srv->proto < SV_PROTO_MIN || srv->proto > SV_PROTO_MAX)
553 return false;
554 // the game doesn't allow server names longer than 64 chars
555 if (strlen(srv->name) > SV_NAME_MAX)
556 return false;
557 // game mode has to actually exist
558 if (srv->gamemode > SV_MAX_GAMEMODE)
559 return false;
560 // flags field can't be higher than the sum of all the flags
561 if (srv->flags > SV_FL_MAX)
562 return false;
563 return true;
566 static void ban_add(const enet_uint32 host, const char *reason) {
567 const time_t now = time(NULL);
569 ban_record_t *rec = ban_record_add_addr(host, NET_FULLMASK, 0, 0);
570 if (!rec) u_fatal("OOM trying to ban %s", u_iptostr(host));
572 rec->cur_ban = now + ban_get_time(rec->ban_count);
573 rec->ban_count++;
575 u_log(LOG_NOTE, "banned %s until %s, reason: %s, ban level: %d", u_iptostr(rec->host), u_strtime(rec->cur_ban), reason, rec->ban_count);
577 ban_save_list(MS_BAN_FILE);
579 sv_remove_by_host(host, NET_FULLMASK);
581 if (host == cl_last_addr)
582 cl_last_addr = 0;
585 static inline void ban_peer(ENetPeer *peer, const char *reason) {
586 if (peer) {
587 ban_add(peer->address.host, reason);
588 peer->data = NULL;
589 enet_peer_reset(peer);
593 /* main */
595 static void deinit(void) {
596 // ban_save_list(MS_BAN_FILE);
597 ban_free_list();
598 if (ms_host) {
599 enet_host_destroy(ms_host);
600 ms_host = NULL;
602 enet_deinitialize();
605 #ifdef SIGUSR1
606 static void sigusr_handler(int signum) {
607 if (signum == SIGUSR1) {
608 u_log(LOG_WARN, "received SIGUSR1, reloading banlist");
609 ban_free_list();
610 ban_load_list(MS_BAN_FILE);
613 #endif
615 static bool handle_msg(const enet_uint8 msgid, ENetPeer *peer) {
616 server_t *sv = NULL;
617 server_t tmpsv = { 0 };
618 char clientver[MAX_STRLEN] = { 0 };
619 const time_t now = time(NULL);
621 switch (msgid) {
622 case NET_MSG_ADD:
623 tmpsv.port = b_read_uint16(&buf_recv);
624 b_read_dstring_to(&buf_recv, tmpsv.name, sizeof(tmpsv.name));
625 b_read_dstring_to(&buf_recv, tmpsv.map, sizeof(tmpsv.map));
626 tmpsv.gamemode = b_read_uint8(&buf_recv);
627 tmpsv.players = b_read_uint8(&buf_recv);
628 tmpsv.maxplayers = b_read_uint8(&buf_recv);
629 tmpsv.proto = b_read_uint8(&buf_recv);
630 tmpsv.flags = b_read_uint8(&buf_recv);
632 if (buf_recv.overflow) {
633 ban_peer(peer, "malformed MSG_ADD");
634 return true;
637 sv = sv_find_or_add(peer->address.host, tmpsv.port);
638 if (!sv) {
639 u_log(LOG_ERROR, "ran out of server slots trying to add %s:%d", u_iptostr(peer->address.host), tmpsv.port);
640 return true;
643 if (sv->host == peer->address.host) {
644 // old server; update it
645 memcpy(sv->map, tmpsv.map, sizeof(sv->map));
646 memcpy(sv->name, tmpsv.name, sizeof(sv->name));
647 sv->players = tmpsv.players;
648 sv->maxplayers = tmpsv.maxplayers;
649 sv->flags = tmpsv.flags;
650 sv->gamemode = tmpsv.gamemode;
651 // first check if the new values are garbage
652 if (!ban_sanity_check(sv)) {
653 ban_peer(peer, "tripped sanity check");
654 return true;
656 // only then update the times
657 sv->death_time = now + ms_sv_timeout;
658 sv->timestamp = now;
659 u_log(LOG_NOTE, "updated server #%d:", sv - servers);
660 u_printsv(sv);
661 } else {
662 // new server; first check if this host is creating too many servers in the list
663 if (max_servers_per_host) {
664 const int count = sv_count_by_host(peer->address.host, NET_FULLMASK);
665 if (count >= max_servers_per_host) {
666 ban_peer(peer, "too many servers in list");
667 return true;
669 /*
670 // FIXME: commented out as this might trip when the master restarts
671 if (count > 0) {
672 // check if this is too soon to create a new server
673 const time_t delta = now - sv_last_timestamp_for_host(peer->address.host, NET_FULLMASK);
674 if (delta < count * SV_NEW_SERVER_INTERVAL) {
675 ban_peer(peer, "creating servers too fast");
676 return true;
679 */
681 // then add that shit
682 *sv = tmpsv;
683 sv->host = peer->address.host;
684 sv->death_time = now + ms_sv_timeout;
685 sv->timestamp = now;
686 if (!ban_sanity_check(sv)) {
687 sv->host = 0;
688 sv->port = 0;
689 ban_peer(peer, "tripped sanity check");
690 return true;
692 ++num_servers;
693 u_log(LOG_NOTE, "added new server #%d:", sv - servers);
694 u_printsv(sv);
696 return true;
698 case NET_MSG_RM:
699 tmpsv.port = b_read_uint16(&buf_recv);
700 if (buf_recv.overflow) {
701 ban_peer(peer, "malformed MSG_RM");
702 return true;
704 sv_remove(peer->address.host, tmpsv.port);
705 return true;
707 case NET_MSG_LIST:
708 buf_send.pos = 0;
709 buf_send.overflow = 0;
710 b_write_uint8(&buf_send, NET_MSG_LIST);
712 clientver[0] = 0;
713 if (buf_recv.size > 2) {
714 // holy shit a fresh client
715 b_read_dstring_to(&buf_recv, clientver, sizeof(clientver));
716 b_write_uint8(&buf_send, num_servers);
717 } else {
718 // old client; feed him fake servers first
719 b_write_uint8(&buf_send, num_servers + num_fake_servers);
720 for (int i = 0; i < num_fake_servers; ++i)
721 b_write_server(&buf_send, &fake_servers[i]);
724 if (buf_recv.overflow) {
725 ban_peer(peer, "malformed MSG_LIST");
726 return true;
729 if (clientver[0] && !u_strisver(clientver)) {
730 ban_peer(peer, "malformed MSG_LIST clientver");
731 return true;
734 for (int i = 0; i < max_servers; ++i) {
735 if (servers[i].host)
736 b_write_server(&buf_send, servers + i);
739 if (clientver[0]) {
740 // TODO: check if this client is outdated (?) and send back new verstring
741 // for now just write the same shit back
742 b_write_dstring(&buf_send, clientver);
743 // write the motd and urgent message
744 b_write_dstring(&buf_send, ms_motd);
745 b_write_dstring(&buf_send, ms_urgent);
748 ENetPacket *p = enet_packet_create(buf_send.data, buf_send.pos, ENET_PACKET_FLAG_RELIABLE);
749 enet_peer_send(peer, NET_CH_MAIN, p);
750 // enet_host_flush(ms_host);
752 u_log(LOG_NOTE, "sent server list to %s:%d (ver %s)", u_iptostr(peer->address.host), peer->address.port, clientver[0] ? clientver : "<old>");
753 return true;
755 default:
756 break;
759 return false;
762 static void print_usage(void) {
763 printf("Usage: d2df_master [OPTIONS...]\n");
764 printf("Available options:\n");
765 printf("-h show this message and exit\n");
766 printf("-p N listen on port N (default: %d)\n", DEFAULT_PORT);
767 printf("-t N seconds before server is removed from list (default: %d)\n", DEFAULT_SERVER_TIMEOUT);
768 printf("-c N how long a client is allowed to hold the connection active (default: %d)\n", DEFAULT_CLIENT_TIMEOUT);
769 printf("-s N max number of servers in server list, 1-%d (default: %d)\n", MS_MAX_SERVERS, DEFAULT_MAX_SERVERS);
770 printf("-d N if N > 0, disallow more than N servers on the same IP (default: %d)\n", DEFAULT_MAX_PER_HOST);
771 printf("-f N crappy spam filter: ban clients after they send N requests in a row too fast (default: %d)\n", DEFAULT_SPAM_CAP);
772 printf("-w N how often does a client have to send packets for the filter to kick in, i.e. once every N sec (default: %d)\n", DEFAULT_SPAM_TIMEOUT);
773 fflush(stdout);
776 static inline bool parse_int_arg(int argc, char **argv, const int i, const char *name, int vmin, int vmax, int *outval) {
777 if (strcmp(name, argv[i]))
778 return false;
780 if (i >= argc - 1) {
781 fprintf(stderr, "expected integer value after %s\n", name);
782 return false;
785 const int v = atoi(argv[i + 1]);
786 if (v < vmin || v > vmax) {
787 fprintf(stderr, "expected integer value in range %d - %d\n", vmin, vmax);
788 return false;
791 *outval = v;
792 return true;
795 static bool parse_args(int argc, char **argv) {
796 if (argc < 2)
797 return true;
799 if (!strcmp(argv[1], "-h")) {
800 print_usage();
801 return false;
804 for (int i = 1; i < argc; ++i) {
805 const bool success =
806 parse_int_arg(argc, argv, i, "-p", 1, 0xFFFF, &ms_port)
807 || parse_int_arg(argc, argv, i, "-t", 1, 0x7FFFFFFF, &ms_sv_timeout)
808 || parse_int_arg(argc, argv, i, "-c", 1, 0x7FFFFFFF, &ms_cl_timeout)
809 || parse_int_arg(argc, argv, i, "-s", 1, MS_MAX_SERVERS, &max_servers)
810 || parse_int_arg(argc, argv, i, "-d", 0, MS_MAX_SERVERS, &max_servers_per_host)
811 || parse_int_arg(argc, argv, i, "-f", 0, 0xFFFF, &ms_spam_cap)
812 || parse_int_arg(argc, argv, i, "-w", 1, 0x7FFFFFFF, &ms_spam_timeout);
813 if (success) {
814 ++i;
815 } else {
816 fprintf(stderr, "unknown or invalid argument: %s\n", argv[i]);
817 return false;
821 return true;
824 // a stupid thing to filter sustained spam from a single IP
825 static bool spam_filter(ENetPeer *peer, const time_t now) {
826 if (peer->address.host == cl_last_addr) {
827 // spam === sending shit faster than once a second
828 if (now - cl_last_time < ms_spam_timeout) {
829 if (cl_spam_cnt > 1)
830 u_log(LOG_WARN, "address %s is sending packets too fast", u_iptostr(peer->address.host));
831 if (++cl_spam_cnt >= ms_spam_cap) {
832 ban_peer(peer, "spam");
833 cl_last_addr = 0;
834 return true;
836 } else {
837 cl_spam_cnt = 0;
839 } else {
840 cl_last_addr = peer->address.host;
841 cl_spam_cnt = 0;
843 cl_last_time = now;
844 return false;
847 // filter incoming UDP packets before the protocol kicks in
848 static int packet_filter(ENetHost *host, ENetEvent *event) {
849 return !!ban_check(host->receivedAddress.host);
852 int main(int argc, char **argv) {
853 if (enet_initialize() != 0)
854 u_fatal("could not init enet");
856 if (!parse_args(argc, argv))
857 return 1; // early exit
859 u_log(LOG_NOTE, "d2df master server starting on port %d", ms_port);
861 if (!u_readtextfile(MS_MOTD_FILE, ms_motd, sizeof(ms_motd)))
862 u_log(LOG_NOTE, "couldn't read motd from %s", MS_MOTD_FILE);
863 else
864 u_log(LOG_NOTE, "motd: %s", ms_motd);
866 if (!u_readtextfile(MS_URGENT_FILE, ms_urgent, sizeof(ms_urgent)))
867 u_log(LOG_NOTE, "couldn't read urgentmsg from %s", MS_URGENT_FILE);
868 else
869 u_log(LOG_NOTE, "urgentmsg: %s", ms_urgent);
871 ban_load_list(MS_BAN_FILE);
873 atexit(deinit);
875 #ifdef SIGUSR1
876 signal(SIGUSR1, sigusr_handler);
877 #endif
879 ENetAddress addr;
880 addr.host = 0;
881 addr.port = ms_port;
882 ms_host = enet_host_create(&addr, MS_MAX_CLIENTS, NET_CH_COUNT + 1, 0, 0);
883 if (!ms_host)
884 u_fatal("could not create enet host on port %d", ms_port);
886 ms_host->intercept = packet_filter;
888 bool running = true;
889 enet_uint8 msgid = 0;
890 ENetEvent event;
891 while (running) {
892 while (enet_host_service(ms_host, &event, 10) > 0) {
893 const time_t now = time(NULL);
894 bool filtered = !event.peer || (ms_spam_cap && spam_filter(event.peer, now));
895 if (!filtered && event.peer->data) {
896 // kick people that have overstayed their welcome
897 const time_t timeout = (time_t)(intptr_t)event.peer->data;
898 if (timeout < now) filtered = true;
901 if (!filtered) {
902 switch (event.type) {
903 case ENET_EVENT_TYPE_CONNECT:
904 u_log(LOG_NOTE, "%s:%d connected", u_iptostr(event.peer->address.host), event.peer->address.port);
905 if (event.peer->channelCount != NET_CH_COUNT)
906 ban_peer(event.peer, "what is this");
907 else // store timeout in the data field
908 event.peer->data = (void *)(intptr_t)(now + ms_cl_timeout);
909 break;
911 case ENET_EVENT_TYPE_RECEIVE:
912 if (!event.packet || event.packet->dataLength == 0) {
913 ban_peer(event.peer, "empty packet");
914 break;
916 // set up receive buffer
917 buf_recv.pos = 0;
918 buf_recv.overflow = 0;
919 buf_recv.data = event.packet->data;
920 buf_recv.size = event.packet->dataLength;
921 // read message id and handle the message
922 msgid = b_read_uint8(&buf_recv);
923 if (!handle_msg(msgid, event.peer)) {
924 // cheeky cunt sending invalid messages
925 ban_peer(event.peer, "unknown message");
926 } else {
927 // can't reset connection right now because we still have packets to dispatch
928 enet_peer_disconnect_later(event.peer, 0);
930 break;
932 case ENET_EVENT_TYPE_DISCONNECT:
933 event.peer->data = NULL;
934 // u_log(LOG_NOTE, "%s:%d disconnected", u_iptostr(event.peer->address.host), event.peer->address.port);
935 break;
937 default:
938 break;
940 } else if (event.peer) {
941 // u_log(LOG_WARN, "filtered event %d from %s", event.type, u_iptostr(event.peer->address.host));
942 event.peer->data = NULL;
943 enet_peer_reset(event.peer);
946 if (event.packet) {
947 buf_recv.data = NULL;
948 enet_packet_destroy(event.packet);
952 const time_t now = time(NULL);
954 // time out servers
955 for (int i = 0; i < max_servers; ++i) {
956 if (servers[i].host) {
957 if (servers[i].death_time <= now) {
958 u_log(LOG_NOTE, "server #%d %s:%d timed out", i, u_iptostr(servers[i].host), servers[i].port);
959 servers[i].host = 0;
960 servers[i].port = 0;
961 --num_servers;
966 // time out clients
967 if (ms_host && ms_host->peers) {
968 for (size_t i = 0; i < ms_host->peerCount; ++i) {
969 ENetPeer *peer = ms_host->peers + i;
970 if ((peer->state >= ENET_PEER_STATE_CONNECTING && peer->state <= ENET_PEER_STATE_DISCONNECT_LATER) && peer->data) {
971 const time_t timeout = (time_t)(intptr_t)peer->data;
972 if (timeout < now) {
973 u_log(LOG_NOTE, "client %s:%d timed out", u_iptostr(peer->address.host), peer->address.port);
974 peer->data = NULL;
975 enet_peer_reset(peer);