diff --git a/api/funkwhale_api/audio/factories.py b/api/funkwhale_api/audio/factories.py index dabaa5114..ddf4ac938 100644 --- a/api/funkwhale_api/audio/factories.py +++ b/api/funkwhale_api/audio/factories.py @@ -33,3 +33,14 @@ class ChannelFactory(NoUpdateOnCreate, factory.django.DjangoModelFactory): ), artist__local=True, ) + + +@registry.register(name="audio.Subscription") +class SubscriptionFactory(NoUpdateOnCreate, factory.django.DjangoModelFactory): + uuid = factory.Faker("uuid4") + approved = True + target = factory.LazyAttribute(lambda o: ChannelFactory().actor) + actor = factory.SubFactory(federation_factories.ActorFactory) + + class Meta: + model = "federation.Follow" diff --git a/api/funkwhale_api/audio/filters.py b/api/funkwhale_api/audio/filters.py index 02776e032..9a8bec6cd 100644 --- a/api/funkwhale_api/audio/filters.py +++ b/api/funkwhale_api/audio/filters.py @@ -1,3 +1,5 @@ +from django.db.models import Q + import django_filters from funkwhale_api.common import fields @@ -23,12 +25,30 @@ class ChannelFilter(moderation_filters.HiddenContentFilterSet): ) tag = TAG_FILTER scope = common_filters.ActorScopeFilter(actor_field="attributed_to", distinct=True) + subscribed = django_filters.BooleanFilter( + field_name="_", method="filter_subscribed" + ) class Meta: model = models.Channel - fields = ["q", "scope", "tag"] + fields = ["q", "scope", "tag", "subscribed"] hidden_content_fields_mapping = moderation_filters.USER_FILTER_CONFIG["CHANNEL"] + def filter_subscribed(self, queryset, name, value): + if not self.request.user.is_authenticated: + return queryset.none() + + emitted_follows = self.request.user.actor.emitted_follows.exclude( + target__channel__isnull=True + ) + + query = Q(actor__in=emitted_follows.values_list("target", flat=True)) + + if value is True: + return queryset.filter(query) + else: + return queryset.exclude(query) + class IncludeChannelsFilterSet(django_filters.FilterSet): """ diff --git a/api/funkwhale_api/audio/serializers.py b/api/funkwhale_api/audio/serializers.py index 6b7bc00ef..a946df9a9 100644 --- a/api/funkwhale_api/audio/serializers.py +++ b/api/funkwhale_api/audio/serializers.py @@ -96,3 +96,15 @@ class ChannelSerializer(serializers.ModelSerializer): def get_artist(self, obj): return music_serializers.serialize_artist_simple(obj.artist) + + +class SubscriptionSerializer(serializers.Serializer): + approved = serializers.BooleanField(read_only=True) + fid = serializers.URLField(read_only=True) + uuid = serializers.UUIDField(read_only=True) + creation_date = serializers.DateTimeField(read_only=True) + + def to_representation(self, obj): + data = super().to_representation(obj) + data["channel"] = ChannelSerializer(obj.target.channel).data + return data diff --git a/api/funkwhale_api/audio/views.py b/api/funkwhale_api/audio/views.py index 9e77d043b..a473dbbe8 100644 --- a/api/funkwhale_api/audio/views.py +++ b/api/funkwhale_api/audio/views.py @@ -1,6 +1,12 @@ -from rest_framework import exceptions, mixins, viewsets +from rest_framework import decorators +from rest_framework import exceptions +from rest_framework import mixins +from rest_framework import permissions as rest_permissions +from rest_framework import response +from rest_framework import viewsets from django import http +from django.db.utils import IntegrityError from funkwhale_api.common import permissions from funkwhale_api.common import preferences @@ -52,3 +58,33 @@ class ChannelViewSet( def perform_create(self, serializer): return serializer.save(attributed_to=self.request.user.actor) + + @decorators.action( + detail=True, + methods=["post"], + permission_classes=[rest_permissions.IsAuthenticated], + ) + def subscribe(self, request, *args, **kwargs): + object = self.get_object() + try: + subscription = object.actor.received_follows.create( + approved=True, actor=request.user.actor, + ) + except IntegrityError: + # there's already a subscription for this actor/channel + subscription = object.actor.received_follows.filter( + actor=request.user.actor + ).get() + + data = serializers.SubscriptionSerializer(subscription).data + return response.Response(data, status=201) + + @decorators.action( + detail=True, + methods=["post", "delete"], + permission_classes=[rest_permissions.IsAuthenticated], + ) + def unsubscribe(self, request, *args, **kwargs): + object = self.get_object() + request.user.actor.emitted_follows.filter(target=object.actor).delete() + return response.Response(status=204) diff --git a/api/tests/audio/test_filters.py b/api/tests/audio/test_filters.py new file mode 100644 index 000000000..c0cb9caa4 --- /dev/null +++ b/api/tests/audio/test_filters.py @@ -0,0 +1,32 @@ +from funkwhale_api.audio import filters +from funkwhale_api.audio import models + + +def test_channel_filter_subscribed_true(factories, mocker, queryset_equal_list): + user = factories["users.User"](with_actor=True) + channel = factories["audio.Channel"]() + other_channel = factories["audio.Channel"]() + factories["audio.Subscription"](target=channel.actor, actor=user.actor) + factories["audio.Subscription"](target=other_channel.actor) + + qs = models.Channel.objects.all() + filterset = filters.ChannelFilter( + {"subscribed": "true"}, request=mocker.Mock(user=user), queryset=qs + ) + + assert filterset.qs == [channel] + + +def test_channel_filter_subscribed_false(factories, mocker, queryset_equal_list): + user = factories["users.User"](with_actor=True) + channel = factories["audio.Channel"]() + other_channel = factories["audio.Channel"]() + factories["audio.Subscription"](target=channel.actor, actor=user.actor) + factories["audio.Subscription"](target=other_channel.actor) + + qs = models.Channel.objects.all() + filterset = filters.ChannelFilter( + {"subscribed": "false"}, request=mocker.Mock(user=user), queryset=qs + ) + + assert filterset.qs == [other_channel] diff --git a/api/tests/audio/test_serializers.py b/api/tests/audio/test_serializers.py index 61beda82a..b431e8e96 100644 --- a/api/tests/audio/test_serializers.py +++ b/api/tests/audio/test_serializers.py @@ -88,3 +88,16 @@ def test_channel_serializer_representation(factories, to_api_date): ).data assert serializers.ChannelSerializer(channel).data == expected + + +def test_subscription_serializer(factories, to_api_date): + subscription = factories["audio.Subscription"]() + expected = { + "channel": serializers.ChannelSerializer(subscription.target.channel).data, + "uuid": str(subscription.uuid), + "creation_date": to_api_date(subscription.creation_date), + "approved": subscription.approved, + "fid": subscription.fid, + } + + assert serializers.SubscriptionSerializer(subscription).data == expected diff --git a/api/tests/audio/test_views.py b/api/tests/audio/test_views.py index 1ef650989..f8cf456a6 100644 --- a/api/tests/audio/test_views.py +++ b/api/tests/audio/test_views.py @@ -126,3 +126,34 @@ def test_channel_views_disabled_via_feature_flag( url = reverse(url_name) response = logged_in_api_client.get(url) assert response.status_code == 405 + + +def test_channel_subscribe(factories, logged_in_api_client): + actor = logged_in_api_client.user.create_actor() + channel = factories["audio.Channel"](artist__description=None) + url = reverse("api:v1:channels-subscribe", kwargs={"uuid": channel.uuid}) + + response = logged_in_api_client.post(url) + + assert response.status_code == 201 + + subscription = actor.emitted_follows.select_related( + "target__channel__artist__description" + ).latest("id") + expected = serializers.SubscriptionSerializer(subscription).data + assert response.data == expected + assert subscription.target == channel.actor + + +def test_channel_unsubscribe(factories, logged_in_api_client): + actor = logged_in_api_client.user.create_actor() + channel = factories["audio.Channel"]() + subscription = factories["audio.Subscription"](target=channel.actor, actor=actor) + url = reverse("api:v1:channels-unsubscribe", kwargs={"uuid": channel.uuid}) + + response = logged_in_api_client.post(url) + + assert response.status_code == 204 + + with pytest.raises(subscription.DoesNotExist): + subscription.refresh_from_db()