Commit efdd97ec authored by Valentin Samir's avatar Valentin Samir
Browse files

Test for CAS federation

parent 3a57ad08
*.pyc
*.egg-info
*.swp
build/
bootstrap3
......
......@@ -12,6 +12,9 @@
"""Some authentication classes for the CAS"""
from django.conf import settings
from django.contrib.auth import get_user_model
from django.utils import timezone
from datetime import timedelta
try:
import MySQLdb
import MySQLdb.cursors
......@@ -19,6 +22,8 @@ try:
except ImportError:
MySQLdb = None
from .models import FederatedUser
class AuthUser(object):
def __init__(self, username):
......@@ -140,3 +145,37 @@ class DjangoAuthUser(AuthUser):
for field in self.user._meta.fields:
attr[field.attname] = getattr(self.user, field.attname)
return attr
class CASFederateAuth(AuthUser):
user = None
def __init__(self, username):
component = username.split('@')
username = '@'.join(component[:-1])
provider = component[-1]
try:
self.user = FederatedUser.objects.get(username=username, provider=provider)
super(CASFederateAuth, self).__init__(
"%s@%s" % (self.user.username, self.user.provider)
)
except FederatedUser.DoesNotExist:
super(CASFederateAuth, self).__init__("%s@%s" % (username, provider))
def test_password(self, ticket):
"""test `password` agains the user"""
if not self.user or not self.user.ticket:
return False
else:
return (
ticket == self.user.ticket and
self.user.last_update >
(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
)
def attributs(self):
"""return a dict of user attributes"""
if not self.user:
return {}
else:
return self.user.attributs
from six.moves.urllib import parse as urllib_parse
from six.moves.urllib import request as urllib_request
from six.moves.urllib.request import Request
from uuid import uuid4
import datetime
class CASError(ValueError):
pass
class SingleLogoutMixin(object):
@classmethod
def get_saml_slos(cls, logout_request):
"""returns saml logout ticket info"""
from lxml import etree
try:
root = etree.fromstring(logout_request)
return root.xpath(
"//samlp:SessionIndex",
namespaces={'samlp': "urn:oasis:names:tc:SAML:2.0:protocol"})
except etree.XMLSyntaxError:
pass
class CASClient(object):
def __new__(self, *args, **kwargs):
version = kwargs.pop('version')
if version in (1, '1'):
return CASClientV1(*args, **kwargs)
elif version in (2, '2'):
return CASClientV2(*args, **kwargs)
elif version in (3, '3'):
return CASClientV3(*args, **kwargs)
elif version == 'CAS_2_SAML_1_0':
return CASClientWithSAMLV1(*args, **kwargs)
raise ValueError('Unsupported CAS_VERSION %r' % version)
class CASClientBase(object):
logout_redirect_param_name = 'service'
def __init__(self, service_url=None, server_url=None,
extra_login_params=None, renew=False,
username_attribute=None):
self.service_url = service_url
self.server_url = server_url
self.extra_login_params = extra_login_params or {}
self.renew = renew
self.username_attribute = username_attribute
pass
def verify_ticket(self, ticket):
"""must return a triple"""
raise NotImplementedError()
def get_login_url(self):
"""Generates CAS login URL"""
params = {'service': self.service_url}
if self.renew:
params.update({'renew': 'true'})
params.update(self.extra_login_params)
url = urllib_parse.urljoin(self.server_url, 'login')
query = urllib_parse.urlencode(params)
return url + '?' + query
def get_logout_url(self, redirect_url=None):
"""Generates CAS logout URL"""
url = urllib_parse.urljoin(self.server_url, 'logout')
if redirect_url:
params = {self.logout_redirect_param_name: redirect_url}
url += '?' + urllib_parse.urlencode(params)
return url
def get_proxy_url(self, pgt):
"""Returns proxy url, given the proxy granting ticket"""
params = urllib_parse.urlencode({'pgt': pgt, 'targetService': self.service_url})
return "%s/proxy?%s" % (self.server_url, params)
def get_proxy_ticket(self, pgt):
"""Returns proxy ticket given the proxy granting ticket"""
response = urllib_request.urlopen(self.get_proxy_url(pgt))
if response.code == 200:
from lxml import etree
root = etree.fromstring(response.read())
tickets = root.xpath(
"//cas:proxyTicket",
namespaces={"cas": "http://www.yale.edu/tp/cas"}
)
if len(tickets) == 1:
return tickets[0].text
errors = root.xpath(
"//cas:authenticationFailure",
namespaces={"cas": "http://www.yale.edu/tp/cas"}
)
if len(errors) == 1:
raise CASError(errors[0].attrib['code'], errors[0].text)
raise CASError("Bad http code %s" % response.code)
class CASClientV1(CASClientBase):
"""CAS Client Version 1"""
logout_redirect_param_name = 'url'
def verify_ticket(self, ticket):
"""Verifies CAS 1.0 authentication ticket.
Returns username on success and None on failure.
"""
params = [('ticket', ticket), ('service', self.service)]
url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' +
urllib_parse.urlencode(params))
page = urllib_request.urlopen(url)
try:
verified = page.readline().strip()
if verified == 'yes':
return page.readline().strip(), None, None
else:
return None, None, None
finally:
page.close()
class CASClientV2(CASClientBase):
"""CAS Client Version 2"""
url_suffix = 'serviceValidate'
logout_redirect_param_name = 'url'
def __init__(self, proxy_callback=None, *args, **kwargs):
"""proxy_callback is for V2 and V3 so V3 is subclass of V2"""
self.proxy_callback = proxy_callback
super(CASClientV2, self).__init__(*args, **kwargs)
def verify_ticket(self, ticket):
"""Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
response = self.get_verification_response(ticket)
return self.verify_response(response)
def get_verification_response(self, ticket):
params = [('ticket', ticket), ('service', self.service_url)]
if self.proxy_callback:
params.append(('pgtUrl', self.proxy_callback))
base_url = urllib_parse.urljoin(self.server_url, self.url_suffix)
url = base_url + '?' + urllib_parse.urlencode(params)
page = urllib_request.urlopen(url)
try:
return page.read()
finally:
page.close()
@classmethod
def parse_attributes_xml_element(cls, element):
attributes = dict()
for attribute in element:
tag = attribute.tag.split("}").pop()
if tag in attributes:
if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text)
else:
attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text)
else:
if tag == 'attraStyle':
pass
else:
attributes[tag] = attribute.text
return attributes
@classmethod
def verify_response(cls, response):
user, attributes, pgtiou = cls.parse_response_xml(response)
if len(attributes) == 0:
attributes = None
return user, attributes, pgtiou
@classmethod
def parse_response_xml(cls, response):
try:
from xml.etree import ElementTree
except ImportError:
from elementtree import ElementTree
user = None
attributes = {}
pgtiou = None
tree = ElementTree.fromstring(response)
if tree[0].tag.endswith('authenticationSuccess'):
for element in tree[0]:
if element.tag.endswith('user'):
user = element.text
elif element.tag.endswith('proxyGrantingTicket'):
pgtiou = element.text
elif element.tag.endswith('attributes'):
attributes = cls.parse_attributes_xml_element(element)
return user, attributes, pgtiou
class CASClientV3(CASClientV2, SingleLogoutMixin):
"""CAS Client Version 3"""
url_suffix = 'serviceValidate'
logout_redirect_param_name = 'service'
@classmethod
def parse_attributes_xml_element(cls, element):
attributes = dict()
for attribute in element:
tag = attribute.tag.split("}").pop()
if tag in attributes:
if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text)
else:
attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text)
else:
attributes[tag] = attribute.text
return attributes
@classmethod
def verify_response(cls, response):
return cls.parse_response_xml(response)
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
SAML_1_0_PROTOCOL_NS = '{' + SAML_1_0_NS + 'protocol' + '}'
SAML_1_0_ASSERTION_NS = '{' + SAML_1_0_NS + 'assertion' + '}'
SAML_ASSERTION_TEMPLATE = """<?xml version="1.0" encoding="UTF-8"?>
<SOAP-ENV:Envelope xmlns:SOAP-ENV="http://schemas.xmlsoap.org/soap/envelope/">
<SOAP-ENV:Header/>
<SOAP-ENV:Body>
<samlp:Request xmlns:samlp="urn:oasis:names:tc:SAML:1.0:protocol"
MajorVersion="1"
MinorVersion="1"
RequestID="{request_id}"
IssueInstant="{timestamp}">
<samlp:AssertionArtifact>{ticket}</samlp:AssertionArtifact></samlp:Request>
</SOAP-ENV:Body>
</SOAP-ENV:Envelope>"""
class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
"""CASClient 3.0+ with SAML"""
def verify_ticket(self, ticket, **kwargs):
"""Verifies CAS 3.0+ XML-based authentication ticket and returns extended attributes.
@date: 2011-11-30
@author: Carlos Gonzalez Vila <carlewis@gmail.com>
Returns username and attributes on success and None,None on failure.
"""
try:
from xml.etree import ElementTree
except ImportError:
from elementtree import ElementTree
page = self.fetch_saml_validation(ticket)
try:
user = None
attributes = {}
response = page.read()
tree = ElementTree.fromstring(response)
# Find the authentication status
success = tree.find('.//' + SAML_1_0_PROTOCOL_NS + 'StatusCode')
if success is not None and success.attrib['Value'].endswith(':Success'):
# User is validated
attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
for at in attrs:
if self.username_attribute in list(at.attrib.values()):
user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text
attributes['uid'] = user
values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
if len(values) > 1:
values_array = []
for v in values:
values_array.append(v.text)
attributes[at.attrib['AttributeName']] = values_array
else:
attributes[at.attrib['AttributeName']] = values[0].text
return user, attributes, None
finally:
page.close()
def fetch_saml_validation(self, ticket):
# We do the SAML validation
headers = {
'soapaction': 'http://www.oasis-open.org/committees/security',
'cache-control': 'no-cache',
'pragma': 'no-cache',
'accept': 'text/xml',
'connection': 'keep-alive',
'content-type': 'text/xml; charset=utf-8',
}
params = [('TARGET', self.service_url)]
saml_validate_url = urllib_parse.urljoin(
self.server_url, 'samlValidate',
)
request = Request(
saml_validate_url + '?' + urllib_parse.urlencode(params),
self.get_saml_assertion(ticket),
headers,
)
return urllib_request.urlopen(request)
@classmethod
def get_saml_assertion(cls, ticket):
"""
http://www.jasig.org/cas/protocol#samlvalidate-cas-3.0
SAML request values:
RequestID [REQUIRED]:
unique identifier for the request
IssueInstant [REQUIRED]:
timestamp of the request
samlp:AssertionArtifact [REQUIRED]:
the valid CAS Service Ticket obtained as a response parameter at login.
"""
# RequestID [REQUIRED] - unique identifier for the request
request_id = uuid4()
# e.g. 2014-06-02T09:21:03.071189
timestamp = datetime.datetime.now().isoformat()
return SAML_ASSERTION_TEMPLATE.format(
request_id=request_id,
timestamp=timestamp,
ticket=ticket,
).encode('utf8')
......@@ -18,6 +18,7 @@ def setting_default(name, default_value):
setattr(settings, name, value)
setting_default('CAS_LOGIN_TEMPLATE', 'cas_server/login.html')
setting_default('CAS_FEDERATE_TEMPLATE', 'cas_server/federate.html')
setting_default('CAS_WARN_TEMPLATE', 'cas_server/warn.html')
setting_default('CAS_LOGGED_TEMPLATE', 'cas_server/logged.html')
setting_default('CAS_LOGOUT_TEMPLATE', 'cas_server/logout.html')
......@@ -70,3 +71,14 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8')
setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
'password, users.* FROM users WHERE user = %s')
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain
setting_default('CAS_FEDERATE', False)
# A dict of "provider name" -> (provider CAS server url, CAS version)
setting_default('CAS_FEDERATE_PROVIDERS', {})
if settings.CAS_FEDERATE:
settings.CAS_AUTH_CLASS = "cas_server.auth.CASFederateAuth"
CAS_FEDERATE_PROVIDERS_LIST = settings.CAS_FEDERATE_PROVIDERS.keys()
CAS_FEDERATE_PROVIDERS_LIST.sort()
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2015 Valentin Samir
from .default_settings import settings
from .cas import CASClient
from .models import FederatedUser
class CASFederateValidateUser(object):
username = None
attributs = {}
client = None
def __init__(self, provider, service_url):
self.provider = provider
if provider in settings.CAS_FEDERATE_PROVIDERS:
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider]
self.client = CASClient(
service_url=service_url,
version=version,
server_url=server_url,
extra_login_params={"provider": provider},
renew=False,
)
def get_login_url(self):
return self.client.get_login_url() if self.client is not None else False
def get_logout_url(self, redirect_url=None):
return self.client.get_logout_url(redirect_url) if self.client is not None else False
def verify_ticket(self, ticket):
"""test `password` agains the user"""
if self.client is None:
return False
username, attributs, pgtiou = self.client.verify_ticket(ticket)
if username is not None:
attributs["provider"] = self.provider
self.username = username
self.attributs = attributs
try:
user = FederatedUser.objects.get(
username=username,
provider=self.provider
)
user.attributs = attributs
user.ticket = ticket
user.save()
except FederatedUser.DoesNotExist:
user = FederatedUser.objects.create(
username=username,
provider=self.provider,
attributs=attributs,
ticket=ticket
)
user.save()
return True
else:
return False
......@@ -9,7 +9,7 @@
#
# (c) 2015 Valentin Samir
"""forms for the app"""
from .default_settings import settings
from .default_settings import settings, CAS_FEDERATE_PROVIDERS_LIST
from django import forms
from django.utils.translation import ugettext_lazy as _
......@@ -27,6 +27,17 @@ class WarnForm(forms.Form):
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
class FederateSelect(forms.Form):
provider = forms.ChoiceField(
label=_('Identity provider'),
choices=[(p, p) for p in CAS_FEDERATE_PROVIDERS_LIST]
)
service = forms.CharField(label=_('service'), widget=forms.HiddenInput(), required=False)
method = forms.CharField(widget=forms.HiddenInput(), required=False)
remember = forms.BooleanField(label=_('Remember the identity provider'), required=False)
warn = forms.BooleanField(label=_('warn'), required=False)
class UserCredential(forms.Form):
"""Form used on the login page to retrive user credentials"""
username = forms.CharField(label=_('login'))
......@@ -46,6 +57,31 @@ class UserCredential(forms.Form):
cleaned_data["username"] = auth.username
else:
raise forms.ValidationError(_(u"Bad user"))
return cleaned_data
class FederateUserCredential(UserCredential):
"""Form used on the login page to retrive user credentials"""
username = forms.CharField(widget=forms.HiddenInput())
service = forms.CharField(widget=forms.HiddenInput(), required=False)
password = forms.CharField(widget=forms.HiddenInput())
ticket = forms.CharField(widget=forms.HiddenInput())
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
method = forms.CharField(widget=forms.HiddenInput(), required=False)
warn = forms.BooleanField(widget=forms.HiddenInput(), required=False)
def clean(self):
cleaned_data = super(FederateUserCredential, self).clean()
try:
component = cleaned_data["username"].split('@')
username = '@'.join(component[:-1])
provider = component[-1]
user = models.FederatedUser.objects.get(username=username, provider=provider)
user.ticket = ""
user.save()
except models.FederatedUser.DoesNotExist:
raise
return cleaned_data
class TicketForm(forms.ModelForm):
......
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2016-06-16 10:18
from __future__ import unicode_literals
from django.db import migrations, models
import picklefield.fields
class Migration(migrations.Migration):
dependencies = [
('cas_server', '0004_auto_20151218_1032'),
]
operations = [
migrations.CreateModel(
name='FederatedUser',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('username', models.CharField(max_length=124)),
('provider', models.CharField(max_length=124)),
('attributs', picklefield.fields.PickledObjectField(editable=False)),
('ticket', models.CharField(max_length=255)),
('last_update', models.DateTimeField(auto_now=True)),
],
),
migrations.AlterUniqueTogether(
name='federateduser',
unique_together=set([('username', 'provider')]),
),
]
......@@ -35,6 +35,16 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
logger = logging.getLogger(__name__)
class FederatedUser(models.Model):
class Meta:
unique_together = ("username", "provider")
username = models.CharField(max_length=124)
provider = models.CharField(max_length=124)
attributs = PickledObjectField()
ticket = models.CharField(max_length=255)
last_update = models.DateTimeField(auto_now=True)
class User(models.Model):
"""A user logged into the CAS"""
class Meta:
......