Commit c0d85501 authored by Valentin Samir's avatar Valentin Samir

Add some tests using tox

parent 39557d19
*.pyc *.pyc
*.egg-info
bootstrap3 bootstrap3
cas/ cas/
db.sqlite3 db.sqlite3
manage.py manage.py
.tox
language: python
python:
- "2.7"
env:
global:
- PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
matrix:
- TOX_ENV=py27-django17
- TOX_ENV=py27-django18
- TOX_ENV=flake8
cache:
directories:
- $HOME/.pip-cache/
install:
- "travis_retry pip install setuptools --upgrade"
- "pip install tox"
script:
- tox -e $TOX_ENV
after_script:
- cat .tox/$TOX_ENV/log/*.log
...@@ -27,26 +27,14 @@ class UserCredential(forms.Form): ...@@ -27,26 +27,14 @@ class UserCredential(forms.Form):
method = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False)
warn = forms.BooleanField(label=_('warn'), required=False) warn = forms.BooleanField(label=_('warn'), required=False)
def __init__(self, request, *args, **kwargs): def __init__(self, *args, **kwargs):
self.request = request
super(UserCredential, self).__init__(*args, **kwargs) super(UserCredential, self).__init__(*args, **kwargs)
def clean(self): def clean(self):
cleaned_data = super(UserCredential, self).clean() cleaned_data = super(UserCredential, self).clean()
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username")) auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
if auth.test_password(cleaned_data.get("password")): if auth.test_password(cleaned_data.get("password")):
try: cleaned_data["username"] = auth.username
user = models.User.objects.get(
username=auth.username,
session_key=self.request.session.session_key
)
user.save()
except models.User.DoesNotExist:
user = models.User.objects.create(
username=auth.username,
session_key=self.request.session.session_key
)
user.save()
else: else:
raise forms.ValidationError(_(u"Bad user")) raise forms.ValidationError(_(u"Bad user"))
......
...@@ -89,11 +89,14 @@ class LogoutView(View, LogoutMixin): ...@@ -89,11 +89,14 @@ class LogoutView(View, LogoutMixin):
request = None request = None
service = None service = None
def get(self, request, *args, **kwargs): def init_get(self, request):
"""methode called on GET request on this view"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
def get(self, request, *args, **kwargs):
"""methode called on GET request on this view"""
self.init_get(request)
self.logout() self.logout()
# if service is set, redirect to service after logout # if service is set, redirect to service after logout
if self.service: if self.service:
...@@ -105,6 +108,7 @@ class LogoutView(View, LogoutMixin): ...@@ -105,6 +108,7 @@ class LogoutView(View, LogoutMixin):
# else redirect to login page # else redirect to login page
else: else:
if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT: if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT:
messages.add_message(request, messages.SUCCESS, _(u'Successfully logout')) messages.add_message(request, messages.SUCCESS, _(u'Successfully logout'))
return redirect("cas_server:login") return redirect("cas_server:login")
else: else:
...@@ -129,67 +133,110 @@ class LoginView(View, LogoutMixin): ...@@ -129,67 +133,110 @@ class LoginView(View, LogoutMixin):
renewed = False renewed = False
warned = False warned = False
def post(self, request, *args, **kwargs): INVALID_LOGIN_TICKET = 1
"""methode called on POST request on this view""" USER_LOGIN_OK = 2
USER_LOGIN_FAILURE = 3
USER_ALREADY_LOGGED = 4
USER_AUTHENTICATED = 5
USER_NOT_AUTHENTICATED = 6
def init_post(self, request):
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
self.renew = True if request.POST.get('renew') else False self.renew = True if request.POST.get('renew') else False
self.gateway = request.POST.get('gateway') self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method') self.method = request.POST.get('method')
def check_lt(self):
# save LT for later check # save LT for later check
lt_valid = request.session.get('lt') lt_valid = self.request.session.get('lt')
lt_send = request.POST.get('lt') lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed) # generate a new LT (by posting the LT has been consumed)
request.session['lt'] = utils.gen_lt() self.request.session['lt'] = utils.gen_lt()
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_valid != lt_send: if lt_valid is None or lt_valid != lt_send:
return False
else:
return True
def post(self, request, *args, **kwargs):
"""methode called on POST request on this view"""
self.init_post(request)
ret = self.process_post()
if ret == self.INVALID_LOGIN_TICKET:
messages.add_message( messages.add_message(
self.request, self.request,
messages.ERROR, messages.ERROR,
_(u"Invalid login ticket") _(u"Invalid login ticket")
) )
values = request.POST.copy() elif ret == self.USER_LOGIN_OK:
# if not set a new LT and fail try:
values['lt'] = request.session['lt']
self.init_form(values)
elif not request.session.get("authenticated") or self.renew:
self.init_form(request.POST)
if self.form.is_valid():
self.user = models.User.objects.get( self.user = models.User.objects.get(
username=self.form.cleaned_data['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
request.session.set_expiry(0) self.user.save()
request.session["username"] = self.form.cleaned_data['username'] except models.User.DoesNotExist:
request.session["warn"] = True if self.form.cleaned_data.get("warn") else False self.user = models.User.objects.create(
request.session["authenticated"] = True username=self.request.session['username'],
session_key=self.request.session.session_key
)
self.user.save()
elif ret == self.USER_LOGIN_FAILURE: # bad user login
self.logout()
elif ret == self.USER_ALREADY_LOGGED:
pass
else:
raise EnvironmentError("invalid output for LoginView.process_post")
return self.common()
def process_post(self, pytest=False):
if not self.check_lt():
values = self.request.POST.copy()
# if not set a new LT and fail
values['lt'] = self.request.session['lt']
self.init_form(values)
return self.INVALID_LOGIN_TICKET
elif not self.request.session.get("authenticated") or self.renew:
self.init_form(self.request.POST)
if self.form.is_valid():
self.request.session.set_expiry(0)
self.request.session["username"] = self.form.cleaned_data['username']
self.request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
self.request.session["authenticated"] = True
self.renewed = True self.renewed = True
self.warned = True self.warned = True
return self.USER_LOGIN_OK
else: else:
self.logout() return self.USER_LOGIN_FAILURE
return self.common() else:
return self.USER_ALREADY_LOGGED
def get(self, request, *args, **kwargs): def init_get(self, request):
"""methode called on GET request on this view"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.renew = True if request.GET.get('renew') else False self.renew = True if request.GET.get('renew') else False
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
def get(self, request, *args, **kwargs):
"""methode called on GET request on this view"""
self.init_get(request)
self.process_get()
return self.common()
def process_get(self):
# generate a new LT if none is present # generate a new LT if none is present
request.session['lt'] = request.session.get('lt', utils.gen_lt()) self.request.session['lt'] = self.request.session.get('lt', utils.gen_lt())
if not request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
return self.common() return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED
def init_form(self, values=None): def init_form(self, values=None):
self.form = forms.UserCredential( self.form = forms.UserCredential(
self.request,
values, values,
initial={ initial={
'service': self.service, 'service': self.service,
...@@ -345,7 +392,6 @@ class Auth(View): ...@@ -345,7 +392,6 @@ class Auth(View):
if not username or not password or not service: if not username or not password or not service:
return HttpResponse("no\n", content_type="text/plain") return HttpResponse("no\n", content_type="text/plain")
form = forms.UserCredential( form = forms.UserCredential(
request,
request.POST, request.POST,
initial={ initial={
'service': service, 'service': service,
...@@ -355,10 +401,17 @@ class Auth(View): ...@@ -355,10 +401,17 @@ class Auth(View):
) )
if form.is_valid(): if form.is_valid():
try: try:
user = models.User.objects.get( try:
username=form.cleaned_data['username'], user = models.User.objects.get(
session_key=request.session.session_key username=form.cleaned_data['username'],
) session_key=request.session.session_key
)
except models.User.DoesNotExist:
user = models.User.objects.create(
username=form.cleaned_data['username'],
session_key=request.session.session_key
)
user.save()
# is the service allowed # is the service allowed
service_pattern = ServicePattern.validate(service) service_pattern = ServicePattern.validate(service)
# is the current user allowed on this service # is the current user allowed on this service
......
tox==1.8.1
pytest==2.6.4
pytest-django==2.7.0
pytest-pythonpath==0.3
requests>=2.4
django-picklefield>=0.3.1
requests_futures>=0.9.5
django-bootstrap3>=5.4
lxml>=3.4
from cas_server import models
class DummyUserManager(object):
def __init__(self, username, session_key):
self.username = username
self.session_key = session_key
def get(self, username=None, session_key=None):
if username == self.username and session_key == self.session_key:
return models.User(username=username, session_key=session_key)
else:
raise models.User.DoesNotExist()
class DummyTicketManager(object):
def __init__(self, ticket_class, service, ticket):
self.ticket_class = ticket_class
self.service = service
self.ticket = ticket
def create(self, **kwargs):
for field in models.ServiceTicket._meta.fields:
field.allow_unsaved_instance_assignment = True
return self.ticket_class(**kwargs)
def filter(self, *args, **kwargs):
return DummyQuerySet()
def get(self, **kwargs):
if 'value' in kwargs:
if kwargs['value'] != self.ticket:
raise self.ticket_class.DoesNotExist()
else:
kwargs['value'] = self.ticket
if 'service' in kwargs:
if kwargs['service'] != self.service:
raise self.ticket_class.DoesNotExist()
else:
kwargs['service'] = self.service
if not 'user' in kwargs:
kwargs['user'] = models.User(username="test")
for field in models.ServiceTicket._meta.fields:
field.allow_unsaved_instance_assignment = True
for key in kwargs.keys():
if '__' in key:
del kwargs[key]
kwargs['attributs'] = {'mail': 'test@example.com'}
kwargs['service_pattern'] = models.ServicePattern()
return self.ticket_class(**kwargs)
class DummySession(dict):
session_key = "test_session"
def set_expiry(self, int):
pass
class DummyQuerySet(set):
pass
import django
from django.conf import settings
from django.contrib import messages
settings.configure()
settings.STATIC_URL = "/static/"
settings.DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': '/dev/null',
}
}
settings.INSTALLED_APPS = (
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'bootstrap3',
'cas_server',
)
settings.ROOT_URLCONF = "/"
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
try:
django.setup()
except AttributeError:
pass
messages.add_message = lambda x,y,z:None
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from lxml import etree
from cas_server.views import ValidateService
from cas_server import models
from .dummy import *
@pytest.mark.django_db
def test_validate_service_view_ok():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(users) == 1
assert users[0].text == "test"
attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(attributes) == 1
attrs = {}
for attr in attributes[0]:
attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
assert 'mail' in attrs
assert attrs['mail'] == 'test@example.com'
@pytest.mark.django_db
def test_validate_service_view_badservice():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(error) == 1
assert error[0].attrib['code'] == 'INVALID_SERVICE'
@pytest.mark.django_db
def test_validate_service_view_badticket():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(error) == 1
assert error[0].attrib['code'] == 'INVALID_TICKET'
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import Auth
from cas_server import models
from .dummy import *
settings.CAS_AUTH_SHARED_SECRET = "test"
@pytest.mark.django_db
def test_auth_view_goodpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == "yes\n"
def test_auth_view_badpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == "no\n"
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import LoginView
from cas_server import models
from .dummy import *
def test_login_view_post_goodpass_goodlt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
request.session = DummySession()
request.session['lt'] = 'LT-random'
request.session["username"] = os.urandom(20)
request.session["warn"] = os.urandom(20)
login = LoginView()
login.init_post(request)
ret = login.process_post(pytest=True)
assert ret == LoginView.USER_LOGIN_OK
assert request.session.get("authenticated") == True
assert request.session.get("username") == "test"
assert request.session.get("warn") == False
def test_login_view_post_badlt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
request.session = DummySession()
request.session['lt'] = 'LT-random2'
authenticated = os.urandom(20)
username = os.urandom(20)
warn = os.urandom(20)
request.session["authenticated"] = authenticated
request.session["username"] = username
request.session["warn"] = warn
login = LoginView()
login.init_post(request)
ret = login.process_post(pytest=True)
assert ret == LoginView.INVALID_LOGIN_TICKET
assert request.session.get("authenticated") == authenticated
assert request.session.get("username") == username
assert request.session.get("warn") == warn
def test_login_view_post_badpass_good_lt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
request.session = DummySession()
request.session['lt'] = 'LT-random'
login = LoginView()
login.init_post(request)
ret = login.process_post()
assert ret == LoginView.USER_LOGIN_FAILURE
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
def test_view_login_get_unauth():
factory = RequestFactory()
request = factory.post('/login')
request.session = DummySession()
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_NOT_AUTHENTICATED
login = LoginView()
response = login.get(request)
assert response.status_code == 200
@pytest.mark.django_db
def test_view_login_get_auth():
factory = RequestFactory()
request = factory.post('/login')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
login = LoginView()
response = login.get(request)
assert response.status_code == 200
@pytest.mark.django_db
def test_view_login_get_auth_service():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.User.save = lambda x:None
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
models.ServiceTicket.save = lambda x:None
login = LoginView()
response = login.get(request)
assert response.status_code == 302
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
@pytest.mark.django_db
def test_view_login_get_auth_service_warn():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = True