Commit d673e4d5 authored by erdnaxe's avatar erdnaxe 🎇 Committed by Hugo LEVY-FALK

Clean up of API code

Automatic clean up that checks Python 2.7 compatibility, switch some
methods to static and rearrange code.
parent ee2ee0ad
......@@ -26,8 +26,8 @@ done.
"""
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.contrib.auth.models import Permission
from django.contrib.contenttypes.models import ContentType
from django.utils.translation import ugettext_lazy as _
......
......@@ -26,12 +26,14 @@ import datetime
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
from rest_framework.authentication import TokenAuthentication
from rest_framework import exceptions
from rest_framework.authentication import TokenAuthentication
class ExpiringTokenAuthentication(TokenAuthentication):
"""Authenticate a user if the provided token is valid and not expired.
"""
def authenticate_credentials(self, key):
"""See base class. Add the verification the token is not expired.
"""
......@@ -46,4 +48,4 @@ class ExpiringTokenAuthentication(TokenAuthentication):
if token.created < utc_now - token_duration:
raise exceptions.AuthenticationFailed(_('Token has expired'))
return (token.user, token)
return token.user, token
......@@ -24,8 +24,6 @@
from rest_framework import permissions, exceptions
from re2o.acl import can_create, can_edit, can_delete, can_view_all
from . import acl
......@@ -57,14 +55,14 @@ def _get_param_in_view(view, param_name):
AssertionError: None of the getter function or the attribute are
defined in the view.
"""
assert hasattr(view, 'get_'+param_name) \
or getattr(view, param_name, None) is not None, (
assert hasattr(view, 'get_' + param_name) \
or getattr(view, param_name, None) is not None, (
'cannot apply {} on a view that does not set '
'`.{}` or have a `.get_{}()` method.'
).format(self.__class__.__name__, param_name, param_name)
if hasattr(view, 'get_'+param_name):
param = getattr(view, 'get_'+param_name)()
if hasattr(view, 'get_' + param_name):
param = getattr(view, 'get_' + param_name)()
assert param is not None, (
'{}.get_{}() returned None'
).format(view.__class__.__name__, param_name)
......@@ -80,7 +78,8 @@ class ACLPermission(permissions.BasePermission):
See the wiki for the syntax of this attribute.
"""
def get_required_permissions(self, method, view):
@staticmethod
def get_required_permissions(method, view):
"""Build the list of permissions required for the request to be
accepted.
......@@ -153,15 +152,15 @@ class AutodetectACLPermission(permissions.BasePermission):
'OPTIONS': [can_see_api, lambda model: model.can_view_all],
'HEAD': [can_see_api, lambda model: model.can_view_all],
'POST': [can_see_api, lambda model: model.can_create],
'PUT': [], # No restrictions, apply to objects
'PATCH': [], # No restrictions, apply to objects
'PUT': [], # No restrictions, apply to objects
'PATCH': [], # No restrictions, apply to objects
'DELETE': [], # No restrictions, apply to objects
}
perms_obj_map = {
'GET': [can_see_api, lambda obj: obj.can_view],
'OPTIONS': [can_see_api, lambda obj: obj.can_view],
'HEAD': [can_see_api, lambda obj: obj.can_view],
'POST': [], # No restrictions, apply to models
'POST': [], # No restrictions, apply to models
'PUT': [can_see_api, lambda obj: obj.can_edit],
'PATCH': [can_see_api, lambda obj: obj.can_edit],
'DELETE': [can_see_api, lambda obj: obj.can_delete],
......@@ -209,7 +208,8 @@ class AutodetectACLPermission(permissions.BasePermission):
return [perm(obj) for perm in self.perms_obj_map[method]]
def _queryset(self, view):
@staticmethod
def _queryset(view):
return _get_param_in_view(view, 'queryset')
def has_permission(self, request, view):
......@@ -282,4 +282,3 @@ class AutodetectACLPermission(permissions.BasePermission):
return False
return True
......@@ -24,12 +24,12 @@
from collections import OrderedDict
from django.conf.urls import url, include
from django.conf.urls import url
from django.core.urlresolvers import NoReverseMatch
from rest_framework import views
from rest_framework.routers import DefaultRouter
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
from rest_framework.settings import api_settings
......@@ -64,7 +64,8 @@ class AllViewsRouter(DefaultRouter):
name = self.get_default_name(pattern)
self.view_registry.append((pattern, view, name))
def get_default_name(self, pattern):
@staticmethod
def get_default_name(pattern):
"""Returns the name to use for the route if none was specified.
Args:
......@@ -113,7 +114,8 @@ class AllViewsRouter(DefaultRouter):
_ignore_model_permissions = True
renderer_classes = view_renderers
def get(self, request, *args, **kwargs):
@staticmethod
def get(request, *args, **kwargs):
if request.accepted_renderer.media_type in schema_media_types:
# Return a schema response.
schema = schema_generator.get_schema(request)
......
This diff is collapsed.
......@@ -48,4 +48,4 @@ API_APPS = (
)
# The expiration time for an authentication token
API_TOKEN_DURATION = 86400 # 24 hours
API_TOKEN_DURATION = 86400 # 24 hours
......@@ -21,10 +21,11 @@
"""Defines the test suite for the API
"""
import json
import datetime
from rest_framework.test import APITestCase
import json
from requests import codes
from rest_framework.test import APITestCase
import cotisations.models as cotisations
import machines.models as machines
......@@ -33,7 +34,7 @@ import topologie.models as topologie
import users.models as users
class APIEndpointsTestCase(APITestCase):
class APIEndpointsTestCase(APITestCase):
"""Test case to test that all endpoints are reachable with respects to
authentication and permission checks.
......@@ -148,10 +149,10 @@ class APIEndpointsTestCase(APITestCase):
'/api/users/club/',
# 4th user to be create (stduser, superuser, users_adherent_1,
# users_club_1)
'/api/users/club/4/',
'/api/users/club/4/',
'/api/users/listright/',
# TODO: Merge !145
# '/api/users/listright/1/',
# TODO: Merge !145
# '/api/users/listright/1/',
'/api/users/school/',
'/api/users/school/1/',
'/api/users/serviceuser/',
......@@ -215,7 +216,7 @@ class APIEndpointsTestCase(APITestCase):
'/api/users/user/4242/',
'/api/users/whitelist/4242/',
]
stduser = None
superuser = None
......@@ -363,7 +364,7 @@ class APIEndpointsTestCase(APITestCase):
machine=cls.machines_machine_1, # Dep machines.Machine
type=cls.machines_machinetype_1, # Dep machines.MachineType
details="machines Interface 1",
#port_lists=[cls.machines_ouvertureportlist_1] # Dep machines.OuverturePortList
# port_lists=[cls.machines_ouvertureportlist_1] # Dep machines.OuverturePortList
)
cls.machines_domain_1 = machines.Domain.objects.create(
interface_parent=cls.machines_interface_1, # Dep machines.Interface
......@@ -525,14 +526,14 @@ class APIEndpointsTestCase(APITestCase):
uid_number=21103,
rezo_rez_uid=21103
)
# Need merge of MR145 to work
# TODO: Merge !145
# cls.users_listright_1 = users.ListRight.objects.create(
# unix_name="userslistright",
# gid=601,
# critical=False,
# details="userslistright"
# )
# Need merge of MR145 to work
# TODO: Merge !145
# cls.users_listright_1 = users.ListRight.objects.create(
# unix_name="userslistright",
# gid=601,
# critical=False,
# details="userslistright"
# )
cls.users_serviceuser_1 = users.ServiceUser.objects.create(
password="password",
last_login=datetime.datetime.now(datetime.timezone.utc),
......@@ -663,7 +664,7 @@ class APIEndpointsTestCase(APITestCase):
AssertionError: An endpoint did not have a 200 status code.
"""
self.client.force_authenticate(user=self.superuser)
urls = self.no_auth_endpoints + self.auth_no_perm_endpoints + \
self.auth_perm_endpoints
......@@ -676,6 +677,7 @@ class APIEndpointsTestCase(APITestCase):
formats=[None, 'json', 'api'],
assert_more=assert_more)
class APIPaginationTestCase(APITestCase):
"""Test case to check that the pagination is used on all endpoints that
should use it.
......@@ -756,7 +758,7 @@ class APIPaginationTestCase(APITestCase):
@classmethod
def tearDownClass(cls):
cls.superuser.delete()
super().tearDownClass()
super(APIPaginationTestCase, self).tearDownClass()
def test_pagination(self):
"""Tests that every endpoint is using the pagination correctly.
......@@ -776,4 +778,3 @@ class APIPaginationTestCase(APITestCase):
assert 'previous' in res_json.keys()
assert 'results' in res_json.keys()
assert not len('results') > 100
......@@ -32,7 +32,6 @@ from django.conf.urls import url, include
from . import views
from .routers import AllViewsRouter
router = AllViewsRouter()
# COTISATIONS
router.register_viewset(r'cotisations/facture', views.FactureViewSet)
......@@ -121,7 +120,6 @@ router.register_view(r'mailing/club', views.ClubMailingView),
# TOKEN AUTHENTICATION
router.register_view(r'token-auth', views.ObtainExpiringAuthToken)
urlpatterns = [
url(r'^', include(router.urls)),
]
......@@ -30,10 +30,10 @@ import datetime
from django.conf import settings
from django.db.models import Q
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework import viewsets, generics, views
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.response import Response
from rest_framework import viewsets, generics, views
import cotisations.models as cotisations
import machines.models as machines
......@@ -41,7 +41,6 @@ import preferences.models as preferences
import topologie.models as topologie
import users.models as users
from re2o.utils import all_active_interfaces, all_has_access
from . import serializers
from .pagination import PageSizedPagination
from .permissions import ACLPermission
......@@ -164,6 +163,7 @@ class TxtViewSet(viewsets.ReadOnlyModelViewSet):
queryset = machines.Txt.objects.all()
serializer_class = serializers.TxtSerializer
class DNameViewSet(viewsets.ReadOnlyModelViewSet):
"""Exposes list and details of `machines.models.DName` objects.
"""
......@@ -256,8 +256,8 @@ class RoleViewSet(viewsets.ReadOnlyModelViewSet):
class OptionalUserView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.OptionalUser.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.OptionalUser.can_view_all]}
serializer_class = serializers.OptionalUserSerializer
def get_object(self):
......@@ -267,8 +267,8 @@ class OptionalUserView(generics.RetrieveAPIView):
class OptionalMachineView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.OptionalMachine` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.OptionalMachine.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.OptionalMachine.can_view_all]}
serializer_class = serializers.OptionalMachineSerializer
def get_object(self):
......@@ -278,8 +278,8 @@ class OptionalMachineView(generics.RetrieveAPIView):
class OptionalTopologieView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.OptionalTopologie` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.OptionalTopologie.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.OptionalTopologie.can_view_all]}
serializer_class = serializers.OptionalTopologieSerializer
def get_object(self):
......@@ -289,8 +289,8 @@ class OptionalTopologieView(generics.RetrieveAPIView):
class GeneralOptionView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.GeneralOption` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.GeneralOption.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.GeneralOption.can_view_all]}
serializer_class = serializers.GeneralOptionSerializer
def get_object(self):
......@@ -307,8 +307,8 @@ class HomeServiceViewSet(viewsets.ReadOnlyModelViewSet):
class AssoOptionView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.AssoOption` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.AssoOption.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.AssoOption.can_view_all]}
serializer_class = serializers.AssoOptionSerializer
def get_object(self):
......@@ -318,8 +318,8 @@ class AssoOptionView(generics.RetrieveAPIView):
class HomeOptionView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.HomeOption` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.HomeOption.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.HomeOption.can_view_all]}
serializer_class = serializers.HomeOptionSerializer
def get_object(self):
......@@ -329,8 +329,8 @@ class HomeOptionView(generics.RetrieveAPIView):
class MailMessageOptionView(generics.RetrieveAPIView):
"""Exposes details of `preferences.models.MailMessageOption` settings.
"""
permission_classes = (ACLPermission, )
perms_map = {'GET' : [preferences.MailMessageOption.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [preferences.MailMessageOption.can_view_all]}
serializer_class = serializers.MailMessageOptionSerializer
def get_object(self):
......@@ -424,6 +424,7 @@ class PortProfileViewSet(viewsets.ReadOnlyModelViewSet):
queryset = topologie.PortProfile.objects.all()
serializer_class = serializers.PortProfileSerializer
# USER
......@@ -433,12 +434,14 @@ class UserViewSet(viewsets.ReadOnlyModelViewSet):
queryset = users.User.objects.all()
serializer_class = serializers.UserSerializer
class HomeCreationViewSet(viewsets.ReadOnlyModelViewSet):
"""Exposes infos of `users.models.Users` objects to create homes.
"""
queryset = users.User.objects.exclude(Q(state=users.User.STATE_DISABLED) | Q(state=users.User.STATE_NOT_YET_ACTIVE))
serializer_class = serializers.HomeCreationSerializer
class ClubViewSet(viewsets.ReadOnlyModelViewSet):
"""Exposes list and details of `users.models.Club` objects.
"""
......@@ -503,7 +506,7 @@ class EMailAddressViewSet(viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
if preferences.OptionalUser.get_cached_value(
'local_email_accounts_enabled'):
'local_email_accounts_enabled'):
return (users.EMailAddress.objects
.filter(user__local_email_enabled=True))
else:
......@@ -567,7 +570,7 @@ class LocalEmailUsersView(generics.ListAPIView):
def get_queryset(self):
if preferences.OptionalUser.get_cached_value(
'local_email_accounts_enabled'):
'local_email_accounts_enabled'):
return (users.User.objects
.filter(local_email_enabled=True))
else:
......@@ -585,16 +588,18 @@ class HostMacIpView(generics.ListAPIView):
serializer_class = serializers.HostMacIpSerializer
#Firewall
# Firewall
class SubnetPortsOpenView(generics.ListAPIView):
queryset = machines.IpType.objects.all()
serializer_class = serializers.SubnetPortsOpenSerializer
class InterfacePortsOpenView(generics.ListAPIView):
queryset = machines.Interface.objects.filter(port_lists__isnull=False).distinct()
serializer_class = serializers.InterfacePortsOpenSerializer
# DNS
......@@ -612,6 +617,7 @@ class DNSZonesView(generics.ListAPIView):
.all())
serializer_class = serializers.DNSZonesSerializer
class DNSReverseZonesView(generics.ListAPIView):
"""Exposes the detailed information about each extension (hostnames,
IPs, DNS records, etc.) in order to build the DNS zone files.
......@@ -620,8 +626,6 @@ class DNSReverseZonesView(generics.ListAPIView):
serializer_class = serializers.DNSReverseZonesSerializer
# MAILING
......@@ -630,8 +634,8 @@ class StandardMailingView(views.APIView):
order to building the corresponding mailing lists.
"""
pagination_class = PageSizedPagination
permission_classes = (ACLPermission, )
perms_map = {'GET' : [users.User.can_view_all]}
permission_classes = (ACLPermission,)
perms_map = {'GET': [users.User.can_view_all]}
def get(self, request, format=None):
adherents_data = serializers.MailingMemberSerializer(all_has_access(), many=True).data
......@@ -659,6 +663,7 @@ class ObtainExpiringAuthToken(ObtainAuthToken):
`rest_framework.auth_token.views.ObtainAuthToken` view except that the
expiration time is send along with the token as an addtional information.
"""
def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment