154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
from django.contrib.auth import get_user_model
|
|
from django.contrib.auth.models import Group
|
|
from django.test import TestCase
|
|
|
|
from wiki_groups.models import WikiGroup
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
def create_user(username):
|
|
return User.objects.create_user(username=username)
|
|
|
|
|
|
def create_group(name, with_users=None):
|
|
"""Create a new WikiGroup, initialised with some new Users."""
|
|
django_group = Group.objects.create(name=name)
|
|
wiki_group = WikiGroup.objects.create(django_group=django_group)
|
|
if with_users is not None:
|
|
for username in with_users:
|
|
u = create_user(username)
|
|
wiki_group.users.add(u)
|
|
return wiki_group
|
|
|
|
|
|
class TestPropagation(TestCase):
|
|
"""Test that WikiGroup changes are correctly propagated to normal Django groups."""
|
|
|
|
def test_tree(self):
|
|
"""Simple case: tree-shaped group structure."""
|
|
a = create_group("group_a", with_users=["a1", "a2"])
|
|
b = create_group("group_b", with_users=["b1", "b2"])
|
|
c = create_group("group_c", with_users=["c1", "c2"])
|
|
|
|
a.includes_groups.set([b, c])
|
|
|
|
self.assertQuerysetEqual(
|
|
a.django_group.user_set.all(), map(repr, User.objects.all()), ordered=False,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
b.django_group.user_set.all(),
|
|
["b1", "b2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
c.django_group.user_set.all(),
|
|
["c1", "c2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
|
|
def test_diamond(self):
|
|
"""Simple case: diamond-shaped group structure."""
|
|
a = create_group("group_a", with_users=["a1", "a2"])
|
|
b = create_group("group_b", with_users=["b1", "b2"])
|
|
c = create_group("group_c", with_users=["c1", "c2"])
|
|
d = create_group("group_d", with_users=["d1", "d2"])
|
|
|
|
a.includes_groups.set([b, c])
|
|
b.includes_groups.add(d)
|
|
c.includes_groups.add(d)
|
|
|
|
self.assertQuerysetEqual(
|
|
a.django_group.user_set.all(), map(repr, User.objects.all()), ordered=False,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
b.django_group.user_set.all(),
|
|
["b1", "b2", "d1", "d2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
c.django_group.user_set.all(),
|
|
["c1", "c2", "d1", "d2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
d.django_group.user_set.all(),
|
|
["d1", "d2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
|
|
def test_removal(self):
|
|
"""Test propagation of user removal."""
|
|
|
|
a = create_group("group_a", with_users=["a1", "a2"])
|
|
b = create_group("group_b", with_users=["b1", "b2"])
|
|
a.includes_groups.add(b)
|
|
|
|
self.assertQuerysetEqual(
|
|
a.django_group.user_set.all(),
|
|
["a1", "a2", "b1", "b2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
|
|
b.users.remove(User.objects.get(username="b1"))
|
|
|
|
self.assertQuerysetEqual(
|
|
a.django_group.user_set.all(),
|
|
["a1", "a2", "b2"],
|
|
ordered=False,
|
|
transform=lambda u: u.username,
|
|
)
|
|
self.assertQuerysetEqual(
|
|
b.django_group.user_set.all(), ["b2"], transform=lambda u: u.username
|
|
)
|
|
|
|
def test_update(self):
|
|
"""Test structure updates propagation."""
|
|
|
|
a = create_group("group_a")
|
|
b = create_group("group_b")
|
|
c = create_group("group_c")
|
|
a.includes_groups.add(b)
|
|
b.includes_groups.add(c)
|
|
|
|
# Update: set a's children to [c]
|
|
# Before update: a --> b --> c
|
|
# After update: b --> c <-- a
|
|
a.includes_groups.set([c])
|
|
|
|
self.assertQuerysetEqual(a.includes_groups.all(), [repr(c)])
|
|
self.assertQuerysetEqual(b.includes_groups.all(), [repr(c)])
|
|
self.assertQuerysetEqual(c.includes_groups.all(), [])
|
|
|
|
|
|
class TestCycleDetection(TestCase):
|
|
"""Test the cycle detection procedure."""
|
|
|
|
def test_loop(self):
|
|
"""Test loops (a --> a) detection."""
|
|
|
|
a = create_group("group_a")
|
|
|
|
in_cycle = a.group_in_cycle(with_children=[a])
|
|
self.assertEqual(in_cycle, a)
|
|
|
|
in_cycle = a.group_in_cycle(with_children=[a, a, a])
|
|
self.assertEqual(in_cycle, a)
|
|
|
|
def test_trivial_cycle(self):
|
|
"""Test trivial cycle detection (a --> b --> c --> a)."""
|
|
|
|
a = create_group("group_a")
|
|
b = create_group("group_b")
|
|
c = create_group("group_c")
|
|
a.includes_groups.add(b)
|
|
b.includes_groups.add(c)
|
|
|
|
in_cycle = c.group_in_cycle(with_children=[a])
|
|
self.assertIsNotNone(in_cycle)
|