Added API endpoint to insert multiple tracks into playlist

This commit is contained in:
Eliot Berriot 2018-03-20 19:56:42 +01:00
parent 1729c4f83e
commit f8b15a3f48
No known key found for this signature in database
GPG Key ID: DD6965E2476E5C27
5 changed files with 112 additions and 12 deletions

View File

@ -1,8 +1,10 @@
from django import forms from django.conf import settings
from django.db import models from django.db import models
from django.db import transaction from django.db import transaction
from django.utils import timezone from django.utils import timezone
from rest_framework import exceptions
from funkwhale_api.common import fields from funkwhale_api.common import fields
@ -40,17 +42,15 @@ class Playlist(models.Model):
index = total index = total
if index > total: if index > total:
raise forms.ValidationError('Index is not continuous') raise exceptions.ValidationError('Index is not continuous')
if index < 0: if index < 0:
raise forms.ValidationError('Index must be zero or positive') raise exceptions.ValidationError('Index must be zero or positive')
if move: if move:
# we remove the index temporarily, to avoid integrity errors # we remove the index temporarily, to avoid integrity errors
plt.index = None plt.index = None
plt.save(update_fields=['index']) plt.save(update_fields=['index'])
if move:
if index > old_index: if index > old_index:
# new index is higher than current, we decrement previous tracks # new index is higher than current, we decrement previous tracks
to_update = existing.filter( to_update = existing.filter(
@ -58,8 +58,7 @@ class Playlist(models.Model):
to_update.update(index=models.F('index') - 1) to_update.update(index=models.F('index') - 1)
if index < old_index: if index < old_index:
# new index is lower than current, we increment next tracks # new index is lower than current, we increment next tracks
to_update = existing.filter( to_update = existing.filter(index__lt=old_index, index__gte=index)
index__lt=old_index, index__gte=index)
to_update.update(index=models.F('index') + 1) to_update.update(index=models.F('index') + 1)
else: else:
to_update = existing.filter(index__gte=index) to_update = existing.filter(index__gte=index)
@ -77,6 +76,24 @@ class Playlist(models.Model):
to_update = existing.filter(index__gt=index) to_update = existing.filter(index__gt=index)
return to_update.update(index=models.F('index') - 1) return to_update.update(index=models.F('index') - 1)
@transaction.atomic
def insert_many(self, tracks):
existing = self.playlist_tracks.select_for_update()
now = timezone.now()
total = existing.filter(index__isnull=False).count()
if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS:
raise exceptions.ValidationError(
'Playlist would reach the maximum of {} tracks'.format(
settings.PLAYLISTS_MAX_TRACKS))
self.save(update_fields=['modification_date'])
start = total
plts = [
PlaylistTrack(
creation_date=now, playlist=self, track=track, index=start+i)
for i, track in enumerate(tracks)
]
return PlaylistTrack.objects.bulk_create(plts)
class PlaylistTrack(models.Model): class PlaylistTrack(models.Model):
track = models.ForeignKey( track = models.ForeignKey(

View File

@ -3,8 +3,9 @@ from django.db import transaction
from rest_framework import serializers from rest_framework import serializers
from taggit.models import Tag from taggit.models import Tag
from funkwhale_api.music.models import Track
from funkwhale_api.music.serializers import TrackSerializerNested from funkwhale_api.music.serializers import TrackSerializerNested
from funkwhale_api.users.serializers import UserBasicSerializer
from . import models from . import models
@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer):
return [] return []
class PlaylistWriteSerializer(serializers.ModelSerializer):
class Meta:
model = models.Playlist
fields = [
'id',
'name',
'privacy_level',
]
class PlaylistSerializer(serializers.ModelSerializer): class PlaylistSerializer(serializers.ModelSerializer):
tracks_count = serializers.SerializerMethodField() tracks_count = serializers.SerializerMethodField()
user = UserBasicSerializer()
class Meta: class Meta:
model = models.Playlist model = models.Playlist
fields = ( fields = (
'id', 'id',
'name', 'name',
'user',
'tracks_count', 'tracks_count',
'privacy_level', 'privacy_level',
'creation_date', 'creation_date',
'modification_date') 'modification_date')
read_only_fields = [ read_only_fields = [
'id', 'id',
'user',
'modification_date', 'modification_date',
'creation_date',] 'creation_date',]
@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer):
except AttributeError: except AttributeError:
# no annotation? # no annotation?
return obj.playlist_tracks.count() return obj.playlist_tracks.count()
class PlaylistAddManySerializer(serializers.Serializer):
tracks = serializers.PrimaryKeyRelatedField(
many=True, queryset=Track.objects.for_nested_serialization())

View File

@ -1,5 +1,7 @@
from django.db.models import Count from django.db.models import Count
from django.db import transaction
from rest_framework import exceptions
from rest_framework import generics, mixins, viewsets from rest_framework import generics, mixins, viewsets
from rest_framework import status from rest_framework import status
from rest_framework.decorators import detail_route from rest_framework.decorators import detail_route
@ -25,7 +27,7 @@ class PlaylistViewSet(
serializer_class = serializers.PlaylistSerializer serializer_class = serializers.PlaylistSerializer
queryset = ( queryset = (
models.Playlist.objects.all() models.Playlist.objects.all().select_related('user')
.annotate(tracks_count=Count('playlist_tracks')) .annotate(tracks_count=Count('playlist_tracks'))
) )
permission_classes = [ permission_classes = [
@ -36,6 +38,11 @@ class PlaylistViewSet(
owner_checks = ['write'] owner_checks = ['write']
filter_class = filters.PlaylistFilter filter_class = filters.PlaylistFilter
def get_serializer_class(self):
if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']:
return serializers.PlaylistWriteSerializer
return self.serializer_class
@detail_route(methods=['get']) @detail_route(methods=['get'])
def tracks(self, request, *args, **kwargs): def tracks(self, request, *args, **kwargs):
playlist = self.get_object() playlist = self.get_object()
@ -47,6 +54,24 @@ class PlaylistViewSet(
} }
return Response(data, status=200) return Response(data, status=200)
@detail_route(methods=['post'])
@transaction.atomic
def add(self, request, *args, **kwargs):
playlist = self.get_object()
serializer = serializers.PlaylistAddManySerializer(data=request.data)
serializer.is_valid(raise_exception=True)
try:
plts = playlist.insert_many(serializer.validated_data['tracks'])
except exceptions.ValidationError as e:
payload = {'playlist': e.detail}
return Response(payload, status=400)
serializer = serializers.PlaylistTrackSerializer(plts, many=True)
data = {
'count': len(plts),
'results': serializer.data
}
return Response(data, status=201)
def get_queryset(self): def get_queryset(self):
return self.queryset.filter( return self.queryset.filter(
fields.privacy_level_query(self.request.user)) fields.privacy_level_query(self.request.user))

View File

@ -1,6 +1,6 @@
import pytest import pytest
from django import forms from rest_framework import exceptions
def test_can_insert_plt(factories): def test_can_insert_plt(factories):
@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories):
def test_cannot_insert_at_wrong_index(factories): def test_cannot_insert_at_wrong_index(factories):
plt = factories['playlists.PlaylistTrack']() plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError): with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, 2) plt.playlist.insert(new, 2)
def test_cannot_insert_at_negative_index(factories): def test_cannot_insert_at_negative_index(factories):
plt = factories['playlists.PlaylistTrack']() plt = factories['playlists.PlaylistTrack']()
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist) new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
with pytest.raises(forms.ValidationError): with pytest.raises(exceptions.ValidationError):
plt.playlist.insert(new, -1) plt.playlist.insert(new, -1)
@ -103,3 +103,24 @@ def test_remove_update_indexes(factories):
assert first.index == 0 assert first.index == 0
assert third.index == 1 assert third.index == 1
def test_can_insert_many(factories):
playlist = factories['playlists.Playlist']()
existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0)
tracks = factories['music.Track'].create_batch(size=3)
plts = playlist.insert_many(tracks)
for i, plt in enumerate(plts):
assert plt.index == i + 1
assert plt.track == tracks[i]
assert plt.playlist == playlist
def test_insert_many_honor_max_tracks(factories, settings):
settings.PLAYLISTS_MAX_TRACKS = 4
playlist = factories['playlists.Playlist']()
plts = factories['playlists.PlaylistTrack'].create_batch(
size=2, playlist=playlist)
track = factories['music.Track']()
with pytest.raises(exceptions.ValidationError):
playlist.insert_many([track, track, track])

View File

@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist(
assert response.data['count'] == 1 assert response.data['count'] == 1
assert response.data['results'][0] == serialized_plt assert response.data['results'][0] == serialized_plt
def test_can_add_multiple_tracks_at_once_via_api(
factories, mocker, logged_in_api_client):
playlist = factories['playlists.Playlist'](user=logged_in_api_client.user)
tracks = factories['music.Track'].create_batch(size=5)
track_ids = [t.id for t in tracks]
mocker.spy(playlist, 'insert_many')
url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk})
response = logged_in_api_client.post(url, {'tracks': track_ids})
assert response.status_code == 201
assert playlist.playlist_tracks.count() == len(track_ids)
for plt in playlist.playlist_tracks.order_by('index'):
assert response.data['results'][plt.index]['id'] == plt.id
assert plt.track == tracks[plt.index]