resolve code review 2

This commit is contained in:
Petitminion 2023-06-27 14:56:31 +02:00
parent 66e49cb12e
commit 32daed3524
5 changed files with 60 additions and 102 deletions

View File

@ -1,19 +0,0 @@
# Generated by Django 3.2.18 on 2023-06-13 21:23
import django.core.serializers.json
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('radios', '0007_merge_20220715_0801'),
]
operations = [
migrations.AddField(
model_name='radiosessiontrack',
name='played',
field=models.BooleanField(default=False),
),
]

View File

@ -73,7 +73,7 @@ class RadioSession(models.Model):
radio_session_tracks = [] radio_session_tracks = []
for i, track in enumerate(tracks): for i, track in enumerate(tracks):
radio_session_track = RadioSessionTrack( radio_session_track = RadioSessionTrack(
track=track, session=self, position=next_position + i, played=True track=track, session=self, position=next_position + i
) )
radio_session_tracks.append(radio_session_track) radio_session_tracks.append(radio_session_track)
@ -96,7 +96,6 @@ class RadioSessionTrack(models.Model):
track = models.ForeignKey( track = models.ForeignKey(
Track, related_name="radio_session_tracks", on_delete=models.CASCADE Track, related_name="radio_session_tracks", on_delete=models.CASCADE
) )
played = models.BooleanField(default=False)
class Meta: class Meta:
ordering = ("session", "position") ordering = ("session", "position")

View File

@ -29,14 +29,14 @@ class SimpleRadio:
def clean(self, instance): def clean(self, instance):
return return
def pick( def pick_v1(
self, choices: List[int], previous_choices: Optional[List[int]] = None self, choices: List[int], previous_choices: Optional[List[int]] = None
) -> int: ) -> int:
if previous_choices: if previous_choices:
choices = list(set(choices).difference(set(previous_choices))) choices = list(set(choices).difference(set(previous_choices)))
return random.sample(choices, 1)[0] return random.sample(choices, 1)[0]
def pick_many(self, choices: List[int], quantity: int) -> int: def pick_many_v1(self, choices: List[int], quantity: int) -> int:
return random.sample(list(set(choices)), quantity) return random.sample(list(set(choices)), quantity)
def weighted_pick( def weighted_pick(
@ -64,22 +64,19 @@ class SessionRadio(SimpleRadio):
return self.session return self.session
def get_queryset(self, **kwargs): def get_queryset(self, **kwargs):
qs = (
Track.objects.all()
.with_playable_uploads(actor=None)
.select_related("artist", "album__artist", "attributed_to")
)
if not self.session: if not self.session or not self.session.user:
return qs return (
if not self.session.user: Track.objects.all()
return qs .with_playable_uploads(actor=None)
.select_related("artist", "album__artist", "attributed_to")
qs = ( )
Track.objects.all() else:
.with_playable_uploads(self.session.user.actor) qs = (
.select_related("artist", "album__artist", "attributed_to") Track.objects.all()
) .with_playable_uploads(self.session.user.actor)
.select_related("artist", "album__artist", "attributed_to")
)
query = moderation_filters.get_filtered_content_query( query = moderation_filters.get_filtered_content_query(
config=moderation_filters.USER_FILTER_CONFIG["TRACK"], config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
@ -90,18 +87,6 @@ class SessionRadio(SimpleRadio):
def get_queryset_kwargs(self): def get_queryset_kwargs(self):
return {} return {}
def get_choices(self, **kwargs):
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
if self.session:
queryset = self.filter_from_session(queryset)
if kwargs.pop("filter_playable", True):
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
)
queryset = self.filter_queryset(queryset)
return queryset
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
return queryset return queryset
@ -112,12 +97,24 @@ class SessionRadio(SimpleRadio):
queryset = queryset.exclude(pk__in=already_played) queryset = queryset.exclude(pk__in=already_played)
return queryset return queryset
def pick(self, **kwargs): def get_choices_v1(self, **kwargs):
return self.pick_many(quantity=1, **kwargs)[0] kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
if self.session:
queryset = self.filter_from_session(queryset)
if kwargs.pop("filter_playable", True):
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
)
queryset = self.filter_queryset(queryset)
return queryset
def pick_many(self, quantity, **kwargs): def pick_v1(self, **kwargs):
choices = self.get_choices(**kwargs) return self.pick_many_v1(quantity=1, **kwargs)[0]
picked_choices = super().pick_many(choices=choices, quantity=quantity)
def pick_many_v1(self, quantity, **kwargs):
choices = self.get_choices_v1(**kwargs)
picked_choices = super().pick_many_v1(choices=choices, quantity=quantity)
if self.session: if self.session:
self.session.add(picked_choices) self.session.add(picked_choices)
return picked_choices return picked_choices
@ -135,7 +132,8 @@ class SessionRadio(SimpleRadio):
# get the queryset and apply filters # get the queryset and apply filters
kwargs.update(self.get_queryset_kwargs()) kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs) queryset = self.get_queryset(**kwargs)
queryset = self.filter_already_played_from_session(queryset) queryset = self.filter_from_session(queryset)
if kwargs["filter_playable"] is True: if kwargs["filter_playable"] is True:
queryset = queryset.playable_by( queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None self.session.user.actor if self.session.user else None
@ -152,7 +150,7 @@ class SessionRadio(SimpleRadio):
# evaluate the queryset to save it in cache # evaluate the queryset to save it in cache
if cached_evaluated_radio_tracks is not None: if cached_evaluated_radio_tracks is not None:
radio_tracks = [t for t in radio_tracks] radio_tracks = list(radio_tracks)
radio_tracks.extend(cached_evaluated_radio_tracks) radio_tracks.extend(cached_evaluated_radio_tracks)
logger.info( logger.info(
f"Setting redis cache for radio generation with radio id {self.session.id}" f"Setting redis cache for radio generation with radio id {self.session.id}"
@ -164,16 +162,6 @@ class SessionRadio(SimpleRadio):
return sliced_queryset return sliced_queryset
def filter_already_played_from_session(self, queryset):
if already_played := self.session.session_tracks.filter(
played=True
).values_list("track", flat=True):
logger.debug("Filtering already played track " + str(already_played))
queryset = queryset.exclude(pk__in=already_played)
else:
logger.debug("No track already played")
return queryset
def get_choices_v2(self, quantity, **kwargs): def get_choices_v2(self, quantity, **kwargs):
if cache.get(f"radiosessiontracks{self.session.id}"): if cache.get(f"radiosessiontracks{self.session.id}"):
cached_radio_tracks = pickle.loads( cached_radio_tracks = pickle.loads(

View File

@ -10,7 +10,6 @@ from rest_framework.response import Response
from funkwhale_api.common import permissions as common_permissions from funkwhale_api.common import permissions as common_permissions
from funkwhale_api.music import utils as music_utils from funkwhale_api.music import utils as music_utils
from funkwhale_api.music.serializers import TrackSerializer from funkwhale_api.music.serializers import TrackSerializer
from funkwhale_api.radios.models import RadioSessionTrack
from funkwhale_api.users.oauth import permissions as oauth_permissions from funkwhale_api.users.oauth import permissions as oauth_permissions
from . import filters, filtersets, models, serializers from . import filters, filtersets, models, serializers
@ -137,15 +136,13 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS
session = serializer.validated_data["session"] session = serializer.validated_data["session"]
if not request.user.is_authenticated and not request.session.session_key: if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create() self.request.session.create()
try: if not request.user == session.user or not (
assert (request.user == session.user) or ( request.session.session_key == session.session_key and session.session_key
request.session.session_key == session.session_key ):
and session.session_key
)
except AssertionError:
return Response(status=status.HTTP_403_FORBIDDEN) return Response(status=status.HTTP_403_FORBIDDEN)
try: try:
session.radio.pick() session.radio.pick_v1()
except ValueError: except ValueError:
return Response( return Response(
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND "Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
@ -192,12 +189,10 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
) )
if not request.user.is_authenticated and not request.session.session_key: if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create() self.request.session.create()
try:
assert (request.user == session.user) or ( if not request.user == session.user or not (
request.session.session_key == session.session_key request.session.session_key == session.session_key and session.session_key
and session.session_key ):
)
except AssertionError:
return Response(status=status.HTTP_403_FORBIDDEN) return Response(status=status.HTTP_403_FORBIDDEN)
try: try:
@ -220,11 +215,6 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
serializer.is_valid() serializer.is_valid()
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
# mark the RadioTracks has played
for radiotrack in batch:
radiotrack.played = True
RadioSessionTrack.objects.bulk_update(batch, ["played"])
# delete the tracks we sent from the cache # delete the tracks we sent from the cache
new_cached_radiotracks = evaluated_radio_tracks[count:] new_cached_radiotracks = evaluated_radio_tracks[count:]
cache.set( cache.set(

View File

@ -17,13 +17,13 @@ def test_can_pick_track_from_choices():
radio = radios.SimpleRadio() radio = radios.SimpleRadio()
first_pick = radio.pick(choices=choices) first_pick = radio.pick_v1(choices=choices)
assert first_pick in choices assert first_pick in choices
previous_choices = [first_pick] previous_choices = [first_pick]
for remaining_choice in choices: for remaining_choice in choices:
pick = radio.pick(choices=choices, previous_choices=previous_choices) pick = radio.pick_v1(choices=choices, previous_choices=previous_choices)
assert pick in set(choices).difference(set(previous_choices)) assert pick in set(choices).difference(set(previous_choices))
@ -62,14 +62,14 @@ def test_session_radio_excludes_previous_picks(factories):
radio.start_session(user) radio.start_session(user)
for i in range(5): for i in range(5):
pick = radio.pick(user=user, filter_playable=False) pick = radio.pick_v1(user=user, filter_playable=False)
assert pick in tracks assert pick in tracks
assert pick not in previous_choices assert pick not in previous_choices
previous_choices.append(pick) previous_choices.append(pick)
with pytest.raises(ValueError): with pytest.raises(ValueError):
# no more picks available # no more picks available
radio.pick(user=user, filter_playable=False) radio.pick_v1(user=user, filter_playable=False)
def test_can_get_choices_for_favorites_radio(factories): def test_can_get_choices_for_favorites_radio(factories):
@ -80,7 +80,7 @@ def test_can_get_choices_for_favorites_radio(factories):
TrackFavorite.add(track=random.choice(tracks), user=user) TrackFavorite.add(track=random.choice(tracks), user=user)
radio = radios.FavoritesRadio() radio = radios.FavoritesRadio()
choices = radio.get_choices(user=user) choices = radio.get_choices_v1(user=user)
assert choices.count() == user.track_favorites.all().count() assert choices.count() == user.track_favorites.all().count()
@ -88,7 +88,7 @@ def test_can_get_choices_for_favorites_radio(factories):
assert favorite.track in choices assert favorite.track in choices
for i in range(5): for i in range(5):
pick = radio.pick(user=user) pick = radio.pick_v1(user=user)
assert pick in choices assert pick in choices
@ -101,7 +101,7 @@ def test_can_get_choices_for_custom_radio(factories):
session = factories["radios.CustomRadioSession"]( session = factories["radios.CustomRadioSession"](
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
) )
choices = session.radio.get_choices(filter_playable=False) choices = session.radio.get_choices_v1(filter_playable=False)
expected = [t.pk for t in tracks] expected = [t.pk for t in tracks]
assert list(choices.values_list("id", flat=True)) == expected assert list(choices.values_list("id", flat=True)) == expected
@ -141,7 +141,7 @@ def test_can_use_radio_session_to_filter_choices(factories):
session = radio.start_session(user) session = radio.start_session(user)
for i in range(10): for i in range(10):
radio.pick(filter_playable=False) radio.pick_v1(filter_playable=False)
# ensure 10 different tracks have been suggested # ensure 10 different tracks have been suggested
tracks_id = [ tracks_id = [
@ -243,7 +243,7 @@ def test_can_start_artist_radio(factories):
session = radio.start_session(user, related_object=artist) session = radio.start_session(user, related_object=artist)
assert session.radio_type == "artist" assert session.radio_type == "artist"
for i in range(5): for i in range(5):
assert radio.pick(filter_playable=False) in good_tracks assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_tag_radio(factories): def test_can_start_tag_radio(factories):
@ -261,7 +261,7 @@ def test_can_start_tag_radio(factories):
assert session.radio_type == "tag" assert session.radio_type == "tag"
for i in range(3): for i in range(3):
assert radio.pick(filter_playable=False) in good_tracks assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_actor_content_radio(factories): def test_can_start_actor_content_radio(factories):
@ -280,7 +280,7 @@ def test_can_start_actor_content_radio(factories):
assert session.radio_type == "actor-content" assert session.radio_type == "actor-content"
for i in range(3): for i in range(3):
assert radio.pick() in good_tracks assert radio.pick_v1() in good_tracks
def test_can_start_actor_content_radio_from_api( def test_can_start_actor_content_radio_from_api(
@ -316,7 +316,7 @@ def test_can_start_library_radio(factories):
assert session.radio_type == "library" assert session.radio_type == "library"
for i in range(3): for i in range(3):
assert radio.pick(filter_playable=False) in good_tracks assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_library_radio_from_api(logged_in_api_client, preferences, factories): def test_can_start_library_radio_from_api(logged_in_api_client, preferences, factories):
@ -362,7 +362,7 @@ def test_can_start_less_listened_radio(factories):
radio.start_session(user) radio.start_session(user)
for i in range(5): for i in range(5):
assert radio.pick(filter_playable=False) in good_tracks assert radio.pick_v1(filter_playable=False) in good_tracks
def test_similar_radio_track(factories): def test_similar_radio_track(factories):
@ -379,7 +379,7 @@ def test_similar_radio_track(factories):
expected_next = factories["music.Track"]() expected_next = factories["music.Track"]()
factories["history.Listening"](track=expected_next, user=l1.user) factories["history.Listening"](track=expected_next, user=l1.user)
assert radio.pick(filter_playable=False) == expected_next assert radio.pick_v1(filter_playable=False) == expected_next
def test_session_radio_get_queryset_ignore_filtered_track_artist( def test_session_radio_get_queryset_ignore_filtered_track_artist(
@ -420,7 +420,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories):
{"type": "artist", "ids": [excluded_artist.pk], "not": True}, {"type": "artist", "ids": [excluded_artist.pk], "not": True},
] ]
) )
choices = session.radio.get_choices(filter_playable=False) choices = session.radio.get_choices_v1(filter_playable=False)
expected = [u.track.pk for u in included_uploads] expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected assert list(choices.values_list("id", flat=True)) == expected
@ -438,7 +438,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories):
{"type": "tag", "names": ["rock"], "not": True}, {"type": "tag", "names": ["rock"], "not": True},
] ]
) )
choices = session.radio.get_choices(filter_playable=False) choices = session.radio.get_choices_v1(filter_playable=False)
expected = [u.track.pk for u in included_uploads] expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected assert list(choices.values_list("id", flat=True)) == expected