Added GenericRelationFilter
This commit is contained in:
parent
ab3bc96783
commit
b363d1117c
|
@ -57,6 +57,64 @@ class SmartSearchFilter(django_filters.CharFilter):
|
|||
return search.apply(qs, cleaned)
|
||||
|
||||
|
||||
def get_generic_filter_query(value, relation_name, choices):
|
||||
parts = value.split(":", 1)
|
||||
type = parts[0]
|
||||
try:
|
||||
conf = choices[type]
|
||||
except KeyError:
|
||||
raise forms.ValidationError("Invalid type")
|
||||
related_queryset = conf["queryset"]
|
||||
related_model = related_queryset.model
|
||||
filter_query = models.Q(
|
||||
**{
|
||||
"{}_content_type__app_label".format(
|
||||
relation_name
|
||||
): related_model._meta.app_label,
|
||||
"{}_content_type__model".format(
|
||||
relation_name
|
||||
): related_model._meta.model_name,
|
||||
}
|
||||
)
|
||||
if len(parts) > 1:
|
||||
id_attr = conf.get("id_attr", "id")
|
||||
id_field = conf.get("id_field", serializers.IntegerField(min_value=1))
|
||||
try:
|
||||
id_value = parts[1]
|
||||
id_value = id_field.to_internal_value(id_value)
|
||||
except (TypeError, KeyError, serializers.ValidationError):
|
||||
raise forms.ValidationError("Invalid id")
|
||||
query_getter = conf.get(
|
||||
"get_query", lambda attr, value: models.Q(**{attr: value})
|
||||
)
|
||||
obj_query = query_getter(id_attr, id_value)
|
||||
try:
|
||||
obj = related_queryset.get(obj_query)
|
||||
except related_queryset.model.DoesNotExist:
|
||||
raise forms.ValidationError("Invalid object")
|
||||
filter_query &= models.Q(**{"{}_id".format(relation_name): obj.id})
|
||||
|
||||
return filter_query
|
||||
|
||||
|
||||
class GenericRelationFilter(django_filters.CharFilter):
|
||||
def __init__(self, relation_name, choices, *args, **kwargs):
|
||||
self.relation_name = relation_name
|
||||
self.choices = choices
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def filter(self, qs, value):
|
||||
if not value:
|
||||
return qs
|
||||
try:
|
||||
filter_query = get_generic_filter_query(
|
||||
value, relation_name=self.relation_name, choices=self.choices
|
||||
)
|
||||
except forms.ValidationError:
|
||||
return qs.none()
|
||||
return qs.filter(filter_query)
|
||||
|
||||
|
||||
class GenericRelation(serializers.JSONField):
|
||||
def __init__(self, choices, *args, **kwargs):
|
||||
self.choices = choices
|
||||
|
|
|
@ -67,3 +67,39 @@ def test_generic_relation_field_validation_error(payload, expected_error, factor
|
|||
|
||||
with pytest.raises(fields.serializers.ValidationError, match=expected_error):
|
||||
f.to_internal_value(payload)
|
||||
|
||||
|
||||
def test_generic_relation_filter_target_type(factories):
|
||||
user = factories["users.User"]()
|
||||
note = factories["moderation.Note"](target=user)
|
||||
factories["moderation.Note"](target=factories["music.Artist"]())
|
||||
f = fields.GenericRelationFilter(
|
||||
"target",
|
||||
{
|
||||
"user": {
|
||||
"queryset": user.__class__.objects.all(),
|
||||
"id_attr": "username",
|
||||
"id_field": fields.serializers.CharField(),
|
||||
}
|
||||
},
|
||||
)
|
||||
qs = f.filter(note.__class__.objects.all(), "user")
|
||||
assert list(qs) == [note]
|
||||
|
||||
|
||||
def test_generic_relation_filter_target_type_and_id(factories):
|
||||
user = factories["users.User"]()
|
||||
note = factories["moderation.Note"](target=user)
|
||||
factories["moderation.Note"](target=factories["users.User"]())
|
||||
f = fields.GenericRelationFilter(
|
||||
"target",
|
||||
{
|
||||
"user": {
|
||||
"queryset": user.__class__.objects.all(),
|
||||
"id_attr": "username",
|
||||
"id_field": fields.serializers.CharField(),
|
||||
}
|
||||
},
|
||||
)
|
||||
qs = f.filter(note.__class__.objects.all(), "user:{}".format(user.username))
|
||||
assert list(qs) == [note]
|
||||
|
|
Loading…
Reference in New Issue