diff mbox series

[2/6] Drivers: hv: vmbus: Introduce vmbus_sendpacket_getid()

Message ID 20220928172646.19337-3-tim.gardner@canonical.com
State New
Headers show
Series Azure: PCI: Fix synchronization | expand

Commit Message

Tim Gardner Sept. 28, 2022, 5:26 p.m. UTC
From: "Andrea Parri (Microsoft)" <parri.andrea@gmail.com>

BugLink: https://bugs.launchpad.net/bugs/1991134

The function can be used to send a VMbus packet and retrieve the
corresponding transaction ID.  It will be used by hv_pci.

No functional change.

Suggested-by: Michael Kelley <mikelley@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
Reviewed-by: Michael Kelley <mikelley@microsoft.com>
Link: https://lore.kernel.org/r/20220419122325.10078-4-parri.andrea@gmail.com
Signed-off-by: Wei Liu <wei.liu@kernel.org>
(cherry picked from commit b03afa57c65e1e045e02df49777e953742745f4c)
Signed-off-by: Tim Gardner <tim.gardner@canonical.com>
---
 drivers/hv/channel.c      | 38 ++++++++++++++++++++++++++++++++------
 drivers/hv/hyperv_vmbus.h |  2 +-
 drivers/hv/ring_buffer.c  | 14 +++++++++++---
 include/linux/hyperv.h    |  7 +++++++
 4 files changed, 51 insertions(+), 10 deletions(-)
diff mbox series

Patch

diff --git a/drivers/hv/channel.c b/drivers/hv/channel.c
index 20fc8d50a039..585a8084848b 100644
--- a/drivers/hv/channel.c
+++ b/drivers/hv/channel.c
@@ -1022,11 +1022,13 @@  void vmbus_close(struct vmbus_channel *channel)
 EXPORT_SYMBOL_GPL(vmbus_close);
 
 /**
- * vmbus_sendpacket() - Send the specified buffer on the given channel
+ * vmbus_sendpacket_getid() - Send the specified buffer on the given channel
  * @channel: Pointer to vmbus_channel structure
  * @buffer: Pointer to the buffer you want to send the data from.
  * @bufferlen: Maximum size of what the buffer holds.
  * @requestid: Identifier of the request
+ * @trans_id: Identifier of the transaction associated to this request, if
+ *            the send is successful; undefined, otherwise.
  * @type: Type of packet that is being sent e.g. negotiate, time
  *	  packet etc.
  * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
@@ -1036,8 +1038,8 @@  EXPORT_SYMBOL_GPL(vmbus_close);
  *
  * Mainly used by Hyper-V drivers.
  */
-int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
-			   u32 bufferlen, u64 requestid,
+int vmbus_sendpacket_getid(struct vmbus_channel *channel, void *buffer,
+			   u32 bufferlen, u64 requestid, u64 *trans_id,
 			   enum vmbus_packet_type type, u32 flags)
 {
 	struct vmpacket_descriptor desc;
@@ -1063,7 +1065,31 @@  int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid, trans_id);
+}
+EXPORT_SYMBOL(vmbus_sendpacket_getid);
+
+/**
+ * vmbus_sendpacket() - Send the specified buffer on the given channel
+ * @channel: Pointer to vmbus_channel structure
+ * @buffer: Pointer to the buffer you want to send the data from.
+ * @bufferlen: Maximum size of what the buffer holds.
+ * @requestid: Identifier of the request
+ * @type: Type of packet that is being sent e.g. negotiate, time
+ *	  packet etc.
+ * @flags: 0 or VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED
+ *
+ * Sends data in @buffer directly to Hyper-V via the vmbus.
+ * This will send the data unparsed to Hyper-V.
+ *
+ * Mainly used by Hyper-V drivers.
+ */
+int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
+		     u32 bufferlen, u64 requestid,
+		     enum vmbus_packet_type type, u32 flags)
+{
+	return vmbus_sendpacket_getid(channel, buffer, bufferlen,
+				      requestid, NULL, type, flags);
 }
 EXPORT_SYMBOL(vmbus_sendpacket);
 
@@ -1122,7 +1148,7 @@  int vmbus_sendpacket_pagebuffer(struct vmbus_channel *channel,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_pagebuffer);
 
@@ -1160,7 +1186,7 @@  int vmbus_sendpacket_mpb_desc(struct vmbus_channel *channel,
 	bufferlist[2].iov_base = &aligned_data;
 	bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-	return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
+	return hv_ringbuffer_write(channel, bufferlist, 3, requestid, NULL);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_mpb_desc);
 
diff --git a/drivers/hv/hyperv_vmbus.h b/drivers/hv/hyperv_vmbus.h
index 3a1f007b678a..64c0b9cbe183 100644
--- a/drivers/hv/hyperv_vmbus.h
+++ b/drivers/hv/hyperv_vmbus.h
@@ -181,7 +181,7 @@  void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info);
 
 int hv_ringbuffer_write(struct vmbus_channel *channel,
 			const struct kvec *kv_list, u32 kv_count,
-			u64 requestid);
+			u64 requestid, u64 *trans_id);
 
 int hv_ringbuffer_read(struct vmbus_channel *channel,
 		       void *buffer, u32 buflen, u32 *buffer_actual_len,
diff --git a/drivers/hv/ring_buffer.c b/drivers/hv/ring_buffer.c
index 1602c16729f8..4cb3d16b8385 100644
--- a/drivers/hv/ring_buffer.c
+++ b/drivers/hv/ring_buffer.c
@@ -281,7 +281,7 @@  void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info)
 /* Write to the ring buffer. */
 int hv_ringbuffer_write(struct vmbus_channel *channel,
 			const struct kvec *kv_list, u32 kv_count,
-			u64 requestid)
+			u64 requestid, u64 *trans_id)
 {
 	int i;
 	u32 bytes_avail_towrite;
@@ -292,7 +292,7 @@  int hv_ringbuffer_write(struct vmbus_channel *channel,
 	unsigned long flags;
 	struct hv_ring_buffer_info *outring_info = &channel->outbound;
 	struct vmpacket_descriptor *desc = kv_list[0].iov_base;
-	u64 rqst_id = VMBUS_NO_RQSTOR;
+	u64 __trans_id, rqst_id = VMBUS_NO_RQSTOR;
 
 	if (channel->rescind)
 		return -ENODEV;
@@ -351,7 +351,15 @@  int hv_ringbuffer_write(struct vmbus_channel *channel,
 		}
 	}
 	desc = hv_get_ring_buffer(outring_info) + old_write;
-	desc->trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id;
+	__trans_id = (rqst_id == VMBUS_NO_RQSTOR) ? requestid : rqst_id;
+	/*
+	 * Ensure the compiler doesn't generate code that reads the value of
+	 * the transaction ID from the ring buffer, which is shared with the
+	 * Hyper-V host and subject to being changed at any time.
+	 */
+	WRITE_ONCE(desc->trans_id, __trans_id);
+	if (trans_id)
+		*trans_id = __trans_id;
 
 	/* Set previous packet start */
 	prev_indices = hv_get_ring_bufferindices(outring_info);
diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h
index 57ad6de53fbe..000bc299d12d 100644
--- a/include/linux/hyperv.h
+++ b/include/linux/hyperv.h
@@ -1175,6 +1175,13 @@  extern int vmbus_open(struct vmbus_channel *channel,
 
 extern void vmbus_close(struct vmbus_channel *channel);
 
+extern int vmbus_sendpacket_getid(struct vmbus_channel *channel,
+				  void *buffer,
+				  u32 bufferLen,
+				  u64 requestid,
+				  u64 *trans_id,
+				  enum vmbus_packet_type type,
+				  u32 flags);
 extern int vmbus_sendpacket(struct vmbus_channel *channel,
 				  void *buffer,
 				  u32 bufferLen,