491 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			491 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
| import datetime
 | |
| import hashlib
 | |
| import logging
 | |
| import os
 | |
| import shutil
 | |
| import uuid
 | |
| import xml.etree.ElementTree as ET
 | |
| from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
 | |
| 
 | |
| import bleach.sanitizer
 | |
| import markdown
 | |
| from django import urls
 | |
| from django.conf import settings
 | |
| from django.core.files.base import ContentFile
 | |
| from django.db import models, transaction
 | |
| from django.http import request
 | |
| from django.utils import timezone
 | |
| from django.utils.deconstruct import deconstructible
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| def batch(iterable, n=1):
 | |
|     has_entries = True
 | |
|     while has_entries:
 | |
|         current = []
 | |
|         for i in range(0, n):
 | |
|             try:
 | |
|                 current.append(next(iterable))
 | |
|             except StopIteration:
 | |
|                 has_entries = False
 | |
|         yield current
 | |
| 
 | |
| 
 | |
| def rename_file(instance, field_name, new_name, allow_missing_file=False):
 | |
|     field = getattr(instance, field_name)
 | |
|     current_name, extension = os.path.splitext(field.name)
 | |
| 
 | |
|     new_name_with_extension = f"{new_name}{extension}"
 | |
|     try:
 | |
|         shutil.move(field.path, new_name_with_extension)
 | |
|     except FileNotFoundError:
 | |
|         if not allow_missing_file:
 | |
|             raise
 | |
|         print("Skipped missing file", field.path)
 | |
|     initial_path = os.path.dirname(field.name)
 | |
|     field.name = os.path.join(initial_path, new_name_with_extension)
 | |
|     instance.save()
 | |
|     return new_name_with_extension
 | |
| 
 | |
| 
 | |
| def on_commit(f, *args, **kwargs):
 | |
|     return transaction.on_commit(lambda: f(*args, **kwargs))
 | |
| 
 | |
| 
 | |
| def set_query_parameter(url, **kwargs):
 | |
|     """Given a URL, set or replace a query parameter and return the
 | |
|     modified URL.
 | |
| 
 | |
|     >>> set_query_parameter('http://example.com?foo=bar&biz=baz', 'foo', 'stuff')
 | |
|     'http://example.com?foo=stuff&biz=baz'
 | |
|     """
 | |
|     scheme, netloc, path, query_string, fragment = urlsplit(url)
 | |
|     query_params = parse_qs(query_string)
 | |
| 
 | |
|     for param_name, param_value in kwargs.items():
 | |
|         query_params[param_name] = [param_value]
 | |
|     new_query_string = urlencode(query_params, doseq=True)
 | |
| 
 | |
|     return urlunsplit((scheme, netloc, path, new_query_string, fragment))
 | |
| 
 | |
| 
 | |
| @deconstructible
 | |
| class ChunkedPath:
 | |
|     def sanitize_filename(self, filename):
 | |
|         return filename.replace("/", "-")
 | |
| 
 | |
|     def __init__(self, root, preserve_file_name=True):
 | |
|         self.root = root
 | |
|         self.preserve_file_name = preserve_file_name
 | |
| 
 | |
|     def __call__(self, instance, filename):
 | |
|         self.sanitize_filename(filename)
 | |
|         uid = str(uuid.uuid4())
 | |
|         chunk_size = 2
 | |
|         chunks = [uid[i : i + chunk_size] for i in range(0, len(uid), chunk_size)]
 | |
|         if self.preserve_file_name:
 | |
|             parts = chunks[:3] + [filename]
 | |
|         else:
 | |
|             ext = os.path.splitext(filename)[1][1:].lower()
 | |
|             new_filename = "".join(chunks[3:]) + f".{ext}"
 | |
|             parts = chunks[:3] + [new_filename]
 | |
|         return os.path.join(self.root, *parts)
 | |
| 
 | |
| 
 | |
| def chunk_queryset(source_qs, chunk_size):
 | |
|     """
 | |
|     From https://github.com/peopledoc/django-chunkator/blob/master/chunkator/__init__.py
 | |
|     """
 | |
|     pk = None
 | |
|     # In django 1.9, _fields is always present and `None` if 'values()' is used
 | |
|     # In Django 1.8 and below, _fields will only be present if using `values()`
 | |
|     has_fields = hasattr(source_qs, "_fields") and source_qs._fields
 | |
|     if has_fields:
 | |
|         if "pk" not in source_qs._fields:
 | |
|             raise ValueError("The values() call must include the `pk` field")
 | |
| 
 | |
|     field = source_qs.model._meta.pk
 | |
|     # set the correct field name:
 | |
|     # for ForeignKeys, we want to use `model_id` field, and not `model`,
 | |
|     # to bypass default ordering on related model
 | |
|     order_by_field = field.attname
 | |
| 
 | |
|     source_qs = source_qs.order_by(order_by_field)
 | |
|     queryset = source_qs
 | |
|     while True:
 | |
|         if pk:
 | |
|             queryset = source_qs.filter(pk__gt=pk)
 | |
|         page = queryset[:chunk_size]
 | |
|         page = list(page)
 | |
|         nb_items = len(page)
 | |
| 
 | |
|         if nb_items == 0:
 | |
|             return
 | |
| 
 | |
|         last_item = page[-1]
 | |
|         # source_qs._fields exists *and* is not none when using "values()"
 | |
|         if has_fields:
 | |
|             pk = last_item["pk"]
 | |
|         else:
 | |
|             pk = last_item.pk
 | |
| 
 | |
|         yield page
 | |
| 
 | |
|         if nb_items < chunk_size:
 | |
|             return
 | |
| 
 | |
| 
 | |
| def join_url(start, end):
 | |
|     if end.startswith("http://") or end.startswith("https://"):
 | |
|         # already a full URL, joining makes no sense
 | |
|         return end
 | |
|     if start.endswith("/") and end.startswith("/"):
 | |
|         return start + end[1:]
 | |
| 
 | |
|     if not start.endswith("/") and not end.startswith("/"):
 | |
|         return start + "/" + end
 | |
| 
 | |
|     return start + end
 | |
| 
 | |
| 
 | |
| def media_url(path):
 | |
|     if settings.MEDIA_URL.startswith("http://") or settings.MEDIA_URL.startswith(
 | |
|         "https://"
 | |
|     ):
 | |
|         return join_url(settings.MEDIA_URL, path)
 | |
| 
 | |
|     from funkwhale_api.federation import utils as federation_utils
 | |
| 
 | |
|     return federation_utils.full_url(path)
 | |
| 
 | |
| 
 | |
| def spa_reverse(name, args=[], kwargs={}):
 | |
|     return urls.reverse(name, urlconf=settings.SPA_URLCONF, args=args, kwargs=kwargs)
 | |
| 
 | |
| 
 | |
| def spa_resolve(path):
 | |
|     return urls.resolve(path, urlconf=settings.SPA_URLCONF)
 | |
| 
 | |
| 
 | |
| def parse_meta(html):
 | |
|     # dirty but this is only for testing so we don't really care,
 | |
|     # we convert the html string to xml so it can be parsed as xml
 | |
|     html = '<?xml version="1.0"?>' + html
 | |
|     tree = ET.fromstring(html)
 | |
| 
 | |
|     meta = [elem for elem in tree.iter() if elem.tag in ["meta", "link"]]
 | |
| 
 | |
|     return [dict([("tag", elem.tag)] + list(elem.items())) for elem in meta]
 | |
| 
 | |
| 
 | |
| def order_for_search(qs, field):
 | |
|     """
 | |
|     When searching, it's often more useful to have short results first,
 | |
|     this function will order the given qs based on the length of the given field
 | |
|     """
 | |
|     return qs.annotate(__size=models.functions.Length(field)).order_by("__size", "pk")
 | |
| 
 | |
| 
 | |
| def recursive_getattr(obj, key, permissive=False):
 | |
|     """
 | |
|     Given a dictionary such as {'user': {'name': 'Bob'}} or and object and
 | |
|     a dotted string such as user.name, returns 'Bob'.
 | |
| 
 | |
|     If the value is not present, returns None
 | |
|     """
 | |
|     v = obj
 | |
|     for k in key.split("."):
 | |
|         try:
 | |
|             if hasattr(v, "get"):
 | |
|                 v = v.get(k)
 | |
|             else:
 | |
|                 v = getattr(v, k)
 | |
|         except (TypeError, AttributeError):
 | |
|             if not permissive:
 | |
|                 raise
 | |
|             return
 | |
|         if v is None:
 | |
|             return
 | |
| 
 | |
|     return v
 | |
| 
 | |
| 
 | |
| def replace_prefix(queryset, field, old, new):
 | |
|     """
 | |
|     Given a queryset of objects and a field name, will find objects
 | |
|     for which the field have the given value, and replace the old prefix by
 | |
|     the new one.
 | |
| 
 | |
|     This is especially useful to find/update bad federation ids, to replace:
 | |
| 
 | |
|     http://wrongprotocolanddomain/path
 | |
| 
 | |
|     by
 | |
| 
 | |
|     https://goodprotocalanddomain/path
 | |
| 
 | |
|     on a whole table with a single query.
 | |
|     """
 | |
|     qs = queryset.filter(**{f"{field}__startswith": old})
 | |
|     # we extract the part after the old prefix, and Concat it with our new prefix
 | |
|     update = models.functions.Concat(
 | |
|         models.Value(new),
 | |
|         models.functions.Substr(field, len(old) + 1, output_field=models.CharField()),
 | |
|     )
 | |
|     return qs.update(**{field: update})
 | |
| 
 | |
| 
 | |
| def concat_dicts(*dicts):
 | |
|     n = {}
 | |
|     for d in dicts:
 | |
|         n.update(d)
 | |
| 
 | |
|     return n
 | |
| 
 | |
| 
 | |
| def get_updated_fields(conf, data, obj):
 | |
|     """
 | |
|     Given a list of fields, a dict and an object, will return the dict keys/values
 | |
|     that differ from the corresponding fields on the object.
 | |
|     """
 | |
|     final_conf = []
 | |
|     for c in conf:
 | |
|         if isinstance(c, str):
 | |
|             final_conf.append((c, c))
 | |
|         else:
 | |
|             final_conf.append(c)
 | |
| 
 | |
|     final_data = {}
 | |
| 
 | |
|     for data_field, obj_field in final_conf:
 | |
|         try:
 | |
|             data_value = data[data_field]
 | |
|         except KeyError:
 | |
|             continue
 | |
|         if obj.pk:
 | |
|             obj_value = getattr(obj, obj_field)
 | |
|             if obj_value != data_value:
 | |
|                 final_data[obj_field] = data_value
 | |
|         else:
 | |
|             final_data[obj_field] = data_value
 | |
| 
 | |
|     return final_data
 | |
| 
 | |
| 
 | |
| def join_queries_or(left, right):
 | |
|     if left:
 | |
|         return left | right
 | |
|     else:
 | |
|         return right
 | |
| 
 | |
| 
 | |
| MARKDOWN_RENDERER = markdown.Markdown(extensions=settings.MARKDOWN_EXTENSIONS)
 | |
| 
 | |
| 
 | |
| def render_markdown(text):
 | |
|     return MARKDOWN_RENDERER.convert(text)
 | |
| 
 | |
| 
 | |
| SAFE_TAGS = [
 | |
|     "p",
 | |
|     "a",
 | |
|     "abbr",
 | |
|     "acronym",
 | |
|     "b",
 | |
|     "blockquote",
 | |
|     "br",
 | |
|     "code",
 | |
|     "em",
 | |
|     "i",
 | |
|     "li",
 | |
|     "ol",
 | |
|     "strong",
 | |
|     "ul",
 | |
| ]
 | |
| HTMl_CLEANER = bleach.sanitizer.Cleaner(strip=True, tags=SAFE_TAGS)
 | |
| 
 | |
| HTML_PERMISSIVE_CLEANER = bleach.sanitizer.Cleaner(
 | |
|     strip=True,
 | |
|     tags=SAFE_TAGS + ["h1", "h2", "h3", "h4", "h5", "h6", "div", "section", "article"],
 | |
|     attributes=["class", "rel", "alt", "title", "href"],
 | |
| )
 | |
| 
 | |
| # support for additional tlds
 | |
| # cf https://github.com/mozilla/bleach/issues/367#issuecomment-384631867
 | |
| ALL_TLDS = set(settings.LINKIFIER_SUPPORTED_TLDS + bleach.linkifier.TLDS)
 | |
| URL_RE = bleach.linkifier.build_url_re(tlds=sorted(ALL_TLDS, reverse=True))
 | |
| HTML_LINKER = bleach.linkifier.Linker(url_re=URL_RE)
 | |
| 
 | |
| 
 | |
| def clean_html(html, permissive=False):
 | |
|     return (
 | |
|         HTML_PERMISSIVE_CLEANER.clean(html) if permissive else HTMl_CLEANER.clean(html)
 | |
|     )
 | |
| 
 | |
| 
 | |
| def render_html(text, content_type, permissive=False):
 | |
|     if not text:
 | |
|         return ""
 | |
|     rendered = render_markdown(text)
 | |
|     if content_type == "text/html":
 | |
|         rendered = text
 | |
|     elif content_type == "text/markdown":
 | |
|         rendered = render_markdown(text)
 | |
|     else:
 | |
|         rendered = render_markdown(text)
 | |
|     rendered = HTML_LINKER.linkify(rendered)
 | |
|     return clean_html(rendered, permissive=permissive).strip().replace("\n", "")
 | |
| 
 | |
| 
 | |
| def render_plain_text(html):
 | |
|     if not html:
 | |
|         return ""
 | |
|     return bleach.clean(html, tags=[], strip=True)
 | |
| 
 | |
| 
 | |
| def same_content(old, text=None, content_type=None):
 | |
|     return old.text == text and old.content_type == content_type
 | |
| 
 | |
| 
 | |
| @transaction.atomic
 | |
| def attach_content(obj, field, content_data):
 | |
|     from . import models
 | |
| 
 | |
|     content_data = content_data or {}
 | |
|     existing = getattr(obj, f"{field}_id")
 | |
| 
 | |
|     if existing:
 | |
|         if same_content(getattr(obj, field), **content_data):
 | |
|             # optimization to avoid a delete/save if possible
 | |
|             return getattr(obj, field)
 | |
|         getattr(obj, field).delete()
 | |
|         setattr(obj, field, None)
 | |
| 
 | |
|     if not content_data:
 | |
|         return
 | |
| 
 | |
|     content_obj = models.Content.objects.create(
 | |
|         text=content_data["text"][: models.CONTENT_TEXT_MAX_LENGTH],
 | |
|         content_type=content_data["content_type"],
 | |
|     )
 | |
|     setattr(obj, field, content_obj)
 | |
|     obj.save(update_fields=[field])
 | |
|     return content_obj
 | |
| 
 | |
| 
 | |
| @transaction.atomic
 | |
| def attach_file(obj, field, file_data, fetch=False):
 | |
|     from . import models, tasks
 | |
| 
 | |
|     existing = getattr(obj, f"{field}_id")
 | |
|     if existing:
 | |
|         getattr(obj, field).delete()
 | |
| 
 | |
|     if not file_data:
 | |
|         return
 | |
| 
 | |
|     if isinstance(file_data, models.Attachment):
 | |
|         attachment = file_data
 | |
|     else:
 | |
|         extensions = {"image/jpeg": "jpg", "image/png": "png", "image/gif": "gif"}
 | |
|         extension = extensions.get(file_data["mimetype"], "jpg")
 | |
|         attachment = models.Attachment(mimetype=file_data["mimetype"])
 | |
|         name_fields = ["uuid", "full_username", "pk"]
 | |
|         name = [
 | |
|             getattr(obj, field) for field in name_fields if getattr(obj, field, None)
 | |
|         ][0]
 | |
|         filename = f"{field}-{name}.{extension}"
 | |
|         if "url" in file_data:
 | |
|             attachment.url = file_data["url"]
 | |
|         else:
 | |
|             f = ContentFile(file_data["content"])
 | |
|             attachment.file.save(filename, f, save=False)
 | |
| 
 | |
|         if not attachment.file and fetch:
 | |
|             try:
 | |
|                 tasks.fetch_remote_attachment(attachment, filename=filename, save=False)
 | |
|             except Exception as e:
 | |
|                 logger.warn(
 | |
|                     "Cannot download attachment at url %s: %s", attachment.url, e
 | |
|                 )
 | |
|                 attachment = None
 | |
| 
 | |
|         if attachment:
 | |
|             attachment.save()
 | |
| 
 | |
|     setattr(obj, field, attachment)
 | |
|     obj.save(update_fields=[field])
 | |
|     return attachment
 | |
| 
 | |
| 
 | |
| def get_mimetype_from_ext(path):
 | |
|     parts = path.lower().split(".")
 | |
|     ext = parts[-1]
 | |
|     match = {
 | |
|         "jpeg": "image/jpeg",
 | |
|         "jpg": "image/jpeg",
 | |
|         "png": "image/png",
 | |
|         "gif": "image/gif",
 | |
|     }
 | |
|     return match.get(ext)
 | |
| 
 | |
| 
 | |
| def get_audio_mimetype(mt):
 | |
|     aliases = {"audio/x-mp3": "audio/mpeg", "audio/mpeg3": "audio/mpeg"}
 | |
|     return aliases.get(mt, mt)
 | |
| 
 | |
| 
 | |
| def update_modification_date(obj, field="modification_date", date=None):
 | |
|     IGNORE_DELAY = 60
 | |
|     current_value = getattr(obj, field)
 | |
|     date = date or timezone.now()
 | |
|     ignore = current_value is not None and current_value < date - datetime.timedelta(
 | |
|         seconds=IGNORE_DELAY
 | |
|     )
 | |
|     if ignore:
 | |
|         setattr(obj, field, date)
 | |
|         obj.__class__.objects.filter(pk=obj.pk).update(**{field: date})
 | |
| 
 | |
|     return date
 | |
| 
 | |
| 
 | |
| def monkey_patch_request_build_absolute_uri():
 | |
|     """
 | |
|     Since we have FUNKWHALE_HOSTNAME and PROTOCOL hardcoded in settings, we can
 | |
|     override django's multisite logic which can break when reverse proxy aren't configured
 | |
|     properly.
 | |
|     """
 | |
|     builtin_scheme = request.HttpRequest.scheme
 | |
| 
 | |
|     def scheme(self):
 | |
|         if settings.IGNORE_FORWARDED_HOST_AND_PROTO:
 | |
|             return settings.FUNKWHALE_PROTOCOL
 | |
|         return builtin_scheme.fget(self)
 | |
| 
 | |
|     builtin_get_host = request.HttpRequest.get_host
 | |
| 
 | |
|     def get_host(self):
 | |
|         if settings.IGNORE_FORWARDED_HOST_AND_PROTO:
 | |
|             return settings.FUNKWHALE_HOSTNAME
 | |
|         return builtin_get_host(self)
 | |
| 
 | |
|     request.HttpRequest.scheme = property(scheme)
 | |
|     request.HttpRequest.get_host = get_host
 | |
| 
 | |
| 
 | |
| def get_file_hash(file, algo=None, chunk_size=None, full_read=False):
 | |
|     algo = algo or settings.HASHING_ALGORITHM
 | |
|     chunk_size = chunk_size or settings.HASHING_CHUNK_SIZE
 | |
|     handler = getattr(hashlib, algo)
 | |
|     hash = handler()
 | |
|     file.seek(0)
 | |
|     if full_read:
 | |
|         for byte_block in iter(lambda: file.read(chunk_size), b""):
 | |
|             hash.update(byte_block)
 | |
|     else:
 | |
|         # sometimes, it's useful to only hash the beginning of the file, e.g
 | |
|         # to avoid a lot of I/O when crawling large libraries
 | |
|         hash.update(file.read(chunk_size))
 | |
|     return f"{algo}:{hash.hexdigest()}"
 |