diff --git a/api/funkwhale_api/radios/urls_v2.py b/api/funkwhale_api/radios/urls_v2.py index 55d1ad9aa..bac76f998 100644 --- a/api/funkwhale_api/radios/urls_v2.py +++ b/api/funkwhale_api/radios/urls_v2.py @@ -4,7 +4,7 @@ from . import views router = routers.OptionalSlashRouter() -router.register(r"sessions", views.V2_RadioSessionViewSet, "tracks") +router.register(r"sessions", views.V2_RadioSessionViewSet, "sessions") urlpatterns = router.urls diff --git a/api/funkwhale_api/radios/views.py b/api/funkwhale_api/radios/views.py index 5426ff10b..e89513472 100644 --- a/api/funkwhale_api/radios/views.py +++ b/api/funkwhale_api/radios/views.py @@ -165,17 +165,19 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS return super().get_serializer_class(*args, **kwargs) -class V2_RadioSessionViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): +class V2_RadioSessionViewSet( + mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet +): """Returns a list of RadioSessions""" serializer_class = serializers.RadioSessionSerializer queryset = models.RadioSession.objects.all() permission_classes = [] - @action(detail=True, serializer_class=serializers.RadioSessionTrackSerializer) - def tracks(self, request, *args, **kwargs): - """Returns tracks for the given radio session""" - serializer = self.get_serializer(data=request.query_params) + @action(detail=True, serializer_class=serializers.RadioSessionTrackSerializerCreate) + def tracks(self, request, pk, *args, **kwargs): + data = request.query_params + serializer = serializers.RadioSessionTrackSerializerCreate(data=data) serializer.is_valid(raise_exception=True) session = serializer.validated_data["session"] @@ -210,9 +212,8 @@ class V2_RadioSessionViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): cache.get(f"radiosessiontracks{session.id}") ) batch = evaluated_radio_tracks[:count] - serializer = self.serializer_class( + serializer = serializers.RadioSessionTrackSerializer( data=batch, - context=self.get_serializer_context(), many="true", ) serializer.is_valid() @@ -227,8 +228,3 @@ class V2_RadioSessionViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): serializer.data, status=status.HTTP_201_CREATED, ) - - def get_serializer_class(self, *args, **kwargs): - if self.action == "list": - return serializers.RadioSessionTrackSerializerCreate - return super().get_serializer_class(*args, **kwargs) diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index 45c3767d0..59e91f883 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -150,6 +150,21 @@ def test_can_use_radio_session_to_filter_choices(factories): assert len(set(tracks_id)) == 10 +def test_can_use_radio_session_to_filter_choices_v2(factories): + factories["music.Upload"].create_batch(10) + user = factories["users.User"]() + radio = radios.RandomRadio() + session = radio.start_session(user) + + radio.pick_many_v2(quantity=10, filter_playable=False) + + # ensure 10 different tracks have been suggested + tracks_id = [ + session_track.track.pk for session_track in session.session_tracks.all() + ] + assert len(set(tracks_id)) == 10 + + def test_can_restore_radio_from_previous_session(factories): user = factories["users.User"]() radio = radios.RandomRadio() @@ -202,7 +217,7 @@ def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client): response = logged_in_api_client.post(url, {"radio_type": "random"}) session = models.RadioSession.objects.latest("id") - url = reverse("api:v2:radios:tracks-list") + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) response = logged_in_api_client.get(url, {"session": session.pk}) data = json.loads(response.content.decode("utf-8")) @@ -465,7 +480,7 @@ def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_clien url = reverse("api:v1:radios:sessions-list") response = logged_in_api_client.post(url, {"radio_type": "random"}) session = models.RadioSession.objects.latest("id") - url = reverse("api:v2:radios:tracks-list") + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) previous_choices = [] @@ -541,7 +556,7 @@ def test_regenerate_cache_if_not_enought_tracks_in_it( url = reverse("api:v1:radios:sessions-list") response = logged_in_api_client.post(url, {"radio_type": "random"}) session = models.RadioSession.objects.latest("id") - url = reverse("api:v2:radios:tracks-list") + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) logged_in_api_client.get( url, {"session": session.pk, "count": 9, "filter_playable": False} )