resolve code review 2

This commit is contained in:
Petitminion 2023-06-27 14:56:31 +02:00
parent 66e49cb12e
commit 32daed3524
5 changed files with 60 additions and 102 deletions

View File

@ -1,19 +0,0 @@
# Generated by Django 3.2.18 on 2023-06-13 21:23
import django.core.serializers.json
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('radios', '0007_merge_20220715_0801'),
]
operations = [
migrations.AddField(
model_name='radiosessiontrack',
name='played',
field=models.BooleanField(default=False),
),
]

View File

@ -73,7 +73,7 @@ class RadioSession(models.Model):
radio_session_tracks = []
for i, track in enumerate(tracks):
radio_session_track = RadioSessionTrack(
track=track, session=self, position=next_position + i, played=True
track=track, session=self, position=next_position + i
)
radio_session_tracks.append(radio_session_track)
@ -96,7 +96,6 @@ class RadioSessionTrack(models.Model):
track = models.ForeignKey(
Track, related_name="radio_session_tracks", on_delete=models.CASCADE
)
played = models.BooleanField(default=False)
class Meta:
ordering = ("session", "position")

View File

@ -29,14 +29,14 @@ class SimpleRadio:
def clean(self, instance):
return
def pick(
def pick_v1(
self, choices: List[int], previous_choices: Optional[List[int]] = None
) -> int:
if previous_choices:
choices = list(set(choices).difference(set(previous_choices)))
return random.sample(choices, 1)[0]
def pick_many(self, choices: List[int], quantity: int) -> int:
def pick_many_v1(self, choices: List[int], quantity: int) -> int:
return random.sample(list(set(choices)), quantity)
def weighted_pick(
@ -64,22 +64,19 @@ class SessionRadio(SimpleRadio):
return self.session
def get_queryset(self, **kwargs):
qs = (
Track.objects.all()
.with_playable_uploads(actor=None)
.select_related("artist", "album__artist", "attributed_to")
)
if not self.session:
return qs
if not self.session.user:
return qs
qs = (
Track.objects.all()
.with_playable_uploads(self.session.user.actor)
.select_related("artist", "album__artist", "attributed_to")
)
if not self.session or not self.session.user:
return (
Track.objects.all()
.with_playable_uploads(actor=None)
.select_related("artist", "album__artist", "attributed_to")
)
else:
qs = (
Track.objects.all()
.with_playable_uploads(self.session.user.actor)
.select_related("artist", "album__artist", "attributed_to")
)
query = moderation_filters.get_filtered_content_query(
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
@ -90,18 +87,6 @@ class SessionRadio(SimpleRadio):
def get_queryset_kwargs(self):
return {}
def get_choices(self, **kwargs):
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
if self.session:
queryset = self.filter_from_session(queryset)
if kwargs.pop("filter_playable", True):
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
)
queryset = self.filter_queryset(queryset)
return queryset
def filter_queryset(self, queryset):
return queryset
@ -112,12 +97,24 @@ class SessionRadio(SimpleRadio):
queryset = queryset.exclude(pk__in=already_played)
return queryset
def pick(self, **kwargs):
return self.pick_many(quantity=1, **kwargs)[0]
def get_choices_v1(self, **kwargs):
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
if self.session:
queryset = self.filter_from_session(queryset)
if kwargs.pop("filter_playable", True):
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
)
queryset = self.filter_queryset(queryset)
return queryset
def pick_many(self, quantity, **kwargs):
choices = self.get_choices(**kwargs)
picked_choices = super().pick_many(choices=choices, quantity=quantity)
def pick_v1(self, **kwargs):
return self.pick_many_v1(quantity=1, **kwargs)[0]
def pick_many_v1(self, quantity, **kwargs):
choices = self.get_choices_v1(**kwargs)
picked_choices = super().pick_many_v1(choices=choices, quantity=quantity)
if self.session:
self.session.add(picked_choices)
return picked_choices
@ -135,7 +132,8 @@ class SessionRadio(SimpleRadio):
# get the queryset and apply filters
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
queryset = self.filter_already_played_from_session(queryset)
queryset = self.filter_from_session(queryset)
if kwargs["filter_playable"] is True:
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
@ -152,7 +150,7 @@ class SessionRadio(SimpleRadio):
# evaluate the queryset to save it in cache
if cached_evaluated_radio_tracks is not None:
radio_tracks = [t for t in radio_tracks]
radio_tracks = list(radio_tracks)
radio_tracks.extend(cached_evaluated_radio_tracks)
logger.info(
f"Setting redis cache for radio generation with radio id {self.session.id}"
@ -164,16 +162,6 @@ class SessionRadio(SimpleRadio):
return sliced_queryset
def filter_already_played_from_session(self, queryset):
if already_played := self.session.session_tracks.filter(
played=True
).values_list("track", flat=True):
logger.debug("Filtering already played track " + str(already_played))
queryset = queryset.exclude(pk__in=already_played)
else:
logger.debug("No track already played")
return queryset
def get_choices_v2(self, quantity, **kwargs):
if cache.get(f"radiosessiontracks{self.session.id}"):
cached_radio_tracks = pickle.loads(

View File

@ -10,7 +10,6 @@ from rest_framework.response import Response
from funkwhale_api.common import permissions as common_permissions
from funkwhale_api.music import utils as music_utils
from funkwhale_api.music.serializers import TrackSerializer
from funkwhale_api.radios.models import RadioSessionTrack
from funkwhale_api.users.oauth import permissions as oauth_permissions
from . import filters, filtersets, models, serializers
@ -137,15 +136,13 @@ class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewS
session = serializer.validated_data["session"]
if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create()
try:
assert (request.user == session.user) or (
request.session.session_key == session.session_key
and session.session_key
)
except AssertionError:
if not request.user == session.user or not (
request.session.session_key == session.session_key and session.session_key
):
return Response(status=status.HTTP_403_FORBIDDEN)
try:
session.radio.pick()
session.radio.pick_v1()
except ValueError:
return Response(
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
@ -192,12 +189,10 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
)
if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create()
try:
assert (request.user == session.user) or (
request.session.session_key == session.session_key
and session.session_key
)
except AssertionError:
if not request.user == session.user or not (
request.session.session_key == session.session_key and session.session_key
):
return Response(status=status.HTTP_403_FORBIDDEN)
try:
@ -220,11 +215,6 @@ class RadioSessionTracksViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet
serializer.is_valid()
headers = self.get_success_headers(serializer.data)
# mark the RadioTracks has played
for radiotrack in batch:
radiotrack.played = True
RadioSessionTrack.objects.bulk_update(batch, ["played"])
# delete the tracks we sent from the cache
new_cached_radiotracks = evaluated_radio_tracks[count:]
cache.set(

View File

@ -17,13 +17,13 @@ def test_can_pick_track_from_choices():
radio = radios.SimpleRadio()
first_pick = radio.pick(choices=choices)
first_pick = radio.pick_v1(choices=choices)
assert first_pick in choices
previous_choices = [first_pick]
for remaining_choice in choices:
pick = radio.pick(choices=choices, previous_choices=previous_choices)
pick = radio.pick_v1(choices=choices, previous_choices=previous_choices)
assert pick in set(choices).difference(set(previous_choices))
@ -62,14 +62,14 @@ def test_session_radio_excludes_previous_picks(factories):
radio.start_session(user)
for i in range(5):
pick = radio.pick(user=user, filter_playable=False)
pick = radio.pick_v1(user=user, filter_playable=False)
assert pick in tracks
assert pick not in previous_choices
previous_choices.append(pick)
with pytest.raises(ValueError):
# no more picks available
radio.pick(user=user, filter_playable=False)
radio.pick_v1(user=user, filter_playable=False)
def test_can_get_choices_for_favorites_radio(factories):
@ -80,7 +80,7 @@ def test_can_get_choices_for_favorites_radio(factories):
TrackFavorite.add(track=random.choice(tracks), user=user)
radio = radios.FavoritesRadio()
choices = radio.get_choices(user=user)
choices = radio.get_choices_v1(user=user)
assert choices.count() == user.track_favorites.all().count()
@ -88,7 +88,7 @@ def test_can_get_choices_for_favorites_radio(factories):
assert favorite.track in choices
for i in range(5):
pick = radio.pick(user=user)
pick = radio.pick_v1(user=user)
assert pick in choices
@ -101,7 +101,7 @@ def test_can_get_choices_for_custom_radio(factories):
session = factories["radios.CustomRadioSession"](
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio.get_choices_v1(filter_playable=False)
expected = [t.pk for t in tracks]
assert list(choices.values_list("id", flat=True)) == expected
@ -141,7 +141,7 @@ def test_can_use_radio_session_to_filter_choices(factories):
session = radio.start_session(user)
for i in range(10):
radio.pick(filter_playable=False)
radio.pick_v1(filter_playable=False)
# ensure 10 different tracks have been suggested
tracks_id = [
@ -243,7 +243,7 @@ def test_can_start_artist_radio(factories):
session = radio.start_session(user, related_object=artist)
assert session.radio_type == "artist"
for i in range(5):
assert radio.pick(filter_playable=False) in good_tracks
assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_tag_radio(factories):
@ -261,7 +261,7 @@ def test_can_start_tag_radio(factories):
assert session.radio_type == "tag"
for i in range(3):
assert radio.pick(filter_playable=False) in good_tracks
assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_actor_content_radio(factories):
@ -280,7 +280,7 @@ def test_can_start_actor_content_radio(factories):
assert session.radio_type == "actor-content"
for i in range(3):
assert radio.pick() in good_tracks
assert radio.pick_v1() in good_tracks
def test_can_start_actor_content_radio_from_api(
@ -316,7 +316,7 @@ def test_can_start_library_radio(factories):
assert session.radio_type == "library"
for i in range(3):
assert radio.pick(filter_playable=False) in good_tracks
assert radio.pick_v1(filter_playable=False) in good_tracks
def test_can_start_library_radio_from_api(logged_in_api_client, preferences, factories):
@ -362,7 +362,7 @@ def test_can_start_less_listened_radio(factories):
radio.start_session(user)
for i in range(5):
assert radio.pick(filter_playable=False) in good_tracks
assert radio.pick_v1(filter_playable=False) in good_tracks
def test_similar_radio_track(factories):
@ -379,7 +379,7 @@ def test_similar_radio_track(factories):
expected_next = factories["music.Track"]()
factories["history.Listening"](track=expected_next, user=l1.user)
assert radio.pick(filter_playable=False) == expected_next
assert radio.pick_v1(filter_playable=False) == expected_next
def test_session_radio_get_queryset_ignore_filtered_track_artist(
@ -420,7 +420,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories):
{"type": "artist", "ids": [excluded_artist.pk], "not": True},
]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio.get_choices_v1(filter_playable=False)
expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected
@ -438,7 +438,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories):
{"type": "tag", "names": ["rock"], "not": True},
]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio.get_choices_v1(filter_playable=False)
expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected