diff mbox series

[1/2] REST: Handle regular form data requests for checks

Message ID 20190429170754.32644-2-dja@axtens.net
State Accepted
Headers show
Series Snowpatch REST fixes | expand

Commit Message

Daniel Axtens April 29, 2019, 5:07 p.m. UTC
08d1459a4a40 ("Add REST API validation using OpenAPI schema") moved
all API requests to JSON blobs rather than form data.

dc48fbce99ef ("REST: Handle JSON requests") attempted to change the
check serialiser to handle this. However, because both a JSON dict
and a QueryDict satisfy isinstance(data, dict), everything was handled
as JSON and the old style requests were broken.

Found in the process of debugging issues from the OzLabs PW & Snowpatch
crew - I'm not sure if they actually hit this one, but kudos to them
anyway as we wouldn't have found it without them.

Fixes: 08d1459a4a40 ("Add REST API validation using OpenAPI schema")
Fixes: dc48fbce99ef ("REST: Handle JSON requests")
Signed-off-by: Daniel Axtens <dja@axtens.net>

---

This will need to go back to stable.
---
 patchwork/api/check.py            |  7 ++--
 patchwork/tests/api/test_check.py | 67 +++++++++++++++++++++++++++++++
 2 files changed, 71 insertions(+), 3 deletions(-)

Comments

Daniel Axtens April 29, 2019, 5:19 p.m. UTC | #1
Daniel Axtens <dja@axtens.net> writes:

> 08d1459a4a40 ("Add REST API validation using OpenAPI schema") moved
> all API requests to JSON blobs rather than form data.
>
> dc48fbce99ef ("REST: Handle JSON requests") attempted to change the
> check serialiser to handle this. However, because both a JSON dict
> and a QueryDict satisfy isinstance(data, dict), everything was handled
> as JSON and the old style requests were broken.
>
> Found in the process of debugging issues from the OzLabs PW & Snowpatch
> crew - I'm not sure if they actually hit this one, but kudos to them
> anyway as we wouldn't have found it without them.
>
> Fixes: 08d1459a4a40 ("Add REST API validation using OpenAPI schema")
> Fixes: dc48fbce99ef ("REST: Handle JSON requests")
> Signed-off-by: Daniel Axtens <dja@axtens.net>
>
> ---
>
> This will need to go back to stable.

This is because OzLabs needed to pick up dc48fbce99ef. I can't remember
why off the top of my head, but I'll send a stable fixes series myself
once we nail down the Ozlabs issues.

Regards,
Daniel

> ---
>  patchwork/api/check.py            |  7 ++--
>  patchwork/tests/api/test_check.py | 67 +++++++++++++++++++++++++++++++
>  2 files changed, 71 insertions(+), 3 deletions(-)
>
> diff --git a/patchwork/api/check.py b/patchwork/api/check.py
> index 1f9fe06866a2..4d2181d0a04b 100644
> --- a/patchwork/api/check.py
> +++ b/patchwork/api/check.py
> @@ -4,6 +4,7 @@
>  # SPDX-License-Identifier: GPL-2.0-or-later
>  
>  from django.http import Http404
> +from django.http.request import QueryDict
>  from django.shortcuts import get_object_or_404
>  from rest_framework.exceptions import PermissionDenied
>  from rest_framework.generics import ListCreateAPIView
> @@ -39,9 +40,7 @@ class CheckSerializer(HyperlinkedModelSerializer):
>              if label != data['state']:
>                  continue
>  
> -            if isinstance(data, dict):  # json request
> -                data['state'] = val
> -            else:  # form-data request
> +            if isinstance(data, QueryDict):  # form-data request
>                  # NOTE(stephenfin): 'data' is essentially 'request.POST', which
>                  # is immutable by default. However, there's no good reason for
>                  # this to be this way [1], so temporarily unset that mutability
> @@ -52,6 +51,8 @@ class CheckSerializer(HyperlinkedModelSerializer):
>                  data._mutable = True  # noqa
>                  data['state'] = val
>                  data._mutable = mutable  # noqa
> +            else:  # json request
> +                data['state'] = val
>  
>              break
>          return super(CheckSerializer, self).run_validation(data)
> diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py
> index 0c10b94553d3..1cfdff6e757b 100644
> --- a/patchwork/tests/api/test_check.py
> +++ b/patchwork/tests/api/test_check.py
> @@ -18,6 +18,10 @@ from patchwork.tests.utils import create_user
>  
>  if settings.ENABLE_REST_API:
>      from rest_framework import status
> +    from rest_framework.test import APITestCase as BaseAPITestCase
> +else:
> +    # stub out APITestCase
> +    from django.test import TestCase as BaseAPITestCase
>  
>  
>  @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
> @@ -174,3 +178,66 @@ class TestCheckAPI(utils.APITestCase):
>  
>          resp = self.client.delete(self.api_url(check))
>          self.assertEqual(status.HTTP_405_METHOD_NOT_ALLOWED, resp.status_code)
> +
> +
> +@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
> +class TestCheckAPIMultipart(BaseAPITestCase):
> +    """Test a minimal subset of functionality where the data is passed as
> +    multipart form data rather than as a JSON blob.
> +
> +    We focus on the POST path exclusively and only on state validation:
> +    everything else should be handled in the JSON tests.
> +
> +    This is required due to the difference in handling JSON vs form-data in
> +    CheckSerializer's run_validation().
> +    """
> +    fixtures = ['default_tags']
> +
> +    def setUp(self):
> +        super(TestCheckAPIMultipart, self).setUp()
> +        project = create_project()
> +        self.user = create_maintainer(project)
> +        self.patch = create_patch(project=project)
> +
> +    def assertSerialized(self, check_obj, check_json):
> +        self.assertEqual(check_obj.id, check_json['id'])
> +        self.assertEqual(check_obj.get_state_display(), check_json['state'])
> +        self.assertEqual(check_obj.target_url, check_json['target_url'])
> +        self.assertEqual(check_obj.context, check_json['context'])
> +        self.assertEqual(check_obj.description, check_json['description'])
> +        self.assertEqual(check_obj.user.id, check_json['user']['id'])
> +
> +    def _test_create(self, user, state='success'):
> +        check = {
> +            'target_url': 'http://t.co',
> +            'description': 'description',
> +            'context': 'context',
> +        }
> +        if state is not None:
> +            check['state'] = state
> +
> +        self.client.force_authenticate(user=user)
> +        return self.client.post(
> +            reverse('api-check-list', args=[self.patch.id]),
> +            check)
> +
> +    def test_creates(self):
> +        """Create a set of checks.
> +        """
> +        resp = self._test_create(user=self.user)
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(1, Check.objects.all().count())
> +        self.assertSerialized(Check.objects.last(), resp.data)
> +
> +        resp = self._test_create(user=self.user, state='pending')
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(2, Check.objects.all().count())
> +        self.assertSerialized(Check.objects.last(), resp.data)
> +
> +        # you can also use the numeric ID of the state, the API explorer does
> +        resp = self._test_create(user=self.user, state=2)
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(3, Check.objects.all().count())
> +        # we check against the string version
> +        resp.data['state'] = 'warning'
> +        self.assertSerialized(Check.objects.last(), resp.data)
> -- 
> 2.19.1
Daniel Axtens April 30, 2019, 5:23 a.m. UTC | #2
Applied.

Daniel Axtens <dja@axtens.net> writes:

> 08d1459a4a40 ("Add REST API validation using OpenAPI schema") moved
> all API requests to JSON blobs rather than form data.
>
> dc48fbce99ef ("REST: Handle JSON requests") attempted to change the
> check serialiser to handle this. However, because both a JSON dict
> and a QueryDict satisfy isinstance(data, dict), everything was handled
> as JSON and the old style requests were broken.
>
> Found in the process of debugging issues from the OzLabs PW & Snowpatch
> crew - I'm not sure if they actually hit this one, but kudos to them
> anyway as we wouldn't have found it without them.
>
> Fixes: 08d1459a4a40 ("Add REST API validation using OpenAPI schema")
I dropped this fixes line as it's not especially accurate.

> Fixes: dc48fbce99ef ("REST: Handle JSON requests")
> Signed-off-by: Daniel Axtens <dja@axtens.net>
>
> ---
>
> This will need to go back to stable.
> ---
>  patchwork/api/check.py            |  7 ++--
>  patchwork/tests/api/test_check.py | 67 +++++++++++++++++++++++++++++++
>  2 files changed, 71 insertions(+), 3 deletions(-)
>
> diff --git a/patchwork/api/check.py b/patchwork/api/check.py
> index 1f9fe06866a2..4d2181d0a04b 100644
> --- a/patchwork/api/check.py
> +++ b/patchwork/api/check.py
> @@ -4,6 +4,7 @@
>  # SPDX-License-Identifier: GPL-2.0-or-later
>  
>  from django.http import Http404
> +from django.http.request import QueryDict
>  from django.shortcuts import get_object_or_404
>  from rest_framework.exceptions import PermissionDenied
>  from rest_framework.generics import ListCreateAPIView
> @@ -39,9 +40,7 @@ class CheckSerializer(HyperlinkedModelSerializer):
>              if label != data['state']:
>                  continue
>  
> -            if isinstance(data, dict):  # json request
> -                data['state'] = val
> -            else:  # form-data request
> +            if isinstance(data, QueryDict):  # form-data request
>                  # NOTE(stephenfin): 'data' is essentially 'request.POST', which
>                  # is immutable by default. However, there's no good reason for
>                  # this to be this way [1], so temporarily unset that mutability
> @@ -52,6 +51,8 @@ class CheckSerializer(HyperlinkedModelSerializer):
>                  data._mutable = True  # noqa
>                  data['state'] = val
>                  data._mutable = mutable  # noqa
> +            else:  # json request
> +                data['state'] = val
>  
>              break
>          return super(CheckSerializer, self).run_validation(data)
> diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py
> index 0c10b94553d3..1cfdff6e757b 100644
> --- a/patchwork/tests/api/test_check.py
> +++ b/patchwork/tests/api/test_check.py
> @@ -18,6 +18,10 @@ from patchwork.tests.utils import create_user
>  
>  if settings.ENABLE_REST_API:
>      from rest_framework import status
> +    from rest_framework.test import APITestCase as BaseAPITestCase
> +else:
> +    # stub out APITestCase
> +    from django.test import TestCase as BaseAPITestCase
>  
>  
>  @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
> @@ -174,3 +178,66 @@ class TestCheckAPI(utils.APITestCase):
>  
>          resp = self.client.delete(self.api_url(check))
>          self.assertEqual(status.HTTP_405_METHOD_NOT_ALLOWED, resp.status_code)
> +
> +
> +@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
> +class TestCheckAPIMultipart(BaseAPITestCase):
> +    """Test a minimal subset of functionality where the data is passed as
> +    multipart form data rather than as a JSON blob.
> +
> +    We focus on the POST path exclusively and only on state validation:
> +    everything else should be handled in the JSON tests.
> +
> +    This is required due to the difference in handling JSON vs form-data in
> +    CheckSerializer's run_validation().
> +    """
> +    fixtures = ['default_tags']
> +
> +    def setUp(self):
> +        super(TestCheckAPIMultipart, self).setUp()
> +        project = create_project()
> +        self.user = create_maintainer(project)
> +        self.patch = create_patch(project=project)
> +
> +    def assertSerialized(self, check_obj, check_json):
> +        self.assertEqual(check_obj.id, check_json['id'])
> +        self.assertEqual(check_obj.get_state_display(), check_json['state'])
> +        self.assertEqual(check_obj.target_url, check_json['target_url'])
> +        self.assertEqual(check_obj.context, check_json['context'])
> +        self.assertEqual(check_obj.description, check_json['description'])
> +        self.assertEqual(check_obj.user.id, check_json['user']['id'])
> +
> +    def _test_create(self, user, state='success'):
> +        check = {
> +            'target_url': 'http://t.co',
> +            'description': 'description',
> +            'context': 'context',
> +        }
> +        if state is not None:
> +            check['state'] = state
> +
> +        self.client.force_authenticate(user=user)
> +        return self.client.post(
> +            reverse('api-check-list', args=[self.patch.id]),
> +            check)
> +
> +    def test_creates(self):
> +        """Create a set of checks.
> +        """
> +        resp = self._test_create(user=self.user)
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(1, Check.objects.all().count())
> +        self.assertSerialized(Check.objects.last(), resp.data)
> +
> +        resp = self._test_create(user=self.user, state='pending')
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(2, Check.objects.all().count())
> +        self.assertSerialized(Check.objects.last(), resp.data)
> +
> +        # you can also use the numeric ID of the state, the API explorer does
> +        resp = self._test_create(user=self.user, state=2)
> +        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
> +        self.assertEqual(3, Check.objects.all().count())
> +        # we check against the string version
> +        resp.data['state'] = 'warning'
> +        self.assertSerialized(Check.objects.last(), resp.data)
> -- 
> 2.19.1
diff mbox series

Patch

diff --git a/patchwork/api/check.py b/patchwork/api/check.py
index 1f9fe06866a2..4d2181d0a04b 100644
--- a/patchwork/api/check.py
+++ b/patchwork/api/check.py
@@ -4,6 +4,7 @@ 
 # SPDX-License-Identifier: GPL-2.0-or-later
 
 from django.http import Http404
+from django.http.request import QueryDict
 from django.shortcuts import get_object_or_404
 from rest_framework.exceptions import PermissionDenied
 from rest_framework.generics import ListCreateAPIView
@@ -39,9 +40,7 @@  class CheckSerializer(HyperlinkedModelSerializer):
             if label != data['state']:
                 continue
 
-            if isinstance(data, dict):  # json request
-                data['state'] = val
-            else:  # form-data request
+            if isinstance(data, QueryDict):  # form-data request
                 # NOTE(stephenfin): 'data' is essentially 'request.POST', which
                 # is immutable by default. However, there's no good reason for
                 # this to be this way [1], so temporarily unset that mutability
@@ -52,6 +51,8 @@  class CheckSerializer(HyperlinkedModelSerializer):
                 data._mutable = True  # noqa
                 data['state'] = val
                 data._mutable = mutable  # noqa
+            else:  # json request
+                data['state'] = val
 
             break
         return super(CheckSerializer, self).run_validation(data)
diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py
index 0c10b94553d3..1cfdff6e757b 100644
--- a/patchwork/tests/api/test_check.py
+++ b/patchwork/tests/api/test_check.py
@@ -18,6 +18,10 @@  from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
+    from rest_framework.test import APITestCase as BaseAPITestCase
+else:
+    # stub out APITestCase
+    from django.test import TestCase as BaseAPITestCase
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
@@ -174,3 +178,66 @@  class TestCheckAPI(utils.APITestCase):
 
         resp = self.client.delete(self.api_url(check))
         self.assertEqual(status.HTTP_405_METHOD_NOT_ALLOWED, resp.status_code)
+
+
+@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
+class TestCheckAPIMultipart(BaseAPITestCase):
+    """Test a minimal subset of functionality where the data is passed as
+    multipart form data rather than as a JSON blob.
+
+    We focus on the POST path exclusively and only on state validation:
+    everything else should be handled in the JSON tests.
+
+    This is required due to the difference in handling JSON vs form-data in
+    CheckSerializer's run_validation().
+    """
+    fixtures = ['default_tags']
+
+    def setUp(self):
+        super(TestCheckAPIMultipart, self).setUp()
+        project = create_project()
+        self.user = create_maintainer(project)
+        self.patch = create_patch(project=project)
+
+    def assertSerialized(self, check_obj, check_json):
+        self.assertEqual(check_obj.id, check_json['id'])
+        self.assertEqual(check_obj.get_state_display(), check_json['state'])
+        self.assertEqual(check_obj.target_url, check_json['target_url'])
+        self.assertEqual(check_obj.context, check_json['context'])
+        self.assertEqual(check_obj.description, check_json['description'])
+        self.assertEqual(check_obj.user.id, check_json['user']['id'])
+
+    def _test_create(self, user, state='success'):
+        check = {
+            'target_url': 'http://t.co',
+            'description': 'description',
+            'context': 'context',
+        }
+        if state is not None:
+            check['state'] = state
+
+        self.client.force_authenticate(user=user)
+        return self.client.post(
+            reverse('api-check-list', args=[self.patch.id]),
+            check)
+
+    def test_creates(self):
+        """Create a set of checks.
+        """
+        resp = self._test_create(user=self.user)
+        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
+        self.assertEqual(1, Check.objects.all().count())
+        self.assertSerialized(Check.objects.last(), resp.data)
+
+        resp = self._test_create(user=self.user, state='pending')
+        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
+        self.assertEqual(2, Check.objects.all().count())
+        self.assertSerialized(Check.objects.last(), resp.data)
+
+        # you can also use the numeric ID of the state, the API explorer does
+        resp = self._test_create(user=self.user, state=2)
+        self.assertEqual(status.HTTP_201_CREATED, resp.status_code)
+        self.assertEqual(3, Check.objects.all().count())
+        # we check against the string version
+        resp.data['state'] = 'warning'
+        self.assertSerialized(Check.objects.last(), resp.data)