Added GenericRelationFilter
This commit is contained in:
parent
ab3bc96783
commit
b363d1117c
|
@ -57,6 +57,64 @@ class SmartSearchFilter(django_filters.CharFilter):
|
||||||
return search.apply(qs, cleaned)
|
return search.apply(qs, cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def get_generic_filter_query(value, relation_name, choices):
|
||||||
|
parts = value.split(":", 1)
|
||||||
|
type = parts[0]
|
||||||
|
try:
|
||||||
|
conf = choices[type]
|
||||||
|
except KeyError:
|
||||||
|
raise forms.ValidationError("Invalid type")
|
||||||
|
related_queryset = conf["queryset"]
|
||||||
|
related_model = related_queryset.model
|
||||||
|
filter_query = models.Q(
|
||||||
|
**{
|
||||||
|
"{}_content_type__app_label".format(
|
||||||
|
relation_name
|
||||||
|
): related_model._meta.app_label,
|
||||||
|
"{}_content_type__model".format(
|
||||||
|
relation_name
|
||||||
|
): related_model._meta.model_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if len(parts) > 1:
|
||||||
|
id_attr = conf.get("id_attr", "id")
|
||||||
|
id_field = conf.get("id_field", serializers.IntegerField(min_value=1))
|
||||||
|
try:
|
||||||
|
id_value = parts[1]
|
||||||
|
id_value = id_field.to_internal_value(id_value)
|
||||||
|
except (TypeError, KeyError, serializers.ValidationError):
|
||||||
|
raise forms.ValidationError("Invalid id")
|
||||||
|
query_getter = conf.get(
|
||||||
|
"get_query", lambda attr, value: models.Q(**{attr: value})
|
||||||
|
)
|
||||||
|
obj_query = query_getter(id_attr, id_value)
|
||||||
|
try:
|
||||||
|
obj = related_queryset.get(obj_query)
|
||||||
|
except related_queryset.model.DoesNotExist:
|
||||||
|
raise forms.ValidationError("Invalid object")
|
||||||
|
filter_query &= models.Q(**{"{}_id".format(relation_name): obj.id})
|
||||||
|
|
||||||
|
return filter_query
|
||||||
|
|
||||||
|
|
||||||
|
class GenericRelationFilter(django_filters.CharFilter):
|
||||||
|
def __init__(self, relation_name, choices, *args, **kwargs):
|
||||||
|
self.relation_name = relation_name
|
||||||
|
self.choices = choices
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def filter(self, qs, value):
|
||||||
|
if not value:
|
||||||
|
return qs
|
||||||
|
try:
|
||||||
|
filter_query = get_generic_filter_query(
|
||||||
|
value, relation_name=self.relation_name, choices=self.choices
|
||||||
|
)
|
||||||
|
except forms.ValidationError:
|
||||||
|
return qs.none()
|
||||||
|
return qs.filter(filter_query)
|
||||||
|
|
||||||
|
|
||||||
class GenericRelation(serializers.JSONField):
|
class GenericRelation(serializers.JSONField):
|
||||||
def __init__(self, choices, *args, **kwargs):
|
def __init__(self, choices, *args, **kwargs):
|
||||||
self.choices = choices
|
self.choices = choices
|
||||||
|
|
|
@ -67,3 +67,39 @@ def test_generic_relation_field_validation_error(payload, expected_error, factor
|
||||||
|
|
||||||
with pytest.raises(fields.serializers.ValidationError, match=expected_error):
|
with pytest.raises(fields.serializers.ValidationError, match=expected_error):
|
||||||
f.to_internal_value(payload)
|
f.to_internal_value(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generic_relation_filter_target_type(factories):
|
||||||
|
user = factories["users.User"]()
|
||||||
|
note = factories["moderation.Note"](target=user)
|
||||||
|
factories["moderation.Note"](target=factories["music.Artist"]())
|
||||||
|
f = fields.GenericRelationFilter(
|
||||||
|
"target",
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"queryset": user.__class__.objects.all(),
|
||||||
|
"id_attr": "username",
|
||||||
|
"id_field": fields.serializers.CharField(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
qs = f.filter(note.__class__.objects.all(), "user")
|
||||||
|
assert list(qs) == [note]
|
||||||
|
|
||||||
|
|
||||||
|
def test_generic_relation_filter_target_type_and_id(factories):
|
||||||
|
user = factories["users.User"]()
|
||||||
|
note = factories["moderation.Note"](target=user)
|
||||||
|
factories["moderation.Note"](target=factories["users.User"]())
|
||||||
|
f = fields.GenericRelationFilter(
|
||||||
|
"target",
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"queryset": user.__class__.objects.all(),
|
||||||
|
"id_attr": "username",
|
||||||
|
"id_field": fields.serializers.CharField(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
qs = f.filter(note.__class__.objects.all(), "user:{}".format(user.username))
|
||||||
|
assert list(qs) == [note]
|
||||||
|
|
Loading…
Reference in New Issue