527 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			527 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
| import datetime
 | |
| import json
 | |
| import logging
 | |
| import pickle
 | |
| import random
 | |
| from typing import List, Optional, Tuple
 | |
| 
 | |
| from django.core.cache import cache
 | |
| from django.core.exceptions import ValidationError
 | |
| from django.db import connection
 | |
| from django.db.models import Q
 | |
| from rest_framework import serializers
 | |
| 
 | |
| from funkwhale_api.federation import fields as federation_fields
 | |
| from funkwhale_api.federation import models as federation_models
 | |
| from funkwhale_api.moderation import filters as moderation_filters
 | |
| from funkwhale_api.music.models import Artist, Library, Track, Upload
 | |
| from funkwhale_api.tags.models import Tag
 | |
| 
 | |
| from . import filters, lb_recommendations, models
 | |
| from .registries_v2 import registry
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class SimpleRadio:
 | |
|     related_object_field = None
 | |
| 
 | |
|     def clean(self, instance):
 | |
|         return
 | |
| 
 | |
|     def weighted_pick(
 | |
|         self,
 | |
|         choices: List[Tuple[int, int]],
 | |
|         previous_choices: Optional[List[int]] = None,
 | |
|     ) -> int:
 | |
|         total = sum(weight for c, weight in choices)
 | |
|         r = random.uniform(0, total)
 | |
|         upto = 0
 | |
|         for choice, weight in choices:
 | |
|             if upto + weight >= r:
 | |
|                 return choice
 | |
|             upto += weight
 | |
| 
 | |
| 
 | |
| class SessionRadio(SimpleRadio):
 | |
|     def __init__(self, session=None):
 | |
|         self.session = session
 | |
| 
 | |
|     def start_session(self, user, **kwargs):
 | |
|         self.session = models.RadioSession.objects.create(
 | |
|             user=user, radio_type=self.radio_type, **kwargs
 | |
|         )
 | |
|         return self.session
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         actor = None
 | |
|         try:
 | |
|             actor = self.session.user.actor
 | |
|         except KeyError:
 | |
|             pass  # Maybe logging would be helpful
 | |
| 
 | |
|         qs = (
 | |
|             Track.objects.all()
 | |
|             .with_playable_uploads(actor=actor)
 | |
|             .prefetch_related(
 | |
|                 "artist_credit__artist", "album__artist_credit__artist", "attributed_to"
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         query = moderation_filters.get_filtered_content_query(
 | |
|             config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
 | |
|             user=self.session.user,
 | |
|         )
 | |
|         return qs.exclude(query)
 | |
| 
 | |
|     def get_queryset_kwargs(self):
 | |
|         return {}
 | |
| 
 | |
|     def filter_queryset(self, queryset):
 | |
|         return queryset
 | |
| 
 | |
|     def filter_from_session(self, queryset):
 | |
|         already_played = self.session.session_tracks.all().values_list(
 | |
|             "track", flat=True
 | |
|         )
 | |
|         queryset = queryset.exclude(pk__in=already_played)
 | |
|         return queryset
 | |
| 
 | |
|     def cache_batch_radio_track(self, **kwargs):
 | |
|         BATCH_SIZE = 100
 | |
|         # get cached RadioTracks if any
 | |
|         try:
 | |
|             cached_evaluated_radio_tracks = pickle.loads(
 | |
|                 cache.get(f"radiotracks{self.session.id}")
 | |
|             )
 | |
|         except TypeError:
 | |
|             cached_evaluated_radio_tracks = None
 | |
| 
 | |
|         # get the queryset and apply filters
 | |
|         kwargs.update(self.get_queryset_kwargs())
 | |
|         queryset = self.get_queryset(**kwargs)
 | |
|         queryset = self.filter_from_session(queryset)
 | |
| 
 | |
|         if kwargs["filter_playable"] is True:
 | |
|             queryset = queryset.playable_by(
 | |
|                 self.session.user.actor if self.session.user else None
 | |
|             )
 | |
|         queryset = self.filter_queryset(queryset)
 | |
| 
 | |
|         # select a random batch of the qs
 | |
|         sliced_queryset = queryset.order_by("?")[:BATCH_SIZE]
 | |
|         if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks:
 | |
|             raise ValueError("No more radio candidates")
 | |
| 
 | |
|         # create the radio session tracks into db in bulk
 | |
|         self.session.add(sliced_queryset)
 | |
| 
 | |
|         # evaluate the queryset to save it in cache
 | |
|         radio_tracks = list(sliced_queryset)
 | |
| 
 | |
|         if cached_evaluated_radio_tracks is not None:
 | |
|             radio_tracks.extend(cached_evaluated_radio_tracks)
 | |
|         logger.info(
 | |
|             f"Setting redis cache for radio generation with radio id {self.session.id}"
 | |
|         )
 | |
|         cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600)
 | |
|         cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600)
 | |
| 
 | |
|         return sliced_queryset
 | |
| 
 | |
|     def get_choices(self, quantity, **kwargs):
 | |
|         if cache.get(f"radiotracks{self.session.id}"):
 | |
|             cached_radio_tracks = pickle.loads(
 | |
|                 cache.get(f"radiotracks{self.session.id}")
 | |
|             )
 | |
|             logger.info("Using redis cache for radio generation")
 | |
|             radio_tracks = cached_radio_tracks
 | |
|             if len(radio_tracks) < quantity:
 | |
|                 logger.info(
 | |
|                     "Not enough radio tracks in cache. Trying to generate new cache"
 | |
|                 )
 | |
|                 sliced_queryset = self.cache_batch_radio_track(**kwargs)
 | |
|             sliced_queryset = cache.get(f"radioqueryset{self.session.id}")
 | |
|         else:
 | |
|             sliced_queryset = self.cache_batch_radio_track(**kwargs)
 | |
| 
 | |
|         return sliced_queryset[:quantity]
 | |
| 
 | |
|     def pick_many(self, quantity, **kwargs):
 | |
|         if self.session:
 | |
|             sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
 | |
|         else:
 | |
|             logger.info(
 | |
|                 "No radio session. Can't track user playback. Won't cache queryset results"
 | |
|             )
 | |
|             sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
 | |
| 
 | |
|         return sliced_queryset
 | |
| 
 | |
|     def validate_session(self, data, **context):
 | |
|         return data
 | |
| 
 | |
| 
 | |
| @registry.register(name="random")
 | |
| class RandomRadio(SessionRadio):
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         return qs.filter(artist_credit__artist__content_category="music").order_by("?")
 | |
| 
 | |
| 
 | |
| @registry.register(name="random_library")
 | |
| class RandomLibraryRadio(SessionRadio):
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
 | |
|             "id", flat=True
 | |
|         )
 | |
|         query = Q(artist_credit__artist__content_category="music") & Q(
 | |
|             pk__in=tracks_ids
 | |
|         )
 | |
|         return qs.filter(query).order_by("?")
 | |
| 
 | |
| 
 | |
| @registry.register(name="favorites")
 | |
| class FavoritesRadio(SessionRadio):
 | |
|     def get_queryset_kwargs(self):
 | |
|         kwargs = super().get_queryset_kwargs()
 | |
|         if self.session:
 | |
|             kwargs["user"] = self.session.user
 | |
|         return kwargs
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         track_ids = (
 | |
|             kwargs["user"].actor.track_favorites.all().values_list("track", flat=True)
 | |
|         )
 | |
|         return qs.filter(
 | |
|             pk__in=track_ids, artist_credit__artist__content_category="music"
 | |
|         )
 | |
| 
 | |
| 
 | |
| @registry.register(name="custom")
 | |
| class CustomRadio(SessionRadio):
 | |
|     def get_queryset_kwargs(self):
 | |
|         kwargs = super().get_queryset_kwargs()
 | |
|         kwargs["user"] = self.session.user
 | |
|         kwargs["custom_radio"] = self.session.custom_radio
 | |
|         return kwargs
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         return filters.run(kwargs["custom_radio"].config, candidates=qs)
 | |
| 
 | |
|     def validate_session(self, data, **context):
 | |
|         data = super().validate_session(data, **context)
 | |
|         try:
 | |
|             user = data["user"]
 | |
|         except KeyError:
 | |
|             user = context.get("user")
 | |
|         try:
 | |
|             assert data["custom_radio"].user == user or data["custom_radio"].is_public
 | |
|         except KeyError:
 | |
|             raise serializers.ValidationError("You must provide a custom radio")
 | |
|         except AssertionError:
 | |
|             raise serializers.ValidationError("You don't have access to this radio")
 | |
|         return data
 | |
| 
 | |
| 
 | |
| @registry.register(name="custom_multiple")
 | |
| class CustomMultiple(SessionRadio):
 | |
|     """
 | |
|     Receive a vuejs generated config and use it to launch a radio session
 | |
|     """
 | |
| 
 | |
|     config = serializers.JSONField(required=True)
 | |
| 
 | |
|     def get_config(self, data):
 | |
|         return data["config"]
 | |
| 
 | |
|     def get_queryset_kwargs(self):
 | |
|         kwargs = super().get_queryset_kwargs()
 | |
|         kwargs["config"] = self.session.config
 | |
|         return kwargs
 | |
| 
 | |
|     def validate_session(self, data, **context):
 | |
|         data = super().validate_session(data, **context)
 | |
|         try:
 | |
|             data["config"] is not None
 | |
|         except KeyError:
 | |
|             raise serializers.ValidationError(
 | |
|                 "You must provide a configuration for this radio"
 | |
|             )
 | |
|         return data
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         return filters.run([kwargs["config"]], candidates=qs)
 | |
| 
 | |
| 
 | |
| class RelatedObjectRadio(SessionRadio):
 | |
|     """Abstract radio related to an object (tag, artist, user...)"""
 | |
| 
 | |
|     related_object_field = serializers.IntegerField(required=True)
 | |
| 
 | |
|     def clean(self, instance):
 | |
|         super().clean(instance)
 | |
|         if not instance.related_object:
 | |
|             raise ValidationError(
 | |
|                 "Cannot start RelatedObjectRadio without related object"
 | |
|             )
 | |
|         if not isinstance(instance.related_object, self.model):
 | |
|             raise ValidationError("Trying to start radio with bad related object")
 | |
| 
 | |
|     def get_related_object(self, pk):
 | |
|         return self.model.objects.get(pk=pk)
 | |
| 
 | |
| 
 | |
| @registry.register(name="tag")
 | |
| class TagRadio(RelatedObjectRadio):
 | |
|     model = Tag
 | |
|     related_object_field = serializers.CharField(required=True)
 | |
| 
 | |
|     def get_related_object(self, name):
 | |
|         return self.model.objects.get(name=name)
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         query = (
 | |
|             Q(tagged_items__tag=self.session.related_object)
 | |
|             | Q(artist__tagged_items__tag=self.session.related_object)
 | |
|             | Q(album__tagged_items__tag=self.session.related_object)
 | |
|         )
 | |
|         return qs.filter(query)
 | |
| 
 | |
|     def get_related_object_id_repr(self, obj):
 | |
|         return obj.name
 | |
| 
 | |
| 
 | |
| def weighted_choice(choices):
 | |
|     total = sum(w for c, w in choices)
 | |
|     r = random.uniform(0, total)
 | |
|     upto = 0
 | |
|     for c, w in choices:
 | |
|         if upto + w >= r:
 | |
|             return c
 | |
|         upto += w
 | |
|     assert False, "Shouldn't get here"
 | |
| 
 | |
| 
 | |
| class NextNotFound(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @registry.register(name="similar")
 | |
| class SimilarRadio(RelatedObjectRadio):
 | |
|     model = Track
 | |
| 
 | |
|     def filter_queryset(self, queryset):
 | |
|         queryset = super().filter_queryset(queryset)
 | |
|         seeds = list(
 | |
|             self.session.session_tracks.all()
 | |
|             .values_list("track_id", flat=True)
 | |
|             .order_by("-id")[:3]
 | |
|         ) + [self.session.related_object.pk]
 | |
|         for seed in seeds:
 | |
|             try:
 | |
|                 return queryset.filter(pk=self.find_next_id(queryset, seed))
 | |
|             except NextNotFound:
 | |
|                 continue
 | |
| 
 | |
|         return queryset.none()
 | |
| 
 | |
|     def find_next_id(self, queryset, seed):
 | |
|         with connection.cursor() as cursor:
 | |
|             query = """
 | |
|             SELECT next, count(next) AS c
 | |
|             FROM (
 | |
|                 SELECT
 | |
|                     history_listening.track_id,
 | |
|                     history_listening.creation_date,
 | |
|                     LEAD(history_listening.track_id) OVER (
 | |
|                         PARTITION BY history_listening.actor_id ORDER BY history_listening.creation_date ASC
 | |
|                     ) AS next
 | |
|                 FROM history_listening
 | |
|                 INNER JOIN federation_actor ON federation_actor.id = history_listening.actor_id
 | |
|                 INNER JOIN users_user ON users_user.actor_id = federation_actor.id
 | |
|                 WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' \
 | |
|                     OR history_listening.actor_id = %s
 | |
|                 ORDER BY history_listening.creation_date ASC
 | |
|             ) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
 | |
|             """
 | |
|             cursor.execute(query, [self.session.user_id, seed, seed])
 | |
|             next_candidates = list(cursor.fetchall())
 | |
| 
 | |
|         if not next_candidates:
 | |
|             raise NextNotFound()
 | |
| 
 | |
|         matching_tracks = list(
 | |
|             queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
 | |
|                 "id", flat=True
 | |
|             )
 | |
|         )
 | |
|         next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
 | |
|         if not next_candidates:
 | |
|             raise NextNotFound()
 | |
|         return random.choice([c[0] for c in next_candidates])
 | |
| 
 | |
| 
 | |
| @registry.register(name="artist")
 | |
| class ArtistRadio(RelatedObjectRadio):
 | |
|     model = Artist
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         return qs.filter(artist_credit__artist=self.session.related_object)
 | |
| 
 | |
| 
 | |
| @registry.register(name="less-listened")
 | |
| class LessListenedRadio(SessionRadio):
 | |
|     def clean(self, instance):
 | |
|         instance.related_object = instance.user
 | |
|         super().clean(instance)
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         listened = self.session.user.actor.listenings.all().values_list(
 | |
|             "track", flat=True
 | |
|         )
 | |
|         return (
 | |
|             qs.filter(artist_credit__artist__content_category="music")
 | |
|             .exclude(pk__in=listened)
 | |
|             .order_by("?")
 | |
|         )
 | |
| 
 | |
| 
 | |
| @registry.register(name="less-listened_library")
 | |
| class LessListenedLibraryRadio(SessionRadio):
 | |
|     def clean(self, instance):
 | |
|         instance.related_object = instance.user
 | |
|         super().clean(instance)
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         listened = self.session.user.actor.listenings.all().values_list(
 | |
|             "track", flat=True
 | |
|         )
 | |
|         tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
 | |
|             "id", flat=True
 | |
|         )
 | |
|         query = Q(artist_credit__artist__content_category="music") & Q(
 | |
|             pk__in=tracks_ids
 | |
|         )
 | |
|         return qs.filter(query).exclude(pk__in=listened).order_by("?")
 | |
| 
 | |
| 
 | |
| @registry.register(name="actor-content")
 | |
| class ActorContentRadio(RelatedObjectRadio):
 | |
|     """
 | |
|     Play content from given actor libraries
 | |
|     """
 | |
| 
 | |
|     model = federation_models.Actor
 | |
|     related_object_field = federation_fields.ActorRelatedField(required=True)
 | |
| 
 | |
|     def get_related_object(self, value):
 | |
|         return value
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         actor_uploads = Upload.objects.filter(
 | |
|             library__actor=self.session.related_object,
 | |
|         )
 | |
|         return qs.filter(pk__in=actor_uploads.values("track"))
 | |
| 
 | |
|     def get_related_object_id_repr(self, obj):
 | |
|         return obj.full_username
 | |
| 
 | |
| 
 | |
| @registry.register(name="library")
 | |
| class LibraryRadio(RelatedObjectRadio):
 | |
|     """
 | |
|     Play content from a given library
 | |
|     """
 | |
| 
 | |
|     model = Library
 | |
|     related_object_field = serializers.UUIDField(required=True)
 | |
| 
 | |
|     def get_related_object(self, value):
 | |
|         return Library.objects.get(uuid=value)
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         actor_uploads = Upload.objects.filter(
 | |
|             library=self.session.related_object,
 | |
|         )
 | |
|         return qs.filter(pk__in=actor_uploads.values("track"))
 | |
| 
 | |
|     def get_related_object_id_repr(self, obj):
 | |
|         return obj.uuid
 | |
| 
 | |
| 
 | |
| @registry.register(name="recently-added")
 | |
| class RecentlyAdded(SessionRadio):
 | |
|     def get_queryset(self, **kwargs):
 | |
|         date = datetime.date.today() - datetime.timedelta(days=30)
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         return qs.filter(
 | |
|             Q(artist_credit__artist__content_category="music"),
 | |
|             Q(creation_date__gt=date),
 | |
|         )
 | |
| 
 | |
| 
 | |
| # Use this to experiment on the custom multiple radio with troi
 | |
| @registry.register(name="troi")
 | |
| class Troi(SessionRadio):
 | |
|     """
 | |
|     Receive a vuejs generated config and use it to launch a troi radio session.
 | |
|     The config data should follow :
 | |
|     {"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
 | |
|     Validation of the config (args) is done by troi during track fetch.
 | |
|     Funkwhale only checks if the patch is implemented
 | |
|     """
 | |
| 
 | |
|     config = serializers.JSONField(required=True)
 | |
| 
 | |
|     def append_lb_config(self, data):
 | |
|         if self.session.user.settings is None:
 | |
|             logger.warning(
 | |
|                 "No lb_user_name set in user settings. Some troi patches will fail"
 | |
|             )
 | |
|             return data
 | |
|         elif self.session.user.settings.get("lb_user_name") is None:
 | |
|             logger.warning(
 | |
|                 "No lb_user_name set in user settings. Some troi patches will fail"
 | |
|             )
 | |
|         else:
 | |
|             data["user_name"] = self.session.user.settings["lb_user_name"]
 | |
| 
 | |
|         if self.session.user.settings.get("lb_user_token") is None:
 | |
|             logger.warning(
 | |
|                 "No lb_user_token set in user settings. Some troi patch will fail"
 | |
|             )
 | |
|         else:
 | |
|             data["user_token"] = self.session.user.settings["lb_user_token"]
 | |
| 
 | |
|         return data
 | |
| 
 | |
|     def get_queryset_kwargs(self):
 | |
|         kwargs = super().get_queryset_kwargs()
 | |
|         kwargs["config"] = self.session.config
 | |
|         return kwargs
 | |
| 
 | |
|     def validate_session(self, data, **context):
 | |
|         data = super().validate_session(data, **context)
 | |
|         if data.get("config") is None:
 | |
|             raise serializers.ValidationError(
 | |
|                 "You must provide a configuration for this radio"
 | |
|             )
 | |
|         return data
 | |
| 
 | |
|     def get_queryset(self, **kwargs):
 | |
|         qs = super().get_queryset(**kwargs)
 | |
|         config = self.append_lb_config(json.loads(kwargs["config"]))
 | |
| 
 | |
|         return lb_recommendations.run(config, candidates=qs)
 |