@@ -578,11 +578,21 @@ int vmbus_sendpacket_pagebuffer_bounce(
(struct hv_page_range *)desc->range, io_type);
if (unlikely(!bounce_pkt))
return -EAGAIN;
+
+ /*
+ * This assignment must be before hv_ringbuffer_write(), because as
+ * soon as hv_ringbuffer_write() returns, the channel callback may
+ * be running, and the callback needs request->bounce_pkt, which is
+ * assigned in this function. Note: if hv_ringbuffer_write() fails,
+ * *pbounce_pkt must be reset to NULL.
+ */
+ *pbounce_pkt = bounce_pkt;
+
ret = hv_ringbuffer_write(channel, bufferlist, 3, requestid);
- if (unlikely(ret < 0))
+ if (unlikely(ret < 0)) {
hv_bounce_resources_release(channel, bounce_pkt);
- else
- *pbounce_pkt = bounce_pkt;
+ *pbounce_pkt = NULL;
+ }
return ret;
}
@@ -619,13 +629,23 @@ int vmbus_sendpacket_mpb_desc_bounce(
if (unlikely(!bounce_pkt))
goto free;
bufferlist[0].iov_base = desc_bounce;
+
+ /*
+ * This assignment must be before hv_ringbuffer_write(), because as
+ * soon as hv_ringbuffer_write() returns, the channel callback may
+ * be running, and the callback needs request->bounce_pkt, which is
+ * assigned in this function. Note: if hv_ringbuffer_write() fails,
+ * *pbounce_pkt must be reset to NULL.
+ */
+ *pbounce_pkt = bounce_pkt;
+
ret = hv_ringbuffer_write(channel, bufferlist, 3, requestid);
free:
kfree(desc_bounce);
- if (unlikely(ret < 0))
+ if (unlikely(ret < 0)) {
hv_bounce_resources_release(channel, bounce_pkt);
- else
- *pbounce_pkt = bounce_pkt;
+ *pbounce_pkt = NULL;
+ }
return ret;
}