diff mbox

[PULL,09/14] virtio-scsi: start preparing for any_layout

Message ID 1403098014-1522-10-git-send-email-pbonzini@redhat.com
State New
Headers show

Commit Message

Paolo Bonzini June 18, 2014, 1:26 p.m. UTC
- Introduce virtio_scsi_init_req and virtio_scsi_free_req

- rename qemu_sgl_init_external to qemu_sgl_concat

- move virtio_scsi_parse_req from virtio_scsi_pop_req to callers
  and add header length checks to virtio_scsi_parse_req.

Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
---
 hw/scsi/virtio-scsi.c | 121 ++++++++++++++++++++++++++++++--------------------
 1 file changed, 72 insertions(+), 49 deletions(-)
diff mbox

Patch

diff --git a/hw/scsi/virtio-scsi.c b/hw/scsi/virtio-scsi.c
index b39880a..f013e35 100644
--- a/hw/scsi/virtio-scsi.c
+++ b/hw/scsi/virtio-scsi.c
@@ -15,6 +15,7 @@ 
 
 #include "hw/virtio/virtio-scsi.h"
 #include "qemu/error-report.h"
+#include "qemu/iov.h"
 #include <hw/scsi/scsi.h>
 #include <block/scsi.h>
 #include <hw/virtio/virtio-bus.h>
@@ -56,18 +57,35 @@  static inline SCSIDevice *virtio_scsi_device_find(VirtIOSCSI *s, uint8_t *lun)
     return scsi_device_find(&s->bus, 0, lun[1], virtio_scsi_get_lun(lun));
 }
 
+static VirtIOSCSIReq *virtio_scsi_init_req(VirtIOSCSI *s, VirtQueue *vq)
+{
+    VirtIOSCSIReq *req;
+    req = g_malloc(sizeof(*req));
+
+    req->vq = vq;
+    req->dev = s;
+    req->sreq = NULL;
+    qemu_sglist_init(&req->qsgl, DEVICE(s), 8, &address_space_memory);
+    return req;
+}
+
+static void virtio_scsi_free_req(VirtIOSCSIReq *req)
+{
+    qemu_sglist_destroy(&req->qsgl);
+    g_free(req);
+}
+
 static void virtio_scsi_complete_req(VirtIOSCSIReq *req)
 {
     VirtIOSCSI *s = req->dev;
     VirtQueue *vq = req->vq;
     VirtIODevice *vdev = VIRTIO_DEVICE(s);
     virtqueue_push(vq, &req->elem, req->qsgl.size + req->elem.in_sg[0].iov_len);
-    qemu_sglist_destroy(&req->qsgl);
     if (req->sreq) {
         req->sreq->hba_private = NULL;
         scsi_req_unref(req->sreq);
     }
-    g_free(req);
+    virtio_scsi_free_req(req);
     virtio_notify(vdev, vq);
 }
 
@@ -77,50 +95,55 @@  static void virtio_scsi_bad_req(void)
     exit(1);
 }
 
-static void qemu_sgl_init_external(VirtIOSCSIReq *req, struct iovec *sg,
+static void qemu_sgl_concat(VirtIOSCSIReq *req, struct iovec *sg,
                                    hwaddr *addr, int num)
 {
     QEMUSGList *qsgl = &req->qsgl;
 
-    qemu_sglist_init(qsgl, DEVICE(req->dev), num, &address_space_memory);
     while (num--) {
         qemu_sglist_add(qsgl, *(addr++), (sg++)->iov_len);
     }
 }
 
-static void virtio_scsi_parse_req(VirtIOSCSI *s, VirtQueue *vq,
-                                  VirtIOSCSIReq *req)
+static int virtio_scsi_parse_req(VirtIOSCSIReq *req,
+                                 unsigned req_size, unsigned resp_size)
 {
-    assert(req->elem.in_num);
-    req->vq = vq;
-    req->dev = s;
-    req->sreq = NULL;
+    if (req->elem.in_num == 0) {
+        return -EINVAL;
+    }
+
+    if (req->elem.out_sg[0].iov_len < req_size) {
+        return -EINVAL;
+    }
     if (req->elem.out_num) {
         req->req.buf = req->elem.out_sg[0].iov_base;
     }
+
+    if (req->elem.in_sg[0].iov_len < resp_size) {
+        return -EINVAL;
+    }
     req->resp.buf = req->elem.in_sg[0].iov_base;
 
     if (req->elem.out_num > 1) {
-        qemu_sgl_init_external(req, &req->elem.out_sg[1],
-                               &req->elem.out_addr[1],
-                               req->elem.out_num - 1);
+        qemu_sgl_concat(req, &req->elem.out_sg[1],
+                        &req->elem.out_addr[1],
+                        req->elem.out_num - 1);
     } else {
-        qemu_sgl_init_external(req, &req->elem.in_sg[1],
-                               &req->elem.in_addr[1],
-                               req->elem.in_num - 1);
+        qemu_sgl_concat(req, &req->elem.in_sg[1],
+                        &req->elem.in_addr[1],
+                        req->elem.in_num - 1);
     }
+
+    return 0;
 }
 
 static VirtIOSCSIReq *virtio_scsi_pop_req(VirtIOSCSI *s, VirtQueue *vq)
 {
-    VirtIOSCSIReq *req;
-    req = g_malloc(sizeof(*req));
+    VirtIOSCSIReq *req = virtio_scsi_init_req(s, vq);
     if (!virtqueue_pop(vq, &req->elem)) {
-        g_free(req);
+        virtio_scsi_free_req(req);
         return NULL;
     }
-
-    virtio_scsi_parse_req(s, vq, req);
     return req;
 }
 
@@ -143,9 +166,9 @@  static void *virtio_scsi_load_request(QEMUFile *f, SCSIRequest *sreq)
     VirtIOSCSIReq *req;
     uint32_t n;
 
-    req = g_malloc(sizeof(*req));
     qemu_get_be32s(f, &n);
     assert(n < vs->conf.num_queues);
+    req = virtio_scsi_init_req(s, vs->cmd_vqs[n]);
     qemu_get_buffer(f, (unsigned char *)&req->elem, sizeof(req->elem));
     /* TODO: add a way for SCSIBusInfo's load_request to fail,
      * and fail migration instead of asserting here.
@@ -156,7 +179,12 @@  static void *virtio_scsi_load_request(QEMUFile *f, SCSIRequest *sreq)
 #endif
     assert(req->elem.in_num <= ARRAY_SIZE(req->elem.in_sg));
     assert(req->elem.out_num <= ARRAY_SIZE(req->elem.out_sg));
-    virtio_scsi_parse_req(s, vs->cmd_vqs[n], req);
+
+    if (virtio_scsi_parse_req(req, sizeof(VirtIOSCSICmdReq) + vs->cdb_size,
+                              sizeof(VirtIOSCSICmdResp) + vs->sense_size) < 0) {
+        error_report("invalid SCSI request migration data");
+        exit(1);
+    }
 
     scsi_req_ref(sreq);
     req->sreq = sreq;
@@ -281,29 +309,29 @@  static void virtio_scsi_handle_ctrl(VirtIODevice *vdev, VirtQueue *vq)
     VirtIOSCSIReq *req;
 
     while ((req = virtio_scsi_pop_req(s, vq))) {
-        int out_size, in_size;
-        if (req->elem.out_num < 1 || req->elem.in_num < 1) {
+        int type;
+
+        if (iov_to_buf(req->elem.out_sg, req->elem.out_num, 0,
+                       &type, sizeof(type)) < sizeof(type)) {
             virtio_scsi_bad_req();
-            continue;
-        }
 
-        out_size = req->elem.out_sg[0].iov_len;
-        in_size = req->elem.in_sg[0].iov_len;
-        if (req->req.tmf->type == VIRTIO_SCSI_T_TMF) {
-            if (out_size < sizeof(VirtIOSCSICtrlTMFReq) ||
-                in_size < sizeof(VirtIOSCSICtrlTMFResp)) {
+        } else if (req->req.tmf->type == VIRTIO_SCSI_T_TMF) {
+            if (virtio_scsi_parse_req(req, sizeof(VirtIOSCSICtrlTMFReq),
+                                      sizeof(VirtIOSCSICtrlTMFResp)) < 0) {
                 virtio_scsi_bad_req();
+            } else {
+                virtio_scsi_do_tmf(s, req);
             }
-            virtio_scsi_do_tmf(s, req);
 
         } else if (req->req.tmf->type == VIRTIO_SCSI_T_AN_QUERY ||
                    req->req.tmf->type == VIRTIO_SCSI_T_AN_SUBSCRIBE) {
-            if (out_size < sizeof(VirtIOSCSICtrlANReq) ||
-                in_size < sizeof(VirtIOSCSICtrlANResp)) {
+            if (virtio_scsi_parse_req(req, sizeof(VirtIOSCSICtrlANReq),
+                                      sizeof(VirtIOSCSICtrlANResp)) < 0) {
                 virtio_scsi_bad_req();
+            } else {
+                req->resp.an->event_actual = 0;
+                req->resp.an->response = VIRTIO_SCSI_S_OK;
             }
-            req->resp.an->event_actual = 0;
-            req->resp.an->response = VIRTIO_SCSI_S_OK;
         }
         virtio_scsi_complete_req(req);
     }
@@ -373,23 +401,18 @@  static void virtio_scsi_handle_cmd(VirtIODevice *vdev, VirtQueue *vq)
 
     while ((req = virtio_scsi_pop_req(s, vq))) {
         SCSIDevice *d;
-        int out_size, in_size;
-        if (req->elem.out_num < 1 || req->elem.in_num < 1) {
-            virtio_scsi_bad_req();
-        }
-
-        out_size = req->elem.out_sg[0].iov_len;
-        in_size = req->elem.in_sg[0].iov_len;
-        if (out_size < sizeof(VirtIOSCSICmdReq) + vs->cdb_size ||
-            in_size < sizeof(VirtIOSCSICmdResp) + vs->sense_size) {
-            virtio_scsi_bad_req();
-        }
-
+        int rc;
         if (req->elem.out_num > 1 && req->elem.in_num > 1) {
             virtio_scsi_fail_cmd_req(req);
             continue;
         }
 
+        rc = virtio_scsi_parse_req(req, sizeof(VirtIOSCSICmdReq) + vs->cdb_size,
+                                   sizeof(VirtIOSCSICmdResp) + vs->sense_size);
+        if (rc < 0) {
+            virtio_scsi_bad_req();
+        }
+
         d = virtio_scsi_device_find(s, req->req.cmd->lun);
         if (!d) {
             req->resp.cmd->response = VIRTIO_SCSI_S_BAD_TARGET;