Added API endpoint to insert multiple tracks into playlist
This commit is contained in:
parent
1729c4f83e
commit
f8b15a3f48
|
@ -1,8 +1,10 @@
|
||||||
from django import forms
|
from django.conf import settings
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from rest_framework import exceptions
|
||||||
|
|
||||||
from funkwhale_api.common import fields
|
from funkwhale_api.common import fields
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,17 +42,15 @@ class Playlist(models.Model):
|
||||||
index = total
|
index = total
|
||||||
|
|
||||||
if index > total:
|
if index > total:
|
||||||
raise forms.ValidationError('Index is not continuous')
|
raise exceptions.ValidationError('Index is not continuous')
|
||||||
|
|
||||||
if index < 0:
|
if index < 0:
|
||||||
raise forms.ValidationError('Index must be zero or positive')
|
raise exceptions.ValidationError('Index must be zero or positive')
|
||||||
|
|
||||||
if move:
|
if move:
|
||||||
# we remove the index temporarily, to avoid integrity errors
|
# we remove the index temporarily, to avoid integrity errors
|
||||||
plt.index = None
|
plt.index = None
|
||||||
plt.save(update_fields=['index'])
|
plt.save(update_fields=['index'])
|
||||||
|
|
||||||
if move:
|
|
||||||
if index > old_index:
|
if index > old_index:
|
||||||
# new index is higher than current, we decrement previous tracks
|
# new index is higher than current, we decrement previous tracks
|
||||||
to_update = existing.filter(
|
to_update = existing.filter(
|
||||||
|
@ -58,8 +58,7 @@ class Playlist(models.Model):
|
||||||
to_update.update(index=models.F('index') - 1)
|
to_update.update(index=models.F('index') - 1)
|
||||||
if index < old_index:
|
if index < old_index:
|
||||||
# new index is lower than current, we increment next tracks
|
# new index is lower than current, we increment next tracks
|
||||||
to_update = existing.filter(
|
to_update = existing.filter(index__lt=old_index, index__gte=index)
|
||||||
index__lt=old_index, index__gte=index)
|
|
||||||
to_update.update(index=models.F('index') + 1)
|
to_update.update(index=models.F('index') + 1)
|
||||||
else:
|
else:
|
||||||
to_update = existing.filter(index__gte=index)
|
to_update = existing.filter(index__gte=index)
|
||||||
|
@ -77,6 +76,24 @@ class Playlist(models.Model):
|
||||||
to_update = existing.filter(index__gt=index)
|
to_update = existing.filter(index__gt=index)
|
||||||
return to_update.update(index=models.F('index') - 1)
|
return to_update.update(index=models.F('index') - 1)
|
||||||
|
|
||||||
|
@transaction.atomic
|
||||||
|
def insert_many(self, tracks):
|
||||||
|
existing = self.playlist_tracks.select_for_update()
|
||||||
|
now = timezone.now()
|
||||||
|
total = existing.filter(index__isnull=False).count()
|
||||||
|
if existing.count() + len(tracks) > settings.PLAYLISTS_MAX_TRACKS:
|
||||||
|
raise exceptions.ValidationError(
|
||||||
|
'Playlist would reach the maximum of {} tracks'.format(
|
||||||
|
settings.PLAYLISTS_MAX_TRACKS))
|
||||||
|
self.save(update_fields=['modification_date'])
|
||||||
|
start = total
|
||||||
|
plts = [
|
||||||
|
PlaylistTrack(
|
||||||
|
creation_date=now, playlist=self, track=track, index=start+i)
|
||||||
|
for i, track in enumerate(tracks)
|
||||||
|
]
|
||||||
|
return PlaylistTrack.objects.bulk_create(plts)
|
||||||
|
|
||||||
|
|
||||||
class PlaylistTrack(models.Model):
|
class PlaylistTrack(models.Model):
|
||||||
track = models.ForeignKey(
|
track = models.ForeignKey(
|
||||||
|
|
|
@ -3,8 +3,9 @@ from django.db import transaction
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from taggit.models import Tag
|
from taggit.models import Tag
|
||||||
|
|
||||||
|
from funkwhale_api.music.models import Track
|
||||||
from funkwhale_api.music.serializers import TrackSerializerNested
|
from funkwhale_api.music.serializers import TrackSerializerNested
|
||||||
|
from funkwhale_api.users.serializers import UserBasicSerializer
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,20 +62,34 @@ class PlaylistTrackWriteSerializer(serializers.ModelSerializer):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class PlaylistWriteSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = models.Playlist
|
||||||
|
fields = [
|
||||||
|
'id',
|
||||||
|
'name',
|
||||||
|
'privacy_level',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PlaylistSerializer(serializers.ModelSerializer):
|
class PlaylistSerializer(serializers.ModelSerializer):
|
||||||
tracks_count = serializers.SerializerMethodField()
|
tracks_count = serializers.SerializerMethodField()
|
||||||
|
user = UserBasicSerializer()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = models.Playlist
|
model = models.Playlist
|
||||||
fields = (
|
fields = (
|
||||||
'id',
|
'id',
|
||||||
'name',
|
'name',
|
||||||
|
'user',
|
||||||
'tracks_count',
|
'tracks_count',
|
||||||
'privacy_level',
|
'privacy_level',
|
||||||
'creation_date',
|
'creation_date',
|
||||||
'modification_date')
|
'modification_date')
|
||||||
read_only_fields = [
|
read_only_fields = [
|
||||||
'id',
|
'id',
|
||||||
|
'user',
|
||||||
'modification_date',
|
'modification_date',
|
||||||
'creation_date',]
|
'creation_date',]
|
||||||
|
|
||||||
|
@ -84,3 +99,8 @@ class PlaylistSerializer(serializers.ModelSerializer):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# no annotation?
|
# no annotation?
|
||||||
return obj.playlist_tracks.count()
|
return obj.playlist_tracks.count()
|
||||||
|
|
||||||
|
|
||||||
|
class PlaylistAddManySerializer(serializers.Serializer):
|
||||||
|
tracks = serializers.PrimaryKeyRelatedField(
|
||||||
|
many=True, queryset=Track.objects.for_nested_serialization())
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from django.db.models import Count
|
from django.db.models import Count
|
||||||
|
from django.db import transaction
|
||||||
|
|
||||||
|
from rest_framework import exceptions
|
||||||
from rest_framework import generics, mixins, viewsets
|
from rest_framework import generics, mixins, viewsets
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.decorators import detail_route
|
from rest_framework.decorators import detail_route
|
||||||
|
@ -25,7 +27,7 @@ class PlaylistViewSet(
|
||||||
|
|
||||||
serializer_class = serializers.PlaylistSerializer
|
serializer_class = serializers.PlaylistSerializer
|
||||||
queryset = (
|
queryset = (
|
||||||
models.Playlist.objects.all()
|
models.Playlist.objects.all().select_related('user')
|
||||||
.annotate(tracks_count=Count('playlist_tracks'))
|
.annotate(tracks_count=Count('playlist_tracks'))
|
||||||
)
|
)
|
||||||
permission_classes = [
|
permission_classes = [
|
||||||
|
@ -36,6 +38,11 @@ class PlaylistViewSet(
|
||||||
owner_checks = ['write']
|
owner_checks = ['write']
|
||||||
filter_class = filters.PlaylistFilter
|
filter_class = filters.PlaylistFilter
|
||||||
|
|
||||||
|
def get_serializer_class(self):
|
||||||
|
if self.request.method in ['PUT', 'PATCH', 'DELETE', 'POST']:
|
||||||
|
return serializers.PlaylistWriteSerializer
|
||||||
|
return self.serializer_class
|
||||||
|
|
||||||
@detail_route(methods=['get'])
|
@detail_route(methods=['get'])
|
||||||
def tracks(self, request, *args, **kwargs):
|
def tracks(self, request, *args, **kwargs):
|
||||||
playlist = self.get_object()
|
playlist = self.get_object()
|
||||||
|
@ -47,6 +54,24 @@ class PlaylistViewSet(
|
||||||
}
|
}
|
||||||
return Response(data, status=200)
|
return Response(data, status=200)
|
||||||
|
|
||||||
|
@detail_route(methods=['post'])
|
||||||
|
@transaction.atomic
|
||||||
|
def add(self, request, *args, **kwargs):
|
||||||
|
playlist = self.get_object()
|
||||||
|
serializer = serializers.PlaylistAddManySerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
try:
|
||||||
|
plts = playlist.insert_many(serializer.validated_data['tracks'])
|
||||||
|
except exceptions.ValidationError as e:
|
||||||
|
payload = {'playlist': e.detail}
|
||||||
|
return Response(payload, status=400)
|
||||||
|
serializer = serializers.PlaylistTrackSerializer(plts, many=True)
|
||||||
|
data = {
|
||||||
|
'count': len(plts),
|
||||||
|
'results': serializer.data
|
||||||
|
}
|
||||||
|
return Response(data, status=201)
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
return self.queryset.filter(
|
return self.queryset.filter(
|
||||||
fields.privacy_level_query(self.request.user))
|
fields.privacy_level_query(self.request.user))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from django import forms
|
from rest_framework import exceptions
|
||||||
|
|
||||||
|
|
||||||
def test_can_insert_plt(factories):
|
def test_can_insert_plt(factories):
|
||||||
|
@ -79,14 +79,14 @@ def test_can_insert_and_move_last_to_0(factories):
|
||||||
def test_cannot_insert_at_wrong_index(factories):
|
def test_cannot_insert_at_wrong_index(factories):
|
||||||
plt = factories['playlists.PlaylistTrack']()
|
plt = factories['playlists.PlaylistTrack']()
|
||||||
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
|
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
|
||||||
with pytest.raises(forms.ValidationError):
|
with pytest.raises(exceptions.ValidationError):
|
||||||
plt.playlist.insert(new, 2)
|
plt.playlist.insert(new, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_cannot_insert_at_negative_index(factories):
|
def test_cannot_insert_at_negative_index(factories):
|
||||||
plt = factories['playlists.PlaylistTrack']()
|
plt = factories['playlists.PlaylistTrack']()
|
||||||
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
|
new = factories['playlists.PlaylistTrack'](playlist=plt.playlist)
|
||||||
with pytest.raises(forms.ValidationError):
|
with pytest.raises(exceptions.ValidationError):
|
||||||
plt.playlist.insert(new, -1)
|
plt.playlist.insert(new, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,3 +103,24 @@ def test_remove_update_indexes(factories):
|
||||||
|
|
||||||
assert first.index == 0
|
assert first.index == 0
|
||||||
assert third.index == 1
|
assert third.index == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_insert_many(factories):
|
||||||
|
playlist = factories['playlists.Playlist']()
|
||||||
|
existing = factories['playlists.PlaylistTrack'](playlist=playlist, index=0)
|
||||||
|
tracks = factories['music.Track'].create_batch(size=3)
|
||||||
|
plts = playlist.insert_many(tracks)
|
||||||
|
for i, plt in enumerate(plts):
|
||||||
|
assert plt.index == i + 1
|
||||||
|
assert plt.track == tracks[i]
|
||||||
|
assert plt.playlist == playlist
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_many_honor_max_tracks(factories, settings):
|
||||||
|
settings.PLAYLISTS_MAX_TRACKS = 4
|
||||||
|
playlist = factories['playlists.Playlist']()
|
||||||
|
plts = factories['playlists.PlaylistTrack'].create_batch(
|
||||||
|
size=2, playlist=playlist)
|
||||||
|
track = factories['music.Track']()
|
||||||
|
with pytest.raises(exceptions.ValidationError):
|
||||||
|
playlist.insert_many([track, track, track])
|
||||||
|
|
|
@ -153,3 +153,20 @@ def test_can_list_tracks_from_playlist(
|
||||||
|
|
||||||
assert response.data['count'] == 1
|
assert response.data['count'] == 1
|
||||||
assert response.data['results'][0] == serialized_plt
|
assert response.data['results'][0] == serialized_plt
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_add_multiple_tracks_at_once_via_api(
|
||||||
|
factories, mocker, logged_in_api_client):
|
||||||
|
playlist = factories['playlists.Playlist'](user=logged_in_api_client.user)
|
||||||
|
tracks = factories['music.Track'].create_batch(size=5)
|
||||||
|
track_ids = [t.id for t in tracks]
|
||||||
|
mocker.spy(playlist, 'insert_many')
|
||||||
|
url = reverse('api:v1:playlists-add', kwargs={'pk': playlist.pk})
|
||||||
|
response = logged_in_api_client.post(url, {'tracks': track_ids})
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
assert playlist.playlist_tracks.count() == len(track_ids)
|
||||||
|
|
||||||
|
for plt in playlist.playlist_tracks.order_by('index'):
|
||||||
|
assert response.data['results'][plt.index]['id'] == plt.id
|
||||||
|
assert plt.track == tracks[plt.index]
|
||||||
|
|
Loading…
Reference in New Issue