From 364d089250d1acf459e9e8580161e7bb06268106 Mon Sep 17 00:00:00 2001
From: Wang Mingyu <wangmy@fujitsu.com>
Date: Tue, 15 Oct 2024 02:47:38 +0000
Subject: [PATCH] Fix off-by-one overflow in the IP protocol table.

Fixes #2896, closes #2897, closes #2900

Upstream-Status: Backport [https://github.com/nmap/nmap/commit/efa0dc36f2ecade6ba8d2ed25dd4d5fbffdea308]

Signed-off-by: Wang Mingyu <wangmy@fujitsu.com>
---
 CHANGELOG     |  3 +++
 portlist.cc   |  8 ++++----
 protocols.cc  |  6 +++---
 protocols.h   |  2 ++
 scan_lists.cc | 10 +++++-----
 5 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/CHANGELOG b/CHANGELOG
index f01262c..5b204bd 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,5 +1,8 @@
 #Nmap Changelog ($Id: CHANGELOG 38849 2024-04-18 17:16:42Z dmiller $); -*-text-*-
 
+o [GH#2900, GH#2896, GH#2897] Nmap is now able to scan IP protocol 255.
+  [nnposter]
+
 Nmap 7.95 [2024-04-19]
 
 o [Windows] Upgraded Npcap (our Windows raw packet capturing and
diff --git a/portlist.cc b/portlist.cc
index 8258853..cd08437 100644
--- a/portlist.cc
+++ b/portlist.cc
@@ -480,7 +480,7 @@ void PortList::setPortState(u16 portno, u8 protocol, int state, int *oldstate) {
       state != PORT_CLOSEDFILTERED)
     fatal("%s: attempt to add port number %d with illegal state %d\n", __func__, portno, state);
 
-  assert(protocol!=IPPROTO_IP || portno<256);
+  assert(protocol!=IPPROTO_IP || portno<=MAX_IPPROTONUM);
 
   bool created = false;
   current = createPort(portno, protocol, &created);
@@ -566,7 +566,7 @@ Port *PortList::nextPort(const Port *cur, Port *next,
   if (cur) {
     proto = INPROTO2PORTLISTPROTO(cur->proto);
     assert(port_map[proto]!=NULL); // Hmm, it's not possible to handle port that doesn't have anything in map
-    assert(cur->proto!=IPPROTO_IP || cur->portno<256);
+    assert(cur->proto!=IPPROTO_IP || cur->portno<=MAX_IPPROTONUM);
     mapped_pno = port_map[proto][cur->portno];
     mapped_pno++; //  we're interested in next port after current
   } else { // running for the first time
@@ -615,7 +615,7 @@ void PortList::mapPort(u16 *portno, u8 *protocol) const {
   mapped_protocol = INPROTO2PORTLISTPROTO(*protocol);
 
   if (*protocol == IPPROTO_IP)
-    assert(*portno < 256);
+    assert(*portno <= MAX_IPPROTONUM);
   if(port_map[mapped_protocol]==NULL || port_list[mapped_protocol]==NULL) {
     fatal("%s(%i,%i): you're trying to access uninitialized protocol", __func__, *portno, *protocol);
   }
@@ -713,7 +713,7 @@ int PortList::port_list_count[PORTLIST_PROTO_MAX];
  * should be sorted. */
 void PortList::initializePortMap(int protocol, u16 *ports, int portcount) {
   int i;
-  int ports_max = (protocol == IPPROTO_IP) ? 256 : 65536;
+  int ports_max = (protocol == IPPROTO_IP) ? MAX_IPPROTONUM + 1 : 65536;
   int proto = INPROTO2PORTLISTPROTO(protocol);
 
   if (port_map[proto] != NULL || port_map_rev[proto] != NULL)
diff --git a/protocols.cc b/protocols.cc
index 76e42c7..85e55e4 100644
--- a/protocols.cc
+++ b/protocols.cc
@@ -79,7 +79,7 @@ struct strcmp_comparator {
 
 // IP Protocol number is 8 bits wide
 // protocol_table[IPPROTO_TCP] == {"tcp", 6}
-static struct nprotoent *protocol_table[UCHAR_MAX];
+static struct nprotoent *protocol_table[MAX_IPPROTONUM + 1];
 // proto_map["tcp"] = {"tcp", 6}
 typedef std::map<const char *, struct nprotoent, strcmp_comparator> ProtoMap;
 static ProtoMap proto_map;
@@ -119,7 +119,7 @@ static int nmap_protocols_init() {
     if (*p == '#' || *p == '\0')
       continue;
     res = sscanf(line, "%127s %hu", protocolname, &protno);
-    if (res !=2 || protno > UCHAR_MAX) {
+    if (res !=2 || protno > MAX_IPPROTONUM) {
       error("Parse error in protocols file %s line %d", filename, lineno);
       continue;
     }
@@ -191,7 +191,7 @@ const struct nprotoent *nmap_getprotbynum(int num) {
   if (nmap_protocols_init() == -1)
     return NULL;
 
-  assert(num >= 0 && num < UCHAR_MAX);
+  assert(num >= 0 && num <= MAX_IPPROTONUM);
   return protocol_table[num];
 }
 
diff --git a/protocols.h b/protocols.h
index 8934284..2de0aa4 100644
--- a/protocols.h
+++ b/protocols.h
@@ -79,6 +79,8 @@ int addprotocolsfromservmask(char *mask, u8 *porttbl);
 const struct nprotoent *nmap_getprotbynum(int num);
 const struct nprotoent *nmap_getprotbyname(const char *name);
 
+#define MAX_IPPROTONUM 255
+
 #define MAX_IPPROTOSTRLEN 4
 #define IPPROTO2STR(p)		\
   ((p)==IPPROTO_TCP ? "tcp" :	\
diff --git a/scan_lists.cc b/scan_lists.cc
index f02e279..ebe1357 100644
--- a/scan_lists.cc
+++ b/scan_lists.cc
@@ -165,7 +165,7 @@ void getpts(const char *origexpr, struct scan_lists *ports) {
       ports->udp_count++;
     if (porttbl[i] & SCAN_SCTP_PORT)
       ports->sctp_count++;
-    if (porttbl[i] & SCAN_PROTOCOLS && i < 256)
+    if (porttbl[i] & SCAN_PROTOCOLS && i <= MAX_IPPROTONUM)
       ports->prot_count++;
   }
 
@@ -192,7 +192,7 @@ void getpts(const char *origexpr, struct scan_lists *ports) {
       ports->udp_ports[udpi++] = i;
     if (porttbl[i] & SCAN_SCTP_PORT)
       ports->sctp_ports[sctpi++] = i;
-    if (porttbl[i] & SCAN_PROTOCOLS && i < 256)
+    if (porttbl[i] & SCAN_PROTOCOLS && i <= MAX_IPPROTONUM)
       ports->prots[proti++] = i;
   }
 
@@ -388,7 +388,7 @@ static void getpts_aux(const char *origexpr, int nested, u8 *porttbl, int range_
     } else if (isdigit((int) (unsigned char) *current_range)) {
       rangestart = strtol(current_range, &endptr, 10);
       if (range_type & SCAN_PROTOCOLS) {
-        if (rangestart < 0 || rangestart > 255)
+        if (rangestart < 0 || rangestart > MAX_IPPROTONUM)
           fatal("Protocols specified must be between 0 and 255 inclusive");
       } else {
         if (rangestart < 0 || rangestart > 65535)
@@ -429,13 +429,13 @@ static void getpts_aux(const char *origexpr, int nested, u8 *porttbl, int range_
       if (!*current_range || *current_range == ',' || *current_range == ']') {
         /* Ended with a -, meaning up until the last possible port */
         if (range_type & SCAN_PROTOCOLS)
-          rangeend = 255;
+          rangeend = MAX_IPPROTONUM;
         else
           rangeend = 65535;
       } else if (isdigit((int) (unsigned char) *current_range)) {
         rangeend = strtol(current_range, &endptr, 10);
         if (range_type & SCAN_PROTOCOLS) {
-          if (rangeend < 0 || rangeend > 255)
+          if (rangeend < 0 || rangeend > MAX_IPPROTONUM)
             fatal("Protocols specified must be between 0 and 255 inclusive");
         } else {
           if (rangeend < 0 || rangeend > 65535)
-- 
2.34.1

