diff --git a/allauth_ens/providers/clipper/provider.py b/allauth_ens/providers/clipper/provider.py index d14bee2..f0e23c8 100644 --- a/allauth_ens/providers/clipper/provider.py +++ b/allauth_ens/providers/clipper/provider.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -import ldap - from allauth.account.models import EmailAddress from allauth.socialaccount.providers.base import ProviderAccount + from allauth_cas.providers import CASProvider -from django.conf import settings +from .utils import get_names class ClipperAccount(ProviderAccount): @@ -22,39 +21,10 @@ class ClipperProvider(CASProvider): return '{}@clipper.ens.fr'.format(uid.strip().lower()) def extract_common_fields(self, data): - def get_names(clipper): - assert clipper.isalnum() - try: - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, - ldap.OPT_X_TLS_NEVER) - l = ldap.initialize("ldaps://ldap.spi.ens.fr:636") - l.set_option(ldap.OPT_REFERRALS, 0) - l.set_option(ldap.OPT_PROTOCOL_VERSION, 3) - l.set_option(ldap.OPT_X_TLS, ldap.OPT_X_TLS_DEMAND) - l.set_option(ldap.OPT_X_TLS_DEMAND, True) - l.set_option(ldap.OPT_DEBUG_LEVEL, 255) - l.set_option(ldap.OPT_NETWORK_TIMEOUT, 10) - l.set_option(ldap.OPT_TIMEOUT, 10) - - info = l.search_s('dc=spi,dc=ens,dc=fr', - ldap.SCOPE_SUBTREE, - ('(uid=%s)' % (clipper,)), - [str("cn"), ]) - - if len(info) > 0: - fullname = info[0][1].get('cn', [''])[0].decode("utf-8") - first_name, last_name = fullname.split(' ', 1) - return first_name, last_name - - except ldap.LDAPError: - pass - - return '', '' - common = super(ClipperProvider, self).extract_common_fields(data) fn, ln = get_names(common['username']) common['email'] = self.extract_email(data) - common['name'] = fn + common['first_name'] = fn common['last_name'] = ln return common diff --git a/allauth_ens/providers/clipper/tests.py b/allauth_ens/providers/clipper/tests.py index 657982c..9538985 100644 --- a/allauth_ens/providers/clipper/tests.py +++ b/allauth_ens/providers/clipper/tests.py @@ -1,15 +1,18 @@ +# -*- coding: utf-8 -*- from django.contrib.auth import get_user_model from allauth_cas.test.testcases import CASTestCase, CASViewTestCase +try: + from unittest import mock +except ImportError: + import mock + User = get_user_model() class ClipperProviderTests(CASTestCase): - def setUp(self): - self.u = User.objects.create_user('user', 'user@mail.net', 'user') - def test_auto_signup(self): self.client_cas_login( self.client, provider_id='clipper', username='clipper_uid') @@ -52,3 +55,71 @@ class ClipperViewsTests(CASViewTestCase): r, expected, fetch_redirect_response=False, ) + + +class ClipperLDAPTests(CASTestCase): + + def setUp(self): + self.mock_ldap_conn = mock.Mock() + self.mock_ldap_conn.search_s = mock.Mock(return_value=[]) + + patch_get_ldap_conn = mock.patch( + 'allauth_ens.providers.clipper.utils.get_ldap_connection', + return_value=self.mock_ldap_conn, + ) + patch_get_ldap_conn.start() + self.addCleanup(patch_get_ldap_conn.stop) + + def set_returned_fullname(self, fullname): + try: + bfullname = bytes(fullname, 'utf-8') + except TypeError: + bfullname = bytes(fullname) + self.mock_ldap_conn.search_s.return_value = [[None, {'cn': bfullname}]] + + def test_ok(self): + self.set_returned_fullname('abc def ghi') + + self.client_cas_login( + self.client, provider_id='clipper', username='theclipper') + + u = User.objects.get(username='theclipper') + self.assertEqual(u.first_name, 'abc') + self.assertEqual(u.last_name, 'def ghi') + + def test_short_fullname(self): + self.set_returned_fullname('abc') + + self.client_cas_login( + self.client, provider_id='clipper', username='theclipper') + + u = User.objects.get(username='theclipper') + self.assertEqual(u.first_name, 'abc') + self.assertEqual(u.last_name, '') + + def test_bad_uid(self): + self.client_cas_login( + self.client, provider_id='clipper', username='the_clipper') + + self.mock_ldap_conn.search_s.assert_not_called() + u = User.objects.get(username='the_clipper') + self.assertEqual(u.first_name, '') + self.assertEqual(u.last_name, '') + + def test_no_result(self): + self.client_cas_login( + self.client, provider_id='clipper', username='theclipper') + + u = User.objects.get(username='theclipper') + self.assertEqual(u.first_name, '') + self.assertEqual(u.last_name, '') + + def test_no_cn(self): + self.mock_ldap_conn.search_s.return_value = [[None, {}]] + + self.client_cas_login( + self.client, provider_id='clipper', username='theclipper') + + u = User.objects.get(username='theclipper') + self.assertEqual(u.first_name, '') + self.assertEqual(u.last_name, '') diff --git a/allauth_ens/providers/clipper/utils.py b/allauth_ens/providers/clipper/utils.py new file mode 100644 index 0000000..8b439b6 --- /dev/null +++ b/allauth_ens/providers/clipper/utils.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +import ldap + + +def get_ldap_connection(): + """ + Returns a connection ready-to-use with the LDAP server of clipper users. + """ + ldap_conn = ldap.initialize("ldaps://ldap.spi.ens.fr:636") + + ldap_conn.set_option(ldap.OPT_REFERRALS, 0) + ldap_conn.set_option(ldap.OPT_PROTOCOL_VERSION, 3) + ldap_conn.set_option(ldap.OPT_X_TLS, ldap.OPT_X_TLS_DEMAND) + ldap_conn.set_option(ldap.OPT_X_TLS_DEMAND, True) + ldap_conn.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER) + ldap_conn.set_option(ldap.OPT_DEBUG_LEVEL, 255) + ldap_conn.set_option(ldap.OPT_NETWORK_TIMEOUT, 5) + ldap_conn.set_option(ldap.OPT_TIMEOUT, 10) + + return ldap_conn + + +def get_names(clipper): + """ + Queries the LDAP server of clipper users to retrieve the names associated + with a clipper. + + Parameters + ---------- + clipper : str + A clipper (unique identifier for the LDAP server). + + Returns + ------- + (first_name, last_name) : tuple of str + The LDAP answers with a single string, the first name is before the + first space character, while the last name is the remaining string. + + If clipper contains non-alphanumeric characters or the server doesn't + return any results, both values are set to the empty string. + + """ + default = '', '' + + if not clipper.isalnum(): + return default + + try: + ldap_conn = get_ldap_connection() + results = ldap_conn.search_s( + 'dc=spi,dc=ens,dc=fr', + ldap.SCOPE_SUBTREE, + ('(uid=%s)' % (clipper,)), + [str("cn"), ], + ) + + if len(results) > 0: + data = results[0][1] + if 'cn' in data: + fullname = data['cn'].decode('utf-8') + names = fullname.split(' ', 1) + return names[0], names[1] if len(names) == 2 else '' + except ldap.LDAPError: + pass + + return default diff --git a/setup.py b/setup.py index 6b8231b..e13b942 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,6 @@ setup( 'django-allauth', 'django-allauth-cas>=1.0.0b2,<1.1', 'django-widget-tweaks', - 'python-ldap', + 'python-ldap>=3.0', ], )