diff --git a/wiki_groups/tests.py b/wiki_groups/tests.py index 7ce503c..662c669 100644 --- a/wiki_groups/tests.py +++ b/wiki_groups/tests.py @@ -1,3 +1,154 @@ +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.test import TestCase -# Create your tests here. +from wiki_groups.models import WikiGroup + +User = get_user_model() + + +def create_user(username): + return User.objects.create_user(username=username) + + +def create_group(name, with_users=None): + """Create a new WikiGroup, initialised with some new Users.""" + django_group = Group.objects.create(name=name) + wiki_group = WikiGroup.objects.create(django_group=django_group) + if with_users is not None: + for username in with_users: + u = create_user(username) + wiki_group.users.add(u) + return wiki_group + + +class TestPropagation(TestCase): + """Test that WikiGroup changes are correctly propagated to normal Django groups.""" + + def test_tree(self): + """Simple case: tree-shaped group structure.""" + a = create_group("group_a", with_users=["a1", "a2"]) + b = create_group("group_b", with_users=["b1", "b2"]) + c = create_group("group_c", with_users=["c1", "c2"]) + + a.includes_groups.set([b, c]) + + self.assertQuerysetEqual( + a.django_group.user_set.all(), map(repr, User.objects.all()), ordered=False, + ) + self.assertQuerysetEqual( + b.django_group.user_set.all(), + ["b1", "b2"], + ordered=False, + transform=lambda u: u.username, + ) + self.assertQuerysetEqual( + c.django_group.user_set.all(), + ["c1", "c2"], + ordered=False, + transform=lambda u: u.username, + ) + + def test_diamond(self): + """Simple case: diamond-shaped group structure.""" + a = create_group("group_a", with_users=["a1", "a2"]) + b = create_group("group_b", with_users=["b1", "b2"]) + c = create_group("group_c", with_users=["c1", "c2"]) + d = create_group("group_d", with_users=["d1", "d2"]) + + a.includes_groups.set([b, c]) + b.includes_groups.add(d) + c.includes_groups.add(d) + + self.assertQuerysetEqual( + a.django_group.user_set.all(), map(repr, User.objects.all()), ordered=False, + ) + self.assertQuerysetEqual( + b.django_group.user_set.all(), + ["b1", "b2", "d1", "d2"], + ordered=False, + transform=lambda u: u.username, + ) + self.assertQuerysetEqual( + c.django_group.user_set.all(), + ["c1", "c2", "d1", "d2"], + ordered=False, + transform=lambda u: u.username, + ) + self.assertQuerysetEqual( + d.django_group.user_set.all(), + ["d1", "d2"], + ordered=False, + transform=lambda u: u.username, + ) + + def test_removal(self): + """Test propagation of user removal.""" + + a = create_group("group_a", with_users=["a1", "a2"]) + b = create_group("group_b", with_users=["b1", "b2"]) + a.includes_groups.add(b) + + self.assertQuerysetEqual( + a.django_group.user_set.all(), + ["a1", "a2", "b1", "b2"], + ordered=False, + transform=lambda u: u.username, + ) + + b.users.remove(User.objects.get(username="b1")) + + self.assertQuerysetEqual( + a.django_group.user_set.all(), + ["a1", "a2", "b2"], + ordered=False, + transform=lambda u: u.username, + ) + self.assertQuerysetEqual( + b.django_group.user_set.all(), ["b2"], transform=lambda u: u.username + ) + + def test_update(self): + """Test structure updates propagation.""" + + a = create_group("group_a") + b = create_group("group_b") + c = create_group("group_c") + a.includes_groups.add(b) + b.includes_groups.add(c) + + # Update: set a's children to [c] + # Before update: a --> b --> c + # After update: b --> c <-- a + a.includes_groups.set([c]) + + self.assertQuerysetEqual(a.includes_groups.all(), [repr(c)]) + self.assertQuerysetEqual(b.includes_groups.all(), [repr(c)]) + self.assertQuerysetEqual(c.includes_groups.all(), []) + + +class TestCycleDetection(TestCase): + """Test the cycle detection procedure.""" + + def test_loop(self): + """Test loops (a --> a) detection.""" + + a = create_group("group_a") + + in_cycle = a.group_in_cycle(with_children=[a]) + self.assertEqual(in_cycle, a) + + in_cycle = a.group_in_cycle(with_children=[a, a, a]) + self.assertEqual(in_cycle, a) + + def test_trivial_cycle(self): + """Test trivial cycle detection (a --> b --> c --> a).""" + + a = create_group("group_a") + b = create_group("group_b") + c = create_group("group_c") + a.includes_groups.add(b) + b.includes_groups.add(c) + + in_cycle = c.group_in_cycle(with_children=[a]) + self.assertIsNotNone(in_cycle)