Commit bab79c4d authored by Valentin Samir's avatar Valentin Samir

More unit tests (essentially for the login view) and some docstrings

parent 7db31578
...@@ -5,3 +5,4 @@ exclude_lines = ...@@ -5,3 +5,4 @@ exclude_lines =
def __unicode__ def __unicode__
raise AssertionError raise AssertionError
raise NotImplementedError raise NotImplementedError
if six.PY3:
...@@ -49,8 +49,9 @@ coverage: test_venv ...@@ -49,8 +49,9 @@ coverage: test_venv
test_venv/bin/pip install coverage test_venv/bin/pip install coverage
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
test_venv/bin/coverage html test_venv/bin/coverage html
test_venv/bin/coverage xml rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
coverage_codacy: coverage coverage_codacy: coverage
test_venv/bin/coverage xml
test_venv/bin/pip install codacy-coverage test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml test_venv/bin/python-codacy-coverage -r coverage.xml
...@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac ...@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend Authentication backend
......
...@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test') ...@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test') setting_default('CAS_TEST_PASSWORD', 'test')
setting_default( setting_default(
'CAS_TEST_ATTRIBUTES', 'CAS_TEST_ATTRIBUTES',
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} {
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
) )
This diff is collapsed.
...@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin): ...@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
service = None service = None
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
...@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin): ...@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6 USER_NOT_AUTHENTICATED = 6
def init_post(self, request): def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
...@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin): ...@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
def gen_lt(self):
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self): def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check # save LT for later check
lt_valid = self.request.session.get('lt', []) lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt') lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed) # generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] self.gen_lt()
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_send not in lt_valid: if lt_valid is None or lt_send not in lt_valid:
return False return False
...@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin): ...@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.user.save() self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.user = models.User.objects.create( self.user = models.User.objects.create(
username=self.request.session['username'], username=self.request.session['username'],
...@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin): ...@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
pass pass
else: else:
raise EnvironmentError("invalid output for LoginView.process_post") raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common() return self.common()
def process_post(self): def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
...@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin): ...@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED return self.USER_ALREADY_LOGGED
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
...@@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin): ...@@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin):
return self.common() return self.common()
def process_get(self): def process_get(self):
# generate a new LT if none is present """Analyse the GET request"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] # generate a new LT
self.gen_lt()
if not self.request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
return self.USER_NOT_AUTHENTICATED return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED return self.USER_AUTHENTICATED
def init_form(self, values=None): def init_form(self, values=None):
"""Initialization of the good form depending of POST and GET parameters"""
self.form = forms.UserCredential( self.form = forms.UserCredential(
values, values,
initial={ initial={
......
...@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [ ...@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware', 'django.middleware.locale.LocaleMiddleware',
] ]
ROOT_URLCONF = 'cas_server.urls' ROOT_URLCONF = 'urls_tests'
# Database # Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases # https://docs.djangoproject.com/en/1.9/ref/settings/#databases
......
"""cas URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.9/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url, include
from django.contrib import admin
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^', include('cas_server.urls', namespace='cas_server')),
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment