Commit 90daf3d2 authored by Valentin Samir's avatar Valentin Samir

Add unit tests for when CAS_FEDERATE is True

Also fix some unicode related bugs
parent fcd906ca
......@@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser):
def attributs(self):
"""return a dict of user attributes"""
if not self.user:
if not self.user: # pragma: no cover (should not happen)
return {}
else:
return self.user.attributs
......@@ -14,7 +14,6 @@ from django.conf import settings
from django.contrib.staticfiles.templatetags.staticfiles import static
import re
import six
def setting_default(name, default_value):
......@@ -112,13 +111,10 @@ except AttributeError:
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
else:
key = key.lower()
if isinstance(key, six.string_types) or isinstance(key, six.text_type):
return tuple(
int(num) if num else alpha
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
)
else:
return key
return tuple(
int(num) if num else alpha
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
)
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)
......
......@@ -15,6 +15,7 @@ from .cas import CASClient
from .models import FederatedUser, FederateSLO, User
from importlib import import_module
from six.moves import urllib
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
......@@ -27,7 +28,7 @@ class CASFederateValidateUser(object):
def __init__(self, provider, service_url):
self.provider = provider
if provider in settings.CAS_FEDERATE_PROVIDERS:
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True)
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
self.client = CASClient(
service_url=service_url,
......@@ -44,9 +45,12 @@ class CASFederateValidateUser(object):
def verify_ticket(self, ticket):
"""test `password` agains the user"""
if self.client is None:
if self.client is None: # pragma: no cover (should not happen)
return False
try:
username, attributs = self.client.verify_ticket(ticket)[:2]
except urllib.error.URLError:
return False
username, attributs = self.client.verify_ticket(ticket)[:2]
if username is not None:
if attributs is None:
attributs = {}
......@@ -83,23 +87,20 @@ class CASFederateValidateUser(object):
def clean_sessions(self, logout_request):
try:
slos = self.client.get_saml_slos(logout_request)
except NameError:
slos = self.client.get_saml_slos(logout_request) or []
except NameError: # pragma: no cover (should not happen)
slos = []
for slo in slos:
try:
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
session = SessionStore(session_key=federate_slo.session_key)
session.flush()
try:
user = User.objects.get(
username=federate_slo.username,
session_key=federate_slo.session_key
)
user.logout()
user.delete()
except User.DoesNotExist:
pass
federate_slo.delete()
except FederateSLO.DoesNotExist:
pass
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
session = SessionStore(session_key=federate_slo.session_key)
session.flush()
try:
user = User.objects.get(
username=federate_slo.username,
session_key=federate_slo.session_key
)
user.logout()
user.delete()
except User.DoesNotExist: # pragma: no cover (should not happen)
pass
federate_slo.delete()
......@@ -31,6 +31,8 @@ class WarnForm(forms.Form):
class FederateSelect(forms.Form):
provider = forms.ChoiceField(
label=_('Identity provider'),
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
# this is usefull to use the override_settings decorator in tests
choices=[
(
p,
......@@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential):
user = models.FederatedUser.objects.get(username=username, provider=provider)
user.ticket = ""
user.save()
except models.FederatedUser.DoesNotExist:
raise
# should not happed as is the FederatedUser do not exists, super should
# raise before a ValidationError("bad user")
except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend)
raise forms.ValidationError(
_(u"User not found in the temporary database, please try to reconnect")
)
return cleaned_data
......
from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from datetime import timedelta
from ... import models
from ...default_settings import settings
class Command(BaseCommand):
......@@ -13,11 +9,5 @@ class Command(BaseCommand):
help = _(u"Clean old federated users")
def handle(self, *args, **options):
federated_users = models.FederatedUser.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in models.User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
models.FederatedUser.clean_old_entries()
models.FederateSLO.clean_deleted_sessions()
......@@ -46,6 +46,16 @@ class FederatedUser(models.Model):
def __unicode__(self):
return u"%s@%s" % (self.username, self.provider)
@classmethod
def clean_old_entries(cls):
federated_users = cls.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
class FederateSLO(models.Model):
class Meta:
......@@ -54,11 +64,6 @@ class FederateSLO(models.Model):
session_key = models.CharField(max_length=40, blank=True, null=True)
ticket = models.CharField(max_length=255)
@property
def provider(self):
component = self.username.split("@")
return component[-1]
@classmethod
def clean_deleted_sessions(cls):
for federate_slo in cls.objects.all():
......@@ -76,6 +81,14 @@ class User(models.Model):
username = models.CharField(max_length=30)
date = models.DateTimeField(auto_now=True)
def delete(self, *args, **kwargs):
if settings.CAS_FEDERATE:
FederateSLO.objects.filter(
username=self.username,
session_key=self.session_key
).delete()
super(User, self).delete(*args, **kwargs)
@classmethod
def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
......
......@@ -191,3 +191,50 @@ class UserModels(object):
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
class CanLogin(object):
"""Assertion about login"""
def assert_logged(
self, client, response, warn=False,
code=200, username=settings.CAS_TEST_USER
):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication
self.assertIn(
(
b"You have successfully logged into "
b"the Central Authentication Service"
),
response.content
)
# these session variables a set if usccessfully authenticated
self.assertEqual(client.session["username"], username)
self.assertIs(client.session["warn"], warn)
self.assertIs(client.session["authenticated"], True)
# on successfull authentication, a corresponding user object is created
self.assertTrue(
models.User.objects.get(
username=username,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication, so it should not
# appear
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# if authentication has failed, these session variables should not be set
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
This diff is collapsed.
......@@ -12,20 +12,81 @@
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test import TestCase, Client
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server import models, utils
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class FederatedUserTestCase(TestCase, UserModels):
"""test for the federated user model"""
def test_clean_old_entries(self):
"""tests for clean_old_entries that should delete federated user no longer used"""
client = Client()
client.get("/login")
models.FederatedUser.objects.create(
username="test1", provider="example.com", attributs={}, ticket=""
)
models.FederatedUser.objects.create(
username="test2", provider="example.com", attributs={}, ticket=""
)
models.FederatedUser.objects.all().update(
last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10))
)
models.FederatedUser.objects.create(
username="test3", provider="example.com", attributs={}, ticket=""
)
models.User.objects.create(
username="test1@example.com", session_key=client.session.session_key
)
models.FederatedUser.clean_old_entries()
self.assertEqual(len(models.FederatedUser.objects.all()), 2)
with self.assertRaises(models.FederatedUser.DoesNotExist):
models.FederatedUser.objects.get(username="test2")
class FederateSLOTestCase(TestCase, UserModels):
"""test for the federated SLO model"""
def test_clean_deleted_sessions(self):
"""
tests for clean_deleted_sessions that should delete object for which matching session
do not exists anymore
"""
client1 = Client()
client2 = Client()
client1.get("/login")
client2.get("/login")
session = client2.session
session['authenticated'] = True
try:
session.save()
except AttributeError:
pass
models.FederateSLO.objects.create(
username="test1@example.com",
session_key=client1.session.session_key,
ticket=utils.gen_st()
)
models.FederateSLO.objects.create(
username="test2@example.com",
session_key=client2.session.session_key,
ticket=utils.gen_st()
)
self.assertEqual(len(models.FederateSLO.objects.all()), 2)
models.FederateSLO.clean_deleted_sessions()
self.assertEqual(len(models.FederateSLO.objects.all()), 1)
with self.assertRaises(models.FederateSLO.DoesNotExist):
models.FederateSLO.objects.get(username="test1@example.com")
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
......
......@@ -10,7 +10,7 @@
#
# (c) 2016 Valentin Samir
"""Tests module for utils"""
from django.test import TestCase
from django.test import TestCase, RequestFactory
import six
......@@ -189,3 +189,22 @@ class UtilsTestCase(TestCase):
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
def test_get_current_url(self):
"""test the function get_current_url"""
factory = RequestFactory()
request = factory.get('/truc/muche?test=1')
self.assertEqual(utils.get_current_url(request), 'http://testserver/truc/muche?test=1')
self.assertEqual(
utils.get_current_url(request, ignore_params={'test'}),
'http://testserver/truc/muche'
)
def test_get_tuple(self):
"""test the function get_tuple"""
test_tuple = (1, 2, 3)
for index, value in enumerate(test_tuple):
self.assertEqual(utils.get_tuple(test_tuple, index), value)
self.assertEqual(utils.get_tuple(test_tuple, 3), None)
self.assertEqual(utils.get_tuple(test_tuple, 3, 'toto'), 'toto')
self.assertEqual(utils.get_tuple(None, 3), None)
# *- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
......@@ -36,57 +36,17 @@ from cas_server.tests.utils import (
HttpParamsHandler,
Http404Handler
)
from cas_server.tests.mixin import BaseServicePattern, XmlContent
from cas_server.tests.mixin import BaseServicePattern, XmlContent, CanLogin
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class LoginTestCase(TestCase, BaseServicePattern):
class LoginTestCase(TestCase, BaseServicePattern, CanLogin):
"""Tests for the login view"""
def setUp(self):
"""Prepare the test context:"""
# we prepare a bunch a service url and service patterns for tests
self.setup_service_patterns()
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# these session variables a set if usccessfully authenticated
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
# on successfull authentication, a corresponding user object is created
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication, so it should not
# appear
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# if authentication has failed, these session variables should not be set
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
# we get a client who fetch a frist time the login page and the login form default
......
......@@ -13,14 +13,33 @@
from cas_server.default_settings import settings
from django.test import Client
from django.template import loader, Context
from django.utils import timezone
import cgi
import six
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from datetime import timedelta
from cas_server import models
from cas_server import utils
def return_unicode(string, charset):
if not isinstance(string, six.text_type):
return string.decode(charset)
else:
return string
def return_bytes(string, charset):
if isinstance(string, six.text_type):
return string.encode(charset)
else:
return string
def copy_form(form):
......@@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return
@classmethod
def run(cls):
def run(cls, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
httpd = server_class(("127.0.0.1", port), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
......@@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
def test_params(self):
if (
self.server.ticket is not None and
self.params.get("service").encode("ascii") == self.server.service and
self.params.get("ticket").encode("ascii") == self.server.ticket
):
self.server.ticket = None
print("good")
return True
else:
print("bad (%r, %r) != (%r, %r)" % (
self.params.get("service").encode("ascii"),
self.params.get("ticket").encode("ascii"),
self.server.service,
self.server.ticket
))
return False
def send_headers(self, code, content_type):
self.send_response(200)
self.send_header("Content-type", content_type)
self.end_headers()
def do_GET(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/validate":
self.send_headers(200, "text/plain; charset=utf-8")
if self.test_params():
self.wfile.write(b"yes\n" + self.server.username + b"\n")
self.server.ticket = None
else:
self.wfile.write(b"no\n")
elif url.path in {
'/serviceValidate', '/serviceValidate',
'/p3/serviceValidate', '/p3/proxyValidate'
}:
self.send_headers(200, "text/xml; charset=utf-8")
if self.test_params():
t = loader.get_template('cas_server/serviceValidate.xml')
c = Context({
'username': self.server.username,
'attributes': self.server.attributes
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/serviceValidateError.xml')
c = Context({
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def do_POST(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/samlValidate":
self.send_headers(200, "text/xml; charset=utf-8")
length = int(self.headers.get('content-length'))
root = etree.fromstring(self.rfile.read(length))
auth_req = root.getchildren()[1].getchildren()[0]
ticket = auth_req.getchildren()[0].text.encode("ascii")
if (
self.server.ticket is not None and
self.params.get("TARGET").encode("ascii") == self.server.service and
ticket == self.server.ticket
):
self.server.ticket = None
t = loader.get_template('cas_server/samlValidate.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
'Recipient': self.server.service,
'ResponseID': utils.gen_saml_id(),
'username': self.server.username,
'attributes': self.server.attributes,
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/samlValidateError.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'ResponseID': utils.gen_saml_id(),
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def return_404(self):
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write("not found")
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls, service, ticket, username, attributes, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", port), cls)
httpd.service = service
httpd.ticket = ticket
httpd.username = username
httpd.attributes = attributes
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
def logout_request(ticket):
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \
{
'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': ticket
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment