From 38d3eb4f212c7ce59b78e85534a19793938a62cc Mon Sep 17 00:00:00 2001 From: Petitminion Date: Wed, 14 Jun 2023 10:09:12 +0200 Subject: [PATCH] cache 100 radiotracks into redis, return a list of radiotracks instead of a single one --- api/config/urls/api_v2.py | 4 + api/funkwhale_api/music/models.py | 2 + .../migrations/0008_auto_20230613_2123.py | 19 +++ api/funkwhale_api/radios/models.py | 17 ++- api/funkwhale_api/radios/radios.py | 111 +++++++++++++--- api/funkwhale_api/radios/serializers.py | 7 +- api/funkwhale_api/radios/urls_v2.py | 10 ++ api/funkwhale_api/radios/views.py | 68 ++++++++++ api/tests/radios/test_radios.py | 124 +++++++++++++++--- 9 files changed, 317 insertions(+), 45 deletions(-) create mode 100644 api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py create mode 100644 api/funkwhale_api/radios/urls_v2.py diff --git a/api/config/urls/api_v2.py b/api/config/urls/api_v2.py index 95c776a0c..d5e040337 100644 --- a/api/config/urls/api_v2.py +++ b/api/config/urls/api_v2.py @@ -10,6 +10,10 @@ v2_patterns += [ r"^instance/", include(("funkwhale_api.instance.urls", "instance"), namespace="instance"), ), + url( + r"^radios/", + include(("funkwhale_api.radios.urls_v2", "radios"), namespace="radios"), + ), ] urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))] diff --git a/api/funkwhale_api/music/models.py b/api/funkwhale_api/music/models.py index fca044a16..080d159f5 100644 --- a/api/funkwhale_api/music/models.py +++ b/api/funkwhale_api/music/models.py @@ -463,6 +463,8 @@ class TrackQuerySet(common_models.LocalFromFidQuerySet, models.QuerySet): return self.exclude(pk__in=matches) def with_playable_uploads(self, actor): + if not actor: + uploads = Upload.objects.filter(library__privacy_level="public") uploads = Upload.objects.playable_by(actor) return self.prefetch_related( models.Prefetch("uploads", queryset=uploads, to_attr="playable_uploads") diff --git a/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py b/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py new file mode 100644 index 000000000..c72175b24 --- /dev/null +++ b/api/funkwhale_api/radios/migrations/0008_auto_20230613_2123.py @@ -0,0 +1,19 @@ +# Generated by Django 3.2.18 on 2023-06-13 21:23 + +import django.core.serializers.json +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('radios', '0007_merge_20220715_0801'), + ] + + operations = [ + migrations.AddField( + model_name='radiosessiontrack', + name='played', + field=models.BooleanField(default=False), + ), + ] diff --git a/api/funkwhale_api/radios/models.py b/api/funkwhale_api/radios/models.py index 9d8753608..489471179 100644 --- a/api/funkwhale_api/radios/models.py +++ b/api/funkwhale_api/radios/models.py @@ -68,12 +68,18 @@ class RadioSession(models.Model): return next_position - def add(self, track): - new_session_track = RadioSessionTrack.objects.create( - track=track, session=self, position=self.next_position - ) + def add(self, tracks): + next_position = self.next_position + radio_session_tracks = [] + for i, track in enumerate(tracks): + radio_session_track = RadioSessionTrack( + track=track, session=self, position=next_position + i, played=True + ) + radio_session_tracks.append(radio_session_track) - return new_session_track + new_session_tracks = RadioSessionTrack.objects.bulk_create(radio_session_tracks) + + return new_session_tracks @property def radio(self): @@ -90,6 +96,7 @@ class RadioSessionTrack(models.Model): track = models.ForeignKey( Track, related_name="radio_session_tracks", on_delete=models.CASCADE ) + played = models.BooleanField(default=False) class Meta: ordering = ("session", "position") diff --git a/api/funkwhale_api/radios/radios.py b/api/funkwhale_api/radios/radios.py index 0c0ba2efa..c4296e6f7 100644 --- a/api/funkwhale_api/radios/radios.py +++ b/api/funkwhale_api/radios/radios.py @@ -16,12 +16,15 @@ from funkwhale_api.moderation import filters as moderation_filters from funkwhale_api.music.models import Artist, Library, Track, Upload from funkwhale_api.radios import lb_recommendations from funkwhale_api.tags.models import Tag - +from funkwhale_api.radios.models import RadioSessionTrack from . import filters, models from .registries import registry logger = logging.getLogger(__name__) +from funkwhale_api.music.models import Track, Prefetch +from funkwhale_api.music import utils as music_utils + class SimpleRadio: related_object_field = None @@ -64,11 +67,23 @@ class SessionRadio(SimpleRadio): return self.session def get_queryset(self, **kwargs): - qs = Track.objects.all() + qs = ( + Track.objects.all() + .with_playable_uploads(actor=None) + .select_related("artist", "album__artist", "attributed_to") + ) + if not self.session: return qs if not self.session.user: return qs + + qs = ( + Track.objects.all() + .with_playable_uploads(self.session.user.actor) + .select_related("artist", "album__artist", "attributed_to") + ) + query = moderation_filters.get_filtered_content_query( config=moderation_filters.USER_FILTER_CONFIG["TRACK"], user=self.session.user, @@ -80,20 +95,7 @@ class SessionRadio(SimpleRadio): def get_choices(self, **kwargs): kwargs.update(self.get_queryset_kwargs()) - if self.session and cache.get(f"radioqueryset{self.session.id}"): - logger.info("Using redis cache for radio generation") - queryset = cache.get(f"radioqueryset{self.session.id}") - elif self.session: - queryset = self.get_queryset(**kwargs) - logger.info("Setting redis cache for radio generation") - cache.set( - f"radioqueryset{self.session.id}", - queryset, - 3600, - ) - else: - queryset = self.get_queryset(**kwargs) - + queryset = self.get_queryset(**kwargs) if self.session: queryset = self.filter_from_session(queryset) if kwargs.pop("filter_playable", True): @@ -120,10 +122,83 @@ class SessionRadio(SimpleRadio): choices = self.get_choices(**kwargs) picked_choices = super().pick_many(choices=choices, quantity=quantity) if self.session: - for choice in picked_choices: - self.session.add(choice) + self.session.add(picked_choices) return picked_choices + def cache_batch_radio_track(self, quantity, **kwargs): + BATCH_SIZE = 100 + # get the queryset and apply filters + queryset = self.get_queryset(**kwargs) + queryset = self.filter_already_played_from_session(queryset) + if kwargs["filter_playable"] == True: + queryset = queryset.playable_by( + self.session.user.actor if self.session.user else None + ) + queryset = self.filter_queryset(queryset) + + # select a random batch of the qs + sliced_queryset = queryset.order_by("?")[:BATCH_SIZE] + if len(sliced_queryset) == 0: + raise ValueError("No more radio candidates") + # create the radio session tracks into db in bulk + radio_tracks = self.session.add(sliced_queryset) + + # evaluate the queryset to save it in cache + evaluated_radio_tracks = [t for t in radio_tracks] + logger.debug( + f"Setting redis cache for radio generation with radio id {self.session.id}" + ) + cache.set(f"radiosessiontracks{self.session.id}", evaluated_radio_tracks, 3600) + cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600) + + return sliced_queryset + + def filter_already_played_from_session(self, queryset): + if already_played := self.session.session_tracks.filter( + played=True + ).values_list("track", flat=True): + logger.debug("Filtering already played track " + str(already_played)) + queryset = queryset.exclude(pk__in=already_played) + else: + logger.debug("No track already played") + return queryset + + def get_choices_v2(self, quantity, **kwargs): + kwargs.update(self.get_queryset_kwargs()) + if cached_radio_tracks := cache.get(f"radiosessiontracks{self.session.id}"): + logger.debug("Using redis cache for radio generation") + radio_tracks = cached_radio_tracks + if len(radio_tracks) < quantity: + logger.debug( + "Not enough radio tracks in cache. Trying to generate new cache" + ) + sliced_queryset = self.cache_batch_radio_track(quantity, **kwargs) + sliced_queryset = cache.get(f"radioqueryset{self.session.id}") + else: + sliced_queryset = self.cache_batch_radio_track(quantity, **kwargs) + + return sliced_queryset + + def pick_v2(self, **kwargs): + return self.pick_many_v2(quantity=1, **kwargs)[0] + + def pick_many_v2(self, quantity, **kwargs): + if self.session: + sliced_queryset = self.get_choices_v2(quantity, **kwargs) + evaluated_radio_tracks = cache.get(f"radiosessiontracks{self.session.id}") + batch = evaluated_radio_tracks[0:quantity] + for radiotrack in batch: + radiotrack.played = True + RadioSessionTrack.objects.bulk_update(batch, ["played"]) + + else: + logger.debug( + "No radio session. Can't track user playback. Won't cache queryset results" + ) + sliced_queryset = self.get_choices_v2(quantity, **kwargs) + + return sliced_queryset + def validate_session(self, data, **context): return data diff --git a/api/funkwhale_api/radios/serializers.py b/api/funkwhale_api/radios/serializers.py index 76e847d9e..44b828c7b 100644 --- a/api/funkwhale_api/radios/serializers.py +++ b/api/funkwhale_api/radios/serializers.py @@ -1,7 +1,10 @@ from rest_framework import serializers +from funkwhale_api.music import utils as music_utils + from funkwhale_api.music.serializers import TrackSerializer from funkwhale_api.users.serializers import UserBasicSerializer +from funkwhale_api.music import models as music_models from . import filters, models from .radios import registry @@ -40,9 +43,11 @@ class RadioSerializer(serializers.ModelSerializer): class RadioSessionTrackSerializerCreate(serializers.ModelSerializer): + count = serializers.IntegerField(required=False, allow_null=True) + class Meta: model = models.RadioSessionTrack - fields = ("session",) + fields = ("session", "count") class RadioSessionTrackSerializer(serializers.ModelSerializer): diff --git a/api/funkwhale_api/radios/urls_v2.py b/api/funkwhale_api/radios/urls_v2.py new file mode 100644 index 000000000..9a84a3272 --- /dev/null +++ b/api/funkwhale_api/radios/urls_v2.py @@ -0,0 +1,10 @@ +from funkwhale_api.common import routers + +from . import views + +router = routers.OptionalSlashRouter() + +router.register(r"tracks", views.RadioSessionTracksViewSet, "tracks") + + +urlpatterns = router.urls diff --git a/api/funkwhale_api/radios/views.py b/api/funkwhale_api/radios/views.py index adf2fe464..ffb16f88e 100644 --- a/api/funkwhale_api/radios/views.py +++ b/api/funkwhale_api/radios/views.py @@ -1,3 +1,4 @@ +from django.core.cache import cache from django.db.models import Q from drf_spectacular.utils import extend_schema from rest_framework import mixins, status, viewsets @@ -11,6 +12,10 @@ from funkwhale_api.users.oauth import permissions as oauth_permissions from . import filters, filtersets, models, serializers +import logging + +logger = logging.getLogger(__name__) + class RadioViewSet( mixins.CreateModelMixin, @@ -161,3 +166,66 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet) if self.action == "create": return serializers.RadioSessionTrackSerializerCreate return super().get_serializer_class(*args, **kwargs) + + +class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): + """Return a list of RadioSessionTracks""" + + serializer_class = serializers.RadioSessionTrackSerializer + queryset = models.RadioSessionTrack.objects.all() + permission_classes = [] + + @extend_schema(operation_id="get_radio_tracks") + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + session = serializer.validated_data["session"] + count = ( + serializer.validated_data["count"] + if "count" in serializer.validated_data.keys() + else 1 + ) + filter_playable = ( + request.data["filter_playable"] + if "filter_playable" in request.data.keys() + else True + ) + if not request.user.is_authenticated and not request.session.session_key: + self.request.session.create() + try: + assert (request.user == session.user) or ( + request.session.session_key == session.session_key + and session.session_key + ) + except AssertionError: + return Response(status=status.HTTP_403_FORBIDDEN) + + try: + session.radio.pick_many_v2(count, filter_playable=filter_playable) + except ValueError: + return Response( + "Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND + ) + # self.perform_create(serializer) + # dirty override here, since we use a different serializer for creation and detail + evaluated_radio_tracks = cache.get(f"radiosessiontracks{session.id}") + serializer = self.serializer_class( + data=evaluated_radio_tracks[:count], + context=self.get_serializer_context(), + many="true", + ) + serializer.is_valid() + headers = self.get_success_headers(serializer.data) + + # delete the tracks we send from the cache + new_cached_radiotracks = evaluated_radio_tracks[count:] + cache.set(f"radiosessiontracks{session.id}", new_cached_radiotracks) + + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) + + def get_serializer_class(self, *args, **kwargs): + if self.action == "create": + return serializers.RadioSessionTrackSerializerCreate + return super().get_serializer_class(*args, **kwargs) diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index 1e9c02321..1ad905ede 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -2,6 +2,7 @@ import json import random import pytest +from django.core.cache import cache from django.core.exceptions import ValidationError from django.urls import reverse @@ -101,7 +102,8 @@ def test_can_get_choices_for_custom_radio(factories): choices = session.radio.get_choices(filter_playable=False) expected = [t.pk for t in tracks] - assert list(choices.values_list("id", flat=True)) == expected + for t in list(choices.values_list("id", flat=True)): + assert t in expected def test_cannot_start_custom_radio_if_not_owner_or_not_public(factories): @@ -190,6 +192,32 @@ def test_can_get_track_for_session_from_api(factories, logged_in_api_client): assert data["position"] == 2 +def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client): + actor = logged_in_api_client.user.create_actor() + track = factories["music.Upload"]( + library__actor=actor, import_status="finished" + ).track + url = reverse("api:v1:radios:sessions-list") + response = logged_in_api_client.post(url, {"radio_type": "random"}) + session = models.RadioSession.objects.latest("id") + + url = reverse("api:v2:radios:tracks-list") + response = logged_in_api_client.post(url, {"session": session.pk}) + data = json.loads(response.content.decode("utf-8")) + + assert data[0]["track"]["id"] == track.pk + assert data[0]["position"] == 1 + + next_track = factories["music.Upload"]( + library__actor=actor, import_status="finished" + ).track + response = logged_in_api_client.post(url, {"session": session.pk}) + data = json.loads(response.content.decode("utf-8")) + + assert data[0]["track"]["id"] == next_track.id + assert data[0]["position"] == 2 + + def test_related_object_radio_validate_related_object(factories): user = factories["users.User"]() # cannot start without related object @@ -394,7 +422,8 @@ def test_get_choices_for_custom_radio_exclude_artist(factories): choices = session.radio.get_choices(filter_playable=False) expected = [u.track.pk for u in included_uploads] - assert list(choices.values_list("id", flat=True)) == expected + for t in list(choices.values_list("id", flat=True)): + assert t in expected def test_get_choices_for_custom_radio_exclude_tag(factories): @@ -412,7 +441,8 @@ def test_get_choices_for_custom_radio_exclude_tag(factories): choices = session.radio.get_choices(filter_playable=False) expected = [u.track.pk for u in included_uploads] - assert list(choices.values_list("id", flat=True)) == expected + for t in list(choices.values_list("id", flat=True)): + assert t in expected def test_can_start_custom_multiple_radio_from_api(api_client, factories): @@ -431,26 +461,78 @@ def test_can_start_custom_multiple_radio_from_api(api_client, factories): assert response.status_code == 201 -def test_can_start_periodic_jams_troi_radio_from_api(api_client, factories): - factories["music.Track"].create_batch(5) +def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_client): + tracks = factories["music.Track"].create_batch(5) url = reverse("api:v1:radios:sessions-list") - config = {"patch": "periodic-jams", "type": "daily-jams"} - response = api_client.post( - url, - {"radio_type": "troi", "config": config}, - format="json", + response = logged_in_api_client.post(url, {"radio_type": "random"}) + session = models.RadioSession.objects.latest("id") + url = reverse("api:v2:radios:tracks-list") + + previous_choices = [] + + for i in range(5): + response = logged_in_api_client.post( + url, {"session": session.pk, "filter_playable": False} + ) + pick = json.loads(response.content.decode("utf-8")) + assert pick[0]["track"]["title"] not in previous_choices + assert pick[0]["track"]["title"] in [t.title for t in tracks] + previous_choices.append(pick[0]["track"]["title"]) + + response = logged_in_api_client.post(url, {"session": session.pk}) + assert ( + json.loads(response.content.decode("utf-8")) + == "Radio doesn't have more candidates" ) - assert response.status_code == 201 -# to do : send error to api ? -def test_can_catch_troi_radio_error(api_client, factories): - factories["music.Track"].create_batch(5) - url = reverse("api:v1:radios:sessions-list") - config = {"patch": "periodic-jams", "type": "not_existing_type"} - response = api_client.post( - url, - {"radio_type": "troi", "config": config}, - format="json", +def test_can_get_choices_for_favorites_radio_v2(factories): + files = factories["music.Upload"].create_batch(10) + tracks = [f.track for f in files] + user = factories["users.User"]() + for i in range(5): + TrackFavorite.add(track=random.choice(tracks), user=user) + + radio = radios.FavoritesRadio() + session = radio.start_session(user=user) + choices = session.radio.get_choices_v2(quantity=100, filter_playable=False) + + assert len(choices) == user.track_favorites.all().count() + + for favorite in user.track_favorites.all(): + assert favorite.track in choices + + +def test_can_get_choices_for_custom_radio_v2(factories): + artist = factories["music.Artist"]() + files = factories["music.Upload"].create_batch(5, track__artist=artist) + tracks = [f.track for f in files] + factories["music.Upload"].create_batch(5) + + session = factories["radios.CustomRadioSession"]( + custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] ) - assert response.status_code == 201 + choices = session.radio.get_choices_v2(quantity=1, filter_playable=False) + + expected = [t.pk for t in tracks] + for t in choices: + assert t.id in expected + + +from funkwhale_api.music.models import Track + + +def test_can_cache_radio_track(factories): + uploads = factories["music.Track"].create_batch(10) + user = factories["users.User"]() + for t in Track.objects.all().playable_by(user.actor): + assert t in uploads + + radio = radios.RandomRadio() + session = radio.start_session(user) + picked = session.radio.pick_many_v2(quantity=1, filter_playable=False) + assert len(picked) == 10 + for t in cache.get(f"radioqueryset{session.id}"): + assert t in picked + for t in cache.get(f"radiosessiontracks{session.id}"): + assert t.track in uploads