import datetime import logging import pickle import random from typing import List, Optional, Tuple from django.core.cache import cache from django.core.exceptions import ValidationError from django.db import connection from django.db.models import Q from rest_framework import serializers from funkwhale_api.federation import fields as federation_fields from funkwhale_api.federation import models as federation_models from funkwhale_api.moderation import filters as moderation_filters from funkwhale_api.music.models import Artist, Library, Track, Upload from funkwhale_api.tags.models import Tag from . import filters, lb_recommendations, models from .registries_v2 import registry logger = logging.getLogger(__name__) class SimpleRadio: related_object_field = None def clean(self, instance): return def weighted_pick( self, choices: List[Tuple[int, int]], previous_choices: Optional[List[int]] = None, ) -> int: total = sum(weight for c, weight in choices) r = random.uniform(0, total) upto = 0 for choice, weight in choices: if upto + weight >= r: return choice upto += weight class SessionRadio(SimpleRadio): def __init__(self, session=None): self.session = session def start_session(self, user, **kwargs): self.session = models.RadioSession.objects.create( user=user, radio_type=self.radio_type, **kwargs ) return self.session def get_queryset(self, **kwargs): actor = None try: actor = self.session.user.actor except KeyError: pass # Maybe logging would be helpful qs = ( Track.objects.all() .with_playable_uploads(actor=actor) .prefetch_related( "artist_credit__artist", "album__artist_credit__artist", "attributed_to" ) ) query = moderation_filters.get_filtered_content_query( config=moderation_filters.USER_FILTER_CONFIG["TRACK"], user=self.session.user, ) return qs.exclude(query) def get_queryset_kwargs(self): return {} def filter_queryset(self, queryset): return queryset def filter_from_session(self, queryset): already_played = self.session.session_tracks.all().values_list( "track", flat=True ) queryset = queryset.exclude(pk__in=already_played) return queryset def cache_batch_radio_track(self, **kwargs): BATCH_SIZE = 100 # get cached RadioTracks if any try: cached_evaluated_radio_tracks = pickle.loads( cache.get(f"radiotracks{self.session.id}") ) except TypeError: cached_evaluated_radio_tracks = None # get the queryset and apply filters kwargs.update(self.get_queryset_kwargs()) queryset = self.get_queryset(**kwargs) queryset = self.filter_from_session(queryset) if kwargs["filter_playable"] is True: queryset = queryset.playable_by( self.session.user.actor if self.session.user else None ) queryset = self.filter_queryset(queryset) # select a random batch of the qs sliced_queryset = queryset.order_by("?")[:BATCH_SIZE] if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks: raise ValueError("No more radio candidates") # create the radio session tracks into db in bulk self.session.add(sliced_queryset) # evaluate the queryset to save it in cache radio_tracks = list(sliced_queryset) if cached_evaluated_radio_tracks is not None: radio_tracks.extend(cached_evaluated_radio_tracks) logger.info( f"Setting redis cache for radio generation with radio id {self.session.id}" ) cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600) cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600) return sliced_queryset def get_choices(self, quantity, **kwargs): if cache.get(f"radiotracks{self.session.id}"): cached_radio_tracks = pickle.loads( cache.get(f"radiotracks{self.session.id}") ) logger.info("Using redis cache for radio generation") radio_tracks = cached_radio_tracks if len(radio_tracks) < quantity: logger.info( "Not enough radio tracks in cache. Trying to generate new cache" ) sliced_queryset = self.cache_batch_radio_track(**kwargs) sliced_queryset = cache.get(f"radioqueryset{self.session.id}") else: sliced_queryset = self.cache_batch_radio_track(**kwargs) return sliced_queryset[:quantity] def pick_many(self, quantity, **kwargs): if self.session: sliced_queryset = self.get_choices(quantity=quantity, **kwargs) else: logger.info( "No radio session. Can't track user playback. Won't cache queryset results" ) sliced_queryset = self.get_choices(quantity=quantity, **kwargs) return sliced_queryset def validate_session(self, data, **context): return data @registry.register(name="random") class RandomRadio(SessionRadio): def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) return qs.filter(artist_credit__artist__content_category="music").order_by("?") @registry.register(name="random_library") class RandomLibraryRadio(SessionRadio): def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) tracks_ids = self.session.user.actor.attributed_tracks.all().values_list( "id", flat=True ) query = Q(artist_credit__artist__content_category="music") & Q( pk__in=tracks_ids ) return qs.filter(query).order_by("?") @registry.register(name="favorites") class FavoritesRadio(SessionRadio): def get_queryset_kwargs(self): kwargs = super().get_queryset_kwargs() if self.session: kwargs["user"] = self.session.user return kwargs def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) track_ids = ( kwargs["user"].actor.track_favorites.all().values_list("track", flat=True) ) return qs.filter( pk__in=track_ids, artist_credit__artist__content_category="music" ) @registry.register(name="custom") class CustomRadio(SessionRadio): def get_queryset_kwargs(self): kwargs = super().get_queryset_kwargs() kwargs["user"] = self.session.user kwargs["custom_radio"] = self.session.custom_radio return kwargs def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) return filters.run(kwargs["custom_radio"].config, candidates=qs) def validate_session(self, data, **context): data = super().validate_session(data, **context) try: user = data["user"] except KeyError: user = context.get("user") try: assert data["custom_radio"].user == user or data["custom_radio"].is_public except KeyError: raise serializers.ValidationError("You must provide a custom radio") except AssertionError: raise serializers.ValidationError("You don't have access to this radio") return data @registry.register(name="custom_multiple") class CustomMultiple(SessionRadio): """ Receive a vuejs generated config and use it to launch a radio session """ config = serializers.JSONField(required=True) def get_config(self, data): return data["config"] def get_queryset_kwargs(self): kwargs = super().get_queryset_kwargs() kwargs["config"] = self.session.config return kwargs def validate_session(self, data, **context): data = super().validate_session(data, **context) try: data["config"] is not None except KeyError: raise serializers.ValidationError( "You must provide a configuration for this radio" ) return data def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) return filters.run([kwargs["config"]], candidates=qs) class RelatedObjectRadio(SessionRadio): """Abstract radio related to an object (tag, artist, user...)""" related_object_field = serializers.IntegerField(required=True) def clean(self, instance): super().clean(instance) if not instance.related_object: raise ValidationError( "Cannot start RelatedObjectRadio without related object" ) if not isinstance(instance.related_object, self.model): raise ValidationError("Trying to start radio with bad related object") def get_related_object(self, pk): return self.model.objects.get(pk=pk) @registry.register(name="tag") class TagRadio(RelatedObjectRadio): model = Tag related_object_field = serializers.CharField(required=True) def get_related_object(self, name): return self.model.objects.get(name=name) def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) query = ( Q(tagged_items__tag=self.session.related_object) | Q(artist__tagged_items__tag=self.session.related_object) | Q(album__tagged_items__tag=self.session.related_object) ) return qs.filter(query) def get_related_object_id_repr(self, obj): return obj.name def weighted_choice(choices): total = sum(w for c, w in choices) r = random.uniform(0, total) upto = 0 for c, w in choices: if upto + w >= r: return c upto += w assert False, "Shouldn't get here" class NextNotFound(Exception): pass @registry.register(name="similar") class SimilarRadio(RelatedObjectRadio): model = Track def filter_queryset(self, queryset): queryset = super().filter_queryset(queryset) seeds = list( self.session.session_tracks.all() .values_list("track_id", flat=True) .order_by("-id")[:3] ) + [self.session.related_object.pk] for seed in seeds: try: return queryset.filter(pk=self.find_next_id(queryset, seed)) except NextNotFound: continue return queryset.none() def find_next_id(self, queryset, seed): with connection.cursor() as cursor: query = """ SELECT next, count(next) AS c FROM ( SELECT history_listening.track_id, history_listening.creation_date, LEAD(history_listening.track_id) OVER ( PARTITION BY history_listening.actor_id ORDER BY history_listening.creation_date ASC ) AS next FROM history_listening INNER JOIN federation_actor ON federation_actor.id = history_listening.actor_id INNER JOIN users_user ON users_user.actor_id = federation_actor.id WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' \ OR history_listening.actor_id = %s ORDER BY history_listening.creation_date ASC ) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC; """ cursor.execute(query, [self.session.user_id, seed, seed]) next_candidates = list(cursor.fetchall()) if not next_candidates: raise NextNotFound() matching_tracks = list( queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list( "id", flat=True ) ) next_candidates = [n for n in next_candidates if n[0] in matching_tracks] if not next_candidates: raise NextNotFound() return random.choice([c[0] for c in next_candidates]) @registry.register(name="artist") class ArtistRadio(RelatedObjectRadio): model = Artist def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) return qs.filter(artist_credit__artist=self.session.related_object) @registry.register(name="less-listened") class LessListenedRadio(SessionRadio): def clean(self, instance): instance.related_object = instance.user super().clean(instance) def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) listened = self.session.user.actor.listenings.all().values_list( "track", flat=True ) return ( qs.filter(artist_credit__artist__content_category="music") .exclude(pk__in=listened) .order_by("?") ) @registry.register(name="less-listened_library") class LessListenedLibraryRadio(SessionRadio): def clean(self, instance): instance.related_object = instance.user super().clean(instance) def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) listened = self.session.user.actor.listenings.all().values_list( "track", flat=True ) tracks_ids = self.session.user.actor.attributed_tracks.all().values_list( "id", flat=True ) query = Q(artist_credit__artist__content_category="music") & Q( pk__in=tracks_ids ) return qs.filter(query).exclude(pk__in=listened).order_by("?") @registry.register(name="actor-content") class ActorContentRadio(RelatedObjectRadio): """ Play content from given actor libraries """ model = federation_models.Actor related_object_field = federation_fields.ActorRelatedField(required=True) def get_related_object(self, value): return value def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) actor_uploads = Upload.objects.filter( library__actor=self.session.related_object, ) return qs.filter(pk__in=actor_uploads.values("track")) def get_related_object_id_repr(self, obj): return obj.full_username @registry.register(name="library") class LibraryRadio(RelatedObjectRadio): """ Play content from a given library """ model = Library related_object_field = serializers.UUIDField(required=True) def get_related_object(self, value): return Library.objects.get(uuid=value) def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) actor_uploads = Upload.objects.filter( library=self.session.related_object, ) return qs.filter(pk__in=actor_uploads.values("track")) def get_related_object_id_repr(self, obj): return obj.uuid @registry.register(name="recently-added") class RecentlyAdded(SessionRadio): def get_queryset(self, **kwargs): date = datetime.date.today() - datetime.timedelta(days=30) qs = super().get_queryset(**kwargs) return qs.filter( Q(artist_credit__artist__content_category="music"), Q(creation_date__gt=date), ) # Use this to experiment on the custom multiple radio with troi @registry.register(name="troi") class Troi(SessionRadio): """ Receive a vuejs generated config and use it to launch a troi radio session. The config data should follow : {"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...} Validation of the config (args) is done by troi during track fetch. Funkwhale only checks if the patch is implemented """ config = serializers.JSONField(required=True) def append_lb_config(self, data): if self.session.user.settings is None: logger.warning( "No lb_user_name set in user settings. Some troi patches will fail" ) return data elif self.session.user.settings.get("lb_user_name") is None: logger.warning( "No lb_user_name set in user settings. Some troi patches will fail" ) else: data["user_name"] = self.session.user.settings["lb_user_name"] if self.session.user.settings.get("lb_user_token") is None: logger.warning( "No lb_user_token set in user settings. Some troi patch will fail" ) else: data["user_token"] = self.session.user.settings["lb_user_token"] return data def get_queryset_kwargs(self): kwargs = super().get_queryset_kwargs() kwargs["config"] = self.session.config return kwargs def validate_session(self, data, **context): data = super().validate_session(data, **context) if data.get("config") is None: raise serializers.ValidationError( "You must provide a configuration for this radio" ) return data def get_queryset(self, **kwargs): qs = super().get_queryset(**kwargs) config = self.append_lb_config(kwargs["config"]) return lb_recommendations.run(config, candidates=qs)