From patchwork Tue Jul 21 10:58:45 2015 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Dexuan Cui X-Patchwork-Id: 498129 X-Patchwork-Delegate: davem@davemloft.net Return-Path: X-Original-To: patchwork-incoming@ozlabs.org Delivered-To: patchwork-incoming@ozlabs.org Received: from vger.kernel.org (vger.kernel.org [209.132.180.67]) by ozlabs.org (Postfix) with ESMTP id 934181402AC for ; Tue, 21 Jul 2015 19:50:54 +1000 (AEST) Received: (majordomo@vger.kernel.org) by vger.kernel.org via listexpand id S1754486AbbGUJeb (ORCPT ); Tue, 21 Jul 2015 05:34:31 -0400 Received: from p3plsmtps2ded04.prod.phx3.secureserver.net ([208.109.80.198]:54498 "EHLO p3plsmtps2ded04.prod.phx3.secureserver.net" rhost-flags-OK-OK-OK-OK) by vger.kernel.org with ESMTP id S1754445AbbGUJeX (ORCPT ); Tue, 21 Jul 2015 05:34:23 -0400 Received: from linuxonhyperv.com ([72.167.245.219]) by p3plsmtps2ded04.prod.phx3.secureserver.net with : DED : id uxaP1q01X4kklxU01xaP8c; Tue, 21 Jul 2015 02:34:23 -0700 x-originating-ip: 72.167.245.219 Received: by linuxonhyperv.com (Postfix, from userid 518) id 7D4A219020D; Tue, 21 Jul 2015 03:58:45 -0700 (PDT) From: Dexuan Cui To: gregkh@linuxfoundation.org, davem@davemloft.net, stephen@networkplumber.org, netdev@vger.kernel.org, linux-kernel@vger.kernel.org, driverdev-devel@linuxdriverproject.org, olaf@aepfle.de, apw@canonical.com, jasowang@redhat.com, kys@microsoft.com, pebolle@tiscali.nl, stefanha@redhat.com Subject: [PATCH V3 6/7] hvsock: introduce Hyper-V VM Sockets feature Date: Tue, 21 Jul 2015 03:58:45 -0700 Message-Id: <1437476325-6940-1-git-send-email-decui@microsoft.com> X-Mailer: git-send-email 1.7.4.1 Sender: netdev-owner@vger.kernel.org Precedence: bulk List-ID: X-Mailing-List: netdev@vger.kernel.org Hyper-V VM sockets (hvsock) supplies a byte-stream based communication mechanism between the host and a guest. It's kind of TCP over VMBus, but the transportation layer (VMBus) is much simpler than IP. With Hyper-V VM Sockets, applications between the host and a guest can talk with each other directly by the traditional BSD-style socket APIs. Hyper-V VM Sockets is only available on Windows 10 host and later. The patch implements the necessary support in the guest side by introducing a new socket address family AF_HYPERV. Signed-off-by: Dexuan Cui --- Changes since v1: - added __init and __exit for the module init/exit functions - net/hv_sock/Kconfig: "default m" -> "default m if HYPERV" - MODULE_LICENSE: "Dual MIT/GPL" -> "Dual BSD/GPL" Changes since v2: - fixed indentation issues - removed pr_debug I know the kernel has already had a VM Sockets driver (AF_VSOCK) based on VMware's VMCI (net/vmw_vsock/, drivers/misc/vmw_vmci), and KVM is proposing AF_VSOCK of virtio version: http://thread.gmane.org/gmane.linux.network/365205. However, though Hyper-V VM Sockets may seem conceptually similar to AF_VOSCK, there are differences in the transportation layer, and IMO these make the direct code reusing impractical: 1. In AF_VSOCK, the endpoint type is: , but in AF_HYPERV, the endpoint type is: . Here GUID is 128-bit. 2. AF_VSOCK supports SOCK_DGRAM, while AF_HYPERV doesn't. 3. AF_VSOCK supports some special sock opts, like SO_VM_SOCKETS_BUFFER_SIZE, SO_VM_SOCKETS_BUFFER_MIN/MAX_SIZE and SO_VM_SOCKETS_CONNECT_TIMEOUT. These are meaningless to AF_HYPERV. 4. Some AF_VSOCK's VMCI transportation ops are meanless to AF_HYPERV/VMBus, like .notify_recv_init .notify_recv_pre_block .notify_recv_pre_dequeue .notify_recv_post_dequeue .notify_send_init .notify_send_pre_block .notify_send_pre_enqueue .notify_send_post_enqueue etc. So I think we'd better introduce a new address family: AF_HYPERV. MAINTAINERS | 2 + include/linux/socket.h | 4 +- include/net/af_hvsock.h | 44 ++ include/uapi/linux/hyperv.h | 16 + net/Kconfig | 1 + net/Makefile | 1 + net/hv_sock/Kconfig | 10 + net/hv_sock/Makefile | 3 + net/hv_sock/af_hvsock.c | 1430 +++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 1510 insertions(+), 1 deletion(-) create mode 100644 include/net/af_hvsock.h create mode 100644 net/hv_sock/Kconfig create mode 100644 net/hv_sock/Makefile create mode 100644 net/hv_sock/af_hvsock.c diff --git a/MAINTAINERS b/MAINTAINERS index e7bdbac..a4a7e03 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -4941,7 +4941,9 @@ F: drivers/input/serio/hyperv-keyboard.c F: drivers/net/hyperv/ F: drivers/scsi/storvsc_drv.c F: drivers/video/fbdev/hyperv_fb.c +F: net/hv_sock/ F: include/linux/hyperv.h +F: include/net/af_hvsock.h F: tools/hv/ I2C OVER PARALLEL PORT diff --git a/include/linux/socket.h b/include/linux/socket.h index 5bf59c8..d5ef612 100644 --- a/include/linux/socket.h +++ b/include/linux/socket.h @@ -200,7 +200,8 @@ struct ucred { #define AF_ALG 38 /* Algorithm sockets */ #define AF_NFC 39 /* NFC sockets */ #define AF_VSOCK 40 /* vSockets */ -#define AF_MAX 41 /* For now.. */ +#define AF_HYPERV 41 /* Hyper-V virtual sockets */ +#define AF_MAX 42 /* For now.. */ /* Protocol families, same as address families. */ #define PF_UNSPEC AF_UNSPEC @@ -246,6 +247,7 @@ struct ucred { #define PF_ALG AF_ALG #define PF_NFC AF_NFC #define PF_VSOCK AF_VSOCK +#define PF_HYPERV AF_HYPERV #define PF_MAX AF_MAX /* Maximum queue length specifiable by listen. */ diff --git a/include/net/af_hvsock.h b/include/net/af_hvsock.h new file mode 100644 index 0000000..9951658 --- /dev/null +++ b/include/net/af_hvsock.h @@ -0,0 +1,44 @@ +#ifndef __AF_HVSOCK_H__ +#define __AF_HVSOCK_H__ + +#include +#include +#include + +#define VMBUS_RINGBUFFER_SIZE_HVSOCK_RECV (5 * PAGE_SIZE) +#define VMBUS_RINGBUFFER_SIZE_HVSOCK_SEND (5 * PAGE_SIZE) + +#define HVSOCK_RCV_BUF_SZ VMBUS_RINGBUFFER_SIZE_HVSOCK_RECV +#define HVSOCK_SND_BUF_SZ PAGE_SIZE + +#define sk_to_hvsock(__sk) ((struct hvsock_sock *)(__sk)) +#define hvsock_to_sk(__hvsk) ((struct sock *)(__hvsk)) + +struct hvsock_sock { + /* sk must be the first member. */ + struct sock sk; + + struct sockaddr_hv local_addr; + struct sockaddr_hv remote_addr; + + /* protected by the global hvsock_mutex */ + struct list_head bound_list; + struct list_head connected_list; + + struct list_head accept_queue; + /* used by enqueue and dequeue */ + struct mutex accept_queue_mutex; + + struct delayed_work dwork; + + u32 peer_shutdown; + + struct vmbus_channel *channel; + + char send_buf[HVSOCK_SND_BUF_SZ]; + char recv_buf[HVSOCK_RCV_BUF_SZ]; + unsigned int recv_data_len; + unsigned int recv_data_offset; +}; + +#endif /* __AF_HVSOCK_H__ */ diff --git a/include/uapi/linux/hyperv.h b/include/uapi/linux/hyperv.h index e4c0a35..23c29c9 100644 --- a/include/uapi/linux/hyperv.h +++ b/include/uapi/linux/hyperv.h @@ -26,6 +26,7 @@ #define _UAPI_HYPERV_H #include +#include /* * Framework version for util services. @@ -395,4 +396,19 @@ struct hv_kvp_ip_msg { struct hv_kvp_ipaddr_value kvp_ip_val; } __attribute__((packed)); +/* This is the Hyper-V socket's address format. */ +struct sockaddr_hv { + __kernel_sa_family_t shv_family; /* Address family */ + __le16 reserved; /* Must be zero */ + uuid_le shv_vm_id; /* Not used. Must be Zero. */ + uuid_le shv_service_id; /* Service ID */ +}; + +#define SHV_VMID_GUEST NULL_UUID_LE +#define SHV_VMID_HOST NULL_UUID_LE + +#define SHV_SERVICE_ID_ANY NULL_UUID_LE + +#define SHV_PROTO_RAW 1 + #endif /* _UAPI_HYPERV_H */ diff --git a/net/Kconfig b/net/Kconfig index 57a7c5a..9ad9f66 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -228,6 +228,7 @@ source "net/dns_resolver/Kconfig" source "net/batman-adv/Kconfig" source "net/openvswitch/Kconfig" source "net/vmw_vsock/Kconfig" +source "net/hv_sock/Kconfig" source "net/netlink/Kconfig" source "net/mpls/Kconfig" source "net/hsr/Kconfig" diff --git a/net/Makefile b/net/Makefile index 3995613..d95ff12 100644 --- a/net/Makefile +++ b/net/Makefile @@ -69,6 +69,7 @@ obj-$(CONFIG_BATMAN_ADV) += batman-adv/ obj-$(CONFIG_NFC) += nfc/ obj-$(CONFIG_OPENVSWITCH) += openvswitch/ obj-$(CONFIG_VSOCKETS) += vmw_vsock/ +obj-$(CONFIG_HYPERV_SOCK) += hv_sock/ obj-$(CONFIG_MPLS) += mpls/ obj-$(CONFIG_HSR) += hsr/ ifneq ($(CONFIG_NET_SWITCHDEV),) diff --git a/net/hv_sock/Kconfig b/net/hv_sock/Kconfig new file mode 100644 index 0000000..900373f --- /dev/null +++ b/net/hv_sock/Kconfig @@ -0,0 +1,10 @@ +config HYPERV_SOCK + tristate "Microsoft Hyper-V Socket (EXPERIMENTAL)" + depends on HYPERV + default m if HYPERV + help + Hyper-V Socket is a socket protocol similar to TCP, allowing + communication between a Linux guest and the host. + + To compile this driver as a module, choose M here: the module + will be called hv_sock. diff --git a/net/hv_sock/Makefile b/net/hv_sock/Makefile new file mode 100644 index 0000000..716c012 --- /dev/null +++ b/net/hv_sock/Makefile @@ -0,0 +1,3 @@ +obj-$(CONFIG_HYPERV_SOCK) += hv_sock.o + +hv_sock-y += af_hvsock.o diff --git a/net/hv_sock/af_hvsock.c b/net/hv_sock/af_hvsock.c new file mode 100644 index 0000000..873a609 --- /dev/null +++ b/net/hv_sock/af_hvsock.c @@ -0,0 +1,1430 @@ +/* + * Hyper-V vSockets driver + * + * Copyright(c) 2015, Microsoft Corporation. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote + * products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +static struct proto hvsock_proto = { + .name = "HV_SOCK", + .owner = THIS_MODULE, + .obj_size = sizeof(struct hvsock_sock), +}; + +#define SS_LISTEN 255 + +static LIST_HEAD(hvsock_bound_list); +static LIST_HEAD(hvsock_connected_list); +static DEFINE_MUTEX(hvsock_mutex); + +static bool uuid_equals(uuid_le u1, uuid_le u2) +{ + return !uuid_le_cmp(u1, u2); +} + +/* NOTE: hvsock_mutex must be held when the below helper functions, whose + * names begin with __ hvsock, are invoked. + */ +static void __hvsock_insert_bound(struct list_head *list, + struct hvsock_sock *hvsk) +{ + sock_hold(&hvsk->sk); + list_add(&hvsk->bound_list, list); +} + +static void __hvsock_insert_connected(struct list_head *list, + struct hvsock_sock *hvsk) +{ + sock_hold(&hvsk->sk); + list_add(&hvsk->connected_list, list); +} + +static void __hvsock_remove_bound(struct hvsock_sock *hvsk) +{ + list_del_init(&hvsk->bound_list); + sock_put(&hvsk->sk); +} + +static void __hvsock_remove_connected(struct hvsock_sock *hvsk) +{ + list_del_init(&hvsk->connected_list); + sock_put(&hvsk->sk); +} + +static struct sock *__hvsock_find_bound_socket(const struct sockaddr_hv *addr) +{ + struct hvsock_sock *hvsk; + + list_for_each_entry(hvsk, &hvsock_bound_list, bound_list) + if (uuid_equals(addr->shv_service_id, + hvsk->local_addr.shv_service_id)) + return hvsock_to_sk(hvsk); + return NULL; +} + +static struct sock *__hvsock_find_connected_socket_by_channel( + const struct vmbus_channel *channel) +{ + struct hvsock_sock *hvsk; + + list_for_each_entry(hvsk, &hvsock_connected_list, connected_list) + if (hvsk->channel == channel) + return hvsock_to_sk(hvsk); + return NULL; +} + +static bool __hvsock_in_bound_list(struct hvsock_sock *hvsk) +{ + return !list_empty(&hvsk->bound_list); +} + +static bool __hvsock_in_connected_list(struct hvsock_sock *hvsk) +{ + return !list_empty(&hvsk->connected_list); +} + +static void hvsock_insert_connected(struct hvsock_sock *hvsk) +{ + __hvsock_insert_connected(&hvsock_connected_list, hvsk); +} + +static +void hvsock_enqueue_accept(struct sock *listener, struct sock *connected) +{ + struct hvsock_sock *hvlistener; + struct hvsock_sock *hvconnected; + + hvlistener = sk_to_hvsock(listener); + hvconnected = sk_to_hvsock(connected); + + sock_hold(connected); + sock_hold(listener); + + mutex_lock(&hvlistener->accept_queue_mutex); + list_add_tail(&hvconnected->accept_queue, &hvlistener->accept_queue); + listener->sk_ack_backlog++; + mutex_unlock(&hvlistener->accept_queue_mutex); +} + +static struct sock *hvsock_dequeue_accept(struct sock *listener) +{ + struct hvsock_sock *hvlistener; + struct hvsock_sock *hvconnected; + + hvlistener = sk_to_hvsock(listener); + + mutex_lock(&hvlistener->accept_queue_mutex); + + if (list_empty(&hvlistener->accept_queue)) { + mutex_unlock(&hvlistener->accept_queue_mutex); + return NULL; + } + + hvconnected = list_entry(hvlistener->accept_queue.next, + struct hvsock_sock, accept_queue); + + list_del_init(&hvconnected->accept_queue); + listener->sk_ack_backlog--; + + mutex_unlock(&hvlistener->accept_queue_mutex); + + sock_put(listener); + /* The caller will need a reference on the connected socket so we let + * it call sock_put(). + */ + + return hvsock_to_sk(hvconnected); +} + +static bool hvsock_is_accept_queue_empty(struct sock *sk) +{ + struct hvsock_sock *hvsk = sk_to_hvsock(sk); + int ret; + + mutex_lock(&hvsk->accept_queue_mutex); + ret = list_empty(&hvsk->accept_queue); + mutex_unlock(&hvsk->accept_queue_mutex); + + return ret; +} + +static void hvsock_addr_init(struct sockaddr_hv *addr, uuid_le service_id) +{ + memset(addr, 0, sizeof(*addr)); + addr->shv_family = AF_HYPERV; + addr->shv_service_id = service_id; +} + +static int hvsock_addr_validate(const struct sockaddr_hv *addr) +{ + if (!addr) + return -EFAULT; + + if (addr->shv_family != AF_HYPERV) + return -EAFNOSUPPORT; + + if (addr->reserved != 0) + return -EINVAL; + + if (!uuid_equals(addr->shv_vm_id, NULL_UUID_LE)) + return -EINVAL; + + return 0; +} + +static bool hvsock_addr_bound(const struct sockaddr_hv *addr) +{ + return !uuid_equals(addr->shv_service_id, SHV_SERVICE_ID_ANY); +} + +static int hvsock_addr_cast(const struct sockaddr *addr, size_t len, + struct sockaddr_hv **out_addr) +{ + if (len < sizeof(**out_addr)) + return -EFAULT; + + *out_addr = (struct sockaddr_hv *)addr; + return hvsock_addr_validate(*out_addr); +} + +static int __hvsock_do_bind(struct hvsock_sock *hvsk, + struct sockaddr_hv *addr) +{ + struct sockaddr_hv hv_addr; + int ret = 0; + + hvsock_addr_init(&hv_addr, addr->shv_service_id); + + mutex_lock(&hvsock_mutex); + + if (uuid_equals(addr->shv_service_id, SHV_SERVICE_ID_ANY)) { + do { + uuid_le_gen(&hv_addr.shv_service_id); + } while (__hvsock_find_bound_socket(&hv_addr)); + } else { + if (__hvsock_find_bound_socket(&hv_addr)) { + ret = -EADDRINUSE; + goto out; + } + } + + hvsock_addr_init(&hvsk->local_addr, hv_addr.shv_service_id); + __hvsock_insert_bound(&hvsock_bound_list, hvsk); + +out: + mutex_unlock(&hvsock_mutex); + + return ret; +} + +static int __hvsock_bind(struct sock *sk, struct sockaddr_hv *addr) +{ + struct hvsock_sock *hvsk = sk_to_hvsock(sk); + int ret; + + if (hvsock_addr_bound(&hvsk->local_addr)) + return -EINVAL; + + switch (sk->sk_socket->type) { + case SOCK_STREAM: + ret = __hvsock_do_bind(hvsk, addr); + break; + + default: + ret = -EINVAL; + break; + } + + return ret; +} + +/* Autobind this socket to the local address if necessary. */ +static int hvsock_auto_bind(struct hvsock_sock *hvsk) +{ + struct sock *sk = hvsock_to_sk(hvsk); + struct sockaddr_hv local_addr; + + if (hvsock_addr_bound(&hvsk->local_addr)) + return 0; + hvsock_addr_init(&local_addr, SHV_SERVICE_ID_ANY); + return __hvsock_bind(sk, &local_addr); +} + +/* hvsock_release() can be invoked in 2 paths: + * 1. on process termination: + * hvsock_sk_destruct+0x1a/0x20 + * __sk_free+0x1d/0x130 + * sk_free+0x19/0x20 + * hvsock_release+0x138/0x160 + * sock_release+0x1f/0x90 + * sock_close+0x12/0x20 + * __fput+0xdf/0x1f0 + * ____fput+0xe/0x10 + * task_work_run+0xd4/0xf0 + * do_exit+0x334/0xb90 + * ? __do_page_fault+0x1e1/0x490 + * ? lockdep_sys_exit_thunk+0x35/0x67 + * do_group_exit+0x54/0xe0 + * SyS_exit_group+0x14/0x20 + * system_call_fastpath+0x16/0x1b + * + * 2. when accept() returns -ENITR: + * hvsock_release+0x151/0x160 + * sock_release+0x1f/0x90 + * sock_close+0x12/0x20 + * __fput+0xdf/0x1f0 + * ____fput+0xe/0x10 + * task_work_run+0xb7/0xf0 + * get_signal+0x750/0x770 + * do_signal+0x28/0xbb0 + * ? put_unused_fd+0x52/0x60 + * ? SYSC_accept4+0x1ca/0x220 + * do_notify_resume+0x4f/0x90 + * int_signal+0x12/0x17 + */ + +static void hvsock_sk_destruct(struct sock *sk) +{ + struct hvsock_sock *hvsk = sk_to_hvsock(sk); + + /* We use the mutex to serialize this function + * with hvsock_process_closing_connection(), otherwise, this function + * may free the channel in vmbus_close_internal() while + *hvsock_process_closing_connection() is still referencing the channel. + */ + mutex_lock(&hvsock_mutex); + + if (hvsk->channel) { + hvsk->channel->rescind = true; + vmbus_close(hvsk->channel); + hvsk->channel = NULL; + } + + mutex_unlock(&hvsock_mutex); +} + +static void __hvsock_release(struct sock *sk) +{ + struct hvsock_sock *hvsk; + struct sock *pending; + + hvsk = sk_to_hvsock(sk); + + mutex_lock(&hvsock_mutex); + if (__hvsock_in_bound_list(hvsk)) + __hvsock_remove_bound(hvsk); + + if (__hvsock_in_connected_list(hvsk)) + __hvsock_remove_connected(hvsk); + mutex_unlock(&hvsock_mutex); + + lock_sock(sk); + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + + /* Clean up any sockets that never were accepted. */ + while ((pending = hvsock_dequeue_accept(sk)) != NULL) { + __hvsock_release(pending); + sock_put(pending); + } + + release_sock(sk); + sock_put(sk); +} + +static int hvsock_release(struct socket *sock) +{ + /* sock->sk is NULL, if accept() is interrupted by a signal */ + if (sock->sk) { + __hvsock_release(sock->sk); + sock->sk = NULL; + } + + sock->state = SS_FREE; + return 0; +} + +static struct sock *__hvsock_create(struct net *net, struct socket *sock, + gfp_t priority, unsigned short type) +{ + struct hvsock_sock *hvsk; + struct sock *sk; + + sk = sk_alloc(net, AF_HYPERV, priority, &hvsock_proto, 0); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + + /* sk->sk_type is normally set in sock_init_data, but only if sock is + * non-NULL. We make sure that our sockets always have a type by + * setting it here if needed. + */ + if (!sock) + sk->sk_type = type; + + hvsk = sk_to_hvsock(sk); + hvsock_addr_init(&hvsk->local_addr, SHV_SERVICE_ID_ANY); + hvsock_addr_init(&hvsk->remote_addr, SHV_SERVICE_ID_ANY); + + sk->sk_destruct = hvsock_sk_destruct; + + /* Looks stream-based socket doesn't need this. */ + sk->sk_backlog_rcv = NULL; + + sk->sk_state = 0; + sock_reset_flag(sk, SOCK_DONE); + + INIT_LIST_HEAD(&hvsk->bound_list); + INIT_LIST_HEAD(&hvsk->connected_list); + + INIT_LIST_HEAD(&hvsk->accept_queue); + mutex_init(&hvsk->accept_queue_mutex); + + hvsk->peer_shutdown = 0; + + hvsk->recv_data_len = 0; + hvsk->recv_data_offset = 0; + + return sk; +} + +static int hvsock_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sockaddr_hv *hv_addr; + struct sock *sk; + int ret; + + sk = sock->sk; + + if (hvsock_addr_cast(addr, addr_len, &hv_addr) != 0) + return -EINVAL; + + lock_sock(sk); + ret = __hvsock_bind(sk, hv_addr); + release_sock(sk); + + return ret; +} + +static int hvsock_getname(struct socket *sock, + struct sockaddr *addr, int *addr_len, int peer) +{ + struct sockaddr_hv *hv_addr; + struct hvsock_sock *hvsk; + struct sock *sk; + int ret; + + sk = sock->sk; + hvsk = sk_to_hvsock(sk); + ret = 0; + + lock_sock(sk); + + if (peer) { + if (sock->state != SS_CONNECTED) { + ret = -ENOTCONN; + goto out; + } + hv_addr = &hvsk->remote_addr; + } else { + hv_addr = &hvsk->local_addr; + } + + __sockaddr_check_size(sizeof(*hv_addr)); + + memcpy(addr, hv_addr, sizeof(*hv_addr)); + *addr_len = sizeof(*hv_addr); + +out: + release_sock(sk); + return ret; +} + +static int hvsock_shutdown(struct socket *sock, int mode) +{ + struct sock *sk; + + if (mode < SHUT_RD || mode > SHUT_RDWR) + return -EINVAL; + /* This maps: + * SHUT_RD (0) -> RCV_SHUTDOWN (1) + * SHUT_WR (1) -> SEND_SHUTDOWN (2) + * SHUT_RDWR (2) -> SHUTDOWN_MASK (3) + */ + ++mode; + + if (sock->state == SS_UNCONNECTED) + return -ENOTCONN; + + sock->state = SS_DISCONNECTING; + + sk = sock->sk; + + lock_sock(sk); + + sk->sk_shutdown |= mode; + sk->sk_state_change(sk); + + /* TODO: how to send a FIN if we haven't done that? */ + if (mode & SEND_SHUTDOWN) + ; + + release_sock(sk); + + return 0; +} + +static unsigned int hvsock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct vmbus_channel *channel; + bool can_read, can_write; + struct hvsock_sock *hvsk; + struct sock *sk; + unsigned int mask; + + sk = sock->sk; + hvsk = sk_to_hvsock(sk); + + poll_wait(file, sk_sleep(sk), wait); + mask = 0; + + if (sk->sk_err) + /* Signify that there has been an error on this socket. */ + mask |= POLLERR; + + /* INET sockets treat local write shutdown and peer write shutdown as a + * case of POLLHUP set. + */ + if ((sk->sk_shutdown == SHUTDOWN_MASK) || + ((sk->sk_shutdown & SEND_SHUTDOWN) && + (hvsk->peer_shutdown & SEND_SHUTDOWN))) { + mask |= POLLHUP; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN || + hvsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= POLLRDHUP; + } + + lock_sock(sk); + + /* Listening sockets that have connections in their accept + * queue can be read. + */ + if (sk->sk_state == SS_LISTEN && !hvsock_is_accept_queue_empty(sk)) + mask |= POLLIN | POLLRDNORM; + + /* The mutex is to against hvsock_process_new_connection() */ + mutex_lock(&hvsock_mutex); + + channel = hvsk->channel; + if (channel) { + /* If there is something in the queue then we can read */ + vmbus_get_hvsock_rw_status(channel, &can_read, &can_write); + + if (!can_read && hvsk->recv_data_len > 0) + can_read = true; + + if (!(sk->sk_shutdown & RCV_SHUTDOWN) && can_read) + mask |= POLLIN | POLLRDNORM; + } else { + can_read = false; + can_write = false; + } + + mutex_unlock(&hvsock_mutex); + + /* Sockets whose connections have been closed terminated should + * also be considered read, and we check the shutdown flag for that. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN || + hvsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= POLLIN | POLLRDNORM; + } + + /* Connected sockets that can produce data can be written. */ + if (sk->sk_state == SS_CONNECTED && can_write && + !(sk->sk_shutdown & SEND_SHUTDOWN)) { + /* Remove POLLWRBAND since INET sockets are not setting it. + */ + mask |= POLLOUT | POLLWRNORM; + } + + /* Simulate INET socket poll behaviors, which sets + * POLLOUT|POLLWRNORM when peer is closed and nothing to read, + * but local send is not shutdown. + */ + if (sk->sk_state == SS_UNCONNECTED && + !(sk->sk_shutdown & SEND_SHUTDOWN)) + mask |= POLLOUT | POLLWRNORM; + + release_sock(sk); + + return mask; +} + +/* This function runs in the tasklet context of process_chn_event() */ +static void hvsock_events(void *ctx) +{ + struct sock *sk = (struct sock *)ctx; + struct hvsock_sock *hvsk = sk_to_hvsock(sk); + struct vmbus_channel *channel = hvsk->channel; + bool can_read, can_write; + + BUG_ON(!channel); + + vmbus_get_hvsock_rw_status(channel, &can_read, &can_write); + + if (can_read) + sk->sk_data_ready(sk); + + if (can_write) + sk->sk_write_space(sk); +} + +static int hvsock_process_new_connection(struct vmbus_channel *channel) +{ + struct hvsock_sock *hvsk, *new_hvsk; + struct sockaddr_hv hv_addr; + struct sock *sk, *new_sk; + + uuid_le *instance, *service_id; + int ret; + + instance = &channel->offermsg.offer.if_instance; + service_id = &channel->offermsg.offer.if_type; + + hvsock_addr_init(&hv_addr, *instance); + + mutex_lock(&hvsock_mutex); + + sk = __hvsock_find_bound_socket(&hv_addr); + + if (sk) { + /* It is from the guest client's connect() */ + if (sk->sk_state != SS_CONNECTING) { + ret = -ENXIO; + goto out; + } + + hvsk = sk_to_hvsock(sk); + hvsk->channel = channel; + set_channel_read_state(channel, false); + ret = vmbus_open(channel, VMBUS_RINGBUFFER_SIZE_HVSOCK_SEND, + VMBUS_RINGBUFFER_SIZE_HVSOCK_RECV, NULL, 0, + hvsock_events, sk); + if (ret != 0) { + hvsk->channel = NULL; + goto out; + } + + set_channel_pending_send_size(channel, + HVSOCK_PKT_LEN(PAGE_SIZE)); + sk->sk_state = SS_CONNECTED; + sk->sk_socket->state = SS_CONNECTED; + hvsock_insert_connected(hvsk); + sk->sk_state_change(sk); + goto out; + } + + /* Now we suppose it is from a host client's connect() */ + hvsock_addr_init(&hv_addr, *service_id); + sk = __hvsock_find_bound_socket(&hv_addr); + + /* No guest server listening? Well, let's ignore the offer */ + if (!sk || sk->sk_state != SS_LISTEN) { + ret = -ENXIO; + goto out; + } + + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) { + ret = -EMFILE; + goto out; + } + + new_sk = __hvsock_create(sock_net(sk), NULL, GFP_KERNEL, sk->sk_type); + if (!new_sk) { + ret = -ENOMEM; + goto out; + } + + new_hvsk = sk_to_hvsock(new_sk); + new_sk->sk_state = SS_CONNECTING; + hvsock_addr_init(&new_hvsk->local_addr, *service_id); + hvsock_addr_init(&new_hvsk->remote_addr, *instance); + + set_channel_read_state(channel, false); + new_hvsk->channel = channel; + ret = vmbus_open(channel, VMBUS_RINGBUFFER_SIZE_HVSOCK_SEND, + VMBUS_RINGBUFFER_SIZE_HVSOCK_RECV, NULL, 0, + hvsock_events, new_sk); + if (ret != 0) { + new_hvsk->channel = NULL; + sock_put(new_sk); + goto out; + } + set_channel_pending_send_size(channel, HVSOCK_PKT_LEN(PAGE_SIZE)); + + new_sk->sk_state = SS_CONNECTED; + hvsock_insert_connected(new_hvsk); + hvsock_enqueue_accept(sk, new_sk); + sk->sk_state_change(sk); +out: + mutex_unlock(&hvsock_mutex); + return ret; +} + +/* We don't invoke vmbus_close() in this function. Instead, we invoke + * vmbus_close() in hvsock_sk_destruct(). +*/ +static void hvsock_process_closing_connection(struct vmbus_channel *channel) +{ + struct hvsock_sock *hvsk; + struct sock *sk; + + mutex_lock(&hvsock_mutex); + + sk = __hvsock_find_connected_socket_by_channel(channel); + + /* The guest has already closed the connection? */ + if (!sk) + goto out; + + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_state = SS_UNCONNECTED; + sock_set_flag(sk, SOCK_DONE); + + hvsk = sk_to_hvsock(sk); + hvsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN; + + sk->sk_state_change(sk); + +out: + mutex_unlock(&hvsock_mutex); +} + +static void hvsock_connect_timeout(struct work_struct *work) +{ + struct hvsock_sock *hvsk; + struct sock *sk; + + hvsk = container_of(work, struct hvsock_sock, dwork.work); + sk = hvsock_to_sk(hvsk); + + lock_sock(sk); + if ((sk->sk_state == SS_CONNECTING) && + (sk->sk_shutdown != SHUTDOWN_MASK)) { + sk->sk_state = SS_UNCONNECTED; + sk->sk_err = ETIMEDOUT; + sk->sk_error_report(sk); + } + release_sock(sk); + + sock_put(sk); +} + +static int hvsock_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + struct sockaddr_hv *remote_addr; + struct hvsock_sock *hvsk; + struct sock *sk; + + DEFINE_WAIT(wait); + long timeout; + + int ret = 0; + + sk = sock->sk; + hvsk = sk_to_hvsock(sk); + + lock_sock(sk); + + switch (sock->state) { + case SS_CONNECTED: + ret = -EISCONN; + goto out; + case SS_DISCONNECTING: + ret = -EINVAL; + goto out; + case SS_CONNECTING: + /* This continues on so we can move sock into the SS_CONNECTED + * state once the connection has completed (at which point err + * will be set to zero also). Otherwise, we will either wait + * for the connection or return -EALREADY should this be a + * non-blocking call. + */ + ret = -EALREADY; + break; + default: + if ((sk->sk_state == SS_LISTEN) || + hvsock_addr_cast(addr, addr_len, &remote_addr) != 0) { + ret = -EINVAL; + goto out; + } + + /* Set the remote address that we are connecting to. */ + memcpy(&hvsk->remote_addr, remote_addr, + sizeof(hvsk->remote_addr)); + + ret = hvsock_auto_bind(hvsk); + if (ret) + goto out; + + sk->sk_state = SS_CONNECTING; + + ret = vmbus_send_tl_connect_request( + &hvsk->local_addr.shv_service_id, + &hvsk->remote_addr.shv_service_id); + if (ret < 0) + goto out; + + /* Mark sock as connecting and set the error code to in + * progress in case this is a non-blocking connect. + */ + sock->state = SS_CONNECTING; + ret = -EINPROGRESS; + } + + /* The receive path will handle all communication until we are able to + * enter the connected state. Here we wait for the connection to be + * completed or a notification of an error. + */ + timeout = 30 * HZ; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) { + if (flags & O_NONBLOCK) { + /* If we're not going to block, we schedule a timeout + * function to generate a timeout on the connection + * attempt, in case the peer doesn't respond in a + * timely manner. We hold on to the socket until the + * timeout fires. + */ + sock_hold(sk); + INIT_DELAYED_WORK(&hvsk->dwork, + hvsock_connect_timeout); + schedule_delayed_work(&hvsk->dwork, timeout); + + /* Skip ahead to preserve error code set above. */ + goto out_wait; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + ret = sock_intr_errno(timeout); + goto out_wait_error; + } else if (timeout == 0) { + ret = -ETIMEDOUT; + goto out_wait_error; + } + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + } + + if (sk->sk_err) { + ret = -sk->sk_err; + goto out_wait_error; + } else { + ret = 0; + } + +out_wait: + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + return ret; + +out_wait_error: + sk->sk_state = SS_UNCONNECTED; + sock->state = SS_UNCONNECTED; + goto out_wait; +} + +static +int hvsock_accept(struct socket *sock, struct socket *newsock, int flags) +{ + struct hvsock_sock *hvconnected; + struct sock *connected; + struct sock *listener; + + DEFINE_WAIT(wait); + long timeout; + + int ret = 0; + + listener = sock->sk; + + lock_sock(listener); + + if (sock->type != SOCK_STREAM) { + ret = -EOPNOTSUPP; + goto out; + } + + if (listener->sk_state != SS_LISTEN) { + ret = -EINVAL; + goto out; + } + + /* Wait for children sockets to appear; these are the new sockets + * created upon connection establishment. + */ + timeout = sock_sndtimeo(listener, flags & O_NONBLOCK); + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + + while ((connected = hvsock_dequeue_accept(listener)) == NULL && + listener->sk_err == 0) { + release_sock(listener); + timeout = schedule_timeout(timeout); + lock_sock(listener); + + if (signal_pending(current)) { + ret = sock_intr_errno(timeout); + goto out_wait; + } else if (timeout == 0) { + ret = -EAGAIN; + goto out_wait; + } + + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + } + + if (listener->sk_err) + ret = -listener->sk_err; + + if (connected) { + lock_sock(connected); + hvconnected = sk_to_hvsock(connected); + + /* If the listener socket has received an error, then we should + * reject this socket and return. Note that we simply mark the + * socket rejected, drop our reference, and let the cleanup + * function handle the cleanup; the fact that we found it in + * the listener's accept queue guarantees that the cleanup + * function hasn't run yet. + */ + if (ret) { + release_sock(connected); + sock_put(connected); + goto out_wait; + } + + newsock->state = SS_CONNECTED; + sock_graft(connected, newsock); + release_sock(connected); + sock_put(connected); + } + +out_wait: + finish_wait(sk_sleep(listener), &wait); +out: + release_sock(listener); + return ret; +} + +static int hvsock_listen(struct socket *sock, int backlog) +{ + struct hvsock_sock *hvsk; + struct sock *sk; + int ret = 0; + + sk = sock->sk; + lock_sock(sk); + + if (sock->type != SOCK_STREAM) { + ret = -EOPNOTSUPP; + goto out; + } + + if (sock->state != SS_UNCONNECTED) { + ret = -EINVAL; + goto out; + } + + if (backlog <= 0) { + ret = -EINVAL; + goto out; + } + /* This is an artificial limit */ + if (backlog > 128) + backlog = 128; + + hvsk = sk_to_hvsock(sk); + if (!hvsock_addr_bound(&hvsk->local_addr)) { + ret = -EINVAL; + goto out; + } + + sk->sk_ack_backlog = 0; + sk->sk_max_ack_backlog = backlog; + sk->sk_state = SS_LISTEN; +out: + release_sock(sk); + return ret; +} + +static int hvsock_setsockopt(struct socket *sock, + int level, + int optname, + char __user *optval, unsigned int optlen) +{ + return -ENOPROTOOPT; +} + +static int hvsock_getsockopt(struct socket *sock, + int level, + int optname, + char __user *optval, int __user *optlen) +{ + return -ENOPROTOOPT; +} + +static int hvsock_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct vmbus_channel *channel; + struct hvsock_sock *hvsk; + struct sock *sk; + + size_t total_to_write = len; + size_t total_written = 0; + + bool can_write; + long timeout; + int ret = 0; + + DEFINE_WAIT(wait); + + if (len == 0) + return -EINVAL; + + if (msg->msg_flags & ~MSG_DONTWAIT) { + pr_err("hvsock_sendmsg: unsupported flags=0x%x\n", + msg->msg_flags); + return -EOPNOTSUPP; + } + + sk = sock->sk; + hvsk = sk_to_hvsock(sk); + channel = hvsk->channel; + + lock_sock(sk); + + /* Callers should not provide a destination with stream sockets. */ + if (msg->msg_namelen) { + ret = -EOPNOTSUPP; + goto out; + } + + /* Send data only if both sides are not shutdown in the direction. */ + if (sk->sk_shutdown & SEND_SHUTDOWN || + hvsk->peer_shutdown & RCV_SHUTDOWN) { + ret = -EPIPE; + goto out; + } + + if (sk->sk_state != SS_CONNECTED || + !hvsock_addr_bound(&hvsk->local_addr)) { + ret = -ENOTCONN; + goto out; + } + + if (!hvsock_addr_bound(&hvsk->remote_addr)) { + ret = -EDESTADDRREQ; + goto out; + } + + timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (total_to_write > 0) { + u32 to_write; + + while (1) { + vmbus_get_hvsock_rw_status(channel, NULL, &can_write); + + if (can_write || sk->sk_err != 0 || + (sk->sk_shutdown & SEND_SHUTDOWN) || + (hvsk->peer_shutdown & RCV_SHUTDOWN)) + break; + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + ret = -EAGAIN; + goto out_wait; + } + + release_sock(sk); + + timeout = schedule_timeout(timeout); + + lock_sock(sk); + if (signal_pending(current)) { + ret = sock_intr_errno(timeout); + goto out_wait; + } else if (timeout == 0) { + ret = -EAGAIN; + goto out_wait; + } + + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + } + + /* These checks occur both as part of and after the loop + * conditional since we need to check before and after + * sleeping. + */ + if (sk->sk_err) { + ret = -sk->sk_err; + goto out_wait; + } else if ((sk->sk_shutdown & SEND_SHUTDOWN) || + (hvsk->peer_shutdown & RCV_SHUTDOWN)) { + ret = -EPIPE; + goto out_wait; + } + + /* Note: that write will only write as many bytes as possible + * in the ringbuffer. It is the caller's responsibility to + * check how many bytes we actually wrote. + */ + do { + to_write = min_t(size_t, HVSOCK_SND_BUF_SZ, + total_to_write); + ret = memcpy_from_msg(hvsk->send_buf, msg, to_write); + if (ret != 0) + goto out_wait; + + ret = vmbus_sendpacket_hvsock(channel, + hvsk->send_buf, + to_write); + if (ret != 0) + goto out_wait; + + total_written += to_write; + total_to_write -= to_write; + } while (total_to_write > 0); + } +out_wait: + if (total_written > 0) + ret = total_written; + + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + + /* ret is a bigger-than-0 total_written or a negative err code. */ + BUG_ON(ret == 0); + + return ret; +} + +static int hvsock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct vmbus_channel *channel; + struct hvsock_sock *hvsk; + struct sock *sk; + + size_t total_to_read = len; + size_t copied; + + bool can_read; + long timeout; + + int ret = 0; + + DEFINE_WAIT(wait); + + sk = sock->sk; + hvsk = sk_to_hvsock(sk); + channel = hvsk->channel; + + lock_sock(sk); + + if (sk->sk_state != SS_CONNECTED) { + /* Recvmsg is supposed to return 0 if a peer performs an + * orderly shutdown. Differentiate between that case and when a + * peer has not connected or a local shutdown occurred with the + * SOCK_DONE flag. + */ + if (sock_flag(sk, SOCK_DONE)) + ret = 0; + else + ret = -ENOTCONN; + + goto out; + } + + /* We ignore msg->addr_name/len. */ + if (flags & ~MSG_DONTWAIT) { + pr_err("hvsock_recvmsg: unsupported flags=0x%x\n", flags); + ret = -EOPNOTSUPP; + goto out; + } + + /* We don't check peer_shutdown flag here since peer may actually shut + * down, but there can be data in the queue that a local socket can + * receive. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN) { + ret = 0; + goto out; + } + + /* It is valid on Linux to pass in a zero-length receive buffer. This + * is not an error. We may as well bail out now. + */ + if (!len) { + ret = 0; + goto out; + } + + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + copied = 0; + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (1) { + bool need_refill = hvsk->recv_data_len == 0; + + if (need_refill) + vmbus_get_hvsock_rw_status(channel, &can_read, NULL); + else + can_read = true; + + if (can_read) { + u32 payload_len; + + if (need_refill) { + ret = vmbus_recvpacket_hvsock(channel, + hvsk->recv_buf, + HVSOCK_RCV_BUF_SZ, + &payload_len); + if (ret != 0 || payload_len == 0) { + ret = -EIO; + goto out_wait; + } + BUG_ON(payload_len > HVSOCK_RCV_BUF_SZ); + + hvsk->recv_data_len = payload_len; + hvsk->recv_data_offset = 0; + } + + if (hvsk->recv_data_len <= total_to_read) { + ret = memcpy_to_msg(msg, hvsk->recv_buf + + hvsk->recv_data_offset, + hvsk->recv_data_len); + if (ret != 0) + break; + + copied += hvsk->recv_data_len; + total_to_read -= hvsk->recv_data_len; + hvsk->recv_data_len = 0; + hvsk->recv_data_offset = 0; + + if (total_to_read == 0) + break; + } else { + ret = memcpy_to_msg(msg, hvsk->recv_buf + + hvsk->recv_data_offset, + total_to_read); + if (ret != 0) + break; + + copied += total_to_read; + hvsk->recv_data_len -= total_to_read; + hvsk->recv_data_offset += total_to_read; + total_to_read = 0; + break; + } + } else { + if (sk->sk_err || (sk->sk_shutdown & RCV_SHUTDOWN) || + (hvsk->peer_shutdown & SEND_SHUTDOWN)) + break; + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + ret = -EAGAIN; + break; + } + + if (copied > 0) + break; + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + ret = sock_intr_errno(timeout); + break; + } else if (timeout == 0) { + ret = -EAGAIN; + break; + } + + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + } + } + + if (sk->sk_err) + ret = -sk->sk_err; + else if (sk->sk_shutdown & RCV_SHUTDOWN) + ret = 0; + + if (copied > 0) { + ret = copied; + + /* If the other side has shutdown for sending and there + * is nothing more to read, then we modify the socket + * state. + */ + if ((hvsk->peer_shutdown & SEND_SHUTDOWN) && + hvsk->recv_data_len == 0) { + vmbus_get_hvsock_rw_status(channel, &can_read, NULL); + if (!can_read) { + sk->sk_state = SS_UNCONNECTED; + sock_set_flag(sk, SOCK_DONE); + sk->sk_state_change(sk); + } + } + } +out_wait: + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + return ret; +} + +static const struct proto_ops hvsock_ops = { + .family = PF_HYPERV, + .owner = THIS_MODULE, + .release = hvsock_release, + .bind = hvsock_bind, + .connect = hvsock_connect, + .socketpair = sock_no_socketpair, + .accept = hvsock_accept, + .getname = hvsock_getname, + .poll = hvsock_poll, + .ioctl = sock_no_ioctl, + .listen = hvsock_listen, + .shutdown = hvsock_shutdown, + .setsockopt = hvsock_setsockopt, + .getsockopt = hvsock_getsockopt, + .sendmsg = hvsock_sendmsg, + .recvmsg = hvsock_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static int hvsock_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + if (!capable(CAP_SYS_ADMIN) && !capable(CAP_NET_ADMIN)) + return -EPERM; + + if (protocol != 0 && protocol != SHV_PROTO_RAW) + return -EPROTONOSUPPORT; + + switch (sock->type) { + case SOCK_STREAM: + sock->ops = &hvsock_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sock->state = SS_UNCONNECTED; + + return __hvsock_create(net, sock, GFP_KERNEL, 0) ? 0 : -ENOMEM; +} + +static const struct net_proto_family hvsock_family_ops = { + .family = AF_HYPERV, + .create = hvsock_create, + .owner = THIS_MODULE, +}; + +static int __init hvsock_init(void) +{ + int ret; + + /* Hyper-V socket requires at least VMBus 4.0 */ + if ((vmbus_proto_version >> 16) < 4) { + pr_err("failed to load: VMBus 4 or later is required\n"); + return -ENODEV; + } + + ret = proto_register(&hvsock_proto, 0); + if (ret) { + pr_err("failed to register protocol\n"); + goto err1; + } + + ret = sock_register(&hvsock_family_ops); + if (ret) { + pr_err("failed to register address family\n"); + goto err2; + } + + vmbus_register_hvsock_callbacks(&hvsock_process_new_connection, + &hvsock_process_closing_connection); + return 0; + +err2: + proto_unregister(&hvsock_proto); +err1: + return ret; +} + +static void __exit hvsock_exit(void) +{ + vmbus_unregister_hvsock_callbacks(); + sock_unregister(AF_HYPERV); + proto_unregister(&hvsock_proto); +} + +module_init(hvsock_init); +module_exit(hvsock_exit); + +MODULE_DESCRIPTION("Microsoft Hyper-V Virtual Socket Family"); +MODULE_VERSION("0.1"); +MODULE_LICENSE("Dual BSD/GPL");