From b363d1117c0526d866183cdfc41837de9e3306c5 Mon Sep 17 00:00:00 2001 From: Eliot Berriot Date: Thu, 29 Aug 2019 12:20:30 +0200 Subject: [PATCH] Added GenericRelationFilter --- api/funkwhale_api/common/fields.py | 58 ++++++++++++++++++++++++++++++ api/tests/common/test_fields.py | 36 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/api/funkwhale_api/common/fields.py b/api/funkwhale_api/common/fields.py index f1a3e77bd..f206b626e 100644 --- a/api/funkwhale_api/common/fields.py +++ b/api/funkwhale_api/common/fields.py @@ -57,6 +57,64 @@ class SmartSearchFilter(django_filters.CharFilter): return search.apply(qs, cleaned) +def get_generic_filter_query(value, relation_name, choices): + parts = value.split(":", 1) + type = parts[0] + try: + conf = choices[type] + except KeyError: + raise forms.ValidationError("Invalid type") + related_queryset = conf["queryset"] + related_model = related_queryset.model + filter_query = models.Q( + **{ + "{}_content_type__app_label".format( + relation_name + ): related_model._meta.app_label, + "{}_content_type__model".format( + relation_name + ): related_model._meta.model_name, + } + ) + if len(parts) > 1: + id_attr = conf.get("id_attr", "id") + id_field = conf.get("id_field", serializers.IntegerField(min_value=1)) + try: + id_value = parts[1] + id_value = id_field.to_internal_value(id_value) + except (TypeError, KeyError, serializers.ValidationError): + raise forms.ValidationError("Invalid id") + query_getter = conf.get( + "get_query", lambda attr, value: models.Q(**{attr: value}) + ) + obj_query = query_getter(id_attr, id_value) + try: + obj = related_queryset.get(obj_query) + except related_queryset.model.DoesNotExist: + raise forms.ValidationError("Invalid object") + filter_query &= models.Q(**{"{}_id".format(relation_name): obj.id}) + + return filter_query + + +class GenericRelationFilter(django_filters.CharFilter): + def __init__(self, relation_name, choices, *args, **kwargs): + self.relation_name = relation_name + self.choices = choices + super().__init__(*args, **kwargs) + + def filter(self, qs, value): + if not value: + return qs + try: + filter_query = get_generic_filter_query( + value, relation_name=self.relation_name, choices=self.choices + ) + except forms.ValidationError: + return qs.none() + return qs.filter(filter_query) + + class GenericRelation(serializers.JSONField): def __init__(self, choices, *args, **kwargs): self.choices = choices diff --git a/api/tests/common/test_fields.py b/api/tests/common/test_fields.py index 21e85b700..2cc07f1b2 100644 --- a/api/tests/common/test_fields.py +++ b/api/tests/common/test_fields.py @@ -67,3 +67,39 @@ def test_generic_relation_field_validation_error(payload, expected_error, factor with pytest.raises(fields.serializers.ValidationError, match=expected_error): f.to_internal_value(payload) + + +def test_generic_relation_filter_target_type(factories): + user = factories["users.User"]() + note = factories["moderation.Note"](target=user) + factories["moderation.Note"](target=factories["music.Artist"]()) + f = fields.GenericRelationFilter( + "target", + { + "user": { + "queryset": user.__class__.objects.all(), + "id_attr": "username", + "id_field": fields.serializers.CharField(), + } + }, + ) + qs = f.filter(note.__class__.objects.all(), "user") + assert list(qs) == [note] + + +def test_generic_relation_filter_target_type_and_id(factories): + user = factories["users.User"]() + note = factories["moderation.Note"](target=user) + factories["moderation.Note"](target=factories["users.User"]()) + f = fields.GenericRelationFilter( + "target", + { + "user": { + "queryset": user.__class__.objects.all(), + "id_attr": "username", + "id_field": fields.serializers.CharField(), + } + }, + ) + qs = f.filter(note.__class__.objects.all(), "user:{}".format(user.username)) + assert list(qs) == [note]