Merge branch '432-tags-mutation' into 'develop'

See #432: API for tags

See merge request funkwhale/funkwhale!830
This commit is contained in:
Eliot Berriot 2019-07-18 09:53:42 +02:00
commit 03a470deaf
17 changed files with 261 additions and 40 deletions

View File

@ -716,3 +716,6 @@ ACTOR_KEY_ROTATION_DELAY = env.int("ACTOR_KEY_ROTATION_DELAY", default=3600 * 48
SUBSONIC_DEFAULT_TRANSCODING_FORMAT = ( SUBSONIC_DEFAULT_TRANSCODING_FORMAT = (
env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None
) )
# extra tags will be ignored
TAGS_MAX_BY_OBJ = env.int("TAGS_MAX_BY_OBJ", default=30)

View File

@ -86,6 +86,7 @@ class MutationSerializer(serializers.Serializer):
class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer): class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
serialized_relations = {} serialized_relations = {}
previous_state_handlers = {}
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# we force partial mode, because update mutations are partial # we force partial mode, because update mutations are partial
@ -139,16 +140,20 @@ class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
return get_update_previous_state( return get_update_previous_state(
obj, obj,
*list(validated_data.keys()), *list(validated_data.keys()),
serialized_relations=self.serialized_relations serialized_relations=self.serialized_relations,
handlers=self.previous_state_handlers,
) )
def get_update_previous_state(obj, *fields, serialized_relations={}): def get_update_previous_state(obj, *fields, serialized_relations={}, handlers={}):
if not fields: if not fields:
raise ValueError("You need to provide at least one field") raise ValueError("You need to provide at least one field")
state = {} state = {}
for field in fields: for field in fields:
if field in handlers:
state[field] = handlers[field](obj)
continue
value = getattr(obj, field) value = getattr(obj, field)
if isinstance(value, models.Model): if isinstance(value, models.Model):
# we store the related object id and repr for better UX # we store the related object id and repr for better UX

View File

@ -9,9 +9,20 @@ from . import models
from . import utils from . import utils
def filter_tags(queryset, name, value):
non_empty_tags = [v.lower() for v in value if v]
for tag in non_empty_tags:
queryset = queryset.filter(tagged_items__tag__name=tag).distinct()
return queryset
TAG_FILTER = common_filters.MultipleQueryFilter(method=filter_tags)
class ArtistFilter(moderation_filters.HiddenContentFilterSet): class ArtistFilter(moderation_filters.HiddenContentFilterSet):
q = fields.SearchFilter(search_fields=["name"]) q = fields.SearchFilter(search_fields=["name"])
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = TAG_FILTER
class Meta: class Meta:
model = models.Artist model = models.Artist
@ -29,7 +40,7 @@ class ArtistFilter(moderation_filters.HiddenContentFilterSet):
class TrackFilter(moderation_filters.HiddenContentFilterSet): class TrackFilter(moderation_filters.HiddenContentFilterSet):
q = fields.SearchFilter(search_fields=["title", "album__title", "artist__name"]) q = fields.SearchFilter(search_fields=["title", "album__title", "artist__name"])
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = common_filters.MultipleQueryFilter(method="filter_tags") tag = TAG_FILTER
id = common_filters.MultipleQueryFilter(coerce=int) id = common_filters.MultipleQueryFilter(coerce=int)
class Meta: class Meta:
@ -48,12 +59,6 @@ class TrackFilter(moderation_filters.HiddenContentFilterSet):
actor = utils.get_actor_from_request(self.request) actor = utils.get_actor_from_request(self.request)
return queryset.playable_by(actor, value) return queryset.playable_by(actor, value)
def filter_tags(self, queryset, name, value):
non_empty_tags = [v.lower() for v in value if v]
for tag in non_empty_tags:
queryset = queryset.filter(tagged_items__tag__name=tag).distinct()
return queryset
class UploadFilter(filters.FilterSet): class UploadFilter(filters.FilterSet):
library = filters.CharFilter("library__uuid") library = filters.CharFilter("library__uuid")
@ -101,6 +106,7 @@ class UploadFilter(filters.FilterSet):
class AlbumFilter(moderation_filters.HiddenContentFilterSet): class AlbumFilter(moderation_filters.HiddenContentFilterSet):
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
q = fields.SearchFilter(search_fields=["title", "artist__name"]) q = fields.SearchFilter(search_fields=["title", "artist__name"])
tag = TAG_FILTER
class Meta: class Meta:
model = models.Album model = models.Album

View File

@ -2,7 +2,6 @@ import base64
import datetime import datetime
import logging import logging
import pendulum import pendulum
import re
import mutagen._util import mutagen._util
import mutagen.oggtheora import mutagen.oggtheora
@ -12,6 +11,8 @@ import mutagen.flac
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import Mapping from rest_framework.compat import Mapping
from funkwhale_api.tags import models as tags_models
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NODEFAULT = object() NODEFAULT = object()
# default title used when imported tracks miss the `Album` tag, see #122 # default title used when imported tracks miss the `Album` tag, see #122
@ -491,9 +492,6 @@ class PermissiveDateField(serializers.CharField):
return None return None
TAG_REGEX = re.compile(r"^((\w+)([\d_]*))$")
def extract_tags_from_genre(string): def extract_tags_from_genre(string):
tags = [] tags = []
delimiter = "@@@@@" delimiter = "@@@@@"
@ -511,7 +509,7 @@ def extract_tags_from_genre(string):
if not tag: if not tag:
continue continue
final_tag = "" final_tag = ""
if not TAG_REGEX.match(tag.replace(" ", "")): if not tags_models.TAG_REGEX.match(tag.replace(" ", "")):
# the string contains some non words chars ($, €, etc.), right now # the string contains some non words chars ($, €, etc.), right now
# we simply skip such tags # we simply skip such tags
continue continue

View File

@ -1,5 +1,7 @@
from funkwhale_api.common import mutations from funkwhale_api.common import mutations
from funkwhale_api.federation import routes from funkwhale_api.federation import routes
from funkwhale_api.tags import models as tags_models
from funkwhale_api.tags import serializers as tags_serializers
from . import models from . import models
@ -12,17 +14,32 @@ def can_approve(obj, actor):
return obj.is_local and actor.user and actor.user.get_permissions()["library"] return obj.is_local and actor.user and actor.user.get_permissions()["library"]
class TagMutation(mutations.UpdateMutationSerializer):
tags = tags_serializers.TagsListField()
previous_state_handlers = {
"tags": lambda obj: list(
sorted(obj.tagged_items.values_list("tag__name", flat=True))
)
}
def update(self, instance, validated_data):
tags = validated_data.pop("tags", [])
r = super().update(instance, validated_data)
tags_models.set_tags(instance, *tags)
return r
@mutations.registry.connect( @mutations.registry.connect(
"update", "update",
models.Track, models.Track,
perm_checkers={"suggest": can_suggest, "approve": can_approve}, perm_checkers={"suggest": can_suggest, "approve": can_approve},
) )
class TrackMutationSerializer(mutations.UpdateMutationSerializer): class TrackMutationSerializer(TagMutation):
serialized_relations = {"license": "code"} serialized_relations = {"license": "code"}
class Meta: class Meta:
model = models.Track model = models.Track
fields = ["license", "title", "position", "copyright"] fields = ["license", "title", "position", "copyright", "tags"]
def post_apply(self, obj, validated_data): def post_apply(self, obj, validated_data):
routes.outbox.dispatch( routes.outbox.dispatch(
@ -35,10 +52,10 @@ class TrackMutationSerializer(mutations.UpdateMutationSerializer):
models.Artist, models.Artist,
perm_checkers={"suggest": can_suggest, "approve": can_approve}, perm_checkers={"suggest": can_suggest, "approve": can_approve},
) )
class ArtistMutationSerializer(mutations.UpdateMutationSerializer): class ArtistMutationSerializer(TagMutation):
class Meta: class Meta:
model = models.Artist model = models.Artist
fields = ["name"] fields = ["name", "tags"]
def post_apply(self, obj, validated_data): def post_apply(self, obj, validated_data):
routes.outbox.dispatch( routes.outbox.dispatch(
@ -51,10 +68,10 @@ class ArtistMutationSerializer(mutations.UpdateMutationSerializer):
models.Album, models.Album,
perm_checkers={"suggest": can_suggest, "approve": can_approve}, perm_checkers={"suggest": can_suggest, "approve": can_approve},
) )
class AlbumMutationSerializer(mutations.UpdateMutationSerializer): class AlbumMutationSerializer(TagMutation):
class Meta: class Meta:
model = models.Album model = models.Album
fields = ["title", "release_date"] fields = ["title", "release_date", "tags"]
def post_apply(self, obj, validated_data): def post_apply(self, obj, validated_data):
routes.outbox.dispatch( routes.outbox.dispatch(

View File

@ -67,10 +67,24 @@ class ArtistAlbumSerializer(serializers.ModelSerializer):
class ArtistWithAlbumsSerializer(serializers.ModelSerializer): class ArtistWithAlbumsSerializer(serializers.ModelSerializer):
albums = ArtistAlbumSerializer(many=True, read_only=True) albums = ArtistAlbumSerializer(many=True, read_only=True)
tags = serializers.SerializerMethodField()
class Meta: class Meta:
model = models.Artist model = models.Artist
fields = ("id", "fid", "mbid", "name", "creation_date", "albums", "is_local") fields = (
"id",
"fid",
"mbid",
"name",
"creation_date",
"albums",
"is_local",
"tags",
)
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
class ArtistSimpleSerializer(serializers.ModelSerializer): class ArtistSimpleSerializer(serializers.ModelSerializer):
@ -124,6 +138,7 @@ class AlbumSerializer(serializers.ModelSerializer):
artist = ArtistSimpleSerializer(read_only=True) artist = ArtistSimpleSerializer(read_only=True)
cover = cover_field cover = cover_field
is_playable = serializers.SerializerMethodField() is_playable = serializers.SerializerMethodField()
tags = serializers.SerializerMethodField()
class Meta: class Meta:
model = models.Album model = models.Album
@ -139,6 +154,7 @@ class AlbumSerializer(serializers.ModelSerializer):
"creation_date", "creation_date",
"is_playable", "is_playable",
"is_local", "is_local",
"tags",
) )
def get_tracks(self, o): def get_tracks(self, o):
@ -153,6 +169,10 @@ class AlbumSerializer(serializers.ModelSerializer):
except AttributeError: except AttributeError:
return None return None
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
class TrackAlbumSerializer(serializers.ModelSerializer): class TrackAlbumSerializer(serializers.ModelSerializer):
artist = ArtistSimpleSerializer(read_only=True) artist = ArtistSimpleSerializer(read_only=True)
@ -192,6 +212,7 @@ class TrackSerializer(serializers.ModelSerializer):
album = TrackAlbumSerializer(read_only=True) album = TrackAlbumSerializer(read_only=True)
uploads = serializers.SerializerMethodField() uploads = serializers.SerializerMethodField()
listen_url = serializers.SerializerMethodField() listen_url = serializers.SerializerMethodField()
tags = serializers.SerializerMethodField()
class Meta: class Meta:
model = models.Track model = models.Track
@ -210,6 +231,7 @@ class TrackSerializer(serializers.ModelSerializer):
"copyright", "copyright",
"license", "license",
"is_local", "is_local",
"tags",
) )
def get_listen_url(self, obj): def get_listen_url(self, obj):
@ -219,6 +241,10 @@ class TrackSerializer(serializers.ModelSerializer):
uploads = getattr(obj, "playable_uploads", []) uploads = getattr(obj, "playable_uploads", [])
return TrackUploadSerializer(uploads, many=True).data return TrackUploadSerializer(uploads, many=True).data
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
@common_serializers.track_fields_for_update("name", "description", "privacy_level") @common_serializers.track_fields_for_update("name", "description", "privacy_level")
class LibraryForOwnerSerializer(serializers.ModelSerializer): class LibraryForOwnerSerializer(serializers.ModelSerializer):

View File

@ -23,13 +23,19 @@ from funkwhale_api.federation import actors
from funkwhale_api.federation import api_serializers as federation_api_serializers from funkwhale_api.federation import api_serializers as federation_api_serializers
from funkwhale_api.federation import decorators as federation_decorators from funkwhale_api.federation import decorators as federation_decorators
from funkwhale_api.federation import routes from funkwhale_api.federation import routes
from funkwhale_api.tags.models import Tag from funkwhale_api.tags.models import Tag, TaggedItem
from funkwhale_api.users.oauth import permissions as oauth_permissions from funkwhale_api.users.oauth import permissions as oauth_permissions
from . import filters, licenses, models, serializers, tasks, utils from . import filters, licenses, models, serializers, tasks, utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TAG_PREFETCH = Prefetch(
"tagged_items",
queryset=TaggedItem.objects.all().select_related().order_by("tag__name"),
to_attr="_prefetched_tagged_items",
)
def get_libraries(filter_uploads): def get_libraries(filter_uploads):
def libraries(self, request, *args, **kwargs): def libraries(self, request, *args, **kwargs):
@ -71,7 +77,9 @@ class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelV
albums = albums.annotate_playable_by_actor( albums = albums.annotate_playable_by_actor(
utils.get_actor_from_request(self.request) utils.get_actor_from_request(self.request)
) )
return queryset.prefetch_related(Prefetch("albums", queryset=albums)) return queryset.prefetch_related(
Prefetch("albums", queryset=albums), TAG_PREFETCH
)
libraries = action(methods=["get"], detail=True)( libraries = action(methods=["get"], detail=True)(
get_libraries( get_libraries(
@ -103,7 +111,9 @@ class AlbumViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelVi
.with_playable_uploads(utils.get_actor_from_request(self.request)) .with_playable_uploads(utils.get_actor_from_request(self.request))
.order_for_album() .order_for_album()
) )
qs = queryset.prefetch_related(Prefetch("tracks", queryset=tracks)) qs = queryset.prefetch_related(
Prefetch("tracks", queryset=tracks), TAG_PREFETCH
)
return qs return qs
libraries = action(methods=["get"], detail=True)( libraries = action(methods=["get"], detail=True)(
@ -206,7 +216,7 @@ class TrackViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelVi
queryset = queryset.with_playable_uploads( queryset = queryset.with_playable_uploads(
utils.get_actor_from_request(self.request) utils.get_actor_from_request(self.request)
) )
return queryset return queryset.prefetch_related(TAG_PREFETCH)
libraries = action(methods=["get"], detail=True)( libraries = action(methods=["get"], detail=True)(
get_libraries(filter_uploads=lambda o, uploads: uploads.filter(track=o)) get_libraries(filter_uploads=lambda o, uploads: uploads.filter(track=o))

View File

@ -1,3 +1,6 @@
import re
from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import CICharField from django.contrib.postgres.fields import CICharField
@ -8,6 +11,9 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
TAG_REGEX = re.compile(r"^((\w+)([\d_]*))$")
class Tag(models.Model): class Tag(models.Model):
name = CICharField(max_length=100, unique=True) name = CICharField(max_length=100, unique=True)
creation_date = models.DateTimeField(default=timezone.now) creation_date = models.DateTimeField(default=timezone.now)
@ -60,6 +66,9 @@ def add_tags(obj, *tags):
@transaction.atomic @transaction.atomic
def set_tags(obj, *tags): def set_tags(obj, *tags):
# we ignore any extra tags if the length of the list is higher
# than our accepted size
tags = tags[: settings.TAGS_MAX_BY_OBJ]
tags = set(tags) tags = set(tags)
existing = set( existing = set(
TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True) TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True)

View File

@ -1,5 +1,7 @@
from rest_framework import serializers from rest_framework import serializers
from django.conf import settings
from . import models from . import models
@ -7,3 +9,26 @@ class TagSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = models.Tag model = models.Tag
fields = ["name", "creation_date"] fields = ["name", "creation_date"]
class TagNameField(serializers.CharField):
def to_internal_value(self, value):
value = super().to_internal_value(value)
if not models.TAG_REGEX.match(value):
raise serializers.ValidationError('Invalid tag "{}"'.format(value))
return value
class TagsListField(serializers.ListField):
def __init__(self, *args, **kwargs):
kwargs.setdefault("min_length", 0)
kwargs.setdefault("child", TagNameField())
super().__init__(*args, **kwargs)
def to_internal_value(self, value):
value = super().to_internal_value(value)
if not value:
return value
# we ignore any extra tags if the length of the list is higher
# than our accepted size
return value[: settings.TAGS_MAX_BY_OBJ]

View File

@ -51,7 +51,7 @@ def test_apply_update_mutation(factories, mutations_registry, mocker):
) )
assert previous_state == get_update_previous_state.return_value assert previous_state == get_update_previous_state.return_value
get_update_previous_state.assert_called_once_with( get_update_previous_state.assert_called_once_with(
user, "username", serialized_relations={} user, "username", serialized_relations={}, handlers={}
) )
user.refresh_from_db() user.refresh_from_db()

View File

@ -1,3 +1,5 @@
import pytest
from funkwhale_api.music import filters from funkwhale_api.music import filters
from funkwhale_api.music import models from funkwhale_api.music import models
@ -54,28 +56,54 @@ def test_artist_filter_track_album_artist(factories, mocker, queryset_equal_list
assert filterset.qs == [hidden_track] assert filterset.qs == [hidden_track]
@pytest.mark.parametrize(
"factory_name, filterset_class",
[
("music.Track", filters.TrackFilter),
("music.Artist", filters.TrackFilter),
("music.Album", filters.TrackFilter),
],
)
def test_track_filter_tag_single( def test_track_filter_tag_single(
factories, mocker, queryset_equal_list, anonymous_user factory_name,
filterset_class,
factories,
mocker,
queryset_equal_list,
anonymous_user,
): ):
factories["music.Track"]() factories[factory_name]()
# tag name partially match the query, so this shouldn't match # tag name partially match the query, so this shouldn't match
factories["music.Track"](set_tags=["TestTag1"]) factories[factory_name](set_tags=["TestTag1"])
tagged = factories["music.Track"](set_tags=["TestTag"]) tagged = factories[factory_name](set_tags=["TestTag"])
qs = models.Track.objects.all() qs = tagged.__class__.objects.all()
filterset = filters.TrackFilter( filterset = filterset_class(
{"tag": "testTaG"}, request=mocker.Mock(user=anonymous_user), queryset=qs {"tag": "testTaG"}, request=mocker.Mock(user=anonymous_user), queryset=qs
) )
assert filterset.qs == [tagged] assert filterset.qs == [tagged]
@pytest.mark.parametrize(
"factory_name, filterset_class",
[
("music.Track", filters.TrackFilter),
("music.Artist", filters.ArtistFilter),
("music.Album", filters.AlbumFilter),
],
)
def test_track_filter_tag_multiple( def test_track_filter_tag_multiple(
factories, mocker, queryset_equal_list, anonymous_user factory_name,
filterset_class,
factories,
mocker,
queryset_equal_list,
anonymous_user,
): ):
factories["music.Track"](set_tags=["TestTag1"]) factories[factory_name](set_tags=["TestTag1"])
tagged = factories["music.Track"](set_tags=["TestTag1", "TestTag2"]) tagged = factories[factory_name](set_tags=["TestTag1", "TestTag2"])
qs = models.Track.objects.all() qs = tagged.__class__.objects.all()
filterset = filters.TrackFilter( filterset = filterset_class(
{"tag": ["testTaG1", "TestTag2"]}, {"tag": ["testTaG1", "TestTag2"]},
request=mocker.Mock(user=anonymous_user), request=mocker.Mock(user=anonymous_user),
queryset=qs, queryset=qs,

View File

@ -2,6 +2,7 @@ import datetime
import pytest import pytest
from funkwhale_api.music import licenses from funkwhale_api.music import licenses
from funkwhale_api.tags import models as tags_models
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -117,3 +118,25 @@ def test_track_mutation_apply_outbox(factories, mocker):
dispatch.assert_called_once_with( dispatch.assert_called_once_with(
{"type": "Update", "object": {"type": "Track"}}, context={"track": track} {"type": "Update", "object": {"type": "Track"}}, context={"track": track}
) )
@pytest.mark.parametrize("factory_name", ["music.Artist", "music.Album", "music.Track"])
def test_mutation_set_tags(factory_name, factories, now, mocker):
tags = ["tag1", "tag2"]
dispatch = mocker.patch("funkwhale_api.federation.routes.outbox.dispatch")
set_tags = mocker.spy(tags_models, "set_tags")
obj = factories[factory_name]()
assert obj.tagged_items.all().count() == 0
mutation = factories["common.Mutation"](
type="update", target=obj, payload={"tags": tags}
)
mutation.apply()
obj.refresh_from_db()
assert sorted(obj.tagged_items.all().values_list("tag__name", flat=True)) == tags
set_tags.assert_called_once_with(obj, *tags)
obj_type = factory_name.lstrip("music.")
dispatch.assert_called_once_with(
{"type": "Update", "object": {"type": obj_type}},
context={obj_type.lower(): obj},
)

View File

@ -69,6 +69,7 @@ def test_artist_with_albums_serializer(factories, to_api_date):
"is_local": artist.is_local, "is_local": artist.is_local,
"creation_date": to_api_date(artist.creation_date), "creation_date": to_api_date(artist.creation_date),
"albums": [serializers.ArtistAlbumSerializer(album).data], "albums": [serializers.ArtistAlbumSerializer(album).data],
"tags": [],
} }
serializer = serializers.ArtistWithAlbumsSerializer(artist) serializer = serializers.ArtistWithAlbumsSerializer(artist)
assert serializer.data == expected assert serializer.data == expected
@ -175,6 +176,7 @@ def test_album_serializer(factories, to_api_date):
"release_date": to_api_date(album.release_date), "release_date": to_api_date(album.release_date),
"tracks": serializers.AlbumTrackSerializer([track2, track1], many=True).data, "tracks": serializers.AlbumTrackSerializer([track2, track1], many=True).data,
"is_local": album.is_local, "is_local": album.is_local,
"tags": [],
} }
serializer = serializers.AlbumSerializer(album) serializer = serializers.AlbumSerializer(album)
@ -202,6 +204,7 @@ def test_track_serializer(factories, to_api_date):
"license": upload.track.license.code, "license": upload.track.license.code,
"copyright": upload.track.copyright, "copyright": upload.track.copyright,
"is_local": upload.track.is_local, "is_local": upload.track.is_local,
"tags": [],
} }
serializer = serializers.TrackSerializer(track) serializer = serializers.TrackSerializer(track)
assert serializer.data == expected assert serializer.data == expected

View File

@ -16,8 +16,11 @@ DATA_DIR = os.path.dirname(os.path.abspath(__file__))
def test_artist_list_serializer(api_request, factories, logged_in_api_client): def test_artist_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"]( track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished" library__privacy_level="everyone",
import_status="finished",
track__album__artist__set_tags=tags,
).track ).track
artist = track.artist artist = track.artist
request = api_request.get("/") request = api_request.get("/")
@ -27,8 +30,10 @@ def test_artist_list_serializer(api_request, factories, logged_in_api_client):
) )
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data} expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for artist in serializer.data: for artist in serializer.data:
artist["tags"] = tags
for album in artist["albums"]: for album in artist["albums"]:
album["is_playable"] = True album["is_playable"] = True
url = reverse("api:v1:artists-list") url = reverse("api:v1:artists-list")
response = logged_in_api_client.get(url) response = logged_in_api_client.get(url)
@ -37,8 +42,11 @@ def test_artist_list_serializer(api_request, factories, logged_in_api_client):
def test_album_list_serializer(api_request, factories, logged_in_api_client): def test_album_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"]( track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished" library__privacy_level="everyone",
import_status="finished",
track__album__set_tags=tags,
).track ).track
album = track.album album = track.album
request = api_request.get("/") request = api_request.get("/")
@ -47,6 +55,8 @@ def test_album_list_serializer(api_request, factories, logged_in_api_client):
qs, many=True, context={"request": request} qs, many=True, context={"request": request}
) )
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data} expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for album in serializer.data:
album["tags"] = tags
url = reverse("api:v1:albums-list") url = reverse("api:v1:albums-list")
response = logged_in_api_client.get(url) response = logged_in_api_client.get(url)
@ -55,8 +65,11 @@ def test_album_list_serializer(api_request, factories, logged_in_api_client):
def test_track_list_serializer(api_request, factories, logged_in_api_client): def test_track_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"]( track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished" library__privacy_level="everyone",
import_status="finished",
track__set_tags=tags,
).track ).track
request = api_request.get("/") request = api_request.get("/")
qs = track.__class__.objects.with_playable_uploads(None) qs = track.__class__.objects.with_playable_uploads(None)
@ -64,6 +77,8 @@ def test_track_list_serializer(api_request, factories, logged_in_api_client):
qs, many=True, context={"request": request} qs, many=True, context={"request": request}
) )
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data} expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for track in serializer.data:
track["tags"] = tags
url = reverse("api:v1:tracks-list") url = reverse("api:v1:tracks-list")
response = logged_in_api_client.get(url) response = logged_in_api_client.get(url)

View File

@ -53,6 +53,24 @@ def test_set_tags(factories, existing, given, expected):
assert match.content_object == obj assert match.content_object == obj
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_set_tags_honor_TAGS_MAX_BY_OBJ(factories, max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
obj = factories["music.Artist"]()
models.set_tags(obj, *tags)
assert sorted(obj.tagged_items.values_list("tag__name", flat=True)) == expected
@pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"]) @pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"])
def test_models_that_support_tags(factories, factory_name): def test_models_that_support_tags(factories, factory_name):
tags = ["tag1", "tag2"] tags = ["tag1", "tag2"]

View File

@ -1,3 +1,5 @@
import pytest
from funkwhale_api.tags import serializers from funkwhale_api.tags import serializers
@ -12,3 +14,33 @@ def test_tag_serializer(factories):
} }
assert serializer.data == expected assert serializer.data == expected
@pytest.mark.parametrize(
"name",
[
"",
"invalid because spaces",
"invalid-because-dashes",
"invalidbecausenonbreakingspaces",
],
)
def test_tag_name_field_validation(name):
field = serializers.TagNameField()
with pytest.raises(serializers.serializers.ValidationError):
field.to_internal_value(name)
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_tags_list_field_honor_TAGS_MAX_BY_OBJ(max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
field = serializers.TagsListField()
assert field.to_internal_value(tags) == expected

View File

@ -75,6 +75,9 @@ http {
location /front-server/ { location /front-server/ {
proxy_pass http://funkwhale-front/; proxy_pass http://funkwhale-front/;
} }
location /sockjs-node/ {
proxy_pass http://funkwhale-front/sockjs-node/;
}
location / { location / {
include /etc/nginx/funkwhale_proxy.conf; include /etc/nginx/funkwhale_proxy.conf;