From 3c7d9b5ac85b8e6ddf6a1bc88977422aa9b405fc Mon Sep 17 00:00:00 2001 From: JuniorJPDJ Date: Tue, 8 Nov 2022 08:53:32 +0000 Subject: [PATCH] perf(radio/pick): speedup radio track picking code NOCHANGELOG --- api/funkwhale_api/radios/radios.py | 20 ++++++++++++++------ api/tests/radios/test_radios.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/api/funkwhale_api/radios/radios.py b/api/funkwhale_api/radios/radios.py index 4b333b966..33512980d 100644 --- a/api/funkwhale_api/radios/radios.py +++ b/api/funkwhale_api/radios/radios.py @@ -1,6 +1,7 @@ import datetime import logging import random +from typing import Optional, List, Tuple from django.core.exceptions import ValidationError from django.db import connection @@ -25,14 +26,21 @@ class SimpleRadio(object): def clean(self, instance): return - def pick(self, choices, previous_choices=[]): - possible_choices = [x for x in choices if x not in previous_choices] - return random.sample(possible_choices, 1)[0] + def pick( + self, choices: List[int], previous_choices: Optional[List[int]] = None + ) -> int: + if previous_choices: + choices = list(set(choices).difference(set(previous_choices))) + return random.sample(choices, 1)[0] - def pick_many(self, choices, quantity): - return random.sample(list(choices), quantity) + def pick_many(self, choices: List[int], quantity: int) -> int: + return random.sample(list(set(choices)), quantity) - def weighted_pick(self, choices, previous_choices=[]): + def weighted_pick( + self, + choices: List[Tuple[int, int]], + previous_choices: Optional[List[int]] = None, + ) -> int: total = sum(weight for c, weight in choices) r = random.uniform(0, total) upto = 0 diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index 0e2b47f60..2745bb368 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -21,7 +21,7 @@ def test_can_pick_track_from_choices(): previous_choices = [first_pick] for remaining_choice in choices: pick = radio.pick(choices=choices, previous_choices=previous_choices) - assert pick in set(choices).difference(previous_choices) + assert pick in set(choices).difference(set(previous_choices)) def test_can_pick_by_weight():