Commit bab79c4d authored by Valentin Samir's avatar Valentin Samir

More unit tests (essentially for the login view) and some docstrings

parent 7db31578
...@@ -5,3 +5,4 @@ exclude_lines = ...@@ -5,3 +5,4 @@ exclude_lines =
def __unicode__ def __unicode__
raise AssertionError raise AssertionError
raise NotImplementedError raise NotImplementedError
if six.PY3:
...@@ -49,8 +49,9 @@ coverage: test_venv ...@@ -49,8 +49,9 @@ coverage: test_venv
test_venv/bin/pip install coverage test_venv/bin/pip install coverage
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
test_venv/bin/coverage html test_venv/bin/coverage html
test_venv/bin/coverage xml rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
coverage_codacy: coverage coverage_codacy: coverage
test_venv/bin/coverage xml
test_venv/bin/pip install codacy-coverage test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml test_venv/bin/python-codacy-coverage -r coverage.xml
...@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac ...@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend Authentication backend
......
...@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test') ...@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test') setting_default('CAS_TEST_PASSWORD', 'test')
setting_default( setting_default(
'CAS_TEST_ATTRIBUTES', 'CAS_TEST_ATTRIBUTES',
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} {
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
) )
...@@ -3,36 +3,49 @@ from .default_settings import settings ...@@ -3,36 +3,49 @@ from .default_settings import settings
from django.test import TestCase from django.test import TestCase
from django.test import Client from django.test import Client
import re
import six import six
import random
from lxml import etree from lxml import etree
from six.moves import range
from cas_server import models from cas_server import models
from cas_server import utils from cas_server import utils
def get_login_page_params(): def copy_form(form):
client = Client() """Copy form value into a dict"""
response = client.get('/login')
form = response.context["form"]
params = {} params = {}
for field in form: for field in form:
if field.value(): if field.value():
params[field.name] = field.value() params[field.name] = field.value()
else: else:
params[field.name] = "" params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params return client, params
def get_auth_client(): def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params) client.post('/login', params)
return client return client
def get_user_ticket_request(service): def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client() client = get_auth_client()
response = client.get("/login", {"service": service}) response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1] ticket_value = response['Location'].split('ticket=')[-1]
...@@ -45,6 +58,7 @@ def get_user_ticket_request(service): ...@@ -45,6 +58,7 @@ def get_user_ticket_request(service):
def get_pgt(): def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
...@@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase): ...@@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase):
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hox_sha512(self): def test_hex_sha512(self):
"""test the hex_sha512 auth method""" """test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
...@@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase): ...@@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase):
class LoginTestCase(TestCase): class LoginTestCase(TestCase):
"""Tests for the login view"""
def setUp(self): def setUp(self):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
# For general purpose testing
self.service_pattern = models.ServicePattern.objects.create( self.service_pattern = models.ServicePattern.objects.create(
name="example", name="example",
pattern="^https://www\.example\.com(/.*)?$", pattern="^https://www\.example\.com(/.*)?$",
) )
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_login_view_post_goodpass_goodlt(self): # For testing the restrict_users attributes
client, params = get_login_page_params() self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
params["username"] = settings.CAS_TEST_USER name="restrict_user_fail",
params["password"] = settings.CAS_TEST_PASSWORD pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
)
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
response = client.post('/login', params) # For testing the user attributes filtering conditions
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
self.assertEqual(response.status_code, 200) # For testing the user_field attributes
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid"
)
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom"
)
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue( self.assertTrue(
( (
b"You have successfully logged into " b"You have successfully logged into "
b"the Central Authentication Service" b"the Central Authentication Service"
) in response.content ) in response.content
) )
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue( self.assertTrue(
models.User.objects.get( models.User.objects.get(
...@@ -154,7 +222,59 @@ class LoginTestCase(TestCase): ...@@ -154,7 +222,59 @@ class LoginTestCase(TestCase):
) )
) )
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 100)
i_not_in_test = random.randint(100, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self): def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD params["password"] = settings.CAS_TEST_PASSWORD
...@@ -162,38 +282,39 @@ class LoginTestCase(TestCase): ...@@ -162,38 +282,39 @@ class LoginTestCase(TestCase):
response = client.post('/login', params) response = client.post('/login', params)
self.assertEqual(response.status_code, 200) self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content) self.assertTrue(b"Invalid login ticket" in response.content)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_login_view_post_badpass_good_lt(self): def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = "test2" params["password"] = "test2"
response = client.post('/login', params) response = client.post('/login', params)
self.assertEqual(response.status_code, 200) self.assert_login_failed(client, response)
self.assertTrue( self.assertTrue(
( (
b"The credentials you provided cannot be " b"The credentials you provided cannot be "
b"determined to be authentic" b"determined to be authentic"
) in response.content ) in response.content
) )
self.assertFalse(
( def assert_ticket_attributes(self, client, ticket_value):
b"You have successfully logged into " """check the ticket attributes in the db"""
b"the Central Authentication Service" user = models.User.objects.get(
) in response.content username=settings.CAS_TEST_USER,
session_key=client.session.session_key
) )
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def test_view_login_get_auth_allowed_service(self): def assert_service_ticket(self, client, response):
client = get_auth_client() """check that a ticket is well emited when requested on a allowed service"""
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location')) self.assertTrue(response.has_header('Location'))
self.assertTrue( self.assertTrue(
...@@ -203,23 +324,125 @@ class LoginTestCase(TestCase): ...@@ -203,23 +324,125 @@ class LoginTestCase(TestCase):
) )
ticket_value = response['Location'].split('ticket=')[-1] ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get( self.assert_ticket_attributes(client, ticket_value)
username=settings.CAS_TEST_USER,
session_key=client.session.session_key def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
"Authentication required by service "
"example (https://www.example.com)"
) in response.content
) )
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value) def test_view_login_get_denied_service(self):
self.assertEqual(ticket.user, user) """Request a ticket for an denied service by an unauthenticated client"""
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) client = Client()
self.assertEqual(ticket.validate, False) response = client.get("/login?service=https://www.example.net")
self.assertEqual(ticket.service_pattern, self.service_pattern) self.assertEqual(response.status_code, 200)
self.assertTrue("Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
"Authentication has been required by service "
"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self): def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client() client = get_auth_client()
response = client.get("/login?service=https://www.example.org") response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content) self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("Username non allowed" in response.content)
service = "https://restrict_user_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_filter(self):
"""Test the filtering on user attributes"""
service = "https://filter_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("User charateristics non allowed" in response.content)
service = "https://filter_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service = "https://field_needed_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("The attribut uid is needed to use that service" in response.content)
service = "https://field_needed_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
class LogoutTestCase(TestCase): class LogoutTestCase(TestCase):
...@@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase): ...@@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase):
namespaces={'cas': "http://www.yale.edu/tp/cas"} namespaces={'cas': "http://www.yale.edu/tp/cas"}
) )
self.assertEqual(len(attributes), 1) self.assertEqual(len(attributes), 1)
attrs1 = {} attrs1 = set()
for attr in attributes[0]: for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1)) self.assertEqual(len(attributes), len(attrs1))
attrs2 = {} attrs2 = set()
for attr in attributes: for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value'] attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for v in value:
original.add((key, v))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self): def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
...@@ -623,17 +853,24 @@ class ProxyTestCase(TestCase): ...@@ -623,17 +853,24 @@ class ProxyTestCase(TestCase):
namespaces={'cas': "http://www.yale.edu/tp/cas"} namespaces={'cas': "http://www.yale.edu/tp/cas"}
) )
self.assertEqual(len(attributes), 1) self.assertEqual(len(attributes), 1)
attrs1 = {} attrs1 = set()
for attr in attributes[0]: for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1)) self.assertEqual(len(attributes), len(attrs1))
attrs2 = {} attrs2 = set()
for attr in attributes: for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value'] attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for v in value:
original.add((key, v))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) self.assertEqual(attrs1, original)
def test_validate_proxy_bad(self): def test_validate_proxy_bad(self):
params = get_pgt() params = get_pgt()
......
...@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin): ...@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
service = None service = None
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
...@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin): ...@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6 USER_NOT_AUTHENTICATED = 6
def init_post(self, request): def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
...@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin): ...@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
def gen_lt(self):
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self): def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check # save LT for later check
lt_valid = self.request.session.get('lt', []) lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt') lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed) # generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] self.gen_lt()
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_send not in lt_valid: if lt_valid is None or lt_send not in lt_valid:
return False return False
...@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin): ...@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.user.save() self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.user = models.User.objects.create( self.user = models.User.objects.create(
username=self.request.session['username'], username=self.request.session['username'],
...@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin): ...@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
pass pass
else: else:
raise EnvironmentError("invalid output for LoginView.process_post") raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common() return self.common()
def process_post(self): def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
...@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin): ...@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED return self.USER_ALREADY_LOGGED
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request