get instead of post

This commit is contained in:
Petitminion 2023-06-29 14:22:04 +02:00
parent 4b65bcba95
commit 064accf288
2 changed files with 22 additions and 20 deletions

View File

@ -136,8 +136,9 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS
session = serializer.validated_data["session"] session = serializer.validated_data["session"]
if not request.user.is_authenticated and not request.session.session_key: if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create() self.request.session.create()
if not request.user == session.user or not ( if not request.user == session.user or (
request.session.session_key == session.session_key and session.session_key not request.session.session_key == session.session_key
and not session.session_key
): ):
return Response(status=status.HTTP_403_FORBIDDEN) return Response(status=status.HTTP_403_FORBIDDEN)
@ -164,18 +165,19 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS
return super().get_serializer_class(*args, **kwargs) return super().get_serializer_class(*args, **kwargs)
class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): class RadioSessionTracksViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
"""Return a list of RadioSessionTracks""" """Return a list of RadioSessionTracks"""
serializer_class = serializers.RadioSessionTrackSerializer serializer_class = serializers.RadioSessionTrackSerializer
queryset = models.RadioSessionTrack.objects.all() queryset = models.RadioSessionTrack.objects.all()
permission_classes = [] permission_classes = []
@extend_schema(operation_id="get_radio_tracks") @extend_schema(operation_id="get_radio_tracks_get")
def create(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.query_params)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
session = serializer.validated_data["session"] session = serializer.validated_data["session"]
count = ( count = (
serializer.validated_data["count"] serializer.validated_data["count"]
if "count" in serializer.validated_data.keys() if "count" in serializer.validated_data.keys()
@ -183,18 +185,18 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
) )
# this is used for test purpose. # this is used for test purpose.
filter_playable = ( filter_playable = (
request.data["filter_playable"] request.query_params["filter_playable"]
if "filter_playable" in request.data.keys() if "filter_playable" in request.query_params.keys()
else True else True
) )
if not request.user.is_authenticated and not request.session.session_key: if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create() self.request.session.create()
if not request.user == session.user or not ( if not request.user == session.user or (
request.session.session_key == session.session_key and session.session_key not request.session.session_key == session.session_key
and not session.session_key
): ):
return Response(status=status.HTTP_403_FORBIDDEN) return Response(status=status.HTTP_403_FORBIDDEN)
try: try:
session.radio.pick_many_v2(count, filter_playable=filter_playable) session.radio.pick_many_v2(count, filter_playable=filter_playable)
except ValueError: except ValueError:
@ -213,7 +215,6 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
many="true", many="true",
) )
serializer.is_valid() serializer.is_valid()
headers = self.get_success_headers(serializer.data)
# delete the tracks we sent from the cache # delete the tracks we sent from the cache
new_cached_radiotracks = evaluated_radio_tracks[count:] new_cached_radiotracks = evaluated_radio_tracks[count:]
@ -222,10 +223,11 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
) )
return Response( return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers serializer.data,
status=status.HTTP_201_CREATED,
) )
def get_serializer_class(self, *args, **kwargs): def get_serializer_class(self, *args, **kwargs):
if self.action == "create": if self.action == "list":
return serializers.RadioSessionTrackSerializerCreate return serializers.RadioSessionTrackSerializerCreate
return super().get_serializer_class(*args, **kwargs) return super().get_serializer_class(*args, **kwargs)

View File

@ -203,7 +203,7 @@ def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client):
session = models.RadioSession.objects.latest("id") session = models.RadioSession.objects.latest("id")
url = reverse("api:v2:radios:tracks-list") url = reverse("api:v2:radios:tracks-list")
response = logged_in_api_client.post(url, {"session": session.pk}) response = logged_in_api_client.get(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8")) data = json.loads(response.content.decode("utf-8"))
assert data[0]["track"]["id"] == track.pk assert data[0]["track"]["id"] == track.pk
@ -212,7 +212,7 @@ def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client):
next_track = factories["music.Upload"]( next_track = factories["music.Upload"](
library__actor=actor, import_status="finished" library__actor=actor, import_status="finished"
).track ).track
response = logged_in_api_client.post(url, {"session": session.pk}) response = logged_in_api_client.get(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8")) data = json.loads(response.content.decode("utf-8"))
assert data[0]["track"]["id"] == next_track.id assert data[0]["track"]["id"] == next_track.id
@ -470,7 +470,7 @@ def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_clien
previous_choices = [] previous_choices = []
for i in range(5): for i in range(5):
response = logged_in_api_client.post( response = logged_in_api_client.get(
url, {"session": session.pk, "filter_playable": False} url, {"session": session.pk, "filter_playable": False}
) )
pick = json.loads(response.content.decode("utf-8")) pick = json.loads(response.content.decode("utf-8"))
@ -478,7 +478,7 @@ def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_clien
assert pick[0]["track"]["title"] in [t.title for t in tracks] assert pick[0]["track"]["title"] in [t.title for t in tracks]
previous_choices.append(pick[0]["track"]["title"]) previous_choices.append(pick[0]["track"]["title"])
response = logged_in_api_client.post(url, {"session": session.pk}) response = logged_in_api_client.get(url, {"session": session.pk})
assert ( assert (
json.loads(response.content.decode("utf-8")) json.loads(response.content.decode("utf-8"))
== "Radio doesn't have more candidates" == "Radio doesn't have more candidates"
@ -542,10 +542,10 @@ def test_regenerate_cache_if_not_enought_tracks_in_it(
response = logged_in_api_client.post(url, {"radio_type": "random"}) response = logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id") session = models.RadioSession.objects.latest("id")
url = reverse("api:v2:radios:tracks-list") url = reverse("api:v2:radios:tracks-list")
logged_in_api_client.post( logged_in_api_client.get(
url, {"session": session.pk, "count": 9, "filter_playable": False} url, {"session": session.pk, "count": 9, "filter_playable": False}
) )
response = logged_in_api_client.post( response = logged_in_api_client.get(
url, {"session": session.pk, "count": 10, "filter_playable": False} url, {"session": session.pk, "count": 10, "filter_playable": False}
) )
pick = json.loads(response.content.decode("utf-8")) pick = json.loads(response.content.decode("utf-8"))