add v2 interface for radios
This commit is contained in:
parent
379672f409
commit
eb3600066e
|
@ -55,7 +55,7 @@ class RadioSession(models.Model):
|
|||
config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True)
|
||||
|
||||
def save(self, **kwargs):
|
||||
self.radio.clean(self)
|
||||
# self.radio.clean(self)
|
||||
super().save(**kwargs)
|
||||
|
||||
@property
|
||||
|
@ -81,11 +81,11 @@ class RadioSession(models.Model):
|
|||
|
||||
return new_session_tracks
|
||||
|
||||
@property
|
||||
def radio(self):
|
||||
def radio(self, api_version):
|
||||
from .registries import registry
|
||||
|
||||
return registry[self.radio_type](session=self)
|
||||
radio_type = self.radio_type + api_version
|
||||
return registry[radio_type](session=self)
|
||||
|
||||
|
||||
class RadioSessionTrack(models.Model):
|
||||
|
|
|
@ -0,0 +1,511 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import connection
|
||||
from django.db.models import Q
|
||||
from rest_framework import serializers
|
||||
|
||||
from funkwhale_api.federation import fields as federation_fields
|
||||
from funkwhale_api.federation import models as federation_models
|
||||
from funkwhale_api.moderation import filters as moderation_filters
|
||||
from funkwhale_api.music.models import Artist, Library, Track, Upload
|
||||
from funkwhale_api.tags.models import Tag
|
||||
|
||||
from . import filters, lb_recommendations, models
|
||||
from .registries import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleRadio:
|
||||
related_object_field = None
|
||||
|
||||
def clean(self, instance):
|
||||
return
|
||||
|
||||
def weighted_pick(
|
||||
self,
|
||||
choices: List[Tuple[int, int]],
|
||||
previous_choices: Optional[List[int]] = None,
|
||||
) -> int:
|
||||
total = sum(weight for c, weight in choices)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for choice, weight in choices:
|
||||
if upto + weight >= r:
|
||||
return choice
|
||||
upto += weight
|
||||
|
||||
|
||||
class SessionRadio(SimpleRadio):
|
||||
def __init__(self, session=None):
|
||||
self.session = session
|
||||
|
||||
def start_session(self, user, **kwargs):
|
||||
self.session = models.RadioSession.objects.create(
|
||||
user=user, radio_type=self.radio_type, **kwargs
|
||||
)
|
||||
return self.session
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
if not self.session or not self.session.user:
|
||||
return (
|
||||
Track.objects.all()
|
||||
.with_playable_uploads(actor=None)
|
||||
.select_related("artist", "album__artist", "attributed_to")
|
||||
)
|
||||
else:
|
||||
qs = (
|
||||
Track.objects.all()
|
||||
.with_playable_uploads(self.session.user.actor)
|
||||
.select_related("artist", "album__artist", "attributed_to")
|
||||
)
|
||||
|
||||
query = moderation_filters.get_filtered_content_query(
|
||||
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
|
||||
user=self.session.user,
|
||||
)
|
||||
return qs.exclude(query)
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
return {}
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
return queryset
|
||||
|
||||
def filter_from_session(self, queryset):
|
||||
already_played = self.session.session_tracks.all().values_list(
|
||||
"track", flat=True
|
||||
)
|
||||
queryset = queryset.exclude(pk__in=already_played)
|
||||
return queryset
|
||||
|
||||
def cache_batch_radio_track(self, **kwargs):
|
||||
BATCH_SIZE = 100
|
||||
# get cached RadioTracks if any
|
||||
try:
|
||||
cached_evaluated_radio_tracks = pickle.loads(
|
||||
cache.get(f"radiotracks{self.session.id}")
|
||||
)
|
||||
except TypeError:
|
||||
cached_evaluated_radio_tracks = None
|
||||
|
||||
# get the queryset and apply filters
|
||||
kwargs.update(self.get_queryset_kwargs())
|
||||
queryset = self.get_queryset(**kwargs)
|
||||
queryset = self.filter_from_session(queryset)
|
||||
|
||||
if kwargs["filter_playable"] is True:
|
||||
queryset = queryset.playable_by(
|
||||
self.session.user.actor if self.session.user else None
|
||||
)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
|
||||
# select a random batch of the qs
|
||||
sliced_queryset = queryset.order_by("?")[:BATCH_SIZE]
|
||||
if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks:
|
||||
raise ValueError("No more radio candidates")
|
||||
|
||||
# create the radio session tracks into db in bulk
|
||||
self.session.add(sliced_queryset)
|
||||
|
||||
# evaluate the queryset to save it in cache
|
||||
radio_tracks = list(sliced_queryset)
|
||||
|
||||
if cached_evaluated_radio_tracks is not None:
|
||||
radio_tracks.extend(cached_evaluated_radio_tracks)
|
||||
logger.info(
|
||||
f"Setting redis cache for radio generation with radio id {self.session.id}"
|
||||
)
|
||||
cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600)
|
||||
cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600)
|
||||
|
||||
return sliced_queryset
|
||||
|
||||
def get_choices_v2(self, quantity, **kwargs):
|
||||
if cache.get(f"radiotracks{self.session.id}"):
|
||||
cached_radio_tracks = pickle.loads(
|
||||
cache.get(f"radiotracks{self.session.id}")
|
||||
)
|
||||
logger.info("Using redis cache for radio generation")
|
||||
radio_tracks = cached_radio_tracks
|
||||
if len(radio_tracks) < quantity:
|
||||
logger.info(
|
||||
"Not enough radio tracks in cache. Trying to generate new cache"
|
||||
)
|
||||
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||
sliced_queryset = cache.get(f"radioqueryset{self.session.id}")
|
||||
else:
|
||||
sliced_queryset = self.cache_batch_radio_track(**kwargs)
|
||||
|
||||
return sliced_queryset[:quantity]
|
||||
|
||||
def pick_many_v2(self, quantity, **kwargs):
|
||||
if self.session:
|
||||
sliced_queryset = self.get_choices_v2(quantity=quantity, **kwargs)
|
||||
else:
|
||||
logger.info(
|
||||
"No radio session. Can't track user playback. Won't cache queryset results"
|
||||
)
|
||||
sliced_queryset = self.get_choices_v2(quantity=quantity, **kwargs)
|
||||
|
||||
return sliced_queryset
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
return data
|
||||
|
||||
|
||||
@registry.register(name="random_v2")
|
||||
class RandomRadio(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(artist__content_category="music").order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="random_library_v2")
|
||||
class RandomLibraryRadio(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||
return qs.filter(query).order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="favorites_v2")
|
||||
class FavoritesRadio(SessionRadio):
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
if self.session:
|
||||
kwargs["user"] = self.session.user
|
||||
return kwargs
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True)
|
||||
return qs.filter(pk__in=track_ids, artist__content_category="music")
|
||||
|
||||
|
||||
@registry.register(name="custom_v2")
|
||||
class CustomRadio(SessionRadio):
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["user"] = self.session.user
|
||||
kwargs["custom_radio"] = self.session.custom_radio
|
||||
return kwargs
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return filters.run(kwargs["custom_radio"].config, candidates=qs)
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
try:
|
||||
user = data["user"]
|
||||
except KeyError:
|
||||
user = context.get("user")
|
||||
try:
|
||||
assert data["custom_radio"].user == user or data["custom_radio"].is_public
|
||||
except KeyError:
|
||||
raise serializers.ValidationError("You must provide a custom radio")
|
||||
except AssertionError:
|
||||
raise serializers.ValidationError("You don't have access to this radio")
|
||||
return data
|
||||
|
||||
|
||||
@registry.register(name="custom_multiple_v2")
|
||||
class CustomMultiple(SessionRadio):
|
||||
"""
|
||||
Receive a vuejs generated config and use it to launch a radio session
|
||||
"""
|
||||
|
||||
config = serializers.JSONField(required=True)
|
||||
|
||||
def get_config(self, data):
|
||||
return data["config"]
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["config"] = self.session.config
|
||||
return kwargs
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
try:
|
||||
data["config"] is not None
|
||||
except KeyError:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
return data
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return filters.run([kwargs["config"]], candidates=qs)
|
||||
|
||||
|
||||
class RelatedObjectRadio(SessionRadio):
|
||||
"""Abstract radio related to an object (tag, artist, user...)"""
|
||||
|
||||
related_object_field = serializers.IntegerField(required=True)
|
||||
|
||||
def clean(self, instance):
|
||||
super().clean(instance)
|
||||
if not instance.related_object:
|
||||
raise ValidationError(
|
||||
"Cannot start RelatedObjectRadio without related object"
|
||||
)
|
||||
if not isinstance(instance.related_object, self.model):
|
||||
raise ValidationError("Trying to start radio with bad related object")
|
||||
|
||||
def get_related_object(self, pk):
|
||||
return self.model.objects.get(pk=pk)
|
||||
|
||||
|
||||
@registry.register(name="tag_v2")
|
||||
class TagRadio(RelatedObjectRadio):
|
||||
model = Tag
|
||||
related_object_field = serializers.CharField(required=True)
|
||||
|
||||
def get_related_object(self, name):
|
||||
return self.model.objects.get(name=name)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
query = (
|
||||
Q(tagged_items__tag=self.session.related_object)
|
||||
| Q(artist__tagged_items__tag=self.session.related_object)
|
||||
| Q(album__tagged_items__tag=self.session.related_object)
|
||||
)
|
||||
return qs.filter(query)
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.name
|
||||
|
||||
|
||||
def weighted_choice(choices):
|
||||
total = sum(w for c, w in choices)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for c, w in choices:
|
||||
if upto + w >= r:
|
||||
return c
|
||||
upto += w
|
||||
assert False, "Shouldn't get here"
|
||||
|
||||
|
||||
class NextNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register(name="similar_v2")
|
||||
class SimilarRadio(RelatedObjectRadio):
|
||||
model = Track
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
seeds = list(
|
||||
self.session.session_tracks.all()
|
||||
.values_list("track_id", flat=True)
|
||||
.order_by("-id")[:3]
|
||||
) + [self.session.related_object.pk]
|
||||
for seed in seeds:
|
||||
try:
|
||||
return queryset.filter(pk=self.find_next_id(queryset, seed))
|
||||
except NextNotFound:
|
||||
continue
|
||||
|
||||
return queryset.none()
|
||||
|
||||
def find_next_id(self, queryset, seed):
|
||||
with connection.cursor() as cursor:
|
||||
query = """
|
||||
SELECT next, count(next) AS c
|
||||
FROM (
|
||||
SELECT
|
||||
track_id,
|
||||
creation_date,
|
||||
LEAD(track_id) OVER (
|
||||
PARTITION by user_id order by creation_date asc
|
||||
) AS next
|
||||
FROM history_listening
|
||||
INNER JOIN users_user ON (users_user.id = user_id)
|
||||
WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s
|
||||
ORDER BY creation_date ASC
|
||||
) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
|
||||
"""
|
||||
cursor.execute(query, [self.session.user_id, seed, seed])
|
||||
next_candidates = list(cursor.fetchall())
|
||||
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
|
||||
matching_tracks = list(
|
||||
queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
return random.choice([c[0] for c in next_candidates])
|
||||
|
||||
|
||||
@registry.register(name="artist_v2")
|
||||
class ArtistRadio(RelatedObjectRadio):
|
||||
model = Artist
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(artist=self.session.related_object)
|
||||
|
||||
|
||||
@registry.register(name="less-listened_v2")
|
||||
class LessListenedRadio(SessionRadio):
|
||||
def clean(self, instance):
|
||||
instance.related_object = instance.user
|
||||
super().clean(instance)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||
return (
|
||||
qs.filter(artist__content_category="music")
|
||||
.exclude(pk__in=listened)
|
||||
.order_by("?")
|
||||
)
|
||||
|
||||
|
||||
@registry.register(name="less-listened_library_v2")
|
||||
class LessListenedLibraryRadio(SessionRadio):
|
||||
def clean(self, instance):
|
||||
instance.related_object = instance.user
|
||||
super().clean(instance)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
||||
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
|
||||
return qs.filter(query).exclude(pk__in=listened).order_by("?")
|
||||
|
||||
|
||||
@registry.register(name="actor-content_v2")
|
||||
class ActorContentRadio(RelatedObjectRadio):
|
||||
"""
|
||||
Play content from given actor libraries
|
||||
"""
|
||||
|
||||
model = federation_models.Actor
|
||||
related_object_field = federation_fields.ActorRelatedField(required=True)
|
||||
|
||||
def get_related_object(self, value):
|
||||
return value
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
actor_uploads = Upload.objects.filter(
|
||||
library__actor=self.session.related_object,
|
||||
)
|
||||
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.full_username
|
||||
|
||||
|
||||
@registry.register(name="library_v2")
|
||||
class LibraryRadio(RelatedObjectRadio):
|
||||
"""
|
||||
Play content from a given library
|
||||
"""
|
||||
|
||||
model = Library
|
||||
related_object_field = serializers.UUIDField(required=True)
|
||||
|
||||
def get_related_object(self, value):
|
||||
return Library.objects.get(uuid=value)
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
actor_uploads = Upload.objects.filter(
|
||||
library=self.session.related_object,
|
||||
)
|
||||
return qs.filter(pk__in=actor_uploads.values("track"))
|
||||
|
||||
def get_related_object_id_repr(self, obj):
|
||||
return obj.uuid
|
||||
|
||||
|
||||
@registry.register(name="recently-added_v2")
|
||||
class RecentlyAdded(SessionRadio):
|
||||
def get_queryset(self, **kwargs):
|
||||
date = datetime.date.today() - datetime.timedelta(days=30)
|
||||
qs = super().get_queryset(**kwargs)
|
||||
return qs.filter(
|
||||
Q(artist__content_category="music"),
|
||||
Q(creation_date__gt=date),
|
||||
)
|
||||
|
||||
|
||||
# Use this to experiment on the custom multiple radio with troi
|
||||
@registry.register(name="troi_v2")
|
||||
class Troi(SessionRadio):
|
||||
"""
|
||||
Receive a vuejs generated config and use it to launch a troi radio session.
|
||||
The config data should follow :
|
||||
{"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
|
||||
Validation of the config (args) is done by troi during track fetch.
|
||||
Funkwhale only checks if the patch is implemented
|
||||
"""
|
||||
|
||||
config = serializers.JSONField(required=True)
|
||||
|
||||
def append_lb_config(self, data):
|
||||
if self.session.user.settings is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
return data
|
||||
elif self.session.user.settings.get("lb_user_name") is None:
|
||||
logger.warning(
|
||||
"No lb_user_name set in user settings. Some troi patches will fail"
|
||||
)
|
||||
else:
|
||||
data["user_name"] = self.session.user.settings["lb_user_name"]
|
||||
|
||||
if self.session.user.settings.get("lb_user_token") is None:
|
||||
logger.warning(
|
||||
"No lb_user_token set in user settings. Some troi patch will fail"
|
||||
)
|
||||
else:
|
||||
data["user_token"] = self.session.user.settings["lb_user_token"]
|
||||
|
||||
return data
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
kwargs = super().get_queryset_kwargs()
|
||||
kwargs["config"] = self.session.config
|
||||
return kwargs
|
||||
|
||||
def validate_session(self, data, **context):
|
||||
data = super().validate_session(data, **context)
|
||||
if data.get("config") is None:
|
||||
raise serializers.ValidationError(
|
||||
"You must provide a configuration for this radio"
|
||||
)
|
||||
return data
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = super().get_queryset(**kwargs)
|
||||
config = self.append_lb_config(json.loads(kwargs["config"]))
|
||||
|
||||
return lb_recommendations.run(config, candidates=qs)
|
|
@ -206,7 +206,14 @@ class V2_RadioSessionViewSet(
|
|||
):
|
||||
return Response(status=status.HTTP_403_FORBIDDEN)
|
||||
try:
|
||||
session.radio.pick_many_v2(count, filter_playable=filter_playable)
|
||||
# needed for for registeries, and we need to use it for linter
|
||||
from . import radio_v2
|
||||
|
||||
radio_v2.datetime()
|
||||
|
||||
session.radio(api_version="_v2").pick_many_v2(
|
||||
count, filter_playable=filter_playable
|
||||
)
|
||||
except ValueError:
|
||||
return Response(
|
||||
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
|
||||
|
|
Loading…
Reference in New Issue