Commit 90daf3d2 authored by Valentin Samir's avatar Valentin Samir

Add unit tests for when CAS_FEDERATE is True

Also fix some unicode related bugs
parent fcd906ca
......@@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser):
def attributs(self):
"""return a dict of user attributes"""
if not self.user:
if not self.user: # pragma: no cover (should not happen)
return {}
else:
return self.user.attributs
......@@ -14,7 +14,6 @@ from django.conf import settings
from django.contrib.staticfiles.templatetags.staticfiles import static
import re
import six
def setting_default(name, default_value):
......@@ -112,13 +111,10 @@ except AttributeError:
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
else:
key = key.lower()
if isinstance(key, six.string_types) or isinstance(key, six.text_type):
return tuple(
int(num) if num else alpha
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
)
else:
return key
return tuple(
int(num) if num else alpha
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
)
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)
......
......@@ -15,6 +15,7 @@ from .cas import CASClient
from .models import FederatedUser, FederateSLO, User
from importlib import import_module
from six.moves import urllib
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
......@@ -27,7 +28,7 @@ class CASFederateValidateUser(object):
def __init__(self, provider, service_url):
self.provider = provider
if provider in settings.CAS_FEDERATE_PROVIDERS:
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True)
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
self.client = CASClient(
service_url=service_url,
......@@ -44,9 +45,12 @@ class CASFederateValidateUser(object):
def verify_ticket(self, ticket):
"""test `password` agains the user"""
if self.client is None:
if self.client is None: # pragma: no cover (should not happen)
return False
try:
username, attributs = self.client.verify_ticket(ticket)[:2]
except urllib.error.URLError:
return False
username, attributs = self.client.verify_ticket(ticket)[:2]
if username is not None:
if attributs is None:
attributs = {}
......@@ -83,23 +87,20 @@ class CASFederateValidateUser(object):
def clean_sessions(self, logout_request):
try:
slos = self.client.get_saml_slos(logout_request)
except NameError:
slos = self.client.get_saml_slos(logout_request) or []
except NameError: # pragma: no cover (should not happen)
slos = []
for slo in slos:
try:
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
session = SessionStore(session_key=federate_slo.session_key)
session.flush()
try:
user = User.objects.get(
username=federate_slo.username,
session_key=federate_slo.session_key
)
user.logout()
user.delete()
except User.DoesNotExist:
pass
federate_slo.delete()
except FederateSLO.DoesNotExist:
pass
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
session = SessionStore(session_key=federate_slo.session_key)
session.flush()
try:
user = User.objects.get(
username=federate_slo.username,
session_key=federate_slo.session_key
)
user.logout()
user.delete()
except User.DoesNotExist: # pragma: no cover (should not happen)
pass
federate_slo.delete()
......@@ -31,6 +31,8 @@ class WarnForm(forms.Form):
class FederateSelect(forms.Form):
provider = forms.ChoiceField(
label=_('Identity provider'),
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
# this is usefull to use the override_settings decorator in tests
choices=[
(
p,
......@@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential):
user = models.FederatedUser.objects.get(username=username, provider=provider)
user.ticket = ""
user.save()
except models.FederatedUser.DoesNotExist:
raise
# should not happed as is the FederatedUser do not exists, super should
# raise before a ValidationError("bad user")
except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend)
raise forms.ValidationError(
_(u"User not found in the temporary database, please try to reconnect")
)
return cleaned_data
......
from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from datetime import timedelta
from ... import models
from ...default_settings import settings
class Command(BaseCommand):
......@@ -13,11 +9,5 @@ class Command(BaseCommand):
help = _(u"Clean old federated users")
def handle(self, *args, **options):
federated_users = models.FederatedUser.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in models.User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
models.FederatedUser.clean_old_entries()
models.FederateSLO.clean_deleted_sessions()
......@@ -46,6 +46,16 @@ class FederatedUser(models.Model):
def __unicode__(self):
return u"%s@%s" % (self.username, self.provider)
@classmethod
def clean_old_entries(cls):
federated_users = cls.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
class FederateSLO(models.Model):
class Meta:
......@@ -54,11 +64,6 @@ class FederateSLO(models.Model):
session_key = models.CharField(max_length=40, blank=True, null=True)
ticket = models.CharField(max_length=255)
@property
def provider(self):
component = self.username.split("@")
return component[-1]
@classmethod
def clean_deleted_sessions(cls):
for federate_slo in cls.objects.all():
......@@ -76,6 +81,14 @@ class User(models.Model):
username = models.CharField(max_length=30)
date = models.DateTimeField(auto_now=True)
def delete(self, *args, **kwargs):
if settings.CAS_FEDERATE:
FederateSLO.objects.filter(
username=self.username,
session_key=self.session_key
).delete()
super(User, self).delete(*args, **kwargs)
@classmethod
def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
......
......@@ -191,3 +191,50 @@ class UserModels(object):
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
class CanLogin(object):
"""Assertion about login"""
def assert_logged(
self, client, response, warn=False,
code=200, username=settings.CAS_TEST_USER
):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication
self.assertIn(
(
b"You have successfully logged into "
b"the Central Authentication Service"
),
response.content
)
# these session variables a set if usccessfully authenticated
self.assertEqual(client.session["username"], username)
self.assertIs(client.session["warn"], warn)
self.assertIs(client.session["authenticated"], True)
# on successfull authentication, a corresponding user object is created
self.assertTrue(
models.User.objects.get(
username=username,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication, so it should not
# appear
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# if authentication has failed, these session variables should not be set
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""tests for the CAS federate mode"""
from cas_server import default_settings
from cas_server.default_settings import settings
import django
from django.test import TestCase, Client
from django.test.utils import override_settings
from six.moves import reload_module
from cas_server import utils, forms
from cas_server.tests.mixin import BaseServicePattern, CanLogin
from cas_server.tests import utils as tests_utils
PROVIDERS = {
"example.com": ("http://127.0.0.1:8080", 1, "Example dot com"),
"example.org": ("http://127.0.0.1:8081", 2, "Example dot org"),
"example.net": ("http://127.0.0.1:8082", 3, "Example dot net"),
"example.test": ("http://127.0.0.1:8083", 'CAS_2_SAML_1_0'),
}
PROVIDERS_LIST = list(PROVIDERS.keys())
PROVIDERS_LIST.sort()
def getaddrinfo_mock(name, port, *args, **kwargs):
return [(2, 1, 6, '', ('127.0.0.1', 80))]
@override_settings(
CAS_FEDERATE=True,
CAS_FEDERATE_PROVIDERS=PROVIDERS,
CAS_FEDERATE_PROVIDERS_LIST=PROVIDERS_LIST,
CAS_AUTH_CLASS="cas_server.auth.CASFederateAuth",
# test with a non ascii username
CAS_TEST_USER=u"dédé"
)
class FederateAuthLoginLogoutTestCase(TestCase, BaseServicePattern, CanLogin):
"""tests for the views login logout and federate then the federated mode is enabled"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
reload_module(forms)
def test_default_settings(self):
"""default settings should populated some default variable then CAS_FEDERATE is True"""
provider_list = settings.CAS_FEDERATE_PROVIDERS_LIST
del settings.CAS_FEDERATE_PROVIDERS_LIST
del settings.CAS_AUTH_CLASS
reload_module(default_settings)
self.assertEqual(settings.CAS_FEDERATE_PROVIDERS_LIST, provider_list)
self.assertEqual(settings.CAS_AUTH_CLASS, "cas_server.auth.CASFederateAuth")
def test_login_get_provider(self):
"""some assertion about the login page in federated mode"""
client = Client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
self.assertTrue('<option value="%s">%s</option>' % (
key,
utils.get_tuple(value, 2, key)
) in response.content.decode("utf-8"))
self.assertEqual(response.context['post_url'], '/federate')
def test_login_post_provider(self, remember=False):
"""test a successful login wrokflow"""
tickets = []
# choose the example.com provider
for (provider, cas_port) in [
("example.com", 8080), ("example.org", 8081),
("example.net", 8082), ("example.test", 8083)
]:
# get a bare client
client = Client()
# fetch the login page
response = client.get("/login")
# in federated mode, we shoudl POST do /federate on the login page
self.assertEqual(response.context['post_url'], '/federate')
# get current form parameter
params = tests_utils.copy_form(response.context["form"])
params['provider'] = provider
if remember:
params['remember'] = 'on'
# post the choosed provider
response = client.post('/federate', params)
# we are redirected to the provider CAS client url
self.assertEqual(response.status_code, 302)
if remember:
self.assertEqual(response["Location"], '%s/federate/%s?remember=on' % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
else:
self.assertEqual(response["Location"], '%s/federate/%s' % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
# let's follow the redirect
response = client.get('/federate/%s' % provider)
# we are redirected to the provider CAS for authentication
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[provider][0],
provider
)
)
# let's generate a ticket
ticket = utils.gen_st()
# we lauch a dummy CAS server that only validate once for the service
# http://testserver/federate/example.com with `ticket`
tests_utils.DummyCAS.run(
("http://testserver/federate/%s" % provider).encode("ascii"),
ticket.encode("ascii"),
settings.CAS_TEST_USER.encode("utf8"),
[],
cas_port
)
# we normally provide a good ticket and should be redirected to /login as the ticket
# get successfully validated again the dummy CAS
response = client.get('/federate/%s' % provider, {'ticket': ticket})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
# follow the redirect
response = client.get("/login")
# we should get a page with a from with all widget hidden that auto POST to /login using
# javascript. If javascript is disabled, a "connect" button is showed
self.assertTrue(response.context['auto_submit'])
self.assertEqual(response.context['post_url'], '/login')
params = tests_utils.copy_form(response.context["form"])
# POST ge prefiled from parameters
response = client.post("/login", params)
# the user should now being authenticated using username test@`provider`
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
tickets.append((provider, ticket, client))
# try to get a ticket
response = client.get("/login", {'service': self.service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service))
return tickets
def test_login_twice(self):
"""Test that user id db is used for the second login (cf coverage)"""
self.test_login_post_provider()
self.test_login_post_provider()
@override_settings(CAS_FEDERATE=False)
def test_auth_federate_false(self):
"""federated view should redirect to /login then CAS_FEDERATE is False"""
provider = "example.com"
client = Client()
response = client.get("/federate/%s" % provider)
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
response = client.post("%s/federate/%s" % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
def test_auth_federate_errors(self):
"""
The federated view should redirect to /login if the provider is unknown or not provided,
try to fetch a new ticket if the provided ticket validation fail
(network error or bad ticket)
"""
return
good_provider = "example.com"
bad_provider = "exemple.fr"
client = Client()
response = client.get("/federate/%s" % bad_provider)
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
# test CAS not avaible
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
good_provider
)
)
# test CAS avaible but bad ticket
tests_utils.DummyCAS.run(
("http://testserver/federate/%s" % good_provider).encode("ascii"),
utils.gen_st().encode("ascii"),
settings.CAS_TEST_USER.encode("utf-8"),
[],
8080
)
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
good_provider
)
)
response = client.post("/federate")
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
def test_auth_federate_slo(self):
"""test that SLO receive from backend CAS log out the users"""
# get tickets and connected clients
tickets = self.test_login_post_provider()
for (provider, ticket, client) in tickets:
# SLO for an unkown ticket should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': tests_utils.logout_request(utils.gen_st())}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
# Bad SLO format should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': ""}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
# Bad SLO format should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': "<root></root>"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
response = client.get("/login")
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
# SLO for a previously logged ticket should log out the user if CAS version is
# 3 or 'CAS_2_SAML_1_0'
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': tests_utils.logout_request(ticket)}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
response = client.get("/login")
if settings.CAS_FEDERATE_PROVIDERS[provider][1] in {3, 'CAS_2_SAML_1_0'}: # support SLO
self.assert_login_failed(client, response)
else:
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
def test_federate_logout(self):
"""
test the logout function: the user should be log out
and redirected to his CAS logout page
"""
# get tickets and connected clients
tickets = self.test_login_post_provider()
for (provider, _, client) in tickets:
response = client.get("/logout")
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/logout" % settings.CAS_FEDERATE_PROVIDERS[provider][0]
)
response = client.get("/login")
self.assert_login_failed(client, response)
def test_remember_provider(self):
"""
If the user check remember, next login should not offer the chose of the backend CAS
and use the one store in the cookie
"""
tickets = self.test_login_post_provider(remember=True)
for (provider, _, client) in tickets:
client.get("/logout")
response = client.get("/login")
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/federate/%s" % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
def test_login_bad_ticket(self):
"""
Try login with a bad ticket:
login should fail and the main login page should be displayed to the user
"""
provider = "example.com"
# get a bare client
client = Client()
session = client.session
session["federate_username"] = '%s@%s' % (settings.CAS_TEST_USER, provider)
session["federate_ticket"] = utils.gen_st()
try:
session.save()
response = client.get("/login")
# we should get a page with a from with all widget hidden that auto POST to /login using
# javascript. If javascript is disabled, a "connect" button is showed
self.assertTrue(response.context['auto_submit'])
self.assertEqual(response.context['post_url'], '/login')
params = tests_utils.copy_form(response.context["form"])
# POST, as (username, ticket) are not valid, we should get the federate login page
response = client.post("/login", params)
self.assertEqual(response.status_code, 200)
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
self.assertTrue('<option value="%s">%s</option>' % (
key,
utils.get_tuple(value, 2, key)
) in response.content.decode("utf-8"))
self.assertEqual(response.context['post_url'], '/federate')
except AttributeError:
pass
......@@ -12,20 +12,81 @@
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test import TestCase, Client
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server import models, utils
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class FederatedUserTestCase(TestCase, UserModels):
"""test for the federated user model"""
def test_clean_old_entries(self):
"""tests for clean_old_entries that should delete federated user no longer used"""
client = Client()
client.get("/login")
models.FederatedUser.objects.create(
username="test1", provider="example.com", attributs={}, ticket=""
)
models.FederatedUser.objects.create(