Commit c7c5151a authored by Valentin Samir's avatar Valentin Samir

Tests comments and move http server handlers from cas_server.utils to cas_server.tests.utils

parent 3ada10b3
......@@ -9,8 +9,7 @@ from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server import utils
from cas_server.tests.utils import get_auth_client
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
......@@ -125,22 +124,32 @@ class TicketTestCase(TestCase, UserModels, BaseServicePattern):
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
This diff is collapsed.
......@@ -3,10 +3,13 @@ from cas_server.default_settings import settings
from django.test import Client
import cgi
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from cas_server import models
from cas_server import utils
def copy_form(form):
......@@ -70,7 +73,7 @@ def get_validated_ticket(service):
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)[:2]
......@@ -100,3 +103,67 @@ def get_proxy_ticket(service):
proxy_ticket = proxy_ticket[0].text
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
return ticket
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
......@@ -23,10 +23,8 @@ import hashlib
import crypt
import base64
import six
import cgi
from threading import Thread
from importlib import import_module
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
......@@ -151,70 +149,6 @@ def gen_saml_id():
return _gen_ticket('_')
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment