Added application token for easier auth
This commit is contained in:
parent
0dfe633d65
commit
f2e5969c44
|
@ -870,6 +870,7 @@ REST_FRAMEWORK = {
|
||||||
),
|
),
|
||||||
"DEFAULT_AUTHENTICATION_CLASSES": (
|
"DEFAULT_AUTHENTICATION_CLASSES": (
|
||||||
"funkwhale_api.common.authentication.OAuth2Authentication",
|
"funkwhale_api.common.authentication.OAuth2Authentication",
|
||||||
|
"funkwhale_api.common.authentication.ApplicationTokenAuthentication",
|
||||||
"funkwhale_api.common.authentication.JSONWebTokenAuthenticationQS",
|
"funkwhale_api.common.authentication.JSONWebTokenAuthenticationQS",
|
||||||
"funkwhale_api.common.authentication.BearerTokenHeaderAuth",
|
"funkwhale_api.common.authentication.BearerTokenHeaderAuth",
|
||||||
"funkwhale_api.common.authentication.JSONWebTokenAuthentication",
|
"funkwhale_api.common.authentication.JSONWebTokenAuthentication",
|
||||||
|
|
|
@ -12,6 +12,8 @@ from rest_framework import exceptions
|
||||||
from rest_framework_jwt import authentication
|
from rest_framework_jwt import authentication
|
||||||
from rest_framework_jwt.settings import api_settings
|
from rest_framework_jwt.settings import api_settings
|
||||||
|
|
||||||
|
from funkwhale_api.users import models as users_models
|
||||||
|
|
||||||
|
|
||||||
def should_verify_email(user):
|
def should_verify_email(user):
|
||||||
if user.is_superuser:
|
if user.is_superuser:
|
||||||
|
@ -46,6 +48,36 @@ class OAuth2Authentication(BaseOAuth2Authentication):
|
||||||
resend_confirmation_email(request, e.user)
|
resend_confirmation_email(request, e.user)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationTokenAuthentication(object):
|
||||||
|
def authenticate(self, request):
|
||||||
|
try:
|
||||||
|
header = request.headers["Authorization"]
|
||||||
|
except KeyError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "Bearer" not in header:
|
||||||
|
return
|
||||||
|
|
||||||
|
token = header.split()[-1].strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
application = users_models.Application.objects.exclude(user=None).get(
|
||||||
|
token=token
|
||||||
|
)
|
||||||
|
except users_models.Application.DoesNotExist:
|
||||||
|
return
|
||||||
|
user = users_models.User.objects.all().for_auth().get(id=application.user_id)
|
||||||
|
if not user.is_active:
|
||||||
|
msg = _("User account is disabled.")
|
||||||
|
raise exceptions.AuthenticationFailed(msg)
|
||||||
|
|
||||||
|
if should_verify_email(user):
|
||||||
|
raise UnverifiedEmail(user)
|
||||||
|
|
||||||
|
request.scopes = application.scope.split()
|
||||||
|
return user, None
|
||||||
|
|
||||||
|
|
||||||
class BaseJsonWebTokenAuth(object):
|
class BaseJsonWebTokenAuth(object):
|
||||||
def authenticate(self, request):
|
def authenticate(self, request):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -129,6 +129,7 @@ class SuperUserFactory(UserFactory):
|
||||||
class ApplicationFactory(factory.django.DjangoModelFactory):
|
class ApplicationFactory(factory.django.DjangoModelFactory):
|
||||||
name = factory.Faker("name")
|
name = factory.Faker("name")
|
||||||
redirect_uris = factory.Faker("url")
|
redirect_uris = factory.Faker("url")
|
||||||
|
token = factory.Faker("uuid4")
|
||||||
client_type = models.Application.CLIENT_CONFIDENTIAL
|
client_type = models.Application.CLIENT_CONFIDENTIAL
|
||||||
authorization_grant_type = models.Application.GRANT_AUTHORIZATION_CODE
|
authorization_grant_type = models.Application.GRANT_AUTHORIZATION_CODE
|
||||||
scope = "read"
|
scope = "read"
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Generated by Django 3.0.8 on 2020-08-19 08:58
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('users', '0019_auto_20200718_0741'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='token',
|
||||||
|
field=models.CharField(blank=True, max_length=50, null=True, unique=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -31,8 +31,8 @@ from funkwhale_api.federation import models as federation_models
|
||||||
from funkwhale_api.federation import utils as federation_utils
|
from funkwhale_api.federation import utils as federation_utils
|
||||||
|
|
||||||
|
|
||||||
def get_token():
|
def get_token(length=15):
|
||||||
return binascii.b2a_hex(os.urandom(15)).decode("utf-8")
|
return binascii.b2a_hex(os.urandom(length)).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
PERMISSIONS_CONFIGURATION = {
|
PERMISSIONS_CONFIGURATION = {
|
||||||
|
@ -350,6 +350,7 @@ class Invitation(models.Model):
|
||||||
|
|
||||||
class Application(oauth2_models.AbstractApplication):
|
class Application(oauth2_models.AbstractApplication):
|
||||||
scope = models.TextField(blank=True)
|
scope = models.TextField(blank=True)
|
||||||
|
token = models.CharField(max_length=50, blank=True, null=True, unique=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def normalized_scopes(self):
|
def normalized_scopes(self):
|
||||||
|
|
|
@ -10,6 +10,12 @@ class ApplicationSerializer(serializers.ModelSerializer):
|
||||||
model = models.Application
|
model = models.Application
|
||||||
fields = ["client_id", "name", "scopes", "created", "updated"]
|
fields = ["client_id", "name", "scopes", "created", "updated"]
|
||||||
|
|
||||||
|
def to_representation(self, obj):
|
||||||
|
repr = super().to_representation(obj)
|
||||||
|
if obj.user_id:
|
||||||
|
repr["token"] = obj.token
|
||||||
|
return repr
|
||||||
|
|
||||||
|
|
||||||
class CreateApplicationSerializer(serializers.ModelSerializer):
|
class CreateApplicationSerializer(serializers.ModelSerializer):
|
||||||
name = serializers.CharField(required=True, max_length=255)
|
name = serializers.CharField(required=True, max_length=255)
|
||||||
|
@ -27,3 +33,9 @@ class CreateApplicationSerializer(serializers.ModelSerializer):
|
||||||
"redirect_uris",
|
"redirect_uris",
|
||||||
]
|
]
|
||||||
read_only_fields = ["client_id", "client_secret", "created", "updated"]
|
read_only_fields = ["client_id", "client_secret", "created", "updated"]
|
||||||
|
|
||||||
|
def to_representation(self, obj):
|
||||||
|
repr = super().to_representation(obj)
|
||||||
|
if obj.user_id:
|
||||||
|
repr["token"] = obj.token
|
||||||
|
return repr
|
||||||
|
|
|
@ -4,7 +4,8 @@ import urllib.parse
|
||||||
from django import http
|
from django import http
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from rest_framework import mixins, permissions, views, viewsets
|
from rest_framework import mixins, permissions, response, views, viewsets
|
||||||
|
from rest_framework.decorators import action
|
||||||
|
|
||||||
from oauth2_provider import exceptions as oauth2_exceptions
|
from oauth2_provider import exceptions as oauth2_exceptions
|
||||||
from oauth2_provider import views as oauth_views
|
from oauth2_provider import views as oauth_views
|
||||||
|
@ -32,6 +33,7 @@ class ApplicationViewSet(
|
||||||
"destroy": "write:security",
|
"destroy": "write:security",
|
||||||
"update": "write:security",
|
"update": "write:security",
|
||||||
"partial_update": "write:security",
|
"partial_update": "write:security",
|
||||||
|
"refresh_token": "write:security",
|
||||||
"list": "read:security",
|
"list": "read:security",
|
||||||
}
|
}
|
||||||
lookup_field = "client_id"
|
lookup_field = "client_id"
|
||||||
|
@ -54,6 +56,7 @@ class ApplicationViewSet(
|
||||||
client_type=models.Application.CLIENT_CONFIDENTIAL,
|
client_type=models.Application.CLIENT_CONFIDENTIAL,
|
||||||
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
|
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
|
||||||
user=self.request.user if self.request.user.is_authenticated else None,
|
user=self.request.user if self.request.user.is_authenticated else None,
|
||||||
|
token=models.get_token(15) if self.request.user.is_authenticated else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_serializer(self, *args, **kwargs):
|
def get_serializer(self, *args, **kwargs):
|
||||||
|
@ -70,10 +73,31 @@ class ApplicationViewSet(
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
qs = super().get_queryset()
|
qs = super().get_queryset()
|
||||||
if self.action in ["list", "destroy", "update", "partial_update"]:
|
if self.action in [
|
||||||
|
"list",
|
||||||
|
"destroy",
|
||||||
|
"update",
|
||||||
|
"partial_update",
|
||||||
|
"refresh_token",
|
||||||
|
]:
|
||||||
qs = qs.filter(user=self.request.user)
|
qs = qs.filter(user=self.request.user)
|
||||||
return qs
|
return qs
|
||||||
|
|
||||||
|
@action(
|
||||||
|
detail=True,
|
||||||
|
methods=["post"],
|
||||||
|
url_name="refresh_token",
|
||||||
|
url_path="refresh-token",
|
||||||
|
)
|
||||||
|
def refresh_token(self, request, *args, **kwargs):
|
||||||
|
app = self.get_object()
|
||||||
|
if not app.user_id or request.user != app.user:
|
||||||
|
return response.Response(status=404)
|
||||||
|
app.token = models.get_token(15)
|
||||||
|
app.save(update_fields=["token"])
|
||||||
|
serializer = serializers.CreateApplicationSerializer(app)
|
||||||
|
return response.Response(serializer.data, status=200)
|
||||||
|
|
||||||
|
|
||||||
class GrantViewSet(
|
class GrantViewSet(
|
||||||
mixins.RetrieveModelMixin,
|
mixins.RetrieveModelMixin,
|
||||||
|
|
|
@ -60,3 +60,13 @@ def test_json_webtoken_auth_verify_email_validity(
|
||||||
auth.authenticate(request)
|
auth.authenticate(request)
|
||||||
|
|
||||||
should_verify.assert_called_once_with(user)
|
should_verify.assert_called_once_with(user)
|
||||||
|
|
||||||
|
|
||||||
|
def test_app_token_authentication(factories, api_request):
|
||||||
|
user = factories["users.User"]()
|
||||||
|
app = factories["users.Application"](user=user, scope="read write")
|
||||||
|
request = api_request.get("/", HTTP_AUTHORIZATION="Bearer {}".format(app.token))
|
||||||
|
|
||||||
|
auth = authentication.ApplicationTokenAuthentication()
|
||||||
|
assert auth.authenticate(request)[0] == app.user
|
||||||
|
assert request.scopes == ["read", "write"]
|
||||||
|
|
|
@ -47,6 +47,8 @@ def test_apps_post_logged_in_user(logged_in_api_client, db):
|
||||||
assert response.data == serializers.CreateApplicationSerializer(app).data
|
assert response.data == serializers.CreateApplicationSerializer(app).data
|
||||||
assert app.scope == "read write:profile"
|
assert app.scope == "read write:profile"
|
||||||
assert app.user == logged_in_api_client.user
|
assert app.user == logged_in_api_client.user
|
||||||
|
assert app.token is not None
|
||||||
|
assert response.data["token"] == app.token
|
||||||
|
|
||||||
|
|
||||||
def test_apps_list_anonymous(api_client, db):
|
def test_apps_list_anonymous(api_client, db):
|
||||||
|
@ -120,6 +122,31 @@ def test_apps_get_owner(preferences, logged_in_api_client, factories):
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.data == serializers.CreateApplicationSerializer(app).data
|
assert response.data == serializers.CreateApplicationSerializer(app).data
|
||||||
|
assert response.data["token"] == app.token
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_refresh_token(preferences, logged_in_api_client, factories):
|
||||||
|
app = factories["users.Application"](user=logged_in_api_client.user)
|
||||||
|
old_token = app.token
|
||||||
|
url = reverse(
|
||||||
|
"api:v1:oauth:apps-refresh_token", kwargs={"client_id": app.client_id}
|
||||||
|
)
|
||||||
|
response = logged_in_api_client.post(url)
|
||||||
|
|
||||||
|
app.refresh_from_db()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.data == serializers.CreateApplicationSerializer(app).data
|
||||||
|
assert app.token != old_token
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_refresh_token_not_owner(preferences, logged_in_api_client, factories):
|
||||||
|
app = factories["users.Application"]()
|
||||||
|
url = reverse(
|
||||||
|
"api:v1:oauth:apps-refresh_token", kwargs={"client_id": app.client_id}
|
||||||
|
)
|
||||||
|
response = logged_in_api_client.post(url)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
def test_authorize_view_post(logged_in_client, factories):
|
def test_authorize_view_post(logged_in_client, factories):
|
||||||
|
|
Loading…
Reference in New Issue