Cache radio queryset. New api endpoint for radio tracks : api/v2/radios/sessions/$sessionid/tracks?count=$count
This commit is contained in:
parent
04acd056e6
commit
4ad806b8e9
|
@ -10,6 +10,10 @@ v2_patterns += [
|
|||
r"^instance/",
|
||||
include(("funkwhale_api.instance.urls", "instance"), namespace="instance"),
|
||||
),
|
||||
url(
|
||||
r"^radios/",
|
||||
include(("funkwhale_api.radios.urls_v2", "radios"), namespace="radios"),
|
||||
),
|
||||
]
|
||||
|
||||
urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))]
|
||||
|
|
|
@ -54,10 +54,6 @@ class RadioSession(models.Model):
|
|||
CONFIG_VERSION = 0
|
||||
config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True)
|
||||
|
||||
def save(self, **kwargs):
|
||||
self.radio.clean(self)
|
||||
super().save(**kwargs)
|
||||
|
||||
@property
|
||||
def next_position(self):
|
||||
next_position = 1
|
||||
|
@ -68,16 +64,24 @@ class RadioSession(models.Model):
|
|||
|
||||
return next_position
|
||||
|
||||
def add(self, track):
|
||||
new_session_track = RadioSessionTrack.objects.create(
|
||||
track=track, session=self, position=self.next_position
|
||||
)
|
||||
def add(self, tracks):
|
||||
next_position = self.next_position
|
||||
radio_session_tracks = []
|
||||
for i, track in enumerate(tracks):
|
||||
radio_session_track = RadioSessionTrack(
|
||||
track=track, session=self, position=next_position + i
|
||||
)
|
||||
radio_session_tracks.append(radio_session_track)
|
||||
|
||||
return new_session_track
|
||||
new_session_tracks = RadioSessionTrack.objects.bulk_create(radio_session_tracks)
|
||||
|
||||
@property
|
||||
def radio(self):
|
||||
from .registries import registry
|
||||
return new_session_tracks
|
||||
|
||||
def radio(self, api_version):
|
||||
if api_version == 2:
|
||||
from .registries_v2 import registry
|
||||
else:
|
||||
from .registries import registry
|
||||
|
||||
return registry[self.radio_type](session=self)
|
||||
|
||||
|
|
|
@ -13,10 +13,9 @@ from funkwhale_api.federation import fields as federation_fields
|
|||
from funkwhale_api.federation import models as federation_models
|
||||
from funkwhale_api.moderation import filters as moderation_filters
|
||||
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||
from funkwhale_api.radios import lb_recommendations
|
||||
from funkwhale_api.tags.models import Tag
|
||||
|
||||
from . import filters, models
|
||||
from . import filters, lb_recommendations, models
|
||||
from .registries import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -63,11 +62,19 @@ class SessionRadio(SimpleRadio):
|
|||
return self.session
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = Track.objects.all()
|
||||
if not self.session:
|
||||
return qs
|
||||
if not self.session.user:
|
||||
return qs
|
||||
if not self.session or not self.session.user:
|
||||
return (
|
||||
Track.objects.all()
|
||||
.with_playable_uploads(actor=None)
|
||||
.select_related("artist", "album__artist", "attributed_to")
|
||||
)
|
||||
else:
|
||||
qs = (
|
||||
Track.objects.all()
|
||||
.with_playable_uploads(self.session.user.actor)
|
||||
.select_related("artist", "album__artist", "attributed_to")
|
||||
)
|
||||
|
||||
query = moderation_filters.get_filtered_content_query(
|
||||
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
||||
user=self.session.user,
|
||||
|
@ -77,6 +84,16 @@ class SessionRadio(SimpleRadio):
|
|||
def get_queryset_kwargs(self):
|
||||
return {}
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
return queryset
|
||||
|
||||
def filter_from_session(self, queryset):
|
||||
already_played = self.session.session_tracks.all().values_list(
|
||||
"track", flat=True
|
||||
)
|
||||
queryset = queryset.exclude(pk__in=already_played)
|
||||
return queryset
|
||||
|
||||
def get_choices(self, **kwargs):
|
||||
kwargs.update(self.get_queryset_kwargs())
|
||||
queryset = self.get_queryset(**kwargs)
|
||||
|
@ -89,16 +106,6 @@ class SessionRadio(SimpleRadio):
|
|||
queryset = self.filter_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
return queryset
|
||||
|
||||
def filter_from_session(self, queryset):
|
||||
already_played = self.session.session_tracks.all().values_list(
|
||||
"track", flat=True
|
||||
)
|
||||
queryset = queryset.exclude(pk__in=already_played)
|
||||
return queryset
|
||||
|
||||
def pick(self, **kwargs):
|
||||
return self.pick_many(quantity=1, **kwargs)[0]
|
||||
|
||||
|
@ -106,8 +113,7 @@ class SessionRadio(SimpleRadio):
|
|||
choices = self.get_choices(**kwargs)
|
||||
picked_choices = super().pick_many(choices=choices, quantity=quantity)
|
||||
if self.session:
|
||||
for choice in picked_choices:
|
||||
self.session.add(choice)
|
||||
self.session.add(picked_choices)
|
||||
return picked_choices
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
|
@ -191,7 +197,9 @@ class CustomMultiple(SessionRadio):
|
|||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
if data.get("config") is None:
|
||||
try:
|
||||
data["config"] is not None
|
||||
except KeyError:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,510 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import connection
|
||||
from django.db.models import Q
|
||||
from rest_framework import serializers
|
||||
|
||||
from funkwhale_api.federation import fields as federation_fields
|
||||
from funkwhale_api.federation import models as federation_models
|
||||
from funkwhale_api.moderation import filters as moderation_filters
|
||||
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||
from funkwhale_api.tags.models import Tag
|
||||
|
||||
from . import filters, lb_recommendations, models
|
||||
from .registries_v2 import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleRadio:
|
||||
related_object_field = None
|
||||
|
||||
def clean(self, instance):
|
||||
return
|
||||
|
||||
def weighted_pick(
|
||||
self,
|
||||
choices: List[Tuple[int, int]],
|
||||
previous_choices: Optional[List[int]] = None,
|
||||
) -> int:
|
||||
total = sum(weight for c, weight in choices)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for choice, weight in choices:
|
||||
if upto + weight >= r:
|
||||
return choice
|
||||
upto += weight
|
||||
|
||||
|
||||
class SessionRadio(SimpleRadio):
|
||||
def __init__(self, session=None):
|
||||
self.session = session
|
||||
|
||||
def start_session(self, user, **kwargs):
|
||||
self.session = models.RadioSession.objects.create(
|
||||
user=user, radio_type=self.radio_type, **kwargs
|
||||
)
|
||||
return self.session
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
actor = None
|
||||
try:
|
||||
actor = self.session.user.actor
|
||||
except KeyError:
|
||||
pass # Maybe logging would be helpful
|
||||
|
||||
qs = (
|
||||
Track.objects.all()
|
||||
.with_playable_uploads(actor=actor)
|
||||
.select_related("artist", "album__artist", "attributed_to")
|
||||
)
|
||||
|
||||
query = moderation_filters.get_filtered_content_query(
|
||||
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
||||
user=self.session.user,
|
||||
)
|
||||
return qs.exclude(query)
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
return {}
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
return queryset
|
||||
|
||||
def filter_from_session(self, queryset):
|
||||
already_played = self.session.session_tracks.all().values_list(
|
||||
"track", flat=True
|
||||
)
|
||||
queryset = queryset.exclude(pk__in=already_played)
|
||||
return queryset
|
||||
|
||||
def cache_batch_radio_track(self, **kwargs):
|
||||
BATCH_SIZE = 100
|
||||
# get cached RadioTracks if any
|
||||
try:
|
||||
cached_evaluated_radio_tracks = pickle.loads(
|
||||
cache.get(f"radiotracks{self.session.id}")
|
||||
)
|
||||
except TypeError:
|
||||
cached_evaluated_radio_tracks = None
|
||||
|
||||
# get the queryset and apply filters
|
||||
kwargs.update(self.get_queryset_kwargs())
|
||||
queryset = self.get_queryset(**kwargs)
|
||||
queryset = self.filter_from_session(queryset)
|
||||
|
||||
if kwargs["filter_playable"] is True:
|
||||
queryset = queryset.playable_by(
|
||||
self.session.user.actor if self.session.user else None
|
||||
)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
|
||||
# select a random batch of the qs
|
||||
sliced_queryset = queryset.order_by("?")[:BATCH_SIZE]
|
||||
if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks:
|
||||
raise ValueError("No more radio candidates")
|
||||
|
||||
# create the radio session tracks into db in bulk
|
||||
self.session.add(sliced_queryset)
|
||||
|
||||
# evaluate the queryset to save it in cache
|
||||
radio_tracks = list(sliced_queryset)
|
||||
|
||||
if cached_evaluated_radio_tracks is not None:
|
||||
radio_tracks.extend(cached_evaluated_radio_tracks)
|
||||
logger.info(
|
||||
f"Setting redis cache for radio generation with radio id {self.session.id}"
|
||||
)
|
||||
cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600)
|
||||
cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600)
|
||||
|
||||
return sliced_queryset
|
||||
|
||||
def get_choices(self, quantity, **kwargs):
|
||||
if cache.get(f"radiotracks{self.session.id}"):
|
||||
cached_radio_tracks = pickle.loads(
|
||||
cache.get(f"radiotracks{self.session.id}")
|
||||
)
|
||||
logger.info("Using redis cache for radio generation")
|
||||
radio_tracks = cached_radio_tracks
|
||||
if len(radio_tracks) < quantity:
|
||||
logger.info(
|
||||
"Not enough radio tracks in cache. Trying to generate new cache"
|
||||
)
|
||||
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||
sliced_queryset = cache.get(f"radioqueryset{self.session.id}")
|
||||
else:
|
||||
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||
|
||||
return sliced_queryset[:quantity]
|
||||
|
||||
def pick_many(self, quantity, **kwargs):
|
||||
if self.session:
|
||||
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
|
||||
else:
|
||||
logger.info(
|
||||
"No radio session. Can't track user playback. Won't cache queryset results"
|
||||
)
|
||||
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
|
||||
|
||||
return sliced_queryset
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
return data
|
||||
|
||||
|
||||
@registry.register(name="random")
|
||||
class RandomRadio(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(artist__content_category="music").order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="random_library")
|
||||
class RandomLibraryRadio(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||
return qs.filter(query).order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="favorites")
|
||||
class FavoritesRadio(SessionRadio):
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
if self.session:
|
||||
kwargs["user"] = self.session.user
|
||||
return kwargs
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True)
|
||||
return qs.filter(pk__in=track_ids, artist__content_category="music")
|
||||
|
||||
|
||||
@registry.register(name="custom")
|
||||
class CustomRadio(SessionRadio):
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["user"] = self.session.user
|
||||
kwargs["custom_radio"] = self.session.custom_radio
|
||||
return kwargs
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return filters.run(kwargs["custom_radio"].config, candidates=qs)
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
try:
|
||||
user = data["user"]
|
||||
except KeyError:
|
||||
user = context.get("user")
|
||||
try:
|
||||
assert data["custom_radio"].user == user or data["custom_radio"].is_public
|
||||
except KeyError:
|
||||
raise serializers.ValidationError("You must provide a custom radio")
|
||||
except AssertionError:
|
||||
raise serializers.ValidationError("You don't have access to this radio")
|
||||
return data
|
||||
|
||||
|
||||
@registry.register(name="custom_multiple")
|
||||
class CustomMultiple(SessionRadio):
|
||||
"""
|
||||
Receive a vuejs generated config and use it to launch a radio session
|
||||
"""
|
||||
|
||||
config = serializers.JSONField(required=True)
|
||||
|
||||
def get_config(self, data):
|
||||
return data["config"]
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["config"] = self.session.config
|
||||
return kwargs
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
try:
|
||||
data["config"] is not None
|
||||
except KeyError:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
return data
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return filters.run([kwargs["config"]], candidates=qs)
|
||||
|
||||
|
||||
class RelatedObjectRadio(SessionRadio):
|
||||
"""Abstract radio related to an object (tag, artist, user...)"""
|
||||
|
||||
related_object_field = serializers.IntegerField(required=True)
|
||||
|
||||
def clean(self, instance):
|
||||
super().clean(instance)
|
||||
if not instance.related_object:
|
||||
raise ValidationError(
|
||||
"Cannot start RelatedObjectRadio without related object"
|
||||
)
|
||||
if not isinstance(instance.related_object, self.model):
|
||||
raise ValidationError("Trying to start radio with bad related object")
|
||||
|
||||
def get_related_object(self, pk):
|
||||
return self.model.objects.get(pk=pk)
|
||||
|
||||
|
||||
@registry.register(name="tag")
|
||||
class TagRadio(RelatedObjectRadio):
|
||||
model = Tag
|
||||
related_object_field = serializers.CharField(required=True)
|
||||
|
||||
def get_related_object(self, name):
|
||||
return self.model.objects.get(name=name)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
query = (
|
||||
Q(tagged_items__tag=self.session.related_object)
|
||||
| Q(artist__tagged_items__tag=self.session.related_object)
|
||||
| Q(album__tagged_items__tag=self.session.related_object)
|
||||
)
|
||||
return qs.filter(query)
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.name
|
||||
|
||||
|
||||
def weighted_choice(choices):
|
||||
total = sum(w for c, w in choices)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for c, w in choices:
|
||||
if upto + w >= r:
|
||||
return c
|
||||
upto += w
|
||||
assert False, "Shouldn't get here"
|
||||
|
||||
|
||||
class NextNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register(name="similar")
|
||||
class SimilarRadio(RelatedObjectRadio):
|
||||
model = Track
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
seeds = list(
|
||||
self.session.session_tracks.all()
|
||||
.values_list("track_id", flat=True)
|
||||
.order_by("-id")[:3]
|
||||
) + [self.session.related_object.pk]
|
||||
for seed in seeds:
|
||||
try:
|
||||
return queryset.filter(pk=self.find_next_id(queryset, seed))
|
||||
except NextNotFound:
|
||||
continue
|
||||
|
||||
return queryset.none()
|
||||
|
||||
def find_next_id(self, queryset, seed):
|
||||
with connection.cursor() as cursor:
|
||||
query = """
|
||||
SELECT next, count(next) AS c
|
||||
FROM (
|
||||
SELECT
|
||||
track_id,
|
||||
creation_date,
|
||||
LEAD(track_id) OVER (
|
||||
PARTITION by user_id order by creation_date asc
|
||||
) AS next
|
||||
FROM history_listening
|
||||
INNER JOIN users_user ON (users_user.id = user_id)
|
||||
WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s
|
||||
ORDER BY creation_date ASC
|
||||
) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
|
||||
"""
|
||||
cursor.execute(query, [self.session.user_id, seed, seed])
|
||||
next_candidates = list(cursor.fetchall())
|
||||
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
|
||||
matching_tracks = list(
|
||||
queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
return random.choice([c[0] for c in next_candidates])
|
||||
|
||||
|
||||
@registry.register(name="artist")
|
||||
class ArtistRadio(RelatedObjectRadio):
|
||||
model = Artist
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(artist=self.session.related_object)
|
||||
|
||||
|
||||
@registry.register(name="less-listened")
|
||||
class LessListenedRadio(SessionRadio):
|
||||
def clean(self, instance):
|
||||
instance.related_object = instance.user
|
||||
super().clean(instance)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||
return (
|
||||
qs.filter(artist__content_category="music")
|
||||
.exclude(pk__in=listened)
|
||||
.order_by("?")
|
||||
)
|
||||
|
||||
|
||||
@registry.register(name="less-listened_library")
|
||||
class LessListenedLibraryRadio(SessionRadio):
|
||||
def clean(self, instance):
|
||||
instance.related_object = instance.user
|
||||
super().clean(instance)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||
return qs.filter(query).exclude(pk__in=listened).order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="actor-content")
|
||||
class ActorContentRadio(RelatedObjectRadio):
|
||||
"""
|
||||
Play content from given actor libraries
|
||||
"""
|
||||
|
||||
model = federation_models.Actor
|
||||
related_object_field = federation_fields.ActorRelatedField(required=True)
|
||||
|
||||
def get_related_object(self, value):
|
||||
return value
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
actor_uploads = Upload.objects.filter(
|
||||
library__actor=self.session.related_object,
|
||||
)
|
||||
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.full_username
|
||||
|
||||
|
||||
@registry.register(name="library")
|
||||
class LibraryRadio(RelatedObjectRadio):
|
||||
"""
|
||||
Play content from a given library
|
||||
"""
|
||||
|
||||
model = Library
|
||||
related_object_field = serializers.UUIDField(required=True)
|
||||
|
||||
def get_related_object(self, value):
|
||||
return Library.objects.get(uuid=value)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
actor_uploads = Upload.objects.filter(
|
||||
library=self.session.related_object,
|
||||
)
|
||||
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.uuid
|
||||
|
||||
|
||||
@registry.register(name="recently-added")
|
||||
class RecentlyAdded(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
date = datetime.date.today() - datetime.timedelta(days=30)
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(
|
||||
Q(artist__content_category="music"),
|
||||
Q(creation_date__gt=date),
|
||||
)
|
||||
|
||||
|
||||
# Use this to experiment on the custom multiple radio with troi
|
||||
@registry.register(name="troi")
|
||||
class Troi(SessionRadio):
|
||||
"""
|
||||
Receive a vuejs generated config and use it to launch a troi radio session.
|
||||
The config data should follow :
|
||||
{"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
|
||||
Validation of the config (args) is done by troi during track fetch.
|
||||
Funkwhale only checks if the patch is implemented
|
||||
"""
|
||||
|
||||
config = serializers.JSONField(required=True)
|
||||
|
||||
def append_lb_config(self, data):
|
||||
if self.session.user.settings is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
return data
|
||||
elif self.session.user.settings.get("lb_user_name") is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
else:
|
||||
data["user_name"] = self.session.user.settings["lb_user_name"]
|
||||
|
||||
if self.session.user.settings.get("lb_user_token") is None:
|
||||
logger.warning(
|
||||
"No lb_user_token set in user settings. Some troi patch will fail"
|
||||
)
|
||||
else:
|
||||
data["user_token"] = self.session.user.settings["lb_user_token"]
|
||||
|
||||
return data
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["config"] = self.session.config
|
||||
return kwargs
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
if data.get("config") is None:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
return data
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
config = self.append_lb_config(json.loads(kwargs["config"]))
|
||||
|
||||
return lb_recommendations.run(config, candidates=qs)
|
|
@ -0,0 +1,10 @@
|
|||
import persisting_theory
|
||||
|
||||
|
||||
class RadioRegistry_v2(persisting_theory.Registry):
|
||||
def prepare_name(self, data, name=None):
|
||||
setattr(data, "radio_type", name)
|
||||
return name
|
||||
|
||||
|
||||
registry = RadioRegistry_v2()
|
|
@ -40,9 +40,11 @@ class RadioSerializer(serializers.ModelSerializer):
|
|||
|
||||
|
||||
class RadioSessionTrackSerializerCreate(serializers.ModelSerializer):
|
||||
count = serializers.IntegerField(required=False, allow_null=True)
|
||||
|
||||
class Meta:
|
||||
model = models.RadioSessionTrack
|
||||
fields = ("session",)
|
||||
fields = ("session", "count")
|
||||
|
||||
|
||||
class RadioSessionTrackSerializer(serializers.ModelSerializer):
|
||||
|
|
|
@ -5,7 +5,7 @@ from . import views
|
|||
router = routers.OptionalSlashRouter()
|
||||
router.register(r"sessions", views.RadioSessionViewSet, "sessions")
|
||||
router.register(r"radios", views.RadioViewSet, "radios")
|
||||
router.register(r"tracks", views.RadioSessionTrackViewSet, "tracks")
|
||||
router.register(r"tracks", views.V1_RadioSessionTrackViewSet, "tracks")
|
||||
|
||||
|
||||
urlpatterns = router.urls
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
from funkwhale_api.common import routers
|
||||
|
||||
from . import views
|
||||
|
||||
router = routers.OptionalSlashRouter()
|
||||
|
||||
router.register(r"sessions", views.V2_RadioSessionViewSet, "sessions")
|
||||
|
||||
|
||||
urlpatterns = router.urls
|
|
@ -1,3 +1,6 @@
|
|||
import pickle
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Q
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import mixins, status, viewsets
|
||||
|
@ -121,7 +124,7 @@ class RadioSessionViewSet(
|
|||
return context
|
||||
|
||||
|
||||
class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
|
||||
class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
|
||||
serializer_class = serializers.RadioSessionTrackSerializer
|
||||
queryset = models.RadioSessionTrack.objects.all()
|
||||
permission_classes = []
|
||||
|
@ -133,21 +136,19 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
|
|||
session = serializer.validated_data["session"]
|
||||
if not request.user.is_authenticated and not request.session.session_key:
|
||||
self.request.session.create()
|
||||
try:
|
||||
assert (request.user == session.user) or (
|
||||
request.session.session_key == session.session_key
|
||||
and session.session_key
|
||||
)
|
||||
except AssertionError:
|
||||
if not request.user == session.user or (
|
||||
not request.session.session_key == session.session_key
|
||||
and not session.session_key
|
||||
):
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
try:
|
||||
session.radio.pick()
|
||||
session.radio(api_version=1).pick()
|
||||
except ValueError:
|
||||
return Response(
|
||||
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
session_track = session.session_tracks.all().latest("id")
|
||||
# self.perform_create(serializer)
|
||||
# dirty override here, since we use a different serializer for creation and detail
|
||||
serializer = self.serializer_class(
|
||||
instance=session_track, context=self.get_serializer_context()
|
||||
|
@ -161,3 +162,99 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
|
|||
if self.action == "create":
|
||||
return serializers.RadioSessionTrackSerializerCreate
|
||||
return super().get_serializer_class(*args, **kwargs)
|
||||
|
||||
|
||||
class V2_RadioSessionViewSet(
|
||||
mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet
|
||||
):
|
||||
"""Returns a list of RadioSessions"""
|
||||
|
||||
serializer_class = serializers.RadioSessionSerializer
|
||||
queryset = models.RadioSession.objects.all()
|
||||
permission_classes = []
|
||||
|
||||
@action(detail=True, serializer_class=serializers.RadioSessionTrackSerializerCreate)
|
||||
def tracks(self, request, pk, *args, **kwargs):
|
||||
data = {"session": pk}
|
||||
data["count"] = (
|
||||
request.query_params["count"]
|
||||
if "count" in request.query_params.keys()
|
||||
else 1
|
||||
)
|
||||
serializer = serializers.RadioSessionTrackSerializerCreate(data=data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
session = serializer.validated_data["session"]
|
||||
|
||||
count = int(data["count"])
|
||||
# this is used for test purpose.
|
||||
filter_playable = (
|
||||
request.query_params["filter_playable"]
|
||||
if "filter_playable" in request.query_params.keys()
|
||||
else True
|
||||
)
|
||||
if not request.user.is_authenticated and not request.session.session_key:
|
||||
self.request.session.create()
|
||||
|
||||
if not request.user == session.user or (
|
||||
not request.session.session_key == session.session_key
|
||||
and not session.session_key
|
||||
):
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
try:
|
||||
from . import radios_v2 # noqa
|
||||
|
||||
session.radio(api_version=2).pick_many(
|
||||
count, filter_playable=filter_playable
|
||||
)
|
||||
except ValueError:
|
||||
return Response(
|
||||
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# dirty override here, since we use a different serializer for creation and detail
|
||||
evaluated_radio_tracks = pickle.loads(cache.get(f"radiotracks{session.id}"))
|
||||
batch = evaluated_radio_tracks[:count]
|
||||
serializer = TrackSerializer(
|
||||
data=batch,
|
||||
many="true",
|
||||
)
|
||||
serializer.is_valid()
|
||||
|
||||
# delete the tracks we sent from the cache
|
||||
new_cached_radiotracks = evaluated_radio_tracks[count:]
|
||||
cache.set(f"radiotracks{session.id}", pickle.dumps(new_cached_radiotracks))
|
||||
|
||||
return Response(
|
||||
serializer.data,
|
||||
status=status.HTTP_201_CREATED,
|
||||
)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
if self.request.user.is_authenticated:
|
||||
return queryset.filter(
|
||||
Q(user=self.request.user)
|
||||
| Q(session_key=self.request.session.session_key)
|
||||
)
|
||||
|
||||
return queryset.filter(session_key=self.request.session.session_key).exclude(
|
||||
session_key=None
|
||||
)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
if (
|
||||
not self.request.user.is_authenticated
|
||||
and not self.request.session.session_key
|
||||
):
|
||||
self.request.session.create()
|
||||
return serializer.save(
|
||||
user=self.request.user if self.request.user.is_authenticated else None,
|
||||
session_key=self.request.session.session_key,
|
||||
)
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
context["user"] = (
|
||||
self.request.user if self.request.user.is_authenticated else None
|
||||
)
|
||||
return context
|
||||
|
|
|
@ -2,8 +2,8 @@ import json
|
|||
import random
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.urls import reverse
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from funkwhale_api.favorites.models import TrackFavorite
|
||||
from funkwhale_api.radios import models, radios, serializers
|
||||
|
@ -98,7 +98,7 @@ def test_can_get_choices_for_custom_radio(factories):
|
|||
session = factories["radios.CustomRadioSession"](
|
||||
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
|
||||
)
|
||||
choices = session.radio.get_choices(filter_playable=False)
|
||||
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||
|
||||
expected = [t.pk for t in tracks]
|
||||
assert list(choices.values_list("id", flat=True)) == expected
|
||||
|
@ -191,16 +191,17 @@ def test_can_get_track_for_session_from_api(factories, logged_in_api_client):
|
|||
|
||||
|
||||
def test_related_object_radio_validate_related_object(factories):
|
||||
user = factories["users.User"]()
|
||||
# cannot start without related object
|
||||
radio = radios.ArtistRadio()
|
||||
radio = {"radio_type": "tag"}
|
||||
serializer = serializers.RadioSessionSerializer()
|
||||
with pytest.raises(ValidationError):
|
||||
radio.start_session(user)
|
||||
serializer.validate(data=radio)
|
||||
|
||||
# cannot start with bad related object type
|
||||
radio = radios.ArtistRadio()
|
||||
radio = {"radio_type": "tag", "related_object": "whatever"}
|
||||
serializer = serializers.RadioSessionSerializer()
|
||||
with pytest.raises(ValidationError):
|
||||
radio.start_session(user, related_object=user)
|
||||
serializer.validate(data=radio)
|
||||
|
||||
|
||||
def test_can_start_artist_radio(factories):
|
||||
|
@ -391,7 +392,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories):
|
|||
{"type": "artist", "ids": [excluded_artist.pk], "not": True},
|
||||
]
|
||||
)
|
||||
choices = session.radio.get_choices(filter_playable=False)
|
||||
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||
|
||||
expected = [u.track.pk for u in included_uploads]
|
||||
assert list(choices.values_list("id", flat=True)) == expected
|
||||
|
@ -409,7 +410,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories):
|
|||
{"type": "tag", "names": ["rock"], "not": True},
|
||||
]
|
||||
)
|
||||
choices = session.radio.get_choices(filter_playable=False)
|
||||
choices = session.radio(api_version=1).get_choices(filter_playable=False)
|
||||
|
||||
expected = [u.track.pk for u in included_uploads]
|
||||
assert list(choices.values_list("id", flat=True)) == expected
|
||||
|
@ -429,28 +430,3 @@ def test_can_start_custom_multiple_radio_from_api(api_client, factories):
|
|||
format="json",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
def test_can_start_periodic_jams_troi_radio_from_api(api_client, factories):
|
||||
factories["music.Track"].create_batch(5)
|
||||
url = reverse("api:v1:radios:sessions-list")
|
||||
config = {"patch": "periodic-jams", "type": "daily-jams"}
|
||||
response = api_client.post(
|
||||
url,
|
||||
{"radio_type": "troi", "config": config},
|
||||
format="json",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
|
||||
# to do : send error to api ?
|
||||
def test_can_catch_troi_radio_error(api_client, factories):
|
||||
factories["music.Track"].create_batch(5)
|
||||
url = reverse("api:v1:radios:sessions-list")
|
||||
config = {"patch": "periodic-jams", "type": "not_existing_type"}
|
||||
response = api_client.post(
|
||||
url,
|
||||
{"radio_type": "troi", "config": config},
|
||||
format="json",
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.urls import reverse
|
||||
|
||||
from funkwhale_api.favorites.models import TrackFavorite
|
||||
from funkwhale_api.radios import models, radios_v2
|
||||
|
||||
|
||||
def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client):
|
||||
actor = logged_in_api_client.user.create_actor()
|
||||
track = factories["music.Upload"](
|
||||
library__actor=actor, import_status="finished"
|
||||
).track
|
||||
url = reverse("api:v2:radios:sessions-list")
|
||||
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||
session = models.RadioSession.objects.latest("id")
|
||||
|
||||
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||
data = json.loads(response.content.decode("utf-8"))
|
||||
|
||||
assert data[0]["id"] == track.pk
|
||||
|
||||
next_track = factories["music.Upload"](
|
||||
library__actor=actor, import_status="finished"
|
||||
).track
|
||||
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||
data = json.loads(response.content.decode("utf-8"))
|
||||
|
||||
assert data[0]["id"] == next_track.id
|
||||
|
||||
|
||||
def test_can_use_radio_session_to_filter_choices_v2(factories):
|
||||
factories["music.Upload"].create_batch(10)
|
||||
user = factories["users.User"]()
|
||||
radio = radios_v2.RandomRadio()
|
||||
session = radio.start_session(user)
|
||||
|
||||
radio.pick_many(quantity=10, filter_playable=False)
|
||||
|
||||
# ensure 10 different tracks have been suggested
|
||||
tracks_id = [
|
||||
session_track.track.pk for session_track in session.session_tracks.all()
|
||||
]
|
||||
assert len(set(tracks_id)) == 10
|
||||
|
||||
|
||||
def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_client):
|
||||
tracks = factories["music.Track"].create_batch(5)
|
||||
url = reverse("api:v2:radios:sessions-list")
|
||||
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||
session = models.RadioSession.objects.latest("id")
|
||||
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||
|
||||
previous_choices = []
|
||||
|
||||
for i in range(5):
|
||||
response = logged_in_api_client.get(
|
||||
url, {"session": session.pk, "filter_playable": False}
|
||||
)
|
||||
pick = json.loads(response.content.decode("utf-8"))
|
||||
assert pick[0]["title"] not in previous_choices
|
||||
assert pick[0]["title"] in [t.title for t in tracks]
|
||||
previous_choices.append(pick[0]["title"])
|
||||
|
||||
response = logged_in_api_client.get(url, {"session": session.pk})
|
||||
assert (
|
||||
json.loads(response.content.decode("utf-8"))
|
||||
== "Radio doesn't have more candidates"
|
||||
)
|
||||
|
||||
|
||||
def test_can_get_choices_for_favorites_radio_v2(factories):
|
||||
files = factories["music.Upload"].create_batch(10)
|
||||
tracks = [f.track for f in files]
|
||||
user = factories["users.User"]()
|
||||
for i in range(5):
|
||||
TrackFavorite.add(track=random.choice(tracks), user=user)
|
||||
|
||||
radio = radios_v2.FavoritesRadio()
|
||||
session = radio.start_session(user=user)
|
||||
choices = session.radio(api_version=2).get_choices(
|
||||
quantity=100, filter_playable=False
|
||||
)
|
||||
|
||||
assert len(choices) == user.track_favorites.all().count()
|
||||
|
||||
for favorite in user.track_favorites.all():
|
||||
assert favorite.track in choices
|
||||
|
||||
|
||||
def test_can_get_choices_for_custom_radio_v2(factories):
|
||||
artist = factories["music.Artist"]()
|
||||
files = factories["music.Upload"].create_batch(5, track__artist=artist)
|
||||
tracks = [f.track for f in files]
|
||||
factories["music.Upload"].create_batch(5)
|
||||
|
||||
session = factories["radios.CustomRadioSession"](
|
||||
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
|
||||
)
|
||||
choices = session.radio(api_version=2).get_choices(
|
||||
quantity=1, filter_playable=False
|
||||
)
|
||||
|
||||
expected = [t.pk for t in tracks]
|
||||
for t in choices:
|
||||
assert t.id in expected
|
||||
|
||||
|
||||
def test_can_cache_radio_track(factories):
|
||||
uploads = factories["music.Track"].create_batch(10)
|
||||
user = factories["users.User"]()
|
||||
radio = radios_v2.RandomRadio()
|
||||
session = radio.start_session(user)
|
||||
picked = session.radio(api_version=2).pick_many(quantity=1, filter_playable=False)
|
||||
assert len(picked) == 1
|
||||
for t in pickle.loads(cache.get(f"radiotracks{session.id}")):
|
||||
assert t in uploads
|
||||
|
||||
|
||||
def test_regenerate_cache_if_not_enought_tracks_in_it(
|
||||
factories, caplog, logged_in_api_client
|
||||
):
|
||||
logger = logging.getLogger("funkwhale_api.radios.radios_v2")
|
||||
caplog.set_level(logging.INFO)
|
||||
logger.addHandler(caplog.handler)
|
||||
|
||||
factories["music.Track"].create_batch(10)
|
||||
factories["users.User"]()
|
||||
url = reverse("api:v2:radios:sessions-list")
|
||||
response = logged_in_api_client.post(url, {"radio_type": "random"})
|
||||
session = models.RadioSession.objects.latest("id")
|
||||
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
|
||||
logged_in_api_client.get(url, {"count": 9, "filter_playable": False})
|
||||
response = logged_in_api_client.get(url, {"count": 10, "filter_playable": False})
|
||||
pick = json.loads(response.content.decode("utf-8"))
|
||||
assert (
|
||||
"Not enough radio tracks in cache. Trying to generate new cache" in caplog.text
|
||||
)
|
||||
assert len(pick) == 1
|
|
@ -0,0 +1 @@
|
|||
Cache radio queryset into redis. New radio track endpoint for api v2 is /api/v2/radios/sessions/{radiosessionid}/tracks (#2135)
|
|
@ -98,6 +98,8 @@ services:
|
|||
env_file:
|
||||
- .env
|
||||
image: typesense/typesense:0.24.0
|
||||
networks:
|
||||
- internal
|
||||
volumes:
|
||||
- ./typesense/data:/data
|
||||
command: --data-dir /data --enable-cors
|
||||
|
|
Loading…
Reference in New Issue