Commit fcd906ca authored by Valentin Samir's avatar Valentin Samir
Browse files

Tweak the cas client lib to always return unicode

hence, the behaviour is consistent between python2 and python3
parent 2f1b3862
......@@ -21,6 +21,7 @@
# This file is originated from https://github.com/python-cas/python-cas
# at commit ec1f2d4779625229398547b9234d0e9e874a2c9a
import six
from six.moves.urllib import parse as urllib_parse
from six.moves.urllib import request as urllib_request
from six.moves.urllib.request import Request
......@@ -32,6 +33,15 @@ class CASError(ValueError):
pass
class ReturnUnicode(object):
@staticmethod
def unicode(string, charset):
if not isinstance(string, six.text_type):
return string.decode(charset)
else:
return string
class SingleLogoutMixin(object):
@classmethod
def get_saml_slos(cls, logout_request):
......@@ -124,7 +134,7 @@ class CASClientBase(object):
raise CASError("Bad http code %s" % response.code)
class CASClientV1(CASClientBase):
class CASClientV1(CASClientBase, ReturnUnicode):
"""CAS Client Version 1"""
logout_redirect_param_name = 'url'
......@@ -140,15 +150,21 @@ class CASClientV1(CASClientBase):
page = urllib_request.urlopen(url)
try:
verified = page.readline().strip()
if verified == 'yes':
return page.readline().strip(), None, None
if verified == b'yes':
content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
user = self.unicode(page.readline().strip(), charset)
return user, None, None
else:
return None, None, None
finally:
page.close()
class CASClientV2(CASClientBase):
class CASClientV2(CASClientBase, ReturnUnicode):
"""CAS Client Version 2"""
url_suffix = 'serviceValidate'
......@@ -161,8 +177,8 @@ class CASClientV2(CASClientBase):
def verify_ticket(self, ticket):
"""Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
response = self.get_verification_response(ticket)
return self.verify_response(response)
(response, charset) = self.get_verification_response(ticket)
return self.verify_response(response, charset)
def get_verification_response(self, ticket):
params = [('ticket', ticket), ('service', self.service_url)]
......@@ -172,37 +188,42 @@ class CASClientV2(CASClientBase):
url = base_url + '?' + urllib_parse.urlencode(params)
page = urllib_request.urlopen(url)
try:
return page.read()
content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
return (page.read(), charset)
finally:
page.close()
@classmethod
def parse_attributes_xml_element(cls, element):
def parse_attributes_xml_element(cls, element, charset):
attributes = dict()
for attribute in element:
tag = attribute.tag.split("}").pop()
tag = cls.self.unicode(attribute.tag, charset).split(u"}").pop()
if tag in attributes:
if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text)
attributes[tag].append(cls.unicode(attribute.text, charset))
else:
attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text)
attributes[tag].append(cls.unicode(attribute.text, charset))
else:
if tag == 'attraStyle':
if tag == u'attraStyle':
pass
else:
attributes[tag] = attribute.text
attributes[tag] = cls.unicode(attribute.text, charset)
return attributes
@classmethod
def verify_response(cls, response):
user, attributes, pgtiou = cls.parse_response_xml(response)
def verify_response(cls, response, charset):
user, attributes, pgtiou = cls.parse_response_xml(response, charset)
if len(attributes) == 0:
attributes = None
return user, attributes, pgtiou
@classmethod
def parse_response_xml(cls, response):
def parse_response_xml(cls, response, charset):
try:
from xml.etree import ElementTree
except ImportError:
......@@ -216,11 +237,11 @@ class CASClientV2(CASClientBase):
if tree[0].tag.endswith('authenticationSuccess'):
for element in tree[0]:
if element.tag.endswith('user'):
user = element.text
user = cls.unicode(element.text, charset)
elif element.tag.endswith('proxyGrantingTicket'):
pgtiou = element.text
pgtiou = cls.unicode(element.text, charset)
elif element.tag.endswith('attributes'):
attributes = cls.parse_attributes_xml_element(element)
attributes = cls.parse_attributes_xml_element(element, charset)
return user, attributes, pgtiou
......@@ -230,23 +251,23 @@ class CASClientV3(CASClientV2, SingleLogoutMixin):
logout_redirect_param_name = 'service'
@classmethod
def parse_attributes_xml_element(cls, element):
def parse_attributes_xml_element(cls, element, charset):
attributes = dict()
for attribute in element:
tag = attribute.tag.split("}").pop()
tag = cls.unicode(attribute.tag, charset).split(u"}").pop()
if tag in attributes:
if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text)
attributes[tag].append(cls.unicode(attribute.text, charset))
else:
attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text)
attributes[tag].append(cls.unicode(attribute.text, charset))
else:
attributes[tag] = attribute.text
attributes[tag] = cls.unicode(attribute.text, charset)
return attributes
@classmethod
def verify_response(cls, response):
return cls.parse_response_xml(response)
def verify_response(cls, response, charset):
return cls.parse_response_xml(response, charset)
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
......@@ -284,6 +305,11 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
from elementtree import ElementTree
page = self.fetch_saml_validation(ticket)
content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
try:
user = None
......@@ -296,21 +322,25 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
# User is validated
name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier')
if name_identifier is not None:
user = name_identifier.text
user = self.unicode(name_identifier.text, charset)
attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
for at in attrs:
if self.username_attribute in list(at.attrib.values()):
user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text
attributes['uid'] = user
user = self.unicode(
at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text,
charset
)
attributes[u'uid'] = user
values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
key = self.unicode(at.attrib['AttributeName'], charset)
if len(values) > 1:
values_array = []
for v in values:
values_array.append(v.text)
attributes[at.attrib['AttributeName']] = values_array
values_array.append(self.unicode(v.text, charset))
attributes[key] = values_array
else:
attributes[at.attrib['AttributeName']] = values[0].text
attributes[key] = self.unicode(values[0].text, charset)
return user, attributes, None
finally:
page.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment