Cache radio queryset. New api endpoint for radio tracks : api/v2/radios/sessions/$sessionid/tracks?count=$count
This commit is contained in:
parent
04acd056e6
commit
4ad806b8e9
|
@ -10,6 +10,10 @@ v2_patterns += [
|
||||||
r"^instance/",
|
r"^instance/",
|
||||||
include(("funkwhale_api.instance.urls", "instance"), namespace="instance"),
|
include(("funkwhale_api.instance.urls", "instance"), namespace="instance"),
|
||||||
),
|
),
|
||||||
|
url(
|
||||||
|
r"^radios/",
|
||||||
|
include(("funkwhale_api.radios.urls_v2", "radios"), namespace="radios"),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))]
|
urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))]
|
||||||
|
|
|
@ -54,10 +54,6 @@ class RadioSession(models.Model):
|
||||||
CONFIG_VERSION = 0
|
CONFIG_VERSION = 0
|
||||||
config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True)
|
config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True)
|
||||||
|
|
||||||
def save(self, **kwargs):
|
|
||||||
self.radio.clean(self)
|
|
||||||
super().save(**kwargs)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_position(self):
|
def next_position(self):
|
||||||
next_position = 1
|
next_position = 1
|
||||||
|
@ -68,16 +64,24 @@ class RadioSession(models.Model):
|
||||||
|
|
||||||
return next_position
|
return next_position
|
||||||
|
|
||||||
def add(self, track):
|
def add(self, tracks):
|
||||||
new_session_track = RadioSessionTrack.objects.create(
|
next_position = self.next_position
|
||||||
track=track, session=self, position=self.next_position
|
radio_session_tracks = []
|
||||||
)
|
for i, track in enumerate(tracks):
|
||||||
|
radio_session_track = RadioSessionTrack(
|
||||||
|
track=track, session=self, position=next_position + i
|
||||||
|
)
|
||||||
|
radio_session_tracks.append(radio_session_track)
|
||||||
|
|
||||||
return new_session_track
|
new_session_tracks = RadioSessionTrack.objects.bulk_create(radio_session_tracks)
|
||||||
|
|
||||||
@property
|
return new_session_tracks
|
||||||
def radio(self):
|
|
||||||
from .registries import registry
|
def radio(self, api_version):
|
||||||
|
if api_version == 2:
|
||||||
|
from .registries_v2 import registry
|
||||||
|
else:
|
||||||
|
from .registries import registry
|
||||||
|
|
||||||
return registry[self.radio_type](session=self)
|
return registry[self.radio_type](session=self)
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,9 @@ from funkwhale_api.federation import fields as federation_fields
|
||||||
from funkwhale_api.federation import models as federation_models
|
from funkwhale_api.federation import models as federation_models
|
||||||
from funkwhale_api.moderation import filters as moderation_filters
|
from funkwhale_api.moderation import filters as moderation_filters
|
||||||
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||||
from funkwhale_api.radios import lb_recommendations
|
|
||||||
from funkwhale_api.tags.models import Tag
|
from funkwhale_api.tags.models import Tag
|
||||||
|
|
||||||
from . import filters, models
|
from . import filters, lb_recommendations, models
|
||||||
from .registries import registry
|
from .registries import registry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -63,11 +62,19 @@ class SessionRadio(SimpleRadio):
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
def get_queryset(self, **kwargs):
|
def get_queryset(self, **kwargs):
|
||||||
qs = Track.objects.all()
|
if not self.session or not self.session.user:
|
||||||
if not self.session:
|
return (
|
||||||
return qs
|
Track.objects.all()
|
||||||
if not self.session.user:
|
.with_playable_uploads(actor=None)
|
||||||
return qs
|
.select_related("artist", "album__artist", "attributed_to")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qs = (
|
||||||
|
Track.objects.all()
|
||||||
|
.with_playable_uploads(self.session.user.actor)
|
||||||
|
.select_related("artist", "album__artist", "attributed_to")
|
||||||
|
)
|
||||||
|
|
||||||
query = moderation_filters.get_filtered_content_query(
|
query = moderation_filters.get_filtered_content_query(
|
||||||
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
||||||
user=self.session.user,
|
user=self.session.user,
|
||||||
|
@ -77,6 +84,16 @@ class SessionRadio(SimpleRadio):
|
||||||
def get_queryset_kwargs(self):
|
def get_queryset_kwargs(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def filter_queryset(self, queryset):
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def filter_from_session(self, queryset):
|
||||||
|
already_played = self.session.session_tracks.all().values_list(
|
||||||
|
"track", flat=True
|
||||||
|
)
|
||||||
|
queryset = queryset.exclude(pk__in=already_played)
|
||||||
|
return queryset
|
||||||
|
|
||||||
def get_choices(self, **kwargs):
|
def get_choices(self, **kwargs):
|
||||||
kwargs.update(self.get_queryset_kwargs())
|
kwargs.update(self.get_queryset_kwargs())
|
||||||
queryset = self.get_queryset(**kwargs)
|
queryset = self.get_queryset(**kwargs)
|
||||||
|
@ -89,16 +106,6 @@ class SessionRadio(SimpleRadio):
|
||||||
queryset = self.filter_queryset(queryset)
|
queryset = self.filter_queryset(queryset)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def filter_queryset(self, queryset):
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
def filter_from_session(self, queryset):
|
|
||||||
already_played = self.session.session_tracks.all().values_list(
|
|
||||||
"track", flat=True
|
|
||||||
)
|
|
||||||
queryset = queryset.exclude(pk__in=already_played)
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
def pick(self, **kwargs):
|
def pick(self, **kwargs):
|
||||||
return self.pick_many(quantity=1, **kwargs)[0]
|
return self.pick_many(quantity=1, **kwargs)[0]
|
||||||
|
|
||||||
|
@ -106,8 +113,7 @@ class SessionRadio(SimpleRadio):
|
||||||
choices = self.get_choices(**kwargs)
|
choices = self.get_choices(**kwargs)
|
||||||
picked_choices = super().pick_many(choices=choices, quantity=quantity)
|
picked_choices = super().pick_many(choices=choices, quantity=quantity)
|
||||||
if self.session:
|
if self.session:
|
||||||
for choice in picked_choices:
|
self.session.add(picked_choices)
|
||||||
self.session.add(choice)
|
|
||||||
return picked_choices
|
return picked_choices
|
||||||
|
|
||||||
def validate_session(self, data, **context):
|
def validate_session(self, data, **context):
|
||||||
|
@ -191,7 +197,9 @@ class CustomMultiple(SessionRadio):
|
||||||
|
|
||||||
def validate_session(self, data, **context):
|
def validate_session(self, data, **context):
|
||||||
data = super().validate_session(data, **context)
|
data = super().validate_session(data, **context)
|
||||||
if data.get("config") is None:
|
try:
|
||||||
|
data["config"] is not None
|
||||||
|
except KeyError:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"You must provide a configuration for this radio"
|
"You must provide a configuration for this radio"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,510 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.models import Q
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from funkwhale_api.federation import fields as federation_fields
|
||||||
|
from funkwhale_api.federation import models as federation_models
|
||||||
|
from funkwhale_api.moderation import filters as moderation_filters
|
||||||
|
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||||
|
from funkwhale_api.tags.models import Tag
|
||||||
|
|
||||||
|
from . import filters, lb_recommendations, models
|
||||||
|
from .registries_v2 import registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleRadio:
|
||||||
|
related_object_field = None
|
||||||
|
|
||||||
|
def clean(self, instance):
|
||||||
|
return
|
||||||
|
|
||||||
|
def weighted_pick(
|
||||||
|
self,
|
||||||
|
choices: List[Tuple[int, int]],
|
||||||
|
previous_choices: Optional[List[int]] = None,
|
||||||
|
) -> int:
|
||||||
|
total = sum(weight for c, weight in choices)
|
||||||
|
r = random.uniform(0, total)
|
||||||
|
upto = 0
|
||||||
|
for choice, weight in choices:
|
||||||
|
if upto + weight >= r:
|
||||||
|
return choice
|
||||||
|
upto += weight
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRadio(SimpleRadio):
|
||||||
|
def __init__(self, session=None):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def start_session(self, user, **kwargs):
|
||||||
|
self.session = models.RadioSession.objects.create(
|
||||||
|
user=user, radio_type=self.radio_type, **kwargs
|
||||||
|
)
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
actor = None
|
||||||
|
try:
|
||||||
|
actor = self.session.user.actor
|
||||||
|
except KeyError:
|
||||||
|
pass # Maybe logging would be helpful
|
||||||
|
|
||||||
|
qs = (
|
||||||
|
Track.objects.all()
|
||||||
|
.with_playable_uploads(actor=actor)
|
||||||
|
.select_related("artist", "album__artist", "attributed_to")
|
||||||
|
)
|
||||||
|
|
||||||
|
query = moderation_filters.get_filtered_content_query(
|
||||||
|
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
||||||
|
user=self.session.user,
|
||||||
|
)
|
||||||
|
return qs.exclude(query)
|
||||||
|
|
||||||
|
def get_queryset_kwargs(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def filter_queryset(self, queryset):
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def filter_from_session(self, queryset):
|
||||||
|
already_played = self.session.session_tracks.all().values_list(
|
||||||
|
"track", flat=True
|
||||||
|
)
|
||||||
|
queryset = queryset.exclude(pk__in=already_played)
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def cache_batch_radio_track(self, **kwargs):
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
# get cached RadioTracks if any
|
||||||
|
try:
|
||||||
|
cached_evaluated_radio_tracks = pickle.loads(
|
||||||
|
cache.get(f"radiotracks{self.session.id}")
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
cached_evaluated_radio_tracks = None
|
||||||
|
|
||||||
|
# get the queryset and apply filters
|
||||||
|
kwargs.update(self.get_queryset_kwargs())
|
||||||
|
queryset = self.get_queryset(**kwargs)
|
||||||
|
queryset = self.filter_from_session(queryset)
|
||||||
|
|
||||||
|
if kwargs["filter_playable"] is True:
|
||||||
|
queryset = queryset.playable_by(
|
||||||
|
self.session.user.actor if self.session.user else None
|
||||||
|
)
|
||||||
|
queryset = self.filter_queryset(queryset)
|
||||||
|
|
||||||
|
# select a random batch of the qs
|
||||||
|
sliced_queryset = queryset.order_by("?")[:BATCH_SIZE]
|
||||||
|
if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks:
|
||||||
|
raise ValueError("No more radio candidates")
|
||||||
|
|
||||||
|
# create the radio session tracks into db in bulk
|
||||||
|
self.session.add(sliced_queryset)
|
||||||
|
|
||||||
|
# evaluate the queryset to save it in cache
|
||||||
|
radio_tracks = list(sliced_queryset)
|
||||||
|
|
||||||
|
if cached_evaluated_radio_tracks is not None:
|
||||||
|
radio_tracks.extend(cached_evaluated_radio_tracks)
|
||||||
|
logger.info(
|
||||||
|
f"Setting redis cache for radio generation with radio id {self.session.id}"
|
||||||
|
)
|
||||||
|
cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600)
|
||||||
|
cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600)
|
||||||
|
|
||||||
|
return sliced_queryset
|
||||||
|
|
||||||
|
def get_choices(self, quantity, **kwargs):
|
||||||
|
if cache.get(f"radiotracks{self.session.id}"):
|
||||||
|
cached_radio_tracks = pickle.loads(
|
||||||
|
cache.get(f"radiotracks{self.session.id}")
|
||||||
|
)
|
||||||
|
logger.info("Using redis cache for radio generation")
|
||||||
|
radio_tracks = cached_radio_tracks
|
||||||
|
if len(radio_tracks) < quantity:
|
||||||
|
logger.info(
|
||||||
|
"Not enough radio tracks in cache. Trying to generate new cache"
|
||||||
|
)
|
||||||
|
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||||
|
sliced_queryset = cache.get(f"radioqueryset{self.session.id}")
|
||||||
|
else:
|
||||||
|
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||||
|
|
||||||
|
return sliced_queryset[:quantity]
|
||||||
|
|
||||||
|
def pick_many(self, quantity, **kwargs):
|
||||||
|
if self.session:
|
||||||
|
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"No radio session. Can't track user playback. Won't cache queryset results"
|
||||||
|
)
|
||||||
|
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
|
||||||
|
|
||||||
|
return sliced_queryset
|
||||||
|
|
||||||
|
def validate_session(self, data, **context):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="random")
|
||||||
|
class RandomRadio(SessionRadio):
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
return qs.filter(artist__content_category="music").order_by("?")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="random_library")
|
||||||
|
class RandomLibraryRadio(SessionRadio):
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||||
|
"id", flat=True
|
||||||
|
)
|
||||||
|
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||||
|
return qs.filter(query).order_by("?")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="favorites")
|
||||||
|
class FavoritesRadio(SessionRadio):
|
||||||
|
def get_queryset_kwargs(self):
|
||||||
|
kwargs = super().get_queryset_kwargs()
|
||||||
|
if self.session:
|
||||||
|
kwargs["user"] = self.session.user
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True)
|
||||||
|
return qs.filter(pk__in=track_ids, artist__content_category="music")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="custom")
|
||||||
|
class CustomRadio(SessionRadio):
|
||||||
|
def get_queryset_kwargs(self):
|
||||||
|
kwargs = super().get_queryset_kwargs()
|
||||||
|
kwargs["user"] = self.session.user
|
||||||
|
kwargs["custom_radio"] = self.session.custom_radio
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
return filters.run(kwargs["custom_radio"].config, candidates=qs)
|
||||||
|
|
||||||
|
def validate_session(self, data, **context):
|
||||||
|
data = super().validate_session(data, **context)
|
||||||
|
try:
|
||||||
|
user = data["user"]
|
||||||
|
except KeyError:
|
||||||
|
user = context.get("user")
|
||||||
|
try:
|
||||||
|
assert data["custom_radio"].user == user or data["custom_radio"].is_public
|
||||||
|
except KeyError:
|
||||||
|
raise serializers.ValidationError("You must provide a custom radio")
|
||||||
|
except AssertionError:
|
||||||
|
raise serializers.ValidationError("You don't have access to this radio")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="custom_multiple")
|
||||||
|
class CustomMultiple(SessionRadio):
|
||||||
|
"""
|
||||||
|
Receive a vuejs generated config and use it to launch a radio session
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = serializers.JSONField(required=True)
|
||||||
|
|
||||||
|
def get_config(self, data):
|
||||||
|
return data["config"]
|
||||||
|
|
||||||
|
def get_queryset_kwargs(self):
|
||||||
|
kwargs = super().get_queryset_kwargs()
|
||||||
|
kwargs["config"] = self.session.config
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def validate_session(self, data, **context):
|
||||||
|
data = super().validate_session(data, **context)
|
||||||
|
try:
|
||||||
|
data["config"] is not None
|
||||||
|
except KeyError:
|
||||||
|
raise serializers.ValidationError(
|
||||||
|
"You must provide a configuration for this radio"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
return filters.run([kwargs["config"]], candidates=qs)
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedObjectRadio(SessionRadio):
|
||||||
|
"""Abstract radio related to an object (tag, artist, user...)"""
|
||||||
|
|
||||||
|
related_object_field = serializers.IntegerField(required=True)
|
||||||
|
|
||||||
|
def clean(self, instance):
|
||||||
|
super().clean(instance)
|
||||||
|
if not instance.related_object:
|
||||||
|
raise ValidationError(
|
||||||
|
"Cannot start RelatedObjectRadio without related object"
|
||||||
|
)
|
||||||
|
if not isinstance(instance.related_object, self.model):
|
||||||
|
raise ValidationError("Trying to start radio with bad related object")
|
||||||
|
|
||||||
|
def get_related_object(self, pk):
|
||||||
|
return self.model.objects.get(pk=pk)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="tag")
|
||||||
|
class TagRadio(RelatedObjectRadio):
|
||||||
|
model = Tag
|
||||||
|
related_object_field = serializers.CharField(required=True)
|
||||||
|
|
||||||
|
def get_related_object(self, name):
|
||||||
|
return self.model.objects.get(name=name)
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
query = (
|
||||||
|
Q(tagged_items__tag=self.session.related_object)
|
||||||
|
| Q(artist__tagged_items__tag=self.session.related_object)
|
||||||
|
| Q(album__tagged_items__tag=self.session.related_object)
|
||||||
|
)
|
||||||
|
return qs.filter(query)
|
||||||
|
|
||||||
|
def get_related_object_id_repr(self, obj):
|
||||||
|
return obj.name
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_choice(choices):
|
||||||
|
total = sum(w for c, w in choices)
|
||||||
|
r = random.uniform(0, total)
|
||||||
|
upto = 0
|
||||||
|
for c, w in choices:
|
||||||
|
if upto + w >= r:
|
||||||
|
return c
|
||||||
|
upto += w
|
||||||
|
assert False, "Shouldn't get here"
|
||||||
|
|
||||||
|
|
||||||
|
class NextNotFound(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="similar")
|
||||||
|
class SimilarRadio(RelatedObjectRadio):
|
||||||
|
model = Track
|
||||||
|
|
||||||
|
def filter_queryset(self, queryset):
|
||||||
|
queryset = super().filter_queryset(queryset)
|
||||||
|
seeds = list(
|
||||||
|
self.session.session_tracks.all()
|
||||||
|
.values_list("track_id", flat=True)
|
||||||
|
.order_by("-id")[:3]
|
||||||
|
) + [self.session.related_object.pk]
|
||||||
|
for seed in seeds:
|
||||||
|
try:
|
||||||
|
return queryset.filter(pk=self.find_next_id(queryset, seed))
|
||||||
|
except NextNotFound:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return queryset.none()
|
||||||
|
|
||||||
|
def find_next_id(self, queryset, seed):
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
query = """
|
||||||
|
SELECT next, count(next) AS c
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
track_id,
|
||||||
|
creation_date,
|
||||||
|
LEAD(track_id) OVER (
|
||||||
|
PARTITION by user_id order by creation_date asc
|
||||||
|
) AS next
|
||||||
|
FROM history_listening
|
||||||
|
INNER JOIN users_user ON (users_user.id = user_id)
|
||||||
|
WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s
|
||||||
|
ORDER BY creation_date ASC
|
||||||
|
) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
|
||||||
|
"""
|
||||||
|
cursor.execute(query, [self.session.user_id, seed, seed])
|
||||||
|
next_candidates = list(cursor.fetchall())
|
||||||
|
|
||||||
|
if not next_candidates:
|
||||||
|
raise NextNotFound()
|
||||||
|
|
||||||
|
matching_tracks = list(
|
||||||
|
queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
|
||||||
|
"id", flat=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
|
||||||
|
if not next_candidates:
|
||||||
|
raise NextNotFound()
|
||||||
|
return random.choice([c[0] for c in next_candidates])
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="artist")
|
||||||
|
class ArtistRadio(RelatedObjectRadio):
|
||||||
|
model = Artist
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
return qs.filter(artist=self.session.related_object)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="less-listened")
|
||||||
|
class LessListenedRadio(SessionRadio):
|
||||||
|
def clean(self, instance):
|
||||||
|
instance.related_object = instance.user
|
||||||
|
super().clean(instance)
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||||
|
return (
|
||||||
|
qs.filter(artist__content_category="music")
|
||||||
|
.exclude(pk__in=listened)
|
||||||
|
.order_by("?")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="less-listened_library")
|
||||||
|
class LessListenedLibraryRadio(SessionRadio):
|
||||||
|
def clean(self, instance):
|
||||||
|
instance.related_object = instance.user
|
||||||
|
super().clean(instance)
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||||
|
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||||
|
"id", flat=True
|
||||||
|
)
|
||||||
|
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||||
|
return qs.filter(query).exclude(pk__in=listened).order_by("?")
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="actor-content")
|
||||||
|
class ActorContentRadio(RelatedObjectRadio):
|
||||||
|
"""
|
||||||
|
Play content from given actor libraries
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = federation_models.Actor
|
||||||
|
related_object_field = federation_fields.ActorRelatedField(required=True)
|
||||||
|
|
||||||
|
def get_related_object(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
actor_uploads = Upload.objects.filter(
|
||||||
|
library__actor=self.session.related_object,
|
||||||
|
)
|
||||||
|
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||||
|
|
||||||
|
def get_related_object_id_repr(self, obj):
|
||||||
|
return obj.full_username
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="library")
|
||||||
|
class LibraryRadio(RelatedObjectRadio):
|
||||||
|
"""
|
||||||
|
Play content from a given library
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = Library
|
||||||
|
related_object_field = serializers.UUIDField(required=True)
|
||||||
|
|
||||||
|
def get_related_object(self, value):
|
||||||
|
return Library.objects.get(uuid=value)
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
actor_uploads = Upload.objects.filter(
|
||||||
|
library=self.session.related_object,
|
||||||
|
)
|
||||||
|
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||||
|
|
||||||
|
def get_related_object_id_repr(self, obj):
|
||||||
|
return obj.uuid
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register(name="recently-added")
|
||||||
|
class RecentlyAdded(SessionRadio):
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
date = datetime.date.today() - datetime.timedelta(days=30)
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
return qs.filter(
|
||||||
|
Q(artist__content_category="music"),
|
||||||
|
Q(creation_date__gt=date),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Use this to experiment on the custom multiple radio with troi
|
||||||
|
@registry.register(name="troi")
|
||||||
|
class Troi(SessionRadio):
|
||||||
|
"""
|
||||||
|
Receive a vuejs generated config and use it to launch a troi radio session.
|
||||||
|
The config data should follow :
|
||||||
|
{"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
|
||||||
|
Validation of the config (args) is done by troi during track fetch.
|
||||||
|
Funkwhale only checks if the patch is implemented
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = serializers.JSONField(required=True)
|
||||||
|
|
||||||
|
def append_lb_config(self, data):
|
||||||
|
if self.session.user.settings is None:
|
||||||
|
logger.warning(
|
||||||
|
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
elif self.session.user.settings.get("lb_user_name") is None:
|
||||||
|
logger.warning(
|
||||||
|
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data["user_name"] = self.session.user.settings["lb_user_name"]
|
||||||
|
|
||||||
|
if self.session.user.settings.get("lb_user_token") is None:
|
||||||
|
logger.warning(
|
||||||
|
"No lb_user_token set in user settings. Some troi patch will fail"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data["user_token"] = self.session.user.settings["lb_user_token"]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_queryset_kwargs(self):
|
||||||
|
kwargs = super().get_queryset_kwargs()
|
||||||
|
kwargs["config"] = self.session.config
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def validate_session(self, data, **context):
|
||||||
|
data = super().validate_session(data, **context)
|
||||||
|
if data.get("config") is None:
|
||||||
|
raise serializers.ValidationError(
|
||||||
|
"You must provide a configuration for this radio"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_queryset(self, **kwargs):
|
||||||
|
qs = super().get_queryset(**kwargs)
|
||||||
|
config = self.append_lb_config(json.loads(kwargs["config"]))
|
||||||
|
|
||||||
|
return lb_recommendations.run(config, candidates=qs)
|
|
@ -0,0 +1,10 @@
|
||||||
|
import persisting_theory
|
||||||
|
|
||||||
|
|
||||||
|
class RadioRegistry_v2(persisting_theory.Registry):
|
||||||
|
def prepare_name(self, data, name=None):
|
||||||
|
setattr(data, "radio_type", name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
registry = RadioRegistry_v2()
|
|
@ -40,9 +40,11 @@ class RadioSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
|
||||||
class RadioSessionTrackSerializerCreate(serializers.ModelSerializer):
|
class RadioSessionTrackSerializerCreate(serializers.ModelSerializer):
|
||||||
|
count = serializers.IntegerField(required=False, allow_null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = models.RadioSessionTrack
|
model = models.RadioSessionTrack
|
||||||
fields = ("session",)
|
fields = ("session", "count")
|
||||||
|
|
||||||
|
|
||||||
class RadioSessionTrackSerializer(serializers.ModelSerializer):
|
class RadioSessionTrackSerializer(serializers.ModelSerializer):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from . import views
|
||||||
router = routers.OptionalSlashRouter()
|
router = routers.OptionalSlashRouter()
|
||||||
router.register(r"sessions", views.RadioSessionViewSet, "sessions")
|
router.register(r"sessions", views.RadioSessionViewSet, "sessions")
|
||||||
router.register(r"radios", views.RadioViewSet, "radios")
|
router.register(r"radios", views.RadioViewSet, "radios")
|
||||||
router.register(r"tracks", views.RadioSessionTrackViewSet, "tracks")
|
router.register(r"tracks", views.V1_RadioSessionTrackViewSet, "tracks")
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = router.urls
|
urlpatterns = router.urls
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
from funkwhale_api.common import routers
|
||||||
|
|
||||||
|
from . import views
|
||||||
|
|
||||||
|
router = routers.OptionalSlashRouter()
|
||||||
|
|
||||||
|
router.register(r"sessions", views.V2_RadioSessionViewSet, "sessions")
|
||||||
|
|
||||||
|
|
||||||
|
urlpatterns = router.urls
|
|
@ -1,3 +1,6 @@
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from drf_spectacular.utils import extend_schema
|
from drf_spectacular.utils import extend_schema
|
||||||
from rest_framework import mixins, status, viewsets
|
from rest_framework import mixins, status, viewsets
|
||||||
|
@ -121,7 +124,7 @@ class RadioSessionViewSet(
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
|
class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
|
||||||
serializer_class = serializers.RadioSessionTrackSerializer
|
serializer_class = serializers.RadioSessionTrackSerializer
|
||||||
queryset = models.RadioSessionTrack.objects.all()
|
queryset = models.RadioSessionTrack.objects.all()
|
||||||
permission_classes = []
|
permission_classes = []
|
||||||
|
@ -133,21 +136,19 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
|
||||||
session = serializer.validated_data["session"]
|
session = serializer.validated_data["session"]
|
||||||
if not request.user.is_authenticated and not request.session.session_key:
|
if not request.user.is_authenticated and not request.session.session_key:
|
||||||
self.request.session.create()
|
self.request.session.create()
|
||||||
try:
|
if not request.user == session.user or (
|
||||||
assert (request.user == session.user) or (
|
not request.session.session_key == session.session_key
|
||||||
request.session.session_key == session.session_key
|
and not session.session_key
|
||||||
and session.session_key
|
):
|
||||||
)
|
|
||||||
except AssertionError:
|
|
||||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.radio.pick()
|
session.radio(api_version=1).pick()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return Response(
|
return Response(
|
||||||
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
||||||
)
|
)
|
||||||
session_track = session.session_tracks.all().latest("id")
|
session_track = session.session_tracks.all().latest("id")
|
||||||
# self.perform_create(serializer)
|
|
||||||
# dirty override here, since we use a different serializer for creation and detail
|
# dirty override here, since we use a different serializer for creation and detail
|
||||||
serializer = self.serializer_class(
|
serializer = self.serializer_class(
|
||||||
instance=session_track, context=self.get_serializer_context()
|
instance=session_track, context=self.get_serializer_context()
|
||||||
|
@ -161,3 +162,99 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
|
||||||
if self.action == "create":
|
if self.action == "create":
|
||||||
return serializers.RadioSessionTrackSerializerCreate
|
return serializers.RadioSessionTrackSerializerCreate
|
||||||
return super().get_serializer_class(*args, **kwargs)
|
return super().get_serializer_class(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class V2_RadioSessionViewSet(
|
||||||
|
mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet
|
||||||
|
):
|
||||||
|
"""Returns a list of RadioSessions"""
|
||||||
|
|
||||||
|
serializer_class = serializers.RadioSessionSerializer
|
||||||
|
queryset = models.RadioSession.objects.all()
|
||||||
|
permission_classes = []
|
||||||
|
|
||||||
|
@action(detail=True, serializer_class=serializers.RadioSessionTrackSerializerCreate)
|
||||||
|
def tracks(self, request, pk, *args, **kwargs):
|
||||||
|
data = {"session": pk}
|
||||||
|
data["count"] = (
|
||||||
|
request.query_params["count"]
|
||||||
|
if "count" in request.query_params.keys()
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
serializer = serializers.RadioSessionTrackSerializerCreate(data=data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
session = serializer.validated_data["session"]
|
||||||
|
|
||||||
|
count = int(data["count"])
|
||||||
|
# this is used for test purpose.
|
||||||
|
filter_playable = (
|
||||||
|
request.query_params["filter_playable"]
|
||||||
|
if "filter_playable" in request.query_params.keys()
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
if not request.user.is_authenticated and not request.session.session_key:
|
||||||
|
self.request.session.create()
|
||||||
|
|
||||||
|
if not request.user == session.user or (
|
||||||
|
not request.session.session_key == session.session_key
|
||||||
|
and not session.session_key
|
||||||
|
):
|
||||||
|
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||||
|
try:
|
||||||
|
from . import radios_v2 # noqa
|
||||||
|
|
||||||
|
session.radio(api_version=2).pick_many(
|
||||||
|
count, filter_playable=filter_playable
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return Response(
|
||||||
|
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
# dirty override here, since we use a different serializer for creation and detail
|
||||||
|
evaluated_radio_tracks = pickle.loads(cache.get(f"radiotracks{session.id}"))
|
||||||
|
batch = evaluated_radio_tracks[:count]
|
||||||
|
serializer = TrackSerializer(
|
||||||
|
data=batch,
|
||||||
|
many="true",
|
||||||
|
)
|
||||||
|
serializer.is_valid()
|
||||||
|
|
||||||
|
# delete the tracks we sent from the cache
|
||||||
|
new_cached_radiotracks = evaluated_radio_tracks[count:]
|
||||||
|
cache.set(f"radiotracks{session.id}", pickle.dumps(new_cached_radiotracks))
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
serializer.data,
|
||||||
|
status=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
queryset = super().get_queryset()
|
||||||
|
if self.request.user.is_authenticated:
|
||||||
|
return queryset.filter(
|
||||||
|
Q(user=self.request.user)
|
||||||
|
| Q(session_key=self.request.session.session_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
return queryset.filter(session_key=self.request.session.session_key).exclude(
|
||||||
|
session_key=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def perform_create(self, serializer):
|
||||||
|
if (
|
||||||
|
not self.request.user.is_authenticated
|
||||||
|
and not self.request.session.session_key
|
||||||
|
):
|
||||||
|
self.request.session.create()
|
||||||
|
return serializer.save(
|
||||||
|
user=self.request.user if self.request.user.is_authenticated else None,
|
||||||
|
session_key=self.request.session.session_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_serializer_context(self):
|
||||||
|
context = super().get_serializer_context()
|
||||||
|
context["user"] = (
|
||||||
|
self.request.user if self.request.user.is_authenticated else None
|
||||||
|
)
|
||||||
|
return context
|
||||||
|
|
|
@ -2,8 +2,8 @@ import json
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.core.exceptions import ValidationError
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
|
|
||||||
from funkwhale_api.favorites.models import TrackFavorite
|
from funkwhale_api.favorites.models import TrackFavorite
|
||||||
from funkwhale_api.radios import models, radios, serializers
|
from funkwhale_api.radios import models, radios, serializers
|
||||||
|
@ -98,7 +98,7 @@ def test_can_get_choices_for_custom_radio(factories):
|
||||||
session = factories["radios.CustomRadioSession"](
|
session = factories["radios.CustomRadioSession"](
|
||||||
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
|
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
|
||||||
)
|
)
|
||||||
choices = session.radio.get_choices(filter_playable=False)
|
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||||
|
|
||||||
expected = [t.pk for t in tracks]
|
expected = [t.pk for t in tracks]
|
||||||
assert list(choices.values_list("id", flat=True)) == expected
|
assert list(choices.values_list("id", flat=True)) == expected
|
||||||
|
@ -191,16 +191,17 @@ def test_can_get_track_for_session_from_api(factories, logged_in_api_client):
|
||||||
|
|
||||||
|
|
||||||
def test_related_object_radio_validate_related_object(factories):
|
def test_related_object_radio_validate_related_object(factories):
|
||||||
user = factories["users.User"]()
|
|
||||||
# cannot start without related object
|
# cannot start without related object
|
||||||
radio = radios.ArtistRadio()
|
radio = {"radio_type": "tag"}
|
||||||
|
serializer = serializers.RadioSessionSerializer()
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
radio.start_session(user)
|
serializer.validate(data=radio)
|
||||||
|
|
||||||
# cannot start with bad related object type
|
# cannot start with bad related object type
|
||||||
radio = radios.ArtistRadio()
|
radio = {"radio_type": "tag", "related_object": "whatever"}
|
||||||
|
serializer = serializers.RadioSessionSerializer()
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
radio.start_session(user, related_object=user)
|
serializer.validate(data=radio)
|
||||||
|
|
||||||
|
|
||||||
def test_can_start_artist_radio(factories):
|
def test_can_start_artist_radio(factories):
|
||||||
|
@ -391,7 +392,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories):
|
||||||
{"type": "artist", "ids": [excluded_artist.pk], "not": True},
|
{"type": "artist", "ids": [excluded_artist.pk], "not": True},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
choices = session.radio.get_choices(filter_playable=False)
|
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||||
|
|
||||||
expected = [u.track.pk for u in included_uploads]
|
expected = [u.track.pk for u in included_uploads]
|
||||||
assert list(choices.values_list("id", flat=True)) == expected
|
assert list(choices.values_list("id", flat=True)) == expected
|
||||||
|
@ -409,7 +410,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories):
|
||||||
{"type": "tag", "names": ["rock"], "not": True},
|
{"type": "tag", "names": ["rock"], "not": True},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
choices = session.radio.get_choices(filter_playable=False)
|
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||||
|
|
||||||
expected = [u.track.pk for u in included_uploads]
|
expected = [u.track.pk for u in included_uploads]
|
||||||
assert list(choices.values_list("id", flat=True)) == expected
|
assert list(choices.values_list("id", flat=True)) == expected
|
||||||
|
@ -429,28 +430,3 @@ def test_can_start_custom_multiple_radio_from_api(api_client, factories):
|
||||||
format="json",
|
format="json",
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
def test_can_start_periodic_jams_troi_radio_from_api(api_client, factories):
|
|
||||||
factories["music.Track"].create_batch(5)
|
|
||||||
url = reverse("api:v1:radios:sessions-list")
|
|
||||||
config = {"patch": "periodic-jams", "type": "daily-jams"}
|
|
||||||
response = api_client.post(
|
|
||||||
url,
|
|
||||||
{"radio_type": "troi", "config": config},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
||||||
|
|
||||||
# to do : send error to api ?
|
|
||||||
def test_can_catch_troi_radio_error(api_client, factories):
|
|
||||||
factories["music.Track"].create_batch(5)
|
|
||||||
url = reverse("api:v1:radios:sessions-list")
|
|
||||||
config = {"patch": "periodic-jams", "type": "not_existing_type"}
|
|
||||||
response = api_client.post(
|
|
||||||
url,
|
|
||||||
{"radio_type": "troi", "config": config},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
from django.urls import reverse
|
||||||
|
|
||||||
|
from funkwhale_api.favorites.models import TrackFavorite
|
||||||
|
from funkwhale_api.radios import models, radios_v2
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client):
|
||||||
|
actor = logged_in_api_client.user.create_actor()
|
||||||
|
track = factories["music.Upload"](
|
||||||
|
library__actor=actor, import_status="finished"
|
||||||
|
).track
|
||||||
|
url = reverse("api:v2:radios:sessions-list")
|
||||||
|
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||||
|
session = models.RadioSession.objects.latest("id")
|
||||||
|
|
||||||
|
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||||
|
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||||
|
data = json.loads(response.content.decode("utf-8"))
|
||||||
|
|
||||||
|
assert data[0]["id"] == track.pk
|
||||||
|
|
||||||
|
next_track = factories["music.Upload"](
|
||||||
|
library__actor=actor, import_status="finished"
|
||||||
|
).track
|
||||||
|
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||||
|
data = json.loads(response.content.decode("utf-8"))
|
||||||
|
|
||||||
|
assert data[0]["id"] == next_track.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_use_radio_session_to_filter_choices_v2(factories):
|
||||||
|
factories["music.Upload"].create_batch(10)
|
||||||
|
user = factories["users.User"]()
|
||||||
|
radio = radios_v2.RandomRadio()
|
||||||
|
session = radio.start_session(user)
|
||||||
|
|
||||||
|
radio.pick_many(quantity=10, filter_playable=False)
|
||||||
|
|
||||||
|
# ensure 10 different tracks have been suggested
|
||||||
|
tracks_id = [
|
||||||
|
session_track.track.pk for session_track in session.session_tracks.all()
|
||||||
|
]
|
||||||
|
assert len(set(tracks_id)) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_client):
|
||||||
|
tracks = factories["music.Track"].create_batch(5)
|
||||||
|
url = reverse("api:v2:radios:sessions-list")
|
||||||
|
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||||
|
session = models.RadioSession.objects.latest("id")
|
||||||
|
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||||
|
|
||||||
|
previous_choices = []
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
response = logged_in_api_client.get(
|
||||||
|
url, {"session": session.pk, "filter_playable": False}
|
||||||
|
)
|
||||||
|
pick = json.loads(response.content.decode("utf-8"))
|
||||||
|
assert pick[0]["title"] not in previous_choices
|
||||||
|
assert pick[0]["title"] in [t.title for t in tracks]
|
||||||
|
previous_choices.append(pick[0]["title"])
|
||||||
|
|
||||||
|
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||||
|
assert (
|
||||||
|
json.loads(response.content.decode("utf-8"))
|
||||||
|
== "Radio doesn't have more candidates"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_get_choices_for_favorites_radio_v2(factories):
|
||||||
|
files = factories["music.Upload"].create_batch(10)
|
||||||
|
tracks = [f.track for f in files]
|
||||||
|
user = factories["users.User"]()
|
||||||
|
for i in range(5):
|
||||||
|
TrackFavorite.add(track=random.choice(tracks), user=user)
|
||||||
|
|
||||||
|
radio = radios_v2.FavoritesRadio()
|
||||||
|
session = radio.start_session(user=user)
|
||||||
|
choices = session.radio(api_version=2).get_choices(
|
||||||
|
quantity=100, filter_playable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(choices) == user.track_favorites.all().count()
|
||||||
|
|
||||||
|
for favorite in user.track_favorites.all():
|
||||||
|
assert favorite.track in choices
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_get_choices_for_custom_radio_v2(factories):
|
||||||
|
artist = factories["music.Artist"]()
|
||||||
|
files = factories["music.Upload"].create_batch(5, track__artist=artist)
|
||||||
|
tracks = [f.track for f in files]
|
||||||
|
factories["music.Upload"].create_batch(5)
|
||||||
|
|
||||||
|
session = factories["radios.CustomRadioSession"](
|
||||||
|
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
|
||||||
|
)
|
||||||
|
choices = session.radio(api_version=2).get_choices(
|
||||||
|
quantity=1, filter_playable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = [t.pk for t in tracks]
|
||||||
|
for t in choices:
|
||||||
|
assert t.id in expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_cache_radio_track(factories):
|
||||||
|
uploads = factories["music.Track"].create_batch(10)
|
||||||
|
user = factories["users.User"]()
|
||||||
|
radio = radios_v2.RandomRadio()
|
||||||
|
session = radio.start_session(user)
|
||||||
|
picked = session.radio(api_version=2).pick_many(quantity=1, filter_playable=False)
|
||||||
|
assert len(picked) == 1
|
||||||
|
for t in pickle.loads(cache.get(f"radiotracks{session.id}")):
|
||||||
|
assert t in uploads
|
||||||
|
|
||||||
|
|
||||||
|
def test_regenerate_cache_if_not_enought_tracks_in_it(
|
||||||
|
factories, caplog, logged_in_api_client
|
||||||
|
):
|
||||||
|
logger = logging.getLogger("funkwhale_api.radios.radios_v2")
|
||||||
|
caplog.set_level(logging.INFO)
|
||||||
|
logger.addHandler(caplog.handler)
|
||||||
|
|
||||||
|
factories["music.Track"].create_batch(10)
|
||||||
|
factories["users.User"]()
|
||||||
|
url = reverse("api:v2:radios:sessions-list")
|
||||||
|
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||||
|
session = models.RadioSession.objects.latest("id")
|
||||||
|
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||||
|
logged_in_api_client.get(url, {"count": 9, "filter_playable": False})
|
||||||
|
response = logged_in_api_client.get(url, {"count": 10, "filter_playable": False})
|
||||||
|
pick = json.loads(response.content.decode("utf-8"))
|
||||||
|
assert (
|
||||||
|
"Not enough radio tracks in cache. Trying to generate new cache" in caplog.text
|
||||||
|
)
|
||||||
|
assert len(pick) == 1
|
|
@ -0,0 +1 @@
|
||||||
|
Cache radio queryset into redis. New radio track endpoint for api v2 is /api/v2/radios/sessions/{radiosessionid}/tracks (#2135)
|
|
@ -98,6 +98,8 @@ services:
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
image: typesense/typesense:0.24.0
|
image: typesense/typesense:0.24.0
|
||||||
|
networks:
|
||||||
|
- internal
|
||||||
volumes:
|
volumes:
|
||||||
- ./typesense/data:/data
|
- ./typesense/data:/data
|
||||||
command: --data-dir /data --enable-cors
|
command: --data-dir /data --enable-cors
|
||||||
|
|
Loading…
Reference in New Issue