See #432: enforce a maximum number of tags per entity

This commit is contained in:
Eliot Berriot 2019-07-15 11:53:58 +02:00
parent bd271c8ead
commit 1b34ae2335
No known key found for this signature in database
GPG Key ID: DD6965E2476E5C27
5 changed files with 57 additions and 0 deletions

View File

@ -716,3 +716,6 @@ ACTOR_KEY_ROTATION_DELAY = env.int("ACTOR_KEY_ROTATION_DELAY", default=3600 * 48
SUBSONIC_DEFAULT_TRANSCODING_FORMAT = ( SUBSONIC_DEFAULT_TRANSCODING_FORMAT = (
env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None
) )
# extra tags will be ignored
TAGS_MAX_BY_OBJ = env.int("TAGS_MAX_BY_OBJ", default=30)

View File

@ -1,5 +1,6 @@
import re import re
from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import CICharField from django.contrib.postgres.fields import CICharField
@ -65,6 +66,9 @@ def add_tags(obj, *tags):
@transaction.atomic @transaction.atomic
def set_tags(obj, *tags): def set_tags(obj, *tags):
# we ignore any extra tags if the length of the list is higher
# than our accepted size
tags = tags[: settings.TAGS_MAX_BY_OBJ]
tags = set(tags) tags = set(tags)
existing = set( existing = set(
TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True) TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True)

View File

@ -1,5 +1,7 @@
from rest_framework import serializers from rest_framework import serializers
from django.conf import settings
from . import models from . import models
@ -15,3 +17,18 @@ class TagNameField(serializers.CharField):
if not models.TAG_REGEX.match(value): if not models.TAG_REGEX.match(value):
raise serializers.ValidationError('Invalid tag "{}"'.format(value)) raise serializers.ValidationError('Invalid tag "{}"'.format(value))
return value return value
class TagsListField(serializers.ListField):
def __init__(self, *args, **kwargs):
kwargs.setdefault("min_length", 0)
kwargs.setdefault("child", TagNameField())
super().__init__(*args, **kwargs)
def to_internal_value(self, value):
value = super().to_internal_value(value)
if not value:
return value
# we ignore any extra tags if the length of the list is higher
# than our accepted size
return value[: settings.TAGS_MAX_BY_OBJ]

View File

@ -53,6 +53,24 @@ def test_set_tags(factories, existing, given, expected):
assert match.content_object == obj assert match.content_object == obj
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_set_tags_honor_TAGS_MAX_BY_OBJ(factories, max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
obj = factories["music.Artist"]()
models.set_tags(obj, *tags)
assert sorted(obj.tagged_items.values_list("tag__name", flat=True)) == expected
@pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"]) @pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"])
def test_models_that_support_tags(factories, factory_name): def test_models_that_support_tags(factories, factory_name):
tags = ["tag1", "tag2"] tags = ["tag1", "tag2"]

View File

@ -29,3 +29,18 @@ def test_tag_name_field_validation(name):
field = serializers.TagNameField() field = serializers.TagNameField()
with pytest.raises(serializers.serializers.ValidationError): with pytest.raises(serializers.serializers.ValidationError):
field.to_internal_value(name) field.to_internal_value(name)
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_tags_list_field_honor_TAGS_MAX_BY_OBJ(max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
field = serializers.TagsListField()
assert field.to_internal_value(tags) == expected