Revert "Revert "Fix #994: use PostgreSQL full-text-search""

This reverts commit 7b0db234e2.
This commit is contained in:
Eliot Berriot 2019-12-18 11:26:59 +01:00
parent c2cb510eb9
commit 57949c02c1
14 changed files with 368 additions and 38 deletions

View File

@ -942,3 +942,5 @@ MODERATION_EMAIL_NOTIFICATIONS_ENABLED = env.bool(
# Delay in days after signup before we show the "support us" messages # Delay in days after signup before we show the "support us" messages
INSTANCE_SUPPORT_MESSAGE_DELAY = env.int("INSTANCE_SUPPORT_MESSAGE_DELAY", default=15) INSTANCE_SUPPORT_MESSAGE_DELAY = env.int("INSTANCE_SUPPORT_MESSAGE_DELAY", default=15)
FUNKWHALE_SUPPORT_MESSAGE_DELAY = env.int("FUNKWHALE_SUPPORT_MESSAGE_DELAY", default=15) FUNKWHALE_SUPPORT_MESSAGE_DELAY = env.int("FUNKWHALE_SUPPORT_MESSAGE_DELAY", default=15)
# XXX Stable release: remove
USE_FULL_TEXT_SEARCH = env.bool("USE_FULL_TEXT_SEARCH", default=True)

View File

@ -1,5 +1,6 @@
import django_filters import django_filters
from django import forms from django import forms
from django.conf import settings
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
from django.db import models from django.db import models
@ -33,12 +34,18 @@ def privacy_level_query(user, lookup_field="privacy_level", user_field="user"):
class SearchFilter(django_filters.CharFilter): class SearchFilter(django_filters.CharFilter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.search_fields = kwargs.pop("search_fields") self.search_fields = kwargs.pop("search_fields")
self.fts_search_fields = kwargs.pop("fts_search_fields", [])
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def filter(self, qs, value): def filter(self, qs, value):
if not value: if not value:
return qs return qs
query = search.get_query(value, self.search_fields) if settings.USE_FULL_TEXT_SEARCH and self.fts_search_fields:
query = search.get_fts_query(
value, self.fts_search_fields, model=self.parent.Meta.model
)
else:
query = search.get_query(value, self.search_fields)
return qs.filter(query) return qs.filter(query)

View File

@ -1,7 +1,10 @@
import re import re
from django.contrib.postgres.search import SearchQuery
from django.db.models import Q from django.db.models import Q
from . import utils
QUERY_REGEX = re.compile(r'(((?P<key>\w+):)?(?P<value>"[^"]+"|[\S]+))') QUERY_REGEX = re.compile(r'(((?P<key>\w+):)?(?P<value>"[^"]+"|[\S]+))')
@ -56,6 +59,41 @@ def get_query(query_string, search_fields):
return query return query
def get_fts_query(query_string, fts_fields=["body_text"], model=None):
if query_string.startswith('"') and query_string.endswith('"'):
# we pass the query directly to the FTS engine
query_string = query_string[1:-1]
else:
parts = query_string.replace(":", "").split(" ")
parts = ["{}:*".format(p) for p in parts if p]
if not parts:
return Q(pk=None)
query_string = "&".join(parts)
if not fts_fields or not query_string.strip():
return Q(pk=None)
query = None
for field in fts_fields:
if "__" in field and model:
# When we have a nested lookup, we switch to a subquery for enhanced performance
fk_field_name, lookup = (
field.split("__")[0],
"__".join(field.split("__")[1:]),
)
fk_field = model._meta.get_field(fk_field_name)
related_model = fk_field.related_model
subquery = related_model.objects.filter(
**{lookup: SearchQuery(query_string, search_type="raw")}
).values_list("pk", flat=True)
new_query = Q(**{"{}__in".format(fk_field_name): list(subquery)})
else:
new_query = Q(**{field: SearchQuery(query_string, search_type="raw")})
query = utils.join_queries_or(query, new_query)
return query
def filter_tokens(tokens, valid): def filter_tokens(tokens, valid):
return [t for t in tokens if t["key"] in valid] return [t for t in tokens if t["key"] in valid]

View File

@ -234,3 +234,10 @@ def get_updated_fields(conf, data, obj):
final_data[obj_field] = data_value final_data[obj_field] = data_value
return final_data return final_data
def join_queries_or(left, right):
if left:
return left | right
else:
return right

View File

@ -511,13 +511,6 @@ def prepare_deliveries_and_inbox_items(recipient_list, type, allowed_domains=Non
return inbox_items, deliveries, urls return inbox_items, deliveries, urls
def join_queries_or(left, right):
if left:
return left | right
else:
return right
def get_actors_from_audience(urls): def get_actors_from_audience(urls):
""" """
Given a list of urls such as [ Given a list of urls such as [
@ -539,22 +532,24 @@ def get_actors_from_audience(urls):
if url == PUBLIC_ADDRESS: if url == PUBLIC_ADDRESS:
continue continue
queries["actors"].append(url) queries["actors"].append(url)
queries["followed"] = join_queries_or( queries["followed"] = funkwhale_utils.join_queries_or(
queries["followed"], Q(target__followers_url=url) queries["followed"], Q(target__followers_url=url)
) )
final_query = None final_query = None
if queries["actors"]: if queries["actors"]:
final_query = join_queries_or(final_query, Q(fid__in=queries["actors"])) final_query = funkwhale_utils.join_queries_or(
final_query, Q(fid__in=queries["actors"])
)
if queries["followed"]: if queries["followed"]:
actor_follows = models.Follow.objects.filter(queries["followed"], approved=True) actor_follows = models.Follow.objects.filter(queries["followed"], approved=True)
final_query = join_queries_or( final_query = funkwhale_utils.join_queries_or(
final_query, Q(pk__in=actor_follows.values_list("actor", flat=True)) final_query, Q(pk__in=actor_follows.values_list("actor", flat=True))
) )
library_follows = models.LibraryFollow.objects.filter( library_follows = models.LibraryFollow.objects.filter(
queries["followed"], approved=True queries["followed"], approved=True
) )
final_query = join_queries_or( final_query = funkwhale_utils.join_queries_or(
final_query, Q(pk__in=library_follows.values_list("actor", flat=True)) final_query, Q(pk__in=library_follows.values_list("actor", flat=True))
) )
if not final_query: if not final_query:

View File

@ -23,7 +23,7 @@ TAG_FILTER = common_filters.MultipleQueryFilter(method=filter_tags)
class ArtistFilter( class ArtistFilter(
audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet
): ):
q = fields.SearchFilter(search_fields=["name"]) q = fields.SearchFilter(search_fields=["name"], fts_search_fields=["body_text"])
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = TAG_FILTER tag = TAG_FILTER
scope = common_filters.ActorScopeFilter( scope = common_filters.ActorScopeFilter(
@ -49,7 +49,10 @@ class ArtistFilter(
class TrackFilter( class TrackFilter(
audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet
): ):
q = fields.SearchFilter(search_fields=["title", "album__title", "artist__name"]) q = fields.SearchFilter(
search_fields=["title", "album__title", "artist__name"],
fts_search_fields=["body_text", "artist__body_text", "album__body_text"],
)
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = TAG_FILTER tag = TAG_FILTER
id = common_filters.MultipleQueryFilter(coerce=int) id = common_filters.MultipleQueryFilter(coerce=int)
@ -127,7 +130,10 @@ class AlbumFilter(
audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet audio_filters.IncludeChannelsFilterSet, moderation_filters.HiddenContentFilterSet
): ):
playable = filters.BooleanFilter(field_name="_", method="filter_playable") playable = filters.BooleanFilter(field_name="_", method="filter_playable")
q = fields.SearchFilter(search_fields=["title", "artist__name"]) q = fields.SearchFilter(
search_fields=["title", "artist__name"],
fts_search_fields=["body_text", "artist__body_text"],
)
tag = TAG_FILTER tag = TAG_FILTER
scope = common_filters.ActorScopeFilter( scope = common_filters.ActorScopeFilter(
actor_field="tracks__uploads__library__actor", distinct=True actor_field="tracks__uploads__library__actor", distinct=True

View File

@ -0,0 +1,109 @@
# Generated by Django 2.2.7 on 2019-12-16 15:06
import django.contrib.postgres.search
import django.contrib.postgres.indexes
from django.db import migrations, models
import django.db.models.deletion
from django.db import connection
FIELDS = {
"music.Artist": {
"fields": [
'name',
],
"trigger_name": "music_artist_update_body_text"
},
"music.Track": {
"fields": ['title', 'copyright'],
"trigger_name": "music_track_update_body_text"
},
"music.Album": {
"fields": ['title'],
"trigger_name": "music_album_update_body_text"
},
}
def populate_body_text(apps, schema_editor):
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
print('Populating search index for {}'.format(model.__name__))
vector = django.contrib.postgres.search.SearchVector(*search_config['fields'])
model.objects.update(body_text=vector)
def rewind(apps, schema_editor):
pass
def setup_triggers(apps, schema_editor):
cursor = connection.cursor()
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
table = model._meta.db_table
print('Creating database trigger {} on {}'.format(search_config['trigger_name'], table))
sql = """
CREATE TRIGGER {trigger_name}
BEFORE INSERT OR UPDATE
ON {table}
FOR EACH ROW
EXECUTE PROCEDURE
tsvector_update_trigger(body_text, 'pg_catalog.english', {fields})
""".format(
trigger_name=search_config['trigger_name'],
table=table,
fields=', '.join(search_config['fields']),
)
print(sql)
cursor.execute(sql)
def rewind_triggers(apps, schema_editor):
cursor = connection.cursor()
for label, search_config in FIELDS.items():
model = apps.get_model(*label.split('.'))
table = model._meta.db_table
print('Dropping database trigger {} on {}'.format(search_config['trigger_name'], table))
sql = """
DROP TRIGGER IF EXISTS {trigger_name} ON {table}
""".format(
trigger_name=search_config['trigger_name'],
table=table,
)
cursor.execute(sql)
class Migration(migrations.Migration):
dependencies = [
('music', '0043_album_cover_attachment'),
]
operations = [
migrations.AddField(
model_name='album',
name='body_text',
field=django.contrib.postgres.search.SearchVectorField(blank=True),
),
migrations.AddField(
model_name='artist',
name='body_text',
field=django.contrib.postgres.search.SearchVectorField(blank=True),
),
migrations.AddField(
model_name='track',
name='body_text',
field=django.contrib.postgres.search.SearchVectorField(blank=True),
),
migrations.AddIndex(
model_name='album',
index=django.contrib.postgres.indexes.GinIndex(fields=['body_text'], name='music_album_body_te_0ec97a_gin'),
),
migrations.AddIndex(
model_name='artist',
index=django.contrib.postgres.indexes.GinIndex(fields=['body_text'], name='music_artis_body_te_5c408d_gin'),
),
migrations.AddIndex(
model_name='track',
index=django.contrib.postgres.indexes.GinIndex(fields=['body_text'], name='music_track_body_te_da0a66_gin'),
),
migrations.RunPython(setup_triggers, rewind_triggers),
migrations.RunPython(populate_body_text, rewind),
]

View File

@ -11,6 +11,8 @@ import pydub
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.postgres.fields import JSONField from django.contrib.postgres.fields import JSONField
from django.contrib.postgres.search import SearchVectorField
from django.contrib.postgres.indexes import GinIndex
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
@ -19,7 +21,6 @@ from django.db.models.signals import post_save, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.urls import reverse from django.urls import reverse
from django.utils import timezone from django.utils import timezone
from versatileimagefield.fields import VersatileImageField from versatileimagefield.fields import VersatileImageField
from funkwhale_api import musicbrainz from funkwhale_api import musicbrainz
@ -56,10 +57,14 @@ class APIModelMixin(models.Model):
api_includes = [] api_includes = []
creation_date = models.DateTimeField(default=timezone.now, db_index=True) creation_date = models.DateTimeField(default=timezone.now, db_index=True)
import_hooks = [] import_hooks = []
body_text = SearchVectorField(blank=True)
class Meta: class Meta:
abstract = True abstract = True
ordering = ["-creation_date"] ordering = ["-creation_date"]
indexes = [
GinIndex(fields=["body_text"]),
]
@classmethod @classmethod
def get_or_create_from_api(cls, mbid): def get_or_create_from_api(cls, mbid):
@ -171,7 +176,12 @@ class ArtistQuerySet(common_models.LocalFromFidQuerySet, models.QuerySet):
def with_albums(self): def with_albums(self):
return self.prefetch_related( return self.prefetch_related(
models.Prefetch("albums", queryset=Album.objects.with_tracks_count()) models.Prefetch(
"albums",
queryset=Album.objects.with_tracks_count().select_related(
"attachment_cover", "attributed_to"
),
)
) )
def annotate_playable_by_actor(self, actor): def annotate_playable_by_actor(self, actor):
@ -524,6 +534,9 @@ class Track(APIModelMixin):
class Meta: class Meta:
ordering = ["album", "disc_number", "position"] ordering = ["album", "disc_number", "position"]
indexes = [
GinIndex(fields=["body_text"]),
]
def __str__(self): def __str__(self):
return self.title return self.title

View File

@ -4,7 +4,9 @@ import magic
import mutagen import mutagen
import pydub import pydub
from funkwhale_api.common.search import normalize_query, get_query # noqa from funkwhale_api.common.search import get_fts_query # noqa
from funkwhale_api.common.search import get_query # noqa
from funkwhale_api.common.search import normalize_query # noqa
def guess_mimetype(f): def guess_mimetype(f):

View File

@ -6,6 +6,7 @@ import urllib.parse
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.db.models import Count, Prefetch, Sum, F, Q from django.db.models import Count, Prefetch, Sum, F, Q
import django.db.utils
from django.utils import timezone from django.utils import timezone
from rest_framework import mixins from rest_framework import mixins
@ -606,20 +607,30 @@ class Search(views.APIView):
anonymous_policy = "setting" anonymous_policy = "setting"
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
query = request.GET["query"] query = request.GET.get("query", request.GET.get("q", "")) or ""
results = { query = query.strip()
# 'tags': serializers.TagSerializer(self.get_tags(query), many=True).data, if not query:
"artists": serializers.ArtistWithAlbumsSerializer( return Response({"detail": "empty query"}, status=400)
self.get_artists(query), many=True try:
).data, results = {
"tracks": serializers.TrackSerializer( # 'tags': serializers.TagSerializer(self.get_tags(query), many=True).data,
self.get_tracks(query), many=True "artists": serializers.ArtistWithAlbumsSerializer(
).data, self.get_artists(query), many=True
"albums": serializers.AlbumSerializer( ).data,
self.get_albums(query), many=True "tracks": serializers.TrackSerializer(
).data, self.get_tracks(query), many=True
"tags": TagSerializer(self.get_tags(query), many=True).data, ).data,
} "albums": serializers.AlbumSerializer(
self.get_albums(query), many=True
).data,
"tags": TagSerializer(self.get_tags(query), many=True).data,
}
except django.db.utils.ProgrammingError as e:
if "in tsquery:" in str(e):
return Response({"detail": "Invalid query"}, status=400)
else:
raise
return Response(results, status=200) return Response(results, status=200)
def get_tracks(self, query): def get_tracks(self, query):
@ -629,28 +640,58 @@ class Search(views.APIView):
"album__title__unaccent", "album__title__unaccent",
"artist__name__unaccent", "artist__name__unaccent",
] ]
query_obj = utils.get_query(query, search_fields) if settings.USE_FULL_TEXT_SEARCH:
query_obj = utils.get_fts_query(
query,
fts_fields=["body_text", "album__body_text", "artist__body_text"],
model=models.Track,
)
else:
query_obj = utils.get_query(query, search_fields)
qs = ( qs = (
models.Track.objects.all() models.Track.objects.all()
.filter(query_obj) .filter(query_obj)
.prefetch_related("artist", "album__artist") .prefetch_related(
"artist",
"attributed_to",
Prefetch(
"album",
queryset=models.Album.objects.select_related(
"artist", "attachment_cover", "attributed_to"
),
),
)
) )
return common_utils.order_for_search(qs, "title")[: self.max_results] return common_utils.order_for_search(qs, "title")[: self.max_results]
def get_albums(self, query): def get_albums(self, query):
search_fields = ["mbid", "title__unaccent", "artist__name__unaccent"] search_fields = ["mbid", "title__unaccent", "artist__name__unaccent"]
query_obj = utils.get_query(query, search_fields) if settings.USE_FULL_TEXT_SEARCH:
query_obj = utils.get_fts_query(
query, fts_fields=["body_text", "artist__body_text"], model=models.Album
)
else:
query_obj = utils.get_query(query, search_fields)
qs = ( qs = (
models.Album.objects.all() models.Album.objects.all()
.filter(query_obj) .filter(query_obj)
.prefetch_related("tracks__artist", "artist", "attributed_to") .select_related("artist", "attachment_cover", "attributed_to")
.prefetch_related("tracks__artist")
) )
return common_utils.order_for_search(qs, "title")[: self.max_results] return common_utils.order_for_search(qs, "title")[: self.max_results]
def get_artists(self, query): def get_artists(self, query):
search_fields = ["mbid", "name__unaccent"] search_fields = ["mbid", "name__unaccent"]
query_obj = utils.get_query(query, search_fields) if settings.USE_FULL_TEXT_SEARCH:
qs = models.Artist.objects.all().filter(query_obj).with_albums() query_obj = utils.get_fts_query(query, model=models.Artist)
else:
query_obj = utils.get_query(query, search_fields)
qs = (
models.Artist.objects.all()
.filter(query_obj)
.with_albums()
.select_related("attributed_to")
)
return common_utils.order_for_search(qs, "name")[: self.max_results] return common_utils.order_for_search(qs, "name")[: self.max_results]
def get_tags(self, query): def get_tags(self, query):

View File

@ -0,0 +1,50 @@
import pytest
from django.db import connection
@pytest.mark.parametrize(
"factory_name,fields",
[
("music.Artist", ["name"]),
("music.Album", ["title"]),
("music.Track", ["title"]),
],
)
def test_body_text_trigger_creation(factory_name, fields, factories):
obj = factories[factory_name]()
obj.refresh_from_db()
cursor = connection.cursor()
sql = """
SELECT to_tsvector('{indexed_text}')
""".format(
indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
)
cursor.execute(sql)
assert cursor.fetchone()[0] == obj.body_text
@pytest.mark.parametrize(
"factory_name,fields",
[
("music.Artist", ["name"]),
("music.Album", ["title"]),
("music.Track", ["title"]),
],
)
def test_body_text_trigger_updaten(factory_name, fields, factories, faker):
obj = factories[factory_name]()
for field in fields:
setattr(obj, field, faker.sentence())
obj.save()
obj.refresh_from_db()
cursor = connection.cursor()
sql = """
SELECT to_tsvector('{indexed_text}')
""".format(
indexed_text=" ".join([getattr(obj, f) for f in fields if getattr(obj, f)]),
)
cursor.execute(sql)
assert cursor.fetchone()[0] == obj.body_text

View File

@ -1193,3 +1193,48 @@ def test_get_upload_audio_metadata(logged_in_api_client, factories):
assert response.status_code == 200 assert response.status_code == 200
assert serializer.is_valid(raise_exception=True) is True assert serializer.is_valid(raise_exception=True) is True
assert response.data == serializer.validated_data assert response.data == serializer.validated_data
@pytest.mark.parametrize("use_fts", [True, False])
def test_search_get(use_fts, settings, logged_in_api_client, factories):
settings.USE_FULL_TEXT_SEARCH = use_fts
artist = factories["music.Artist"](name="Foo Fighters")
album = factories["music.Album"](title="Foo Bar")
track = factories["music.Track"](title="Foo Baz")
tag = factories["tags.Tag"](name="Foo")
factories["music.Track"]()
factories["tags.Tag"]()
url = reverse("api:v1:search")
expected = {
"artists": [serializers.ArtistWithAlbumsSerializer(artist).data],
"albums": [serializers.AlbumSerializer(album).data],
"tracks": [serializers.TrackSerializer(track).data],
"tags": [views.TagSerializer(tag).data],
}
response = logged_in_api_client.get(url, {"q": "foo"})
assert response.status_code == 200
assert response.data == expected
def test_search_get_fts_advanced(settings, logged_in_api_client, factories):
settings.USE_FULL_TEXT_SEARCH = True
artist1 = factories["music.Artist"](name="Foo Bighters")
artist2 = factories["music.Artist"](name="Bar Fighter")
factories["music.Artist"]()
url = reverse("api:v1:search")
expected = {
"artists": serializers.ArtistWithAlbumsSerializer(
[artist2, artist1], many=True
).data,
"albums": [],
"tracks": [],
"tags": [],
}
response = logged_in_api_client.get(url, {"q": '"foo | bar"'})
assert response.status_code == 200
assert response.data == expected

View File

@ -0,0 +1 @@
Replaced our slow research logic by PostgreSQL full-text search (#994)

View File

@ -6,6 +6,20 @@ Next release notes
Those release notes refer to the current development branch and are reset Those release notes refer to the current development branch and are reset
after each release. after each release.
Improved search performance
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Our search engine went through a full rewrite to make it faster. This new engine is enabled
by default when using the search bar, or when searching for artists, albums and tracks. It leverages
PostgreSQL full-text search capabilities.
During our tests, we observed huge performance improvements after the switch, by an order of
magnitude. This should be especially perceptible on pods with large databases, more modest hardware
or hard drives.
We plan to remove the old engine in an upcoming release. In the meantime, if anything goes wrong,
you can switch back by setting ``USE_FULL_TEXT_SEARCH=false`` in your ``.env`` file.
User management through the server CLI User management through the server CLI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^