diff mbox series

[1/5] cifs: update receive_encrypted_standard to handle compounded responses

Message ID 20180808050749.26562-2-lsahlber@redhat.com
State New
Headers show
Series cifs: compounding | expand

Commit Message

Ronnie Sahlberg Aug. 8, 2018, 5:07 a.m. UTC
Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
---
 fs/cifs/cifsglob.h  |  5 +++-
 fs/cifs/connect.c   | 82 ++++++++++++++++++++++++++++++++---------------------
 fs/cifs/smb2ops.c   | 61 +++++++++++++++++++++++++++++++++------
 fs/cifs/transport.c |  2 --
 4 files changed, 107 insertions(+), 43 deletions(-)

Comments

Pavel Shilovsky Aug. 8, 2018, 11:02 p.m. UTC | #1
2018-08-07 22:07 GMT-07:00 Ronnie Sahlberg <lsahlber@redhat.com>:
> Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
> ---
>  fs/cifs/cifsglob.h  |  5 +++-
>  fs/cifs/connect.c   | 82 ++++++++++++++++++++++++++++++++---------------------
>  fs/cifs/smb2ops.c   | 61 +++++++++++++++++++++++++++++++++------
>  fs/cifs/transport.c |  2 --
>  4 files changed, 107 insertions(+), 43 deletions(-)
>
> diff --git a/fs/cifs/cifsglob.h b/fs/cifs/cifsglob.h
> index 41803d374da0..0c9ab62c3df4 100644
> --- a/fs/cifs/cifsglob.h
> +++ b/fs/cifs/cifsglob.h
> @@ -76,6 +76,9 @@
>  #define SMB_ECHO_INTERVAL_MAX 600
>  #define SMB_ECHO_INTERVAL_DEFAULT 60
>
> +/* maximum number of PDUs in one compound */
> +#define MAX_COMPOUND 5
> +
>  /*
>   * Default number of credits to keep available for SMB3.
>   * This value is chosen somewhat arbitrarily. The Windows client
> @@ -458,7 +461,7 @@ struct smb_version_operations {
>                                  struct smb_rqst *, struct smb_rqst *);
>         int (*is_transform_hdr)(void *buf);
>         int (*receive_transform)(struct TCP_Server_Info *,
> -                                struct mid_q_entry **);
> +                                struct mid_q_entry **, char **, int *);
>         enum securityEnum (*select_sectype)(struct TCP_Server_Info *,
>                             enum securityEnum);
>         int (*next_header)(char *);
> diff --git a/fs/cifs/connect.c b/fs/cifs/connect.c
> index d9bd10d295a9..20564ea71dd5 100644
> --- a/fs/cifs/connect.c
> +++ b/fs/cifs/connect.c
> @@ -850,13 +850,14 @@ cifs_handle_standard(struct TCP_Server_Info *server, struct mid_q_entry *mid)
>  static int
>  cifs_demultiplex_thread(void *p)
>  {
> -       int length;
> +       int i, num_mids, length;
>         struct TCP_Server_Info *server = p;
>         unsigned int pdu_length;
>         unsigned int next_offset;
>         char *buf = NULL;
>         struct task_struct *task_to_wake = NULL;
> -       struct mid_q_entry *mid_entry;
> +       struct mid_q_entry *mids[MAX_COMPOUND];
> +       char *bufs[MAX_COMPOUND];
>
>         current->flags |= PF_MEMALLOC;
>         cifs_dbg(FYI, "Demultiplex PID: %d\n", task_pid_nr(current));
> @@ -923,58 +924,75 @@ cifs_demultiplex_thread(void *p)
>                                 server->pdu_size = next_offset;
>                 }
>
> -               mid_entry = NULL;
> +               memset(mids, 0, sizeof(mids));
> +               memset(bufs, 0, sizeof(bufs));
> +               num_mids = 0;
> +
>                 if (server->ops->is_transform_hdr &&
>                     server->ops->receive_transform &&
>                     server->ops->is_transform_hdr(buf)) {
>                         length = server->ops->receive_transform(server,
> -                                                               &mid_entry);
> +                                                               mids,
> +                                                               bufs,
> +                                                               &num_mids);
>                 } else {
> -                       mid_entry = server->ops->find_mid(server, buf);
> +                       mids[0] = server->ops->find_mid(server, buf);
> +                       bufs[0] = buf;
> +                       if (mids[0])
> +                               num_mids = 1;
>
> -                       if (!mid_entry || !mid_entry->receive)
> -                               length = standard_receive3(server, mid_entry);
> +                       if (!mids[0] || !mids[0]->receive)
> +                               length = standard_receive3(server, mids[0]);
>                         else
> -                               length = mid_entry->receive(server, mid_entry);
> +                               length = mids[0]->receive(server, mids[0]);
>                 }
>
>                 if (length < 0) {
> -                       if (mid_entry)
> -                               cifs_mid_q_entry_release(mid_entry);
> +                       for (i = 0; i < num_mids; i++)
> +                               if (mids[i])
> +                                       cifs_mid_q_entry_release(mids[i]);
>                         continue;
>                 }
>
>                 if (server->large_buf)
>                         buf = server->bigbuf;
> +               //qqq
                    ^^^
typo

>
>                 server->lstrp = jiffies;
> -               if (mid_entry != NULL) {
> -                       mid_entry->resp_buf_size = server->pdu_size;
> -                       if ((mid_entry->mid_flags & MID_WAIT_CANCELLED) &&
> -                            mid_entry->mid_state == MID_RESPONSE_RECEIVED &&
> -                                       server->ops->handle_cancelled_mid)
> -                               server->ops->handle_cancelled_mid(
> -                                                       mid_entry->resp_buf,
> -                                                       server);
>
> -                       if (!mid_entry->multiRsp || mid_entry->multiEnd)
> -                               mid_entry->callback(mid_entry);
> +               for (i = 0; i < num_mids; i++) {
> +                       if (mids[i] != NULL) {
> +                               mids[i]->resp_buf_size = server->pdu_size;
> +                               if ((mids[i]->mid_flags & MID_WAIT_CANCELLED) &&
> +                                   mids[i]->mid_state == MID_RESPONSE_RECEIVED &&
> +                                   server->ops->handle_cancelled_mid)
> +                                       server->ops->handle_cancelled_mid(
> +                                                       mids[i]->resp_buf,
> +                                                       server);
>
> -                       cifs_mid_q_entry_release(mid_entry);
> -               } else if (server->ops->is_oplock_break &&
> -                          server->ops->is_oplock_break(buf, server)) {
> -                       cifs_dbg(FYI, "Received oplock break\n");
> -               } else {
> -                       cifs_dbg(VFS, "No task to wake, unknown frame received! NumMids %d\n",
> -                                atomic_read(&midCount));
> -                       cifs_dump_mem("Received Data is: ", buf,
> -                                     HEADER_SIZE(server));
> +                               if (!mids[i]->multiRsp || mids[i]->multiEnd)
> +                                       mids[i]->callback(mids[i]);
> +
> +                               cifs_mid_q_entry_release(mids[i]);
> +                       } else if (server->ops->is_oplock_break &&
> +                                  server->ops->is_oplock_break(bufs[i],
> +                                                               server)) {
> +                               cifs_dbg(FYI, "Received oplock break\n");
> +                       } else {
> +                               cifs_dbg(VFS, "No task to wake, unknown frame "
> +                                        "received! NumMids %d\n",
> +                                        atomic_read(&midCount));
> +                               cifs_dump_mem("Received Data is: ", bufs[i],
> +                                             HEADER_SIZE(server));
>  #ifdef CONFIG_CIFS_DEBUG2
> -                       if (server->ops->dump_detail)
> -                               server->ops->dump_detail(buf, server);
> -                       cifs_dump_mids(server);
> +                               if (server->ops->dump_detail)
> +                                       server->ops->dump_detail(bufs[i],
> +                                                                server);
> +                               cifs_dump_mids(server);
>  #endif /* CIFS_DEBUG2 */
> +                       }
>                 }
> +
>                 if (pdu_length > server->pdu_size) {
>                         if (!allocate_buffers(server))
>                                 continue;
> diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
> index ebc13ebebddf..5f3f27e33244 100644
> --- a/fs/cifs/smb2ops.c
> +++ b/fs/cifs/smb2ops.c
> @@ -2942,13 +2942,20 @@ receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid)
>
>  static int
>  receive_encrypted_standard(struct TCP_Server_Info *server,
> -                          struct mid_q_entry **mid)
> +                          struct mid_q_entry **mids, char **bufs,
> +                          int *num_mids)
>  {
> -       int length;
> +       int ret, length;
>         char *buf = server->smallbuf;
> +       char *tmpbuf;
> +       struct smb2_sync_hdr *shdr;
>         unsigned int pdu_length = server->pdu_size;
>         unsigned int buf_size;
>         struct mid_q_entry *mid_entry;
> +       int next_is_large;
> +       char *next_buffer = NULL;
> +
> +       *num_mids = 0;
>
>         /* switch to large buffer if too big for a small one */
>         if (pdu_length > MAX_CIFS_SMALL_BUFFER_SIZE) {
> @@ -2969,24 +2976,61 @@ receive_encrypted_standard(struct TCP_Server_Info *server,
>         if (length)
>                 return length;
>
> +       next_is_large = server->large_buf;
> + one_more:
> +       shdr = (struct smb2_sync_hdr *)buf;
> +       if (shdr->NextCommand) {
> +               if (next_is_large) {
> +                       tmpbuf = server->bigbuf;
> +                       next_buffer = (char *)cifs_buf_get();
> +               } else {
> +                       tmpbuf = server->smallbuf;
> +                       next_buffer = (char *)cifs_small_buf_get();
> +               }
> +               memcpy(next_buffer,
> +                      tmpbuf + le32_to_cpu(shdr->NextCommand),
> +                      pdu_length - le32_to_cpu(shdr->NextCommand));
> +       }
> +
>         mid_entry = smb2_find_mid(server, buf);
>         if (mid_entry == NULL)
>                 cifs_dbg(FYI, "mid not found\n");
>         else {
>                 cifs_dbg(FYI, "mid found\n");
>                 mid_entry->decrypted = true;
> +               mid_entry->resp_buf_size = server->pdu_size;

server->pdu_size is being set in demultiplex thread to pdu_length that
equals to a size of a whole PDU. While without encryption demultiplex
thread updates server->pdu_size by setting it to a size of a
particular chain element, it doesn't seem that we update this value in
this function

I think we need a local variable pdu_size which is being updated the
same way demultiplex thread updates server->pdu_size: set it to
pdu_length at the beginning of the "one_more" loop and then truncate
to the offset of the next chain element.

>         }
>
> -       *mid = mid_entry;
> +       if (*num_mids >= MAX_COMPOUND) {
> +               cifs_dbg(VFS, "too many PDUs in compound\n");
> +               return -1;
> +       }
> +       bufs[*num_mids] = buf;
> +       mids[(*num_mids)++] = mid_entry;
>
>         if (mid_entry && mid_entry->handle)
> -               return mid_entry->handle(server, mid_entry);
> +               ret = mid_entry->handle(server, mid_entry);
> +       else
> +               ret = cifs_handle_standard(server, mid_entry);
> +
> +       if (ret == 0 && shdr->NextCommand) {
> +               pdu_length -= le32_to_cpu(shdr->NextCommand);
> +               server->large_buf = next_is_large;
> +               if (next_is_large) {
> +                       server->bigbuf = next_buffer;
> +               } else {
> +                       server->smallbuf = next_buffer;
> +               }
> +               buf += le32_to_cpu(shdr->NextCommand);
> +               goto one_more;
> +       }
>
> -       return cifs_handle_standard(server, mid_entry);
> +        return ret;
>  }
>
>  static int
> -smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
> +smb3_receive_transform(struct TCP_Server_Info *server,
> +                      struct mid_q_entry **mids, char **bufs, int *num_mids)
>  {
>         char *buf = server->smallbuf;
>         unsigned int pdu_length = server->pdu_size;
> @@ -3009,10 +3053,11 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
>                 return -ECONNABORTED;
>         }
>
> +       /* TODO: add support for compounds containing READ. */
>         if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
> -               return receive_encrypted_read(server, mid);
> +               return receive_encrypted_read(server, &mids[0]);
>
> -       return receive_encrypted_standard(server, mid);
> +       return receive_encrypted_standard(server, mids, bufs, num_mids);
>  }
>
>  int
> diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
> index c53c0908d4c6..78f96fa3d7d9 100644
> --- a/fs/cifs/transport.c
> +++ b/fs/cifs/transport.c
> @@ -383,8 +383,6 @@ __smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
>         return rc;
>  }
>
> -#define MAX_COMPOUND 5
> -
>  static int
>  smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
>               struct smb_rqst *rqst, int flags)
> --
> 2.13.3
>
> --
> To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
> the body of a message to majordomo@vger.kernel.org
> More majordomo info at  http://vger.kernel.org/majordomo-info.html


--
Best regards,
Pavel Shilovsky
--
To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Steve French Aug. 8, 2018, 11:13 p.m. UTC | #2
I fixed the typo at least for this - and repushed (also added reviewed
bys to the other ones which Pavel has reviewed)
On Wed, Aug 8, 2018 at 6:02 PM Pavel Shilovsky <piastryyy@gmail.com> wrote:
>
> 2018-08-07 22:07 GMT-07:00 Ronnie Sahlberg <lsahlber@redhat.com>:
> > Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
> > ---
> >  fs/cifs/cifsglob.h  |  5 +++-
> >  fs/cifs/connect.c   | 82 ++++++++++++++++++++++++++++++++---------------------
> >  fs/cifs/smb2ops.c   | 61 +++++++++++++++++++++++++++++++++------
> >  fs/cifs/transport.c |  2 --
> >  4 files changed, 107 insertions(+), 43 deletions(-)
> >
> > diff --git a/fs/cifs/cifsglob.h b/fs/cifs/cifsglob.h
> > index 41803d374da0..0c9ab62c3df4 100644
> > --- a/fs/cifs/cifsglob.h
> > +++ b/fs/cifs/cifsglob.h
> > @@ -76,6 +76,9 @@
> >  #define SMB_ECHO_INTERVAL_MAX 600
> >  #define SMB_ECHO_INTERVAL_DEFAULT 60
> >
> > +/* maximum number of PDUs in one compound */
> > +#define MAX_COMPOUND 5
> > +
> >  /*
> >   * Default number of credits to keep available for SMB3.
> >   * This value is chosen somewhat arbitrarily. The Windows client
> > @@ -458,7 +461,7 @@ struct smb_version_operations {
> >                                  struct smb_rqst *, struct smb_rqst *);
> >         int (*is_transform_hdr)(void *buf);
> >         int (*receive_transform)(struct TCP_Server_Info *,
> > -                                struct mid_q_entry **);
> > +                                struct mid_q_entry **, char **, int *);
> >         enum securityEnum (*select_sectype)(struct TCP_Server_Info *,
> >                             enum securityEnum);
> >         int (*next_header)(char *);
> > diff --git a/fs/cifs/connect.c b/fs/cifs/connect.c
> > index d9bd10d295a9..20564ea71dd5 100644
> > --- a/fs/cifs/connect.c
> > +++ b/fs/cifs/connect.c
> > @@ -850,13 +850,14 @@ cifs_handle_standard(struct TCP_Server_Info *server, struct mid_q_entry *mid)
> >  static int
> >  cifs_demultiplex_thread(void *p)
> >  {
> > -       int length;
> > +       int i, num_mids, length;
> >         struct TCP_Server_Info *server = p;
> >         unsigned int pdu_length;
> >         unsigned int next_offset;
> >         char *buf = NULL;
> >         struct task_struct *task_to_wake = NULL;
> > -       struct mid_q_entry *mid_entry;
> > +       struct mid_q_entry *mids[MAX_COMPOUND];
> > +       char *bufs[MAX_COMPOUND];
> >
> >         current->flags |= PF_MEMALLOC;
> >         cifs_dbg(FYI, "Demultiplex PID: %d\n", task_pid_nr(current));
> > @@ -923,58 +924,75 @@ cifs_demultiplex_thread(void *p)
> >                                 server->pdu_size = next_offset;
> >                 }
> >
> > -               mid_entry = NULL;
> > +               memset(mids, 0, sizeof(mids));
> > +               memset(bufs, 0, sizeof(bufs));
> > +               num_mids = 0;
> > +
> >                 if (server->ops->is_transform_hdr &&
> >                     server->ops->receive_transform &&
> >                     server->ops->is_transform_hdr(buf)) {
> >                         length = server->ops->receive_transform(server,
> > -                                                               &mid_entry);
> > +                                                               mids,
> > +                                                               bufs,
> > +                                                               &num_mids);
> >                 } else {
> > -                       mid_entry = server->ops->find_mid(server, buf);
> > +                       mids[0] = server->ops->find_mid(server, buf);
> > +                       bufs[0] = buf;
> > +                       if (mids[0])
> > +                               num_mids = 1;
> >
> > -                       if (!mid_entry || !mid_entry->receive)
> > -                               length = standard_receive3(server, mid_entry);
> > +                       if (!mids[0] || !mids[0]->receive)
> > +                               length = standard_receive3(server, mids[0]);
> >                         else
> > -                               length = mid_entry->receive(server, mid_entry);
> > +                               length = mids[0]->receive(server, mids[0]);
> >                 }
> >
> >                 if (length < 0) {
> > -                       if (mid_entry)
> > -                               cifs_mid_q_entry_release(mid_entry);
> > +                       for (i = 0; i < num_mids; i++)
> > +                               if (mids[i])
> > +                                       cifs_mid_q_entry_release(mids[i]);
> >                         continue;
> >                 }
> >
> >                 if (server->large_buf)
> >                         buf = server->bigbuf;
> > +               //qqq
>                     ^^^
> typo
>
> >
> >                 server->lstrp = jiffies;
> > -               if (mid_entry != NULL) {
> > -                       mid_entry->resp_buf_size = server->pdu_size;
> > -                       if ((mid_entry->mid_flags & MID_WAIT_CANCELLED) &&
> > -                            mid_entry->mid_state == MID_RESPONSE_RECEIVED &&
> > -                                       server->ops->handle_cancelled_mid)
> > -                               server->ops->handle_cancelled_mid(
> > -                                                       mid_entry->resp_buf,
> > -                                                       server);
> >
> > -                       if (!mid_entry->multiRsp || mid_entry->multiEnd)
> > -                               mid_entry->callback(mid_entry);
> > +               for (i = 0; i < num_mids; i++) {
> > +                       if (mids[i] != NULL) {
> > +                               mids[i]->resp_buf_size = server->pdu_size;
> > +                               if ((mids[i]->mid_flags & MID_WAIT_CANCELLED) &&
> > +                                   mids[i]->mid_state == MID_RESPONSE_RECEIVED &&
> > +                                   server->ops->handle_cancelled_mid)
> > +                                       server->ops->handle_cancelled_mid(
> > +                                                       mids[i]->resp_buf,
> > +                                                       server);
> >
> > -                       cifs_mid_q_entry_release(mid_entry);
> > -               } else if (server->ops->is_oplock_break &&
> > -                          server->ops->is_oplock_break(buf, server)) {
> > -                       cifs_dbg(FYI, "Received oplock break\n");
> > -               } else {
> > -                       cifs_dbg(VFS, "No task to wake, unknown frame received! NumMids %d\n",
> > -                                atomic_read(&midCount));
> > -                       cifs_dump_mem("Received Data is: ", buf,
> > -                                     HEADER_SIZE(server));
> > +                               if (!mids[i]->multiRsp || mids[i]->multiEnd)
> > +                                       mids[i]->callback(mids[i]);
> > +
> > +                               cifs_mid_q_entry_release(mids[i]);
> > +                       } else if (server->ops->is_oplock_break &&
> > +                                  server->ops->is_oplock_break(bufs[i],
> > +                                                               server)) {
> > +                               cifs_dbg(FYI, "Received oplock break\n");
> > +                       } else {
> > +                               cifs_dbg(VFS, "No task to wake, unknown frame "
> > +                                        "received! NumMids %d\n",
> > +                                        atomic_read(&midCount));
> > +                               cifs_dump_mem("Received Data is: ", bufs[i],
> > +                                             HEADER_SIZE(server));
> >  #ifdef CONFIG_CIFS_DEBUG2
> > -                       if (server->ops->dump_detail)
> > -                               server->ops->dump_detail(buf, server);
> > -                       cifs_dump_mids(server);
> > +                               if (server->ops->dump_detail)
> > +                                       server->ops->dump_detail(bufs[i],
> > +                                                                server);
> > +                               cifs_dump_mids(server);
> >  #endif /* CIFS_DEBUG2 */
> > +                       }
> >                 }
> > +
> >                 if (pdu_length > server->pdu_size) {
> >                         if (!allocate_buffers(server))
> >                                 continue;
> > diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
> > index ebc13ebebddf..5f3f27e33244 100644
> > --- a/fs/cifs/smb2ops.c
> > +++ b/fs/cifs/smb2ops.c
> > @@ -2942,13 +2942,20 @@ receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid)
> >
> >  static int
> >  receive_encrypted_standard(struct TCP_Server_Info *server,
> > -                          struct mid_q_entry **mid)
> > +                          struct mid_q_entry **mids, char **bufs,
> > +                          int *num_mids)
> >  {
> > -       int length;
> > +       int ret, length;
> >         char *buf = server->smallbuf;
> > +       char *tmpbuf;
> > +       struct smb2_sync_hdr *shdr;
> >         unsigned int pdu_length = server->pdu_size;
> >         unsigned int buf_size;
> >         struct mid_q_entry *mid_entry;
> > +       int next_is_large;
> > +       char *next_buffer = NULL;
> > +
> > +       *num_mids = 0;
> >
> >         /* switch to large buffer if too big for a small one */
> >         if (pdu_length > MAX_CIFS_SMALL_BUFFER_SIZE) {
> > @@ -2969,24 +2976,61 @@ receive_encrypted_standard(struct TCP_Server_Info *server,
> >         if (length)
> >                 return length;
> >
> > +       next_is_large = server->large_buf;
> > + one_more:
> > +       shdr = (struct smb2_sync_hdr *)buf;
> > +       if (shdr->NextCommand) {
> > +               if (next_is_large) {
> > +                       tmpbuf = server->bigbuf;
> > +                       next_buffer = (char *)cifs_buf_get();
> > +               } else {
> > +                       tmpbuf = server->smallbuf;
> > +                       next_buffer = (char *)cifs_small_buf_get();
> > +               }
> > +               memcpy(next_buffer,
> > +                      tmpbuf + le32_to_cpu(shdr->NextCommand),
> > +                      pdu_length - le32_to_cpu(shdr->NextCommand));
> > +       }
> > +
> >         mid_entry = smb2_find_mid(server, buf);
> >         if (mid_entry == NULL)
> >                 cifs_dbg(FYI, "mid not found\n");
> >         else {
> >                 cifs_dbg(FYI, "mid found\n");
> >                 mid_entry->decrypted = true;
> > +               mid_entry->resp_buf_size = server->pdu_size;
>
> server->pdu_size is being set in demultiplex thread to pdu_length that
> equals to a size of a whole PDU. While without encryption demultiplex
> thread updates server->pdu_size by setting it to a size of a
> particular chain element, it doesn't seem that we update this value in
> this function
>
> I think we need a local variable pdu_size which is being updated the
> same way demultiplex thread updates server->pdu_size: set it to
> pdu_length at the beginning of the "one_more" loop and then truncate
> to the offset of the next chain element.
>
> >         }
> >
> > -       *mid = mid_entry;
> > +       if (*num_mids >= MAX_COMPOUND) {
> > +               cifs_dbg(VFS, "too many PDUs in compound\n");
> > +               return -1;
> > +       }
> > +       bufs[*num_mids] = buf;
> > +       mids[(*num_mids)++] = mid_entry;
> >
> >         if (mid_entry && mid_entry->handle)
> > -               return mid_entry->handle(server, mid_entry);
> > +               ret = mid_entry->handle(server, mid_entry);
> > +       else
> > +               ret = cifs_handle_standard(server, mid_entry);
> > +
> > +       if (ret == 0 && shdr->NextCommand) {
> > +               pdu_length -= le32_to_cpu(shdr->NextCommand);
> > +               server->large_buf = next_is_large;
> > +               if (next_is_large) {
> > +                       server->bigbuf = next_buffer;
> > +               } else {
> > +                       server->smallbuf = next_buffer;
> > +               }
> > +               buf += le32_to_cpu(shdr->NextCommand);
> > +               goto one_more;
> > +       }
> >
> > -       return cifs_handle_standard(server, mid_entry);
> > +        return ret;
> >  }
> >
> >  static int
> > -smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
> > +smb3_receive_transform(struct TCP_Server_Info *server,
> > +                      struct mid_q_entry **mids, char **bufs, int *num_mids)
> >  {
> >         char *buf = server->smallbuf;
> >         unsigned int pdu_length = server->pdu_size;
> > @@ -3009,10 +3053,11 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
> >                 return -ECONNABORTED;
> >         }
> >
> > +       /* TODO: add support for compounds containing READ. */
> >         if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
> > -               return receive_encrypted_read(server, mid);
> > +               return receive_encrypted_read(server, &mids[0]);
> >
> > -       return receive_encrypted_standard(server, mid);
> > +       return receive_encrypted_standard(server, mids, bufs, num_mids);
> >  }
> >
> >  int
> > diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
> > index c53c0908d4c6..78f96fa3d7d9 100644
> > --- a/fs/cifs/transport.c
> > +++ b/fs/cifs/transport.c
> > @@ -383,8 +383,6 @@ __smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
> >         return rc;
> >  }
> >
> > -#define MAX_COMPOUND 5
> > -
> >  static int
> >  smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
> >               struct smb_rqst *rqst, int flags)
> > --
> > 2.13.3
> >
> > --
> > To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
> > the body of a message to majordomo@vger.kernel.org
> > More majordomo info at  http://vger.kernel.org/majordomo-info.html
>
>
> --
> Best regards,
> Pavel Shilovsky
ronnie sahlberg Aug. 8, 2018, 11:21 p.m. UTC | #3
On Thu, Aug 9, 2018 at 9:02 AM, Pavel Shilovsky <piastryyy@gmail.com> wrote:
> 2018-08-07 22:07 GMT-07:00 Ronnie Sahlberg <lsahlber@redhat.com>:
>> Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
>> ---
>>  fs/cifs/cifsglob.h  |  5 +++-
>>  fs/cifs/connect.c   | 82 ++++++++++++++++++++++++++++++++---------------------
>>  fs/cifs/smb2ops.c   | 61 +++++++++++++++++++++++++++++++++------
>>  fs/cifs/transport.c |  2 --
>>  4 files changed, 107 insertions(+), 43 deletions(-)
>>
>> diff --git a/fs/cifs/cifsglob.h b/fs/cifs/cifsglob.h
>> index 41803d374da0..0c9ab62c3df4 100644
>> --- a/fs/cifs/cifsglob.h
>> +++ b/fs/cifs/cifsglob.h
>> @@ -76,6 +76,9 @@
>>  #define SMB_ECHO_INTERVAL_MAX 600
>>  #define SMB_ECHO_INTERVAL_DEFAULT 60
>>
>> +/* maximum number of PDUs in one compound */
>> +#define MAX_COMPOUND 5
>> +
>>  /*
>>   * Default number of credits to keep available for SMB3.
>>   * This value is chosen somewhat arbitrarily. The Windows client
>> @@ -458,7 +461,7 @@ struct smb_version_operations {
>>                                  struct smb_rqst *, struct smb_rqst *);
>>         int (*is_transform_hdr)(void *buf);
>>         int (*receive_transform)(struct TCP_Server_Info *,
>> -                                struct mid_q_entry **);
>> +                                struct mid_q_entry **, char **, int *);
>>         enum securityEnum (*select_sectype)(struct TCP_Server_Info *,
>>                             enum securityEnum);
>>         int (*next_header)(char *);
>> diff --git a/fs/cifs/connect.c b/fs/cifs/connect.c
>> index d9bd10d295a9..20564ea71dd5 100644
>> --- a/fs/cifs/connect.c
>> +++ b/fs/cifs/connect.c
>> @@ -850,13 +850,14 @@ cifs_handle_standard(struct TCP_Server_Info *server, struct mid_q_entry *mid)
>>  static int
>>  cifs_demultiplex_thread(void *p)
>>  {
>> -       int length;
>> +       int i, num_mids, length;
>>         struct TCP_Server_Info *server = p;
>>         unsigned int pdu_length;
>>         unsigned int next_offset;
>>         char *buf = NULL;
>>         struct task_struct *task_to_wake = NULL;
>> -       struct mid_q_entry *mid_entry;
>> +       struct mid_q_entry *mids[MAX_COMPOUND];
>> +       char *bufs[MAX_COMPOUND];
>>
>>         current->flags |= PF_MEMALLOC;
>>         cifs_dbg(FYI, "Demultiplex PID: %d\n", task_pid_nr(current));
>> @@ -923,58 +924,75 @@ cifs_demultiplex_thread(void *p)
>>                                 server->pdu_size = next_offset;
>>                 }
>>
>> -               mid_entry = NULL;
>> +               memset(mids, 0, sizeof(mids));
>> +               memset(bufs, 0, sizeof(bufs));
>> +               num_mids = 0;
>> +
>>                 if (server->ops->is_transform_hdr &&
>>                     server->ops->receive_transform &&
>>                     server->ops->is_transform_hdr(buf)) {
>>                         length = server->ops->receive_transform(server,
>> -                                                               &mid_entry);
>> +                                                               mids,
>> +                                                               bufs,
>> +                                                               &num_mids);
>>                 } else {
>> -                       mid_entry = server->ops->find_mid(server, buf);
>> +                       mids[0] = server->ops->find_mid(server, buf);
>> +                       bufs[0] = buf;
>> +                       if (mids[0])
>> +                               num_mids = 1;
>>
>> -                       if (!mid_entry || !mid_entry->receive)
>> -                               length = standard_receive3(server, mid_entry);
>> +                       if (!mids[0] || !mids[0]->receive)
>> +                               length = standard_receive3(server, mids[0]);
>>                         else
>> -                               length = mid_entry->receive(server, mid_entry);
>> +                               length = mids[0]->receive(server, mids[0]);
>>                 }
>>
>>                 if (length < 0) {
>> -                       if (mid_entry)
>> -                               cifs_mid_q_entry_release(mid_entry);
>> +                       for (i = 0; i < num_mids; i++)
>> +                               if (mids[i])
>> +                                       cifs_mid_q_entry_release(mids[i]);
>>                         continue;
>>                 }
>>
>>                 if (server->large_buf)
>>                         buf = server->bigbuf;
>> +               //qqq
>                     ^^^
> typo
>
>>
>>                 server->lstrp = jiffies;
>> -               if (mid_entry != NULL) {
>> -                       mid_entry->resp_buf_size = server->pdu_size;
>> -                       if ((mid_entry->mid_flags & MID_WAIT_CANCELLED) &&
>> -                            mid_entry->mid_state == MID_RESPONSE_RECEIVED &&
>> -                                       server->ops->handle_cancelled_mid)
>> -                               server->ops->handle_cancelled_mid(
>> -                                                       mid_entry->resp_buf,
>> -                                                       server);
>>
>> -                       if (!mid_entry->multiRsp || mid_entry->multiEnd)
>> -                               mid_entry->callback(mid_entry);
>> +               for (i = 0; i < num_mids; i++) {
>> +                       if (mids[i] != NULL) {
>> +                               mids[i]->resp_buf_size = server->pdu_size;
>> +                               if ((mids[i]->mid_flags & MID_WAIT_CANCELLED) &&
>> +                                   mids[i]->mid_state == MID_RESPONSE_RECEIVED &&
>> +                                   server->ops->handle_cancelled_mid)
>> +                                       server->ops->handle_cancelled_mid(
>> +                                                       mids[i]->resp_buf,
>> +                                                       server);
>>
>> -                       cifs_mid_q_entry_release(mid_entry);
>> -               } else if (server->ops->is_oplock_break &&
>> -                          server->ops->is_oplock_break(buf, server)) {
>> -                       cifs_dbg(FYI, "Received oplock break\n");
>> -               } else {
>> -                       cifs_dbg(VFS, "No task to wake, unknown frame received! NumMids %d\n",
>> -                                atomic_read(&midCount));
>> -                       cifs_dump_mem("Received Data is: ", buf,
>> -                                     HEADER_SIZE(server));
>> +                               if (!mids[i]->multiRsp || mids[i]->multiEnd)
>> +                                       mids[i]->callback(mids[i]);
>> +
>> +                               cifs_mid_q_entry_release(mids[i]);
>> +                       } else if (server->ops->is_oplock_break &&
>> +                                  server->ops->is_oplock_break(bufs[i],
>> +                                                               server)) {
>> +                               cifs_dbg(FYI, "Received oplock break\n");
>> +                       } else {
>> +                               cifs_dbg(VFS, "No task to wake, unknown frame "
>> +                                        "received! NumMids %d\n",
>> +                                        atomic_read(&midCount));
>> +                               cifs_dump_mem("Received Data is: ", bufs[i],
>> +                                             HEADER_SIZE(server));
>>  #ifdef CONFIG_CIFS_DEBUG2
>> -                       if (server->ops->dump_detail)
>> -                               server->ops->dump_detail(buf, server);
>> -                       cifs_dump_mids(server);
>> +                               if (server->ops->dump_detail)
>> +                                       server->ops->dump_detail(bufs[i],
>> +                                                                server);
>> +                               cifs_dump_mids(server);
>>  #endif /* CIFS_DEBUG2 */
>> +                       }
>>                 }
>> +
>>                 if (pdu_length > server->pdu_size) {
>>                         if (!allocate_buffers(server))
>>                                 continue;
>> diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
>> index ebc13ebebddf..5f3f27e33244 100644
>> --- a/fs/cifs/smb2ops.c
>> +++ b/fs/cifs/smb2ops.c
>> @@ -2942,13 +2942,20 @@ receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid)
>>
>>  static int
>>  receive_encrypted_standard(struct TCP_Server_Info *server,
>> -                          struct mid_q_entry **mid)
>> +                          struct mid_q_entry **mids, char **bufs,
>> +                          int *num_mids)
>>  {
>> -       int length;
>> +       int ret, length;
>>         char *buf = server->smallbuf;
>> +       char *tmpbuf;
>> +       struct smb2_sync_hdr *shdr;
>>         unsigned int pdu_length = server->pdu_size;
>>         unsigned int buf_size;
>>         struct mid_q_entry *mid_entry;
>> +       int next_is_large;
>> +       char *next_buffer = NULL;
>> +
>> +       *num_mids = 0;
>>
>>         /* switch to large buffer if too big for a small one */
>>         if (pdu_length > MAX_CIFS_SMALL_BUFFER_SIZE) {
>> @@ -2969,24 +2976,61 @@ receive_encrypted_standard(struct TCP_Server_Info *server,
>>         if (length)
>>                 return length;
>>
>> +       next_is_large = server->large_buf;
>> + one_more:
>> +       shdr = (struct smb2_sync_hdr *)buf;
>> +       if (shdr->NextCommand) {
>> +               if (next_is_large) {
>> +                       tmpbuf = server->bigbuf;
>> +                       next_buffer = (char *)cifs_buf_get();
>> +               } else {
>> +                       tmpbuf = server->smallbuf;
>> +                       next_buffer = (char *)cifs_small_buf_get();
>> +               }
>> +               memcpy(next_buffer,
>> +                      tmpbuf + le32_to_cpu(shdr->NextCommand),
>> +                      pdu_length - le32_to_cpu(shdr->NextCommand));
>> +       }
>> +
>>         mid_entry = smb2_find_mid(server, buf);
>>         if (mid_entry == NULL)
>>                 cifs_dbg(FYI, "mid not found\n");
>>         else {
>>                 cifs_dbg(FYI, "mid found\n");
>>                 mid_entry->decrypted = true;
>> +               mid_entry->resp_buf_size = server->pdu_size;
>
> server->pdu_size is being set in demultiplex thread to pdu_length that
> equals to a size of a whole PDU. While without encryption demultiplex
> thread updates server->pdu_size by setting it to a size of a
> particular chain element, it doesn't seem that we update this value in
> this function
>
> I think we need a local variable pdu_size which is being updated the
> same way demultiplex thread updates server->pdu_size: set it to
> pdu_length at the beginning of the "one_more" loop and then truncate
> to the offset of the next chain element.

I think this is harmless buy you are right.
I will leave the current patch as is but send a follow up patch to change this.

>
>>         }
>>
>> -       *mid = mid_entry;
>> +       if (*num_mids >= MAX_COMPOUND) {
>> +               cifs_dbg(VFS, "too many PDUs in compound\n");
>> +               return -1;
>> +       }
>> +       bufs[*num_mids] = buf;
>> +       mids[(*num_mids)++] = mid_entry;
>>
>>         if (mid_entry && mid_entry->handle)
>> -               return mid_entry->handle(server, mid_entry);
>> +               ret = mid_entry->handle(server, mid_entry);
>> +       else
>> +               ret = cifs_handle_standard(server, mid_entry);
>> +
>> +       if (ret == 0 && shdr->NextCommand) {
>> +               pdu_length -= le32_to_cpu(shdr->NextCommand);
>> +               server->large_buf = next_is_large;
>> +               if (next_is_large) {
>> +                       server->bigbuf = next_buffer;
>> +               } else {
>> +                       server->smallbuf = next_buffer;
>> +               }
>> +               buf += le32_to_cpu(shdr->NextCommand);
>> +               goto one_more;
>> +       }
>>
>> -       return cifs_handle_standard(server, mid_entry);
>> +        return ret;
>>  }
>>
>>  static int
>> -smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
>> +smb3_receive_transform(struct TCP_Server_Info *server,
>> +                      struct mid_q_entry **mids, char **bufs, int *num_mids)
>>  {
>>         char *buf = server->smallbuf;
>>         unsigned int pdu_length = server->pdu_size;
>> @@ -3009,10 +3053,11 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
>>                 return -ECONNABORTED;
>>         }
>>
>> +       /* TODO: add support for compounds containing READ. */
>>         if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
>> -               return receive_encrypted_read(server, mid);
>> +               return receive_encrypted_read(server, &mids[0]);
>>
>> -       return receive_encrypted_standard(server, mid);
>> +       return receive_encrypted_standard(server, mids, bufs, num_mids);
>>  }
>>
>>  int
>> diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
>> index c53c0908d4c6..78f96fa3d7d9 100644
>> --- a/fs/cifs/transport.c
>> +++ b/fs/cifs/transport.c
>> @@ -383,8 +383,6 @@ __smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
>>         return rc;
>>  }
>>
>> -#define MAX_COMPOUND 5
>> -
>>  static int
>>  smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
>>               struct smb_rqst *rqst, int flags)
>> --
>> 2.13.3
>>
>> --
>> To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
>> the body of a message to majordomo@vger.kernel.org
>> More majordomo info at  http://vger.kernel.org/majordomo-info.html
>
>
> --
> Best regards,
> Pavel Shilovsky
> --
> To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
> the body of a message to majordomo@vger.kernel.org
> More majordomo info at  http://vger.kernel.org/majordomo-info.html
--
To unsubscribe from this list: send the line "unsubscribe linux-cifs" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
diff mbox series

Patch

diff --git a/fs/cifs/cifsglob.h b/fs/cifs/cifsglob.h
index 41803d374da0..0c9ab62c3df4 100644
--- a/fs/cifs/cifsglob.h
+++ b/fs/cifs/cifsglob.h
@@ -76,6 +76,9 @@ 
 #define SMB_ECHO_INTERVAL_MAX 600
 #define SMB_ECHO_INTERVAL_DEFAULT 60
 
+/* maximum number of PDUs in one compound */
+#define MAX_COMPOUND 5
+
 /*
  * Default number of credits to keep available for SMB3.
  * This value is chosen somewhat arbitrarily. The Windows client
@@ -458,7 +461,7 @@  struct smb_version_operations {
 				 struct smb_rqst *, struct smb_rqst *);
 	int (*is_transform_hdr)(void *buf);
 	int (*receive_transform)(struct TCP_Server_Info *,
-				 struct mid_q_entry **);
+				 struct mid_q_entry **, char **, int *);
 	enum securityEnum (*select_sectype)(struct TCP_Server_Info *,
 			    enum securityEnum);
 	int (*next_header)(char *);
diff --git a/fs/cifs/connect.c b/fs/cifs/connect.c
index d9bd10d295a9..20564ea71dd5 100644
--- a/fs/cifs/connect.c
+++ b/fs/cifs/connect.c
@@ -850,13 +850,14 @@  cifs_handle_standard(struct TCP_Server_Info *server, struct mid_q_entry *mid)
 static int
 cifs_demultiplex_thread(void *p)
 {
-	int length;
+	int i, num_mids, length;
 	struct TCP_Server_Info *server = p;
 	unsigned int pdu_length;
 	unsigned int next_offset;
 	char *buf = NULL;
 	struct task_struct *task_to_wake = NULL;
-	struct mid_q_entry *mid_entry;
+	struct mid_q_entry *mids[MAX_COMPOUND];
+	char *bufs[MAX_COMPOUND];
 
 	current->flags |= PF_MEMALLOC;
 	cifs_dbg(FYI, "Demultiplex PID: %d\n", task_pid_nr(current));
@@ -923,58 +924,75 @@  cifs_demultiplex_thread(void *p)
 				server->pdu_size = next_offset;
 		}
 
-		mid_entry = NULL;
+		memset(mids, 0, sizeof(mids));
+		memset(bufs, 0, sizeof(bufs));
+		num_mids = 0;
+
 		if (server->ops->is_transform_hdr &&
 		    server->ops->receive_transform &&
 		    server->ops->is_transform_hdr(buf)) {
 			length = server->ops->receive_transform(server,
-								&mid_entry);
+								mids,
+								bufs,
+								&num_mids);
 		} else {
-			mid_entry = server->ops->find_mid(server, buf);
+			mids[0] = server->ops->find_mid(server, buf);
+			bufs[0] = buf;
+			if (mids[0])
+				num_mids = 1;
 
-			if (!mid_entry || !mid_entry->receive)
-				length = standard_receive3(server, mid_entry);
+			if (!mids[0] || !mids[0]->receive)
+				length = standard_receive3(server, mids[0]);
 			else
-				length = mid_entry->receive(server, mid_entry);
+				length = mids[0]->receive(server, mids[0]);
 		}
 
 		if (length < 0) {
-			if (mid_entry)
-				cifs_mid_q_entry_release(mid_entry);
+			for (i = 0; i < num_mids; i++)
+				if (mids[i])
+					cifs_mid_q_entry_release(mids[i]);
 			continue;
 		}
 
 		if (server->large_buf)
 			buf = server->bigbuf;
+		//qqq
 
 		server->lstrp = jiffies;
-		if (mid_entry != NULL) {
-			mid_entry->resp_buf_size = server->pdu_size;
-			if ((mid_entry->mid_flags & MID_WAIT_CANCELLED) &&
-			     mid_entry->mid_state == MID_RESPONSE_RECEIVED &&
-					server->ops->handle_cancelled_mid)
-				server->ops->handle_cancelled_mid(
-							mid_entry->resp_buf,
-							server);
 
-			if (!mid_entry->multiRsp || mid_entry->multiEnd)
-				mid_entry->callback(mid_entry);
+		for (i = 0; i < num_mids; i++) {
+			if (mids[i] != NULL) {
+				mids[i]->resp_buf_size = server->pdu_size;
+				if ((mids[i]->mid_flags & MID_WAIT_CANCELLED) &&
+				    mids[i]->mid_state == MID_RESPONSE_RECEIVED &&
+				    server->ops->handle_cancelled_mid)
+					server->ops->handle_cancelled_mid(
+							mids[i]->resp_buf,
+							server);
 
-			cifs_mid_q_entry_release(mid_entry);
-		} else if (server->ops->is_oplock_break &&
-			   server->ops->is_oplock_break(buf, server)) {
-			cifs_dbg(FYI, "Received oplock break\n");
-		} else {
-			cifs_dbg(VFS, "No task to wake, unknown frame received! NumMids %d\n",
-				 atomic_read(&midCount));
-			cifs_dump_mem("Received Data is: ", buf,
-				      HEADER_SIZE(server));
+				if (!mids[i]->multiRsp || mids[i]->multiEnd)
+					mids[i]->callback(mids[i]);
+
+				cifs_mid_q_entry_release(mids[i]);
+			} else if (server->ops->is_oplock_break &&
+				   server->ops->is_oplock_break(bufs[i],
+								server)) {
+				cifs_dbg(FYI, "Received oplock break\n");
+			} else {
+				cifs_dbg(VFS, "No task to wake, unknown frame "
+					 "received! NumMids %d\n",
+					 atomic_read(&midCount));
+				cifs_dump_mem("Received Data is: ", bufs[i],
+					      HEADER_SIZE(server));
 #ifdef CONFIG_CIFS_DEBUG2
-			if (server->ops->dump_detail)
-				server->ops->dump_detail(buf, server);
-			cifs_dump_mids(server);
+				if (server->ops->dump_detail)
+					server->ops->dump_detail(bufs[i],
+								 server);
+				cifs_dump_mids(server);
 #endif /* CIFS_DEBUG2 */
+			}
 		}
+
 		if (pdu_length > server->pdu_size) {
 			if (!allocate_buffers(server))
 				continue;
diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c
index ebc13ebebddf..5f3f27e33244 100644
--- a/fs/cifs/smb2ops.c
+++ b/fs/cifs/smb2ops.c
@@ -2942,13 +2942,20 @@  receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid)
 
 static int
 receive_encrypted_standard(struct TCP_Server_Info *server,
-			   struct mid_q_entry **mid)
+			   struct mid_q_entry **mids, char **bufs,
+			   int *num_mids)
 {
-	int length;
+	int ret, length;
 	char *buf = server->smallbuf;
+	char *tmpbuf;
+	struct smb2_sync_hdr *shdr;
 	unsigned int pdu_length = server->pdu_size;
 	unsigned int buf_size;
 	struct mid_q_entry *mid_entry;
+	int next_is_large;
+	char *next_buffer = NULL;
+
+	*num_mids = 0;
 
 	/* switch to large buffer if too big for a small one */
 	if (pdu_length > MAX_CIFS_SMALL_BUFFER_SIZE) {
@@ -2969,24 +2976,61 @@  receive_encrypted_standard(struct TCP_Server_Info *server,
 	if (length)
 		return length;
 
+	next_is_large = server->large_buf;
+ one_more:
+	shdr = (struct smb2_sync_hdr *)buf;
+	if (shdr->NextCommand) {
+		if (next_is_large) {
+			tmpbuf = server->bigbuf;
+			next_buffer = (char *)cifs_buf_get();
+		} else {
+			tmpbuf = server->smallbuf;
+			next_buffer = (char *)cifs_small_buf_get();
+		}
+		memcpy(next_buffer,
+		       tmpbuf + le32_to_cpu(shdr->NextCommand),
+		       pdu_length - le32_to_cpu(shdr->NextCommand));
+	}
+
 	mid_entry = smb2_find_mid(server, buf);
 	if (mid_entry == NULL)
 		cifs_dbg(FYI, "mid not found\n");
 	else {
 		cifs_dbg(FYI, "mid found\n");
 		mid_entry->decrypted = true;
+		mid_entry->resp_buf_size = server->pdu_size;
 	}
 
-	*mid = mid_entry;
+	if (*num_mids >= MAX_COMPOUND) {
+		cifs_dbg(VFS, "too many PDUs in compound\n");
+		return -1;
+	}
+	bufs[*num_mids] = buf;
+	mids[(*num_mids)++] = mid_entry;
 
 	if (mid_entry && mid_entry->handle)
-		return mid_entry->handle(server, mid_entry);
+		ret = mid_entry->handle(server, mid_entry);
+	else
+		ret = cifs_handle_standard(server, mid_entry);
+
+	if (ret == 0 && shdr->NextCommand) {
+		pdu_length -= le32_to_cpu(shdr->NextCommand);
+		server->large_buf = next_is_large;
+		if (next_is_large) {
+			server->bigbuf = next_buffer;
+		} else {
+			server->smallbuf = next_buffer;
+		}
+		buf += le32_to_cpu(shdr->NextCommand);
+		goto one_more;
+	}
 
-	return cifs_handle_standard(server, mid_entry);
+        return ret;
 }
 
 static int
-smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
+smb3_receive_transform(struct TCP_Server_Info *server,
+		       struct mid_q_entry **mids, char **bufs, int *num_mids)
 {
 	char *buf = server->smallbuf;
 	unsigned int pdu_length = server->pdu_size;
@@ -3009,10 +3053,11 @@  smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
 		return -ECONNABORTED;
 	}
 
+	/* TODO: add support for compounds containing READ. */
 	if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
-		return receive_encrypted_read(server, mid);
+		return receive_encrypted_read(server, &mids[0]);
 
-	return receive_encrypted_standard(server, mid);
+	return receive_encrypted_standard(server, mids, bufs, num_mids);
 }
 
 int
diff --git a/fs/cifs/transport.c b/fs/cifs/transport.c
index c53c0908d4c6..78f96fa3d7d9 100644
--- a/fs/cifs/transport.c
+++ b/fs/cifs/transport.c
@@ -383,8 +383,6 @@  __smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
 	return rc;
 }
 
-#define MAX_COMPOUND 5
-
 static int
 smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
 	      struct smb_rqst *rqst, int flags)