From 32daed3524afcf51fd7a56395535b135c69c6527 Mon Sep 17 00:00:00 2001 From: Petitminion Date: Tue, 27 Jun 2023 14:56:31 +0200 Subject: [PATCH] resolve code review 2 --- .../migrations/0008_auto_20230613_2123.py | 19 ----- api/funkwhale_api/radios/models.py | 3 +- api/funkwhale_api/radios/radios.py | 80 ++++++++----------- api/funkwhale_api/radios/views.py | 28 +++---- api/tests/radios/test_radios.py | 32 ++++---- 5 files changed, 60 insertions(+), 102 deletions(-) delete mode 100644 api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py diff --git a/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py b/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py deleted file mode 100644 index c72175b24..000000000 --- a/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py +++ /dev/null @@ -1,19 +0,0 @@ -# Generated by Django 3.2.18 on 2023-06-13 21:23 - -import django.core.serializers.json -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('radios', '0007_merge_20220715_0801'), - ] - - operations = [ - migrations.AddField( - model_name='radiosessiontrack', - name='played', - field=models.BooleanField(default=False), - ), - ] diff --git a/api/funkwhale_api/radios/models.py b/api/funkwhale_api/radios/models.py index 489471179..de20969fc 100644 --- a/api/funkwhale_api/radios/models.py +++ b/api/funkwhale_api/radios/models.py @@ -73,7 +73,7 @@ class RadioSession(models.Model): radio_session_tracks = [] for i, track in enumerate(tracks): radio_session_track = RadioSessionTrack( - track=track, session=self, position=next_position + i, played=True + track=track, session=self, position=next_position + i ) radio_session_tracks.append(radio_session_track) @@ -96,7 +96,6 @@ class RadioSessionTrack(models.Model): track = models.ForeignKey( Track, related_name="radio_session_tracks", on_delete=models.CASCADE ) - played = models.BooleanField(default=False) class Meta: ordering = ("session", "position") diff --git a/api/funkwhale_api/radios/radios.py b/api/funkwhale_api/radios/radios.py index 31ba4f7e5..81d683bd2 100644 --- a/api/funkwhale_api/radios/radios.py +++ b/api/funkwhale_api/radios/radios.py @@ -29,14 +29,14 @@ class SimpleRadio: def clean(self, instance): return - def pick( + def pick_v1( self, choices: List[int], previous_choices: Optional[List[int]] = None ) -> int: if previous_choices: choices = list(set(choices).difference(set(previous_choices))) return random.sample(choices, 1)[0] - def pick_many(self, choices: List[int], quantity: int) -> int: + def pick_many_v1(self, choices: List[int], quantity: int) -> int: return random.sample(list(set(choices)), quantity) def weighted_pick( @@ -64,22 +64,19 @@ class SessionRadio(SimpleRadio): return self.session def get_queryset(self, **kwargs): - qs = ( - Track.objects.all() - .with_playable_uploads(actor=None) - .select_related("artist", "album__artist", "attributed_to") - ) - if not self.session: - return qs - if not self.session.user: - return qs - - qs = ( - Track.objects.all() - .with_playable_uploads(self.session.user.actor) - .select_related("artist", "album__artist", "attributed_to") - ) + if not self.session or not self.session.user: + return ( + Track.objects.all() + .with_playable_uploads(actor=None) + .select_related("artist", "album__artist", "attributed_to") + ) + else: + qs = ( + Track.objects.all() + .with_playable_uploads(self.session.user.actor) + .select_related("artist", "album__artist", "attributed_to") + ) query = moderation_filters.get_filtered_content_query( config=moderation_filters.USER_FILTER_CONFIG["TRACK"], @@ -90,18 +87,6 @@ class SessionRadio(SimpleRadio): def get_queryset_kwargs(self): return {} - def get_choices(self, **kwargs): - kwargs.update(self.get_queryset_kwargs()) - queryset = self.get_queryset(**kwargs) - if self.session: - queryset = self.filter_from_session(queryset) - if kwargs.pop("filter_playable", True): - queryset = queryset.playable_by( - self.session.user.actor if self.session.user else None - ) - queryset = self.filter_queryset(queryset) - return queryset - def filter_queryset(self, queryset): return queryset @@ -112,12 +97,24 @@ class SessionRadio(SimpleRadio): queryset = queryset.exclude(pk__in=already_played) return queryset - def pick(self, **kwargs): - return self.pick_many(quantity=1, **kwargs)[0] + def get_choices_v1(self, **kwargs): + kwargs.update(self.get_queryset_kwargs()) + queryset = self.get_queryset(**kwargs) + if self.session: + queryset = self.filter_from_session(queryset) + if kwargs.pop("filter_playable", True): + queryset = queryset.playable_by( + self.session.user.actor if self.session.user else None + ) + queryset = self.filter_queryset(queryset) + return queryset - def pick_many(self, quantity, **kwargs): - choices = self.get_choices(**kwargs) - picked_choices = super().pick_many(choices=choices, quantity=quantity) + def pick_v1(self, **kwargs): + return self.pick_many_v1(quantity=1, **kwargs)[0] + + def pick_many_v1(self, quantity, **kwargs): + choices = self.get_choices_v1(**kwargs) + picked_choices = super().pick_many_v1(choices=choices, quantity=quantity) if self.session: self.session.add(picked_choices) return picked_choices @@ -135,7 +132,8 @@ class SessionRadio(SimpleRadio): # get the queryset and apply filters kwargs.update(self.get_queryset_kwargs()) queryset = self.get_queryset(**kwargs) - queryset = self.filter_already_played_from_session(queryset) + queryset = self.filter_from_session(queryset) + if kwargs["filter_playable"] is True: queryset = queryset.playable_by( self.session.user.actor if self.session.user else None @@ -152,7 +150,7 @@ class SessionRadio(SimpleRadio): # evaluate the queryset to save it in cache if cached_evaluated_radio_tracks is not None: - radio_tracks = [t for t in radio_tracks] + radio_tracks = list(radio_tracks) radio_tracks.extend(cached_evaluated_radio_tracks) logger.info( f"Setting redis cache for radio generation with radio id {self.session.id}" @@ -164,16 +162,6 @@ class SessionRadio(SimpleRadio): return sliced_queryset - def filter_already_played_from_session(self, queryset): - if already_played := self.session.session_tracks.filter( - played=True - ).values_list("track", flat=True): - logger.debug("Filtering already played track " + str(already_played)) - queryset = queryset.exclude(pk__in=already_played) - else: - logger.debug("No track already played") - return queryset - def get_choices_v2(self, quantity, **kwargs): if cache.get(f"radiosessiontracks{self.session.id}"): cached_radio_tracks = pickle.loads( diff --git a/api/funkwhale_api/radios/views.py b/api/funkwhale_api/radios/views.py index 4d221b619..7882135fd 100644 --- a/api/funkwhale_api/radios/views.py +++ b/api/funkwhale_api/radios/views.py @@ -10,7 +10,6 @@ from rest_framework.response import Response from funkwhale_api.common import permissions as common_permissions from funkwhale_api.music import utils as music_utils from funkwhale_api.music.serializers import TrackSerializer -from funkwhale_api.radios.models import RadioSessionTrack from funkwhale_api.users.oauth import permissions as oauth_permissions from . import filters, filtersets, models, serializers @@ -137,15 +136,13 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS session = serializer.validated_data["session"] if not request.user.is_authenticated and not request.session.session_key: self.request.session.create() - try: - assert (request.user == session.user) or ( - request.session.session_key == session.session_key - and session.session_key - ) - except AssertionError: + if not request.user == session.user or not ( + request.session.session_key == session.session_key and session.session_key + ): return Response(status=status.HTTP_403_FORBIDDEN) + try: - session.radio.pick() + session.radio.pick_v1() except ValueError: return Response( "Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND @@ -192,12 +189,10 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet ) if not request.user.is_authenticated and not request.session.session_key: self.request.session.create() - try: - assert (request.user == session.user) or ( - request.session.session_key == session.session_key - and session.session_key - ) - except AssertionError: + + if not request.user == session.user or not ( + request.session.session_key == session.session_key and session.session_key + ): return Response(status=status.HTTP_403_FORBIDDEN) try: @@ -220,11 +215,6 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet serializer.is_valid() headers = self.get_success_headers(serializer.data) - # mark the RadioTracks has played - for radiotrack in batch: - radiotrack.played = True - RadioSessionTrack.objects.bulk_update(batch, ["played"]) - # delete the tracks we sent from the cache new_cached_radiotracks = evaluated_radio_tracks[count:] cache.set( diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index 9a8b04477..2c1fcb290 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -17,13 +17,13 @@ def test_can_pick_track_from_choices(): radio = radios.SimpleRadio() - first_pick = radio.pick(choices=choices) + first_pick = radio.pick_v1(choices=choices) assert first_pick in choices previous_choices = [first_pick] for remaining_choice in choices: - pick = radio.pick(choices=choices, previous_choices=previous_choices) + pick = radio.pick_v1(choices=choices, previous_choices=previous_choices) assert pick in set(choices).difference(set(previous_choices)) @@ -62,14 +62,14 @@ def test_session_radio_excludes_previous_picks(factories): radio.start_session(user) for i in range(5): - pick = radio.pick(user=user, filter_playable=False) + pick = radio.pick_v1(user=user, filter_playable=False) assert pick in tracks assert pick not in previous_choices previous_choices.append(pick) with pytest.raises(ValueError): # no more picks available - radio.pick(user=user, filter_playable=False) + radio.pick_v1(user=user, filter_playable=False) def test_can_get_choices_for_favorites_radio(factories): @@ -80,7 +80,7 @@ def test_can_get_choices_for_favorites_radio(factories): TrackFavorite.add(track=random.choice(tracks), user=user) radio = radios.FavoritesRadio() - choices = radio.get_choices(user=user) + choices = radio.get_choices_v1(user=user) assert choices.count() == user.track_favorites.all().count() @@ -88,7 +88,7 @@ def test_can_get_choices_for_favorites_radio(factories): assert favorite.track in choices for i in range(5): - pick = radio.pick(user=user) + pick = radio.pick_v1(user=user) assert pick in choices @@ -101,7 +101,7 @@ def test_can_get_choices_for_custom_radio(factories): session = factories["radios.CustomRadioSession"]( custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio.get_choices_v1(filter_playable=False) expected = [t.pk for t in tracks] assert list(choices.values_list("id", flat=True)) == expected @@ -141,7 +141,7 @@ def test_can_use_radio_session_to_filter_choices(factories): session = radio.start_session(user) for i in range(10): - radio.pick(filter_playable=False) + radio.pick_v1(filter_playable=False) # ensure 10 different tracks have been suggested tracks_id = [ @@ -243,7 +243,7 @@ def test_can_start_artist_radio(factories): session = radio.start_session(user, related_object=artist) assert session.radio_type == "artist" for i in range(5): - assert radio.pick(filter_playable=False) in good_tracks + assert radio.pick_v1(filter_playable=False) in good_tracks def test_can_start_tag_radio(factories): @@ -261,7 +261,7 @@ def test_can_start_tag_radio(factories): assert session.radio_type == "tag" for i in range(3): - assert radio.pick(filter_playable=False) in good_tracks + assert radio.pick_v1(filter_playable=False) in good_tracks def test_can_start_actor_content_radio(factories): @@ -280,7 +280,7 @@ def test_can_start_actor_content_radio(factories): assert session.radio_type == "actor-content" for i in range(3): - assert radio.pick() in good_tracks + assert radio.pick_v1() in good_tracks def test_can_start_actor_content_radio_from_api( @@ -316,7 +316,7 @@ def test_can_start_library_radio(factories): assert session.radio_type == "library" for i in range(3): - assert radio.pick(filter_playable=False) in good_tracks + assert radio.pick_v1(filter_playable=False) in good_tracks def test_can_start_library_radio_from_api(logged_in_api_client, preferences, factories): @@ -362,7 +362,7 @@ def test_can_start_less_listened_radio(factories): radio.start_session(user) for i in range(5): - assert radio.pick(filter_playable=False) in good_tracks + assert radio.pick_v1(filter_playable=False) in good_tracks def test_similar_radio_track(factories): @@ -379,7 +379,7 @@ def test_similar_radio_track(factories): expected_next = factories["music.Track"]() factories["history.Listening"](track=expected_next, user=l1.user) - assert radio.pick(filter_playable=False) == expected_next + assert radio.pick_v1(filter_playable=False) == expected_next def test_session_radio_get_queryset_ignore_filtered_track_artist( @@ -420,7 +420,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories): {"type": "artist", "ids": [excluded_artist.pk], "not": True}, ] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio.get_choices_v1(filter_playable=False) expected = [u.track.pk for u in included_uploads] assert list(choices.values_list("id", flat=True)) == expected @@ -438,7 +438,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories): {"type": "tag", "names": ["rock"], "not": True}, ] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio.get_choices_v1(filter_playable=False) expected = [u.track.pk for u in included_uploads] assert list(choices.values_list("id", flat=True)) == expected