diff --git a/api/funkwhale_api/music/filters.py b/api/funkwhale_api/music/filters.py index d8e68d731..44763b966 100644 --- a/api/funkwhale_api/music/filters.py +++ b/api/funkwhale_api/music/filters.py @@ -9,9 +9,20 @@ from . import models from . import utils +def filter_tags(queryset, name, value): + non_empty_tags = [v.lower() for v in value if v] + for tag in non_empty_tags: + queryset = queryset.filter(tagged_items__tag__name=tag).distinct() + return queryset + + +TAG_FILTER = common_filters.MultipleQueryFilter(method=filter_tags) + + class ArtistFilter(moderation_filters.HiddenContentFilterSet): q = fields.SearchFilter(search_fields=["name"]) playable = filters.BooleanFilter(field_name="_", method="filter_playable") + tag = TAG_FILTER class Meta: model = models.Artist @@ -29,7 +40,7 @@ class ArtistFilter(moderation_filters.HiddenContentFilterSet): class TrackFilter(moderation_filters.HiddenContentFilterSet): q = fields.SearchFilter(search_fields=["title", "album__title", "artist__name"]) playable = filters.BooleanFilter(field_name="_", method="filter_playable") - tag = common_filters.MultipleQueryFilter(method="filter_tags") + tag = TAG_FILTER id = common_filters.MultipleQueryFilter(coerce=int) class Meta: @@ -48,12 +59,6 @@ class TrackFilter(moderation_filters.HiddenContentFilterSet): actor = utils.get_actor_from_request(self.request) return queryset.playable_by(actor, value) - def filter_tags(self, queryset, name, value): - non_empty_tags = [v.lower() for v in value if v] - for tag in non_empty_tags: - queryset = queryset.filter(tagged_items__tag__name=tag).distinct() - return queryset - class UploadFilter(filters.FilterSet): library = filters.CharFilter("library__uuid") @@ -101,6 +106,7 @@ class UploadFilter(filters.FilterSet): class AlbumFilter(moderation_filters.HiddenContentFilterSet): playable = filters.BooleanFilter(field_name="_", method="filter_playable") q = fields.SearchFilter(search_fields=["title", "artist__name"]) + tag = TAG_FILTER class Meta: model = models.Album diff --git a/api/tests/music/test_filters.py b/api/tests/music/test_filters.py index 0076f3c93..f3ff13e77 100644 --- a/api/tests/music/test_filters.py +++ b/api/tests/music/test_filters.py @@ -1,3 +1,5 @@ +import pytest + from funkwhale_api.music import filters from funkwhale_api.music import models @@ -54,28 +56,54 @@ def test_artist_filter_track_album_artist(factories, mocker, queryset_equal_list assert filterset.qs == [hidden_track] +@pytest.mark.parametrize( + "factory_name, filterset_class", + [ + ("music.Track", filters.TrackFilter), + ("music.Artist", filters.TrackFilter), + ("music.Album", filters.TrackFilter), + ], +) def test_track_filter_tag_single( - factories, mocker, queryset_equal_list, anonymous_user + factory_name, + filterset_class, + factories, + mocker, + queryset_equal_list, + anonymous_user, ): - factories["music.Track"]() + factories[factory_name]() # tag name partially match the query, so this shouldn't match - factories["music.Track"](set_tags=["TestTag1"]) - tagged = factories["music.Track"](set_tags=["TestTag"]) - qs = models.Track.objects.all() - filterset = filters.TrackFilter( + factories[factory_name](set_tags=["TestTag1"]) + tagged = factories[factory_name](set_tags=["TestTag"]) + qs = tagged.__class__.objects.all() + filterset = filterset_class( {"tag": "testTaG"}, request=mocker.Mock(user=anonymous_user), queryset=qs ) assert filterset.qs == [tagged] +@pytest.mark.parametrize( + "factory_name, filterset_class", + [ + ("music.Track", filters.TrackFilter), + ("music.Artist", filters.ArtistFilter), + ("music.Album", filters.AlbumFilter), + ], +) def test_track_filter_tag_multiple( - factories, mocker, queryset_equal_list, anonymous_user + factory_name, + filterset_class, + factories, + mocker, + queryset_equal_list, + anonymous_user, ): - factories["music.Track"](set_tags=["TestTag1"]) - tagged = factories["music.Track"](set_tags=["TestTag1", "TestTag2"]) - qs = models.Track.objects.all() - filterset = filters.TrackFilter( + factories[factory_name](set_tags=["TestTag1"]) + tagged = factories[factory_name](set_tags=["TestTag1", "TestTag2"]) + qs = tagged.__class__.objects.all() + filterset = filterset_class( {"tag": ["testTaG1", "TestTag2"]}, request=mocker.Mock(user=anonymous_user), queryset=qs,