Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
django-cas-server
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Valentin Samir
django-cas-server
Commits
bab79c4d
Commit
bab79c4d
authored
Jun 27, 2016
by
Valentin Samir
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
More unit tests (essentially for the login view) and some docstrings
parent
7db31578
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
341 additions
and
61 deletions
+341
-61
.coveragerc
.coveragerc
+1
-0
Makefile
Makefile
+2
-1
README.rst
README.rst
+2
-1
cas_server/default_settings.py
cas_server/default_settings.py
+6
-1
cas_server/tests.py
cas_server/tests.py
+285
-48
cas_server/views.py
cas_server/views.py
+22
-9
settings_tests.py
settings_tests.py
+1
-1
urls_tests.py
urls_tests.py
+22
-0
No files found.
.coveragerc
View file @
bab79c4d
...
...
@@ -5,3 +5,4 @@ exclude_lines =
def __unicode__
raise AssertionError
raise NotImplementedError
if six.PY3:
Makefile
View file @
bab79c4d
...
...
@@ -49,8 +49,9 @@ coverage: test_venv
test_venv/bin/pip
install
coverage
test_venv/bin/coverage run
--source
=
'cas_server'
--omit
=
'cas_server/migrations*'
run_tests
test_venv/bin/coverage html
test_venv/bin/coverage xml
rm
htmlcov/coverage_html.js
# I am really pissed off by those keybord shortcuts
coverage_codacy
:
coverage
test_venv/bin/coverage xml
test_venv/bin/pip
install
codacy-coverage
test_venv/bin/python-codacy-coverage
-r
coverage.xml
README.rst
View file @
bab79c4d
...
...
@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend
...
...
cas_server/default_settings.py
View file @
bab79c4d
...
...
@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
setting_default
(
'CAS_TEST_PASSWORD'
,
'test'
)
setting_default
(
'CAS_TEST_ATTRIBUTES'
,
{
'nom'
:
'Nymous'
,
'prenom'
:
'Ano'
,
'email'
:
'anonymous@example.net'
}
{
'nom'
:
'Nymous'
,
'prenom'
:
'Ano'
,
'email'
:
'anonymous@example.net'
,
'alias'
:
[
'demo1'
,
'demo2'
]
}
)
cas_server/tests.py
View file @
bab79c4d
...
...
@@ -3,36 +3,49 @@ from .default_settings import settings
from
django.test
import
TestCase
from
django.test
import
Client
import
re
import
six
import
random
from
lxml
import
etree
from
six.moves
import
range
from
cas_server
import
models
from
cas_server
import
utils
def
get_login_page_params
():
client
=
Client
()
response
=
client
.
get
(
'/login'
)
form
=
response
.
context
[
"form"
]
def
copy_form
(
form
):
"""Copy form value into a dict"""
params
=
{}
for
field
in
form
:
if
field
.
value
():
params
[
field
.
name
]
=
field
.
value
()
else
:
params
[
field
.
name
]
=
""
return
params
def
get_login_page_params
(
client
=
None
):
"""Return a client and the POST params for the client to login"""
if
client
is
None
:
client
=
Client
()
response
=
client
.
get
(
'/login'
)
params
=
copy_form
(
response
.
context
[
"form"
])
return
client
,
params
def
get_auth_client
():
def
get_auth_client
(
**
update
):
"""return a authenticated client"""
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
settings
.
CAS_TEST_PASSWORD
params
.
update
(
update
)
client
.
post
(
'/login'
,
params
)
return
client
def
get_user_ticket_request
(
service
):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client
=
get_auth_client
()
response
=
client
.
get
(
"/login"
,
{
"service"
:
service
})
ticket_value
=
response
[
'Location'
].
split
(
'ticket='
)[
-
1
]
...
...
@@ -45,6 +58,7 @@ def get_user_ticket_request(service):
def
get_pgt
():
"""return a dict contening a service, user and PGT ticket for this service"""
(
host
,
port
)
=
utils
.
PGTUrlHandler
.
run
()[
1
:
3
]
service
=
"http://%s:%s"
%
(
host
,
port
)
...
...
@@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase):
self
.
assertTrue
(
utils
.
check_password
(
"hex_md5"
,
self
.
password1
,
hashed_password1
,
"utf8"
))
self
.
assertFalse
(
utils
.
check_password
(
"hex_md5"
,
self
.
password2
,
hashed_password1
,
"utf8"
))
def
test_h
o
x_sha512
(
self
):
def
test_h
e
x_sha512
(
self
):
"""test the hex_sha512 auth method"""
hashed_password1
=
utils
.
hashlib
.
sha512
(
self
.
password1
).
hexdigest
()
...
...
@@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase):
class
LoginTestCase
(
TestCase
):
"""Tests for the login view"""
def
setUp
(
self
):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings
.
CAS_AUTH_CLASS
=
'cas_server.auth.TestAuthUser'
# For general purpose testing
self
.
service_pattern
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"example"
,
pattern
=
"^https://www\.example\.com(/.*)?$"
,
)
models
.
ReplaceAttributName
.
objects
.
create
(
name
=
"*"
,
service_pattern
=
self
.
service_pattern
)
def
test_login_view_post_goodpass_goodlt
(
self
):
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
settings
.
CAS_TEST_PASSWORD
# For testing the restrict_users attributes
self
.
service_pattern_restrict_user_fail
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"restrict_user_fail"
,
pattern
=
"^https://restrict_user_fail\.example\.com(/.*)?$"
,
restrict_users
=
True
,
)
self
.
service_pattern_restrict_user_success
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"restrict_user_success"
,
pattern
=
"^https://restrict_user_success\.example\.com(/.*)?$"
,
restrict_users
=
True
,
)
models
.
Username
.
objects
.
create
(
value
=
settings
.
CAS_TEST_USER
,
service_pattern
=
self
.
service_pattern_restrict_user_success
)
response
=
client
.
post
(
'/login'
,
params
)
# For testing the user attributes filtering conditions
self
.
service_pattern_filter_fail
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"filter_fail"
,
pattern
=
"^https://filter_fail\.example\.com(/.*)?$"
,
)
models
.
FilterAttributValue
.
objects
.
create
(
attribut
=
"right"
,
pattern
=
"^admin$"
,
service_pattern
=
self
.
service_pattern_filter_fail
)
self
.
service_pattern_filter_success
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"filter_success"
,
pattern
=
"^https://filter_success\.example\.com(/.*)?$"
,
)
models
.
FilterAttributValue
.
objects
.
create
(
attribut
=
"email"
,
pattern
=
"^%s$"
%
re
.
escape
(
settings
.
CAS_TEST_ATTRIBUTES
[
'email'
]),
service_pattern
=
self
.
service_pattern_filter_success
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# For testing the user_field attributes
self
.
service_pattern_field_needed_fail
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"field_needed_fail"
,
pattern
=
"^https://field_needed_fail\.example\.com(/.*)?$"
,
user_field
=
"uid"
)
self
.
service_pattern_field_needed_success
=
models
.
ServicePattern
.
objects
.
create
(
name
=
"field_needed_success"
,
pattern
=
"^https://field_needed_success\.example\.com(/.*)?$"
,
user_field
=
"nom"
)
def
assert_logged
(
self
,
client
,
response
,
warn
=
False
,
code
=
200
):
"""Assertions testing that client is well authenticated"""
self
.
assertEqual
(
response
.
status_code
,
code
)
self
.
assertTrue
(
(
b
"You have successfully logged into "
b
"the Central Authentication Service"
)
in
response
.
content
)
self
.
assertTrue
(
client
.
session
[
"username"
]
==
settings
.
CAS_TEST_USER
)
self
.
assertTrue
(
client
.
session
[
"warn"
]
is
warn
)
self
.
assertTrue
(
client
.
session
[
"authenticated"
]
is
True
)
self
.
assertTrue
(
models
.
User
.
objects
.
get
(
...
...
@@ -154,7 +222,59 @@ class LoginTestCase(TestCase):
)
)
def
assert_login_failed
(
self
,
client
,
response
,
code
=
200
):
"""Assertions testing a failed login attempt"""
self
.
assertEqual
(
response
.
status_code
,
code
)
self
.
assertFalse
(
(
b
"You have successfully logged into "
b
"the Central Authentication Service"
)
in
response
.
content
)
self
.
assertTrue
(
client
.
session
.
get
(
"username"
)
is
None
)
self
.
assertTrue
(
client
.
session
.
get
(
"warn"
)
is
None
)
self
.
assertTrue
(
client
.
session
.
get
(
"authenticated"
)
is
None
)
def
test_login_view_post_goodpass_goodlt
(
self
):
"""Test a successul login"""
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
settings
.
CAS_TEST_PASSWORD
self
.
assertTrue
(
params
[
'lt'
]
in
client
.
session
[
'lt'
])
response
=
client
.
post
(
'/login'
,
params
)
self
.
assert_logged
(
client
,
response
)
# LoginTicket conssumed
self
.
assertTrue
(
params
[
'lt'
]
not
in
client
.
session
[
'lt'
])
def
test_login_view_post_goodpass_goodlt_warn
(
self
):
"""Test a successul login requesting to be warned before creating services tickets"""
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
settings
.
CAS_TEST_PASSWORD
params
[
"warn"
]
=
"on"
response
=
client
.
post
(
'/login'
,
params
)
self
.
assert_logged
(
client
,
response
,
warn
=
True
)
def
test_lt_max
(
self
):
"""Check we only keep the last 100 Login Ticket for a user"""
client
,
params
=
get_login_page_params
()
current_lt
=
params
[
"lt"
]
i_in_test
=
random
.
randint
(
0
,
100
)
i_not_in_test
=
random
.
randint
(
100
,
150
)
for
i
in
range
(
150
):
if
i
==
i_in_test
:
self
.
assertTrue
(
current_lt
in
client
.
session
[
'lt'
])
if
i
==
i_not_in_test
:
self
.
assertTrue
(
current_lt
not
in
client
.
session
[
'lt'
])
self
.
assertTrue
(
len
(
client
.
session
[
'lt'
])
<=
100
)
client
,
params
=
get_login_page_params
(
client
)
self
.
assertTrue
(
len
(
client
.
session
[
'lt'
])
<=
100
)
def
test_login_view_post_badlt
(
self
):
"""Login attempt with a bad LoginTicket"""
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
settings
.
CAS_TEST_PASSWORD
...
...
@@ -162,38 +282,39 @@ class LoginTestCase(TestCase):
response
=
client
.
post
(
'/login'
,
params
)
self
.
assert
Equal
(
response
.
status_code
,
200
)
self
.
assert
_login_failed
(
client
,
response
)
self
.
assertTrue
(
b
"Invalid login ticket"
in
response
.
content
)
self
.
assertFalse
(
(
b
"You have successfully logged into "
b
"the Central Authentication Service"
)
in
response
.
content
)
def
test_login_view_post_badpass_good_lt
(
self
):
"""Login attempt with a bad password"""
client
,
params
=
get_login_page_params
()
params
[
"username"
]
=
settings
.
CAS_TEST_USER
params
[
"password"
]
=
"test2"
response
=
client
.
post
(
'/login'
,
params
)
self
.
assert
Equal
(
response
.
status_code
,
200
)
self
.
assert
_login_failed
(
client
,
response
)
self
.
assertTrue
(
(
b
"The credentials you provided cannot be "
b
"determined to be authentic"
)
in
response
.
content
)
self
.
assertFalse
(
(
b
"You have successfully logged into "
b
"the Central Authentication Service"
)
in
response
.
content
def
assert_ticket_attributes
(
self
,
client
,
ticket_value
):
"""check the ticket attributes in the db"""
user
=
models
.
User
.
objects
.
get
(
username
=
settings
.
CAS_TEST_USER
,
session_key
=
client
.
session
.
session_key
)
self
.
assertTrue
(
user
)
ticket
=
models
.
ServiceTicket
.
objects
.
get
(
value
=
ticket_value
)
self
.
assertEqual
(
ticket
.
user
,
user
)
self
.
assertEqual
(
ticket
.
attributs
,
settings
.
CAS_TEST_ATTRIBUTES
)
self
.
assertEqual
(
ticket
.
validate
,
False
)
self
.
assertEqual
(
ticket
.
service_pattern
,
self
.
service_pattern
)
def
test_view_login_get_auth_allowed_service
(
self
):
client
=
get_auth_client
()
response
=
client
.
get
(
"/login?service=https://www.example.com"
)
def
assert_service_ticket
(
self
,
client
,
response
):
"""check that a ticket is well emited when requested on a allowed service"""
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertTrue
(
response
.
has_header
(
'Location'
))
self
.
assertTrue
(
...
...
@@ -203,23 +324,125 @@ class LoginTestCase(TestCase):
)
ticket_value
=
response
[
'Location'
].
split
(
'ticket='
)[
-
1
]
user
=
models
.
User
.
objects
.
get
(
username
=
settings
.
CAS_TEST_USER
,
session_key
=
client
.
session
.
session_key
self
.
assert_ticket_attributes
(
client
,
ticket_value
)
def
test_view_login_get_allowed_service
(
self
):
"""Request a ticket for an allowed service by an unauthenticated client"""
client
=
Client
()
response
=
client
.
get
(
"/login?service=https://www.example.com"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
(
"Authentication required by service "
"example (https://www.example.com)"
)
in
response
.
content
)
self
.
assertTrue
(
user
)
ticket
=
models
.
ServiceTicket
.
objects
.
get
(
value
=
ticket_value
)
self
.
assertEqual
(
ticket
.
user
,
user
)
self
.
assertEqual
(
ticket
.
attributs
,
settings
.
CAS_TEST_ATTRIBUTES
)
self
.
assertEqual
(
ticket
.
validate
,
False
)
self
.
assertEqual
(
ticket
.
service_pattern
,
self
.
service_pattern
)
def
test_view_login_get_denied_service
(
self
):
"""Request a ticket for an denied service by an unauthenticated client"""
client
=
Client
()
response
=
client
.
get
(
"/login?service=https://www.example.net"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
"Service https://www.example.net non allowed"
in
response
.
content
)
def
test_view_login_get_auth_allowed_service
(
self
):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client
=
get_auth_client
()
response
=
client
.
get
(
"/login?service=https://www.example.com"
)
self
.
assert_service_ticket
(
client
,
response
)
def
test_view_login_get_auth_allowed_service_warn
(
self
):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client
=
get_auth_client
(
warn
=
"on"
)
response
=
client
.
get
(
"/login?service=https://www.example.com"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
(
"Authentication has been required by service "
"example (https://www.example.com)"
)
in
response
.
content
)
params
=
copy_form
(
response
.
context
[
"form"
])
response
=
client
.
post
(
"/login"
,
params
)
self
.
assert_service_ticket
(
client
,
response
)
def
test_view_login_get_auth_denied_service
(
self
):
"""Request a ticket for a not allowed service by an authenticated client"""
client
=
get_auth_client
()
response
=
client
.
get
(
"/login?service=https://www.example.org"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
b
"Service https://www.example.org non allowed"
in
response
.
content
)
def
test_user_logged_not_in_db
(
self
):
"""If the user is logged but has been delete from the database, it should be logged out"""
client
=
get_auth_client
()
models
.
User
.
objects
.
get
(
username
=
settings
.
CAS_TEST_USER
,
session_key
=
client
.
session
.
session_key
).
delete
()
response
=
client
.
get
(
"/login"
)
self
.
assert_login_failed
(
client
,
response
,
code
=
302
)
self
.
assertEqual
(
response
[
"Location"
],
"/login?"
)
def
test_service_restrict_user
(
self
):
"""Testing the restric user capability fro a service"""
service
=
"https://restrict_user_fail.example.com"
client
=
get_auth_client
()
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
"Username non allowed"
in
response
.
content
)
service
=
"https://restrict_user_success.example.com"
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertTrue
(
response
[
"Location"
].
startswith
(
"%s?ticket="
%
service
))
def
test_service_filter
(
self
):
"""Test the filtering on user attributes"""
service
=
"https://filter_fail.example.com"
client
=
get_auth_client
()
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
"User charateristics non allowed"
in
response
.
content
)
service
=
"https://filter_success.example.com"
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertTrue
(
response
[
"Location"
].
startswith
(
"%s?ticket="
%
service
))
def
test_service_user_field
(
self
):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service
=
"https://field_needed_fail.example.com"
client
=
get_auth_client
()
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
200
)
self
.
assertTrue
(
"The attribut uid is needed to use that service"
in
response
.
content
)
service
=
"https://field_needed_success.example.com"
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
})
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertTrue
(
response
[
"Location"
].
startswith
(
"%s?ticket="
%
service
))
def
test_gateway
(
self
):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service
=
"https://restrict_user_fail.example.com"
client
=
get_auth_client
()
response
=
client
.
get
(
"/login"
,
{
'service'
:
service
,
'gateway'
:
'on'
})
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertEqual
(
response
[
"Location"
],
service
)
# second for an user not yet authenticated on a valid service
client
=
Client
()
response
=
client
.
get
(
'/login'
,
{
'service'
:
service
,
'gateway'
:
'on'
})
self
.
assertEqual
(
response
.
status_code
,
302
)
self
.
assertEqual
(
response
[
"Location"
],
service
)
class
LogoutTestCase
(
TestCase
):
...
...
@@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase):
namespaces
=
{
'cas'
:
"http://www.yale.edu/tp/cas"
}
)
self
.
assertEqual
(
len
(
attributes
),
1
)
attrs1
=
{}
attrs1
=
set
()
for
attr
in
attributes
[
0
]:
attrs1
[
attr
.
tag
[
len
(
"http://www.yale.edu/tp/cas"
)
+
2
:]]
=
attr
.
text
attrs1
.
add
((
attr
.
tag
[
len
(
"http://www.yale.edu/tp/cas"
)
+
2
:],
attr
.
text
))
attributes
=
root
.
xpath
(
"//cas:attribute"
,
namespaces
=
{
'cas'
:
"http://www.yale.edu/tp/cas"
})
self
.
assertEqual
(
len
(
attributes
),
len
(
attrs1
))
attrs2
=
{}
attrs2
=
set
()
for
attr
in
attributes
:
attrs2
[
attr
.
attrib
[
'name'
]]
=
attr
.
attrib
[
'value'
]
attrs2
.
add
((
attr
.
attrib
[
'name'
],
attr
.
attrib
[
'value'
]))
original
=
set
()
for
key
,
value
in
settings
.
CAS_TEST_ATTRIBUTES
.
items
():
if
isinstance
(
value
,
list
):
for
v
in
value
:
original
.
add
((
key
,
v
))
else
:
original
.
add
((
key
,
value
))
self
.
assertEqual
(
attrs1
,
attrs2
)
self
.
assertEqual
(
attrs1
,
settings
.
CAS_TEST_ATTRIBUTES
)
self
.
assertEqual
(
attrs1
,
original
)
def
test_validate_service_view_badservice
(
self
):
ticket
=
get_user_ticket_request
(
self
.
service
)[
1
]
...
...
@@ -623,17 +853,24 @@ class ProxyTestCase(TestCase):
namespaces
=
{
'cas'
:
"http://www.yale.edu/tp/cas"
}
)
self
.
assertEqual
(
len
(
attributes
),
1
)
attrs1
=
{}
attrs1
=
set
()
for
attr
in
attributes
[
0
]:
attrs1
[
attr
.
tag
[
len
(
"http://www.yale.edu/tp/cas"
)
+
2
:]]
=
attr
.
text
attrs1
.
add
((
attr
.
tag
[
len
(
"http://www.yale.edu/tp/cas"
)
+
2
:],
attr
.
text
))
attributes
=
root
.
xpath
(
"//cas:attribute"
,
namespaces
=
{
'cas'
:
"http://www.yale.edu/tp/cas"
})
self
.
assertEqual
(
len
(
attributes
),
len
(
attrs1
))
attrs2
=
{}
attrs2
=
set
()
for
attr
in
attributes
:
attrs2
[
attr
.
attrib
[
'name'
]]
=
attr
.
attrib
[
'value'
]
attrs2
.
add
((
attr
.
attrib
[
'name'
],
attr
.
attrib
[
'value'
]))
original
=
set
()
for
key
,
value
in
settings
.
CAS_TEST_ATTRIBUTES
.
items
():
if
isinstance
(
value
,
list
):
for
v
in
value
:
original
.
add
((
key
,
v
))
else
:
original
.
add
((
key
,
value
))
self
.
assertEqual
(
attrs1
,
attrs2
)
self
.
assertEqual
(
attrs1
,
settings
.
CAS_TEST_ATTRIBUTES
)
self
.
assertEqual
(
attrs1
,
original
)
def
test_validate_proxy_bad
(
self
):
params
=
get_pgt
()
...
...
cas_server/views.py
View file @
bab79c4d
...
...
@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
service
=
None
def
init_get
(
self
,
request
):
"""Initialize GET received parameters"""
self
.
request
=
request
self
.
service
=
request
.
GET
.
get
(
'service'
)
self
.
url
=
request
.
GET
.
get
(
'url'
)
...
...
@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED
=
6
def
init_post
(
self
,
request
):
"""Initialize POST received parameters"""
self
.
request
=
request
self
.
service
=
request
.
POST
.
get
(
'service'
)
self
.
renew
=
bool
(
request
.
POST
.
get
(
'renew'
)
and
request
.
POST
[
'renew'
]
!=
"False"
)
...
...
@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
if
request
.
POST
.
get
(
'warned'
)
and
request
.
POST
[
'warned'
]
!=
"False"
:
self
.
warned
=
True
def
gen_lt
(
self
):
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
self
.
request
.
session
[
'lt'
]
=
self
.
request
.
session
.
get
(
'lt'
,
[])
+
[
utils
.
gen_lt
()]
if
len
(
self
.
request
.
session
[
'lt'
])
>
100
:
self
.
request
.
session
[
'lt'
]
=
self
.
request
.
session
[
'lt'
][
-
100
:]
def
check_lt
(
self
):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check
lt_valid
=
self
.
request
.
session
.
get
(
'lt'
,
[])
lt_send
=
self
.
request
.
POST
.
get
(
'lt'
)
# generate a new LT (by posting the LT has been consumed)
self
.
request
.
session
[
'lt'
]
=
self
.
request
.
session
.
get
(
'lt'
,
[])
+
[
utils
.
gen_lt
()]
if
len
(
self
.
request
.
session
[
'lt'
])
>
100
:
self
.
request
.
session
[
'lt'
]
=
self
.
request
.
session
[
'lt'
][
-
100
:]
self
.
gen_lt
()
# check if send LT is valid
if
lt_valid
is
None
or
lt_send
not
in
lt_valid
:
return
False
...
...
@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
username
=
self
.
request
.
session
[
'username'
],
session_key
=
self
.
request
.
session
.
session_key
)
self
.
user
.
save
()
self
.
user
.
save
()
# pragma: no cover (should not happend)
except
models
.
User
.
DoesNotExist
:
self
.
user
=
models
.
User
.
objects
.
create
(
username
=
self
.
request
.
session
[
'username'
],
...
...
@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
elif
ret
==
self
.
USER_ALREADY_LOGGED
:
pass
else
:
raise
EnvironmentError
(
"invalid output for LoginView.process_post"
)
raise
EnvironmentError
(
"invalid output for LoginView.process_post"
)
# pragma: no cover
return
self
.
common
()
def
process_post
(
self
):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if
not
self
.
check_lt
():
values
=
self
.
request
.
POST
.
copy
()
# if not set a new LT and fail
...
...
@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
return
self
.
USER_ALREADY_LOGGED
def
init_get
(
self
,
request
):
"""Initialize GET received parameters"""
self
.
request
=
request
self
.
service
=
request
.
GET
.
get
(
'service'
)
self
.
renew
=
bool
(
request
.
GET
.
get
(
'renew'
)
and
request
.
GET
[
'renew'
]
!=
"False"
)
...
...
@@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin):
return
self
.
common
()
def
process_get
(
self
):
# generate a new LT if none is present
self
.
request
.
session
[
'lt'
]
=
self
.
request
.
session
.
get
(
'lt'
,
[])
+
[
utils
.
gen_lt
()]
"""Analyse the GET request"""
# generate a new LT
self
.
gen_lt
()
if
not
self
.
request
.
session
.
get
(
"authenticated"
)
or
self
.
renew
:
self
.
init_form
()
return
self
.
USER_NOT_AUTHENTICATED
return
self
.
USER_AUTHENTICATED
def
init_form
(
self
,
values
=
None
):
"""Initialization of the good form depending of POST and GET parameters"""
self
.
form
=
forms
.
UserCredential
(
values
,
initial
=
{
...
...
settings_tests.py
View file @
bab79c4d
...
...
@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware'
,
]
ROOT_URLCONF
=
'
cas_server.url
s'
ROOT_URLCONF
=
'
urls_test
s'