Added GenericRelationFilter

This commit is contained in:
Eliot Berriot 2019-08-29 12:20:30 +02:00
parent ab3bc96783
commit b363d1117c
No known key found for this signature in database
GPG Key ID: DD6965E2476E5C27
2 changed files with 94 additions and 0 deletions

View File

@ -57,6 +57,64 @@ class SmartSearchFilter(django_filters.CharFilter):
return search.apply(qs, cleaned)
def get_generic_filter_query(value, relation_name, choices):
parts = value.split(":", 1)
type = parts[0]
try:
conf = choices[type]
except KeyError:
raise forms.ValidationError("Invalid type")
related_queryset = conf["queryset"]
related_model = related_queryset.model
filter_query = models.Q(
**{
"{}_content_type__app_label".format(
relation_name
): related_model._meta.app_label,
"{}_content_type__model".format(
relation_name
): related_model._meta.model_name,
}
)
if len(parts) > 1:
id_attr = conf.get("id_attr", "id")
id_field = conf.get("id_field", serializers.IntegerField(min_value=1))
try:
id_value = parts[1]
id_value = id_field.to_internal_value(id_value)
except (TypeError, KeyError, serializers.ValidationError):
raise forms.ValidationError("Invalid id")
query_getter = conf.get(
"get_query", lambda attr, value: models.Q(**{attr: value})
)
obj_query = query_getter(id_attr, id_value)
try:
obj = related_queryset.get(obj_query)
except related_queryset.model.DoesNotExist:
raise forms.ValidationError("Invalid object")
filter_query &= models.Q(**{"{}_id".format(relation_name): obj.id})
return filter_query
class GenericRelationFilter(django_filters.CharFilter):
def __init__(self, relation_name, choices, *args, **kwargs):
self.relation_name = relation_name
self.choices = choices
super().__init__(*args, **kwargs)
def filter(self, qs, value):
if not value:
return qs
try:
filter_query = get_generic_filter_query(
value, relation_name=self.relation_name, choices=self.choices
)
except forms.ValidationError:
return qs.none()
return qs.filter(filter_query)
class GenericRelation(serializers.JSONField):
def __init__(self, choices, *args, **kwargs):
self.choices = choices

View File

@ -67,3 +67,39 @@ def test_generic_relation_field_validation_error(payload, expected_error, factor
with pytest.raises(fields.serializers.ValidationError, match=expected_error):
f.to_internal_value(payload)
def test_generic_relation_filter_target_type(factories):
user = factories["users.User"]()
note = factories["moderation.Note"](target=user)
factories["moderation.Note"](target=factories["music.Artist"]())
f = fields.GenericRelationFilter(
"target",
{
"user": {
"queryset": user.__class__.objects.all(),
"id_attr": "username",
"id_field": fields.serializers.CharField(),
}
},
)
qs = f.filter(note.__class__.objects.all(), "user")
assert list(qs) == [note]
def test_generic_relation_filter_target_type_and_id(factories):
user = factories["users.User"]()
note = factories["moderation.Note"](target=user)
factories["moderation.Note"](target=factories["users.User"]())
f = fields.GenericRelationFilter(
"target",
{
"user": {
"queryset": user.__class__.objects.all(),
"id_attr": "username",
"id_field": fields.serializers.CharField(),
}
},
)
qs = f.filter(note.__class__.objects.all(), "user:{}".format(user.username))
assert list(qs) == [note]