See #432: enforce a maximum number of tags per entity
This commit is contained in:
parent
bd271c8ead
commit
1b34ae2335
|
@ -716,3 +716,6 @@ ACTOR_KEY_ROTATION_DELAY = env.int("ACTOR_KEY_ROTATION_DELAY", default=3600 * 48
|
||||||
SUBSONIC_DEFAULT_TRANSCODING_FORMAT = (
|
SUBSONIC_DEFAULT_TRANSCODING_FORMAT = (
|
||||||
env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None
|
env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# extra tags will be ignored
|
||||||
|
TAGS_MAX_BY_OBJ = env.int("TAGS_MAX_BY_OBJ", default=30)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.contrib.postgres.fields import CICharField
|
from django.contrib.postgres.fields import CICharField
|
||||||
|
@ -65,6 +66,9 @@ def add_tags(obj, *tags):
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
def set_tags(obj, *tags):
|
def set_tags(obj, *tags):
|
||||||
|
# we ignore any extra tags if the length of the list is higher
|
||||||
|
# than our accepted size
|
||||||
|
tags = tags[: settings.TAGS_MAX_BY_OBJ]
|
||||||
tags = set(tags)
|
tags = set(tags)
|
||||||
existing = set(
|
existing = set(
|
||||||
TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True)
|
TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,3 +17,18 @@ class TagNameField(serializers.CharField):
|
||||||
if not models.TAG_REGEX.match(value):
|
if not models.TAG_REGEX.match(value):
|
||||||
raise serializers.ValidationError('Invalid tag "{}"'.format(value))
|
raise serializers.ValidationError('Invalid tag "{}"'.format(value))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class TagsListField(serializers.ListField):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("min_length", 0)
|
||||||
|
kwargs.setdefault("child", TagNameField())
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def to_internal_value(self, value):
|
||||||
|
value = super().to_internal_value(value)
|
||||||
|
if not value:
|
||||||
|
return value
|
||||||
|
# we ignore any extra tags if the length of the list is higher
|
||||||
|
# than our accepted size
|
||||||
|
return value[: settings.TAGS_MAX_BY_OBJ]
|
||||||
|
|
|
@ -53,6 +53,24 @@ def test_set_tags(factories, existing, given, expected):
|
||||||
assert match.content_object == obj
|
assert match.content_object == obj
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max, tags, expected",
|
||||||
|
[
|
||||||
|
(5, ["hello", "world"], ["hello", "world"]),
|
||||||
|
# we truncate extra tags
|
||||||
|
(1, ["hello", "world"], ["hello"]),
|
||||||
|
(2, ["hello", "world", "foo"], ["hello", "world"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_set_tags_honor_TAGS_MAX_BY_OBJ(factories, max, tags, expected, settings):
|
||||||
|
settings.TAGS_MAX_BY_OBJ = max
|
||||||
|
obj = factories["music.Artist"]()
|
||||||
|
|
||||||
|
models.set_tags(obj, *tags)
|
||||||
|
|
||||||
|
assert sorted(obj.tagged_items.values_list("tag__name", flat=True)) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"])
|
@pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"])
|
||||||
def test_models_that_support_tags(factories, factory_name):
|
def test_models_that_support_tags(factories, factory_name):
|
||||||
tags = ["tag1", "tag2"]
|
tags = ["tag1", "tag2"]
|
||||||
|
|
|
@ -29,3 +29,18 @@ def test_tag_name_field_validation(name):
|
||||||
field = serializers.TagNameField()
|
field = serializers.TagNameField()
|
||||||
with pytest.raises(serializers.serializers.ValidationError):
|
with pytest.raises(serializers.serializers.ValidationError):
|
||||||
field.to_internal_value(name)
|
field.to_internal_value(name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max, tags, expected",
|
||||||
|
[
|
||||||
|
(5, ["hello", "world"], ["hello", "world"]),
|
||||||
|
# we truncate extra tags
|
||||||
|
(1, ["hello", "world"], ["hello"]),
|
||||||
|
(2, ["hello", "world", "foo"], ["hello", "world"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tags_list_field_honor_TAGS_MAX_BY_OBJ(max, tags, expected, settings):
|
||||||
|
settings.TAGS_MAX_BY_OBJ = max
|
||||||
|
field = serializers.TagsListField()
|
||||||
|
assert field.to_internal_value(tags) == expected
|
||||||
|
|
Loading…
Reference in New Issue