diff mbox series

[ovs-dev] conntrack: handle SNAT with NULL IP address

Message ID 161719474291.264572.15825110114977895266.stgit@fed.void
State Superseded, archived
Headers show
Series [ovs-dev] conntrack: handle SNAT with NULL IP address | expand

Commit Message

Paolo Valerio March 31, 2021, 12:47 p.m. UTC
this patch introduces for the userspace datapath the handling
of rules like the following:

ct(commit,nat(src=0.0.0.0),...)

Kernel datapath already handle this case that is particularly
handy in scenarios like the following:

Given A: 10.1.1.1, B: 192.168.2.100, C: 10.1.1.2

A opens a connection toward B on port 80 selecting as source port 10000.
B's IP gets dnat'ed to C's IP (10.1.1.1:10000 -> 192.168.2.100:80).

This will result in:

tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)

A now tries to establish another connection with C using source port
10000, this time using C's IP address (10.1.1.1:10000 -> 10.1.1.2:80).

This second connection, if processed by conntrack with no SNAT/DNAT
involved, collides with the reverse tuple of the first connection,
so the entry for this valid connection doesn't get created.

With this commit, and adding a NULL SNAT rule for
10.1.1.1:10000 -> 10.1.1.2:80 will allow to create the conn entry:

tcp,orig=(src=10.1.1.1,dst=10.1.1.2,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10001),protoinfo=(state=ESTABLISHED)
tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)

The issue exists even in the opposite case (with A trying to connect
to C using B's IP after establishing a direct connection from A to C).

This commit refactors the relevant function in a way that both of the
previously mentioned cases are handled as well.

Suggested-by: Eelco Chaudron <echaudro@redhat.com>
Signed-off-by: Paolo Valerio <pvalerio@redhat.com>
---
Unit test for userspace will be added to [1] once merged.

[1] https://patchwork.ozlabs.org/project/openvswitch/patch/161710710690.181407.5749135681436588686.stgit@ebuild/

 lib/conntrack.c |  340 ++++++++++++++++++++++++++++++++++---------------------
 lib/conntrack.h |   15 ++
 2 files changed, 228 insertions(+), 127 deletions(-)

Comments

Eelco Chaudron March 31, 2021, 12:59 p.m. UTC | #1
On 31 Mar 2021, at 14:47, Paolo Valerio wrote:

> this patch introduces for the userspace datapath the handling
> of rules like the following:
>
> ct(commit,nat(src=0.0.0.0),...)
>
> Kernel datapath already handle this case that is particularly
> handy in scenarios like the following:
>
> Given A: 10.1.1.1, B: 192.168.2.100, C: 10.1.1.2
>
> A opens a connection toward B on port 80 selecting as source port 
> 10000.
> B's IP gets dnat'ed to C's IP (10.1.1.1:10000 -> 192.168.2.100:80).
>
> This will result in:
>
> tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)
>
> A now tries to establish another connection with C using source port
> 10000, this time using C's IP address (10.1.1.1:10000 -> 10.1.1.2:80).
>
> This second connection, if processed by conntrack with no SNAT/DNAT
> involved, collides with the reverse tuple of the first connection,
> so the entry for this valid connection doesn't get created.
>
> With this commit, and adding a NULL SNAT rule for
> 10.1.1.1:10000 -> 10.1.1.2:80 will allow to create the conn entry:
>
> tcp,orig=(src=10.1.1.1,dst=10.1.1.2,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10001),protoinfo=(state=ESTABLISHED)
> tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)
>
> The issue exists even in the opposite case (with A trying to connect
> to C using B's IP after establishing a direct connection from A to C).
>
> This commit refactors the relevant function in a way that both of the
> previously mentioned cases are handled as well.
>
> Suggested-by: Eelco Chaudron <echaudro@redhat.com>
> Signed-off-by: Paolo Valerio <pvalerio@redhat.com>
> ---
> Unit test for userspace will be added to [1] once merged.
>
> [1] 
> https://patchwork.ozlabs.org/project/openvswitch/patch/161710710690.181407.5749135681436588686.stgit@ebuild/

Paolo you need to update tests/system-userspace-macros.at to execute the 
test, i.e. remove the AT_SKIP_IF([:]) part.

>  lib/conntrack.c |  340 
> ++++++++++++++++++++++++++++++++++---------------------
>  lib/conntrack.h |   15 ++
>  2 files changed, 228 insertions(+), 127 deletions(-)
>
> diff --git a/lib/conntrack.c b/lib/conntrack.c
> index 99198a601..da69f63ef 100644
> --- a/lib/conntrack.c
> +++ b/lib/conntrack.c
> @@ -108,9 +108,8 @@ static void set_label(struct dp_packet *, struct 
> conn *,
>  static void *clean_thread_main(void *f_);
>
>  static bool
> -nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
> -                       struct conn *nat_conn);
> -
> +nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
> +                     struct conn *nat_conn);
>  static uint8_t
>  reverse_icmp_type(uint8_t type);
>  static uint8_t
> @@ -728,11 +727,11 @@ pat_packet(struct dp_packet *pkt, const struct 
> conn *conn)
>          }
>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>          if (conn->key.nw_proto == IPPROTO_TCP) {
> -            struct tcp_header *th = dp_packet_l4(pkt);
> -            packet_set_tcp_port(pkt, th->tcp_src, 
> conn->rev_key.src.port);
> +            packet_set_tcp_port(pkt, conn->rev_key.dst.port,
> +                                conn->rev_key.src.port);
>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
> -            struct udp_header *uh = dp_packet_l4(pkt);
> -            packet_set_udp_port(pkt, uh->udp_src, 
> conn->rev_key.src.port);
> +            packet_set_udp_port(pkt, conn->rev_key.dst.port,
> +                                conn->rev_key.src.port);
>          }
>      }
>  }
> @@ -786,11 +785,9 @@ un_pat_packet(struct dp_packet *pkt, const struct 
> conn *conn)
>          }
>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>          if (conn->key.nw_proto == IPPROTO_TCP) {
> -            struct tcp_header *th = dp_packet_l4(pkt);
> -            packet_set_tcp_port(pkt, conn->key.dst.port, 
> th->tcp_dst);
> +            packet_set_tcp_port(pkt, conn->key.dst.port, 
> conn->key.src.port);
>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
> -            struct udp_header *uh = dp_packet_l4(pkt);
> -            packet_set_udp_port(pkt, conn->key.dst.port, 
> uh->udp_dst);
> +            packet_set_udp_port(pkt, conn->key.dst.port, 
> conn->key.src.port);
>          }
>      }
>  }
> @@ -810,12 +807,10 @@ reverse_pat_packet(struct dp_packet *pkt, const 
> struct conn *conn)
>          }
>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>          if (conn->key.nw_proto == IPPROTO_TCP) {
> -            struct tcp_header *th_in = dp_packet_l4(pkt);
> -            packet_set_tcp_port(pkt, th_in->tcp_src,
> +            packet_set_tcp_port(pkt, conn->key.src.port,
>                                  conn->key.dst.port);
>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
> -            struct udp_header *uh_in = dp_packet_l4(pkt);
> -            packet_set_udp_port(pkt, uh_in->udp_src,
> +            packet_set_udp_port(pkt, conn->key.src.port,
>                                  conn->key.dst.port);
>          }
>      }
> @@ -1029,14 +1024,14 @@ conn_not_found(struct conntrack *ct, struct 
> dp_packet *pkt,
>                  }
>              } else {
>                  memcpy(nat_conn, nc, sizeof *nat_conn);
> -                bool nat_res = nat_select_range_tuple(ct, nc, 
> nat_conn);
> +                bool nat_res = nat_get_unique_tuple(ct, nc, 
> nat_conn);
>
>                  if (!nat_res) {
>                      goto nat_res_exhaustion;
>                  }
>
>                  /* Update nc with nat adjustments made to nat_conn by
> -                 * nat_select_range_tuple(). */
> +                 * nat_get_unique_tuple(). */
>                  memcpy(nc, nat_conn, sizeof *nc);
>              }
>
> @@ -1391,7 +1386,6 @@ process_one(struct conntrack *ct, struct 
> dp_packet *pkt,
>
>      set_cached_conn(nat_action_info, ctx, conn, pkt);
>  }
> -
>  /* Sends the packets in '*pkt_batch' through the connection tracker 
> 'ct'.  All
>   * the packets must have the same 'dl_type' (IPv4 or IPv6) and should 
> have
>   * the l3 and and l4 offset properly set.  Performs fragment 
> reassembly with
> @@ -1436,7 +1430,6 @@ conntrack_execute(struct conntrack *ct, struct 
> dp_packet_batch *pkt_batch,
>      }
>
>      ipf_postprocess_conntrack(ct->ipf, pkt_batch, now, dl_type);
> -
>      return 0;
>  }
>
> @@ -2210,130 +2203,223 @@ nat_range_hash(const struct conn *conn, 
> uint32_t basis)
>      return hash_finish(hash, 0);
>  }
>
> -static bool
> -nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
> -                       struct conn *nat_conn)
> -{
> -    enum { MIN_NAT_EPHEMERAL_PORT = 1024,
> -           MAX_NAT_EPHEMERAL_PORT = 65535 };
> -
> -    uint16_t min_port;
> -    uint16_t max_port;
> -    uint16_t first_port;
> -    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
> +/* Ports are stored in host byte order for convenience. */
> +static void
> +set_sport_range(struct nat_action_info_t *ni, const struct conn_key 
> *k,
> +                uint32_t hash, uint16_t *curr, uint16_t *min,
> +                uint16_t *max)
> +{
> +    if (((ni->nat_action & NAT_ACTION_SNAT_ALL) == NAT_ACTION_SRC) ||
> +        ((ni->nat_action & NAT_ACTION_DST))) {
> +        *curr = ntohs(k->src.port);
> +        *min = MIN_NAT_EPHEMERAL_PORT;
> +        *max = MAX_NAT_EPHEMERAL_PORT;
> +    } else {
> +        *min = ni->min_port;
> +        *max = ni->max_port;
> +        *curr = *min + (hash % ((*max - *min) + 1));
> +    }
> +}
>
> -    if ((conn->nat_info->nat_action & NAT_ACTION_SRC) &&
> -        (!(conn->nat_info->nat_action & NAT_ACTION_SRC_PORT))) {
> -        min_port = ntohs(conn->key.src.port);
> -        max_port = ntohs(conn->key.src.port);
> -        first_port = min_port;
> -    } else if ((conn->nat_info->nat_action & NAT_ACTION_DST) &&
> -               (!(conn->nat_info->nat_action & NAT_ACTION_DST_PORT))) 
> {
> -        min_port = ntohs(conn->key.dst.port);
> -        max_port = ntohs(conn->key.dst.port);
> -        first_port = min_port;
> +static void
> +set_dport_range(struct nat_action_info_t *ni, const struct conn_key 
> *k,
> +                uint32_t hash, uint16_t *curr, uint16_t *min,
> +                uint16_t *max)
> +{
> +    if (ni->nat_action & NAT_ACTION_DST_PORT) {
> +        *min = ni->min_port;
> +        *max = ni->max_port;
> +        *curr = *min + (hash % ((*max - *min) + 1));
>      } else {
> -        uint16_t deltap = conn->nat_info->max_port - 
> conn->nat_info->min_port;
> -        uint32_t port_index = hash % (deltap + 1);
> -        first_port = conn->nat_info->min_port + port_index;
> -        min_port = conn->nat_info->min_port;
> -        max_port = conn->nat_info->max_port;
> +        *curr = ntohs(k->dst.port);
> +        *min = *max = *curr;
>      }
> +}
>
> -    uint32_t deltaa = 0;
> -    uint32_t address_index;
> -    union ct_addr ct_addr;
> -    memset(&ct_addr, 0, sizeof ct_addr);
> -    union ct_addr max_ct_addr;
> -    memset(&max_ct_addr, 0, sizeof max_ct_addr);
> -    max_ct_addr = conn->nat_info->max_addr;
> +/* Gets the initial in range address based on the hash.
> + * Addresses are kept in network order. */
> +static void
> +get_addr_in_range(union ct_addr *min, union ct_addr *max,
> +                  union ct_addr *curr, uint32_t hash,
> +                  bool ipv4)
> +{
> +    uint32_t offt, range;
>
> -    if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
> -        deltaa = ntohl(conn->nat_info->max_addr.ipv4) -
> -                 ntohl(conn->nat_info->min_addr.ipv4);
> -        address_index = hash % (deltaa + 1);
> -        ct_addr.ipv4 = htonl(
> -            ntohl(conn->nat_info->min_addr.ipv4) + address_index);
> +    if (ipv4) {
> +        range = (ntohl(max->ipv4) - ntohl(min->ipv4)) + 1;
> +        offt = hash % range;
> +        curr->ipv4 = htonl(ntohl(min->ipv4) + offt);
>      } else {
> -        deltaa = nat_ipv6_addrs_delta(&conn->nat_info->min_addr.ipv6,
> -                                      
> &conn->nat_info->max_addr.ipv6);
> -        /* deltaa must be within 32 bits for full hash coverage. A 64 
> or
> +        range = nat_ipv6_addrs_delta(&min->ipv6,
> +                                     &max->ipv6) + 1;
> +        /* range must be within 32 bits for full hash coverage. A 64 
> or
>           * 128 bit hash is unnecessary and hence not used here. Most 
> code
>           * is kept common with V4; nat_ipv6_addrs_delta() will do the
>           * enforcement via max_ct_addr. */
> -        max_ct_addr = conn->nat_info->min_addr;
> -        nat_ipv6_addr_increment(&max_ct_addr.ipv6, deltaa);
> -        address_index = hash % (deltaa + 1);
> -        ct_addr.ipv6 = conn->nat_info->min_addr.ipv6;
> -        nat_ipv6_addr_increment(&ct_addr.ipv6, address_index);
> -    }
> -
> -    uint16_t port = first_port;
> -    bool all_ports_tried = false;
> -    /* For DNAT or for specified port ranges, we don't use ephemeral 
> ports. */
> -    bool ephemeral_ports_tried
> -        = conn->nat_info->nat_action & NAT_ACTION_DST ||
> -              conn->nat_info->nat_action & NAT_ACTION_SRC_PORT
> -          ? true : false;
> -    union ct_addr first_addr = ct_addr;
> -    bool pat_enabled = conn->key.nw_proto == IPPROTO_TCP ||
> -                       conn->key.nw_proto == IPPROTO_UDP;
> -
> -    while (true) {
> +        offt = hash % range;
> +        curr->ipv6 = min->ipv6;
> +        nat_ipv6_addr_increment(&curr->ipv6, offt);
> +    }
> +}
> +
> +static void
> +get_initial_addr(const struct conn *conn, union ct_addr *min,
> +                 union ct_addr *max, union ct_addr *curr,
> +                 uint32_t hash, bool ipv4)
> +{
> +    const union ct_addr zero_ip = {0};
> +
> +    /* NULL CASE */
> +    if (!memcmp(min, &zero_ip, sizeof(*min))) {
>          if (conn->nat_info->nat_action & NAT_ACTION_SRC) {
> -            nat_conn->rev_key.dst.addr = ct_addr;
> -            if (pat_enabled) {
> -                nat_conn->rev_key.dst.port = htons(port);
> -            }
> -        } else {
> -            nat_conn->rev_key.src.addr = ct_addr;
> -            if (pat_enabled) {
> -                nat_conn->rev_key.src.port = htons(port);
> -            }
> +            *curr = conn->key.src.addr;
> +        } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
> +            *curr = conn->key.dst.addr;
>          }
> +    } else {
> +        get_addr_in_range(min, max, curr, hash, ipv4);
> +    }
> +}
>
> -        bool found = conn_lookup(ct, &nat_conn->rev_key, time_msec(), 
> NULL,
> -                                 NULL);
> -        if (!found) {
> +/* if action is src, store to dst, otherwise store src
> + * if src is NULL, do not store anything. */
> +static void
> +store_addr_to_key(union ct_addr *addr, struct conn_key *key,
> +                  uint16_t action)
> +{
> +    if (action & NAT_ACTION_SRC) {
> +        key->dst.addr = *addr;
> +    } else {
> +        key->src.addr = *addr;
> +    }
> +}
> +
> +static void
> +next_addr_in_range(union ct_addr *curr, union ct_addr *min,
> +                   union ct_addr *max, bool ipv4)
> +{
> +    if (ipv4) {
> +        /* this check could be unified with IPv6, but let's avoid
> +         * an unneeded memcmp() in case of IPv4. */
> +        if (min->ipv4 == max->ipv4) {
> +            return;
> +        }
> +
> +        curr->ipv4 = (curr->ipv4 == max->ipv4) ?
> +                      min->ipv4 :
> +                      htonl(ntohl(curr->ipv4) + 1);
> +    } else {
> +        if (!memcmp(min, max, sizeof(*min))) {
> +            return;
> +        }
> +
> +        if (!memcmp(curr, max, sizeof(*curr))) {
> +            *curr = *min;
> +            return;
> +        }
> +
> +        nat_ipv6_addr_increment(&curr->ipv6, 1);
> +    }
> +}
> +
> +static bool
> +next_addr_in_range_guarded(union ct_addr *curr, union ct_addr *min,
> +                           union ct_addr *max, union ct_addr *guard,
> +                           bool ipv4)
> +{
> +    bool exhausted;
> +
> +    next_addr_in_range(curr, min, max, ipv4);
> +
> +    if (ipv4) {
> +        exhausted = (curr->ipv4 == guard->ipv4);
> +    } else {
> +        exhausted = !memcmp(curr, guard, sizeof(*curr));
> +    }
> +
> +    return exhausted;
> +}
> +
> +/* This function tries to get a unique tuple.
> + * Every iteration checks that the reverse tuple doesn't
> + * collide with any existing one.
> + *
> + * in case of SNAT:
> + *    - for each src IP address in the range (if any)
> + *        - try to find a source port in range (if any)
> + *        - if no port range exists, use the whole
> + *          ephemeral range (starting from the port
> + *          used by the client)
> + *
> + * in case of DNAT:
> + *    - for each dst IP address in the range (if any)
> + *        - for each dport in range (if any)
> + *             - try to find a source port in the ephemeral range
> + *               (starting from the port used by the client)
> + *
> + * If none can be found, return exhaustion to the caller. */
> +static bool
> +nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
> +                     struct conn *nat_conn)
> +{
> +    union ct_addr min_addr = {0}, max_addr = {0}, curr_addr = {0},
> +                  guard_addr = {0};
> +    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
> +    bool pat_proto = conn->key.nw_proto == IPPROTO_TCP ||
> +                     conn->key.nw_proto == IPPROTO_UDP;
> +    uint16_t min_dport, max_dport, curr_dport;
> +    uint16_t min_sport, max_sport, curr_sport;
> +
> +    min_addr = conn->nat_info->min_addr;
> +    max_addr = conn->nat_info->max_addr;
> +
> +    get_initial_addr(conn, &min_addr, &max_addr, &curr_addr, hash,
> +                     (conn->key.dl_type == htons(ETH_TYPE_IP)));
> +
> +    /* save the address we started from so that
> +     * we can stop once we reach it. */
> +    guard_addr = curr_addr;
> +
> +    set_sport_range(conn->nat_info, &conn->key, hash, &curr_sport,
> +                    &min_sport, &max_sport);
> +    set_dport_range(conn->nat_info, &conn->key, hash, &curr_dport,
> +                    &min_dport, &max_dport);
> +
> +another_round:
> +    store_addr_to_key(&curr_addr, &nat_conn->rev_key,
> +                      conn->nat_info->nat_action);
> +
> +    if (!pat_proto) {
> +        if (!conn_lookup(ct, &nat_conn->rev_key,
> +                         time_msec(), NULL, NULL)) {
>              return true;
> -        } else if (pat_enabled && !all_ports_tried) {
> -            if (min_port == max_port) {
> -                all_ports_tried = true;
> -            } else if (port == max_port) {
> -                port = min_port;
> -            } else {
> -                port++;
> -            }
> -            if (port == first_port) {
> -                all_ports_tried = true;
> -            }
> -        } else {
> -            if (memcmp(&ct_addr, &max_ct_addr, sizeof ct_addr)) {
> -                if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
> -                    ct_addr.ipv4 = htonl(ntohl(ct_addr.ipv4) + 1);
> -                } else {
> -                    nat_ipv6_addr_increment(&ct_addr.ipv6, 1);
> -                }
> -            } else {
> -                ct_addr = conn->nat_info->min_addr;
> -            }
> -            if (!memcmp(&ct_addr, &first_addr, sizeof ct_addr)) {
> -                if (pat_enabled && !ephemeral_ports_tried) {
> -                    ephemeral_ports_tried = true;
> -                    ct_addr = conn->nat_info->min_addr;
> -                    first_addr = ct_addr;
> -                    min_port = MIN_NAT_EPHEMERAL_PORT;
> -                    max_port = MAX_NAT_EPHEMERAL_PORT;
> -                } else {
> -                    break;
> -                }
> +        }
> +
> +        goto next_addr;
> +    }
> +
> +    int i, j;
> +    FOR_EACH_PORT_IN_RANGE(i, curr_dport, min_dport, max_dport) {
> +        nat_conn->rev_key.src.port = htons(curr_dport);
> +        FOR_EACH_PORT_IN_RANGE(j, curr_sport, min_sport, max_sport) {
> +            nat_conn->rev_key.dst.port = htons(curr_sport);
> +            if (!conn_lookup(ct, &nat_conn->rev_key,
> +                             time_msec(), NULL, NULL)) {
> +                return true;
>              }
> -            first_port = min_port;
> -            port = first_port;
> -            all_ports_tried = false;
>          }
>      }
> -    return false;
> +
> +    /* Check if next IP is in range and respin. Otherwise, notify
> +     * exhaustion to the caller. */
> +next_addr:
> +    if (next_addr_in_range_guarded(&curr_addr, &min_addr,
> +                                   &max_addr, &guard_addr,
> +                                   conn->key.dl_type == 
> htons(ETH_TYPE_IP))) {
> +        return false;
> +    }
> +
> +    goto another_round;
>  }
>
>  static enum ct_update_res
> diff --git a/lib/conntrack.h b/lib/conntrack.h
> index 9553b188a..6ce1cd216 100644
> --- a/lib/conntrack.h
> +++ b/lib/conntrack.h
> @@ -77,6 +77,14 @@ enum nat_action_e {
>      NAT_ACTION_DST_PORT = 1 << 3,
>  };
>
> +#define NAT_ACTION_SNAT_ALL (NAT_ACTION_SRC | NAT_ACTION_SRC_PORT)
> +#define NAT_ACTION_DNAT_ALL (NAT_ACTION_DST | NAT_ACTION_DST_PORT)
> +
> +enum {
> +    MIN_NAT_EPHEMERAL_PORT = 1024,
> +    MAX_NAT_EPHEMERAL_PORT = 65535
> +};
> +
>  struct nat_action_info_t {
>      union ct_addr min_addr;
>      union ct_addr max_addr;
> @@ -85,6 +93,13 @@ struct nat_action_info_t {
>      uint16_t nat_action;
>  };
>
> +#define NEXT_PORT_IN_RANGE(curr, min, max) \
> +    curr = (curr == max) ? min : curr + 1
> +
> +#define FOR_EACH_PORT_IN_RANGE(idx, curr, min, max) \
> +    for (idx = 0; idx < (max - min) + 1; idx++, \
> +             NEXT_PORT_IN_RANGE(curr, min, max))
> +
>  struct conntrack *conntrack_init(void);
>  void conntrack_destroy(struct conntrack *);
Paolo Valerio March 31, 2021, 5:02 p.m. UTC | #2
"Eelco Chaudron" <echaudro@redhat.com> writes:

> On 31 Mar 2021, at 14:47, Paolo Valerio wrote:
>
>> this patch introduces for the userspace datapath the handling
>> of rules like the following:
>>
>> ct(commit,nat(src=0.0.0.0),...)
>>
>> Kernel datapath already handle this case that is particularly
>> handy in scenarios like the following:
>>
>> Given A: 10.1.1.1, B: 192.168.2.100, C: 10.1.1.2
>>
>> A opens a connection toward B on port 80 selecting as source port 
>> 10000.
>> B's IP gets dnat'ed to C's IP (10.1.1.1:10000 -> 192.168.2.100:80).
>>
>> This will result in:
>>
>> tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)
>>
>> A now tries to establish another connection with C using source port
>> 10000, this time using C's IP address (10.1.1.1:10000 -> 10.1.1.2:80).
>>
>> This second connection, if processed by conntrack with no SNAT/DNAT
>> involved, collides with the reverse tuple of the first connection,
>> so the entry for this valid connection doesn't get created.
>>
>> With this commit, and adding a NULL SNAT rule for
>> 10.1.1.1:10000 -> 10.1.1.2:80 will allow to create the conn entry:
>>
>> tcp,orig=(src=10.1.1.1,dst=10.1.1.2,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10001),protoinfo=(state=ESTABLISHED)
>> tcp,orig=(src=10.1.1.1,dst=192.168.2.100,sport=10000,dport=80),reply=(src=10.1.1.2,dst=10.1.1.1,sport=80,dport=10000),protoinfo=(state=ESTABLISHED)
>>
>> The issue exists even in the opposite case (with A trying to connect
>> to C using B's IP after establishing a direct connection from A to C).
>>
>> This commit refactors the relevant function in a way that both of the
>> previously mentioned cases are handled as well.
>>
>> Suggested-by: Eelco Chaudron <echaudro@redhat.com>
>> Signed-off-by: Paolo Valerio <pvalerio@redhat.com>
>> ---
>> Unit test for userspace will be added to [1] once merged.
>>
>> [1] 
>> https://patchwork.ozlabs.org/project/openvswitch/patch/161710710690.181407.5749135681436588686.stgit@ebuild/
>
> Paolo you need to update tests/system-userspace-macros.at to execute the 
> test, i.e. remove the AT_SKIP_IF([:]) part.
>

Ok. It's not merged yet, so I was unsure about adding a dependency.
I'll respin a v2 with a note for the maintainers. Thanks.

>>  lib/conntrack.c |  340 
>> ++++++++++++++++++++++++++++++++++---------------------
>>  lib/conntrack.h |   15 ++
>>  2 files changed, 228 insertions(+), 127 deletions(-)
>>
>> diff --git a/lib/conntrack.c b/lib/conntrack.c
>> index 99198a601..da69f63ef 100644
>> --- a/lib/conntrack.c
>> +++ b/lib/conntrack.c
>> @@ -108,9 +108,8 @@ static void set_label(struct dp_packet *, struct 
>> conn *,
>>  static void *clean_thread_main(void *f_);
>>
>>  static bool
>> -nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
>> -                       struct conn *nat_conn);
>> -
>> +nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
>> +                     struct conn *nat_conn);
>>  static uint8_t
>>  reverse_icmp_type(uint8_t type);
>>  static uint8_t
>> @@ -728,11 +727,11 @@ pat_packet(struct dp_packet *pkt, const struct 
>> conn *conn)
>>          }
>>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>>          if (conn->key.nw_proto == IPPROTO_TCP) {
>> -            struct tcp_header *th = dp_packet_l4(pkt);
>> -            packet_set_tcp_port(pkt, th->tcp_src, 
>> conn->rev_key.src.port);
>> +            packet_set_tcp_port(pkt, conn->rev_key.dst.port,
>> +                                conn->rev_key.src.port);
>>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
>> -            struct udp_header *uh = dp_packet_l4(pkt);
>> -            packet_set_udp_port(pkt, uh->udp_src, 
>> conn->rev_key.src.port);
>> +            packet_set_udp_port(pkt, conn->rev_key.dst.port,
>> +                                conn->rev_key.src.port);
>>          }
>>      }
>>  }
>> @@ -786,11 +785,9 @@ un_pat_packet(struct dp_packet *pkt, const struct 
>> conn *conn)
>>          }
>>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>>          if (conn->key.nw_proto == IPPROTO_TCP) {
>> -            struct tcp_header *th = dp_packet_l4(pkt);
>> -            packet_set_tcp_port(pkt, conn->key.dst.port, 
>> th->tcp_dst);
>> +            packet_set_tcp_port(pkt, conn->key.dst.port, 
>> conn->key.src.port);
>>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
>> -            struct udp_header *uh = dp_packet_l4(pkt);
>> -            packet_set_udp_port(pkt, conn->key.dst.port, 
>> uh->udp_dst);
>> +            packet_set_udp_port(pkt, conn->key.dst.port, 
>> conn->key.src.port);
>>          }
>>      }
>>  }
>> @@ -810,12 +807,10 @@ reverse_pat_packet(struct dp_packet *pkt, const 
>> struct conn *conn)
>>          }
>>      } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>>          if (conn->key.nw_proto == IPPROTO_TCP) {
>> -            struct tcp_header *th_in = dp_packet_l4(pkt);
>> -            packet_set_tcp_port(pkt, th_in->tcp_src,
>> +            packet_set_tcp_port(pkt, conn->key.src.port,
>>                                  conn->key.dst.port);
>>          } else if (conn->key.nw_proto == IPPROTO_UDP) {
>> -            struct udp_header *uh_in = dp_packet_l4(pkt);
>> -            packet_set_udp_port(pkt, uh_in->udp_src,
>> +            packet_set_udp_port(pkt, conn->key.src.port,
>>                                  conn->key.dst.port);
>>          }
>>      }
>> @@ -1029,14 +1024,14 @@ conn_not_found(struct conntrack *ct, struct 
>> dp_packet *pkt,
>>                  }
>>              } else {
>>                  memcpy(nat_conn, nc, sizeof *nat_conn);
>> -                bool nat_res = nat_select_range_tuple(ct, nc, 
>> nat_conn);
>> +                bool nat_res = nat_get_unique_tuple(ct, nc, 
>> nat_conn);
>>
>>                  if (!nat_res) {
>>                      goto nat_res_exhaustion;
>>                  }
>>
>>                  /* Update nc with nat adjustments made to nat_conn by
>> -                 * nat_select_range_tuple(). */
>> +                 * nat_get_unique_tuple(). */
>>                  memcpy(nc, nat_conn, sizeof *nc);
>>              }
>>
>> @@ -1391,7 +1386,6 @@ process_one(struct conntrack *ct, struct 
>> dp_packet *pkt,
>>
>>      set_cached_conn(nat_action_info, ctx, conn, pkt);
>>  }
>> -
>>  /* Sends the packets in '*pkt_batch' through the connection tracker 
>> 'ct'.  All
>>   * the packets must have the same 'dl_type' (IPv4 or IPv6) and should 
>> have
>>   * the l3 and and l4 offset properly set.  Performs fragment 
>> reassembly with
>> @@ -1436,7 +1430,6 @@ conntrack_execute(struct conntrack *ct, struct 
>> dp_packet_batch *pkt_batch,
>>      }
>>
>>      ipf_postprocess_conntrack(ct->ipf, pkt_batch, now, dl_type);
>> -
>>      return 0;
>>  }
>>
>> @@ -2210,130 +2203,223 @@ nat_range_hash(const struct conn *conn, 
>> uint32_t basis)
>>      return hash_finish(hash, 0);
>>  }
>>
>> -static bool
>> -nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
>> -                       struct conn *nat_conn)
>> -{
>> -    enum { MIN_NAT_EPHEMERAL_PORT = 1024,
>> -           MAX_NAT_EPHEMERAL_PORT = 65535 };
>> -
>> -    uint16_t min_port;
>> -    uint16_t max_port;
>> -    uint16_t first_port;
>> -    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
>> +/* Ports are stored in host byte order for convenience. */
>> +static void
>> +set_sport_range(struct nat_action_info_t *ni, const struct conn_key 
>> *k,
>> +                uint32_t hash, uint16_t *curr, uint16_t *min,
>> +                uint16_t *max)
>> +{
>> +    if (((ni->nat_action & NAT_ACTION_SNAT_ALL) == NAT_ACTION_SRC) ||
>> +        ((ni->nat_action & NAT_ACTION_DST))) {
>> +        *curr = ntohs(k->src.port);
>> +        *min = MIN_NAT_EPHEMERAL_PORT;
>> +        *max = MAX_NAT_EPHEMERAL_PORT;
>> +    } else {
>> +        *min = ni->min_port;
>> +        *max = ni->max_port;
>> +        *curr = *min + (hash % ((*max - *min) + 1));
>> +    }
>> +}
>>
>> -    if ((conn->nat_info->nat_action & NAT_ACTION_SRC) &&
>> -        (!(conn->nat_info->nat_action & NAT_ACTION_SRC_PORT))) {
>> -        min_port = ntohs(conn->key.src.port);
>> -        max_port = ntohs(conn->key.src.port);
>> -        first_port = min_port;
>> -    } else if ((conn->nat_info->nat_action & NAT_ACTION_DST) &&
>> -               (!(conn->nat_info->nat_action & NAT_ACTION_DST_PORT))) 
>> {
>> -        min_port = ntohs(conn->key.dst.port);
>> -        max_port = ntohs(conn->key.dst.port);
>> -        first_port = min_port;
>> +static void
>> +set_dport_range(struct nat_action_info_t *ni, const struct conn_key 
>> *k,
>> +                uint32_t hash, uint16_t *curr, uint16_t *min,
>> +                uint16_t *max)
>> +{
>> +    if (ni->nat_action & NAT_ACTION_DST_PORT) {
>> +        *min = ni->min_port;
>> +        *max = ni->max_port;
>> +        *curr = *min + (hash % ((*max - *min) + 1));
>>      } else {
>> -        uint16_t deltap = conn->nat_info->max_port - 
>> conn->nat_info->min_port;
>> -        uint32_t port_index = hash % (deltap + 1);
>> -        first_port = conn->nat_info->min_port + port_index;
>> -        min_port = conn->nat_info->min_port;
>> -        max_port = conn->nat_info->max_port;
>> +        *curr = ntohs(k->dst.port);
>> +        *min = *max = *curr;
>>      }
>> +}
>>
>> -    uint32_t deltaa = 0;
>> -    uint32_t address_index;
>> -    union ct_addr ct_addr;
>> -    memset(&ct_addr, 0, sizeof ct_addr);
>> -    union ct_addr max_ct_addr;
>> -    memset(&max_ct_addr, 0, sizeof max_ct_addr);
>> -    max_ct_addr = conn->nat_info->max_addr;
>> +/* Gets the initial in range address based on the hash.
>> + * Addresses are kept in network order. */
>> +static void
>> +get_addr_in_range(union ct_addr *min, union ct_addr *max,
>> +                  union ct_addr *curr, uint32_t hash,
>> +                  bool ipv4)
>> +{
>> +    uint32_t offt, range;
>>
>> -    if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
>> -        deltaa = ntohl(conn->nat_info->max_addr.ipv4) -
>> -                 ntohl(conn->nat_info->min_addr.ipv4);
>> -        address_index = hash % (deltaa + 1);
>> -        ct_addr.ipv4 = htonl(
>> -            ntohl(conn->nat_info->min_addr.ipv4) + address_index);
>> +    if (ipv4) {
>> +        range = (ntohl(max->ipv4) - ntohl(min->ipv4)) + 1;
>> +        offt = hash % range;
>> +        curr->ipv4 = htonl(ntohl(min->ipv4) + offt);
>>      } else {
>> -        deltaa = nat_ipv6_addrs_delta(&conn->nat_info->min_addr.ipv6,
>> -                                      
>> &conn->nat_info->max_addr.ipv6);
>> -        /* deltaa must be within 32 bits for full hash coverage. A 64 
>> or
>> +        range = nat_ipv6_addrs_delta(&min->ipv6,
>> +                                     &max->ipv6) + 1;
>> +        /* range must be within 32 bits for full hash coverage. A 64 
>> or
>>           * 128 bit hash is unnecessary and hence not used here. Most 
>> code
>>           * is kept common with V4; nat_ipv6_addrs_delta() will do the
>>           * enforcement via max_ct_addr. */
>> -        max_ct_addr = conn->nat_info->min_addr;
>> -        nat_ipv6_addr_increment(&max_ct_addr.ipv6, deltaa);
>> -        address_index = hash % (deltaa + 1);
>> -        ct_addr.ipv6 = conn->nat_info->min_addr.ipv6;
>> -        nat_ipv6_addr_increment(&ct_addr.ipv6, address_index);
>> -    }
>> -
>> -    uint16_t port = first_port;
>> -    bool all_ports_tried = false;
>> -    /* For DNAT or for specified port ranges, we don't use ephemeral 
>> ports. */
>> -    bool ephemeral_ports_tried
>> -        = conn->nat_info->nat_action & NAT_ACTION_DST ||
>> -              conn->nat_info->nat_action & NAT_ACTION_SRC_PORT
>> -          ? true : false;
>> -    union ct_addr first_addr = ct_addr;
>> -    bool pat_enabled = conn->key.nw_proto == IPPROTO_TCP ||
>> -                       conn->key.nw_proto == IPPROTO_UDP;
>> -
>> -    while (true) {
>> +        offt = hash % range;
>> +        curr->ipv6 = min->ipv6;
>> +        nat_ipv6_addr_increment(&curr->ipv6, offt);
>> +    }
>> +}
>> +
>> +static void
>> +get_initial_addr(const struct conn *conn, union ct_addr *min,
>> +                 union ct_addr *max, union ct_addr *curr,
>> +                 uint32_t hash, bool ipv4)
>> +{
>> +    const union ct_addr zero_ip = {0};
>> +
>> +    /* NULL CASE */
>> +    if (!memcmp(min, &zero_ip, sizeof(*min))) {
>>          if (conn->nat_info->nat_action & NAT_ACTION_SRC) {
>> -            nat_conn->rev_key.dst.addr = ct_addr;
>> -            if (pat_enabled) {
>> -                nat_conn->rev_key.dst.port = htons(port);
>> -            }
>> -        } else {
>> -            nat_conn->rev_key.src.addr = ct_addr;
>> -            if (pat_enabled) {
>> -                nat_conn->rev_key.src.port = htons(port);
>> -            }
>> +            *curr = conn->key.src.addr;
>> +        } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
>> +            *curr = conn->key.dst.addr;
>>          }
>> +    } else {
>> +        get_addr_in_range(min, max, curr, hash, ipv4);
>> +    }
>> +}
>>
>> -        bool found = conn_lookup(ct, &nat_conn->rev_key, time_msec(), 
>> NULL,
>> -                                 NULL);
>> -        if (!found) {
>> +/* if action is src, store to dst, otherwise store src
>> + * if src is NULL, do not store anything. */
>> +static void
>> +store_addr_to_key(union ct_addr *addr, struct conn_key *key,
>> +                  uint16_t action)
>> +{
>> +    if (action & NAT_ACTION_SRC) {
>> +        key->dst.addr = *addr;
>> +    } else {
>> +        key->src.addr = *addr;
>> +    }
>> +}
>> +
>> +static void
>> +next_addr_in_range(union ct_addr *curr, union ct_addr *min,
>> +                   union ct_addr *max, bool ipv4)
>> +{
>> +    if (ipv4) {
>> +        /* this check could be unified with IPv6, but let's avoid
>> +         * an unneeded memcmp() in case of IPv4. */
>> +        if (min->ipv4 == max->ipv4) {
>> +            return;
>> +        }
>> +
>> +        curr->ipv4 = (curr->ipv4 == max->ipv4) ?
>> +                      min->ipv4 :
>> +                      htonl(ntohl(curr->ipv4) + 1);
>> +    } else {
>> +        if (!memcmp(min, max, sizeof(*min))) {
>> +            return;
>> +        }
>> +
>> +        if (!memcmp(curr, max, sizeof(*curr))) {
>> +            *curr = *min;
>> +            return;
>> +        }
>> +
>> +        nat_ipv6_addr_increment(&curr->ipv6, 1);
>> +    }
>> +}
>> +
>> +static bool
>> +next_addr_in_range_guarded(union ct_addr *curr, union ct_addr *min,
>> +                           union ct_addr *max, union ct_addr *guard,
>> +                           bool ipv4)
>> +{
>> +    bool exhausted;
>> +
>> +    next_addr_in_range(curr, min, max, ipv4);
>> +
>> +    if (ipv4) {
>> +        exhausted = (curr->ipv4 == guard->ipv4);
>> +    } else {
>> +        exhausted = !memcmp(curr, guard, sizeof(*curr));
>> +    }
>> +
>> +    return exhausted;
>> +}
>> +
>> +/* This function tries to get a unique tuple.
>> + * Every iteration checks that the reverse tuple doesn't
>> + * collide with any existing one.
>> + *
>> + * in case of SNAT:
>> + *    - for each src IP address in the range (if any)
>> + *        - try to find a source port in range (if any)
>> + *        - if no port range exists, use the whole
>> + *          ephemeral range (starting from the port
>> + *          used by the client)
>> + *
>> + * in case of DNAT:
>> + *    - for each dst IP address in the range (if any)
>> + *        - for each dport in range (if any)
>> + *             - try to find a source port in the ephemeral range
>> + *               (starting from the port used by the client)
>> + *
>> + * If none can be found, return exhaustion to the caller. */
>> +static bool
>> +nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
>> +                     struct conn *nat_conn)
>> +{
>> +    union ct_addr min_addr = {0}, max_addr = {0}, curr_addr = {0},
>> +                  guard_addr = {0};
>> +    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
>> +    bool pat_proto = conn->key.nw_proto == IPPROTO_TCP ||
>> +                     conn->key.nw_proto == IPPROTO_UDP;
>> +    uint16_t min_dport, max_dport, curr_dport;
>> +    uint16_t min_sport, max_sport, curr_sport;
>> +
>> +    min_addr = conn->nat_info->min_addr;
>> +    max_addr = conn->nat_info->max_addr;
>> +
>> +    get_initial_addr(conn, &min_addr, &max_addr, &curr_addr, hash,
>> +                     (conn->key.dl_type == htons(ETH_TYPE_IP)));
>> +
>> +    /* save the address we started from so that
>> +     * we can stop once we reach it. */
>> +    guard_addr = curr_addr;
>> +
>> +    set_sport_range(conn->nat_info, &conn->key, hash, &curr_sport,
>> +                    &min_sport, &max_sport);
>> +    set_dport_range(conn->nat_info, &conn->key, hash, &curr_dport,
>> +                    &min_dport, &max_dport);
>> +
>> +another_round:
>> +    store_addr_to_key(&curr_addr, &nat_conn->rev_key,
>> +                      conn->nat_info->nat_action);
>> +
>> +    if (!pat_proto) {
>> +        if (!conn_lookup(ct, &nat_conn->rev_key,
>> +                         time_msec(), NULL, NULL)) {
>>              return true;
>> -        } else if (pat_enabled && !all_ports_tried) {
>> -            if (min_port == max_port) {
>> -                all_ports_tried = true;
>> -            } else if (port == max_port) {
>> -                port = min_port;
>> -            } else {
>> -                port++;
>> -            }
>> -            if (port == first_port) {
>> -                all_ports_tried = true;
>> -            }
>> -        } else {
>> -            if (memcmp(&ct_addr, &max_ct_addr, sizeof ct_addr)) {
>> -                if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
>> -                    ct_addr.ipv4 = htonl(ntohl(ct_addr.ipv4) + 1);
>> -                } else {
>> -                    nat_ipv6_addr_increment(&ct_addr.ipv6, 1);
>> -                }
>> -            } else {
>> -                ct_addr = conn->nat_info->min_addr;
>> -            }
>> -            if (!memcmp(&ct_addr, &first_addr, sizeof ct_addr)) {
>> -                if (pat_enabled && !ephemeral_ports_tried) {
>> -                    ephemeral_ports_tried = true;
>> -                    ct_addr = conn->nat_info->min_addr;
>> -                    first_addr = ct_addr;
>> -                    min_port = MIN_NAT_EPHEMERAL_PORT;
>> -                    max_port = MAX_NAT_EPHEMERAL_PORT;
>> -                } else {
>> -                    break;
>> -                }
>> +        }
>> +
>> +        goto next_addr;
>> +    }
>> +
>> +    int i, j;
>> +    FOR_EACH_PORT_IN_RANGE(i, curr_dport, min_dport, max_dport) {
>> +        nat_conn->rev_key.src.port = htons(curr_dport);
>> +        FOR_EACH_PORT_IN_RANGE(j, curr_sport, min_sport, max_sport) {
>> +            nat_conn->rev_key.dst.port = htons(curr_sport);
>> +            if (!conn_lookup(ct, &nat_conn->rev_key,
>> +                             time_msec(), NULL, NULL)) {
>> +                return true;
>>              }
>> -            first_port = min_port;
>> -            port = first_port;
>> -            all_ports_tried = false;
>>          }
>>      }
>> -    return false;
>> +
>> +    /* Check if next IP is in range and respin. Otherwise, notify
>> +     * exhaustion to the caller. */
>> +next_addr:
>> +    if (next_addr_in_range_guarded(&curr_addr, &min_addr,
>> +                                   &max_addr, &guard_addr,
>> +                                   conn->key.dl_type == 
>> htons(ETH_TYPE_IP))) {
>> +        return false;
>> +    }
>> +
>> +    goto another_round;
>>  }
>>
>>  static enum ct_update_res
>> diff --git a/lib/conntrack.h b/lib/conntrack.h
>> index 9553b188a..6ce1cd216 100644
>> --- a/lib/conntrack.h
>> +++ b/lib/conntrack.h
>> @@ -77,6 +77,14 @@ enum nat_action_e {
>>      NAT_ACTION_DST_PORT = 1 << 3,
>>  };
>>
>> +#define NAT_ACTION_SNAT_ALL (NAT_ACTION_SRC | NAT_ACTION_SRC_PORT)
>> +#define NAT_ACTION_DNAT_ALL (NAT_ACTION_DST | NAT_ACTION_DST_PORT)
>> +
>> +enum {
>> +    MIN_NAT_EPHEMERAL_PORT = 1024,
>> +    MAX_NAT_EPHEMERAL_PORT = 65535
>> +};
>> +
>>  struct nat_action_info_t {
>>      union ct_addr min_addr;
>>      union ct_addr max_addr;
>> @@ -85,6 +93,13 @@ struct nat_action_info_t {
>>      uint16_t nat_action;
>>  };
>>
>> +#define NEXT_PORT_IN_RANGE(curr, min, max) \
>> +    curr = (curr == max) ? min : curr + 1
>> +
>> +#define FOR_EACH_PORT_IN_RANGE(idx, curr, min, max) \
>> +    for (idx = 0; idx < (max - min) + 1; idx++, \
>> +             NEXT_PORT_IN_RANGE(curr, min, max))
>> +
>>  struct conntrack *conntrack_init(void);
>>  void conntrack_destroy(struct conntrack *);
diff mbox series

Patch

diff --git a/lib/conntrack.c b/lib/conntrack.c
index 99198a601..da69f63ef 100644
--- a/lib/conntrack.c
+++ b/lib/conntrack.c
@@ -108,9 +108,8 @@  static void set_label(struct dp_packet *, struct conn *,
 static void *clean_thread_main(void *f_);
 
 static bool
-nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
-                       struct conn *nat_conn);
-
+nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
+                     struct conn *nat_conn);
 static uint8_t
 reverse_icmp_type(uint8_t type);
 static uint8_t
@@ -728,11 +727,11 @@  pat_packet(struct dp_packet *pkt, const struct conn *conn)
         }
     } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
         if (conn->key.nw_proto == IPPROTO_TCP) {
-            struct tcp_header *th = dp_packet_l4(pkt);
-            packet_set_tcp_port(pkt, th->tcp_src, conn->rev_key.src.port);
+            packet_set_tcp_port(pkt, conn->rev_key.dst.port,
+                                conn->rev_key.src.port);
         } else if (conn->key.nw_proto == IPPROTO_UDP) {
-            struct udp_header *uh = dp_packet_l4(pkt);
-            packet_set_udp_port(pkt, uh->udp_src, conn->rev_key.src.port);
+            packet_set_udp_port(pkt, conn->rev_key.dst.port,
+                                conn->rev_key.src.port);
         }
     }
 }
@@ -786,11 +785,9 @@  un_pat_packet(struct dp_packet *pkt, const struct conn *conn)
         }
     } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
         if (conn->key.nw_proto == IPPROTO_TCP) {
-            struct tcp_header *th = dp_packet_l4(pkt);
-            packet_set_tcp_port(pkt, conn->key.dst.port, th->tcp_dst);
+            packet_set_tcp_port(pkt, conn->key.dst.port, conn->key.src.port);
         } else if (conn->key.nw_proto == IPPROTO_UDP) {
-            struct udp_header *uh = dp_packet_l4(pkt);
-            packet_set_udp_port(pkt, conn->key.dst.port, uh->udp_dst);
+            packet_set_udp_port(pkt, conn->key.dst.port, conn->key.src.port);
         }
     }
 }
@@ -810,12 +807,10 @@  reverse_pat_packet(struct dp_packet *pkt, const struct conn *conn)
         }
     } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
         if (conn->key.nw_proto == IPPROTO_TCP) {
-            struct tcp_header *th_in = dp_packet_l4(pkt);
-            packet_set_tcp_port(pkt, th_in->tcp_src,
+            packet_set_tcp_port(pkt, conn->key.src.port,
                                 conn->key.dst.port);
         } else if (conn->key.nw_proto == IPPROTO_UDP) {
-            struct udp_header *uh_in = dp_packet_l4(pkt);
-            packet_set_udp_port(pkt, uh_in->udp_src,
+            packet_set_udp_port(pkt, conn->key.src.port,
                                 conn->key.dst.port);
         }
     }
@@ -1029,14 +1024,14 @@  conn_not_found(struct conntrack *ct, struct dp_packet *pkt,
                 }
             } else {
                 memcpy(nat_conn, nc, sizeof *nat_conn);
-                bool nat_res = nat_select_range_tuple(ct, nc, nat_conn);
+                bool nat_res = nat_get_unique_tuple(ct, nc, nat_conn);
 
                 if (!nat_res) {
                     goto nat_res_exhaustion;
                 }
 
                 /* Update nc with nat adjustments made to nat_conn by
-                 * nat_select_range_tuple(). */
+                 * nat_get_unique_tuple(). */
                 memcpy(nc, nat_conn, sizeof *nc);
             }
 
@@ -1391,7 +1386,6 @@  process_one(struct conntrack *ct, struct dp_packet *pkt,
 
     set_cached_conn(nat_action_info, ctx, conn, pkt);
 }
-
 /* Sends the packets in '*pkt_batch' through the connection tracker 'ct'.  All
  * the packets must have the same 'dl_type' (IPv4 or IPv6) and should have
  * the l3 and and l4 offset properly set.  Performs fragment reassembly with
@@ -1436,7 +1430,6 @@  conntrack_execute(struct conntrack *ct, struct dp_packet_batch *pkt_batch,
     }
 
     ipf_postprocess_conntrack(ct->ipf, pkt_batch, now, dl_type);
-
     return 0;
 }
 
@@ -2210,130 +2203,223 @@  nat_range_hash(const struct conn *conn, uint32_t basis)
     return hash_finish(hash, 0);
 }
 
-static bool
-nat_select_range_tuple(struct conntrack *ct, const struct conn *conn,
-                       struct conn *nat_conn)
-{
-    enum { MIN_NAT_EPHEMERAL_PORT = 1024,
-           MAX_NAT_EPHEMERAL_PORT = 65535 };
-
-    uint16_t min_port;
-    uint16_t max_port;
-    uint16_t first_port;
-    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
+/* Ports are stored in host byte order for convenience. */
+static void
+set_sport_range(struct nat_action_info_t *ni, const struct conn_key *k,
+                uint32_t hash, uint16_t *curr, uint16_t *min,
+                uint16_t *max)
+{
+    if (((ni->nat_action & NAT_ACTION_SNAT_ALL) == NAT_ACTION_SRC) ||
+        ((ni->nat_action & NAT_ACTION_DST))) {
+        *curr = ntohs(k->src.port);
+        *min = MIN_NAT_EPHEMERAL_PORT;
+        *max = MAX_NAT_EPHEMERAL_PORT;
+    } else {
+        *min = ni->min_port;
+        *max = ni->max_port;
+        *curr = *min + (hash % ((*max - *min) + 1));
+    }
+}
 
-    if ((conn->nat_info->nat_action & NAT_ACTION_SRC) &&
-        (!(conn->nat_info->nat_action & NAT_ACTION_SRC_PORT))) {
-        min_port = ntohs(conn->key.src.port);
-        max_port = ntohs(conn->key.src.port);
-        first_port = min_port;
-    } else if ((conn->nat_info->nat_action & NAT_ACTION_DST) &&
-               (!(conn->nat_info->nat_action & NAT_ACTION_DST_PORT))) {
-        min_port = ntohs(conn->key.dst.port);
-        max_port = ntohs(conn->key.dst.port);
-        first_port = min_port;
+static void
+set_dport_range(struct nat_action_info_t *ni, const struct conn_key *k,
+                uint32_t hash, uint16_t *curr, uint16_t *min,
+                uint16_t *max)
+{
+    if (ni->nat_action & NAT_ACTION_DST_PORT) {
+        *min = ni->min_port;
+        *max = ni->max_port;
+        *curr = *min + (hash % ((*max - *min) + 1));
     } else {
-        uint16_t deltap = conn->nat_info->max_port - conn->nat_info->min_port;
-        uint32_t port_index = hash % (deltap + 1);
-        first_port = conn->nat_info->min_port + port_index;
-        min_port = conn->nat_info->min_port;
-        max_port = conn->nat_info->max_port;
+        *curr = ntohs(k->dst.port);
+        *min = *max = *curr;
     }
+}
 
-    uint32_t deltaa = 0;
-    uint32_t address_index;
-    union ct_addr ct_addr;
-    memset(&ct_addr, 0, sizeof ct_addr);
-    union ct_addr max_ct_addr;
-    memset(&max_ct_addr, 0, sizeof max_ct_addr);
-    max_ct_addr = conn->nat_info->max_addr;
+/* Gets the initial in range address based on the hash.
+ * Addresses are kept in network order. */
+static void
+get_addr_in_range(union ct_addr *min, union ct_addr *max,
+                  union ct_addr *curr, uint32_t hash,
+                  bool ipv4)
+{
+    uint32_t offt, range;
 
-    if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
-        deltaa = ntohl(conn->nat_info->max_addr.ipv4) -
-                 ntohl(conn->nat_info->min_addr.ipv4);
-        address_index = hash % (deltaa + 1);
-        ct_addr.ipv4 = htonl(
-            ntohl(conn->nat_info->min_addr.ipv4) + address_index);
+    if (ipv4) {
+        range = (ntohl(max->ipv4) - ntohl(min->ipv4)) + 1;
+        offt = hash % range;
+        curr->ipv4 = htonl(ntohl(min->ipv4) + offt);
     } else {
-        deltaa = nat_ipv6_addrs_delta(&conn->nat_info->min_addr.ipv6,
-                                      &conn->nat_info->max_addr.ipv6);
-        /* deltaa must be within 32 bits for full hash coverage. A 64 or
+        range = nat_ipv6_addrs_delta(&min->ipv6,
+                                     &max->ipv6) + 1;
+        /* range must be within 32 bits for full hash coverage. A 64 or
          * 128 bit hash is unnecessary and hence not used here. Most code
          * is kept common with V4; nat_ipv6_addrs_delta() will do the
          * enforcement via max_ct_addr. */
-        max_ct_addr = conn->nat_info->min_addr;
-        nat_ipv6_addr_increment(&max_ct_addr.ipv6, deltaa);
-        address_index = hash % (deltaa + 1);
-        ct_addr.ipv6 = conn->nat_info->min_addr.ipv6;
-        nat_ipv6_addr_increment(&ct_addr.ipv6, address_index);
-    }
-
-    uint16_t port = first_port;
-    bool all_ports_tried = false;
-    /* For DNAT or for specified port ranges, we don't use ephemeral ports. */
-    bool ephemeral_ports_tried
-        = conn->nat_info->nat_action & NAT_ACTION_DST ||
-              conn->nat_info->nat_action & NAT_ACTION_SRC_PORT
-          ? true : false;
-    union ct_addr first_addr = ct_addr;
-    bool pat_enabled = conn->key.nw_proto == IPPROTO_TCP ||
-                       conn->key.nw_proto == IPPROTO_UDP;
-
-    while (true) {
+        offt = hash % range;
+        curr->ipv6 = min->ipv6;
+        nat_ipv6_addr_increment(&curr->ipv6, offt);
+    }
+}
+
+static void
+get_initial_addr(const struct conn *conn, union ct_addr *min,
+                 union ct_addr *max, union ct_addr *curr,
+                 uint32_t hash, bool ipv4)
+{
+    const union ct_addr zero_ip = {0};
+
+    /* NULL CASE */
+    if (!memcmp(min, &zero_ip, sizeof(*min))) {
         if (conn->nat_info->nat_action & NAT_ACTION_SRC) {
-            nat_conn->rev_key.dst.addr = ct_addr;
-            if (pat_enabled) {
-                nat_conn->rev_key.dst.port = htons(port);
-            }
-        } else {
-            nat_conn->rev_key.src.addr = ct_addr;
-            if (pat_enabled) {
-                nat_conn->rev_key.src.port = htons(port);
-            }
+            *curr = conn->key.src.addr;
+        } else if (conn->nat_info->nat_action & NAT_ACTION_DST) {
+            *curr = conn->key.dst.addr;
         }
+    } else {
+        get_addr_in_range(min, max, curr, hash, ipv4);
+    }
+}
 
-        bool found = conn_lookup(ct, &nat_conn->rev_key, time_msec(), NULL,
-                                 NULL);
-        if (!found) {
+/* if action is src, store to dst, otherwise store src
+ * if src is NULL, do not store anything. */
+static void
+store_addr_to_key(union ct_addr *addr, struct conn_key *key,
+                  uint16_t action)
+{
+    if (action & NAT_ACTION_SRC) {
+        key->dst.addr = *addr;
+    } else {
+        key->src.addr = *addr;
+    }
+}
+
+static void
+next_addr_in_range(union ct_addr *curr, union ct_addr *min,
+                   union ct_addr *max, bool ipv4)
+{
+    if (ipv4) {
+        /* this check could be unified with IPv6, but let's avoid
+         * an unneeded memcmp() in case of IPv4. */
+        if (min->ipv4 == max->ipv4) {
+            return;
+        }
+
+        curr->ipv4 = (curr->ipv4 == max->ipv4) ?
+                      min->ipv4 :
+                      htonl(ntohl(curr->ipv4) + 1);
+    } else {
+        if (!memcmp(min, max, sizeof(*min))) {
+            return;
+        }
+
+        if (!memcmp(curr, max, sizeof(*curr))) {
+            *curr = *min;
+            return;
+        }
+
+        nat_ipv6_addr_increment(&curr->ipv6, 1);
+    }
+}
+
+static bool
+next_addr_in_range_guarded(union ct_addr *curr, union ct_addr *min,
+                           union ct_addr *max, union ct_addr *guard,
+                           bool ipv4)
+{
+    bool exhausted;
+
+    next_addr_in_range(curr, min, max, ipv4);
+
+    if (ipv4) {
+        exhausted = (curr->ipv4 == guard->ipv4);
+    } else {
+        exhausted = !memcmp(curr, guard, sizeof(*curr));
+    }
+
+    return exhausted;
+}
+
+/* This function tries to get a unique tuple.
+ * Every iteration checks that the reverse tuple doesn't
+ * collide with any existing one.
+ *
+ * in case of SNAT:
+ *    - for each src IP address in the range (if any)
+ *        - try to find a source port in range (if any)
+ *        - if no port range exists, use the whole
+ *          ephemeral range (starting from the port
+ *          used by the client)
+ *
+ * in case of DNAT:
+ *    - for each dst IP address in the range (if any)
+ *        - for each dport in range (if any)
+ *             - try to find a source port in the ephemeral range
+ *               (starting from the port used by the client)
+ *
+ * If none can be found, return exhaustion to the caller. */
+static bool
+nat_get_unique_tuple(struct conntrack *ct, const struct conn *conn,
+                     struct conn *nat_conn)
+{
+    union ct_addr min_addr = {0}, max_addr = {0}, curr_addr = {0},
+                  guard_addr = {0};
+    uint32_t hash = nat_range_hash(conn, ct->hash_basis);
+    bool pat_proto = conn->key.nw_proto == IPPROTO_TCP ||
+                     conn->key.nw_proto == IPPROTO_UDP;
+    uint16_t min_dport, max_dport, curr_dport;
+    uint16_t min_sport, max_sport, curr_sport;
+
+    min_addr = conn->nat_info->min_addr;
+    max_addr = conn->nat_info->max_addr;
+
+    get_initial_addr(conn, &min_addr, &max_addr, &curr_addr, hash,
+                     (conn->key.dl_type == htons(ETH_TYPE_IP)));
+
+    /* save the address we started from so that
+     * we can stop once we reach it. */
+    guard_addr = curr_addr;
+
+    set_sport_range(conn->nat_info, &conn->key, hash, &curr_sport,
+                    &min_sport, &max_sport);
+    set_dport_range(conn->nat_info, &conn->key, hash, &curr_dport,
+                    &min_dport, &max_dport);
+
+another_round:
+    store_addr_to_key(&curr_addr, &nat_conn->rev_key,
+                      conn->nat_info->nat_action);
+
+    if (!pat_proto) {
+        if (!conn_lookup(ct, &nat_conn->rev_key,
+                         time_msec(), NULL, NULL)) {
             return true;
-        } else if (pat_enabled && !all_ports_tried) {
-            if (min_port == max_port) {
-                all_ports_tried = true;
-            } else if (port == max_port) {
-                port = min_port;
-            } else {
-                port++;
-            }
-            if (port == first_port) {
-                all_ports_tried = true;
-            }
-        } else {
-            if (memcmp(&ct_addr, &max_ct_addr, sizeof ct_addr)) {
-                if (conn->key.dl_type == htons(ETH_TYPE_IP)) {
-                    ct_addr.ipv4 = htonl(ntohl(ct_addr.ipv4) + 1);
-                } else {
-                    nat_ipv6_addr_increment(&ct_addr.ipv6, 1);
-                }
-            } else {
-                ct_addr = conn->nat_info->min_addr;
-            }
-            if (!memcmp(&ct_addr, &first_addr, sizeof ct_addr)) {
-                if (pat_enabled && !ephemeral_ports_tried) {
-                    ephemeral_ports_tried = true;
-                    ct_addr = conn->nat_info->min_addr;
-                    first_addr = ct_addr;
-                    min_port = MIN_NAT_EPHEMERAL_PORT;
-                    max_port = MAX_NAT_EPHEMERAL_PORT;
-                } else {
-                    break;
-                }
+        }
+
+        goto next_addr;
+    }
+
+    int i, j;
+    FOR_EACH_PORT_IN_RANGE(i, curr_dport, min_dport, max_dport) {
+        nat_conn->rev_key.src.port = htons(curr_dport);
+        FOR_EACH_PORT_IN_RANGE(j, curr_sport, min_sport, max_sport) {
+            nat_conn->rev_key.dst.port = htons(curr_sport);
+            if (!conn_lookup(ct, &nat_conn->rev_key,
+                             time_msec(), NULL, NULL)) {
+                return true;
             }
-            first_port = min_port;
-            port = first_port;
-            all_ports_tried = false;
         }
     }
-    return false;
+
+    /* Check if next IP is in range and respin. Otherwise, notify
+     * exhaustion to the caller. */
+next_addr:
+    if (next_addr_in_range_guarded(&curr_addr, &min_addr,
+                                   &max_addr, &guard_addr,
+                                   conn->key.dl_type == htons(ETH_TYPE_IP))) {
+        return false;
+    }
+
+    goto another_round;
 }
 
 static enum ct_update_res
diff --git a/lib/conntrack.h b/lib/conntrack.h
index 9553b188a..6ce1cd216 100644
--- a/lib/conntrack.h
+++ b/lib/conntrack.h
@@ -77,6 +77,14 @@  enum nat_action_e {
     NAT_ACTION_DST_PORT = 1 << 3,
 };
 
+#define NAT_ACTION_SNAT_ALL (NAT_ACTION_SRC | NAT_ACTION_SRC_PORT)
+#define NAT_ACTION_DNAT_ALL (NAT_ACTION_DST | NAT_ACTION_DST_PORT)
+
+enum {
+    MIN_NAT_EPHEMERAL_PORT = 1024,
+    MAX_NAT_EPHEMERAL_PORT = 65535
+};
+
 struct nat_action_info_t {
     union ct_addr min_addr;
     union ct_addr max_addr;
@@ -85,6 +93,13 @@  struct nat_action_info_t {
     uint16_t nat_action;
 };
 
+#define NEXT_PORT_IN_RANGE(curr, min, max) \
+    curr = (curr == max) ? min : curr + 1
+
+#define FOR_EACH_PORT_IN_RANGE(idx, curr, min, max) \
+    for (idx = 0; idx < (max - min) + 1; idx++, \
+             NEXT_PORT_IN_RANGE(curr, min, max))
+
 struct conntrack *conntrack_init(void);
 void conntrack_destroy(struct conntrack *);