[ovs-dev,2/3,v4] datapath-windows: Add NAT module in conntrack

Submitted by Yin Lin on April 19, 2017, 8:19 p.m.

Details

Message ID BY2PR0501MB1848C646F5D4345096C9A37ED4180@BY2PR0501MB1848.namprd05.prod.outlook.com
State New
Headers show

Commit Message

Yin Lin April 19, 2017, 8:19 p.m.
Signed-off-by: Yin Lin <linyi@vmware.com>

Issue: #
Change-Id: I6f37360c36525548b343f0016304015fec8aba7d
---
 datapath-windows/automake.mk            |   2 +
 datapath-windows/ovsext/Conntrack-nat.c | 424 ++++++++++++++++++++++++++++++++
 datapath-windows/ovsext/Conntrack-nat.h |  39 +++
 3 files changed, 465 insertions(+)
 create mode 100644 datapath-windows/ovsext/Conntrack-nat.c
 create mode 100644 datapath-windows/ovsext/Conntrack-nat.h

Patch hide | download patch | download mbox

diff --git a/datapath-windows/automake.mk b/datapath-windows/automake.mk
index 53983ae..296e785 100644
--- a/datapath-windows/automake.mk
+++ b/datapath-windows/automake.mk
@@ -16,7 +16,9 @@  EXTRA_DIST += \
 	datapath-windows/ovsext/Conntrack-icmp.c \
 	datapath-windows/ovsext/Conntrack-other.c \
 	datapath-windows/ovsext/Conntrack-related.c \
+    datapath-windows/ovsext/Conntrack-nat.c \
 	datapath-windows/ovsext/Conntrack-tcp.c \
+    datapath-windows/ovsext/Conntrack-nat.h \
 	datapath-windows/ovsext/Conntrack.c \
 	datapath-windows/ovsext/Conntrack.h \
 	datapath-windows/ovsext/Datapath.c \
diff --git a/datapath-windows/ovsext/Conntrack-nat.c b/datapath-windows/ovsext/Conntrack-nat.c
new file mode 100644
index 0000000..edf5f8f
--- /dev/null
+++ b/datapath-windows/ovsext/Conntrack-nat.c
@@ -0,0 +1,424 @@ 
+#include "Conntrack-nat.h"
+#include "Jhash.h"
+
+PLIST_ENTRY ovsNatTable = NULL;
+PLIST_ENTRY ovsUnNatTable = NULL;
+
+/*
+ *---------------------------------------------------------------------------
+ * OvsHashNatKey
+ *     Hash NAT related fields in a Conntrack key.
+ *---------------------------------------------------------------------------
+ */
+static __inline UINT32
+OvsHashNatKey(const OVS_CT_KEY *key)
+{
+    UINT32 hash = 0;
+#define HASH_ADD(field) \
+    hash = OvsJhashBytes(&key->field, sizeof(key->field), hash)
+
+    HASH_ADD(src.addr.ipv4_aligned);
+    HASH_ADD(dst.addr.ipv4_aligned);
+    HASH_ADD(src.port);
+    HASH_ADD(dst.port);
+    HASH_ADD(zone);
+#undef HASH_ADD
+    return hash;
+}
+
+/*
+ *---------------------------------------------------------------------------
+ * OvsNatKeyAreSame
+ *     Compare NAT related fields in a Conntrack key.
+ *---------------------------------------------------------------------------
+ */
+static __inline BOOLEAN
+OvsNatKeyAreSame(const OVS_CT_KEY *key1, const OVS_CT_KEY *key2)
+{
+    // XXX: Compare IPv6 key as well
+#define FIELD_COMPARE(field) \
+    if (key1->field != key2->field) return FALSE
+
+    FIELD_COMPARE(src.addr.ipv4_aligned);
+    FIELD_COMPARE(dst.addr.ipv4_aligned);
+    FIELD_COMPARE(src.port);
+    FIELD_COMPARE(dst.port);
+    FIELD_COMPARE(zone);
+    return TRUE;
+#undef FIELD_COMPARE
+}
+
+/*
+ *---------------------------------------------------------------------------
+ * OvsNaGetBucket
+ *     Returns the row of NAT table that has the same hash as the given NAT
+ *     hash key. If isReverse is TRUE, returns the row of reverse NAT table
+ *     instead.
+ *---------------------------------------------------------------------------
+ */
+static __inline PLIST_ENTRY
+OvsNatGetBucket(const OVS_CT_KEY *key, BOOLEAN isReverse)
+{
+    uint32_t hash = OvsHashNatKey(key);
+    if (isReverse) {
+        return &ovsUnNatTable[hash & NAT_HASH_TABLE_MASK];
+    } else {
+        return &ovsNatTable[hash & NAT_HASH_TABLE_MASK];
+    }
+}
+
+/*
+ *---------------------------------------------------------------------------
+ * OvsNatInit
+ *     Initialize NAT related resources.
+ *---------------------------------------------------------------------------
+ */
+NTSTATUS OvsNatInit()
+{
+    ASSERT(ovsNatTable == NULL);
+
+    /* Init the Hash Buffer */
+    ovsNatTable = OvsAllocateMemoryWithTag(
+        sizeof(LIST_ENTRY) * NAT_HASH_TABLE_SIZE,
+        OVS_CT_POOL_TAG);
+    if (ovsNatTable == NULL) {
+        goto failNoMem;
+    }
+
+    ovsUnNatTable = OvsAllocateMemoryWithTag(
+        sizeof(LIST_ENTRY) * NAT_HASH_TABLE_SIZE,
+        OVS_CT_POOL_TAG);
+    if (ovsUnNatTable == NULL) {
+        goto freeNatTable;
+    }
+
+    for (int i = 0; i < NAT_HASH_TABLE_SIZE; i++) {
+        InitializeListHead(&ovsNatTable[i]);
+        InitializeListHead(&ovsUnNatTable[i]);
+    }
+    return STATUS_SUCCESS;
+
+freeNatTable:
+    OvsFreeMemoryWithTag(ovsNatTable, OVS_CT_POOL_TAG);
+failNoMem:
+    return STATUS_INSUFFICIENT_RESOURCES;
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatFlush
+ *     Flushes out all NAT entries that match the given zone.
+ *----------------------------------------------------------------------------
+ */
+VOID OvsNatFlush(UINT16 zone)
+{
+    PLIST_ENTRY link, next;
+    for (int i = 0; i < NAT_HASH_TABLE_SIZE; i++) {
+        LIST_FORALL_SAFE(&ovsNatTable[i], link, next) {
+            POVS_NAT_ENTRY entry =
+                CONTAINING_RECORD(link, OVS_NAT_ENTRY, link);
+            /* zone is a non-zero value */
+            if (!zone || zone == entry->key.zone) {
+                OvsNatDeleteEntry(entry);
+            }
+        }
+    }
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatCleanup
+ *     Releases all NAT related resources.
+ *----------------------------------------------------------------------------
+ */
+VOID OvsNatCleanup()
+{
+    if (ovsNatTable == NULL) return;
+    OvsFreeMemoryWithTag(ovsNatTable, OVS_CT_POOL_TAG);
+    OvsFreeMemoryWithTag(ovsUnNatTable, OVS_CT_POOL_TAG);
+    ovsNatTable = NULL;
+    ovsUnNatTable = NULL;
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatPacket
+ *     Performs NAT operation on the packet by replacing the source/destinaton
+ *     address/port based on natAction. If reverse is TRUE, perform unNAT
+ *     instead.
+ *----------------------------------------------------------------------------
+ */
+VOID
+OvsNatPacket(OvsForwardingContext *ovsFwdCtx,
+             const OVS_CT_ENTRY *entry,
+             UINT16 natAction,
+             OvsFlowKey *key,
+             BOOLEAN reverse)
+{
+    UINT32 natFlag;
+    const struct ct_endpoint* endpoint;
+    /* When it is NAT, only entry->rev_key contains NATTED address;
+       When it is unNAT, only entry->key contains the UNNATTED address;*/
+    const OVS_CT_KEY *ctKey = reverse ? &entry->key : &entry->rev_key;
+    BOOLEAN isSrcNat;
+
+    if (!(natAction & (NAT_ACTION_SRC | NAT_ACTION_DST))) {
+        return;
+    }
+    isSrcNat = (((natAction & NAT_ACTION_SRC) && !reverse) ||
+                ((natAction & NAT_ACTION_DST) && reverse));
+
+    if (isSrcNat) {
+        /* Flag is set to SNAT for SNAT case and the reverse DNAT case */
+        natFlag = OVS_CS_F_SRC_NAT;
+        /* Note that ctKey is the key in the other direction, so
+           endpoint has to be reverted, i.e. ctKey->dst for SNAT
+           and ctKey->src for DNAT */
+        endpoint = &ctKey->dst;
+    } else {
+        natFlag = OVS_CS_F_DST_NAT;
+        endpoint = &ctKey->src;
+    }
+    key->ct.state |= natFlag;
+    if (ctKey->dl_type == htons(ETH_TYPE_IPV4)) {
+        OvsUpdateAddressAndPort(ovsFwdCtx,
+                                endpoint->addr.ipv4_aligned,
+                                endpoint->port, isSrcNat,
+                                !reverse);
+        if (isSrcNat) {
+            key->ipKey.nwSrc = endpoint->addr.ipv4_aligned;
+        } else {
+            key->ipKey.nwDst = endpoint->addr.ipv4_aligned;
+        }
+    } else if (ctKey->dl_type == htons(ETH_TYPE_IPV6)){
+        // XXX: IPv6 packet not supported yet.
+        return;
+    }
+    if (natAction & (NAT_ACTION_SRC_PORT | NAT_ACTION_DST_PORT)) {
+        if (isSrcNat) {
+            if (key->ipKey.l4.tpSrc != 0) {
+                key->ipKey.l4.tpSrc = endpoint->port;
+            }
+        } else {
+            if (key->ipKey.l4.tpDst != 0) {
+                key->ipKey.l4.tpDst = endpoint->port;
+            }
+        }
+    }
+}
+
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatHashRange
+ *     Compute hash for a range of addresses specified in natInfo.
+ *----------------------------------------------------------------------------
+ */
+static UINT32 OvsNatHashRange(const OVS_CT_ENTRY *entry, UINT32 basis)
+{
+    UINT32 hash = basis;
+#define HASH_ADD(field) \
+    hash = OvsJhashBytes(&field, sizeof(field), hash)
+
+    HASH_ADD(entry->natInfo.minAddr);
+    HASH_ADD(entry->natInfo.maxAddr);
+    HASH_ADD(entry->key.dl_type);
+    HASH_ADD(entry->key.nw_proto);
+    HASH_ADD(entry->key.zone);
+#undef HASH_ADD
+    return hash;
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatAddEntry
+ *     Add an entry to the NAT table. Also updates the reverse NAT lookup
+ *     table.
+ *----------------------------------------------------------------------------
+ */
+VOID
+OvsNatAddEntry(OVS_NAT_ENTRY* entry)
+{
+    InsertHeadList(OvsNatGetBucket(&entry->key, FALSE),
+                   &entry->link);
+    InsertHeadList(OvsNatGetBucket(&entry->value, TRUE),
+                   &entry->reverseLink);
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatCtEntry
+ *     Update an Conntrack entry with NAT information. Translated address and
+ *     port will be generated and write back to the conntrack entry as a
+ *     result.
+ *----------------------------------------------------------------------------
+ */
+BOOLEAN
+OvsNatCtEntry(OVS_CT_ENTRY *entry)
+{
+    const uint16_t MIN_NAT_EPHEMERAL_PORT = 1024;
+    const uint16_t MAX_NAT_EPHEMERAL_PORT = 65535;
+
+    uint16_t minPort;
+    uint16_t maxPort;
+    uint16_t firstPort;
+
+    uint32_t hash = OvsNatHashRange(entry, 0);
+
+    if ((entry->natInfo.natAction & NAT_ACTION_SRC) &&
+        (!(entry->natInfo.natAction & NAT_ACTION_SRC_PORT))) {
+        firstPort = minPort = maxPort = ntohs(entry->key.src.port);
+    } else if ((entry->natInfo.natAction & NAT_ACTION_DST) &&
+               (!(entry->natInfo.natAction & NAT_ACTION_DST_PORT))) {
+        firstPort = minPort = maxPort = ntohs(entry->key.dst.port);
+    } else {
+        uint16_t portDelta = entry->natInfo.maxPort - entry->natInfo.minPort;
+        uint16_t portIndex = (uint16_t) hash % (portDelta + 1);
+        firstPort = entry->natInfo.minPort + portIndex;
+        minPort = entry->natInfo.minPort;
+        maxPort = entry->natInfo.maxPort;
+    }
+
+    uint32_t addrDelta = 0;
+    uint32_t addrIndex;
+    struct ct_addr ctAddr, maxCtAddr;
+    memset(&ctAddr, 0, sizeof ctAddr);
+    memset(&maxCtAddr, 0, sizeof maxCtAddr);
+    maxCtAddr = entry->natInfo.maxAddr;
+
+    if (entry->key.dl_type == htons(ETH_TYPE_IPV4)) {
+        addrDelta = ntohl(entry->natInfo.maxAddr.ipv4_aligned) -
+                    ntohl(entry->natInfo.minAddr.ipv4_aligned);
+        addrIndex = hash % (addrDelta + 1);
+        ctAddr.ipv4_aligned = htonl(
+            ntohl(entry->natInfo.minAddr.ipv4_aligned) + addrIndex);
+    } else {
+        // XXX: IPv6 not supported
+        return FALSE;
+    }
+
+    uint16_t port = firstPort;
+    BOOLEAN allPortsTried = FALSE;
+    BOOLEAN originalPortsTried = FALSE;
+    struct ct_addr firstAddr = ctAddr;
+    for (;;) {
+        if (entry->natInfo.natAction & NAT_ACTION_SRC) {
+            entry->rev_key.dst.addr = ctAddr;
+            entry->rev_key.dst.port = htons(port);
+        } else {
+            entry->rev_key.src.addr = ctAddr;
+            entry->rev_key.src.port = htons(port);
+        }
+
+        OVS_NAT_ENTRY *natEntry = OvsNatLookup(&entry->rev_key, TRUE);
+
+        if (!natEntry) {
+            natEntry = OvsAllocateMemoryWithTag(sizeof(*natEntry),
+                                                OVS_CT_POOL_TAG);
+            memcpy(&natEntry->key, &entry->key,
+                   sizeof natEntry->key);
+            memcpy(&natEntry->value, &entry->rev_key,
+                   sizeof natEntry->value);
+            natEntry->ctEntry = entry;
+            OvsNatAddEntry(natEntry);
+            return TRUE;
+        } else if (!allPortsTried) {
+            if (minPort == maxPort) {
+                allPortsTried = TRUE;
+            } else if (port == maxPort) {
+                port = minPort;
+            } else {
+                port++;
+            }
+            if (port == firstPort) {
+                allPortsTried = TRUE;
+            }
+        } else {
+            if (memcmp(&ctAddr, &maxCtAddr, sizeof ctAddr)) {
+                if (entry->key.dl_type == htons(ETH_TYPE_IPV4)) {
+                    ctAddr.ipv4_aligned = htonl(
+                        ntohl(ctAddr.ipv4_aligned) + 1);
+                } else {
+                    // XXX: IPv6 not supported
+                    return FALSE;
+                }
+            } else {
+                ctAddr = entry->natInfo.minAddr;
+            }
+            if (!memcmp(&ctAddr, &firstAddr, sizeof ctAddr)) {
+                if (!originalPortsTried) {
+                    originalPortsTried = TRUE;
+                    ctAddr = entry->natInfo.minAddr;
+                    minPort = MIN_NAT_EPHEMERAL_PORT;
+                    maxPort = MAX_NAT_EPHEMERAL_PORT;
+                } else {
+                    break;
+                }
+            }
+            firstPort = minPort;
+            port = firstPort;
+            allPortsTried = FALSE;
+        }
+    }
+    return FALSE;
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatLookup
+ *     Look up a NAT entry with the given key in the NAT table.
+ *     If reverse is TRUE, look up a NAT entry with the given value instead.
+ *----------------------------------------------------------------------------
+ */
+POVS_NAT_ENTRY
+OvsNatLookup(const OVS_CT_KEY *ctKey, BOOLEAN reverse)
+{
+    PLIST_ENTRY link;
+    POVS_NAT_ENTRY entry;
+
+    LIST_FORALL(OvsNatGetBucket(ctKey, reverse), link) {
+        if (reverse) {
+            entry = CONTAINING_RECORD(link, OVS_NAT_ENTRY, reverseLink);
+
+            if (OvsNatKeyAreSame(ctKey, &entry->value)) {
+                return entry;
+            }
+        } else {
+            entry = CONTAINING_RECORD(link, OVS_NAT_ENTRY, link);
+
+            if (OvsNatKeyAreSame(ctKey, &entry->key)) {
+                return entry;
+            }
+        }
+    }
+    return NULL;
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatDeleteEntry
+ *     Delete a NAT entry.
+ *----------------------------------------------------------------------------
+ */
+VOID
+OvsNatDeleteEntry(POVS_NAT_ENTRY entry)
+{
+    if (entry == NULL) {
+        return;
+    }
+    RemoveEntryList(&entry->link);
+    RemoveEntryList(&entry->reverseLink);
+    OvsFreeMemoryWithTag(entry, OVS_CT_POOL_TAG);
+}
+
+/*
+ *----------------------------------------------------------------------------
+ * OvsNatDeleteKey
+ *     Delete a NAT entry with the given key.
+ *----------------------------------------------------------------------------
+ */
+VOID
+OvsNatDeleteKey(const OVS_CT_KEY *key)
+{
+    OvsNatDeleteEntry(OvsNatLookup(key, FALSE));
+}
diff --git a/datapath-windows/ovsext/Conntrack-nat.h b/datapath-windows/ovsext/Conntrack-nat.h
new file mode 100644
index 0000000..aaa8ceb
--- /dev/null
+++ b/datapath-windows/ovsext/Conntrack-nat.h
@@ -0,0 +1,39 @@ 
+#ifndef _CONNTRACK_NAT_H
+#define _CONNTRACK_NAT_H
+
+#include "precomp.h"
+#include "Flow.h"
+#include "Debug.h"
+#include <stddef.h>
+#include "Conntrack.h"
+
+#define NAT_HASH_TABLE_SIZE ((UINT32)1 << 10)
+#define NAT_HASH_TABLE_MASK (NAT_HASH_TABLE_SIZE - 1)
+
+typedef struct OVS_NAT_ENTRY {
+    LIST_ENTRY link;
+    LIST_ENTRY reverseLink;
+    OVS_CT_KEY key;
+    OVS_CT_KEY value;
+    POVS_CT_ENTRY  ctEntry;
+} OVS_NAT_ENTRY, *POVS_NAT_ENTRY;
+
+__inline static BOOLEAN OvsIsForwardNat(UINT16 natAction) {
+    return !!(natAction & (NAT_ACTION_SRC | NAT_ACTION_DST));
+}
+
+NTSTATUS OvsNatInit();
+VOID OvsNatFlush(UINT16 zone);
+
+VOID OvsNatAddEntry(OVS_NAT_ENTRY* entry);
+
+VOID OvsNatDeleteEntry(POVS_NAT_ENTRY entry);
+VOID OvsNatDeleteKey(const OVS_CT_KEY *key);
+VOID OvsNatCleanup();
+
+POVS_NAT_ENTRY OvsNatLookup(const OVS_CT_KEY *ctKey, BOOLEAN reverse);
+BOOLEAN OvsNatCtEntry(OVS_CT_ENTRY *ctEntry);
+VOID OvsNatPacket(OvsForwardingContext *ovsFwdCtx, const OVS_CT_ENTRY *entry,
+                  UINT16 natAction, OvsFlowKey *key, BOOLEAN reverse);
+
+#endif