diff --git a/docs/source/autodoc/server/camcops_server/_index.rst b/docs/source/autodoc/server/camcops_server/_index.rst
index 817aaf333..a652e61d7 100644
--- a/docs/source/autodoc/server/camcops_server/_index.rst
+++ b/docs/source/autodoc/server/camcops_server/_index.rst
@@ -198,6 +198,7 @@ server/camcops_server
cc_modules/cc_taskschedule.py.rst
cc_modules/cc_taskschedulereports.py.rst
cc_modules/cc_testfactories.py.rst
+ cc_modules/cc_testproviders.py.rst
cc_modules/cc_text.py.rst
cc_modules/cc_tracker.py.rst
cc_modules/cc_trackerhelpers.py.rst
@@ -649,5 +650,6 @@ server/camcops_server
templates/test/test_template_filters.mako.rst
templates/test/testpage.mako.rst
tools/fetch_snomed_codes.py.rst
+ tools/generate_task_factories.py.rst
tools/print_latest_github_version.py.rst
tools/run_server_self_tests.py.rst
diff --git a/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst b/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst
new file mode 100644
index 000000000..d45db5afa
--- /dev/null
+++ b/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst
@@ -0,0 +1,29 @@
+.. docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst
+
+.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
+
+
+.. Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
+ Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
+ .
+ This file is part of CamCOPS.
+ .
+ CamCOPS is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+ .
+ CamCOPS is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+ .
+ You should have received a copy of the GNU General Public License
+ along with CamCOPS. If not, see .
+
+
+camcops_server.cc_modules.cc_testproviders
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. automodule:: camcops_server.cc_modules.cc_testproviders
+ :members:
diff --git a/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst b/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst
new file mode 100644
index 000000000..963fc05a7
--- /dev/null
+++ b/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst
@@ -0,0 +1,29 @@
+.. docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst
+
+.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
+
+
+.. Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
+ Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
+ .
+ This file is part of CamCOPS.
+ .
+ CamCOPS is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+ .
+ CamCOPS is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+ .
+ You should have received a copy of the GNU General Public License
+ along with CamCOPS. If not, see .
+
+
+camcops_server.tools.generate_task_factories
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. automodule:: camcops_server.tools.generate_task_factories
+ :members:
diff --git a/docs/source/developer/server_testing.rst b/docs/source/developer/server_testing.rst
index f8b1a242a..ae65e1cde 100644
--- a/docs/source/developer/server_testing.rst
+++ b/docs/source/developer/server_testing.rst
@@ -20,7 +20,7 @@
.. _pytest: https://docs.pytest.org/en/stable/
-
+.. _Factory Boy: https://factoryboy.readthedocs.io/en/stable/
Testing the server code
=======================
@@ -32,12 +32,26 @@ with the filename of the module appended with ``_tests.py``. So the module
``camcops_server/cc_modules/cc_patient.py`` is tested in
``camcops_server/cc_modules/tests/cc_patient_tests.py``.
-Test classes should end in ``Tests`` e.g. ``PatientTests``. Tests that require
-an empty database should inherit from ``DemoRequestTestCase``. Tests that
-require the demonstration database should inherit from
-``DemoDatabaseTestCase``. See ``camcops_server/cc_modules/cc_unittest``. Tests
-that do not require a database can just inherit from the standard python
-``unittest.TestCase``
+Test classes should end in ``Tests`` e.g. ``PatientTests``. A number of
+``unittest.TestCase`` subclasses are defined in
+``camcops_server/cc_modules/cc_unittest``.
+
+- Tests that require an empty database and a request object should inherit from
+ ``DemoRequestTestCase``.
+
+- Tests that require a minimal database setup (system user, superuser set on the
+ request object, group administrator and a server device) should inherit from
+ ``BasicDatabaseTestCase``.
+
+- Tests that require the demonstration database, which has a patient and two
+ instances of each type of task should inherit from ``DemoDatabaseTestCase``.
+
+- Tests that do not require a database
+ can just inherit from the standard python ``unittest.TestCase``.
+
+Use `Factory Boy`_ test factories to create test instances of SQLAlchemy
+database models. See ``camcops_server/cc_modules/cc_testfactories.py`` and
+``camcops_server/tasks/tests/factories.py``.
.. _run_all_server_tests:
diff --git a/server/camcops_server/cc_modules/cc_membership.py b/server/camcops_server/cc_modules/cc_membership.py
index b05e879d4..3d6de0f16 100644
--- a/server/camcops_server/cc_modules/cc_membership.py
+++ b/server/camcops_server/cc_modules/cc_membership.py
@@ -156,10 +156,6 @@ class UserGroupMembership(Base):
group = relationship("Group", back_populates="user_group_memberships")
user = relationship("User", back_populates="user_group_memberships")
- def __init__(self, user_id: int, group_id: int):
- self.user_id = user_id
- self.group_id = group_id
-
@classmethod
def get_ugm_by_id(
cls, dbsession: SqlASession, ugm_id: Optional[int]
diff --git a/server/camcops_server/cc_modules/cc_request.py b/server/camcops_server/cc_modules/cc_request.py
index 4977733aa..30506a864 100644
--- a/server/camcops_server/cc_modules/cc_request.py
+++ b/server/camcops_server/cc_modules/cc_request.py
@@ -2540,8 +2540,5 @@ def get_unittest_request(
req.set_get_params(params)
req._debugging_db_session = dbsession
- user = User()
- user.superuser = True
- req._debugging_user = user
return req
diff --git a/server/camcops_server/cc_modules/cc_testfactories.py b/server/camcops_server/cc_modules/cc_testfactories.py
index 635533445..d499c8a82 100644
--- a/server/camcops_server/cc_modules/cc_testfactories.py
+++ b/server/camcops_server/cc_modules/cc_testfactories.py
@@ -27,19 +27,27 @@
"""
+from typing import cast, Optional, TYPE_CHECKING
+
from cardinal_pythonlib.datetimefunc import (
convert_datetime_to_utc,
format_datetime,
)
import factory
+from faker import Faker
import pendulum
-from camcops_server.cc_modules.cc_constants import DateFormat
+from camcops_server.cc_modules.cc_blob import Blob
+from camcops_server.cc_modules.cc_constants import DateFormat, ERA_NOW
from camcops_server.cc_modules.cc_device import Device
from camcops_server.cc_modules.cc_email import Email
from camcops_server.cc_modules.cc_group import Group
+from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition
+from camcops_server.cc_modules.cc_ipuse import IpUse
from camcops_server.cc_modules.cc_membership import UserGroupMembership
from camcops_server.cc_modules.cc_patient import Patient
+from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
+from camcops_server.cc_modules.cc_testproviders import register_all_providers
from camcops_server.cc_modules.cc_taskschedule import (
PatientTaskSchedule,
PatientTaskScheduleEmail,
@@ -48,89 +56,282 @@
)
from camcops_server.cc_modules.cc_user import User
+if TYPE_CHECKING:
+ from factory.builder import Resolver
+ from camcops_server.cc_modules.cc_request import CamcopsRequest
+
+
+# Avoid any ID clashes with objects not created with factories
+ID_OFFSET = 1000
+RIO_ID_OFFSET = 10000
+STUDY_ID_OFFSET = 5000
+
+
+class Fake:
+ # Factory Boy has its own interface to Faker (factory.Faker()). This
+ # takes a function to be called at object generation time and as far as I
+ # can tell this doesn't support being able to create fake data based on
+ # other fake attributes such as notes for a patient. You can work
+ # around this by adding a lot of logic to the factories. To me it makes
+ # sense to keep the factories simple and do as much as possible of the
+ # content generation in the providers. So we call Faker directly instead.
+ en_gb = Faker("en_GB") # For UK postcodes, phone numbers etc
+ en_us = Faker("en_US") # en_GB gives Lorem ipsum for pad words.
+
+
+register_all_providers(Fake.en_gb)
+
# sqlalchemy_session gets poked in by DemoRequestCase.setUp()
class BaseFactory(factory.alchemy.SQLAlchemyModelFactory):
- pass
+ class Meta:
+ sqlalchemy_session_persistence = "commit"
class DeviceFactory(BaseFactory):
class Meta:
model = Device
- id = factory.Sequence(lambda n: n)
- name = factory.Sequence(lambda n: f"Test device {n}")
+ id = factory.Sequence(lambda n: n + ID_OFFSET)
+ name = factory.Sequence(lambda n: f"test-device-{n + ID_OFFSET}")
+
+
+class IpUseFactory(BaseFactory):
+ class Meta:
+ model = IpUse
+
+ clinical = factory.LazyFunction(Fake.en_gb.pybool)
+ commercial = factory.LazyFunction(Fake.en_gb.pybool)
+ educational = factory.LazyFunction(Fake.en_gb.pybool)
+ research = factory.LazyFunction(Fake.en_gb.pybool)
class GroupFactory(BaseFactory):
class Meta:
model = Group
- id = factory.Sequence(lambda n: n)
- name = factory.Sequence(lambda n: f"Group {n}")
+ id = factory.Sequence(lambda n: n + ID_OFFSET)
+ name = factory.Sequence(lambda n: f"Group {n + ID_OFFSET}")
+ ip_use = factory.SubFactory(IpUseFactory)
+
+
+class AnyIdNumGroupFactory(GroupFactory):
+ upload_policy = "sex and anyidnum"
+ finalize_policy = "sex and anyidnum"
class UserFactory(BaseFactory):
class Meta:
model = User
- id = factory.Sequence(lambda n: n)
username = factory.Sequence(lambda n: f"user{n}")
hashedpw = ""
+ @factory.post_generation
+ def password(
+ obj: User,
+ create: bool,
+ password: Optional[str],
+ request: "CamcopsRequest" = None,
+ **kwargs,
+ ) -> None:
+ if not create:
+ return
+
+ if password is None:
+ return
+
+ assert request is not None
+
+ obj.set_password(request, password)
+
class GenericTabletRecordFactory(BaseFactory):
+ class Meta:
+ exclude = ("default_iso_datetime",)
+ abstract = True
+
default_iso_datetime = "1970-01-01T12:00"
+ _pk = factory.Sequence(lambda n: n + ID_OFFSET)
_device = factory.SubFactory(DeviceFactory)
- _group = factory.SubFactory(GroupFactory)
+ _group = factory.SubFactory(AnyIdNumGroupFactory)
_adding_user = factory.SubFactory(UserFactory)
@factory.lazy_attribute
- def _when_added_exact(self) -> pendulum.DateTime:
- return pendulum.parse(self.default_iso_datetime)
+ def _when_added_exact(obj: "Resolver") -> pendulum.DateTime:
+ datetime = cast(
+ pendulum.DateTime, pendulum.parse(obj.default_iso_datetime)
+ )
+
+ return datetime
@factory.lazy_attribute
- def _when_added_batch_utc(self) -> pendulum.DateTime:
- era_time = pendulum.parse(self.default_iso_datetime)
+ def _when_added_batch_utc(obj: "Resolver") -> pendulum.DateTime:
+ era_time = pendulum.parse(obj.default_iso_datetime)
return convert_datetime_to_utc(era_time)
@factory.lazy_attribute
- def _era(self) -> str:
- era_time = pendulum.parse(self.default_iso_datetime)
+ def _era(obj: "Resolver") -> str:
+ era_time = pendulum.parse(obj.default_iso_datetime)
return format_datetime(era_time, DateFormat.ISO8601)
@factory.lazy_attribute
- def _current(self) -> bool:
+ def _current(obj: "Resolver") -> bool:
# _current = True gets ignored for some reason
return True
- class Meta:
- exclude = ("default_iso_datetime",)
- abstract = True
-
class PatientFactory(GenericTabletRecordFactory):
class Meta:
model = Patient
- id = factory.Sequence(lambda n: n)
+ id = factory.Sequence(lambda n: n + ID_OFFSET)
+ sex = factory.LazyFunction(Fake.en_gb.sex)
+ dob = factory.LazyFunction(Fake.en_gb.consistent_date_of_birth)
+ address = factory.LazyFunction(Fake.en_gb.address)
+ gp = factory.LazyFunction(Fake.en_gb.name)
+ other = factory.LazyFunction(Fake.en_us.paragraph)
+ email = factory.LazyFunction(Fake.en_gb.email)
+
+ @factory.lazy_attribute
+ def forename(obj: "Resolver") -> str:
+ return Fake.en_gb.forename(obj.sex)
+
+ surname = factory.LazyFunction(Fake.en_gb.last_name)
class ServerCreatedPatientFactory(PatientFactory):
@factory.lazy_attribute
- def _device(self) -> Device:
- # Should have been created in BasicDatabaseTestCase.setUp
+ def _device(obj: "Resolver") -> Device:
+ # May have been created in BasicDatabaseTestCase.setUp
return Device.get_server_device(
ServerCreatedPatientFactory._meta.sqlalchemy_session
)
+ @factory.lazy_attribute
+ def _era(obj: "Resolver") -> str:
+ return ERA_NOW
+
+
+class IdNumDefinitionFactory(BaseFactory):
+ class Meta:
+ model = IdNumDefinition
+
+ which_idnum = factory.Sequence(lambda n: n + ID_OFFSET)
+
+
+class NHSIdNumDefinitionFactory(IdNumDefinitionFactory):
+ description = "NHS number"
+ short_description = "NHS#"
+ hl7_assigning_authority = "NHS"
+ hl7_id_type = "NHSN"
+
+
+class StudyIdNumDefinitionFactory(IdNumDefinitionFactory):
+ description = "Study number"
+ short_description = "Study"
+
+
+class RioIdNumDefinitionFactory(IdNumDefinitionFactory):
+ description = "RiO number"
+ short_description = "RiO"
+ hl7_assigning_authority = "CPFT"
+ hl7_id_type = "CPRiO"
+
+
+class PatientIdNumFactory(GenericTabletRecordFactory):
+ class Meta:
+ model = PatientIdNum
+
+ id = factory.Sequence(lambda n: n + ID_OFFSET)
+ patient = factory.SubFactory(PatientFactory)
+ patient_id = factory.SelfAttribute("patient.id")
+ _group = factory.SelfAttribute("patient._group")
+ _device = factory.SelfAttribute("patient._device")
+
+
+class NHSPatientIdNumFactory(PatientIdNumFactory):
+ class Meta:
+ exclude = PatientIdNumFactory._meta.exclude + ("idnum",)
+
+ idnum = factory.SubFactory(NHSIdNumDefinitionFactory)
+
+ which_idnum = factory.SelfAttribute("idnum.which_idnum")
+ idnum_value = factory.LazyFunction(Fake.en_gb.nhs_number)
+
+
+class RioPatientIdNumFactory(PatientIdNumFactory):
+ class Meta:
+ exclude = PatientIdNumFactory._meta.exclude + ("idnum",)
+
+ idnum = factory.SubFactory(RioIdNumDefinitionFactory)
+
+ which_idnum = factory.SelfAttribute("idnum.which_idnum")
+ idnum_value = factory.Sequence(lambda n: n + RIO_ID_OFFSET)
+
+
+class StudyPatientIdNumFactory(PatientIdNumFactory):
+ class Meta:
+ exclude = PatientIdNumFactory._meta.exclude + ("idnum",)
+
+ idnum = factory.SubFactory(StudyIdNumDefinitionFactory)
+
+ which_idnum = factory.SelfAttribute("idnum.which_idnum")
+ idnum_value = factory.Sequence(lambda n: n + STUDY_ID_OFFSET)
+
+
+class ServerCreatedPatientIdNumFactory(PatientIdNumFactory):
+ patient = factory.SubFactory(ServerCreatedPatientFactory)
+
+ @factory.lazy_attribute
+ def _device(obj: "Resolver") -> Device:
+ # Should have been created in BasicDatabaseTestCase.setUp
+ return Device.get_server_device(
+ ServerCreatedPatientIdNumFactory._meta.sqlalchemy_session
+ )
+
+ @factory.lazy_attribute
+ def _era(obj: "Resolver") -> str:
+ return ERA_NOW
+
+
+class ServerCreatedNHSPatientIdNumFactory(
+ ServerCreatedPatientIdNumFactory, NHSPatientIdNumFactory
+):
+ class Meta:
+ exclude = (
+ ServerCreatedPatientIdNumFactory._meta.exclude
+ + NHSPatientIdNumFactory._meta.exclude
+ )
+
+
+class ServerCreatedRioPatientIdNumFactory(
+ ServerCreatedPatientIdNumFactory, RioPatientIdNumFactory
+):
+ class Meta:
+ exclude = (
+ ServerCreatedPatientIdNumFactory._meta.exclude
+ + RioPatientIdNumFactory._meta.exclude
+ )
+
+
+class ServerCreatedStudyPatientIdNumFactory(
+ ServerCreatedPatientIdNumFactory, StudyPatientIdNumFactory
+):
+ class Meta:
+ exclude = (
+ ServerCreatedPatientIdNumFactory._meta.exclude
+ + StudyPatientIdNumFactory._meta.exclude
+ )
+
class TaskScheduleFactory(BaseFactory):
class Meta:
model = TaskSchedule
group = factory.SubFactory(GroupFactory)
+ name = factory.Sequence(lambda n: f"Schedule {n + ID_OFFSET}")
class TaskScheduleItemFactory(BaseFactory):
@@ -158,37 +359,45 @@ class EmailFactory(BaseFactory):
class Meta:
model = Email
+ # Although sent and sent_at_utc are columns, they are not keyword
+ # arguments to Email's constructor so they are populated after the object
+ # has been created. For some reason 'sent' needs to be set explicitly
+ # when creating the factory even though the default should be False. Might
+ # be a SQLite thing.
@factory.post_generation
def sent_at_utc(
- self, create: bool, sent_at_utc: pendulum.DateTime, **kwargs
+ obj: Email, create: bool, sent_at_utc: pendulum.DateTime, **kwargs
) -> None:
if not create:
return
- self.sent_at_utc = sent_at_utc
+ obj.sent_at_utc = sent_at_utc
@factory.post_generation
- def sent(self, create: bool, sent: bool, **kwargs) -> None:
+ def sent(obj: Email, create: bool, sent: bool, **kwargs) -> None:
if not create:
return
- self.sent = sent
+ obj.sent = sent
class PatientTaskScheduleEmailFactory(BaseFactory):
class Meta:
model = PatientTaskScheduleEmail
+ patient_task_schedule = factory.SubFactory(
+ PatientTaskScheduleFactory,
+ )
+ email = factory.SubFactory(EmailFactory, sent=True)
+
class UserGroupMembershipFactory(BaseFactory):
class Meta:
model = UserGroupMembership
- @factory.post_generation
- def may_run_reports(
- self, create: bool, may_run_reports: bool, **kwargs
- ) -> None:
- if not create:
- return
- self.may_run_reports = may_run_reports
+class BlobFactory(GenericTabletRecordFactory):
+ class Meta:
+ model = Blob
+
+ id = factory.Sequence(lambda n: n + ID_OFFSET)
diff --git a/server/camcops_server/cc_modules/cc_testproviders.py b/server/camcops_server/cc_modules/cc_testproviders.py
new file mode 100644
index 000000000..32513dc2a
--- /dev/null
+++ b/server/camcops_server/cc_modules/cc_testproviders.py
@@ -0,0 +1,154 @@
+"""
+camcops_server/cc_modules/cc_testproviders.py
+
+===============================================================================
+
+ Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
+ Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
+
+ This file is part of CamCOPS.
+
+ CamCOPS is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ CamCOPS is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with CamCOPS. If not, see .
+
+===============================================================================
+
+**Faker test data providers.**
+
+There may be some interest in a Faker Medical community provider if we felt it
+was worth the effort.
+
+https://github.com/joke2k/faker/issues/1142
+
+See also duplicate functionality in CRATE: crate_anon/testing/providers.py
+
+"""
+
+import datetime
+
+from cardinal_pythonlib.nhs import generate_random_nhs_number
+from faker import Faker
+from faker.providers import BaseProvider
+import pendulum
+from pendulum import DateTime as Pendulum
+from typing import Any, List
+
+
+class NhsNumberProvider(BaseProvider):
+ def nhs_number(self) -> str:
+ return generate_random_nhs_number()
+
+
+class ChoiceProvider(BaseProvider):
+ def random_choice(self, choices: List, **kwargs) -> Any:
+ """
+ Given a list of choices return a random value
+ """
+ choices = self.generator.random.choices(choices, **kwargs)
+
+ return choices[0]
+
+
+# No one is born after this
+_max_birth_datetime = Pendulum(year=2000, month=1, day=1, hour=9)
+
+
+class ConsistentDateOfBirthProvider(BaseProvider):
+ """
+ Faker date_of_birth calculates from the current time so gives different
+ results on different days.
+ """
+
+ def consistent_date_of_birth(self) -> datetime.datetime:
+ return self.generator.date_between_dates(
+ date_start=pendulum.date(1900, 1, 1),
+ date_end=_max_birth_datetime,
+ )
+
+
+class ForenameProvider(BaseProvider):
+ """
+ Return a forename given the sex of the person
+ """
+
+ def forename(self, sex: str) -> str:
+ if sex == "M":
+ return self.generator.first_name_male()
+
+ if sex == "F":
+ return self.generator.first_name_female()
+
+ return self.generator.first_name()[:1]
+
+
+class HeightProvider(BaseProvider):
+ def height_m(self) -> float:
+ """
+ Return a random patient height in metres
+ """
+
+ return float(self.generator.random_int(min=145, max=191) / 100.0)
+
+
+class MassProvider(BaseProvider):
+ def mass_kg(self) -> float:
+ """
+ Return a random patient mass in kilograms
+ """
+
+ return float(self.generator.random_int(min=400, max=1000) / 10.0)
+
+
+class SexProvider(ChoiceProvider):
+ """
+ Return a random sex, with realistic distribution.
+ """
+
+ def sex(self) -> str:
+ return self.random_choice(["M", "F", "X"], weights=[49.8, 49.8, 0.4])
+
+
+class ValidPhoneNumberProvider(BaseProvider):
+ """
+ Return a random mobile phone number
+ """
+
+ # The default Faker phone_number provider for en_GB uses
+ # https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501
+ # 07700 900000 to 900999 reserved for TV and Radio drama purposes
+ # but unfortunately the phonenumbers library considers these invalid.
+ def valid_phone_number(self) -> str:
+ number = self.generator.random_int(min=7000000000, max=7999999999)
+
+ return f"+44{number}"
+
+
+class WaistProvider(BaseProvider):
+ """
+ Return a random waist circumference in centimetres
+ """
+
+ def waist_cm(self) -> float:
+ return float(self.generator.random_int(min=40, max=130))
+
+
+def register_all_providers(fake: Faker) -> None:
+ fake.add_provider(ChoiceProvider)
+ fake.add_provider(ConsistentDateOfBirthProvider)
+ fake.add_provider(ForenameProvider)
+ fake.add_provider(HeightProvider)
+ fake.add_provider(MassProvider)
+ fake.add_provider(NhsNumberProvider)
+ fake.add_provider(ValidPhoneNumberProvider)
+ fake.add_provider(WaistProvider)
+ fake.add_provider(SexProvider)
diff --git a/server/camcops_server/cc_modules/cc_unittest.py b/server/camcops_server/cc_modules/cc_unittest.py
index 6ed2e9c54..d20b781b9 100644
--- a/server/camcops_server/cc_modules/cc_unittest.py
+++ b/server/camcops_server/cc_modules/cc_unittest.py
@@ -29,48 +29,44 @@
import base64
import copy
+from faker import Faker
import logging
import os
+import random
import sqlite3
-from typing import Any, List, Type, TYPE_CHECKING
-import unittest
+from typing import Any, Dict, List, Type, TYPE_CHECKING
+from unittest import mock, TestCase
from cardinal_pythonlib.classes import all_subclasses
from cardinal_pythonlib.dbfunc import get_fieldnames_from_cursor
from cardinal_pythonlib.httpconst import MimeType
from cardinal_pythonlib.logs import BraceStyleAdapter
-import pendulum
import pytest
from sqlalchemy.engine.base import Engine
from camcops_server.cc_modules.cc_baseconstants import ENVVAR_CONFIG_FILE
-from camcops_server.cc_modules.cc_constants import ERA_NOW
from camcops_server.cc_modules.cc_device import Device
from camcops_server.cc_modules.cc_exportrecipient import ExportRecipient
-from camcops_server.cc_modules.cc_group import Group
-from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition
-from camcops_server.cc_modules.cc_ipuse import IpUse
from camcops_server.cc_modules.cc_request import (
CamcopsRequest,
get_unittest_request,
)
from camcops_server.cc_modules.cc_sqlalchemy import sql_from_sqlite_database
+from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin
from camcops_server.cc_modules.cc_user import User
-from camcops_server.cc_modules.cc_membership import UserGroupMembership
from camcops_server.cc_modules.cc_testfactories import (
BaseFactory,
- DeviceFactory,
GroupFactory,
+ NHSPatientIdNumFactory,
+ PatientFactory,
+ RioPatientIdNumFactory,
UserFactory,
+ UserGroupMembershipFactory,
)
-from camcops_server.cc_modules.cc_version import CAMCOPS_SERVER_VERSION
+from camcops_server.tasks.tests import factories as task_factories
if TYPE_CHECKING:
from sqlalchemy.orm import Session
- from camcops_server.cc_modules.cc_db import GenericTabletRecordMixin
- from camcops_server.cc_modules.cc_patient import Patient
- from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
- from camcops_server.cc_modules.cc_task import Task
log = BraceStyleAdapter(logging.getLogger(__name__))
@@ -91,7 +87,15 @@
# =============================================================================
-class ExtendedTestCase(unittest.TestCase):
+class ExtendedTestCase(TestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+
+ # Arbitrary seed
+ Faker.seed(1234)
+ random.seed(1234)
+
"""
A subclass of :class:`unittest.TestCase` that provides some additional
functionality.
@@ -135,6 +139,8 @@ class DemoRequestTestCase(ExtendedTestCase):
db_filename: str
def setUp(self) -> None:
+ super().setUp()
+
for factory in all_subclasses(BaseFactory):
factory._meta.sqlalchemy_session = self.dbsession
@@ -148,7 +154,7 @@ def setUp(self) -> None:
# afterwards.
self.old_config = copy.copy(self.req.config)
- self.req.matched_route = unittest.mock.Mock()
+ self.req.matched_route = mock.Mock()
self.recipdef = ExportRecipient()
def tearDown(self) -> None:
@@ -215,246 +221,29 @@ def dump_table(
class BasicDatabaseTestCase(DemoRequestTestCase):
"""
- Test case that sets up some useful database records for testing:
- ID numbers, user, group, devices etc and has helper methods for
- creating patients and tasks
+ Test case that sets up some minimal database records for testing.
"""
def setUp(self) -> None:
super().setUp()
- self.set_era("2010-07-07T13:40+0100")
-
- # Set up groups, users, etc.
- # ... ID number definitions
- idnum_type_nhs = 1
- idnum_type_rio = 2
- idnum_type_study = 3
- self.nhs_iddef = IdNumDefinition(
- which_idnum=idnum_type_nhs,
- description="NHS number",
- short_description="NHS#",
- hl7_assigning_authority="NHS",
- hl7_id_type="NHSN",
- )
- self.dbsession.add(self.nhs_iddef)
- self.rio_iddef = IdNumDefinition(
- which_idnum=idnum_type_rio,
- description="RiO number",
- short_description="RiO",
- hl7_assigning_authority="CPFT",
- hl7_id_type="CPRiO",
- )
- self.dbsession.add(self.rio_iddef)
- self.study_iddef = IdNumDefinition(
- which_idnum=idnum_type_study,
- description="Study number",
- short_description="Study",
- )
- self.dbsession.add(self.study_iddef)
- # ... group
- self.group = Group()
- self.group.name = "testgroup"
- self.group.description = "Test group"
- self.group.upload_policy = "sex AND anyidnum"
- self.group.finalize_policy = "sex AND idnum1"
- self.group.ip_use = IpUse()
- self.dbsession.add(self.group)
- self.dbsession.flush() # sets PK fields
- GroupFactory.reset_sequence(self.group.id + 1)
-
- # ... users
-
- self.user = User.get_system_user(self.dbsession)
- self.user.upload_group_id = self.group.id
- self.req._debugging_user = self.user # improve our debugging user
-
- # ... devices
- self.server_device = Device.get_server_device(self.dbsession)
- DeviceFactory.reset_sequence(self.server_device.id + 1)
- self.other_device = DeviceFactory(
- name="other_device",
- friendly_name="Test device that may upload",
- registered_by_user=self.user,
- when_registered_utc=self.era_time_utc,
- camcops_version=CAMCOPS_SERVER_VERSION,
- )
- # ... export recipient definition (the minimum)
- self.recipdef.primary_idnum = idnum_type_nhs
-
- self.dbsession.flush() # sets PK fields
- UserFactory.reset_sequence(self.user.id + 1)
-
- self.create_tasks()
-
- def set_era(self, iso_datetime: str) -> None:
- from cardinal_pythonlib.datetimefunc import (
- convert_datetime_to_utc,
- format_datetime,
- )
- from camcops_server.cc_modules.cc_constants import DateFormat
-
- self.era_time = pendulum.parse(iso_datetime)
- self.era_time_utc = convert_datetime_to_utc(self.era_time)
- self.era = format_datetime(self.era_time, DateFormat.ISO8601)
-
- def create_patient_with_two_idnums(self) -> "Patient":
- from camcops_server.cc_modules.cc_patient import Patient
- from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
-
- # Populate database with two of everything
- patient = Patient()
- patient.id = 1
- self.apply_standard_db_fields(patient)
- patient.forename = "Forename1"
- patient.surname = "Surname1"
- patient.dob = pendulum.parse("1950-01-01")
- self.dbsession.add(patient)
- patient_idnum1 = PatientIdNum()
- patient_idnum1.id = 1
- self.apply_standard_db_fields(patient_idnum1)
- patient_idnum1.patient_id = patient.id
- patient_idnum1.which_idnum = self.nhs_iddef.which_idnum
- patient_idnum1.idnum_value = 333
- self.dbsession.add(patient_idnum1)
- patient_idnum2 = PatientIdNum()
- patient_idnum2.id = 2
- self.apply_standard_db_fields(patient_idnum2)
- patient_idnum2.patient_id = patient.id
- patient_idnum2.which_idnum = self.rio_iddef.which_idnum
- patient_idnum2.idnum_value = 444
- self.dbsession.add(patient_idnum2)
- self.dbsession.commit()
-
- return patient
+ self.group = GroupFactory()
+ self.groupadmin = UserFactory()
- def create_patient_with_one_idnum(self) -> "Patient":
- from camcops_server.cc_modules.cc_patient import Patient
+ self.superuser = UserFactory(superuser=True)
- patient = Patient()
- patient.id = 2
- self.apply_standard_db_fields(patient)
- patient.forename = "Forename2"
- patient.surname = "Surname2"
- patient.dob = pendulum.parse("1975-12-12")
- self.dbsession.add(patient)
-
- self.create_patient_idnum(
- id=3,
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=555,
+ UserGroupMembershipFactory(
+ group_id=self.group.id, user_id=self.groupadmin.id, groupadmin=True
)
- return patient
-
- def create_patient_idnum(
- self, as_server_patient: bool = False, **kwargs: Any
- ) -> "PatientIdNum":
- from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
-
- patient_idnum = PatientIdNum()
- self.apply_standard_db_fields(patient_idnum, era_now=as_server_patient)
-
- for key, value in kwargs.items():
- setattr(patient_idnum, key, value)
-
- if "id" not in kwargs:
- patient_idnum.save_with_next_available_id(
- self.req, patient_idnum._device_id
- )
- else:
- self.dbsession.add(patient_idnum)
-
- self.dbsession.commit()
-
- return patient_idnum
-
- def create_patient(
- self, as_server_patient: bool = False, **kwargs: Any
- ) -> "Patient":
- from camcops_server.cc_modules.cc_patient import Patient
+ self.system_user = User.get_system_user(self.dbsession)
+ self.system_user.upload_group_id = self.group.id
- patient = Patient()
- self.apply_standard_db_fields(patient, era_now=as_server_patient)
-
- for key, value in kwargs.items():
- setattr(patient, key, value)
-
- if "id" not in kwargs:
- patient.save_with_next_available_id(self.req, patient._device_id)
- else:
- self.dbsession.add(patient)
+ self.req._debugging_user = self.superuser # improve our debugging user
+ self.server_device = Device.get_server_device(self.dbsession)
self.dbsession.commit()
- return patient
-
- def create_tasks(self) -> None:
- # Override in subclass
- pass
-
- def apply_standard_task_fields(self, task: "Task") -> None:
- """
- Writes some default values to an SQLAlchemy ORM object representing
- a task.
- """
- self.apply_standard_db_fields(task)
- task.when_created = self.era_time
-
- def apply_standard_db_fields(
- self, obj: "GenericTabletRecordMixin", era_now: bool = False
- ) -> None:
- """
- Writes some default values to an SQLAlchemy ORM object representing a
- record uploaded from a client (tablet) device.
-
- Though we use the server device ID.
- """
- obj._device_id = self.server_device.id
- obj._era = ERA_NOW if era_now else self.era
- obj._group_id = self.group.id
- obj._current = True
- obj._adding_user_id = self.user.id
- obj._when_added_batch_utc = self.era_time_utc
-
- def create_user(self, **kwargs) -> User:
- user = User()
- user.hashedpw = ""
-
- for key, value in kwargs.items():
- setattr(user, key, value)
-
- self.dbsession.add(user)
-
- return user
-
- def create_group(self, name: str, **kwargs) -> Group:
- group = Group()
- group.name = name
-
- for key, value in kwargs.items():
- setattr(group, key, value)
-
- self.dbsession.add(group)
-
- return group
-
- def create_membership(
- self, user: User, group: Group, **kwargs
- ) -> UserGroupMembership:
- ugm = UserGroupMembership(user_id=user.id, group_id=group.id)
-
- for key, value in kwargs.items():
- setattr(ugm, key, value)
-
- self.dbsession.add(ugm)
-
- return ugm
-
- def tearDown(self) -> None:
- pass
-
class DemoDatabaseTestCase(BasicDatabaseTestCase):
"""
@@ -462,43 +251,39 @@ class DemoDatabaseTestCase(BasicDatabaseTestCase):
each type
"""
- def create_tasks(self) -> None:
- from camcops_server.cc_modules.cc_blob import Blob
- from camcops_server.tasks.photo import Photo
- from camcops_server.cc_modules.cc_task import Task
+ def setUp(self) -> None:
+ super().setUp()
- patient_with_two_idnums = self.create_patient_with_two_idnums()
- patient_with_one_idnum = self.create_patient_with_one_idnum()
+ self.demo_database_group = GroupFactory()
- for cls in Task.all_subclasses_by_tablename():
- t1 = cls()
- t1.id = 1
- self.apply_standard_task_fields(t1)
- if t1.has_patient:
- t1.patient_id = patient_with_two_idnums.id
-
- if isinstance(t1, Photo):
- b = Blob()
- b.id = 1
- self.apply_standard_db_fields(b)
- b.tablename = t1.tablename
- b.tablepk = t1.id
- b.fieldname = "photo_blobid"
- b.filename = "some_picture.png"
- b.mimetype = MimeType.PNG
- b.image_rotation_deg_cw = 0
- b.theblob = DEMO_PNG_BYTES
- self.dbsession.add(b)
-
- t1.photo_blobid = b.id
-
- self.dbsession.add(t1)
-
- t2 = cls()
- t2.id = 2
- self.apply_standard_task_fields(t2)
- if t2.has_patient:
- t2.patient_id = patient_with_one_idnum.id
- self.dbsession.add(t2)
+ patient_with_two_idnums = PatientFactory(
+ _group=self.demo_database_group
+ )
+ NHSPatientIdNumFactory(patient=patient_with_two_idnums)
+ RioPatientIdNumFactory(patient=patient_with_two_idnums)
- self.dbsession.commit()
+ patient_with_one_idnum = PatientFactory(
+ _group=self.demo_database_group
+ )
+ NHSPatientIdNumFactory(patient=patient_with_one_idnum)
+
+ for cls in Task.all_subclasses_by_tablename():
+ factory_class = getattr(task_factories, f"{cls.__name__}Factory")
+
+ t1_kwargs: Dict[str, Any] = dict(_group=self.demo_database_group)
+ t2_kwargs = t1_kwargs
+ if issubclass(cls, TaskHasPatientMixin):
+ t1_kwargs.update(patient=patient_with_two_idnums)
+ t2_kwargs.update(patient=patient_with_one_idnum)
+
+ if cls.__name__ == "Photo":
+ t1_kwargs.update(
+ create_blob__fieldname="photo_blobid",
+ create_blob__filename="some_picture.png",
+ create_blob__mimetype=MimeType.PNG,
+ create_blob__image_rotation_deg_cw=0,
+ create_blob__theblob=DEMO_PNG_BYTES,
+ )
+
+ factory_class(**t1_kwargs)
+ factory_class(**t2_kwargs)
diff --git a/server/camcops_server/cc_modules/tests/cc_fhir_tests.py b/server/camcops_server/cc_modules/tests/cc_fhir_tests.py
index 69475a630..6b5af11fa 100644
--- a/server/camcops_server/cc_modules/tests/cc_fhir_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_fhir_tests.py
@@ -32,11 +32,10 @@
from unittest import mock
from cardinal_pythonlib.httpconst import HttpMethod
-from cardinal_pythonlib.nhs import generate_random_nhs_number
import pendulum
from requests.exceptions import HTTPError
-from camcops_server.cc_modules.cc_constants import FHIRConst as Fc, JSON_INDENT
+from camcops_server.cc_modules.cc_constants import FHIRConst as Fc
from camcops_server.cc_modules.cc_exportmodels import (
ExportedTask,
ExportedTaskFhir,
@@ -53,36 +52,31 @@
FhirTaskExporter,
)
from camcops_server.cc_modules.cc_pyramid import Routes
-from camcops_server.cc_modules.cc_unittest import DemoDatabaseTestCase
+from camcops_server.cc_modules.cc_testfactories import (
+ NHSPatientIdNumFactory,
+ PatientFactory,
+ RioPatientIdNumFactory,
+)
+
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
from camcops_server.cc_modules.cc_version_string import (
CAMCOPS_SERVER_VERSION_STRING,
)
-from camcops_server.tasks.apeqpt import Apeqpt
-from camcops_server.tasks.bmi import Bmi
-from camcops_server.tasks.gad7 import Gad7
-from camcops_server.tasks.diagnosis import (
- DiagnosisIcd10,
- DiagnosisIcd10Item,
- DiagnosisIcd9CM,
- DiagnosisIcd9CMItem,
+from camcops_server.tasks.tests.factories import (
+ ApeqptFactory,
+ BmiFactory,
+ DiagnosisIcd10Factory,
+ DiagnosisIcd10ItemFactory,
+ DiagnosisIcd9CMFactory,
+ DiagnosisIcd9CMItemFactory,
+ Gad7Factory,
+ Phq9Factory,
)
-from camcops_server.tasks.phq9 import Phq9
log = logging.getLogger()
-# =============================================================================
-# Constants
-# =============================================================================
-
-TEST_NHS_NUMBER = generate_random_nhs_number()
-TEST_RIO_NUMBER = 12345
-TEST_FORENAME = "Gwendolyn"
-TEST_SURNAME = "Ryann"
-TEST_SEX = "F"
-
-
# =============================================================================
# Helper classes
# =============================================================================
@@ -100,13 +94,12 @@ def __init__(self, response_json: Dict):
)
-class FhirExportTestCase(DemoDatabaseTestCase):
+class FhirExportTestCase(DemoRequestTestCase):
def setUp(self) -> None:
super().setUp()
recipientinfo = ExportRecipientInfo()
self.recipient = ExportRecipient(recipientinfo)
- self.recipient.primary_idnum = self.rio_iddef.which_idnum
self.recipient.fhir_api_url = "https://www.example.com/fhir"
# auto increment doesn't work for BigInteger with SQLite
@@ -116,20 +109,16 @@ def setUp(self) -> None:
self.camcops_root_url = self.req.route_url(Routes.HOME).rstrip("/")
# ... no trailing slash
- def create_fhir_patient(self) -> None:
- self.patient = self.create_patient(
- forename=TEST_FORENAME, surname=TEST_SURNAME, sex=TEST_SEX
- )
- self.patient_nhs = self.create_patient_idnum(
- patient_id=self.patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- )
- self.patient_rio = self.create_patient_idnum(
- patient_id=self.patient.id,
- which_idnum=self.rio_iddef.which_idnum,
- idnum_value=TEST_RIO_NUMBER,
- )
+
+class FhirExportPatientTestCase(FhirExportTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.patient = PatientFactory()
+ self.patient_nhs_idnum = NHSPatientIdNumFactory(patient=self.patient)
+ self.patient_rio_idnum = RioPatientIdNumFactory(patient=self.patient)
+
+ self.recipient.primary_idnum = self.patient_rio_idnum.which_idnum
# =============================================================================
@@ -137,27 +126,23 @@ def create_fhir_patient(self) -> None:
# =============================================================================
-class FhirTaskExporterPhq9Tests(FhirExportTestCase):
- def create_tasks(self) -> None:
- self.create_fhir_patient()
-
- self.task = Phq9()
- self.apply_standard_task_fields(self.task)
- self.task.q1 = 0
- self.task.q2 = 1
- self.task.q3 = 2
- self.task.q4 = 3
- self.task.q5 = 0
- self.task.q6 = 1
- self.task.q7 = 2
- self.task.q8 = 3
- self.task.q9 = 0
- self.task.q10 = 3
- self.task.patient_id = self.patient.id
- self.task.save_with_next_available_id(
- self.req, self.patient._device_id
+class FhirTaskExporterPhq9Tests(FhirExportPatientTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.task = Phq9Factory(
+ patient=self.patient,
+ q1=0,
+ q2=1,
+ q3=2,
+ q4=3,
+ q5=0,
+ q6=1,
+ q7=2,
+ q8=3,
+ q9=0,
+ q10=3,
)
- self.dbsession.commit()
def test_patient_exported(self) -> None:
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
@@ -185,7 +170,7 @@ def test_patient_exported(self) -> None:
self.assertEqual(patient[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_PATIENT)
identifier = patient[Fc.IDENTIFIER]
- idnum_value = self.patient_rio.idnum_value
+ idnum_value = self.patient_rio_idnum.idnum_value
patient_id = self.patient.get_fhir_identifier(self.req, self.recipient)
@@ -346,7 +331,7 @@ def test_questionnaire_response_exported(self) -> None:
subject = response[Fc.SUBJECT]
identifier = subject[Fc.IDENTIFIER]
self.assertEqual(subject[Fc.TYPE], Fc.RESOURCE_TYPE_PATIENT)
- idnum_value = self.patient_rio.idnum_value
+ idnum_value = self.patient_rio_idnum.idnum_value
patient_id = self.patient.get_fhir_identifier(self.req, self.recipient)
if isinstance(identifier, list):
@@ -534,18 +519,17 @@ def test_raises_for_missing_api_url(self) -> None:
class FhirTaskExporterAnonymousTests(FhirExportTestCase):
- def create_tasks(self) -> None:
- self.task = Apeqpt()
- self.apply_standard_task_fields(self.task)
- self.task.q_datetime = pendulum.now()
- self.task.q1_choice = 0
- self.task.q2_choice = 1
- self.task.q3_choice = 2
- self.task.q1_satisfaction = 3
- self.task.q2_satisfaction = "Service experience"
-
- self.task.save_with_next_available_id(self.req, self.server_device.id)
- self.dbsession.commit()
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.task = ApeqptFactory(
+ q_datetime=pendulum.now(),
+ q1_choice=0,
+ q2_choice=1,
+ q3_choice=2,
+ q1_satisfaction=3,
+ q2_satisfaction="Service experience",
+ )
def test_questionnaire_exported(self) -> None:
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
@@ -764,45 +748,73 @@ def test_questionnaire_response_exported(self) -> None:
# =============================================================================
-class FhirTaskExporterBMITests(FhirExportTestCase):
- def create_tasks(self) -> None:
- self.create_fhir_patient()
+class FhirTaskExporterBMITests(FhirExportPatientTestCase):
+ def setUp(self) -> None:
+ super().setUp()
- self.task = Bmi()
- self.apply_standard_task_fields(self.task)
- self.task.mass_kg = 70
- self.task.height_m = 1.8
- self.task.waist_cm = 82
- self.task.patient_id = self.patient.id
- self.task.save_with_next_available_id(
- self.req, self.patient._device_id
- )
- self.dbsession.commit()
+ self.task = BmiFactory(patient=self.patient)
def test_observations(self) -> None:
bundle = self.task.get_fhir_bundle(
self.req, self.recipient, skip_docs_if_other_content=True
)
- bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT)
- log.debug(f"Bundle:\n{bundle_str}")
- # The test is that it doesn't crash.
+ bundle_json = bundle.as_json()
+
+ height_entry = bundle_json[Fc.ENTRY][3]
+ mass_entry = bundle_json[Fc.ENTRY][4]
+ bmi_entry = bundle_json[Fc.ENTRY][5]
+ waist_entry = bundle_json[Fc.ENTRY][6]
-class FhirTaskExporterDiagnosisIcd10Tests(FhirExportTestCase):
- def create_tasks(self) -> None:
- self.create_fhir_patient()
+ height_resource = height_entry[Fc.RESOURCE]
+ mass_resource = mass_entry[Fc.RESOURCE]
+ bmi_resource = bmi_entry[Fc.RESOURCE]
+ waist_resource = waist_entry[Fc.RESOURCE]
- self.task = DiagnosisIcd10()
- self.apply_standard_task_fields(self.task)
- self.task.patient_id = self.patient.id
- self.task.save_with_next_available_id(
- self.req, self.patient._device_id
+ self.assertEqual(
+ height_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION
+ )
+ self.assertEqual(
+ height_resource[Fc.VALUE_QUANTITY][Fc.VALUE], self.task.height_m
+ )
+
+ self.assertEqual(
+ mass_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION
+ )
+ self.assertAlmostEqual(
+ mass_resource[Fc.VALUE_QUANTITY][Fc.VALUE],
+ self.task.mass_kg,
+ places=2,
+ )
+
+ self.assertEqual(
+ bmi_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION
+ )
+ self.assertAlmostEqual(
+ bmi_resource[Fc.VALUE_QUANTITY][Fc.VALUE],
+ self.task.bmi(),
+ places=2,
+ )
+
+ self.assertEqual(
+ waist_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION
+ )
+ self.assertAlmostEqual(
+ waist_resource[Fc.VALUE_QUANTITY][Fc.VALUE],
+ self.task.waist_cm,
+ places=2,
)
- self.dbsession.commit()
+
+
+class FhirTaskExporterDiagnosisIcd10Tests(FhirExportPatientTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.task = DiagnosisIcd10Factory(patient=self.patient)
# noinspection PyArgumentList
- item1 = DiagnosisIcd10Item(
- diagnosis_icd10_id=self.task.id,
+ self.item1 = DiagnosisIcd10ItemFactory(
+ diagnosis_icd10=self.task,
seqnum=1,
code="F33.30",
description="Recurrent depressive disorder, current episode "
@@ -810,89 +822,143 @@ def create_tasks(self) -> None:
"with mood-congruent psychotic symptoms",
comment="Cotard's syndrome",
)
- self.apply_standard_db_fields(item1)
- item1.save_with_next_available_id(self.req, self.task._device_id)
# noinspection PyArgumentList
- item2 = DiagnosisIcd10Item(
- diagnosis_icd10_id=self.task.id,
+ self.item2 = DiagnosisIcd10ItemFactory(
+ diagnosis_icd10=self.task,
seqnum=2,
code="F43.1",
description="Post-traumatic stress disorder",
)
- self.apply_standard_db_fields(item2)
- item2.save_with_next_available_id(self.req, self.task._device_id)
def test_observations(self) -> None:
bundle = self.task.get_fhir_bundle(
self.req, self.recipient, skip_docs_if_other_content=True
)
- bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT)
- log.debug(f"Bundle:\n{bundle_str}")
- # The test is that it doesn't crash.
+ bundle_json = bundle.as_json()
-class FhirTaskExporterDiagnosisIcd9CMTests(FhirExportTestCase):
- def create_tasks(self) -> None:
- self.create_fhir_patient()
+ cotard_resource = bundle_json[Fc.ENTRY][4][Fc.RESOURCE]
+ ptsd_resource = bundle_json[Fc.ENTRY][5][Fc.RESOURCE]
- self.task = DiagnosisIcd9CM()
- self.apply_standard_task_fields(self.task)
- self.task.patient_id = self.patient.id
- self.task.save_with_next_available_id(
- self.req, self.patient._device_id
+ self.assertEqual(
+ cotard_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_CONDITION
)
- self.dbsession.commit()
+ self.assertEqual(
+ cotard_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "F33.30"
+ )
+ self.assertIn(
+ "Cotard's syndrome",
+ cotard_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY],
+ )
+ self.assertIn(
+ "Recurrent depressive",
+ cotard_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY],
+ )
+
+ self.assertEqual(
+ ptsd_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "F43.1"
+ )
+
+
+class FhirTaskExporterDiagnosisIcd9CMTests(FhirExportPatientTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.task = DiagnosisIcd9CMFactory(patient=self.patient)
# noinspection PyArgumentList
- item1 = DiagnosisIcd9CMItem(
- diagnosis_icd9cm_id=self.task.id,
+ self.item1 = DiagnosisIcd9CMItemFactory(
+ diagnosis_icd9cm=self.task,
seqnum=1,
code="290.4",
description="Vascular dementia",
comment="or perhaps mixed dementia",
)
- self.apply_standard_db_fields(item1)
- item1.save_with_next_available_id(self.req, self.task._device_id)
# noinspection PyArgumentList
- item2 = DiagnosisIcd9CMItem(
- diagnosis_icd9cm_id=self.task.id,
+ self.item2 = DiagnosisIcd9CMItemFactory(
+ diagnosis_icd9cm=self.task,
seqnum=2,
code="303.0",
description="Acute alcoholic intoxication",
)
- self.apply_standard_db_fields(item2)
- item2.save_with_next_available_id(self.req, self.task._device_id)
def test_observations(self) -> None:
bundle = self.task.get_fhir_bundle(
self.req, self.recipient, skip_docs_if_other_content=True
)
- bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT)
- log.debug(f"Bundle:\n{bundle_str}")
- # The test is that it doesn't crash.
+ bundle_json = bundle.as_json()
+ dementia_resource = bundle_json[Fc.ENTRY][4][Fc.RESOURCE]
+ intoxication_resource = bundle_json[Fc.ENTRY][5][Fc.RESOURCE]
+
+ self.assertEqual(
+ dementia_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_CONDITION
+ )
+ self.assertEqual(
+ dementia_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "290.4"
+ )
+ self.assertIn(
+ "Vascular dementia",
+ dementia_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY],
+ )
+ self.assertIn(
+ "or perhaps mixed dementia",
+ dementia_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY],
+ )
+
+ self.assertEqual(
+ intoxication_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "303.0"
+ )
-class FhirTaskExporterGad7Tests(FhirExportTestCase):
+class FhirTaskExporterGad7Tests(FhirExportPatientTestCase):
"""
The GAD7 is a standard questionnaire that we don't provide any special
- FHIR support for; we rely on autodiscovery.
+ FHIR support for; we rely on autodiscovery. This is essentially a high
+ level test for _fhir_autodiscover() in cc_task.py, albeit not a
+ particularly thorough one.
"""
- def create_tasks(self) -> None:
- self.create_fhir_patient()
+ def setUp(self) -> None:
+ super().setUp()
- self.task = Gad7()
- self.apply_standard_task_fields(self.task)
- self.task.patient_id = self.patient.id
- self.task.save_with_next_available_id(
- self.req, self.patient._device_id
+ self.task = Gad7Factory(
+ patient=self.patient,
+ q1=0,
+ q2=1,
+ q3=2,
+ q4=3,
+ q5=0,
+ q6=1,
+ q7=2,
)
- self.dbsession.commit()
def test_observations(self) -> None:
bundle = self.task.get_fhir_bundle(
self.req, self.recipient, skip_docs_if_other_content=True
)
- bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT)
- log.critical(f"Bundle:\n{bundle_str}")
- # The test is that it doesn't crash.
+ bundle_json = bundle.as_json()
+ questions = bundle_json[Fc.ENTRY][1][Fc.RESOURCE][Fc.ITEM]
+ answers = bundle_json[Fc.ENTRY][2][Fc.RESOURCE][Fc.ITEM]
+
+ # Question text
+ self.assertIn(
+ "1. Feeling nervous, anxious or on edge", questions[0][Fc.TEXT]
+ )
+ # Comment string
+ self.assertIn(
+ "Q1, nervous/anxious/on edge (0 not at all - 3 nearly every day)",
+ questions[0][Fc.TEXT],
+ )
+
+ self.assertIn(
+ "1. Feeling nervous, anxious or on edge", answers[0][Fc.TEXT]
+ )
+ self.assertIn(
+ "Q1, nervous/anxious/on edge (0 not at all - 3 nearly every day)",
+ answers[0][Fc.TEXT],
+ )
+
+ self.assertEqual(answers[0][Fc.ANSWER][0][Fc.VALUE_INTEGER], 0)
+ self.assertEqual(answers[1][Fc.ANSWER][0][Fc.VALUE_INTEGER], 1)
+ self.assertEqual(answers[2][Fc.ANSWER][0][Fc.VALUE_INTEGER], 2)
+ self.assertEqual(answers[3][Fc.ANSWER][0][Fc.VALUE_INTEGER], 3)
diff --git a/server/camcops_server/cc_modules/tests/cc_forms_tests.py b/server/camcops_server/cc_modules/tests/cc_forms_tests.py
index 293c281b7..42ae682b0 100644
--- a/server/camcops_server/cc_modules/tests/cc_forms_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_forms_tests.py
@@ -54,33 +54,22 @@
)
from camcops_server.cc_modules.cc_ipuse import IpContexts
from camcops_server.cc_modules.cc_pyramid import ViewParam
-from camcops_server.cc_modules.cc_taskschedule import TaskSchedule
+from camcops_server.cc_modules.cc_testfactories import (
+ Fake,
+ GroupFactory,
+ TaskScheduleFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
+)
from camcops_server.cc_modules.cc_unittest import (
BasicDatabaseTestCase,
- DemoDatabaseTestCase,
DemoRequestTestCase,
)
-TEST_PHONE_NUMBER = "+{ctry}{tel}".format(
- ctry=phonenumbers.PhoneMetadata.metadata_for_region("GB").country_code,
- tel=phonenumbers.PhoneMetadata.metadata_for_region(
- "GB"
- ).personal_number.example_number,
-) # see webview_tests.py
-
log = logging.getLogger(__name__)
-# =============================================================================
-# Unit tests
-# =============================================================================
-
-
class SchemaTestCase(DemoRequestTestCase):
- """
- Unit tests.
- """
-
def serialize_deserialize(
self, schema: Schema, appstruct: Dict[str, Any]
) -> None:
@@ -113,7 +102,7 @@ def test_serialize_deserialize(self) -> None:
self.serialize_deserialize(schema, appstruct)
-class TaskScheduleSchemaTests(DemoDatabaseTestCase):
+class TaskScheduleSchemaTests(BasicDatabaseTestCase):
def test_invalid_for_bad_template_placeholder(self) -> None:
schema = TaskScheduleSchema().bind(request=self.req)
cstruct = {
@@ -252,23 +241,19 @@ def test_invalid_for_negative_due_from(self) -> None:
self.assertIn("must be zero or more days", cm.exception.messages()[0])
-class TaskScheduleItemSchemaIpTests(BasicDatabaseTestCase):
- def setUp(self) -> None:
- super().setUp()
-
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
-
+class TaskScheduleItemSchemaIpTests(DemoRequestTestCase):
def test_invalid_for_commercial_mismatch(self) -> None:
- self.group.ip_use.commercial = True
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(
+ ip_use__clinical=False,
+ ip_use__commercial=True,
+ ip_use__educational=False,
+ ip_use__research=False,
+ )
+ schedule = TaskScheduleFactory(group=group)
schema = TaskScheduleItemSchema().bind(request=self.req)
appstruct = {
- ViewParam.SCHEDULE_ID: self.schedule.id,
+ ViewParam.SCHEDULE_ID: schedule.id,
ViewParam.TABLE_NAME: "mfi20",
ViewParam.CLINICIAN_CONFIRMATION: False,
ViewParam.DUE_FROM: Duration(days=0),
@@ -282,13 +267,17 @@ def test_invalid_for_commercial_mismatch(self) -> None:
self.assertIn("prohibits commercial", cm.exception.messages()[0])
def test_invalid_for_clinical_mismatch(self) -> None:
- self.group.ip_use.clinical = True
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(
+ ip_use__clinical=True,
+ ip_use__commercial=False,
+ ip_use__educational=False,
+ ip_use__research=False,
+ )
+ schedule = TaskScheduleFactory(group=group)
schema = TaskScheduleItemSchema().bind(request=self.req)
appstruct = {
- ViewParam.SCHEDULE_ID: self.schedule.id,
+ ViewParam.SCHEDULE_ID: schedule.id,
ViewParam.TABLE_NAME: "mfi20",
ViewParam.CLINICIAN_CONFIRMATION: False,
ViewParam.DUE_FROM: Duration(days=0),
@@ -302,13 +291,17 @@ def test_invalid_for_clinical_mismatch(self) -> None:
self.assertIn("prohibits clinical", cm.exception.messages()[0])
def test_invalid_for_educational_mismatch(self) -> None:
- self.group.ip_use.educational = True
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(
+ ip_use__clinical=False,
+ ip_use__commercial=False,
+ ip_use__educational=True,
+ ip_use__research=False,
+ )
+ schedule = TaskScheduleFactory(group=group)
schema = TaskScheduleItemSchema().bind(request=self.req)
appstruct = {
- ViewParam.SCHEDULE_ID: self.schedule.id,
+ ViewParam.SCHEDULE_ID: schedule.id,
ViewParam.TABLE_NAME: "mfi20",
ViewParam.CLINICIAN_CONFIRMATION: True,
ViewParam.DUE_FROM: Duration(days=0),
@@ -328,13 +321,17 @@ def test_invalid_for_educational_mismatch(self) -> None:
self.assertIn("prohibits educational", cm.exception.messages()[0])
def test_invalid_for_research_mismatch(self) -> None:
- self.group.ip_use.research = True
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(
+ ip_use__clinical=False,
+ ip_use__commercial=False,
+ ip_use__educational=False,
+ ip_use__research=True,
+ )
+ schedule = TaskScheduleFactory(group=group)
schema = TaskScheduleItemSchema().bind(request=self.req)
appstruct = {
- ViewParam.SCHEDULE_ID: self.schedule.id,
+ ViewParam.SCHEDULE_ID: schedule.id,
ViewParam.TABLE_NAME: "moca",
ViewParam.CLINICIAN_CONFIRMATION: True,
ViewParam.DUE_FROM: Duration(days=0),
@@ -348,13 +345,12 @@ def test_invalid_for_research_mismatch(self) -> None:
self.assertIn("prohibits research", cm.exception.messages()[0])
def test_invalid_for_missing_ip_use(self) -> None:
- self.group.ip_use = None
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(ip_use=None)
+ schedule = TaskScheduleFactory(group=group)
schema = TaskScheduleItemSchema().bind(request=self.req)
appstruct = {
- ViewParam.SCHEDULE_ID: self.schedule.id,
+ ViewParam.SCHEDULE_ID: schedule.id,
ViewParam.TABLE_NAME: "moca",
ViewParam.CLINICIAN_CONFIRMATION: True,
ViewParam.DUE_FROM: Duration(days=0),
@@ -366,7 +362,7 @@ def test_invalid_for_missing_ip_use(self) -> None:
schema.deserialize(cstruct)
self.assertIn(
- f"The group '{self.group.name}' has no intellectual property "
+ f"The group '{group.name}' has no intellectual property "
f"settings",
cm.exception.messages()[0],
)
@@ -713,25 +709,17 @@ def test_deserialize_not_a_json_object_fails_validation(self) -> None:
self.assertEqual(cm.exception.value, "[{}]")
-class TaskScheduleSelectorTests(BasicDatabaseTestCase):
+class TaskScheduleSelectorTests(DemoRequestTestCase):
def test_displays_only_users_schedules(self) -> None:
- user = self.create_user(username="regular_user")
- my_group = self.create_group("mygroup")
- not_my_group = self.create_group("notmygroup")
- self.dbsession.flush()
-
- self.create_membership(user, my_group, may_manage_patients=True)
-
- my_schedule = TaskSchedule()
- my_schedule.group_id = my_group.id
- my_schedule.name = "My group's schedule"
- self.dbsession.add(my_schedule)
+ user = UserFactory()
+ my_group = GroupFactory()
+ not_my_group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=my_group.id, may_manage_patients=True
+ )
- not_my_schedule = TaskSchedule()
- not_my_schedule.group_id = not_my_group.id
- not_my_schedule.name = "Not my group's schedule"
- self.dbsession.add(not_my_schedule)
- self.dbsession.commit()
+ my_schedule = TaskScheduleFactory(group=my_group)
+ not_my_schedule = TaskScheduleFactory(group=not_my_group)
self.req._debugging_user = user
@@ -994,9 +982,8 @@ def test_raises_for_invalid_parsable_number(self) -> None:
self.assertIn("Invalid phone number", cm.exception.messages()[0])
def test_returns_valid_phone_number(self) -> None:
- phone_number = self.phone_type.deserialize(
- self.node, TEST_PHONE_NUMBER
- )
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ phone_number = self.phone_type.deserialize(self.node, phone_number_str)
self.assertIsInstance(phone_number, phonenumbers.PhoneNumber)
@@ -1004,7 +991,7 @@ def test_returns_valid_phone_number(self) -> None:
phonenumbers.format_number(
phone_number, phonenumbers.PhoneNumberFormat.E164
),
- TEST_PHONE_NUMBER,
+ phone_number_str,
)
@@ -1013,11 +1000,12 @@ def test_returns_null_for_appstruct_none(self) -> None:
self.assertIs(self.phone_type.serialize(self.node, None), null)
def test_returns_number_formatted_e164(self) -> None:
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ phone_number = phonenumbers.parse(phone_number_str)
self.assertEqual(
self.phone_type.serialize(self.node, phone_number),
- TEST_PHONE_NUMBER,
+ phone_number_str,
)
diff --git a/server/camcops_server/cc_modules/tests/cc_patient_tests.py b/server/camcops_server/cc_modules/tests/cc_patient_tests.py
index eebbb491b..e543171b0 100644
--- a/server/camcops_server/cc_modules/tests/cc_patient_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_patient_tests.py
@@ -28,8 +28,8 @@
import hl7
import pendulum
+from camcops_server.cc_modules.cc_group import Group
from camcops_server.cc_modules.cc_simpleobjects import BarePatientInfo
-from camcops_server.cc_modules.cc_patient import Patient
from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
from camcops_server.cc_modules.cc_simpleobjects import IdNumReference
from camcops_server.cc_modules.cc_taskschedule import (
@@ -38,9 +38,21 @@
TaskScheduleItem,
)
from camcops_server.cc_modules.cc_spreadsheet import SpreadsheetPage
+from camcops_server.cc_modules.cc_testfactories import (
+ GroupFactory,
+ NHSPatientIdNumFactory,
+ PatientFactory,
+ PatientTaskScheduleFactory,
+ RioPatientIdNumFactory,
+ ServerCreatedPatientFactory,
+ TaskScheduleFactory,
+ TaskScheduleItemFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
+)
from camcops_server.cc_modules.cc_unittest import (
BasicDatabaseTestCase,
- DemoDatabaseTestCase,
+ DemoRequestTestCase,
)
from camcops_server.cc_modules.cc_xml import XmlElement
@@ -50,26 +62,30 @@
# =============================================================================
-class PatientTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class PatientTests(DemoRequestTestCase):
def test_patient(self) -> None:
- self.announce("test_patient")
- from camcops_server.cc_modules.cc_group import Group
-
req = self.req
- q = self.dbsession.query(Patient)
- p = q.first() # type: Patient
- assert p, "Missing Patient in demo database!"
+ req._debugging_user = UserFactory()
- for pidnum in p.get_idnum_objects():
+ p = PatientFactory()
+ nhs_idnum = NHSPatientIdNumFactory(patient=p)
+ RioPatientIdNumFactory(patient=p)
+
+ idnum_objects = p.get_idnum_objects()
+ self.assertEqual(len(idnum_objects), 2)
+ for pidnum in idnum_objects:
self.assertIsInstance(pidnum, PatientIdNum)
- for idref in p.get_idnum_references():
+
+ idnum_references = p.get_idnum_references()
+ self.assertEqual(len(idnum_references), 2)
+ for idref in idnum_references:
self.assertIsInstance(idref, IdNumReference)
- for idnum in p.get_idnum_raw_values_only():
+
+ idnum_raw_values = p.get_idnum_raw_values_only()
+ self.assertEqual(len(idnum_raw_values), 2)
+ for idnum in idnum_raw_values:
self.assertIsInstance(idnum, int)
+
self.assertIsInstance(p.get_xml_root(req), XmlElement)
self.assertIsInstance(p.get_spreadsheet_page(req), SpreadsheetPage)
self.assertIsInstance(p.get_bare_ptinfo(), BarePatientInfo)
@@ -99,104 +115,82 @@ def test_patient(self) -> None:
p.get_hl7_pid_segment(req, self.recipdef), hl7.Segment
)
self.assertIsInstanceOrNone(
- p.get_idnum_object(which_idnum=1), PatientIdNum
+ p.get_idnum_object(which_idnum=nhs_idnum.which_idnum), PatientIdNum
+ )
+ self.assertIsInstanceOrNone(
+ p.get_idnum_value(which_idnum=nhs_idnum.which_idnum), int
+ )
+ self.assertIsInstance(
+ p.get_iddesc(req, which_idnum=nhs_idnum.which_idnum), str
+ )
+ self.assertIsInstance(
+ p.get_idshortdesc(req, which_idnum=nhs_idnum.which_idnum), str
)
- self.assertIsInstanceOrNone(p.get_idnum_value(which_idnum=1), int)
- self.assertIsInstance(p.get_iddesc(req, which_idnum=1), str)
- self.assertIsInstance(p.get_idshortdesc(req, which_idnum=1), str)
self.assertIsInstance(p.is_preserved(), bool)
self.assertIsInstance(p.is_finalized(), bool)
self.assertIsInstance(p.user_may_edit(req), bool)
def test_surname_forename_upper(self) -> None:
- patient = Patient()
- patient.forename = "Forename"
- patient.surname = "Surname"
-
+ patient = PatientFactory(forename="Forename", surname="Surname")
self.assertEqual(
patient.get_surname_forename_upper(), "SURNAME, FORENAME"
)
def test_surname_forename_upper_no_forename(self) -> None:
- patient = Patient()
- patient.surname = "Surname"
-
+ patient = PatientFactory(forename=None, surname="Surname")
self.assertEqual(
patient.get_surname_forename_upper(), "SURNAME, (UNKNOWN)"
)
def test_surname_forename_upper_no_surname(self) -> None:
- patient = Patient()
- patient.forename = "Forename"
-
+ patient = PatientFactory(forename="Forename", surname=None)
self.assertEqual(
patient.get_surname_forename_upper(), "(UNKNOWN), FORENAME"
)
-class LineageTests(DemoDatabaseTestCase):
- def create_tasks(self) -> None:
- # Actually not creating any tasks but we don't want the patients
- # created by default in the baseclass
-
- # First record for patient 1
- self.set_era("2020-01-01")
-
- self.patient_1 = Patient()
- self.patient_1.id = 1
- self.apply_standard_db_fields(self.patient_1)
- self.dbsession.add(self.patient_1)
-
- # First ID number record for patient 1
- self.patient_idnum_1_1 = PatientIdNum()
- self.patient_idnum_1_1.id = 3
- self.apply_standard_db_fields(self.patient_idnum_1_1)
- self.patient_idnum_1_1.patient_id = 1
- self.patient_idnum_1_1.which_idnum = self.nhs_iddef.which_idnum
- self.patient_idnum_1_1.idnum_value = 555
- self.dbsession.add(self.patient_idnum_1_1)
-
- # Second ID number record for patient 1
- self.patient_idnum_1_2 = PatientIdNum()
- self.patient_idnum_1_2.id = 3
- self.apply_standard_db_fields(self.patient_idnum_1_2)
- # This one is not current
- self.patient_idnum_1_2._current = False
- self.patient_idnum_1_2.patient_id = 1
- self.patient_idnum_1_2.which_idnum = self.nhs_iddef.which_idnum
- self.patient_idnum_1_2.idnum_value = 555
- self.dbsession.add(self.patient_idnum_1_2)
+class LineageTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
- self.dbsession.commit()
+ self.patient = PatientFactory()
+ self.current_patient_idnum = NHSPatientIdNumFactory(
+ patient=self.patient
+ )
+ self.assertTrue(self.current_patient_idnum._current)
+
+ self.not_current_patient_idnum = NHSPatientIdNumFactory(
+ patient=self.patient,
+ _current=False,
+ id=self.current_patient_idnum.id,
+ which_idnum=self.current_patient_idnum.which_idnum,
+ idnum_value=self.current_patient_idnum.idnum_value,
+ )
+ self.assertFalse(self.not_current_patient_idnum._current)
def test_gen_patient_idnums_even_noncurrent(self) -> None:
- idnums = list(self.patient_1.gen_patient_idnums_even_noncurrent())
+ idnums = list(self.patient.gen_patient_idnums_even_noncurrent())
self.assertEqual(len(idnums), 2)
-class PatientDeleteTests(DemoDatabaseTestCase):
+class PatientDeleteTests(DemoRequestTestCase):
def test_deletes_patient_task_schedule(self) -> None:
- schedule = TaskSchedule()
- schedule.group_id = self.group.id
- self.dbsession.add(schedule)
- self.dbsession.flush()
-
- item = TaskScheduleItem()
- item.schedule_id = schedule.id
- item.task_table_name = "ace3"
- item.due_from = pendulum.Duration(days=30)
- item.due_by = pendulum.Duration(days=60)
- self.dbsession.add(item)
- self.dbsession.flush()
-
- patient = self.create_patient()
-
- pts = PatientTaskSchedule()
- pts.schedule_id = schedule.id
- pts.patient_pk = patient.pk
- self.dbsession.add(pts)
- self.dbsession.commit()
+ schedule = TaskScheduleFactory()
+
+ item = TaskScheduleItemFactory(
+ task_schedule=schedule,
+ task_table_name="ace3",
+ due_from=pendulum.Duration(days=30),
+ due_by=pendulum.Duration(days=60),
+ )
+
+ patient = ServerCreatedPatientFactory()
+
+ pts = PatientTaskScheduleFactory(
+ task_schedule=schedule,
+ patient=patient,
+ )
self.assertIsNotNone(
self.dbsession.query(TaskSchedule)
@@ -236,60 +230,52 @@ def test_deletes_patient_task_schedule(self) -> None:
class PatientPermissionTests(BasicDatabaseTestCase):
- def test_group_administrator_may_edit_server_created(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
+ def setUp(self) -> None:
+ super().setUp()
- patient = self.create_patient(
- _group=self.group, as_server_patient=True
- )
+ self.user = UserFactory()
+ self.group = GroupFactory()
- self.create_membership(user, self.group, groupadmin=True)
- self.dbsession.commit()
+ def test_group_administrator_may_edit_server_patient(self) -> None:
+ patient = ServerCreatedPatientFactory(_group=self.group)
+ ugm = UserGroupMembershipFactory(
+ user_id=self.user.id, group_id=self.group.id, groupadmin=True
+ )
- self.req._debugging_user = user
+ self.req._debugging_user = ugm.user
self.assertTrue(patient.user_may_edit(self.req))
- def test_group_administrator_may_edit_finalized(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
-
- patient = self.create_patient(
- _group=self.group, as_server_patient=False
+ def test_group_administrator_may_edit_finalized_patient(self) -> None:
+ patient = PatientFactory(_group=self.group)
+ ugm = UserGroupMembershipFactory(
+ user_id=self.user.id, group_id=self.group.id, groupadmin=True
)
- self.create_membership(user, self.group, groupadmin=True)
- self.dbsession.commit()
+ self.assertTrue(ugm.groupadmin)
- self.req._debugging_user = user
+ self.req._debugging_user = ugm.user
self.assertTrue(patient.user_may_edit(self.req))
def test_group_member_with_permission_may_edit_server_created(
self,
) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
-
- patient = self.create_patient(
- _group=self.group, as_server_patient=True
+ patient = ServerCreatedPatientFactory(_group=self.group)
+ ugm = UserGroupMembershipFactory(
+ user_id=self.user.id,
+ group_id=self.group.id,
+ may_manage_patients=True,
)
- self.create_membership(user, self.group, may_manage_patients=True)
- self.dbsession.commit()
-
- self.req._debugging_user = user
+ self.req._debugging_user = ugm.user
self.assertTrue(patient.user_may_edit(self.req))
def test_group_member_with_permission_may_not_edit_finalized(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
-
- patient = self.create_patient(
- _group=self.group, as_server_patient=False
+ patient = PatientFactory(_group=self.group)
+ ugm = UserGroupMembershipFactory(
+ user_id=self.user.id,
+ group_id=self.group.id,
+ may_manage_patients=True,
)
- self.create_membership(user, self.group, may_manage_patients=True)
- self.dbsession.commit()
-
- self.req._debugging_user = user
+ self.req._debugging_user = ugm.user
self.assertFalse(patient.user_may_edit(self.req))
diff --git a/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py b/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py
index f6ccd8e98..d759a9f2a 100644
--- a/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py
@@ -32,13 +32,15 @@
CamcopsAuthenticationPolicy,
Permission,
)
-from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase
-
+from camcops_server.cc_modules.cc_testfactories import (
+ GroupFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
+)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
-class CamcopsAuthenticationPolicyTests(BasicDatabaseTestCase):
- def setUp(self) -> None:
- super().setUp()
+class CamcopsAuthenticationPolicyTests(DemoRequestTestCase):
def test_principals_for_no_user(self) -> None:
self.req._debugging_user = None
self.assertEqual(
@@ -47,10 +49,7 @@ def test_principals_for_no_user(self) -> None:
)
def test_principals_for_authenticated_user(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
-
- self.req._debugging_user = user
+ user = self.req._debugging_user = UserFactory()
self.assertIn(
Authenticated,
CamcopsAuthenticationPolicy.effective_principals(self.req),
@@ -61,23 +60,29 @@ def test_principals_for_authenticated_user(self) -> None:
)
def test_principals_when_user_must_change_pasword(self) -> None:
- user = self.create_user(username="test", must_change_password=True)
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ user = self.req._debugging_user = UserFactory(
+ when_agreed_terms_of_use=self.req.now,
+ must_change_password=True,
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
+ )
- self.req._debugging_user = user
self.assertIn(
Permission.MUST_CHANGE_PASSWORD,
CamcopsAuthenticationPolicy.effective_principals(self.req),
)
def test_principals_when_user_must_set_up_mfa(self) -> None:
- user = self.create_user(username="test", mfa_method=MfaMethod.NO_MFA)
- user.agree_terms(self.req)
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ user = self.req._debugging_user = UserFactory(
+ mfa_method=MfaMethod.NO_MFA, when_agreed_terms_of_use=self.req.now
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
+ )
- self.req._debugging_user = user
self.req.config.mfa_methods = [MfaMethod.HOTP_EMAIL]
self.assertIn(
Permission.MUST_SET_MFA,
@@ -85,23 +90,28 @@ def test_principals_when_user_must_set_up_mfa(self) -> None:
)
def test_principals_when_user_must_agree_terms(self) -> None:
- user = self.create_user(username="test", when_agreed_terms_of_use=None)
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ user = self.req._debugging_user = UserFactory(
+ when_agreed_terms_of_use=None
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
+ )
- self.req._debugging_user = user
self.assertIn(
Permission.MUST_AGREE_TERMS,
CamcopsAuthenticationPolicy.effective_principals(self.req),
)
def test_principals_when_everything_ok(self) -> None:
- user = self.create_user(username="test", mfa_method=MfaMethod.NO_MFA)
- user.agree_terms(self.req)
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ user = self.req._debugging_user = UserFactory(
+ mfa_method=MfaMethod.NO_MFA, when_agreed_terms_of_use=self.req.now
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
+ )
- self.req._debugging_user = user
self.req.config.mfa_methods = [MfaMethod.NO_MFA]
self.assertIn(
Permission.HAPPY,
@@ -109,19 +119,19 @@ def test_principals_when_everything_ok(self) -> None:
)
def test_principals_for_superuser(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ self.req._debugging_user = UserFactory(superuser=True)
- self.req._debugging_user = user
self.assertIn(
Permission.SUPERUSER,
CamcopsAuthenticationPolicy.effective_principals(self.req),
)
def test_principals_for_groupadmin(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
- self.create_membership(user, self.group, groupadmin=True)
+ user = self.req._debugging_user = UserFactory()
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, groupadmin=True
+ )
self.req._debugging_user = user
self.assertIn(
diff --git a/server/camcops_server/cc_modules/tests/cc_redcap_tests.py b/server/camcops_server/cc_modules/tests/cc_redcap_tests.py
index e31229a33..c9d30b9c9 100644
--- a/server/camcops_server/cc_modules/tests/cc_redcap_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_redcap_tests.py
@@ -25,7 +25,7 @@
import os
import tempfile
-from typing import Generator, TYPE_CHECKING
+from typing import Any, Dict, Generator
from unittest import mock, TestCase
from pandas import DataFrame
@@ -33,6 +33,10 @@
import redcap
from camcops_server.cc_modules.cc_constants import ConfigParamExportRecipient
+from camcops_server.cc_modules.cc_exportmodels import (
+ ExportedTask,
+ ExportedTaskRedcap,
+)
from camcops_server.cc_modules.cc_exportrecipient import ExportRecipient
from camcops_server.cc_modules.cc_exportrecipientinfo import (
ExportRecipientInfo,
@@ -45,11 +49,17 @@
RedcapRecordStatus,
RedcapTaskExporter,
)
-from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase
-
-if TYPE_CHECKING:
- from camcops_server.cc_modules.cc_patient import Patient
-
+from camcops_server.cc_modules.cc_testfactories import (
+ NHSPatientIdNumFactory,
+ PatientFactory,
+)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
+from camcops_server.tasks.tests.factories import (
+ APEQCPFTPerinatalFactory,
+ BmiFactory,
+ KhandakerMojoMedicationTherapyFactory,
+ Phq9Factory,
+)
# =============================================================================
# Unit testing
@@ -124,7 +134,7 @@ def test_raises_when_fieldmap_has_unknown_symbols(self) -> None:
task = mock.Mock(tablename="bmi")
fieldmap = {"pa_height": "sys.platform"}
- field_dict = {}
+ field_dict: Dict[str, Any] = {}
with self.assertRaises(RedcapExportException) as cm:
exporter.transform_fields(field_dict, task, fieldmap)
@@ -172,7 +182,7 @@ def test_raises_when_error_from_redcap_on_import(self) -> None:
)
with self.assertRaises(RedcapExportException) as cm:
- record = {}
+ record: Dict[str, Any] = {}
exporter.upload_record(record)
message = str(cm.exception)
@@ -468,14 +478,19 @@ def test_raises_when_fields_missing_attributes(self) -> None:
# =============================================================================
-class RedcapExportTestCase(BasicDatabaseTestCase):
+class RedcapExportTestCase(DemoRequestTestCase):
fieldmap = ""
def setUp(self) -> None:
+ super().setUp()
+
+ self.patient = PatientFactory()
+ self.patient_idnum = NHSPatientIdNumFactory(patient=self.patient)
+
recipientinfo = ExportRecipientInfo()
self.recipient = ExportRecipient(recipientinfo)
- self.recipient.primary_idnum = 1001
+ self.recipient.primary_idnum = self.patient_idnum.which_idnum
# auto increment doesn't work for BigInteger with SQLite
self.recipient.id = 1
@@ -485,58 +500,12 @@ def setUp(self) -> None:
)
self.write_fieldmaps(self.recipient.redcap_fieldmap_filename)
- super().setUp()
-
def write_fieldmaps(self, filename: str) -> None:
with open(filename, "w") as f:
f.write(self.fieldmap)
- def create_patient_with_idnum_1001(self) -> "Patient":
- from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition
- from camcops_server.cc_modules.cc_patient import Patient
- from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
-
- patient = Patient()
- patient.id = 2
- self.apply_standard_db_fields(patient)
- patient.forename = "Forename2"
- patient.surname = "Surname2"
- patient.dob = pendulum.parse("1975-12-12")
- self.dbsession.add(patient)
-
- idnumdef_1001 = IdNumDefinition()
- idnumdef_1001.which_idnum = 1001
- idnumdef_1001.description = "Test idnumdef 1001"
- self.dbsession.add(idnumdef_1001)
- self.dbsession.commit()
-
- patient_idnum1 = PatientIdNum()
- patient_idnum1.id = 3
- self.apply_standard_db_fields(patient_idnum1)
- patient_idnum1.patient_id = patient.id
- patient_idnum1.which_idnum = 1001
- patient_idnum1.idnum_value = 555
- self.dbsession.add(patient_idnum1)
- self.dbsession.commit()
-
- return patient
-
-
-class BmiRedcapExportTestCase(RedcapExportTestCase):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
-
-class BmiRedcapValidFieldmapTestCase(BmiRedcapExportTestCase):
+class BmiRedcapValidFieldmapTestCase(RedcapExportTestCase):
fieldmap = """
@@ -559,25 +528,17 @@ class BmiRedcapExportTests(BmiRedcapValidFieldmapTestCase):
related to the BMI task
"""
- def create_tasks(self) -> None:
- from camcops_server.tasks.bmi import Bmi
-
- patient = self.create_patient_with_idnum_1001()
- self.task = Bmi()
- self.apply_standard_task_fields(self.task)
- self.task.id = next(self.id_sequence)
- self.task.height_m = 1.83
- self.task.mass_kg = 67.57
- self.task.patient_id = patient.id
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ def setUp(self) -> None:
+ super().setUp()
- def test_record_exported(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
+ self.task = BmiFactory(
+ patient=self.patient,
+ height_m=1.83,
+ mass_kg=67.57,
+ when_created=pendulum.parse("2010-07-07"),
)
+ def test_record_exported(self) -> None:
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -625,14 +586,9 @@ def test_record_exported(self) -> None:
rows = args[0]
record = rows[0]
- self.assertEqual(record["patient_id"], 555)
+ self.assertEqual(record["patient_id"], self.patient_idnum.idnum_value)
def test_record_exported_with_non_integer_id(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -648,11 +604,6 @@ def test_record_exported_with_non_integer_id(self) -> None:
self.assertEqual(exported_task_redcap.redcap_record_id, "15-123")
def test_record_id_generated_when_no_autonumbering(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -678,11 +629,6 @@ def test_record_id_generated_when_no_autonumbering(self) -> None:
self.assertFalse(kwargs["force_auto_number"])
def test_record_imported_when_no_existing_records(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exporter = MockRedcapTaskExporter()
project = exporter.get_project()
project.export_records.return_value = DataFrame()
@@ -701,33 +647,17 @@ def test_record_imported_when_no_existing_records(self) -> None:
class BmiRedcapUpdateTests(BmiRedcapValidFieldmapTestCase):
- def create_tasks(self) -> None:
- from camcops_server.tasks.bmi import Bmi
-
- patient = self.create_patient_with_idnum_1001()
- self.task1 = Bmi()
- self.apply_standard_task_fields(self.task1)
- self.task1.id = next(self.id_sequence)
- self.task1.height_m = 1.83
- self.task1.mass_kg = 67.57
- self.task1.patient_id = patient.id
- self.dbsession.add(self.task1)
-
- self.task2 = Bmi()
- self.apply_standard_task_fields(self.task2)
- self.task2.id = next(self.id_sequence)
- self.task2.height_m = 1.83
- self.task2.mass_kg = 68.5
- self.task2.patient_id = patient.id
- self.dbsession.add(self.task2)
- self.dbsession.commit()
+ def setUp(self) -> None:
+ super().setUp()
- def test_existing_record_id_used_for_update(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
+ self.task1 = BmiFactory(
+ patient=self.patient,
+ )
+ self.task2 = BmiFactory(
+ patient=self.patient,
)
+ def test_existing_record_id_used_for_update(self) -> None:
exporter = MockRedcapTaskExporter()
project = exporter.get_project()
project.export_records.return_value = DataFrame({"patient_id": []})
@@ -748,7 +678,7 @@ def test_existing_record_id_used_for_update(self) -> None:
project.export_records.return_value = DataFrame(
{
"record_id": ["123"],
- "patient_id": [555],
+ "patient_id": [self.patient_idnum.idnum_value],
"redcap_repeat_instrument": ["bmi"],
"redcap_repeat_instance": [1],
}
@@ -810,43 +740,26 @@ class Phq9RedcapExportTests(RedcapExportTestCase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
-
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
- while True:
- yield i
- i += 1
-
- def create_tasks(self) -> None:
- from camcops_server.tasks.phq9 import Phq9
-
- patient = self.create_patient_with_idnum_1001()
- self.task = Phq9()
- self.apply_standard_task_fields(self.task)
- self.task.id = next(self.id_sequence)
- self.task.q1 = 0
- self.task.q2 = 1
- self.task.q3 = 2
- self.task.q4 = 3
- self.task.q5 = 0
- self.task.q6 = 1
- self.task.q7 = 2
- self.task.q8 = 3
- self.task.q9 = 0
- self.task.q10 = 3
- self.task.patient_id = patient.id
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ def setUp(self) -> None:
+ super().setUp()
- def test_record_exported(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
+ self.task = Phq9Factory(
+ patient=self.patient,
+ q1=0,
+ q2=1,
+ q3=2,
+ q4=3,
+ q5=0,
+ q6=1,
+ q7=2,
+ q8=3,
+ q9=0,
+ q10=3,
+ when_created=pendulum.parse("2010-07-07"),
)
+ def test_record_exported(self) -> None:
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -883,8 +796,8 @@ def test_record_exported(self) -> None:
)
self.assertEqual(record["phq9_how_difficult"], 4)
self.assertEqual(record["phq9_total_score"], 12)
- self.assertEqual(record["phq9_first_name"], "Forename2")
- self.assertEqual(record["phq9_last_name"], "Surname2")
+ self.assertEqual(record["phq9_first_name"], self.patient.forename)
+ self.assertEqual(record["phq9_last_name"], self.patient.surname)
self.assertEqual(record["phq9_date_enrolled"], "2010-07-07")
self.assertEqual(record["phq9_1"], 0)
@@ -905,7 +818,7 @@ def test_record_exported(self) -> None:
rows = args[0]
record = rows[0]
- self.assertEqual(record["patient_id"], 555)
+ self.assertEqual(record["patient_id"], self.patient_idnum.idnum_value)
class MedicationTherapyRedcapExportTests(RedcapExportTestCase):
@@ -940,25 +853,14 @@ def get_id() -> Generator[int, None, None]:
yield i
i += 1
- def create_tasks(self) -> None:
- from camcops_server.tasks.khandaker_mojo_medicationtherapy import (
- KhandakerMojoMedicationTherapy,
- )
-
- patient = self.create_patient_with_idnum_1001()
- self.task = KhandakerMojoMedicationTherapy()
- self.apply_standard_task_fields(self.task)
- self.task.id = next(self.id_sequence)
- self.task.patient_id = patient.id
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ def setUp(self) -> None:
+ super().setUp()
- def test_record_exported(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
+ self.task = KhandakerMojoMedicationTherapyFactory(
+ patient=self.patient,
)
+ def test_record_exported(self) -> None:
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1030,46 +932,17 @@ class MultipleTaskRedcapExportTests(RedcapExportTestCase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
-
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
- while True:
- yield i
- i += 1
+ def setUp(self) -> None:
+ super().setUp()
- def create_tasks(self) -> None:
- from camcops_server.tasks.khandaker_mojo_medicationtherapy import (
- KhandakerMojoMedicationTherapy,
+ self.mojo_task = KhandakerMojoMedicationTherapyFactory(
+ patient=self.patient
)
- patient = self.create_patient_with_idnum_1001()
- self.mojo_task = KhandakerMojoMedicationTherapy()
- self.apply_standard_task_fields(self.mojo_task)
- self.mojo_task.id = next(self.id_sequence)
- self.mojo_task.patient_id = patient.id
- self.dbsession.add(self.mojo_task)
- self.dbsession.commit()
-
- from camcops_server.tasks.bmi import Bmi
-
- self.bmi_task = Bmi()
- self.apply_standard_task_fields(self.bmi_task)
- self.bmi_task.id = next(self.id_sequence)
- self.bmi_task.height_m = 1.83
- self.bmi_task.mass_kg = 67.57
- self.bmi_task.patient_id = patient.id
- self.dbsession.add(self.bmi_task)
- self.dbsession.commit()
+ self.bmi_task = BmiFactory(patient=self.patient)
def test_instance_ids_on_different_tasks_in_same_record(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exporter = MockRedcapTaskExporter()
project = exporter.get_project()
project.export_records.return_value = DataFrame({"patient_id": []})
@@ -1091,7 +964,7 @@ def test_instance_ids_on_different_tasks_in_same_record(self) -> None:
project.export_records.return_value = DataFrame(
{
"record_id": ["123"],
- "patient_id": [555],
+ "patient_id": [self.patient_idnum.idnum_value],
"redcap_repeat_instrument": [
"khandaker_mojo_medicationtherapy"
],
@@ -1115,11 +988,6 @@ def test_instance_ids_on_different_tasks_in_same_record(self) -> None:
self.assertEqual(record["redcap_repeat_instance"], 1)
def test_imported_into_different_events(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exporter = MockRedcapTaskExporter()
project = exporter.get_project()
@@ -1163,28 +1031,11 @@ def test_imported_into_different_events(self) -> None:
class BadConfigurationRedcapTests(RedcapExportTestCase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- def create_tasks(self) -> None:
- from camcops_server.tasks.bmi import Bmi
+ def setUp(self) -> None:
+ super().setUp()
- patient = self.create_patient_with_idnum_1001()
- self.task = Bmi()
- self.apply_standard_task_fields(self.task)
- self.task.id = next(self.id_sequence)
- self.task.height_m = 1.83
- self.task.mass_kg = 67.57
- self.task.patient_id = patient.id
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ self.task = BmiFactory(patient=self.patient)
class MissingInstrumentRedcapTests(BadConfigurationRedcapTests):
@@ -1201,11 +1052,6 @@ class MissingInstrumentRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_when_instrument_missing_from_fieldmap(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1237,11 +1083,6 @@ class IncorrectRecordIdRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_when_record_id_is_incorrect(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1250,7 +1091,7 @@ def test_raises_when_record_id_is_incorrect(self) -> None:
project.export_records.return_value = DataFrame(
{
"record_id": ["123"],
- "patient_id": [555],
+ "patient_id": [self.patient_idnum.idnum_value],
"redcap_repeat_instrument": ["bmi"],
"redcap_repeat_instance": [1],
}
@@ -1281,11 +1122,6 @@ class IncorrectPatientIdRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_when_patient_id_is_incorrect(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1294,7 +1130,7 @@ def test_raises_when_patient_id_is_incorrect(self) -> None:
project.export_records.return_value = DataFrame(
{
"record_id": ["123"],
- "patient_id": [555],
+ "patient_id": [self.patient_idnum.idnum_value],
"redcap_repeat_instrument": ["bmi"],
"redcap_repeat_instance": [1],
}
@@ -1327,11 +1163,6 @@ class MissingPatientInstrumentRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_when_instrument_is_missing(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1362,11 +1193,6 @@ class MissingEventRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_for_longitudinal_project(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1400,11 +1226,6 @@ class MissingInstrumentEventRedcapTests(BadConfigurationRedcapTests):
"""
def test_raises_when_instrument_missing_event(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
@@ -1421,21 +1242,12 @@ def test_raises_when_instrument_missing_event(self) -> None:
class AnonymousTaskRedcapTests(RedcapExportTestCase):
- def create_tasks(self) -> None:
- from camcops_server.tasks.apeq_cpft_perinatal import APEQCPFTPerinatal
+ def setUp(self) -> None:
+ super().setUp()
- self.task = APEQCPFTPerinatal()
- self.apply_standard_task_fields(self.task)
- self.task.id = 1
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ self.task = APEQCPFTPerinatalFactory()
def test_raises_when_task_is_anonymous(self) -> None:
- from camcops_server.cc_modules.cc_exportmodels import (
- ExportedTask,
- ExportedTaskRedcap,
- )
-
exported_task = ExportedTask(task=self.task, recipient=self.recipient)
exported_task_redcap = ExportedTaskRedcap(exported_task)
diff --git a/server/camcops_server/cc_modules/tests/cc_report_tests.py b/server/camcops_server/cc_modules/tests/cc_report_tests.py
index c77380933..fda963dc7 100644
--- a/server/camcops_server/cc_modules/tests/cc_report_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_report_tests.py
@@ -25,8 +25,10 @@
"""
+import csv
+import io
import logging
-from typing import Generator, Optional, TYPE_CHECKING
+from typing import Optional, TYPE_CHECKING
from cardinal_pythonlib.classes import classproperty
from cardinal_pythonlib.logs import BraceStyleAdapter
@@ -36,20 +38,17 @@
XlsxResponse,
)
from deform.form import Form
-import pendulum
from pyramid.httpexceptions import HTTPBadRequest
from pyramid.response import Response
from sqlalchemy.orm.query import Query
from sqlalchemy.sql.selectable import SelectBase
from camcops_server.cc_modules.cc_report import (
- AverageScoreReport,
get_all_report_classes,
PlainReportType,
Report,
)
from camcops_server.cc_modules.cc_unittest import (
- BasicDatabaseTestCase,
DemoDatabaseTestCase,
DemoRequestTestCase,
)
@@ -62,8 +61,6 @@
ReportParamForm,
ReportParamSchema,
)
- from camcops_server.cc_modules.cc_patient import Patient
- from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
from camcops_server.cc_modules.cc_request import (
CamcopsRequest,
)
@@ -142,81 +139,6 @@ def test_reports(self) -> None:
pass
-class AverageScoreReportTestCase(BasicDatabaseTestCase):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.patient_id_sequence = self.get_patient_id()
- self.task_id_sequence = self.get_task_id()
- self.patient_idnum_id_sequence = self.get_patient_idnum_id()
-
- def setUp(self) -> None:
- super().setUp()
-
- self.report = self.create_report()
-
- def create_report(self) -> AverageScoreReport:
- raise NotImplementedError(
- "Report TestCase needs to implement create_report"
- )
-
- @staticmethod
- def get_patient_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- @staticmethod
- def get_task_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- @staticmethod
- def get_patient_idnum_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- def create_patient(self, idnum_value: int = 333) -> "Patient":
- from camcops_server.cc_modules.cc_patient import Patient
-
- patient = Patient()
- patient.id = next(self.patient_id_sequence)
- self.apply_standard_db_fields(patient)
-
- patient.forename = f"Forename {patient.id}"
- patient.surname = f"Surname {patient.id}"
- patient.dob = pendulum.parse("1950-01-01")
- self.dbsession.add(patient)
-
- self.create_patient_idnum(patient, idnum_value)
-
- self.dbsession.commit()
-
- return patient
-
- def create_patient_idnum(
- self, patient, idnum_value: int = 333
- ) -> "PatientIdNum":
- from camcops_server.cc_modules.cc_patient import PatientIdNum
-
- patient_idnum = PatientIdNum()
- patient_idnum.id = next(self.patient_idnum_id_sequence)
- self.apply_standard_db_fields(patient_idnum)
- patient_idnum.patient_id = patient.id
- patient_idnum.which_idnum = self.nhs_iddef.which_idnum
- patient_idnum.idnum_value = idnum_value
- self.dbsession.add(patient_idnum)
-
- return patient_idnum
-
-
class TestReport(Report):
# noinspection PyMethodParameters
@classproperty
@@ -278,9 +200,6 @@ def test_render_tsv(self) -> None:
self.assertIn(".tsv", response.content_disposition)
- import csv
- import io
-
reader = csv.reader(
io.StringIO(response.body.decode()), dialect="excel-tab"
)
diff --git a/server/camcops_server/cc_modules/tests/cc_task_tests.py b/server/camcops_server/cc_modules/tests/cc_task_tests.py
index e4fc40f42..a667aa177 100644
--- a/server/camcops_server/cc_modules/tests/cc_task_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_task_tests.py
@@ -34,6 +34,7 @@
from camcops_server.cc_modules.cc_dummy_database import DummyDataInserter
from camcops_server.cc_modules.cc_task import Task
+from camcops_server.cc_modules.cc_testfactories import UserFactory
from camcops_server.cc_modules.cc_unittest import DemoDatabaseTestCase
from camcops_server.cc_modules.cc_validators import validate_task_tablename
@@ -46,9 +47,6 @@
class TaskTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
def test_query_phq9(self) -> None:
self.announce("test_query_phq9")
@@ -79,6 +77,8 @@ def test_all_tasks(self) -> None:
)
from camcops_server.cc_modules.cc_xml import XmlElement
+ user = UserFactory()
+
subclasses = Task.all_subclasses_by_tablename()
tables = [cls.tablename for cls in subclasses]
log.info("Actual task table names: {!r} (n={})", tables, len(tables))
@@ -214,7 +214,7 @@ def test_all_tasks(self) -> None:
t.get_rio_metadata(
req,
which_idnum=1,
- uploading_user_id=self.user.id,
+ uploading_user_id=user.id,
document_type="some_doc_type",
),
str,
diff --git a/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py b/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py
index 3d0cbbfda..2ec6417b2 100644
--- a/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py
@@ -81,7 +81,6 @@ def setUp(self) -> None:
UserGroupMembershipFactory(
group_id=self.group_a.id, user_id=self.jim.id, may_run_reports=True
)
- self.dbsession.commit()
self.num_jim_tasks = self.num_01_oct_2022_bmi_tasks = 2
self.num_freda_tasks = (
@@ -214,9 +213,9 @@ def test_task_counts_by_year_and_month(self) -> None:
# Default is by year and month but better to be explicit
self.req.add_get_params(
{
- ViewParam.BY_YEAR: True,
- ViewParam.BY_MONTH: True,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "true",
+ ViewParam.BY_MONTH: "true",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
@@ -291,9 +290,9 @@ def test_task_counts_by_year(self) -> None:
self.req.add_get_params(
{
- ViewParam.BY_YEAR: True,
- ViewParam.BY_MONTH: False,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "true",
+ ViewParam.BY_MONTH: "false",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
@@ -351,10 +350,10 @@ def test_task_counts_by_task(self) -> None:
self.req.add_get_params(
{
- ViewParam.BY_YEAR: False,
- ViewParam.BY_MONTH: False,
- ViewParam.BY_TASK: True,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "false",
+ ViewParam.BY_MONTH: "false",
+ ViewParam.BY_TASK: "true",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
@@ -393,11 +392,11 @@ def test_task_counts_by_user(self) -> None:
self.req.add_get_params(
{
- ViewParam.BY_YEAR: False,
- ViewParam.BY_MONTH: False,
- ViewParam.BY_TASK: False,
- ViewParam.BY_USER: True,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "false",
+ ViewParam.BY_MONTH: "false",
+ ViewParam.BY_TASK: "false",
+ ViewParam.BY_USER: "true",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
@@ -434,11 +433,11 @@ def test_total_task_count_for_superuser(self) -> None:
self.req.add_get_params(
{
- ViewParam.BY_YEAR: False,
- ViewParam.BY_MONTH: False,
- ViewParam.BY_TASK: False,
- ViewParam.BY_USER: False,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "false",
+ ViewParam.BY_MONTH: "false",
+ ViewParam.BY_TASK: "false",
+ ViewParam.BY_USER: "false",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
@@ -477,11 +476,11 @@ def test_task_counts_by_day_of_month(self) -> None:
self.req.add_get_params(
{
- ViewParam.BY_YEAR: False,
- ViewParam.BY_MONTH: False,
- ViewParam.BY_DAY_OF_MONTH: True,
- ViewParam.BY_TASK: False,
- ViewParam.VIA_INDEX: self.via_index,
+ ViewParam.BY_YEAR: "false",
+ ViewParam.BY_MONTH: "false",
+ ViewParam.BY_DAY_OF_MONTH: "true",
+ ViewParam.BY_TASK: "false",
+ ViewParam.VIA_INDEX: "true" if self.via_index else "false",
}
)
diff --git a/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py b/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py
index 96ef47d4d..e64ec9883 100644
--- a/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py
@@ -35,13 +35,17 @@
from camcops_server.cc_modules.cc_taskschedule import (
PatientTaskSchedule,
PatientTaskScheduleEmail,
- TaskSchedule,
TaskScheduleItem,
)
-from camcops_server.cc_modules.cc_unittest import (
- DemoDatabaseTestCase,
- DemoRequestTestCase,
+from camcops_server.cc_modules.cc_testfactories import (
+ EmailFactory,
+ PatientTaskScheduleEmailFactory,
+ PatientTaskScheduleFactory,
+ ServerCreatedPatientFactory,
+ TaskScheduleFactory,
+ TaskScheduleItemFactory,
)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
# =============================================================================
@@ -49,38 +53,24 @@
# =============================================================================
-class TaskScheduleTests(DemoDatabaseTestCase):
+class TaskScheduleTests(DemoRequestTestCase):
def test_deleting_deletes_related_objects(self) -> None:
- schedule = TaskSchedule()
- schedule.group_id = self.group.id
- self.dbsession.add(schedule)
- self.dbsession.flush()
-
- item = TaskScheduleItem()
- item.schedule_id = schedule.id
- item.task_table_name = "ace3"
- item.due_from = Duration(days=30)
- item.due_by = Duration(days=60)
- self.dbsession.add(item)
- self.dbsession.flush()
-
- patient = self.create_patient()
-
- pts = PatientTaskSchedule()
- pts.schedule_id = schedule.id
- pts.patient_pk = patient.pk
- self.dbsession.add(pts)
- self.dbsession.flush()
-
- email = Email()
- self.dbsession.add(email)
- self.dbsession.flush()
-
- pts_email = PatientTaskScheduleEmail()
- pts_email.email_id = email.id
- pts_email.patient_task_schedule_id = pts.id
- self.dbsession.add(pts_email)
- self.dbsession.commit()
+ patient = ServerCreatedPatientFactory()
+ schedule = TaskScheduleFactory(group=patient._group)
+
+ item = TaskScheduleItemFactory(
+ task_schedule=schedule,
+ task_table_name="ace3",
+ )
+
+ pts = PatientTaskScheduleFactory(
+ task_schedule=schedule,
+ patient=patient,
+ )
+
+ pts_email = PatientTaskScheduleEmailFactory(
+ patient_task_schedule=pts,
+ )
self.assertIsNotNone(
self.dbsession.query(TaskScheduleItem)
@@ -101,7 +91,7 @@ def test_deleting_deletes_related_objects(self) -> None:
)
self.assertIsNotNone(
self.dbsession.query(Email)
- .filter(Email.id == email.id)
+ .filter(Email.id == pts_email.email.id)
.one_or_none()
)
@@ -127,114 +117,73 @@ def test_deleting_deletes_related_objects(self) -> None:
)
self.assertIsNone(
self.dbsession.query(Email)
- .filter(Email.id == email.id)
+ .filter(Email.id == pts_email.email.id)
.one_or_none()
)
class TaskScheduleItemTests(DemoRequestTestCase):
def test_description_shows_shortname_and_number_of_days(self) -> None:
- item = TaskScheduleItem()
- item.task_table_name = "bmi"
- item.due_from = Duration(days=30)
-
+ item = TaskScheduleItemFactory(
+ task_table_name="bmi",
+ due_from=Duration(days=30),
+ )
self.assertEqual(item.description(self.req), "BMI @ 30 days")
def test_description_with_no_durations(self) -> None:
- item = TaskScheduleItem()
- item.task_table_name = "bmi"
-
+ item = TaskScheduleItemFactory(task_table_name="bmi")
self.assertEqual(item.description(self.req), "BMI @ ? days")
def test_due_within_calculated_from_due_by_and_due_from(self) -> None:
- item = TaskScheduleItem()
- item.due_from = Duration(days=30)
- item.due_by = Duration(days=50)
-
+ item = TaskScheduleItemFactory(
+ due_from=Duration(days=30),
+ due_by=Duration(days=50),
+ )
self.assertEqual(item.due_within.in_days(), 20)
def test_due_within_is_none_when_missing_due_by(self) -> None:
- item = TaskScheduleItem()
- item.due_from = Duration(days=30)
-
+ item = TaskScheduleItemFactory(due_from=Duration(days=30))
self.assertIsNone(item.due_within)
def test_due_within_calculated_when_missing_due_from(self) -> None:
- item = TaskScheduleItem()
- item.due_by = Duration(days=30)
-
+ item = TaskScheduleItemFactory(due_by=Duration(days=30))
self.assertEqual(item.due_within.in_days(), 30)
-class PatientTaskScheduleTests(DemoDatabaseTestCase):
- def setUp(self) -> None:
- super().setUp()
-
- import datetime
-
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.dbsession.add(self.schedule)
-
- self.patient = self.create_patient(
- id=1,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- )
-
- self.pts = PatientTaskSchedule()
- self.pts.schedule_id = self.schedule.id
- self.pts.patient_pk = self.patient.pk
- self.dbsession.add(self.pts)
- self.dbsession.flush()
-
+class PatientTaskScheduleTests(DemoRequestTestCase):
def test_email_body_contains_access_key(self) -> None:
- self.schedule.email_template = "{access_key}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{access_key}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
self.assertIn(
- f"{self.patient.uuid_as_proquint}", self.pts.email_body(self.req)
+ f"{pts.patient.uuid_as_proquint}", pts.email_body(self.req)
)
def test_email_body_contains_server_url(self) -> None:
- self.schedule.email_template = "{server_url}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{server_url}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
expected_url = self.req.route_url(Routes.CLIENT_API)
- self.assertIn(f"{expected_url}", self.pts.email_body(self.req))
+ self.assertIn(f"{expected_url}", pts.email_body(self.req))
def test_email_body_contains_patient_forename(self) -> None:
- self.schedule.email_template = "{forename}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{forename}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
- self.assertIn(
- f"{self.pts.patient.forename}", self.pts.email_body(self.req)
- )
+ self.assertIn(f"{pts.patient.forename}", pts.email_body(self.req))
def test_email_body_contains_patient_surname(self) -> None:
- self.schedule.email_template = "{surname}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{surname}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
- self.assertIn(
- f"{self.pts.patient.surname}", self.pts.email_body(self.req)
- )
+ self.assertIn(f"{pts.patient.surname}", pts.email_body(self.req))
def test_email_body_contains_android_launch_url(self) -> None:
- self.schedule.email_template = "{android_launch_url}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{android_launch_url}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
- url = self.pts.email_body(self.req)
+ url = pts.email_body(self.req)
(scheme, netloc, path, query, fragment) = urlsplit(url)
self.assertEqual(scheme, UriSchemes.HTTP)
self.assertEqual(netloc, "camcops.org")
@@ -246,15 +195,14 @@ def test_email_body_contains_android_launch_url(self) -> None:
[self.req.route_url(Routes.CLIENT_API)],
)
self.assertEqual(
- query_dict["default_access_key"], [self.patient.uuid_as_proquint]
+ query_dict["default_access_key"], [pts.patient.uuid_as_proquint]
)
def test_email_body_contains_ios_launch_url(self) -> None:
- self.schedule.email_template = "{ios_launch_url}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{ios_launch_url}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
- url = self.pts.email_body(self.req)
+ url = pts.email_body(self.req)
(scheme, netloc, path, query, fragment) = urlsplit(url)
self.assertEqual(scheme, "camcops")
self.assertEqual(netloc, "camcops.org")
@@ -266,59 +214,33 @@ def test_email_body_contains_ios_launch_url(self) -> None:
[self.req.route_url(Routes.CLIENT_API)],
)
self.assertEqual(
- query_dict["default_access_key"], [self.patient.uuid_as_proquint]
+ query_dict["default_access_key"], [pts.patient.uuid_as_proquint]
)
def test_email_body_disallows_invalid_template(self) -> None:
- self.schedule.email_template = "{foobar}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{foobar}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
with self.assertRaises(KeyError):
- self.pts.email_body(self.req)
+ pts.email_body(self.req)
def test_email_body_disallows_accessing_properties(self) -> None:
- self.schedule.email_template = "{server_url.__class__}"
- self.dbsession.add(self.schedule)
- self.dbsession.flush()
+ schedule = TaskScheduleFactory(email_template="{server_url.__class__}")
+ pts = PatientTaskScheduleFactory(task_schedule=schedule)
with self.assertRaises(KeyError):
- self.pts.email_body(self.req)
+ pts.email_body(self.req)
def test_email_sent_false_for_no_emails(self) -> None:
- self.assertFalse(self.pts.email_sent)
+ pts = PatientTaskScheduleFactory()
+ self.assertFalse(pts.email_sent)
def test_email_sent_false_for_one_unsent_email(self) -> None:
- email1 = Email()
- email1.sent = False
- self.dbsession.add(email1)
- self.dbsession.flush()
- pts_email1 = PatientTaskScheduleEmail()
- pts_email1.email_id = email1.id
- pts_email1.patient_task_schedule_id = self.pts.id
- self.dbsession.add(pts_email1)
- self.dbsession.commit()
-
- self.assertFalse(self.pts.email_sent)
+ email1 = EmailFactory(sent=False)
+ pts_email1 = PatientTaskScheduleEmailFactory(email=email1)
+ self.assertFalse(pts_email1.patient_task_schedule.email_sent)
def test_email_sent_true_for_one_sent_email(self) -> None:
- email1 = Email()
- email1.sent = False
- self.dbsession.add(email1)
- self.dbsession.flush()
- pts_email1 = PatientTaskScheduleEmail()
- pts_email1.email_id = email1.id
- pts_email1.patient_task_schedule_id = self.pts.id
- self.dbsession.add(pts_email1)
-
- email2 = Email()
- email2.sent = True
- self.dbsession.add(email2)
- self.dbsession.flush()
- pts_email2 = PatientTaskScheduleEmail()
- pts_email2.email_id = email2.id
- pts_email2.patient_task_schedule_id = self.pts.id
- self.dbsession.add(pts_email2)
- self.dbsession.commit()
-
- self.assertTrue(self.pts.email_sent)
+ email1 = EmailFactory(sent=True)
+ pts_email1 = PatientTaskScheduleEmailFactory(email=email1)
+ self.assertTrue(pts_email1.patient_task_schedule.email_sent)
diff --git a/server/camcops_server/cc_modules/tests/cc_user_tests.py b/server/camcops_server/cc_modules/tests/cc_user_tests.py
index 9962f7fc3..96ec8d422 100644
--- a/server/camcops_server/cc_modules/tests/cc_user_tests.py
+++ b/server/camcops_server/cc_modules/tests/cc_user_tests.py
@@ -33,9 +33,13 @@
OBSCURE_PHONE_ASTERISKS,
)
from camcops_server.cc_modules.cc_group import Group
+from camcops_server.cc_modules.cc_testfactories import (
+ GroupFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
+)
from camcops_server.cc_modules.cc_unittest import (
- BasicDatabaseTestCase,
- DemoDatabaseTestCase,
+ DemoRequestTestCase,
)
from camcops_server.cc_modules.cc_user import (
SecurityAccountLockout,
@@ -49,13 +53,15 @@
# =============================================================================
-class UserTests(DemoDatabaseTestCase):
+class UserTests(DemoRequestTestCase):
"""
Unit tests.
"""
def test_user(self) -> None:
- self.announce("test_user")
+ UserFactory()
+ GroupFactory()
+
req = self.req
SecurityAccountLockout.delete_old_account_lockouts(req)
@@ -134,7 +140,7 @@ def test_partial_email(self) -> None:
("very.unusual.”@”.unusual.com@example.com", f"v{a}m@example.com"),
)
- user = self.create_user()
+ user = UserFactory()
for email, expected_partial in tests:
user.email = email
@@ -143,31 +149,36 @@ def test_partial_email(self) -> None:
)
def test_partial_phone_number(self) -> None:
- user = self.create_user()
# https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501
- user.phone_number = phonenumbers.parse("+447700900123")
+ user = UserFactory(phone_number=phonenumbers.parse("+447700900123"))
a = OBSCURE_PHONE_ASTERISKS
self.assertEqual(user.partial_phone_number, f"{a}23")
-class UserPermissionTests(BasicDatabaseTestCase):
+class UserPermissionTests(DemoRequestTestCase):
def setUp(self) -> None:
super().setUp()
# Deliberately not in alphabetical order to test sorting
- self.group_c = self.create_group("groupc")
- self.group_b = self.create_group("groupb")
- self.group_a = self.create_group("groupa")
- self.group_d = self.create_group("groupd")
- self.dbsession.flush()
+ self.group_c = GroupFactory(name="groupc")
+ self.group_b = GroupFactory(name="groupb")
+ self.group_a = GroupFactory(name="groupa")
+ self.group_d = GroupFactory(name="groupd")
def test_groups_user_may_manage_patients_in(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_manage_patients=True)
- self.create_membership(user, self.group_c, may_manage_patients=True)
- self.create_membership(user, self.group_a, may_manage_patients=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_manage_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_manage_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ may_manage_patients=False,
+ )
self.assertEqual(
[self.group_c, self.group_d],
@@ -175,12 +186,17 @@ def test_groups_user_may_manage_patients_in(self) -> None:
)
def test_groups_user_may_email_patients_in(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_email_patients=True)
- self.create_membership(user, self.group_c, may_email_patients=True)
- self.create_membership(user, self.group_a, may_email_patients=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_email_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_email_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_email_patients=False
+ )
self.assertEqual(
[self.group_c, self.group_d],
@@ -188,12 +204,17 @@ def test_groups_user_may_email_patients_in(self) -> None:
)
def test_ids_of_groups_user_may_report_on(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_run_reports=False)
- self.create_membership(user, self.group_c, may_run_reports=True)
- self.create_membership(user, self.group_d, may_run_reports=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_run_reports=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_run_reports=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_run_reports=True
+ )
ids = user.ids_of_groups_user_may_report_on
@@ -203,8 +224,7 @@ def test_ids_of_groups_user_may_report_on(self) -> None:
self.assertNotIn(self.group_b.id, ids)
def test_ids_of_groups_superuser_may_report_on(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
ids = user.ids_of_groups_user_may_report_on
@@ -214,12 +234,17 @@ def test_ids_of_groups_superuser_may_report_on(self) -> None:
self.assertIn(self.group_d.id, ids)
def test_ids_of_groups_user_is_admin_for(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
ids = user.ids_of_groups_user_is_admin_for
@@ -229,8 +254,7 @@ def test_ids_of_groups_user_is_admin_for(self) -> None:
self.assertNotIn(self.group_b.id, ids)
def test_ids_of_groups_superuser_is_admin_for(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
ids = user.ids_of_groups_user_is_admin_for
@@ -240,12 +264,17 @@ def test_ids_of_groups_superuser_is_admin_for(self) -> None:
self.assertIn(self.group_d.id, ids)
def test_names_of_groups_user_is_admin_for(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
names = user.names_of_groups_user_is_admin_for
@@ -255,8 +284,7 @@ def test_names_of_groups_user_is_admin_for(self) -> None:
self.assertNotIn(self.group_b.name, names)
def test_names_of_groups_superuser_is_admin_for(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
names = user.names_of_groups_user_is_admin_for
@@ -266,33 +294,41 @@ def test_names_of_groups_superuser_is_admin_for(self) -> None:
self.assertIn(self.group_d.name, names)
def test_groups_user_is_admin_for(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertEqual(
[self.group_c, self.group_d], user.groups_user_is_admin_for
)
def test_user_may_administer_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertFalse(user.may_administer_group(self.group_a.id))
self.assertTrue(user.may_administer_group(self.group_c.id))
self.assertTrue(user.may_administer_group(self.group_d.id))
def test_superuser_may_administer_group(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_administer_group(self.group_a.id))
self.assertTrue(user.may_administer_group(self.group_b.id))
@@ -300,48 +336,68 @@ def test_superuser_may_administer_group(self) -> None:
self.assertTrue(user.may_administer_group(self.group_d.id))
def test_groups_user_may_dump(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_dump_data=True)
- self.create_membership(user, self.group_c, may_dump_data=True)
- self.create_membership(user, self.group_a, may_dump_data=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_dump_data=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_dump_data=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_dump_data=False
+ )
self.assertEqual(
[self.group_c, self.group_d], user.groups_user_may_dump
)
def test_groups_user_may_report_on(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_run_reports=True)
- self.create_membership(user, self.group_c, may_run_reports=True)
- self.create_membership(user, self.group_a, may_run_reports=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_run_reports=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_run_reports=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_run_reports=False
+ )
self.assertEqual(
[self.group_c, self.group_d], user.groups_user_may_report_on
)
def test_groups_user_may_upload_into(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_upload=True)
- self.create_membership(user, self.group_c, may_upload=True)
- self.create_membership(user, self.group_a, may_upload=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_upload=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_upload=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=False
+ )
self.assertEqual(
[self.group_c, self.group_d], user.groups_user_may_upload_into
)
def test_groups_user_may_add_special_notes(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, may_add_notes=True)
- self.create_membership(user, self.group_c, may_add_notes=True)
- self.create_membership(user, self.group_a, may_add_notes=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_add_notes=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_add_notes=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_add_notes=False
+ )
self.assertEqual(
[self.group_c, self.group_d],
@@ -349,17 +405,22 @@ def test_groups_user_may_add_special_notes(self) -> None:
)
def test_groups_user_may_see_all_pts_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(
- user, self.group_d, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_d.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_c, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_c.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_a, view_all_patients_when_unfiltered=False
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ view_all_patients_when_unfiltered=False,
)
self.assertEqual(
@@ -368,252 +429,259 @@ def test_groups_user_may_see_all_pts_when_unfiltered(self) -> None:
)
def test_is_a_group_admin(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertTrue(user.is_a_groupadmin)
def test_is_not_a_group_admin(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, groupadmin=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=False
+ )
self.assertFalse(user.is_a_groupadmin)
def test_authorized_as_groupadmin(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertTrue(user.authorized_as_groupadmin)
def test_not_authorized_as_groupadmin(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_d, groupadmin=False)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=False
+ )
self.assertFalse(user.authorized_as_groupadmin)
def test_superuser_authorized_as_groupadmin(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_as_groupadmin)
def test_membership_for_group_id(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- ugm = self.create_membership(user, self.group_a)
+ ugm = UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id
+ )
self.assertEqual(user.membership_for_group_id(self.group_a.id), ugm)
def test_no_membership_for_group_id(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertIsNone(user.membership_for_group_id(self.group_a.id))
def test_may_use_webviewer(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_use_webviewer=False)
- self.create_membership(user, self.group_c, may_use_webviewer=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_use_webviewer=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_use_webviewer=True
+ )
self.assertTrue(user.may_use_webviewer)
def test_may_not_use_webviewer(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.may_use_webviewer)
def test_superuser_may_use_webviewer(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_use_webviewer)
def test_authorized_to_add_special_note(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_c, may_add_notes=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_add_notes=True
+ )
self.assertTrue(user.authorized_to_add_special_note(self.group_c.id))
def test_not_authorized_to_add_special_note(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_c, may_add_notes=False)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_add_notes=False
+ )
self.assertFalse(user.authorized_to_add_special_note(self.group_c.id))
def test_superuser_authorized_to_add_special_note(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_to_add_special_note(self.group_c.id))
def test_groupadmin_authorized_to_erase_tasks(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_c, groupadmin=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
self.assertTrue(user.authorized_to_erase_tasks(self.group_c.id))
def test_non_member_not_authorized_to_erase_tasks(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=True
+ )
self.assertFalse(user.authorized_to_erase_tasks(self.group_c.id))
def test_non_admin_not_authorized_to_erase_tasks(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_c)
- self.dbsession.commit()
+ UserGroupMembershipFactory(user_id=user.id, group_id=self.group_c.id)
self.assertFalse(user.authorized_to_erase_tasks(self.group_c.id))
def test_superuser_authorized_to_erase_tasks(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_to_erase_tasks(self.group_c.id))
def test_authorized_to_dump(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_dump_data=False)
- self.create_membership(user, self.group_c, may_dump_data=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_dump_data=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_dump_data=True
+ )
self.assertTrue(user.authorized_to_dump)
def test_not_authorized_to_dump(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.authorized_to_dump)
def test_superuser_authorized_to_dump(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_to_dump)
def test_authorized_for_reports(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_run_reports=False)
- self.create_membership(user, self.group_c, may_run_reports=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_run_reports=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_run_reports=True
+ )
self.assertTrue(user.authorized_for_reports)
def test_not_authorized_for_reports(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.authorized_for_reports)
def test_superuser_authorized_for_reports(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_for_reports)
def test_may_view_all_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(
- user, self.group_a, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_c, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_c.id,
+ view_all_patients_when_unfiltered=True,
)
- self.dbsession.commit()
self.assertTrue(user.may_view_all_patients_when_unfiltered)
def test_may_not_view_all_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(
- user, self.group_a, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_c, view_all_patients_when_unfiltered=False
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_c.id,
+ view_all_patients_when_unfiltered=False,
)
- self.dbsession.commit()
self.assertFalse(user.may_view_all_patients_when_unfiltered)
def test_superuser_may_view_all_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_view_all_patients_when_unfiltered)
def test_may_view_no_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertTrue(user.may_view_no_patients_when_unfiltered)
def test_may_not_view_no_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(
- user, self.group_a, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_c, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_c.id,
+ view_all_patients_when_unfiltered=True,
)
- self.dbsession.commit()
self.assertFalse(user.may_view_no_patients_when_unfiltered)
def test_superuser_may_not_view_no_patients_when_unfiltered(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertFalse(user.may_view_no_patients_when_unfiltered)
def test_group_ids_that_nonsuperuser_may_see_when_unfiltered(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(
- user, self.group_a, view_all_patients_when_unfiltered=False
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ view_all_patients_when_unfiltered=False,
)
- self.create_membership(
- user, self.group_c, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_c.id,
+ view_all_patients_when_unfiltered=True,
)
- self.create_membership(
- user, self.group_d, view_all_patients_when_unfiltered=True
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_d.id,
+ view_all_patients_when_unfiltered=True,
)
ids = user.group_ids_nonsuperuser_may_see_when_unfiltered()
@@ -624,165 +692,167 @@ def test_group_ids_that_nonsuperuser_may_see_when_unfiltered(self) -> None:
self.assertNotIn(self.group_b.id, ids)
def test_may_upload_to_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_upload=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=True
+ )
self.assertTrue(user.may_upload_to_group(self.group_a.id))
def test_may_not_upload_to_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_upload=False)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=False
+ )
self.assertFalse(user.may_upload_to_group(self.group_a.id))
def test_superuser_may_upload_to_group(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_upload_to_group(self.group_a.id))
def test_may_upload_to_upload_group(self) -> None:
- user = self.create_user(
- username="test", upload_group_id=self.group_a.id
- )
- self.dbsession.flush()
+ user = UserFactory(upload_group_id=self.group_a.id)
- self.create_membership(user, self.group_a, may_upload=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=True
+ )
self.assertTrue(user.may_upload)
def test_may_not_upload_with_no_upload_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_upload=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=True
+ )
self.assertFalse(user.may_upload)
def test_may_not_upload_with_upload_group_but_no_permission(self) -> None:
- user = self.create_user(
- username="test", upload_group_id=self.group_a.id
- )
- self.dbsession.flush()
+ user = UserFactory(upload_group_id=self.group_a.id)
- self.create_membership(user, self.group_a, may_upload=False)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_upload=False
+ )
self.assertFalse(user.may_upload)
def test_may_register_devices_with_upload_group(self) -> None:
- user = self.create_user(
- username="test", upload_group_id=self.group_a.id
- )
- self.dbsession.flush()
+ user = UserFactory(upload_group_id=self.group_a.id)
- self.create_membership(user, self.group_a, may_register_devices=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ may_register_devices=True,
+ )
self.assertTrue(user.may_register_devices)
def test_may_not_register_devices_with_no_upload_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.may_register_devices)
def test_may_not_register_devices_with_upload_group_but_no_permission(
self,
) -> None:
- user = self.create_user(
- username="test", upload_group_id=self.group_a.id
- )
- self.dbsession.flush()
+ user = UserFactory(upload_group_id=self.group_a.id)
- self.create_membership(user, self.group_a, may_register_devices=False)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ may_register_devices=False,
+ )
self.assertFalse(user.may_register_devices)
def test_superuser_may_register_devices_with_upload_group(self) -> None:
- user = self.create_user(
- username="test", upload_group_id=self.group_a.id, superuser=True
- )
- self.dbsession.flush()
+ user = UserFactory(upload_group_id=self.group_a.id, superuser=True)
self.assertTrue(user.may_register_devices)
def test_superuser_may_not_register_devices_with_no_upload_group(
self,
) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertFalse(user.may_register_devices)
def test_authorized_to_manage_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_manage_patients=False)
- self.create_membership(user, self.group_c, may_manage_patients=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ may_manage_patients=False,
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_manage_patients=True
+ )
self.assertTrue(user.authorized_to_manage_patients)
def test_not_authorized_to_manage_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.authorized_to_manage_patients)
def test_groupadmin_authorized_to_manage_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=True
+ )
self.assertTrue(user.authorized_to_manage_patients)
def test_superuser_authorized_to_manage_patients(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_to_manage_patients)
def test_user_may_manage_patients_in_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_manage_patients=False)
- self.create_membership(user, self.group_c, may_manage_patients=True)
- self.create_membership(user, self.group_d, may_manage_patients=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group_a.id,
+ may_manage_patients=False,
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_manage_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_manage_patients=True
+ )
self.assertFalse(user.may_manage_patients_in_group(self.group_a.id))
self.assertTrue(user.may_manage_patients_in_group(self.group_c.id))
self.assertTrue(user.may_manage_patients_in_group(self.group_d.id))
def test_groupadmin_may_manage_patients_in_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertFalse(user.may_manage_patients_in_group(self.group_a.id))
self.assertTrue(user.may_manage_patients_in_group(self.group_c.id))
self.assertTrue(user.may_manage_patients_in_group(self.group_d.id))
def test_superuser_may_manage_patients_in_group(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_manage_patients_in_group(self.group_a.id))
self.assertTrue(user.may_manage_patients_in_group(self.group_b.id))
@@ -790,66 +860,104 @@ def test_superuser_may_manage_patients_in_group(self) -> None:
self.assertTrue(user.may_manage_patients_in_group(self.group_d.id))
def test_authorized_to_email_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_email_patients=False)
- self.create_membership(user, self.group_c, may_email_patients=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_email_patients=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_email_patients=True
+ )
self.assertTrue(user.authorized_to_email_patients)
def test_not_authorized_to_email_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
self.assertFalse(user.authorized_to_email_patients)
def test_groupadmin_authorized_to_email_patients(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=True)
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=True
+ )
self.assertTrue(user.authorized_to_email_patients)
def test_superuser_authorized_to_email_patients(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.authorized_to_email_patients)
def test_user_may_email_patients_in_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, may_email_patients=False)
- self.create_membership(user, self.group_c, may_email_patients=True)
- self.create_membership(user, self.group_d, may_email_patients=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, may_email_patients=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, may_email_patients=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, may_email_patients=True
+ )
self.assertFalse(user.may_email_patients_in_group(self.group_a.id))
self.assertTrue(user.may_email_patients_in_group(self.group_c.id))
self.assertTrue(user.may_email_patients_in_group(self.group_d.id))
def test_groupadmin_may_email_patients_in_group(self) -> None:
- user = self.create_user(username="test")
- self.dbsession.flush()
+ user = UserFactory()
- self.create_membership(user, self.group_a, groupadmin=False)
- self.create_membership(user, self.group_c, groupadmin=True)
- self.create_membership(user, self.group_d, groupadmin=True)
- self.dbsession.commit()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_a.id, groupadmin=False
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_c.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group_d.id, groupadmin=True
+ )
self.assertFalse(user.may_email_patients_in_group(self.group_a.id))
self.assertTrue(user.may_email_patients_in_group(self.group_c.id))
self.assertTrue(user.may_email_patients_in_group(self.group_d.id))
def test_superuser_may_email_patients_in_group(self) -> None:
- user = self.create_user(username="test", superuser=True)
- self.dbsession.flush()
+ user = UserFactory(superuser=True)
self.assertTrue(user.may_email_patients_in_group(self.group_a.id))
self.assertTrue(user.may_email_patients_in_group(self.group_b.id))
self.assertTrue(user.may_email_patients_in_group(self.group_c.id))
self.assertTrue(user.may_email_patients_in_group(self.group_d.id))
+
+
+class SetGroupIdsTests(DemoRequestTestCase):
+ def test_old_group_ids_removed(self) -> None:
+ group_a = GroupFactory()
+ group_b = GroupFactory()
+
+ user = UserFactory()
+
+ UserGroupMembershipFactory(user_id=user.id, group_id=group_a.id)
+ UserGroupMembershipFactory(user_id=user.id, group_id=group_b.id)
+ self.assertEqual(len(user.user_group_memberships), 2)
+
+ user.set_group_ids([])
+ self.dbsession.refresh(user)
+
+ self.assertEqual(len(user.user_group_memberships), 0)
+
+ def test_new_group_ids_added(self) -> None:
+ group_a = GroupFactory()
+ group_b = GroupFactory()
+
+ user = UserFactory()
+
+ self.assertEqual(len(user.user_group_memberships), 0)
+
+ user.set_group_ids([group_a.id, group_b.id])
+ self.dbsession.refresh(user)
+
+ self.assertEqual(len(user.user_group_memberships), 2)
diff --git a/server/camcops_server/cc_modules/tests/client_api_tests.py b/server/camcops_server/cc_modules/tests/client_api_tests.py
index 8ccf083d4..119c99c66 100644
--- a/server/camcops_server/cc_modules/tests/client_api_tests.py
+++ b/server/camcops_server/cc_modules/tests/client_api_tests.py
@@ -38,6 +38,8 @@
from cardinal_pythonlib.nhs import generate_random_nhs_number
from cardinal_pythonlib.sql.literals import sql_quote_string
from cardinal_pythonlib.text import escape_newlines, unescape_newlines
+from pendulum import DateTime as Pendulum, Duration, local, parse
+
from pyramid.response import Response
from camcops_server.cc_modules.cc_client_api_core import (
@@ -49,12 +51,25 @@
UserErrorException,
)
from camcops_server.cc_modules.cc_convert import decode_values
-from camcops_server.cc_modules.cc_ipuse import IpUse
from camcops_server.cc_modules.cc_proquint import uuid_from_proquint
-from camcops_server.cc_modules.cc_unittest import (
- BasicDatabaseTestCase,
- DemoDatabaseTestCase,
+from camcops_server.cc_modules.cc_taskindex import (
+ PatientIdNumIndexEntry,
+ TaskIndexEntry,
+)
+from camcops_server.cc_modules.cc_testfactories import (
+ DeviceFactory,
+ GroupFactory,
+ NHSPatientIdNumFactory,
+ PatientFactory,
+ PatientTaskScheduleFactory,
+ ServerCreatedNHSPatientIdNumFactory,
+ ServerCreatedPatientFactory,
+ TaskScheduleFactory,
+ TaskScheduleItemFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
from camcops_server.cc_modules.cc_user import User
from camcops_server.cc_modules.cc_version import MINIMUM_TABLET_VERSION
from camcops_server.cc_modules.cc_validators import (
@@ -63,11 +78,12 @@
from camcops_server.cc_modules.client_api import (
client_api,
FAILURE_CODE,
+ get_or_create_single_user,
make_single_user_mode_username,
Operations,
SUCCESS_CODE,
)
-
+from camcops_server.tasks.tests.factories import BmiFactory
TEST_NHS_NUMBER = generate_random_nhs_number()
@@ -101,14 +117,8 @@ def get_reply_dict_from_response(response: Response) -> Dict[str, str]:
return {}
-class ClientApiTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class ClientApiTests(DemoRequestTestCase):
def test_client_api_basics(self) -> None:
- self.announce("test_client_api_basics")
-
with self.assertRaises(UserErrorException):
fail_user_error("testmsg")
with self.assertRaises(ServerErrorException):
@@ -171,10 +181,12 @@ def test_client_api_basics(self) -> None:
# TODO: client_api.ClientApiTests: more tests here... ?
def test_non_existent_table_rejected(self) -> None:
+ device = DeviceFactory()
+
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: device.name,
TabletParam.OPERATION: Operations.WHICH_KEYS_TO_SEND,
TabletParam.TABLE: "nonexistent_table",
}
@@ -184,7 +196,6 @@ def test_non_existent_table_rejected(self) -> None:
self.assertEqual(d[TabletParam.SUCCESS], FAILURE_CODE)
def test_client_api_validators(self) -> None:
- self.announce("test_client_api_validators")
for x in class_attribute_names(Operations):
try:
validate_alphanum_underscore(x, self.req)
@@ -192,38 +203,26 @@ def test_client_api_validators(self) -> None:
self.fail(f"Operations.{x} fails validate_alphanum_underscore")
-class PatientRegistrationTests(BasicDatabaseTestCase):
- def test_returns_patient_info(self) -> None:
- import datetime
+class PatientRegistrationTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
- patient = self.create_patient(
- forename="JO",
- surname="PATIENT",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- as_server_patient=True,
- )
+ self.device = DeviceFactory()
- self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- as_server_patient=True,
- )
+ def test_returns_patient_info(self) -> None:
+ patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient)
proquint = patient.uuid_as_proquint
# For type checker
assert proquint is not None
- assert self.other_device.name is not None
+ assert self.device is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -237,43 +236,34 @@ def test_returns_patient_info(self) -> None:
patient_dict = json.loads(reply_dict[TabletParam.PATIENT_INFO])[0]
- self.assertEqual(patient_dict[TabletParam.SURNAME], "PATIENT")
- self.assertEqual(patient_dict[TabletParam.FORENAME], "JO")
- self.assertEqual(patient_dict[TabletParam.SEX], "F")
- self.assertEqual(patient_dict[TabletParam.DOB], "1958-04-19")
- self.assertEqual(patient_dict[TabletParam.ADDRESS], "Address")
- self.assertEqual(patient_dict[TabletParam.GP], "GP")
- self.assertEqual(patient_dict[TabletParam.OTHER], "Other")
+ self.assertEqual(patient_dict[TabletParam.SURNAME], patient.surname)
+ self.assertEqual(patient_dict[TabletParam.FORENAME], patient.forename)
+ self.assertEqual(patient_dict[TabletParam.SEX], patient.sex)
self.assertEqual(
- patient_dict[f"idnum{self.nhs_iddef.which_idnum}"], TEST_NHS_NUMBER
+ patient_dict[TabletParam.DOB], patient.dob.isoformat()
)
-
- def test_creates_user(self) -> None:
- from camcops_server.cc_modules.cc_taskindex import (
- PatientIdNumIndexEntry,
+ self.assertEqual(patient_dict[TabletParam.ADDRESS], patient.address)
+ self.assertEqual(patient_dict[TabletParam.GP], patient.gp)
+ self.assertEqual(patient_dict[TabletParam.OTHER], patient.other)
+ self.assertEqual(
+ patient_dict[f"idnum{idnum.which_idnum}"], idnum.idnum_value
)
- patient = self.create_patient(
- _group_id=self.group.id, as_server_patient=True
- )
- idnum = self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- as_server_patient=True,
- )
+ def test_creates_user(self) -> None:
+ patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
proquint = patient.uuid_as_proquint
# For type checker
assert proquint is not None
- assert self.other_device.name is not None
+ assert self.device.name is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -288,9 +278,7 @@ def test_creates_user(self) -> None:
username = reply_dict[TabletParam.USER]
self.assertEqual(
username,
- make_single_user_mode_username(
- self.other_device.name, patient._pk
- ),
+ make_single_user_mode_username(self.device.name, patient._pk),
)
password = reply_dict[TabletParam.PASSWORD]
self.assertEqual(len(password), 32)
@@ -310,36 +298,26 @@ def test_creates_user(self) -> None:
self.assertTrue(user.may_upload)
def test_does_not_create_user_when_name_exists(self) -> None:
- from camcops_server.cc_modules.cc_taskindex import (
- PatientIdNumIndexEntry,
- )
-
- patient = self.create_patient(
- _group_id=self.group.id, as_server_patient=True
- )
- idnum = self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- as_server_patient=True,
- )
+ patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
proquint = patient.uuid_as_proquint
- user = User(
- username=make_single_user_mode_username(
- self.other_device.name, patient._pk
- )
+ single_user_username = make_single_user_mode_username(
+ self.device.name, patient._pk
+ )
+
+ user = UserFactory(
+ username=single_user_username,
+ password="old password",
+ password__request=self.req,
)
- user.set_password(self.req, "old password")
- self.dbsession.add(user)
- self.dbsession.commit()
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -354,9 +332,7 @@ def test_does_not_create_user_when_name_exists(self) -> None:
username = reply_dict[TabletParam.USER]
self.assertEqual(
username,
- make_single_user_mode_username(
- self.other_device.name, patient._pk
- ),
+ make_single_user_mode_username(self.device.name, patient._pk),
)
password = reply_dict[TabletParam.PASSWORD]
self.assertEqual(len(password), 32)
@@ -377,12 +353,12 @@ def test_does_not_create_user_when_name_exists(self) -> None:
def test_raises_for_invalid_proquint(self) -> None:
# For type checker
- assert self.other_device.name is not None
+ assert self.device.name is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: "invalid",
}
@@ -405,12 +381,12 @@ def test_raises_for_missing_valid_proquint(self) -> None:
# test proquint really is valid (should not raise)
uuid_from_proquint(valid_proquint)
- assert self.other_device.name is not None
+ assert self.device.name is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: valid_proquint,
}
@@ -429,13 +405,13 @@ def test_raises_for_missing_valid_proquint(self) -> None:
def test_raises_when_no_patient_idnums(self) -> None:
# In theory this shouldn't be possible in normal operation as the
# patient cannot be created without any idnums
- patient = self.create_patient(as_server_patient=True)
+ patient = ServerCreatedPatientFactory()
proquint = patient.uuid_as_proquint
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -451,15 +427,13 @@ def test_raises_when_no_patient_idnums(self) -> None:
)
def test_raises_when_patient_not_created_on_server(self) -> None:
- patient = self.create_patient(
- _device_id=self.other_device.id, as_server_patient=True
- )
+ patient = PatientFactory()
proquint = patient.uuid_as_proquint
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -476,49 +450,21 @@ def test_raises_when_patient_not_created_on_server(self) -> None:
)
def test_returns_ip_use_flags(self) -> None:
- import datetime
- from camcops_server.cc_modules.cc_taskindex import (
- PatientIdNumIndexEntry,
- )
-
- patient = self.create_patient(
- forename="JO",
- surname="PATIENT",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- as_server_patient=True,
- )
- idnum = self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- as_server_patient=True,
- )
+ patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
-
- patient.group.ip_use = IpUse()
-
- patient.group.ip_use.commercial = True
- patient.group.ip_use.clinical = True
- patient.group.ip_use.educational = False
- patient.group.ip_use.research = False
-
- self.dbsession.add(patient.group)
- self.dbsession.commit()
+ ip_use = patient.group.ip_use
proquint = patient.uuid_as_proquint
# For type checker
assert proquint is not None
- assert self.other_device.name is not None
+ assert self.device.name is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: self.device.name,
TabletParam.OPERATION: Operations.REGISTER_PATIENT,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -532,123 +478,113 @@ def test_returns_ip_use_flags(self) -> None:
ip_use_info = json.loads(reply_dict[TabletParam.IP_USE_INFO])
- self.assertEqual(ip_use_info[TabletParam.IP_USE_COMMERCIAL], 1)
- self.assertEqual(ip_use_info[TabletParam.IP_USE_CLINICAL], 1)
- self.assertEqual(ip_use_info[TabletParam.IP_USE_EDUCATIONAL], 0)
- self.assertEqual(ip_use_info[TabletParam.IP_USE_RESEARCH], 0)
+ self.assertEqual(
+ ip_use_info[TabletParam.IP_USE_COMMERCIAL], ip_use.commercial
+ )
+ self.assertEqual(
+ ip_use_info[TabletParam.IP_USE_CLINICAL], ip_use.clinical
+ )
+ self.assertEqual(
+ ip_use_info[TabletParam.IP_USE_EDUCATIONAL], ip_use.educational
+ )
+ self.assertEqual(
+ ip_use_info[TabletParam.IP_USE_RESEARCH], ip_use.research
+ )
+
+
+class GetTaskSchedulesTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.group = GroupFactory()
+ user = self.req._debugging_user = UserFactory(
+ upload_group_id=self.group.id,
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id,
+ group_id=self.group.id,
+ may_register_devices=True,
+ )
-class GetTaskSchedulesTests(BasicDatabaseTestCase):
def test_returns_task_schedules(self) -> None:
- from pendulum import DateTime as Pendulum, Duration, local, parse
-
- from camcops_server.cc_modules.cc_taskindex import (
- PatientIdNumIndexEntry,
- TaskIndexEntry,
- )
- from camcops_server.cc_modules.cc_taskschedule import (
- PatientTaskSchedule,
- TaskSchedule,
- TaskScheduleItem,
- )
- from camcops_server.tasks.bmi import Bmi
-
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
-
- schedule2 = TaskSchedule()
- schedule2.group_id = self.group.id
- self.dbsession.add(schedule2)
- self.dbsession.commit()
-
- item1 = TaskScheduleItem()
- item1.schedule_id = schedule1.id
- item1.task_table_name = "phq9"
- item1.due_from = Duration(days=0)
- item1.due_by = Duration(days=7)
- self.dbsession.add(item1)
-
- item2 = TaskScheduleItem()
- item2.schedule_id = schedule1.id
- item2.task_table_name = "bmi"
- item2.due_from = Duration(days=0)
- item2.due_by = Duration(days=8)
- self.dbsession.add(item2)
-
- item3 = TaskScheduleItem()
- item3.schedule_id = schedule1.id
- item3.task_table_name = "phq9"
- item3.due_from = Duration(days=30)
- item3.due_by = Duration(days=37)
- self.dbsession.add(item3)
-
- item4 = TaskScheduleItem()
- item4.schedule_id = schedule1.id
- item4.task_table_name = "gmcpq"
- item4.due_from = Duration(days=30)
- item4.due_by = Duration(days=38)
- self.dbsession.add(item4)
- self.dbsession.commit()
-
- patient = self.create_patient()
- idnum = self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
+ schedule1 = TaskScheduleFactory(group=self.group)
+ schedule2 = TaskScheduleFactory(group=self.group)
+
+ TaskScheduleItemFactory(
+ task_schedule=schedule1,
+ task_table_name="phq9",
+ due_from=Duration(days=0),
+ due_by=Duration(days=7),
+ )
+ TaskScheduleItemFactory(
+ task_schedule=schedule1,
+ task_table_name="bmi",
+ due_from=Duration(days=0),
+ due_by=Duration(days=8),
+ )
+ TaskScheduleItemFactory(
+ task_schedule=schedule1,
+ task_table_name="phq9",
+ due_from=Duration(days=30),
+ due_by=Duration(days=37),
+ )
+ TaskScheduleItemFactory(
+ task_schedule=schedule1,
+ task_table_name="gmcpq",
+ due_from=Duration(days=30),
+ due_by=Duration(days=38),
+ )
+
+ # This is the patient originally created om the server
+ server_patient = ServerCreatedPatientFactory(_group=self.group)
+ server_idnum = ServerCreatedNHSPatientIdNumFactory(
+ patient=server_patient
+ )
+
+ # This is the same patient but from the device
+ patient = PatientFactory(_group=self.group)
+ idnum = NHSPatientIdNumFactory(
+ patient=patient,
+ which_idnum=server_idnum.which_idnum,
+ idnum_value=server_idnum.idnum_value,
)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
- server_patient = self.create_patient(as_server_patient=True)
- _ = self.create_patient_idnum(
- patient_id=server_patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER,
- as_server_patient=True,
+ PatientTaskScheduleFactory(
+ patient=server_patient,
+ task_schedule=schedule1,
+ settings={
+ "bmi": {"bmi_key": "bmi_value"},
+ "phq9": {"phq9_key": "phq9_value"},
+ },
+ start_datetime=local(2020, 7, 31),
)
- schedule_1 = PatientTaskSchedule()
- schedule_1.patient_pk = server_patient.pk
- schedule_1.schedule_id = schedule1.id
- schedule_1.settings = {
- "bmi": {"bmi_key": "bmi_value"},
- "phq9": {"phq9_key": "phq9_value"},
- }
- schedule_1.start_datetime = local(2020, 7, 31)
- self.dbsession.add(schedule_1)
-
- schedule_2 = PatientTaskSchedule()
- schedule_2.patient_pk = server_patient.pk
- schedule_2.schedule_id = schedule2.id
- self.dbsession.add(schedule_2)
-
- bmi = Bmi()
- self.apply_standard_task_fields(bmi)
- bmi.id = 1
- bmi.height_m = 1.83
- bmi.mass_kg = 67.57
- bmi.patient_id = patient.id
- bmi.when_created = local(2020, 8, 1)
- self.dbsession.add(bmi)
- self.dbsession.commit()
+ PatientTaskScheduleFactory(
+ patient=server_patient,
+ task_schedule=schedule2,
+ )
+
+ bmi = BmiFactory(
+ patient=patient,
+ when_created=local(2020, 8, 1),
+ )
self.assertTrue(bmi.is_complete())
TaskIndexEntry.index_task(
bmi, self.dbsession, indexed_at_utc=Pendulum.utcnow()
)
- self.dbsession.commit()
proquint = server_patient.uuid_as_proquint
# For type checker
assert proquint is not None
- assert self.other_device.name is not None
self.req.fake_request_post_from_dict(
{
TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION,
- TabletParam.DEVICE: self.other_device.name,
+ TabletParam.DEVICE: patient._device.name,
TabletParam.OPERATION: Operations.GET_TASK_SCHEDULES,
TabletParam.PATIENT_PROQUINT: proquint,
}
@@ -665,7 +601,7 @@ def test_returns_task_schedules(self) -> None:
self.assertEqual(len(task_schedules), 2)
s = task_schedules[0]
- self.assertEqual(s[TabletParam.TASK_SCHEDULE_NAME], "Test 1")
+ self.assertEqual(s[TabletParam.TASK_SCHEDULE_NAME], schedule1.name)
schedule_items = s[TabletParam.TASK_SCHEDULE_ITEMS]
self.assertEqual(len(schedule_items), 4)
@@ -715,3 +651,74 @@ def test_returns_task_schedules(self) -> None:
# GMCPQ
gmcpq_sched = schedule_items[3]
self.assertTrue(gmcpq_sched[TabletParam.ANONYMOUS])
+
+
+class GetOrCreateSingleUserTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.patient = PatientFactory()
+ self.req._debugging_user = UserFactory()
+
+ def test_user_is_added_to_patient_group(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertIn(self.patient.group.id, user.group_ids)
+
+ def test_user_is_created_with_username(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertEqual(user.username, "test")
+
+ def test_user_is_assigned_password(self) -> None:
+ _, password = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ valid_chars = string.ascii_letters + string.digits + string.punctuation
+ self.assertTrue(all(c in valid_chars for c in password))
+
+ def test_user_upload_group_set(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertEqual(user.upload_group, self.patient.group)
+
+ def test_user_auto_generated_flag_set(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertTrue(user.auto_generated)
+
+ def test_user_is_not_superuser(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertFalse(user.superuser)
+
+ def test_single_patient_pk_set(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertEqual(user.single_patient_pk, self.patient._pk)
+
+ def test_user_may_register_devices(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertTrue(user.user_group_memberships[0].may_register_devices)
+
+ def test_user_may_upload(self) -> None:
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertTrue(user.user_group_memberships[0].may_upload)
+
+ def test_existing_user_is_updated(self) -> None:
+ existing_user = UserFactory(username="test")
+
+ user, _ = get_or_create_single_user(self.req, "test", self.patient)
+ self.dbsession.flush()
+
+ self.assertEqual(user, existing_user)
diff --git a/server/camcops_server/cc_modules/tests/webview_tests.py b/server/camcops_server/cc_modules/tests/webview_tests.py
index fb4bdeb56..0c4fdd23c 100644
--- a/server/camcops_server/cc_modules/tests/webview_tests.py
+++ b/server/camcops_server/cc_modules/tests/webview_tests.py
@@ -35,8 +35,7 @@
from cardinal_pythonlib.classes import class_attribute_names
from cardinal_pythonlib.httpconst import MimeType
-from cardinal_pythonlib.nhs import generate_random_nhs_number
-from pendulum import local
+from pendulum import Duration, local
import phonenumbers
import pyotp
from pyramid.httpexceptions import HTTPBadRequest, HTTPFound
@@ -49,7 +48,6 @@
)
from camcops_server.cc_modules.cc_device import Device
from camcops_server.cc_modules.cc_group import Group
-from camcops_server.cc_modules.cc_membership import UserGroupMembership
from camcops_server.cc_modules.cc_patient import Patient
from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
from camcops_server.cc_modules.cc_pyramid import (
@@ -66,9 +64,27 @@
TaskSchedule,
TaskScheduleItem,
)
+from camcops_server.cc_modules.cc_testfactories import (
+ AnyIdNumGroupFactory,
+ Fake,
+ GroupFactory,
+ NHSIdNumDefinitionFactory,
+ NHSPatientIdNumFactory,
+ PatientFactory,
+ PatientTaskScheduleFactory,
+ RioIdNumDefinitionFactory,
+ ServerCreatedNHSPatientIdNumFactory,
+ ServerCreatedPatientFactory,
+ StudyPatientIdNumFactory,
+ TaskScheduleFactory,
+ TaskScheduleItemFactory,
+ UserFactory,
+ UserGroupMembershipFactory,
+)
from camcops_server.cc_modules.cc_unittest import (
BasicDatabaseTestCase,
DemoDatabaseTestCase,
+ DemoRequestTestCase,
)
from camcops_server.cc_modules.cc_user import (
SecurityAccountLockout,
@@ -79,11 +95,13 @@
validate_alphanum_underscore,
)
from camcops_server.cc_modules.cc_view_classes import FormWizardMixin
+from camcops_server.tasks.tests.factories import BmiFactory
from camcops_server.cc_modules.tests.cc_view_classes_tests import (
TestStateMixin,
)
from camcops_server.cc_modules.webview import (
add_patient,
+ add_user,
AddPatientView,
AddTaskScheduleItemView,
AddTaskScheduleView,
@@ -123,30 +141,13 @@
UTF8 = "utf-8"
-TEST_NHS_NUMBER_1 = generate_random_nhs_number()
-TEST_NHS_NUMBER_2 = generate_random_nhs_number()
-
-# https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501
-# 07700 900000 to 900999 reserved for TV and Radio drama purposes
-# but unfortunately phonenumbers considers these invalid. However, it offers
-# some examples:
-TEST_PHONE_NUMBER = "+{ctry}{tel}".format(
- ctry=phonenumbers.PhoneMetadata.metadata_for_region("GB").country_code,
- tel=phonenumbers.PhoneMetadata.metadata_for_region(
- "GB"
- ).personal_number.example_number,
-)
-
class WebviewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
def test_any_records_use_group_true(self) -> None:
# All tasks created in DemoDatabaseTestCase will be in this group
- self.announce("test_any_records_use_group_true")
- self.assertTrue(any_records_use_group(self.req, self.group))
+ self.assertTrue(
+ any_records_use_group(self.req, self.demo_database_group)
+ )
def test_any_records_use_group_false(self) -> None:
"""
@@ -156,15 +157,11 @@ def test_any_records_use_group_false(self) -> None:
then the base class probably needs to be declared __abstract__. See
DiagnosisItemBase as an example.
"""
- self.announce("test_any_records_use_group_false")
- group = Group()
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory()
self.assertFalse(any_records_use_group(self.req, group))
def test_webview_constant_validators(self) -> None:
- self.announce("test_webview_constant_validators")
for x in class_attribute_names(ViewArg):
try:
validate_alphanum_underscore(x, self.req)
@@ -172,11 +169,7 @@ def test_webview_constant_validators(self) -> None:
self.fail(f"Operations.{x} fails validate_alphanum_underscore")
-class AddTaskScheduleViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class AddTaskScheduleViewTests(BasicDatabaseTestCase):
def test_schedule_form_displayed(self) -> None:
view = AddTaskScheduleView(self.req)
@@ -222,35 +215,29 @@ def test_schedule_is_created(self) -> None:
)
-class EditTaskScheduleViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
- def setUp(self) -> None:
- super().setUp()
-
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
-
+class EditTaskScheduleViewTests(DemoRequestTestCase):
def test_schedule_name_can_be_updated(self) -> None:
+ user = self.req._debugging_user = UserFactory()
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ group_id=group.id, user_id=user.id, groupadmin=True
+ )
+
+ schedule = TaskScheduleFactory(group=group)
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
(ViewParam.NAME, "MOJO"),
- (ViewParam.GROUP_ID, self.group.id),
+ (ViewParam.GROUP_ID, group.id),
(FormAction.SUBMIT, "submit"),
]
)
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.SCHEDULE_ID: str(self.schedule.id)},
+ {ViewParam.SCHEDULE_ID: str(schedule.id)},
set_method_get=False,
)
@@ -269,28 +256,16 @@ def test_schedule_name_can_be_updated(self) -> None:
)
def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None:
- group_a = Group()
- group_a.name = "Group A"
- self.dbsession.add(group_a)
+ group_a = GroupFactory()
+ group_b = GroupFactory()
- group_b = Group()
- group_b.name = "Group B"
- self.dbsession.add(group_b)
- self.dbsession.commit()
-
- group_a_schedule = TaskSchedule()
- group_a_schedule.group_id = group_a.id
- group_a_schedule.name = "Group A schedule"
- self.dbsession.add(group_a_schedule)
- self.dbsession.commit()
+ group_a_schedule = TaskScheduleFactory(group=group_a)
- self.user = User()
- self.user.upload_group_id = group_b.id
- self.user.username = "group b admin"
- self.user.set_password(self.req, "secret123")
- self.dbsession.add(self.user)
- self.dbsession.commit()
- self.req._debugging_user = self.user
+ group_b_user = UserFactory()
+ UserGroupMembershipFactory(
+ group_id=group_b.id, user_id=group_b_user.id, groupadmin=True
+ )
+ self.req._debugging_user = group_b_user
multidict = MultiDict(
[
@@ -298,14 +273,14 @@ def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None:
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
(ViewParam.NAME, "Something else"),
- (ViewParam.GROUP_ID, self.group.id),
+ (ViewParam.GROUP_ID, group_b.id),
(FormAction.SUBMIT, "submit"),
]
)
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.SCHEDULE_ID: str(self.schedule.id)},
+ {ViewParam.SCHEDULE_ID: str(group_a_schedule.id)},
set_method_get=False,
)
@@ -317,21 +292,16 @@ def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None:
self.assertIn("not a group administrator", cm.exception.message)
-class DeleteTaskScheduleViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
- def setUp(self) -> None:
- super().setUp()
-
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
-
+class DeleteTaskScheduleViewTests(DemoRequestTestCase):
def test_schedule_item_is_deleted(self) -> None:
+ user = self.req._debugging_user = UserFactory()
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ group_id=group.id, user_id=user.id, groupadmin=True
+ )
+ schedule = TaskScheduleFactory(group=group)
+ self.assertIsNotNone(self.dbsession.query(TaskSchedule).one_or_none())
+
multidict = MultiDict(
[
("_charset_", UTF8),
@@ -352,7 +322,7 @@ def test_schedule_item_is_deleted(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.SCHEDULE_ID: str(self.schedule.id)},
+ {ViewParam.SCHEDULE_ID: str(schedule.id)},
set_method_get=False,
)
view = DeleteTaskScheduleView(self.req)
@@ -365,25 +335,14 @@ def test_schedule_item_is_deleted(self) -> None:
Routes.VIEW_TASK_SCHEDULES, e.exception.headers["Location"]
)
- item = self.dbsession.query(TaskScheduleItem).one_or_none()
-
- self.assertIsNone(item)
-
+ self.assertIsNone(self.dbsession.query(TaskSchedule).one_or_none())
-class AddTaskScheduleItemViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
+class AddTaskScheduleItemViewTests(BasicDatabaseTestCase):
def setUp(self) -> None:
super().setUp()
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test"
-
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory(group=self.group)
def test_schedule_item_form_displayed(self) -> None:
view = AddTaskScheduleItemView(self.req)
@@ -480,31 +439,19 @@ def test_non_existent_schedule_handled(self) -> None:
view.dispatch()
-class EditTaskScheduleItemViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class EditTaskScheduleItemViewTests(BasicDatabaseTestCase):
def setUp(self) -> None:
- from pendulum import Duration
-
super().setUp()
-
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
-
- self.item = TaskScheduleItem()
- self.item.schedule_id = self.schedule.id
- self.item.task_table_name = "ace3"
- self.item.due_from = Duration(days=30)
- self.item.due_by = Duration(days=60)
- self.dbsession.add(self.item)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory(group=self.group)
def test_schedule_item_is_updated(self) -> None:
+ item = TaskScheduleItemFactory(
+ task_schedule=self.schedule,
+ task_table_name="ace3",
+ due_from=Duration(days=30),
+ due_by=Duration(days=60),
+ )
+
multidict = MultiDict(
[
("_charset_", UTF8),
@@ -529,7 +476,7 @@ def test_schedule_item_is_updated(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.SCHEDULE_ITEM_ID: str(self.item.id)},
+ {ViewParam.SCHEDULE_ITEM_ID: str(item.id)},
set_method_get=False,
)
view = EditTaskScheduleItemView(self.req)
@@ -537,15 +484,22 @@ def test_schedule_item_is_updated(self) -> None:
with self.assertRaises(HTTPFound) as cm:
view.dispatch()
- self.assertEqual(self.item.task_table_name, "bmi")
+ self.assertEqual(item.task_table_name, "bmi")
self.assertEqual(cm.exception.status_code, 302)
self.assertIn(
f"{Routes.VIEW_TASK_SCHEDULE_ITEMS}"
- f"?{ViewParam.SCHEDULE_ID}={self.item.schedule_id}",
+ f"?{ViewParam.SCHEDULE_ID}={item.schedule_id}",
cm.exception.headers["Location"],
)
def test_schedule_item_is_not_updated_on_cancel(self) -> None:
+ item = TaskScheduleItemFactory(
+ task_schedule=self.schedule,
+ task_table_name="ace3",
+ due_from=Duration(days=30),
+ due_by=Duration(days=60),
+ )
+
multidict = MultiDict(
[
("_charset_", UTF8),
@@ -570,7 +524,7 @@ def test_schedule_item_is_not_updated_on_cancel(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.SCHEDULE_ITEM_ID: str(self.item.id)},
+ {ViewParam.SCHEDULE_ITEM_ID: str(item.id)},
set_method_get=False,
)
view = EditTaskScheduleItemView(self.req)
@@ -578,7 +532,7 @@ def test_schedule_item_is_not_updated_on_cancel(self) -> None:
with self.assertRaises(HTTPFound):
view.dispatch()
- self.assertEqual(self.item.task_table_name, "ace3")
+ self.assertEqual(item.task_table_name, "ace3")
def test_non_existent_item_handled(self) -> None:
self.req.add_get_params({ViewParam.SCHEDULE_ITEM_ID: "99999"})
@@ -595,53 +549,37 @@ def test_null_item_handled(self) -> None:
view.dispatch()
def test_get_form_values(self) -> None:
+ item = TaskScheduleItemFactory(
+ task_schedule=self.schedule,
+ task_table_name="ace3",
+ due_from=Duration(days=30),
+ due_by=Duration(days=60),
+ )
view = EditTaskScheduleItemView(self.req)
- view.object = self.item
+ view.object = item
form_values = view.get_form_values()
self.assertEqual(form_values[ViewParam.SCHEDULE_ID], self.schedule.id)
self.assertEqual(
- form_values[ViewParam.TABLE_NAME], self.item.task_table_name
+ form_values[ViewParam.TABLE_NAME], item.task_table_name
)
- self.assertEqual(form_values[ViewParam.DUE_FROM], self.item.due_from)
+ self.assertEqual(form_values[ViewParam.DUE_FROM], item.due_from)
- due_within = self.item.due_by - self.item.due_from
+ due_within = item.due_by - item.due_from
self.assertEqual(form_values[ViewParam.DUE_WITHIN], due_within)
def test_group_a_item_cannot_be_edited_by_group_b_admin(self) -> None:
- from pendulum import Duration
+ group_a = GroupFactory()
+ group_b = GroupFactory()
- group_a = Group()
- group_a.name = "Group A"
- self.dbsession.add(group_a)
-
- group_b = Group()
- group_b.name = "Group B"
- self.dbsession.add(group_b)
- self.dbsession.commit()
-
- group_a_schedule = TaskSchedule()
- group_a_schedule.group_id = group_a.id
- group_a_schedule.name = "Group A schedule"
- self.dbsession.add(group_a_schedule)
- self.dbsession.commit()
-
- group_a_item = TaskScheduleItem()
- group_a_item.schedule_id = group_a_schedule.id
- group_a_item.task_table_name = "ace3"
- group_a_item.due_from = Duration(days=30)
- group_a_item.due_by = Duration(days=60)
- self.dbsession.add(group_a_item)
- self.dbsession.commit()
+ group_b_admin = self.req._debugging_user = UserFactory()
+ UserGroupMembershipFactory(
+ group_id=group_b.id, user_id=group_b_admin.id, groupadmin=True
+ )
- self.user = User()
- self.user.upload_group_id = group_b.id
- self.user.username = "group b admin"
- self.user.set_password(self.req, "secret123")
- self.dbsession.add(self.user)
- self.dbsession.commit()
- self.req._debugging_user = self.user
+ group_a_schedule = TaskScheduleFactory(group=group_a)
+ group_a_item = TaskScheduleItemFactory(task_schedule=group_a_schedule)
view = EditTaskScheduleItemView(self.req)
view.object = group_a_item
@@ -652,25 +590,16 @@ def test_group_a_item_cannot_be_edited_by_group_b_admin(self) -> None:
self.assertIn("not a group administrator", cm.exception.message)
-class DeleteTaskScheduleItemViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class DeleteTaskScheduleItemViewTests(BasicDatabaseTestCase):
def setUp(self) -> None:
super().setUp()
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory(group=self.group)
- self.item = TaskScheduleItem()
- self.item.schedule_id = self.schedule.id
- self.item.task_table_name = "ace3"
- self.dbsession.add(self.item)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory()
+ self.item = TaskScheduleItemFactory(
+ task_schedule=self.schedule, task_table_name="ace3"
+ )
def test_delete_form_displayed(self) -> None:
view = DeleteTaskScheduleItemView(self.req)
@@ -754,10 +683,19 @@ def test_schedule_item_not_deleted_on_cancel(self) -> None:
self.assertIsNotNone(item)
-class EditFinalizedPatientViewTests(BasicDatabaseTestCase):
- """
- Unit tests.
- """
+class EditFinalizedPatientViewTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.group = AnyIdNumGroupFactory()
+ user = self.req._debugging_user = UserFactory()
+
+ UserGroupMembershipFactory(
+ group_id=self.group.id,
+ user_id=user.id,
+ groupadmin=True,
+ view_all_patients_when_unfiltered=True,
+ )
def test_raises_when_patient_does_not_exists(self) -> None:
with self.assertRaises(HTTPBadRequest) as cm:
@@ -769,7 +707,7 @@ def test_raises_when_patient_does_not_exists(self) -> None:
@unittest.skip("Can't save patient in database without group")
def test_raises_when_patient_not_in_a_group(self) -> None:
- patient = self.create_patient(_group_id=None)
+ patient = PatientFactory(_group=None)
self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)})
@@ -779,9 +717,7 @@ def test_raises_when_patient_not_in_a_group(self) -> None:
self.assertEqual(str(cm.exception), "Bad patient: not in a group")
def test_raises_when_not_authorized(self) -> None:
- patient = self.create_patient()
-
- self.req._debugging_user = User()
+ patient = PatientFactory()
with mock.patch.object(
self.req._debugging_user,
@@ -798,11 +734,7 @@ def test_raises_when_not_authorized(self) -> None:
)
def test_raises_when_patient_not_finalized(self) -> None:
- device = Device(name="Not the server device")
- self.req.dbsession.add(device)
- self.req.dbsession.commit()
-
- patient = self.create_patient(id=1, _device_id=device.id, _era=ERA_NOW)
+ patient = PatientFactory(_era=ERA_NOW, _group=self.group)
self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)})
@@ -812,12 +744,23 @@ def test_raises_when_patient_not_finalized(self) -> None:
self.assertIn("Patient is not editable", str(cm.exception))
def test_patient_updated(self) -> None:
- patient = self.create_patient()
+ patient = PatientFactory(_group=self.group)
+ nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient)
self.req.add_get_params(
{ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False
)
+ new_sex = Fake.en_gb.sex()
+ new_forename = Fake.en_gb.forename(new_sex)
+ new_surname = Fake.en_gb.last_name()
+ new_address = Fake.en_gb.address()
+ new_email = Fake.en_gb.email()
+ new_gp = Fake.en_gb.name()
+ new_other = Fake.en_us.paragraph()
+ new_dob = Fake.en_gb.consistent_date_of_birth()
+ new_nhs_number = Fake.en_gb.nhs_number()
+
multidict = MultiDict(
[
("_charset_", UTF8),
@@ -825,22 +768,22 @@ def test_patient_updated(self) -> None:
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
(ViewParam.SERVER_PK, str(patient.pk)),
(ViewParam.GROUP_ID, str(patient.group.id)),
- (ViewParam.FORENAME, "Jo"),
- (ViewParam.SURNAME, "Patient"),
+ (ViewParam.FORENAME, new_forename),
+ (ViewParam.SURNAME, new_surname),
("__start__", "dob:mapping"),
- ("date", "1958-04-19"),
+ ("date", new_dob),
("__end__", "dob:mapping"),
("__start__", "sex:rename"),
- ("deformField7", "X"),
+ ("deformField7", new_sex),
("__end__", "sex:rename"),
- (ViewParam.ADDRESS, "New address"),
- (ViewParam.EMAIL, "newjopatient@example.com"),
- (ViewParam.GP, "New GP"),
- (ViewParam.OTHER, "New other"),
+ (ViewParam.ADDRESS, new_address),
+ (ViewParam.EMAIL, new_email),
+ (ViewParam.GP, new_gp),
+ (ViewParam.OTHER, new_other),
("__start__", "id_references:sequence"),
("__start__", "idnum_sequence:mapping"),
- (ViewParam.WHICH_IDNUM, self.nhs_iddef.which_idnum),
- (ViewParam.IDNUM_VALUE, str(TEST_NHS_NUMBER_1)),
+ (ViewParam.WHICH_IDNUM, nhs_patient_idnum.which_idnum),
+ (ViewParam.IDNUM_VALUE, new_nhs_number),
("__end__", "idnum_sequence:mapping"),
("__end__", "id_references:sequence"),
("__start__", "danger:mapping"),
@@ -858,32 +801,32 @@ def test_patient_updated(self) -> None:
self.dbsession.commit()
- self.assertEqual(patient.forename, "Jo")
- self.assertEqual(patient.surname, "Patient")
- self.assertEqual(patient.dob.isoformat(), "1958-04-19")
- self.assertEqual(patient.sex, "X")
- self.assertEqual(patient.address, "New address")
- self.assertEqual(patient.email, "newjopatient@example.com")
- self.assertEqual(patient.gp, "New GP")
- self.assertEqual(patient.other, "New other")
+ self.assertEqual(patient.forename, new_forename)
+ self.assertEqual(patient.surname, new_surname)
+ self.assertEqual(patient.dob, new_dob)
+ self.assertEqual(patient.sex, new_sex)
+ self.assertEqual(patient.address, new_address)
+ self.assertEqual(patient.email, new_email)
+ self.assertEqual(patient.gp, new_gp)
+ self.assertEqual(patient.other, new_other)
idnum = patient.get_idnum_objects()[0]
self.assertEqual(idnum.patient_id, patient.id)
- self.assertEqual(idnum.which_idnum, self.nhs_iddef.which_idnum)
- self.assertEqual(idnum.idnum_value, TEST_NHS_NUMBER_1)
+ self.assertEqual(idnum.which_idnum, nhs_patient_idnum.which_idnum)
+ self.assertEqual(idnum.idnum_value, new_nhs_number)
self.assertEqual(len(patient.special_notes), 1)
note = patient.special_notes[0].note
self.assertIn("Patient details edited", note)
self.assertIn("forename", note)
- self.assertIn("Jo", note)
+ self.assertIn(new_forename, note)
self.assertIn("surname", note)
- self.assertIn("Patient", note)
+ self.assertIn(new_surname, note)
- self.assertIn("idnum1", note)
- self.assertIn(str(TEST_NHS_NUMBER_1), note)
+ self.assertIn(f"idnum{nhs_patient_idnum.which_idnum}", note)
+ self.assertIn(str(new_nhs_number), note)
messages = self.req.session.peek_flash(FlashQueue.SUCCESS)
@@ -891,46 +834,20 @@ def test_patient_updated(self) -> None:
f"Amended patient record with server PK {patient.pk}", messages[0]
)
self.assertIn("forename", messages[0])
- self.assertIn("Jo", messages[0])
+ self.assertIn(new_forename, messages[0])
self.assertIn("surname", messages[0])
- self.assertIn("Patient", messages[0])
+ self.assertIn(new_surname, messages[0])
self.assertIn("idnum1", messages[0])
- self.assertIn(str(TEST_NHS_NUMBER_1), messages[0])
+ self.assertIn(str(new_nhs_number), messages[0])
def test_message_when_no_changes(self) -> None:
- patient = self.create_patient(
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- )
- patient_idnum = self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
- )
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
- self.dbsession.commit()
+ patient = PatientFactory(_group=self.group)
- patient_task_schedule = PatientTaskSchedule()
- patient_task_schedule.patient_pk = patient.pk
- patient_task_schedule.schedule_id = schedule1.id
- patient_task_schedule.start_datetime = local(2020, 6, 12, 9)
- patient_task_schedule.settings = {
- "name 1": "value 1",
- "name 2": "value 2",
- "name 3": "value 3",
- }
-
- self.dbsession.add(patient_task_schedule)
+ patient_idnum = NHSPatientIdNumFactory(
+ patient=patient,
+ )
self.req.add_get_params(
{ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False
)
@@ -951,6 +868,7 @@ def test_message_when_no_changes(self) -> None:
("deformField7", patient.sex),
("__end__", "sex:rename"),
(ViewParam.ADDRESS, patient.address),
+ (ViewParam.EMAIL, patient.email),
(ViewParam.GP, patient.gp),
(ViewParam.OTHER, patient.other),
("__start__", "id_references:sequence"),
@@ -963,25 +881,6 @@ def test_message_when_no_changes(self) -> None:
("target", "7836"),
("user_entry", "7836"),
("__end__", "danger:mapping"),
- ("__start__", "task_schedules:sequence"),
- ("__start__", "task_schedule_sequence:mapping"),
- ("schedule_id", schedule1.id),
- ("__start__", "start_datetime:mapping"),
- ("date", "2020-06-12"),
- ("time", "09:00:00"),
- ("__end__", "start_datetime:mapping"),
- (
- "settings",
- json.dumps(
- {
- "name 1": "value 1",
- "name 2": "value 2",
- "name 3": "value 3",
- }
- ),
- ),
- ("__end__", "task_schedule_sequence:mapping"),
- ("__end__", "task_schedules:sequence"),
(FormAction.SUBMIT, "submit"),
]
)
@@ -996,44 +895,11 @@ def test_message_when_no_changes(self) -> None:
self.assertIn("No changes required", messages[0])
def test_template_rendered_with_values(self) -> None:
- patient = self.create_patient(
- id=1,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- )
- self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
- )
-
- from camcops_server.tasks import Bmi
-
- task1 = Bmi()
- task1.id = 1
- task1._device_id = patient.device_id
- task1._group_id = patient.group_id
- task1._era = patient.era
- task1.patient_id = patient.id
- task1.when_created = self.era_time
- task1._current = False
- self.dbsession.add(task1)
-
- task2 = Bmi()
- task2.id = 2
- task2._device_id = patient.device_id
- task2._group_id = patient.group_id
- task2._era = patient.era
- task2.patient_id = patient.id
- task2.when_created = self.era_time
- task2._current = False
- self.dbsession.add(task2)
- self.dbsession.commit()
+ patient = PatientFactory(_group=self.group)
+ NHSPatientIdNumFactory(patient=patient)
+
+ task1 = BmiFactory(patient=patient, _current=False)
+ task2 = BmiFactory(patient=patient, _current=False)
self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)})
@@ -1053,54 +919,48 @@ def test_template_rendered_with_values(self) -> None:
def test_changes_to_simple_params(self) -> None:
view = EditFinalizedPatientView(self.req)
- patient = self.create_patient(
- id=1,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- email="jopatient@example.com",
- gp="GP",
- other=None,
- )
+ patient = PatientFactory()
+ old_forename = patient.forename
+ old_surname = patient.surname
+ old_address = patient.address
+ new_forename = Fake.en_gb.forename(patient.sex)
+ new_surname = Fake.en_gb.last_name()
+ new_address = Fake.en_gb.address()
+
view.object = patient
changes = OrderedDict() # type: OrderedDict
appstruct = {
- ViewParam.FORENAME: "Joanna",
- ViewParam.SURNAME: "Patient-Patient",
- ViewParam.DOB: datetime.date(1958, 4, 19),
- ViewParam.ADDRESS: "New address",
- ViewParam.OTHER: "",
+ ViewParam.FORENAME: new_forename,
+ ViewParam.SURNAME: new_surname,
+ ViewParam.DOB: patient.dob,
+ ViewParam.ADDRESS: new_address,
+ ViewParam.OTHER: patient.other,
}
view._save_simple_params(appstruct, changes)
- self.assertEqual(changes[ViewParam.FORENAME], ("Jo", "Joanna"))
self.assertEqual(
- changes[ViewParam.SURNAME], ("Patient", "Patient-Patient")
+ changes[ViewParam.FORENAME], (old_forename, new_forename)
+ )
+ self.assertEqual(
+ changes[ViewParam.SURNAME], (old_surname, new_surname)
)
self.assertNotIn(ViewParam.DOB, changes)
self.assertEqual(
- changes[ViewParam.ADDRESS], ("Address", "New address")
+ changes[ViewParam.ADDRESS], (old_address, new_address)
)
self.assertNotIn(ViewParam.OTHER, changes)
def test_changes_to_idrefs(self) -> None:
view = EditFinalizedPatientView(self.req)
- patient = self.create_patient(id=1)
- self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
- )
- self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.study_iddef.which_idnum,
- idnum_value=123,
- )
+ patient = PatientFactory()
+ nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient)
+ study_patient_idnum = StudyPatientIdNumFactory(patient=patient)
+ rio_iddef = RioIdNumDefinitionFactory()
+ new_nhs_number = Fake.en_gb.nhs_number()
+ new_rio_number = 9999 # Below the range the factory would use
view.object = patient
@@ -1109,40 +969,43 @@ def test_changes_to_idrefs(self) -> None:
appstruct = {
ViewParam.ID_REFERENCES: [
{
- ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum,
- ViewParam.IDNUM_VALUE: TEST_NHS_NUMBER_2,
+ ViewParam.WHICH_IDNUM: nhs_patient_idnum.which_idnum,
+ ViewParam.IDNUM_VALUE: new_nhs_number,
},
{
- ViewParam.WHICH_IDNUM: self.rio_iddef.which_idnum,
- ViewParam.IDNUM_VALUE: 456,
+ ViewParam.WHICH_IDNUM: rio_iddef.which_idnum,
+ ViewParam.IDNUM_VALUE: new_rio_number,
},
]
}
view._save_idrefs(appstruct, changes)
+ nhs_key = f"idnum{nhs_patient_idnum.which_idnum} (NHS number)"
+ self.assertIn(nhs_key, changes)
+
+ study_key = f"idnum{study_patient_idnum.which_idnum} (Study number)"
+ self.assertIn(study_key, changes)
+
+ rio_key = f"idnum{rio_iddef.which_idnum} (RiO number)"
+ self.assertIn(rio_key, changes)
+
self.assertEqual(
- changes["idnum1 (NHS number)"],
- (TEST_NHS_NUMBER_1, TEST_NHS_NUMBER_2),
+ changes[nhs_key],
+ (nhs_patient_idnum.idnum_value, new_nhs_number),
)
- self.assertEqual(changes["idnum3 (Study number)"], (123, None))
- self.assertEqual(changes["idnum2 (RiO number)"], (None, 456))
+ self.assertEqual(
+ changes[study_key], (study_patient_idnum.idnum_value, None)
+ )
+ self.assertEqual(changes[rio_key], (None, new_rio_number))
-class EditServerCreatedPatientViewTests(BasicDatabaseTestCase):
- """
- Unit tests.
- """
+class EditServerCreatedPatientViewTests(BasicDatabaseTestCase):
def test_group_updated(self) -> None:
- patient = self.create_patient(sex="F", as_server_patient=True)
- new_group = Group()
- new_group.name = "newgroup"
- new_group.description = "New group"
- new_group.upload_policy = "sex AND anyidnum"
- new_group.finalize_policy = "sex AND idnum1"
- self.dbsession.add(new_group)
- self.dbsession.commit()
+ patient = ServerCreatedPatientFactory(_group=self.group)
+ old_group = patient.group
+ new_group = GroupFactory()
view = EditServerCreatedPatientView(self.req)
view.object = patient
@@ -1155,12 +1018,12 @@ def test_group_updated(self) -> None:
messages = self.req.session.peek_flash(FlashQueue.SUCCESS)
- self.assertIn("testgroup", messages[0])
- self.assertIn("newgroup", messages[0])
+ self.assertIn(old_group.name, messages[0])
+ self.assertIn(new_group.name, messages[0])
self.assertIn("group:", messages[0])
def test_raises_when_not_created_on_the_server(self) -> None:
- patient = self.create_patient(id=1, _device_id=self.other_device.id)
+ patient = PatientFactory()
view = EditServerCreatedPatientView(self.req)
@@ -1172,41 +1035,29 @@ def test_raises_when_not_created_on_the_server(self) -> None:
self.assertIn("Patient is not editable", str(cm.exception))
def test_patient_task_schedules_updated(self) -> None:
- patient = self.create_patient(sex="F", as_server_patient=True)
-
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
- schedule2 = TaskSchedule()
- schedule2.group_id = self.group.id
- schedule2.name = "Test 2"
- self.dbsession.add(schedule2)
- schedule3 = TaskSchedule()
- schedule3.group_id = self.group.id
- schedule3.name = "Test 3"
- self.dbsession.add(schedule3)
- self.dbsession.commit()
-
- patient_task_schedule = PatientTaskSchedule()
- patient_task_schedule.patient_pk = patient.pk
- patient_task_schedule.schedule_id = schedule1.id
- patient_task_schedule.start_datetime = local(2020, 6, 12, 9)
- patient_task_schedule.settings = {
- "name 1": "value 1",
- "name 2": "value 2",
- "name 3": "value 3",
- }
-
- self.dbsession.add(patient_task_schedule)
-
- patient_task_schedule = PatientTaskSchedule()
- patient_task_schedule.patient_pk = patient.pk
- patient_task_schedule.schedule_id = schedule3.id
-
- self.dbsession.add(patient_task_schedule)
- self.dbsession.commit()
+ patient = ServerCreatedPatientFactory()
+ nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient)
+ group = patient._group
+
+ schedule1 = TaskScheduleFactory(group=group)
+ schedule2 = TaskScheduleFactory(group=group)
+ schedule3 = TaskScheduleFactory(group=group)
+
+ PatientTaskScheduleFactory(
+ patient=patient,
+ task_schedule=schedule1,
+ start_datetime=local(2020, 6, 12, 9),
+ settings={
+ "name 1": "value 1",
+ "name 2": "value 2",
+ "name 3": "value 3",
+ },
+ )
+ PatientTaskScheduleFactory(
+ patient=patient,
+ task_schedule=schedule3,
+ )
self.req.add_get_params(
{ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False
)
@@ -1216,11 +1067,14 @@ def test_patient_task_schedules_updated(self) -> None:
"name 2": "new value 2",
"name 3": "new value 3",
}
+ changed_schedule_1_datetime = local(2020, 6, 19, 8, 0, 0)
new_schedule_2_settings = {
"name 4": "value 4",
"name 5": "value 5",
"name 6": "value 6",
}
+ new_schedule_2_datetime = local(2020, 7, 1, 13, 45, 0)
+ new_nhs_number = Fake.en_gb.nhs_number()
multidict = MultiDict(
[
("_charset_", UTF8),
@@ -1241,8 +1095,8 @@ def test_patient_task_schedules_updated(self) -> None:
(ViewParam.OTHER, patient.other),
("__start__", "id_references:sequence"),
("__start__", "idnum_sequence:mapping"),
- (ViewParam.WHICH_IDNUM, self.nhs_iddef.which_idnum),
- (ViewParam.IDNUM_VALUE, str(TEST_NHS_NUMBER_1)),
+ (ViewParam.WHICH_IDNUM, nhs_patient_idnum.which_idnum),
+ (ViewParam.IDNUM_VALUE, str(new_nhs_number)),
("__end__", "idnum_sequence:mapping"),
("__end__", "id_references:sequence"),
("__start__", "danger:mapping"),
@@ -1253,16 +1107,16 @@ def test_patient_task_schedules_updated(self) -> None:
("__start__", "task_schedule_sequence:mapping"),
("schedule_id", schedule1.id),
("__start__", "start_datetime:mapping"),
- ("date", "2020-06-19"),
- ("time", "08:00:00"),
+ ("date", changed_schedule_1_datetime.to_date_string()),
+ ("time", changed_schedule_1_datetime.to_time_string()),
("__end__", "start_datetime:mapping"),
("settings", json.dumps(changed_schedule_1_settings)),
("__end__", "task_schedule_sequence:mapping"),
("__start__", "task_schedule_sequence:mapping"),
("schedule_id", schedule2.id),
("__start__", "start_datetime:mapping"),
- ("date", "2020-07-01"),
- ("time", "13:45:00"),
+ ("date", new_schedule_2_datetime.to_date_string()),
+ ("time", new_schedule_2_datetime.to_time_string()),
("__end__", "start_datetime:mapping"),
("settings", json.dumps(new_schedule_2_settings)),
("__end__", "task_schedule_sequence:mapping"),
@@ -1281,20 +1135,24 @@ def test_patient_task_schedules_updated(self) -> None:
schedules = {
pts.task_schedule.name: pts for pts in patient.task_schedules
}
- self.assertIn("Test 1", schedules)
- self.assertIn("Test 2", schedules)
- self.assertNotIn("Test 3", schedules)
+ self.assertIn(schedule1.name, schedules)
+ self.assertIn(schedule2.name, schedules)
+ self.assertNotIn(schedule3.name, schedules)
self.assertEqual(
- schedules["Test 1"].start_datetime, local(2020, 6, 19, 8)
+ schedules[schedule1.name].start_datetime,
+ changed_schedule_1_datetime,
+ )
+ self.assertEqual(
+ schedules[schedule1.name].settings, changed_schedule_1_settings
)
self.assertEqual(
- schedules["Test 1"].settings, changed_schedule_1_settings
+ schedules[schedule2.name].start_datetime,
+ new_schedule_2_datetime,
)
self.assertEqual(
- schedules["Test 2"].start_datetime, local(2020, 7, 1, 13, 45)
+ schedules[schedule2.name].settings, new_schedule_2_settings
)
- self.assertEqual(schedules["Test 2"].settings, new_schedule_2_settings)
messages = self.req.session.peek_flash(FlashQueue.SUCCESS)
@@ -1304,12 +1162,9 @@ def test_patient_task_schedules_updated(self) -> None:
self.assertIn("Task schedules", messages[0])
def test_unprivileged_user_cannot_edit_patient(self) -> None:
- patient = self.create_patient(sex="F", as_server_patient=True)
+ patient = ServerCreatedPatientFactory()
- user = self.create_user(username="testuser")
- self.dbsession.flush()
-
- self.req._debugging_user = user
+ self.req._debugging_user = UserFactory()
view = EditServerCreatedPatientView(self.req)
view.object = patient
@@ -1324,20 +1179,15 @@ def test_unprivileged_user_cannot_edit_patient(self) -> None:
)
def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None:
- patient = self.create_patient(sex="F", as_server_patient=True)
-
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
- self.dbsession.flush()
-
- pts = PatientTaskSchedule()
- pts.patient_pk = patient.pk
- pts.schedule_id = schedule1.id
- pts.start_datetime = local(2020, 6, 12, 12, 34)
- self.dbsession.add(pts)
- self.dbsession.commit()
+ patient = ServerCreatedPatientFactory()
+
+ schedule1 = TaskScheduleFactory(group=self.group)
+
+ pts = PatientTaskScheduleFactory(
+ patient=patient,
+ task_schedule=schedule1,
+ start_datetime=local(2020, 6, 12, 12, 34),
+ )
appstruct = {
ViewParam.TASK_SCHEDULES: [
@@ -1359,7 +1209,7 @@ def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None:
view = EditServerCreatedPatientView(self.req)
view.object = patient
- changes = {}
+ changes: OrderedDict = OrderedDict()
view._save_task_schedules(appstruct, changes)
self.req.dbsession.commit()
@@ -1367,43 +1217,24 @@ def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None:
self.assertEqual(patient.task_schedules[1].task_schedule, schedule1)
def test_form_values_for_existing_patient(self) -> None:
- patient = self.create_patient(
- id=1,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- email="jopatient@example.com",
- gp="GP",
- other="Other",
- )
-
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
- self.dbsession.commit()
+ patient = PatientFactory()
- patient_task_schedule = PatientTaskSchedule()
- patient_task_schedule.patient_pk = patient.pk
- patient_task_schedule.schedule_id = schedule1.id
- patient_task_schedule.start_datetime = local(2020, 6, 12)
- patient_task_schedule.settings = {
- "name 1": "value 1",
- "name 2": "value 2",
- "name 3": "value 3",
- }
-
- self.dbsession.add(patient_task_schedule)
- self.dbsession.commit()
+ schedule1 = TaskScheduleFactory(
+ group=self.group,
+ )
- self.create_patient_idnum(
- patient_id=patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
+ patient_task_schedule = PatientTaskScheduleFactory(
+ patient=patient,
+ task_schedule=schedule1,
+ start_datetime=local(2020, 6, 12),
+ settings={
+ "name 1": "value 1",
+ "name 2": "value 2",
+ "name 3": "value 3",
+ },
)
+ patient_idnum = NHSPatientIdNumFactory(patient=patient)
self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)})
view = EditServerCreatedPatientView(self.req)
@@ -1411,25 +1242,26 @@ def test_form_values_for_existing_patient(self) -> None:
form_values = view.get_form_values()
- self.assertEqual(form_values[ViewParam.FORENAME], "Jo")
- self.assertEqual(form_values[ViewParam.SURNAME], "Patient")
- self.assertEqual(
- form_values[ViewParam.DOB], datetime.date(1958, 4, 19)
- )
- self.assertEqual(form_values[ViewParam.SEX], "F")
- self.assertEqual(form_values[ViewParam.ADDRESS], "Address")
- self.assertEqual(form_values[ViewParam.EMAIL], "jopatient@example.com")
- self.assertEqual(form_values[ViewParam.GP], "GP")
- self.assertEqual(form_values[ViewParam.OTHER], "Other")
+ self.assertEqual(form_values[ViewParam.FORENAME], patient.forename)
+ self.assertEqual(form_values[ViewParam.SURNAME], patient.surname)
+ self.assertEqual(form_values[ViewParam.DOB], patient.dob)
+ self.assertEqual(form_values[ViewParam.SEX], patient.sex)
+ self.assertEqual(form_values[ViewParam.ADDRESS], patient.address)
+ self.assertEqual(form_values[ViewParam.EMAIL], patient.email)
+ self.assertEqual(form_values[ViewParam.GP], patient.gp)
+ self.assertEqual(form_values[ViewParam.OTHER], patient.other)
self.assertEqual(form_values[ViewParam.SERVER_PK], patient.pk)
self.assertEqual(form_values[ViewParam.GROUP_ID], patient.group.id)
idnum = form_values[ViewParam.ID_REFERENCES][0]
self.assertEqual(
- idnum[ViewParam.WHICH_IDNUM], self.nhs_iddef.which_idnum
+ idnum[ViewParam.WHICH_IDNUM],
+ patient_idnum.which_idnum,
+ )
+ self.assertEqual(
+ idnum[ViewParam.IDNUM_VALUE], patient_idnum.idnum_value
)
- self.assertEqual(idnum[ViewParam.IDNUM_VALUE], TEST_NHS_NUMBER_1)
task_schedule = form_values[ViewParam.TASK_SCHEDULES][0]
self.assertEqual(
@@ -1449,24 +1281,12 @@ def test_form_values_for_existing_patient(self) -> None:
)
-class AddPatientViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
-
+class AddPatientViewTests(BasicDatabaseTestCase):
def test_patient_created(self) -> None:
view = AddPatientView(self.req)
- schedule1 = TaskSchedule()
- schedule1.group_id = self.group.id
- schedule1.name = "Test 1"
- self.dbsession.add(schedule1)
-
- schedule2 = TaskSchedule()
- schedule2.group_id = self.group.id
- schedule2.name = "Test 2"
- self.dbsession.add(schedule2)
- self.dbsession.commit()
+ schedule1 = TaskScheduleFactory()
+ schedule2 = TaskScheduleFactory()
start_datetime1 = local(2020, 6, 12)
start_datetime2 = local(2020, 7, 1)
@@ -1475,6 +1295,9 @@ def test_patient_created(self) -> None:
{"name 1": "value 1", "name 2": "value 2", "name 3": "value 3"}
)
+ nhs_iddef = NHSIdNumDefinitionFactory()
+ nhs_number = Fake.en_gb.nhs_number()
+
appstruct = {
ViewParam.GROUP_ID: self.group.id,
ViewParam.FORENAME: "Jo",
@@ -1487,8 +1310,8 @@ def test_patient_created(self) -> None:
ViewParam.OTHER: "Other",
ViewParam.ID_REFERENCES: [
{
- ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum,
- ViewParam.IDNUM_VALUE: 1192220552,
+ ViewParam.WHICH_IDNUM: nhs_iddef.which_idnum,
+ ViewParam.IDNUM_VALUE: nhs_number,
}
],
ViewParam.TASK_SCHEDULES: [
@@ -1506,12 +1329,12 @@ def test_patient_created(self) -> None:
}
view.save_object(appstruct)
+ self.dbsession.commit()
patient = cast(Patient, view.object)
server_device = Device.get_server_device(self.req.dbsession)
- self.assertEqual(patient.id, 1)
self.assertEqual(patient.device_id, server_device.id)
self.assertEqual(patient.era, ERA_NOW)
self.assertEqual(patient.group.id, self.group.id)
@@ -1526,27 +1349,32 @@ def test_patient_created(self) -> None:
self.assertEqual(patient.other, "Other")
idnum = patient.get_idnum_objects()[0]
- self.assertEqual(idnum.patient_id, 1)
- self.assertEqual(idnum.which_idnum, self.nhs_iddef.which_idnum)
- self.assertEqual(idnum.idnum_value, 1192220552)
+ self.assertEqual(idnum.patient_id, patient.id)
+ self.assertEqual(idnum.which_idnum, nhs_iddef.which_idnum)
+ self.assertEqual(idnum.idnum_value, nhs_number)
patient_task_schedules = {
pts.task_schedule.name: pts for pts in patient.task_schedules
}
- self.assertIn("Test 1", patient_task_schedules)
- self.assertIn("Test 2", patient_task_schedules)
+ self.assertIn(schedule1.name, patient_task_schedules)
+ self.assertIn(schedule2.name, patient_task_schedules)
self.assertEqual(
- patient_task_schedules["Test 1"].start_datetime, start_datetime1
+ patient_task_schedules[schedule1.name].start_datetime,
+ start_datetime1,
)
- self.assertEqual(patient_task_schedules["Test 1"].settings, settings1)
self.assertEqual(
- patient_task_schedules["Test 2"].start_datetime, start_datetime2
+ patient_task_schedules[schedule1.name].settings, settings1
+ )
+ self.assertEqual(
+ patient_task_schedules[schedule2.name].start_datetime,
+ start_datetime2,
)
def test_patient_takes_next_available_id(self) -> None:
- self.create_patient(id=1234, as_server_patient=True)
+ patient = ServerCreatedPatientFactory(id=1234)
+ nhs_iddef = NHSIdNumDefinitionFactory()
view = AddPatientView(self.req)
@@ -1561,8 +1389,8 @@ def test_patient_takes_next_available_id(self) -> None:
ViewParam.OTHER: "Other",
ViewParam.ID_REFERENCES: [
{
- ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum,
- ViewParam.IDNUM_VALUE: 1192220552,
+ ViewParam.WHICH_IDNUM: nhs_iddef.which_idnum,
+ ViewParam.IDNUM_VALUE: Fake.en_gb.nhs_number(),
}
],
ViewParam.TASK_SCHEDULES: [],
@@ -1587,8 +1415,7 @@ def test_form_rendered_with_values(self) -> None:
self.assertIn("form", context)
def test_unprivileged_user_cannot_add_patient(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
+ user = UserFactory(username="testuser")
self.req._debugging_user = user
@@ -1600,10 +1427,11 @@ def test_unprivileged_user_cannot_add_patient(self) -> None:
)
def test_group_listed_for_privileged_group_member(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_manage_patients=True)
- self.dbsession.commit()
+ user = UserFactory()
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_manage_patients=True
+ )
self.req._debugging_user = user
@@ -1616,50 +1444,24 @@ def test_group_listed_for_privileged_group_member(self) -> None:
context = args[0]
- self.assertIn("testgroup", context["form"])
+ self.assertIn(group.name, context["form"])
class DeleteServerCreatedPatientViewTests(BasicDatabaseTestCase):
- """
- Unit tests.
- """
-
def setUp(self) -> None:
super().setUp()
- self.patient = self.create_patient(
- as_server_patient=True,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- )
-
- patient_pk = self.patient.pk
-
- idnum = self.create_patient_idnum(
- as_server_patient=True,
- patient_id=self.patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
- )
+ self.patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=self.patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test 1"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory(group=self.group)
- pts = PatientTaskSchedule()
- pts.patient_pk = patient_pk
- pts.schedule_id = self.schedule.id
- self.dbsession.add(pts)
- self.dbsession.commit()
+ PatientTaskScheduleFactory(
+ patient=self.patient,
+ task_schedule=self.schedule,
+ )
self.multidict = MultiDict(
[
@@ -1770,16 +1572,7 @@ def test_registered_patient_deleted(self) -> None:
# self.assertIsNone(user.single_patient_pk)
def test_unrelated_patient_unaffected(self) -> None:
- other_patient = self.create_patient(
- as_server_patient=True,
- forename="Mo",
- surname="Patient",
- dob=datetime.date(1968, 11, 30),
- sex="M",
- address="Address",
- gp="GP",
- other="Other",
- )
+ other_patient = ServerCreatedPatientFactory()
patient_pk = other_patient._pk
saved_patient = (
@@ -1790,12 +1583,7 @@ def test_unrelated_patient_unaffected(self) -> None:
self.assertIsNotNone(saved_patient)
- idnum = self.create_patient_idnum(
- as_server_patient=True,
- patient_id=other_patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_2,
- )
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=other_patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
@@ -1812,11 +1600,9 @@ def test_unrelated_patient_unaffected(self) -> None:
self.assertIsNotNone(saved_idnum)
- pts = PatientTaskSchedule()
- pts.patient_pk = patient_pk
- pts.schedule_id = self.schedule.id
- self.dbsession.add(pts)
- self.dbsession.commit()
+ PatientTaskScheduleFactory(
+ patient=other_patient, task_schedule=self.schedule
+ )
self.req.fake_request_post_from_dict(self.multidict)
@@ -1866,8 +1652,7 @@ def test_unprivileged_user_cannot_delete_patient(self) -> None:
)
view = DeleteServerCreatedPatientView(self.req)
- user = self.create_user(username="testuser")
- self.dbsession.flush()
+ user = UserFactory(username="testuser")
self.req._debugging_user = user
@@ -1885,8 +1670,7 @@ def test_unprivileged_user_cannot_see_delete_form(self) -> None:
self.req.add_get_params({ViewParam.SERVER_PK: str(patient_pk)})
view = DeleteServerCreatedPatientView(self.req)
- user = self.create_user(username="testuser")
- self.dbsession.flush()
+ user = UserFactory()
self.req._debugging_user = user
@@ -1899,33 +1683,19 @@ def test_unprivileged_user_cannot_see_delete_form(self) -> None:
class EraseTaskTestCase(BasicDatabaseTestCase):
- """
- Unit tests.
- """
-
- def create_tasks(self) -> None:
- from camcops_server.tasks.bmi import Bmi
-
- self.task = Bmi()
- self.task.id = 1
- self.apply_standard_task_fields(self.task)
- patient = self.create_patient_with_one_idnum()
- self.task.patient_id = patient.id
+ def setUp(self) -> None:
+ super().setUp()
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ self.patient = PatientFactory(_group=self.group)
class EraseTaskLeavingPlaceholderViewTests(EraseTaskTestCase):
- """
- Unit tests.
- """
-
def test_displays_form(self) -> None:
+ task = BmiFactory(patient=self.patient)
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -1940,13 +1710,14 @@ def test_displays_form(self) -> None:
self.assertIn("form", context)
def test_deletes_task_leaving_placeholder(self) -> None:
+ task = BmiFactory(patient=self.patient)
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
- (ViewParam.SERVER_PK, self.task.pk),
- (ViewParam.TABLE_NAME, self.task.tablename),
+ (ViewParam.SERVER_PK, task.pk),
+ (ViewParam.TABLE_NAME, task.tablename),
("confirm_1_t", "true"),
("confirm_2_t", "true"),
("confirm_4_t", "true"),
@@ -1962,9 +1733,7 @@ def test_deletes_task_leaving_placeholder(self) -> None:
self.req.fake_request_post_from_dict(multidict)
view = EraseTaskLeavingPlaceholderView(self.req)
- with mock.patch.object(
- self.task, "manually_erase"
- ) as mock_manually_erase:
+ with mock.patch.object(task, "manually_erase") as mock_manually_erase:
with self.assertRaises(HTTPFound):
view.dispatch()
@@ -1976,12 +1745,13 @@ def test_deletes_task_leaving_placeholder(self) -> None:
self.assertEqual(request, self.req)
def test_task_not_deleted_on_cancel(self) -> None:
+ task = BmiFactory(patient=self.patient)
self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"})
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -1990,17 +1760,18 @@ def test_task_not_deleted_on_cancel(self) -> None:
with self.assertRaises(HTTPFound):
view.dispatch()
- task = self.dbsession.query(self.task.__class__).one_or_none()
+ task = self.dbsession.query(task.__class__).one_or_none()
self.assertIsNotNone(task)
def test_redirect_on_cancel(self) -> None:
+ task = BmiFactory(patient=self.patient)
self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"})
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -2012,11 +1783,11 @@ def test_redirect_on_cancel(self) -> None:
self.assertEqual(cm.exception.status_code, 302)
self.assertIn(f"/{Routes.TASK}", cm.exception.headers["Location"])
self.assertIn(
- f"{ViewParam.TABLE_NAME}={self.task.tablename}",
+ f"{ViewParam.TABLE_NAME}={task.tablename}",
cm.exception.headers["Location"],
)
self.assertIn(
- f"{ViewParam.SERVER_PK}={self.task.pk}",
+ f"{ViewParam.SERVER_PK}={task.pk}",
cm.exception.headers["Location"],
)
self.assertIn(
@@ -2037,14 +1808,12 @@ def test_raises_when_task_does_not_exist(self) -> None:
self.assertEqual(cm.exception.message, "No such task: phq9, PK=123")
def test_raises_when_task_is_live_on_tablet(self) -> None:
- self.task._era = ERA_NOW
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ task = BmiFactory(patient=self.patient, _era=ERA_NOW)
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -2056,14 +1825,21 @@ def test_raises_when_task_is_live_on_tablet(self) -> None:
self.assertIn("Task is live on tablet", cm.exception.message)
def test_raises_when_user_not_authorized_to_erase(self) -> None:
+ task = BmiFactory(patient=self.patient)
+ user = UserFactory()
+
+ self.req._debugging_user = user
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group.id, groupadmin=True
+ )
+
with mock.patch.object(
- self.user, "authorized_to_erase_tasks", return_value=False
+ user, "authorized_to_erase_tasks", return_value=False
):
-
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -2075,14 +1851,12 @@ def test_raises_when_user_not_authorized_to_erase(self) -> None:
self.assertIn("Not authorized to erase tasks", cm.exception.message)
def test_raises_when_task_already_erased(self) -> None:
- self.task._manually_erased = True
- self.dbsession.add(self.task)
- self.dbsession.commit()
+ task = BmiFactory(patient=self.patient, _manually_erased=True)
self.req.add_get_params(
{
- ViewParam.SERVER_PK: str(self.task.pk),
- ViewParam.TABLE_NAME: self.task.tablename,
+ ViewParam.SERVER_PK: str(task.pk),
+ ViewParam.TABLE_NAME: task.tablename,
},
set_method_get=False,
)
@@ -2095,18 +1869,15 @@ def test_raises_when_task_already_erased(self) -> None:
class EraseTaskEntirelyViewTests(EraseTaskTestCase):
- """
- Unit tests.
- """
-
def test_deletes_task_entirely(self) -> None:
+ task = BmiFactory(patient=self.patient)
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
- (ViewParam.SERVER_PK, self.task.pk),
- (ViewParam.TABLE_NAME, self.task.tablename),
+ (ViewParam.SERVER_PK, task.pk),
+ (ViewParam.TABLE_NAME, task.tablename),
("confirm_1_t", "true"),
("confirm_2_t", "true"),
("confirm_4_t", "true"),
@@ -2124,7 +1895,7 @@ def test_deletes_task_entirely(self) -> None:
view = EraseTaskEntirelyView(self.req)
with mock.patch.object(
- self.task, "delete_entirely"
+ task, "delete_entirely"
) as mock_delete_entirely:
with self.assertRaises(HTTPFound):
@@ -2140,36 +1911,37 @@ def test_deletes_task_entirely(self) -> None:
self.assertTrue(len(messages) > 0)
self.assertIn("Task erased", messages[0])
- self.assertIn(self.task.tablename, messages[0])
- self.assertIn("server PK {}".format(self.task.pk), messages[0])
-
+ self.assertIn(task.tablename, messages[0])
+ self.assertIn("server PK {}".format(task.pk), messages[0])
-class EditGroupViewTests(DemoDatabaseTestCase):
- """
- Unit tests.
- """
+class EditGroupViewTests(DemoRequestTestCase):
def test_group_updated(self) -> None:
- other_group_1 = Group()
- other_group_1.name = "other-group-1"
- self.dbsession.add(other_group_1)
+ groupadmin = self.req._debugging_user = UserFactory()
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ group_id=group.id, user_id=groupadmin.id, groupadmin=True
+ )
+ other_group_1 = GroupFactory()
+ other_group_2 = GroupFactory()
- other_group_2 = Group()
- other_group_2.name = "other-group-2"
- self.dbsession.add(other_group_2)
+ nhs_iddef = NHSIdNumDefinitionFactory()
- self.dbsession.commit()
+ new_name = "new-name"
+ new_description = "new description"
+ new_upload_policy = "anyidnum AND sex"
+ new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex"
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
- (ViewParam.GROUP_ID, self.group.id),
- (ViewParam.NAME, "new-name"),
- (ViewParam.DESCRIPTION, "new description"),
- (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"), # reversed
- (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"), # reversed
+ (ViewParam.GROUP_ID, group.id),
+ (ViewParam.NAME, new_name),
+ (ViewParam.DESCRIPTION, new_description),
+ (ViewParam.UPLOAD_POLICY, new_upload_policy),
+ (ViewParam.FINALIZE_POLICY, new_finalize_policy),
("__start__", "group_ids:sequence"),
("group_id_sequence", str(other_group_1.id)),
("group_id_sequence", str(other_group_2.id)),
@@ -2182,26 +1954,38 @@ def test_group_updated(self) -> None:
with self.assertRaises(HTTPFound):
edit_group(self.req)
- self.assertEqual(self.group.name, "new-name")
- self.assertEqual(self.group.description, "new description")
- self.assertEqual(self.group.upload_policy, "anyidnum AND sex")
- self.assertEqual(self.group.finalize_policy, "idnum1 AND sex")
- self.assertIn(other_group_1, self.group.can_see_other_groups)
- self.assertIn(other_group_2, self.group.can_see_other_groups)
+ self.assertEqual(group.name, new_name)
+ self.assertEqual(group.description, new_description)
+ self.assertEqual(group.upload_policy, new_upload_policy)
+ self.assertEqual(group.finalize_policy, new_finalize_policy)
+ self.assertIn(other_group_1, group.can_see_other_groups)
+ self.assertIn(other_group_2, group.can_see_other_groups)
def test_ip_use_added(self) -> None:
from camcops_server.cc_modules.cc_ipuse import IpContexts
+ group = GroupFactory()
+ groupadmin = self.req._debugging_user = UserFactory()
+ UserGroupMembershipFactory(
+ group_id=group.id, user_id=groupadmin.id, groupadmin=True
+ )
+ nhs_iddef = NHSIdNumDefinitionFactory()
+
+ new_name = "new-name"
+ new_description = "new description"
+ new_upload_policy = "anyidnum AND sex"
+ new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex"
+
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
- (ViewParam.GROUP_ID, self.group.id),
- (ViewParam.NAME, "new-name"),
- (ViewParam.DESCRIPTION, "new description"),
- (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"),
- (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"),
+ (ViewParam.GROUP_ID, group.id),
+ (ViewParam.NAME, new_name),
+ (ViewParam.DESCRIPTION, new_description),
+ (ViewParam.UPLOAD_POLICY, new_upload_policy),
+ (ViewParam.FINALIZE_POLICY, new_finalize_policy),
("__start__", "ip_use:mapping"),
(IpContexts.CLINICAL, "true"),
(IpContexts.COMMERCIAL, "true"),
@@ -2214,31 +1998,39 @@ def test_ip_use_added(self) -> None:
with self.assertRaises(HTTPFound):
edit_group(self.req)
- self.assertTrue(self.group.ip_use.clinical)
- self.assertTrue(self.group.ip_use.commercial)
- self.assertFalse(self.group.ip_use.educational)
- self.assertFalse(self.group.ip_use.research)
+ self.assertTrue(group.ip_use.clinical)
+ self.assertTrue(group.ip_use.commercial)
+ self.assertFalse(group.ip_use.educational)
+ self.assertFalse(group.ip_use.research)
def test_ip_use_updated(self) -> None:
from camcops_server.cc_modules.cc_ipuse import IpContexts
- self.group.ip_use.educational = True
- self.group.ip_use.research = True
- self.dbsession.add(self.group.ip_use)
- self.dbsession.commit()
+ group = GroupFactory(ip_use__educational=True, ip_use__research=True)
+ groupadmin = self.req._debugging_user = UserFactory()
+ UserGroupMembershipFactory(
+ group_id=group.id, user_id=groupadmin.id, groupadmin=True
+ )
+
+ old_id = group.ip_use.id
+
+ nhs_iddef = NHSIdNumDefinitionFactory()
- old_id = self.group.ip_use.id
+ new_name = "new-name"
+ new_description = "new description"
+ new_upload_policy = "anyidnum AND sex"
+ new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex"
multidict = MultiDict(
[
("_charset_", UTF8),
("__formid__", "deform"),
(ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
- (ViewParam.GROUP_ID, self.group.id),
- (ViewParam.NAME, "new-name"),
- (ViewParam.DESCRIPTION, "new description"),
- (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"),
- (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"),
+ (ViewParam.GROUP_ID, group.id),
+ (ViewParam.NAME, new_name),
+ (ViewParam.DESCRIPTION, new_description),
+ (ViewParam.UPLOAD_POLICY, new_upload_policy),
+ (ViewParam.FINALIZE_POLICY, new_finalize_policy),
("__start__", "ip_use:mapping"),
(IpContexts.CLINICAL, "true"),
(IpContexts.COMMERCIAL, "true"),
@@ -2251,32 +2043,23 @@ def test_ip_use_updated(self) -> None:
with self.assertRaises(HTTPFound):
edit_group(self.req)
- self.assertTrue(self.group.ip_use.clinical)
- self.assertTrue(self.group.ip_use.commercial)
- self.assertFalse(self.group.ip_use.educational)
- self.assertFalse(self.group.ip_use.research)
- self.assertEqual(self.group.ip_use.id, old_id)
+ self.assertTrue(group.ip_use.clinical)
+ self.assertTrue(group.ip_use.commercial)
+ self.assertFalse(group.ip_use.educational)
+ self.assertFalse(group.ip_use.research)
+ self.assertEqual(group.ip_use.id, old_id)
def test_other_groups_displayed_in_form(self) -> None:
- z_group = Group()
- z_group.name = "z-group"
- self.dbsession.add(z_group)
-
- a_group = Group()
- a_group.name = "a-group"
- self.dbsession.add(a_group)
- self.dbsession.commit()
+ z_group = GroupFactory(name="z-group")
+ a_group = GroupFactory(name="a-group")
other_groups = Group.get_groups_from_id_list(
self.dbsession, [z_group.id, a_group.id]
)
- self.group.can_see_other_groups = other_groups
-
- self.dbsession.add(self.group)
- self.dbsession.commit()
+ group = GroupFactory(can_see_other_groups=other_groups)
view = EditGroupView(self.req)
- view.object = self.group
+ view.object = group
form_values = view.get_form_values()
@@ -2285,59 +2068,38 @@ def test_other_groups_displayed_in_form(self) -> None:
)
def test_group_id_displayed_in_form(self) -> None:
+ group = GroupFactory()
view = EditGroupView(self.req)
- view.object = self.group
+ view.object = group
form_values = view.get_form_values()
- self.assertEqual(form_values[ViewParam.GROUP_ID], self.group.id)
+ self.assertEqual(form_values[ViewParam.GROUP_ID], group.id)
def test_ip_use_displayed_in_form(self) -> None:
+ group = GroupFactory()
view = EditGroupView(self.req)
- view.object = self.group
+ view.object = group
form_values = view.get_form_values()
- self.assertEqual(form_values[ViewParam.IP_USE], self.group.ip_use)
+ self.assertEqual(form_values[ViewParam.IP_USE], group.ip_use)
class SendEmailFromPatientTaskScheduleViewTests(BasicDatabaseTestCase):
def setUp(self) -> None:
super().setUp()
- self.patient = self.create_patient(
- as_server_patient=True,
- forename="Jo",
- surname="Patient",
- dob=datetime.date(1958, 4, 19),
- sex="F",
- address="Address",
- gp="GP",
- other="Other",
- )
-
- patient_pk = self.patient.pk
-
- idnum = self.create_patient_idnum(
- as_server_patient=True,
- patient_id=self.patient.id,
- which_idnum=self.nhs_iddef.which_idnum,
- idnum_value=TEST_NHS_NUMBER_1,
- )
+ self.patient = ServerCreatedPatientFactory()
+ idnum = ServerCreatedNHSPatientIdNumFactory(patient=self.patient)
PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession)
- self.schedule = TaskSchedule()
- self.schedule.group_id = self.group.id
- self.schedule.name = "Test 1"
- self.dbsession.add(self.schedule)
- self.dbsession.commit()
+ self.schedule = TaskScheduleFactory(group=self.group)
- self.pts = PatientTaskSchedule()
- self.pts.patient_pk = patient_pk
- self.pts.schedule_id = self.schedule.id
- self.dbsession.add(self.pts)
- self.dbsession.commit()
+ self.pts = PatientTaskScheduleFactory(
+ patient=self.patient, task_schedule=self.schedule
+ )
def test_displays_form(self) -> None:
self.req.add_get_params(
@@ -2576,8 +2338,7 @@ def test_email_record_created(
self.assertEqual(self.pts.emails[0].email.to, "patient@example.com")
def test_unprivileged_user_cannot_email_patient(self) -> None:
- user = self.create_user(username="testuser")
- self.dbsession.flush()
+ user = UserFactory(username="testuser")
self.req._debugging_user = user
@@ -2647,8 +2408,9 @@ def test_password_autocomplete_read_from_config(self) -> None:
self.assertIn('autocomplete="current-password"', context["form"])
def test_fails_when_user_locked_out(self) -> None:
- user = self.create_user(username="test")
- user.set_password(self.req, "secret")
+ user = UserFactory(
+ username="test", password__request=self.req, password="secret"
+ )
SecurityAccountLockout.lock_user_out(
self.req, user.username, lockout_minutes=1
)
@@ -2679,10 +2441,12 @@ def test_fails_when_user_locked_out(self) -> None:
@mock.patch("camcops_server.cc_modules.webview.audit")
def test_user_can_log_in(self, mock_audit: mock.Mock) -> None:
- user = self.create_user(username="test")
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ user = UserFactory(
+ username="test", password__request=self.req, password="secret"
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group.id, may_use_webviewer=True
+ )
multidict = MultiDict(
[
@@ -2715,14 +2479,17 @@ def test_user_can_log_in(self, mock_audit: mock.Mock) -> None:
self.assertEqual(kwargs["user_id"], user.id)
def test_user_with_totp_sees_token_form(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.TOTP,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
view = LoginView(self.req)
view.state.update(
@@ -2749,16 +2516,19 @@ def test_user_with_totp_sees_token_form(self) -> None:
def test_user_with_hotp_email_sees_token_form(
self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock
) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_EMAIL,
email="user@example.com",
hotp_counter=0,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
view = LoginView(self.req)
view.state.update(
mfa_user_id=user.id,
@@ -2783,17 +2553,20 @@ def test_user_with_hotp_sms_sees_token_form(self) -> None:
SmsBackendNames.CONSOLE, {}
)
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- user = self.create_user(
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ user = UserFactory(
username="test",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_SMS,
- phone_number=phone_number,
+ phone_number=phonenumbers.parse(phone_number_str),
hotp_counter=0,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
view = LoginView(self.req)
view.state.update(
@@ -2824,15 +2597,18 @@ def test_session_state_set_for_user_with_mfa(
mock_make_email: mock.Mock,
mock_send_msg: mock.Mock,
) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_EMAIL,
email="user@example.com",
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
multidict = MultiDict(
[
@@ -2876,16 +2652,19 @@ def test_user_with_hotp_is_sent_email(
self.req.config.email_use_tls = True
self.req.config.email_from = "server@example.com"
- user = self.create_user(
+ user = UserFactory(
username="test",
email="user@example.com",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_EMAIL,
hotp_counter=0,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
multidict = MultiDict(
[
@@ -2926,18 +2705,21 @@ def test_user_with_hotp_is_sent_sms(self) -> None:
)
self.req.config.sms_config = test_config
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- user = self.create_user(
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ user = UserFactory(
username="test",
email="user@example.com",
- phone_number=phone_number,
+ phone_number=phonenumbers.parse(phone_number_str),
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_SMS,
hotp_counter=0,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
multidict = MultiDict(
[
@@ -2958,7 +2740,7 @@ def test_user_with_hotp_is_sent_sms(self) -> None:
expected_message = f"Your CamCOPS verification code is {expected_code}"
self.assertIn(
- ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message),
+ ConsoleSmsBackend.make_msg(phone_number_str, expected_message),
logging_cm.output[0],
)
@@ -2967,16 +2749,19 @@ def test_user_with_hotp_is_sent_sms(self) -> None:
def test_login_with_hotp_increments_counter(
self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock
) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
email="user@example.com",
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_EMAIL,
hotp_counter=0,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
multidict = MultiDict(
[
@@ -2996,15 +2781,18 @@ def test_login_with_hotp_increments_counter(
@mock.patch("camcops_server.cc_modules.webview.audit")
def test_user_with_totp_can_log_in(self, mock_audit: mock.Mock) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_method=MfaMethod.TOTP,
mfa_secret_key=pyotp.random_base32(),
+ password__request=self.req,
+ password="secret",
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
+ )
totp = pyotp.TOTP(user.mfa_secret_key)
@@ -3042,16 +2830,18 @@ def test_user_with_totp_can_log_in(self, mock_audit: mock.Mock) -> None:
@mock.patch("camcops_server.cc_modules.webview.audit")
def test_user_with_hotp_can_log_in(self, mock_audit: mock.Mock) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_method=MfaMethod.HOTP_EMAIL,
mfa_secret_key=pyotp.random_base32(),
hotp_counter=1,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
-
- self.create_membership(user, self.group, may_use_webviewer=True)
hotp = pyotp.HOTP(user.mfa_secret_key)
multidict = MultiDict(
@@ -3087,15 +2877,18 @@ def test_user_with_hotp_can_log_in(self, mock_audit: mock.Mock) -> None:
self.assert_state_is_finished()
def test_form_state_cleared_on_failed_login(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_method=MfaMethod.HOTP_EMAIL,
mfa_secret_key=pyotp.random_base32(),
hotp_counter=1,
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
hotp = pyotp.HOTP(user.mfa_secret_key)
@@ -3123,14 +2916,17 @@ def test_form_state_cleared_on_failed_login(self) -> None:
def test_user_cannot_log_in_if_timed_out(self) -> None:
self.req.config.mfa_timeout_s = 600
- user = self.create_user(
+ user = UserFactory(
username="test",
mfa_method=MfaMethod.TOTP,
mfa_secret_key=pyotp.random_base32(),
+ password__request=self.req,
+ password="secret",
+ )
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=group.id, may_use_webviewer=True
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
- self.create_membership(user, self.group, may_use_webviewer=True)
totp = pyotp.TOTP(user.mfa_secret_key)
@@ -3159,11 +2955,12 @@ def test_user_cannot_log_in_if_timed_out(self) -> None:
mock_fail_timed_out.assert_called_once()
def test_unprivileged_user_cannot_log_in(self) -> None:
- user = self.create_user(username="test")
- user.set_password(self.req, "secret")
- self.dbsession.flush()
-
- self.create_membership(user, self.group, may_use_webviewer=False)
+ user = UserFactory(
+ username="test", password__request=self.req, password="secret"
+ )
+ UserGroupMembershipFactory(
+ user_id=user.id, group_id=self.group.id, may_use_webviewer=False
+ )
multidict = MultiDict(
[
@@ -3235,7 +3032,7 @@ def test_timed_out_false_when_no_authenticated_user(self) -> None:
def test_timed_out_false_when_no_authentication_time(self) -> None:
view = LoginView(self.req)
- user = self.create_user(username="test")
+ user = UserFactory(username="test")
# Should never be the case that we have a user ID but no
# authentication time
view.state["mfa_user_id"] = user.id
@@ -3245,8 +3042,7 @@ def test_timed_out_false_when_no_authentication_time(self) -> None:
class EditUserViewTests(BasicDatabaseTestCase):
def test_redirect_on_cancel(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
+ regular_user = UserFactory(username="regular_user")
self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"})
self.req.add_get_params(
{ViewParam.USER_ID: str(regular_user.id)}, set_method_get=False
@@ -3261,10 +3057,9 @@ def test_redirect_on_cancel(self) -> None:
)
def test_raises_if_user_may_not_edit_another(self) -> None:
- self.req.add_get_params({ViewParam.USER_ID: str(self.user.id)})
+ self.req.add_get_params({ViewParam.USER_ID: str(self.system_user.id)})
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
+ regular_user = UserFactory(username="regular_user")
self.req._debugging_user = regular_user
with self.assertRaises(HTTPBadRequest) as cm:
edit_user(self.req)
@@ -3272,8 +3067,7 @@ def test_raises_if_user_may_not_edit_another(self) -> None:
self.assertIn("Nobody may edit the system user", cm.exception.message)
def test_superuser_sees_full_form(self) -> None:
- superuser = self.create_user(username="admin", superuser=True)
- self.dbsession.flush()
+ superuser = UserFactory(username="admin", superuser=True)
self.req._debugging_user = superuser
self.req.add_get_params({ViewParam.USER_ID: str(superuser.id)})
@@ -3283,12 +3077,13 @@ def test_superuser_sees_full_form(self) -> None:
self.assertIn("Superuser (CAUTION!)", response.body.decode(UTF8))
def test_groupadmin_sees_groupadmin_form(self) -> None:
- groupadmin = self.create_user(username="groupadmin")
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
- self.create_membership(groupadmin, self.group, groupadmin=True)
- self.create_membership(regular_user, self.group)
- self.dbsession.flush()
+ groupadmin = UserFactory(username="groupadmin")
+ regular_user = UserFactory(username="regular_user")
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id, group_id=group.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id)
self.req._debugging_user = groupadmin
self.req.add_get_params({ViewParam.USER_ID: str(regular_user.id)})
@@ -3300,9 +3095,8 @@ def test_groupadmin_sees_groupadmin_form(self) -> None:
self.assertNotIn("Superuser (CAUTION!)", content)
def test_raises_for_conflicting_user_name(self) -> None:
- self.create_user(username="existing_user")
- other_user = self.create_user(username="other_user")
- self.dbsession.flush()
+ UserFactory(username="existing_user")
+ other_user = UserFactory(username="other_user")
multidict = MultiDict(
[
@@ -3321,13 +3115,12 @@ def test_raises_for_conflicting_user_name(self) -> None:
self.assertIn("Can't rename user", cm.exception.message)
def test_user_is_updated(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="old_username",
fullname="Old Name",
email="old@example.com",
language="da_DK",
)
- self.dbsession.flush()
multidict = MultiDict(
[
@@ -3352,9 +3145,8 @@ def test_user_is_updated(self) -> None:
self.assertEqual(user.language, "en_GB")
def test_user_is_added_to_group(self) -> None:
- user = self.create_user(username="regular_user")
- group = self.create_group("group")
- self.dbsession.flush()
+ user = UserFactory()
+ group = GroupFactory()
multidict = MultiDict(
[
@@ -3377,15 +3169,19 @@ def test_user_is_added_to_group(self) -> None:
mock_set_group_ids.assert_called_once_with([group.id])
def test_user_stays_in_group_the_groupadmin_cannot_edit(self) -> None:
- regular_user = self.create_user(username="regular_user")
- group_b_admin = self.create_user(username="group_b_admin")
- group_a = self.create_group("group_a")
- group_b = self.create_group("group_b")
- self.dbsession.flush()
- self.create_membership(regular_user, group_a)
- self.create_membership(regular_user, group_b)
- self.create_membership(group_b_admin, group_b, groupadmin=True)
- self.dbsession.flush()
+ regular_user = UserFactory(username="regular_user")
+ group_b_admin = UserFactory(username="group_b_admin")
+ group_a = GroupFactory()
+ group_b = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group_a.id
+ )
+ UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group_b.id
+ )
+ UserGroupMembershipFactory(
+ user_id=group_b_admin.id, group_id=group_b.id, groupadmin=True
+ )
self.req._debugging_user = group_b_admin
multidict = MultiDict(
@@ -3411,18 +3207,22 @@ def test_user_stays_in_group_the_groupadmin_cannot_edit(self) -> None:
mock_set_group_ids.assert_called_once_with([group_a.id, group_b.id])
def test_upload_group_id_unset_when_membership_removed(self) -> None:
- group_a = self.create_group("group_a")
- group_b = self.create_group("group_b")
- regular_user = self.create_user(
- username="regular_user", upload_group=group_a
- )
- groupadmin = self.create_user(username="groupadmin")
- self.dbsession.flush()
- self.create_membership(regular_user, group_a)
- self.create_membership(regular_user, group_b)
- self.create_membership(groupadmin, group_a, groupadmin=True)
- self.create_membership(groupadmin, group_b, groupadmin=True)
- self.dbsession.flush()
+ group_a = GroupFactory()
+ group_b = GroupFactory()
+ regular_user = UserFactory(upload_group=group_a)
+ groupadmin = UserFactory()
+ UserGroupMembershipFactory(
+ group_id=group_a.id, user_id=regular_user.id
+ )
+ UserGroupMembershipFactory(
+ group_id=group_b.id, user_id=regular_user.id
+ )
+ UserGroupMembershipFactory(
+ group_id=group_a.id, user_id=groupadmin.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ group_id=group_b.id, user_id=groupadmin.id, groupadmin=True
+ )
self.req._debugging_user = groupadmin
multidict = MultiDict(
@@ -3445,20 +3245,24 @@ def test_upload_group_id_unset_when_membership_removed(self) -> None:
self.assertIsNone(regular_user.upload_group_id)
def test_get_form_values(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user",
fullname="Full Name",
email="user@example.com",
language="da_DK",
)
- group_b_admin = self.create_user(username="group_b_admin")
- group_a = self.create_group("group_a")
- group_b = self.create_group("group_b")
- self.dbsession.flush()
- self.create_membership(regular_user, group_a)
- self.create_membership(regular_user, group_b)
- self.create_membership(group_b_admin, group_b, groupadmin=True)
- self.dbsession.flush()
+ group_b_admin = UserFactory(username="group_b_admin")
+ group_a = GroupFactory()
+ group_b = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group_a.id
+ )
+ UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group_b.id
+ )
+ UserGroupMembershipFactory(
+ user_id=group_b_admin.id, group_id=group_b.id, groupadmin=True
+ )
self.req._debugging_user = group_b_admin
view = EditUserGroupAdminView(self.req)
@@ -3480,12 +3284,11 @@ def test_get_form_values(self) -> None:
self.assertEqual(form_values[ViewParam.GROUP_IDS], [group_b.id])
def test_raises_if_email_address_used_for_mfa(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user",
mfa_method=MfaMethod.HOTP_EMAIL,
email="user@example.com",
)
- self.dbsession.flush()
multidict = MultiDict(
[
@@ -3509,11 +3312,9 @@ def test_raises_if_email_address_used_for_mfa(self) -> None:
class EditOwnUserMfaViewTests(BasicDatabaseTestCase):
def test_get_form_values_mfa_method(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user", mfa_method=MfaMethod.HOTP_SMS
)
- self.dbsession.flush()
-
self.req._debugging_user = regular_user
view = EditOwnUserMfaView(self.req)
@@ -3527,13 +3328,11 @@ def test_get_form_values_mfa_method(self) -> None:
)
def test_get_form_values_hotp_email(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user",
mfa_method=MfaMethod.HOTP_EMAIL,
email="regular_user@example.com",
)
- self.dbsession.flush()
-
self.req._debugging_user = regular_user
view = EditOwnUserMfaView(self.req)
@@ -3556,13 +3355,11 @@ def test_get_form_values_hotp_email(self) -> None:
self.assertEqual(form_values[ViewParam.EMAIL], regular_user.email)
def test_get_form_values_hotp_sms(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user",
mfa_method=MfaMethod.HOTP_SMS,
- phone_number=phonenumbers.parse(TEST_PHONE_NUMBER),
+ phone_number=phonenumbers.parse(Fake.en_gb.valid_phone_number()),
)
- self.dbsession.flush()
-
self.req._debugging_user = regular_user
view = EditOwnUserMfaView(self.req)
@@ -3587,11 +3384,9 @@ def test_get_form_values_hotp_sms(self) -> None:
)
def test_get_form_values_totp(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user", mfa_method=MfaMethod.TOTP
)
- self.dbsession.flush()
-
self.req._debugging_user = regular_user
view = EditOwnUserMfaView(self.req)
@@ -3613,13 +3408,11 @@ def test_get_form_values_totp(self) -> None:
)
def test_user_can_set_secret_key(self) -> None:
- regular_user = self.create_user(username="regular_user")
+ regular_user = UserFactory(username="regular_user")
regular_user.mfa_method = MfaMethod.TOTP
regular_user.ensure_mfa_info()
# ... otherwise, the absence of e.g. the HOTP counter will cause a
# secret key reset.
- self.dbsession.flush()
-
mfa_secret_key = pyotp.random_base32()
multidict = MultiDict(
@@ -3640,9 +3433,7 @@ def test_user_can_set_secret_key(self) -> None:
self.assertEqual(regular_user.mfa_secret_key, mfa_secret_key)
def test_user_can_set_method_totp(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
-
+ regular_user = UserFactory(username="regular_user")
multidict = MultiDict(
[
(ViewParam.MFA_METHOD, MfaMethod.TOTP),
@@ -3660,9 +3451,7 @@ def test_user_can_set_method_totp(self) -> None:
self.assertEqual(regular_user.mfa_method, MfaMethod.TOTP)
def test_user_can_set_method_hotp_email(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
-
+ regular_user = UserFactory(username="regular_user")
multidict = MultiDict(
[
(ViewParam.MFA_METHOD, MfaMethod.HOTP_EMAIL),
@@ -3681,9 +3470,7 @@ def test_user_can_set_method_hotp_email(self) -> None:
self.assertEqual(regular_user.hotp_counter, 0)
def test_user_can_set_method_hotp_sms(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
-
+ regular_user = UserFactory(username="regular_user")
multidict = MultiDict(
[
(ViewParam.MFA_METHOD, MfaMethod.HOTP_SMS),
@@ -3702,11 +3489,9 @@ def test_user_can_set_method_hotp_sms(self) -> None:
self.assertEqual(regular_user.hotp_counter, 0)
def test_user_can_disable_mfa(self) -> None:
- regular_user = self.create_user(
+ regular_user = UserFactory(
username="regular_user", mfa_method=MfaMethod.TOTP
)
- self.dbsession.flush()
-
multidict = MultiDict(
[
(ViewParam.MFA_METHOD, MfaMethod.NO_MFA),
@@ -3730,13 +3515,14 @@ def test_user_can_disable_mfa(self) -> None:
self.assertEqual(regular_user.mfa_method, MfaMethod.NO_MFA)
def test_user_can_set_phone_number(self) -> None:
- regular_user = self.create_user(username="regular_user")
+ regular_user = UserFactory(username="regular_user")
regular_user.mfa_method = MfaMethod.HOTP_SMS
- self.dbsession.flush()
+
+ phone_number_str = Fake.en_gb.valid_phone_number()
multidict = MultiDict(
[
- (ViewParam.PHONE_NUMBER, TEST_PHONE_NUMBER),
+ (ViewParam.PHONE_NUMBER, phone_number_str),
(FormAction.SUBMIT, "submit"),
]
)
@@ -3749,16 +3535,15 @@ def test_user_can_set_phone_number(self) -> None:
view.dispatch()
- test_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- self.assertEqual(regular_user.phone_number, test_number)
+ self.assertEqual(
+ regular_user.phone_number, phonenumbers.parse(phone_number_str)
+ )
def test_user_can_set_email_address(self) -> None:
- regular_user = self.create_user(username="regular_user")
+ regular_user = UserFactory(username="regular_user")
# We're going to force this user to the e-mail verification step, so
# we need to ensure it's set to use e-mail MFA:
regular_user.mfa_method = MfaMethod.HOTP_EMAIL
- self.dbsession.flush()
-
multidict = MultiDict(
[
(ViewParam.EMAIL, "regular_user@example.com"),
@@ -3797,8 +3582,7 @@ def test_raises_for_invalid_user(self) -> None:
self.assertIn("Cannot find User with id:123", cm.exception.message)
def test_raises_when_user_may_not_edit_other_user(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
+ regular_user = UserFactory(username="regular_user")
multidict = MultiDict(
[
("__start__", "new_password:mapping"),
@@ -3812,7 +3596,7 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_ID: str(self.user.id)}, set_method_get=False
+ {ViewParam.USER_ID: str(self.system_user.id)}, set_method_get=False
)
view = ChangeOtherPasswordView(self.req)
@@ -3822,12 +3606,13 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None:
self.assertIn("Nobody may edit the system user", cm.exception.message)
def test_password_set(self) -> None:
- groupadmin = self.create_user(username="groupadmin")
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
- self.create_membership(groupadmin, self.group, groupadmin=True)
- self.create_membership(regular_user, self.group)
- self.dbsession.flush()
+ groupadmin = UserFactory(username="groupadmin")
+ regular_user = UserFactory(username="regular_user")
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id, group_id=group.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id)
self.assertFalse(regular_user.must_change_password)
@@ -3863,13 +3648,13 @@ def test_password_set(self) -> None:
self.assertIn("Password changed for user 'regular_user'", messages[0])
def test_user_forced_to_change_password(self) -> None:
- groupadmin = self.create_user(username="groupadmin")
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
- self.create_membership(groupadmin, self.group, groupadmin=True)
- self.create_membership(regular_user, self.group)
- self.dbsession.flush()
-
+ groupadmin = UserFactory(username="groupadmin")
+ regular_user = UserFactory(username="regular_user")
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id, group_id=group.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id)
multidict = MultiDict(
[
(ViewParam.MUST_CHANGE_PASSWORD, "true"),
@@ -3898,8 +3683,7 @@ def test_user_forced_to_change_password(self) -> None:
mock_force_change.assert_called_once()
def test_redirects_if_editing_own_account(self) -> None:
- superuser = self.create_user(username="admin", superuser=True)
- self.dbsession.flush()
+ superuser = UserFactory(username="admin", superuser=True)
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(superuser.id)}, set_method_get=False
@@ -3919,7 +3703,7 @@ def test_redirects_if_editing_own_account(self) -> None:
def test_user_sees_otp_form_if_mfa_setup(
self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock
) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="admin",
superuser=True,
email="admin@example.com",
@@ -3928,8 +3712,7 @@ def test_user_sees_otp_form_if_mfa_setup(
hotp_counter=0,
)
- user = self.create_user(username="user")
- self.dbsession.flush()
+ user = UserFactory(username="user")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -3951,19 +3734,17 @@ def test_code_sent_if_mfa_setup(self) -> None:
SmsBackendNames.CONSOLE, {}
)
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- superuser = self.create_user(
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ superuser = UserFactory(
username="admin",
superuser=True,
email="admin@example.com",
- phone_number=phone_number,
+ phone_number=phonenumbers.parse(phone_number_str),
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_SMS,
hotp_counter=0,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -3977,12 +3758,12 @@ def test_code_sent_if_mfa_setup(self) -> None:
expected_message = f"Your CamCOPS verification code is {expected_code}"
self.assertIn(
- ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message),
+ ConsoleSmsBackend.make_msg(phone_number_str, expected_message),
logging_cm.output[0],
)
def test_user_can_enter_token(self) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="admin",
superuser=True,
mfa_method=MfaMethod.HOTP_EMAIL,
@@ -3990,9 +3771,7 @@ def test_user_can_enter_token(self) -> None:
email="user@example.com",
hotp_counter=1,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4021,7 +3800,7 @@ def test_user_can_enter_token(self) -> None:
)
def test_form_state_cleared_on_invalid_token(self) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="superuser",
superuser=True,
mfa_method=MfaMethod.HOTP_EMAIL,
@@ -4029,9 +3808,7 @@ def test_form_state_cleared_on_invalid_token(self) -> None:
email="user@example.com",
hotp_counter=1,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4059,15 +3836,13 @@ def test_form_state_cleared_on_invalid_token(self) -> None:
def test_cannot_change_password_if_timed_out(self) -> None:
self.req.config.mfa_timeout_s = 600
- superuser = self.create_user(
+ superuser = UserFactory(
username="admin",
superuser=True,
mfa_method=MfaMethod.TOTP,
mfa_secret_key=pyotp.random_base32(),
)
- user = self.create_user(username="user")
- self.dbsession.flush()
-
+ user = UserFactory(username="user")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4119,14 +3894,13 @@ def test_raises_for_invalid_user(self) -> None:
self.assertIn("Cannot find User with id:123", cm.exception.message)
def test_raises_when_user_may_not_edit_other_user(self) -> None:
- regular_user = self.create_user(username="regular_user")
- self.dbsession.flush()
+ regular_user = UserFactory(username="regular_user")
multidict = MultiDict([(FormAction.SUBMIT, "submit")])
self.req._debugging_user = regular_user
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_ID: str(self.user.id)}, set_method_get=False
+ {ViewParam.USER_ID: str(self.system_user.id)}, set_method_get=False
)
view = EditOtherUserMfaView(self.req)
@@ -4136,15 +3910,15 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None:
self.assertIn("Nobody may edit the system user", cm.exception.message)
def test_disable_mfa(self) -> None:
- groupadmin = self.create_user(username="groupadmin")
- regular_user = self.create_user(
+ groupadmin = UserFactory(username="groupadmin")
+ regular_user = UserFactory(
username="regular_user", mfa_method=MfaMethod.TOTP
)
- self.dbsession.flush()
- self.create_membership(groupadmin, self.group, groupadmin=True)
- self.create_membership(regular_user, self.group)
- self.dbsession.flush()
-
+ group = GroupFactory()
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id, group_id=group.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id)
self.assertFalse(regular_user.must_change_password)
multidict = MultiDict(
@@ -4171,8 +3945,7 @@ def test_disable_mfa(self) -> None:
)
def test_redirects_if_editing_own_account(self) -> None:
- superuser = self.create_user(username="admin", superuser=True)
- self.dbsession.flush()
+ superuser = UserFactory(username="admin", superuser=True)
self.req._debugging_user = superuser
self.req.add_get_params({ViewParam.USER_ID: str(superuser.id)})
@@ -4190,7 +3963,7 @@ def test_redirects_if_editing_own_account(self) -> None:
def test_user_sees_otp_form_if_mfa_setup(
self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock
) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="admin",
superuser=True,
email="admin@example.com",
@@ -4199,8 +3972,7 @@ def test_user_sees_otp_form_if_mfa_setup(
hotp_counter=0,
)
- user = self.create_user(username="user")
- self.dbsession.flush()
+ user = UserFactory(username="user")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4222,19 +3994,17 @@ def test_code_sent_if_mfa_setup(self) -> None:
SmsBackendNames.CONSOLE, {}
)
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- superuser = self.create_user(
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ superuser = UserFactory(
username="admin",
superuser=True,
email="admin@example.com",
- phone_number=phone_number,
+ phone_number=phonenumbers.parse(phone_number_str),
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_SMS,
hotp_counter=0,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4248,12 +4018,12 @@ def test_code_sent_if_mfa_setup(self) -> None:
expected_message = f"Your CamCOPS verification code is {expected_code}"
self.assertIn(
- ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message),
+ ConsoleSmsBackend.make_msg(phone_number_str, expected_message),
logging_cm.output[0],
)
def test_user_can_enter_token(self) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="admin",
superuser=True,
mfa_method=MfaMethod.HOTP_EMAIL,
@@ -4261,9 +4031,7 @@ def test_user_can_enter_token(self) -> None:
email="user@example.com",
hotp_counter=1,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4292,7 +4060,7 @@ def test_user_can_enter_token(self) -> None:
)
def test_form_state_cleared_on_invalid_token(self) -> None:
- superuser = self.create_user(
+ superuser = UserFactory(
username="superuser",
superuser=True,
mfa_method=MfaMethod.HOTP_EMAIL,
@@ -4300,9 +4068,7 @@ def test_form_state_cleared_on_invalid_token(self) -> None:
email="user@example.com",
hotp_counter=1,
)
- user = self.create_user(username="user", email="user@example.com")
- self.dbsession.flush()
-
+ user = UserFactory(username="user", email="user@example.com")
self.req._debugging_user = superuser
self.req.add_get_params(
{ViewParam.USER_ID: str(user.id)}, set_method_get=False
@@ -4329,45 +4095,34 @@ def test_form_state_cleared_on_invalid_token(self) -> None:
self.assert_state_is_clean()
-class EditUserGroupMembershipViewTests(BasicDatabaseTestCase):
- def setUp(self) -> None:
- super().setUp()
-
- self.regular_user = User()
- self.regular_user.username = "ruser"
- self.regular_user.hashedpw = ""
- self.dbsession.add(self.regular_user)
- self.dbsession.flush()
-
- self.group_admin = User()
- self.group_admin.username = "gadmin"
- self.group_admin.hashedpw = ""
- self.dbsession.add(self.group_admin)
- self.dbsession.flush()
+class EditUserGroupMembershipViewTests(DemoRequestTestCase):
+ def test_superuser_can_update_user_group_membership(self) -> None:
+ regular_user = UserFactory()
+ groupadmin = UserFactory()
+ group = GroupFactory()
- admin_ugm = UserGroupMembership(
- user_id=self.group_admin.id, group_id=self.group.id
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id,
+ group_id=group.id,
+ groupadmin=True,
)
- admin_ugm.groupadmin = True
- self.dbsession.add(admin_ugm)
- self.ugm = UserGroupMembership(
- user_id=self.regular_user.id, group_id=self.group.id
+ ugm = UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group.id
)
- self.dbsession.add(self.ugm)
- self.dbsession.commit()
- def test_superuser_can_update_user_group_membership(self) -> None:
- self.assertFalse(self.ugm.may_upload)
- self.assertFalse(self.ugm.may_register_devices)
- self.assertFalse(self.ugm.may_use_webviewer)
- self.assertFalse(self.ugm.view_all_patients_when_unfiltered)
- self.assertFalse(self.ugm.may_dump_data)
- self.assertFalse(self.ugm.may_run_reports)
- self.assertFalse(self.ugm.may_add_notes)
- self.assertFalse(self.ugm.may_manage_patients)
- self.assertFalse(self.ugm.may_email_patients)
- self.assertFalse(self.ugm.groupadmin)
+ self.req._debugging_user = groupadmin
+
+ self.assertFalse(ugm.may_upload)
+ self.assertFalse(ugm.may_register_devices)
+ self.assertFalse(ugm.may_use_webviewer)
+ self.assertFalse(ugm.view_all_patients_when_unfiltered)
+ self.assertFalse(ugm.may_dump_data)
+ self.assertFalse(ugm.may_run_reports)
+ self.assertFalse(ugm.may_add_notes)
+ self.assertFalse(ugm.may_manage_patients)
+ self.assertFalse(ugm.may_email_patients)
+ self.assertFalse(ugm.groupadmin)
multidict = MultiDict(
[
@@ -4387,35 +4142,49 @@ def test_superuser_can_update_user_group_membership(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)},
+ {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)},
set_method_get=False,
)
with self.assertRaises(HTTPFound):
edit_user_group_membership(self.req)
- self.assertTrue(self.ugm.may_upload)
- self.assertTrue(self.ugm.may_register_devices)
- self.assertTrue(self.ugm.may_use_webviewer)
- self.assertTrue(self.ugm.view_all_patients_when_unfiltered)
- self.assertTrue(self.ugm.may_dump_data)
- self.assertTrue(self.ugm.may_run_reports)
- self.assertTrue(self.ugm.may_add_notes)
- self.assertTrue(self.ugm.may_manage_patients)
- self.assertTrue(self.ugm.may_email_patients)
+ self.assertTrue(ugm.may_upload)
+ self.assertTrue(ugm.may_register_devices)
+ self.assertTrue(ugm.may_use_webviewer)
+ self.assertTrue(ugm.view_all_patients_when_unfiltered)
+ self.assertTrue(ugm.may_dump_data)
+ self.assertTrue(ugm.may_run_reports)
+ self.assertTrue(ugm.may_add_notes)
+ self.assertTrue(ugm.may_manage_patients)
+ self.assertTrue(ugm.may_email_patients)
def test_groupadmin_can_update_user_group_membership(self) -> None:
- self.req._debugging_user = self.group_admin
-
- self.assertFalse(self.ugm.may_upload)
- self.assertFalse(self.ugm.may_register_devices)
- self.assertFalse(self.ugm.may_use_webviewer)
- self.assertFalse(self.ugm.view_all_patients_when_unfiltered)
- self.assertFalse(self.ugm.may_dump_data)
- self.assertFalse(self.ugm.may_run_reports)
- self.assertFalse(self.ugm.may_add_notes)
- self.assertFalse(self.ugm.may_manage_patients)
- self.assertFalse(self.ugm.may_email_patients)
+ regular_user = UserFactory()
+ groupadmin = UserFactory()
+ group = GroupFactory()
+
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id,
+ group_id=group.id,
+ groupadmin=True,
+ )
+
+ ugm = UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group.id
+ )
+
+ self.req._debugging_user = groupadmin
+
+ self.assertFalse(ugm.may_upload)
+ self.assertFalse(ugm.may_register_devices)
+ self.assertFalse(ugm.may_use_webviewer)
+ self.assertFalse(ugm.view_all_patients_when_unfiltered)
+ self.assertFalse(ugm.may_dump_data)
+ self.assertFalse(ugm.may_run_reports)
+ self.assertFalse(ugm.may_add_notes)
+ self.assertFalse(ugm.may_manage_patients)
+ self.assertFalse(ugm.may_email_patients)
multidict = MultiDict(
[
@@ -4434,33 +4203,45 @@ def test_groupadmin_can_update_user_group_membership(self) -> None:
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)},
+ {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)},
set_method_get=False,
)
with self.assertRaises(HTTPFound):
edit_user_group_membership(self.req)
- self.assertTrue(self.ugm.may_upload)
- self.assertTrue(self.ugm.may_register_devices)
- self.assertTrue(self.ugm.may_use_webviewer)
- self.assertTrue(self.ugm.view_all_patients_when_unfiltered)
- self.assertTrue(self.ugm.may_dump_data)
- self.assertTrue(self.ugm.may_run_reports)
- self.assertTrue(self.ugm.may_add_notes)
- self.assertTrue(self.ugm.may_manage_patients)
- self.assertTrue(self.ugm.may_email_patients)
+ self.assertTrue(ugm.may_upload)
+ self.assertTrue(ugm.may_register_devices)
+ self.assertTrue(ugm.may_use_webviewer)
+ self.assertTrue(ugm.view_all_patients_when_unfiltered)
+ self.assertTrue(ugm.may_dump_data)
+ self.assertTrue(ugm.may_run_reports)
+ self.assertTrue(ugm.may_add_notes)
+ self.assertTrue(ugm.may_manage_patients)
+ self.assertTrue(ugm.may_email_patients)
def test_raises_if_cant_edit_user(self) -> None:
- self.ugm.user_id = self.user.id
- self.dbsession.add(self.ugm)
- self.dbsession.commit()
+ system_user = User.get_system_user(self.dbsession)
+ groupadmin = UserFactory()
+ group = GroupFactory()
+
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id,
+ group_id=group.id,
+ groupadmin=True,
+ )
+
+ system_ugm = UserGroupMembershipFactory(
+ user_id=system_user.id, group_id=group.id
+ )
+
+ self.req._debugging_user = groupadmin
multidict = MultiDict([(FormAction.SUBMIT, "submit")])
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)},
+ {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(system_ugm.id)},
set_method_get=False,
)
@@ -4470,22 +4251,22 @@ def test_raises_if_cant_edit_user(self) -> None:
self.assertIn("Nobody may edit the system user", cm.exception.message)
def test_raises_if_cant_administer_group(self) -> None:
- group_a = self.create_group("groupa")
- group_b = self.create_group("groupb")
+ group_a = GroupFactory()
+ group_b = GroupFactory()
- user1 = self.create_user(username="user1")
- user2 = self.create_user(username="user2")
- self.dbsession.flush()
+ user1 = UserFactory()
+ user2 = UserFactory()
# User 1 is a group administrator for group A,
# User 2 is a member if group A
- self.create_membership(user1, group_a, groupadmin=True)
- self.create_membership(user2, group_a),
+ UserGroupMembershipFactory(
+ user_id=user1.id, group_id=group_a.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(user_id=user2.id, group_id=group_a.id),
# User 1 is not an administrator of group B
# User 2 is a member of group B
- ugm = self.create_membership(user2, group_b)
- self.dbsession.commit()
+ ugm = UserGroupMembershipFactory(user_id=user2.id, group_id=group_b.id)
multidict = MultiDict([(FormAction.SUBMIT, "submit")])
@@ -4505,11 +4286,26 @@ def test_raises_if_cant_administer_group(self) -> None:
)
def test_cancel_returns_to_users_list(self) -> None:
+ regular_user = UserFactory()
+ groupadmin = UserFactory()
+ group = GroupFactory()
+
+ UserGroupMembershipFactory(
+ user_id=groupadmin.id,
+ group_id=group.id,
+ groupadmin=True,
+ )
+
+ ugm = UserGroupMembershipFactory(
+ user_id=regular_user.id, group_id=group.id
+ )
+
+ self.req._debugging_user = groupadmin
multidict = MultiDict([(FormAction.CANCEL, "cancel")])
self.req.fake_request_post_from_dict(multidict)
self.req.add_get_params(
- {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)},
+ {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)},
set_method_get=False,
)
@@ -4530,8 +4326,12 @@ def setUp(self) -> None:
def test_user_can_change_password(self) -> None:
new_password = "monkeybusiness"
- user = self.create_user(username="user", mfa_method=MfaMethod.NO_MFA)
- user.set_password(self.req, "secret")
+ user = UserFactory(
+ username="user",
+ mfa_method=MfaMethod.NO_MFA,
+ password__request=self.req,
+ password="secret",
+ )
multidict = MultiDict(
[
(ViewParam.OLD_PASSWORD, "secret"),
@@ -4558,7 +4358,7 @@ def test_user_can_change_password(self) -> None:
self.assert_state_is_finished()
def test_user_sees_expiry_message(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="user",
mfa_method=MfaMethod.NO_MFA,
must_change_password=True,
@@ -4584,7 +4384,7 @@ def test_password_must_differ(self) -> None:
def test_user_sees_otp_form_if_mfa_setup(
self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock
) -> None:
- user = self.create_user(
+ user = UserFactory(
username="user",
email="user@example.com",
mfa_method=MfaMethod.HOTP_EMAIL,
@@ -4608,11 +4408,11 @@ def test_code_sent_if_mfa_setup(self) -> None:
self.req.config.sms_backend = get_sms_backend(
SmsBackendNames.CONSOLE, {}
)
- phone_number = phonenumbers.parse(TEST_PHONE_NUMBER)
- user = self.create_user(
+ phone_number_str = Fake.en_gb.valid_phone_number()
+ user = UserFactory(
username="user",
email="user@example.com",
- phone_number=phone_number,
+ phone_number=phonenumbers.parse(phone_number_str),
mfa_secret_key=pyotp.random_base32(),
mfa_method=MfaMethod.HOTP_SMS,
hotp_counter=0,
@@ -4627,21 +4427,20 @@ def test_code_sent_if_mfa_setup(self) -> None:
expected_message = f"Your CamCOPS verification code is {expected_code}"
self.assertIn(
- ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message),
+ ConsoleSmsBackend.make_msg(phone_number_str, expected_message),
logging_cm.output[0],
)
def test_user_can_enter_token(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="user",
mfa_method=MfaMethod.HOTP_EMAIL,
mfa_secret_key=pyotp.random_base32(),
email="user@example.com",
hotp_counter=1,
+ password__request=self.req,
+ password="secret",
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
-
self.req._debugging_user = user
hotp = pyotp.HOTP(user.mfa_secret_key)
@@ -4667,16 +4466,15 @@ def test_user_can_enter_token(self) -> None:
)
def test_form_state_cleared_on_invalid_token(self) -> None:
- user = self.create_user(
+ user = UserFactory(
username="user",
mfa_method=MfaMethod.HOTP_EMAIL,
mfa_secret_key=pyotp.random_base32(),
email="user@example.com",
hotp_counter=1,
+ password__request=self.req,
+ password="secret",
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
-
self.req._debugging_user = user
hotp = pyotp.HOTP(user.mfa_secret_key)
@@ -4701,14 +4499,13 @@ def test_form_state_cleared_on_invalid_token(self) -> None:
def test_cannot_change_password_if_timed_out(self) -> None:
self.req.config.mfa_timeout_s = 600
- user = self.create_user(
+ user = UserFactory(
username="user",
mfa_method=MfaMethod.TOTP,
mfa_secret_key=pyotp.random_base32(),
+ password__request=self.req,
+ password="secret",
)
- user.set_password(self.req, "secret")
- self.dbsession.flush()
-
self.req._debugging_user = user
totp = pyotp.TOTP(user.mfa_secret_key)
@@ -4734,3 +4531,58 @@ def test_cannot_change_password_if_timed_out(self) -> None:
view.dispatch()
mock_fail_timed_out.assert_called_once()
+
+
+class AddUserTests(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.groupadmin = self.req._debugging_user = UserFactory()
+
+ def test_user_created(self) -> None:
+ group_1 = GroupFactory()
+ group_2 = GroupFactory()
+
+ UserGroupMembershipFactory(
+ user_id=self.groupadmin.id, group_id=group_1.id, groupadmin=True
+ )
+ UserGroupMembershipFactory(
+ user_id=self.groupadmin.id, group_id=group_2.id, groupadmin=True
+ )
+
+ multidict = MultiDict(
+ [
+ ("_charset_", UTF8),
+ ("__formid__", "deform"),
+ (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()),
+ (ViewParam.USERNAME, "test"),
+ ("__start__", "new_password:mapping"),
+ (ViewParam.NEW_PASSWORD, "monkeybusiness"),
+ ("new_password-confirm", "monkeybusiness"),
+ ("__end__", "new_password:mapping"),
+ (ViewParam.MUST_CHANGE_PASSWORD, "true"),
+ ("__start__", "group_ids:sequence"),
+ ("group_id_sequence", str(group_1.id)),
+ ("group_id_sequence", str(group_2.id)),
+ ("__end__", "group_ids:sequence"),
+ (FormAction.SUBMIT, "submit"),
+ ]
+ )
+ self.req.fake_request_post_from_dict(multidict)
+
+ with self.assertRaises(HTTPFound):
+ add_user(self.req)
+
+ user = (
+ self.dbsession.query(User)
+ .filter(
+ User.username == "test",
+ )
+ .one_or_none()
+ )
+
+ self.assertIsNotNone(user)
+
+ self.assertTrue(user.must_change_password)
+ self.assertIn(group_1.id, user.group_ids)
+ self.assertIn(group_2.id, user.group_ids)
diff --git a/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py b/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py
index 7a24da711..5733eae09 100644
--- a/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py
+++ b/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py
@@ -25,33 +25,26 @@
"""
-from typing import Generator, Optional
-
import pendulum
-from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
from camcops_server.tasks.apeq_cpft_perinatal import (
- APEQCPFTPerinatal,
APEQCPFTPerinatalReport,
)
-
+from camcops_server.tasks.tests.factories import APEQCPFTPerinatalFactory
# =============================================================================
# Unit tests
# =============================================================================
-class APEQCPFTPerinatalReportTestCase(BasicDatabaseTestCase):
+class APEQCPFTPerinatalReportTestCase(DemoRequestTestCase):
COL_Q = 0
COL_TOTAL = 1
COL_RESPONSE_START = 2
COL_FF_WHY = 1
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
-
def setUp(self) -> None:
super().setUp()
@@ -61,48 +54,9 @@ def setUp(self) -> None:
self.report.start_datetime = None
self.report.end_datetime = None
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- def create_task(
- self,
- q1: Optional[int],
- q2: Optional[int],
- q3: Optional[int],
- q4: Optional[int],
- q5: Optional[int],
- q6: Optional[int],
- ff_rating: int,
- ff_why: str = None,
- comments: str = None,
- era: str = None,
- ) -> None:
- task = APEQCPFTPerinatal()
- self.apply_standard_task_fields(task)
- task.id = next(self.id_sequence)
- task.q1 = q1
- task.q2 = q2
- task.q3 = q3
- task.q4 = q4
- task.q5 = q5
- task.q6 = q6
- task.ff_rating = ff_rating
- task.ff_why = ff_why
- task.comments = comments
-
- if era is not None:
- task.when_created = pendulum.parse(era)
-
- self.dbsession.add(task)
-
class APEQCPFTPerinatalReportTests(APEQCPFTPerinatalReportTestCase):
- def create_tasks(self) -> None:
+ def setUp(self) -> None:
"""
Creates 20 tasks.
Should give us:
@@ -132,43 +86,102 @@ def create_tasks(self) -> None:
5 - 35%
"""
- # q1 q2 q3 q4 q5 q6 ff
- self.create_task(0, 1, 0, 0, 2, 2, 5, ff_why="ff_5_1")
- self.create_task(
- 0, 1, 1, 0, 2, 2, 5, ff_why="ff_5_2", comments="comments_2"
+ super().setUp()
+
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=0, q4=0, q5=2, q6=2, ff_rating=5, ff_why="ff_5_1"
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0,
+ q2=1,
+ q3=1,
+ q4=0,
+ q5=2,
+ q6=2,
+ ff_rating=5,
+ ff_why="ff_5_2",
+ comments="comments_2",
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=1, q4=1, q5=2, q6=2, ff_rating=5
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=1, q4=1, q5=2, q6=2, ff_rating=5
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0,
+ q2=1,
+ q3=1,
+ q4=1,
+ q5=2,
+ q6=2,
+ ff_rating=5,
+ comments="comments_5",
+ )
+
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=2, q4=1, q5=2, q6=2, ff_rating=5
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=5
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=4, ff_why="ff_4_1"
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=3
+ )
+ APEQCPFTPerinatalFactory(
+ q1=0, q2=1, q3=2, q4=1, q5=1, q6=1, ff_rating=3, ff_why="ff_3_1"
+ )
+
+ APEQCPFTPerinatalFactory(
+ q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2, ff_why="ff_2_1"
+ )
+ APEQCPFTPerinatalFactory(
+ q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2
+ )
+ APEQCPFTPerinatalFactory(
+ q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2, ff_why="ff_2_2"
+ )
+ APEQCPFTPerinatalFactory(
+ q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=1, ff_why="ff_1_1"
+ )
+ APEQCPFTPerinatalFactory(
+ q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=1, ff_why="ff_1_2"
+ )
+
+ APEQCPFTPerinatalFactory(
+ q1=2, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=0
+ )
+ APEQCPFTPerinatalFactory(
+ q1=2, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=0
+ )
+ APEQCPFTPerinatalFactory(
+ q1=2, q2=1, q3=2, q4=2, q5=0, q6=None, ff_rating=0
+ )
+ APEQCPFTPerinatalFactory(
+ q1=2, q2=1, q3=2, q4=2, q5=0, q6=None, ff_rating=0
+ )
+ APEQCPFTPerinatalFactory(
+ q1=2,
+ q2=1,
+ q3=2,
+ q4=2,
+ q5=0,
+ q6=1,
+ ff_rating=0,
+ comments="comments_20",
)
- self.create_task(0, 1, 1, 1, 2, 2, 5)
- self.create_task(0, 1, 1, 1, 2, 2, 5)
- self.create_task(0, 1, 1, 1, 2, 2, 5, comments="comments_5")
-
- self.create_task(0, 1, 2, 1, 2, 2, 5)
- self.create_task(0, 1, 2, 1, 1, 2, 5)
- self.create_task(0, 1, 2, 1, 1, 2, 4, ff_why="ff_4_1")
- self.create_task(0, 1, 2, 1, 1, 2, 3)
- self.create_task(0, 1, 2, 1, 1, 1, 3, ff_why="ff_3_1")
-
- self.create_task(1, 1, 2, 2, 1, 1, 2, ff_why="ff_2_1")
- self.create_task(1, 1, 2, 2, 1, 1, 2)
- self.create_task(1, 1, 2, 2, 1, 1, 2, ff_why="ff_2_2")
- self.create_task(1, 1, 2, 2, 1, 1, 1, ff_why="ff_1_1")
- self.create_task(1, 1, 2, 2, 1, 1, 1, ff_why="ff_1_2")
-
- self.create_task(2, 1, 2, 2, 1, 1, 0)
- self.create_task(2, 1, 2, 2, 1, 1, 0)
- self.create_task(2, 1, 2, 2, 0, None, 0)
- self.create_task(2, 1, 2, 2, 0, None, 0)
- self.create_task(2, 1, 2, 2, 0, 1, 0, comments="comments_20")
-
- self.dbsession.commit()
def test_main_rows_contain_percentages(self) -> None:
expected_percentages = [
- [20, 50, 25, 25], # q1
- [20, "", 100, ""], # q2
- [20, 5, 20, 75], # q3
- [20, 10, 40, 50], # q4
- [20, 15, 55, 30], # q5
- [18, "", 50, 50], # q6
+ ["20", "50", "25", "25"], # q1
+ ["20", "", "100", ""], # q2
+ ["20", "5", "20", "75"], # q3
+ ["20", "10", "40", "50"], # q4
+ ["20", "15", "55", "30"], # q5
+ ["18", "", "50", "50"], # q6
]
main_rows = self.report._get_main_rows(self.req)
@@ -179,7 +192,7 @@ def test_main_rows_contain_percentages(self) -> None:
for p in row[1:]:
if p != "":
- p = int(float(p))
+ p = str(int(float(p)))
percentages.append(p)
@@ -229,73 +242,75 @@ def test_ff_why_rows_contain_reasons(self) -> None:
def test_comments(self) -> None:
expected_comments = ["comments_2", "comments_5", "comments_20"]
+
comments = self.report._get_comments(self.req)
self.assertEqual(comments, expected_comments)
class APEQCPFTPerinatalReportDateRangeTests(APEQCPFTPerinatalReportTestCase):
- def create_tasks(self) -> None:
- self.create_task(
- 1,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
+ def setUp(self) -> None:
+ super().setUp()
+
+ APEQCPFTPerinatalFactory(
+ q1=1,
+ q2=0,
+ q3=0,
+ q4=0,
+ q5=0,
+ q6=0,
+ ff_rating=0,
ff_why="ff why 1",
comments="comments 1",
- era="2018-10-01T00:00:00.000000+00:00",
+ when_created=pendulum.parse("2018-10-01"),
)
- self.create_task(
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 2,
+ APEQCPFTPerinatalFactory(
+ q1=0,
+ q2=0,
+ q3=0,
+ q4=0,
+ q5=0,
+ q6=0,
+ ff_rating=2,
ff_why="ff why 2",
comments="comments 2",
- era="2018-10-02T00:00:00.000000+00:00",
+ when_created=pendulum.parse("2018-10-02"),
)
- self.create_task(
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 2,
+ APEQCPFTPerinatalFactory(
+ q1=0,
+ q2=0,
+ q3=0,
+ q4=0,
+ q5=0,
+ q6=0,
+ ff_rating=2,
ff_why="ff why 3",
comments="comments 3",
- era="2018-10-03T00:00:00.000000+00:00",
+ when_created=pendulum.parse("2018-10-03"),
)
- self.create_task(
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 2,
+ APEQCPFTPerinatalFactory(
+ q1=0,
+ q2=0,
+ q3=0,
+ q4=0,
+ q5=0,
+ q6=0,
+ ff_rating=2,
ff_why="ff why 4",
comments="comments 4",
- era="2018-10-04T00:00:00.000000+00:00",
+ when_created=pendulum.parse("2018-10-04"),
)
- self.create_task(
- 1,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
+ APEQCPFTPerinatalFactory(
+ q1=1,
+ q2=0,
+ q3=0,
+ q4=0,
+ q5=0,
+ q6=0,
+ ff_rating=0,
ff_why="ff why 5",
comments="comments 5",
- era="2018-10-05T00:00:00.000000+00:00",
+ when_created=pendulum.parse("2018-10-05"),
)
- self.dbsession.commit()
def test_main_rows_filtered_by_date(self) -> None:
self.report.start_datetime = "2018-10-02T00:00:00.000000+00:00"
diff --git a/server/camcops_server/tasks/tests/core10_tests.py b/server/camcops_server/tasks/tests/core10_tests.py
index b089447f7..4fdf19139 100644
--- a/server/camcops_server/tasks/tests/core10_tests.py
+++ b/server/camcops_server/tasks/tests/core10_tests.py
@@ -27,86 +27,75 @@
import pendulum
-from camcops_server.cc_modules.cc_patient import Patient
-from camcops_server.cc_modules.tests.cc_report_tests import (
- AverageScoreReportTestCase,
+from camcops_server.cc_modules.cc_testfactories import (
+ PatientFactory,
+ UserFactory,
)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
from camcops_server.tasks.core10 import Core10, Core10Report
+from camcops_server.tasks.tests.factories import Core10Factory
-class Core10ReportTestCase(AverageScoreReportTestCase):
+class Core10ReportTestCase(DemoRequestTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ self.report = self.create_report()
+ self.req._debugging_user = UserFactory(superuser=True)
+
def create_report(self) -> Core10Report:
return Core10Report(via_index=False)
- def create_task(
- self,
- patient: Patient,
- q1: int = 0,
- q2: int = 0,
- q3: int = 0,
- q4: int = 0,
- q5: int = 0,
- q6: int = 0,
- q7: int = 0,
- q8: int = 0,
- q9: int = 0,
- q10: int = 0,
- era: str = None,
- ) -> None:
- task = Core10()
- self.apply_standard_task_fields(task)
- task.id = next(self.task_id_sequence)
-
- task.patient_id = patient.id
-
- task.q1 = q1
- task.q2 = q2
- task.q3 = q3
- task.q4 = q4
- task.q5 = q5
- task.q6 = q6
- task.q7 = q7
- task.q8 = q8
- task.q9 = q9
- task.q10 = q10
-
- if era is not None:
- task.when_created = pendulum.parse(era)
- # log.info(f"Creating task, when_created = {task.when_created}")
-
- self.dbsession.add(task)
-
-
-class Core10ReportTests(Core10ReportTestCase):
- def create_tasks(self) -> None:
- self.patient_1 = self.create_patient(idnum_value=333)
- self.patient_2 = self.create_patient(idnum_value=444)
- self.patient_3 = self.create_patient(idnum_value=555)
+
+class Core10ReportTotalsTests(Core10ReportTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+
+ patient_1 = PatientFactory()
+ patient_2 = PatientFactory()
+ patient_3 = PatientFactory()
# Initial average score = (8 + 6 + 4) / 3 = 6
# Latest average score = (2 + 3 + 4) / 3 = 3
- self.create_task(
- patient=self.patient_1, q1=4, q2=4, era="2018-06-01"
+ Core10Factory(
+ patient=patient_1,
+ q1=4,
+ q2=4,
+ when_created=pendulum.parse("2018-06-01"),
) # Score 8
- self.create_task(
- patient=self.patient_1, q7=1, q8=1, era="2018-10-04"
+ Core10Factory(
+ patient=patient_1,
+ q7=1,
+ q8=1,
+ when_created=pendulum.parse("2018-10-04"),
) # Score 2
- self.create_task(
- patient=self.patient_2, q3=3, q4=3, era="2018-05-02"
+ Core10Factory(
+ patient=patient_2,
+ q3=3,
+ q4=3,
+ when_created=pendulum.parse("2018-05-02"),
) # Score 6
- self.create_task(
- patient=self.patient_2, q3=2, q4=1, era="2018-10-03"
+ Core10Factory(
+ patient=patient_2,
+ q3=2,
+ q4=1,
+ when_created=pendulum.parse("2018-10-03"),
) # Score 3
- self.create_task(
- patient=self.patient_3, q5=2, q6=2, era="2018-01-10"
+ Core10Factory(
+ patient=patient_3,
+ q5=2,
+ q6=2,
+ when_created=pendulum.parse("2018-01-10"),
) # Score 4
- self.create_task(
- patient=self.patient_3, q9=1, q10=3, era="2018-10-01"
+ Core10Factory(
+ patient=patient_3,
+ q9=1,
+ q10=3,
+ when_created=pendulum.parse("2018-10-01"),
) # Score 4
- self.dbsession.commit()
def test_row_has_totals_and_averages(self) -> None:
pages = self.report.get_spreadsheet_pages(req=self.req)
@@ -131,32 +120,49 @@ def test_no_rows_when_no_data(self) -> None:
class Core10ReportDoubleCountingTests(Core10ReportTestCase):
- def create_tasks(self) -> None:
- self.patient_1 = self.create_patient(idnum_value=333)
- self.patient_2 = self.create_patient(idnum_value=444)
- self.patient_3 = self.create_patient(idnum_value=555)
+ def setUp(self) -> None:
+ super().setUp()
+
+ patient_1 = PatientFactory()
+ patient_2 = PatientFactory()
+ patient_3 = PatientFactory()
# Initial average score = (8 + 6 + 4) / 3 = 6
# Latest average score = ( 3 + 3) / 2 = 3
# Progress avg score = ( 3 + 1) / 2 = 2 ... NOT 3.
- self.create_task(
- patient=self.patient_1, q1=4, q2=4, era="2018-06-01"
+
+ Core10Factory(
+ patient=patient_1,
+ q1=4,
+ q2=4,
+ when_created=pendulum.parse("2018-06-01"),
) # Score 8
- self.create_task(
- patient=self.patient_2, q3=3, q4=3, era="2018-05-02"
+ Core10Factory(
+ patient=patient_2,
+ q3=3,
+ q4=3,
+ when_created=pendulum.parse("2018-05-02"),
) # Score 6
- self.create_task(
- patient=self.patient_2, q3=2, q4=1, era="2018-10-03"
+ Core10Factory(
+ patient=patient_2,
+ q3=2,
+ q4=1,
+ when_created=pendulum.parse("2018-10-03"),
) # Score 3
- self.create_task(
- patient=self.patient_3, q5=2, q6=2, era="2018-01-10"
+ Core10Factory(
+ patient=patient_3,
+ q5=2,
+ q6=2,
+ when_created=pendulum.parse("2018-01-10"),
) # Score 4
- self.create_task(
- patient=self.patient_3, q9=1, q10=2, era="2018-10-01"
+ Core10Factory(
+ patient=patient_3,
+ q9=1,
+ q10=2,
+ when_created=pendulum.parse("2018-10-01"),
) # Score 3
- self.dbsession.commit()
def test_record_does_not_appear_in_first_and_latest(self) -> None:
pages = self.report.get_spreadsheet_pages(req=self.req)
@@ -223,45 +229,73 @@ class Core10ReportDateRangeTests(Core10ReportTestCase):
""" # noqa
- def create_tasks(self) -> None:
- self.patient_1 = self.create_patient(idnum_value=333)
- self.patient_2 = self.create_patient(idnum_value=444)
- self.patient_3 = self.create_patient(idnum_value=555)
+ def setUp(self) -> None:
+ super().setUp()
+
+ patient_1 = PatientFactory()
+ patient_2 = PatientFactory()
+ patient_3 = PatientFactory()
# 2018-06 average score = (8 + 6 + 4) / 3 = 6
# 2018-08 average score = (4 + 4 + 4) / 3 = 4
# 2018-10 average score = (2 + 3 + 4) / 3 = 3
- self.create_task(
- patient=self.patient_1, q1=4, q2=4, era="2018-06-01"
+ Core10Factory(
+ patient=patient_1,
+ q1=4,
+ q2=4,
+ when_created=pendulum.parse("2018-06-01"),
) # Score 8
- self.create_task(
- patient=self.patient_1, q7=3, q8=1, era="2018-08-01"
+ Core10Factory(
+ patient=patient_1,
+ q7=3,
+ q8=1,
+ when_created=pendulum.parse("2018-08-01"),
) # Score 4
- self.create_task(
- patient=self.patient_1, q7=1, q8=1, era="2018-10-01"
+ Core10Factory(
+ patient=patient_1,
+ q7=1,
+ q8=1,
+ when_created=pendulum.parse("2018-10-01"),
) # Score 2
- self.create_task(
- patient=self.patient_2, q3=3, q4=3, era="2018-06-01"
+ Core10Factory(
+ patient=patient_2,
+ q3=3,
+ q4=3,
+ when_created=pendulum.parse("2018-06-01"),
) # Score 6
- self.create_task(
- patient=self.patient_2, q3=2, q4=2, era="2018-08-01"
+ Core10Factory(
+ patient=patient_2,
+ q3=2,
+ q4=2,
+ when_created=pendulum.parse("2018-08-01"),
) # Score 4
- self.create_task(
- patient=self.patient_2, q3=1, q4=2, era="2018-10-01"
+ Core10Factory(
+ patient=patient_2,
+ q3=1,
+ q4=2,
+ when_created=pendulum.parse("2018-10-01"),
) # Score 3
- self.create_task(
- patient=self.patient_3, q5=2, q6=2, era="2018-06-01"
+ Core10Factory(
+ patient=patient_3,
+ q5=2,
+ q6=2,
+ when_created=pendulum.parse("2018-06-01"),
) # Score 4
- self.create_task(
- patient=self.patient_3, q9=1, q10=3, era="2018-08-01"
+ Core10Factory(
+ patient=patient_3,
+ q9=1,
+ q10=3,
+ when_created=pendulum.parse("2018-08-01"),
) # Score 4
- self.create_task(
- patient=self.patient_3, q9=1, q10=3, era="2018-10-01"
+ Core10Factory(
+ patient=patient_3,
+ q9=1,
+ q10=3,
+ when_created=pendulum.parse("2018-10-01"),
) # Score 4
- self.dbsession.commit()
self.dump_table(
Core10.__tablename__,
@@ -269,13 +303,8 @@ def create_tasks(self) -> None:
)
def test_report_filtered_by_date_range(self) -> None:
- # self.report.start_datetime = pendulum.parse("2018-05-01T00:00:00.000000+00:00") # noqa
- self.report.start_datetime = pendulum.parse(
- "2018-06-01T00:00:00.000000+00:00"
- )
- self.report.end_datetime = pendulum.parse(
- "2018-09-01T00:00:00.000000+00:00"
- )
+ self.report.start_datetime = "2018-06-01T00:00:00.000000+00:00"
+ self.report.end_datetime = "2018-09-01T00:00:00.000000+00:00"
self.set_echo(True)
pages = self.report.get_spreadsheet_pages(req=self.req)
diff --git a/server/camcops_server/tasks/tests/factories.py b/server/camcops_server/tasks/tests/factories.py
index f8d9e3cbf..673738812 100644
--- a/server/camcops_server/tasks/tests/factories.py
+++ b/server/camcops_server/tasks/tests/factories.py
@@ -29,13 +29,161 @@
import factory
import pendulum
+from typing import cast, TYPE_CHECKING
+from camcops_server.cc_modules.cc_task import Task
from camcops_server.cc_modules.cc_testfactories import (
+ BlobFactory,
+ Fake,
GenericTabletRecordFactory,
)
+from camcops_server.tasks.ace3 import Ace3, MiniAce
+from camcops_server.tasks.aims import Aims
+from camcops_server.tasks.apeqpt import Apeqpt
+from camcops_server.tasks.apeq_cpft_perinatal import APEQCPFTPerinatal
+from camcops_server.tasks.aq import Aq
+from camcops_server.tasks.asdas import Asdas
+from camcops_server.tasks.audit import Audit, AuditC
+from camcops_server.tasks.badls import Badls
+from camcops_server.tasks.basdai import Basdai
+from camcops_server.tasks.bdi import Bdi
from camcops_server.tasks.bmi import Bmi
+from camcops_server.tasks.bprs import Bprs
+from camcops_server.tasks.bprse import Bprse
+from camcops_server.tasks.cage import Cage
+from camcops_server.tasks.cape42 import Cape42
+from camcops_server.tasks.caps import Caps
+from camcops_server.tasks.cardinal_expdetthreshold import (
+ CardinalExpDetThreshold,
+)
+from camcops_server.tasks.cardinal_expectationdetection import (
+ CardinalExpectationDetection,
+)
+from camcops_server.tasks.cbir import CbiR
+from camcops_server.tasks.ceca import CecaQ3
+from camcops_server.tasks.cesd import Cesd
+from camcops_server.tasks.cesdr import Cesdr
+from camcops_server.tasks.cet import Cet
+from camcops_server.tasks.cgi_task import Cgi, CgiI
+from camcops_server.tasks.cgisch import CgiSch
+from camcops_server.tasks.chit import Chit
+from camcops_server.tasks.cia import Cia
+from camcops_server.tasks.cisr import Cisr
+from camcops_server.tasks.ciwa import Ciwa
+from camcops_server.tasks.contactlog import ContactLog
+from camcops_server.tasks.cope import CopeBrief
+from camcops_server.tasks.core10 import Core10
+from camcops_server.tasks.cpft_covid_medical import CpftCovidMedical
+from camcops_server.tasks.cpft_lps import (
+ CPFTLPSDischarge,
+ CPFTLPSReferral,
+ CPFTLPSResetResponseClock,
+)
+from camcops_server.tasks.cpft_research_preferences import (
+ CpftResearchPreferences,
+)
+from camcops_server.tasks.dad import Dad
+from camcops_server.tasks.das28 import Das28
+from camcops_server.tasks.dast import Dast
+from camcops_server.tasks.deakin_s1_healthreview import DeakinS1HealthReview
+from camcops_server.tasks.demoquestionnaire import DemoQuestionnaire
+from camcops_server.tasks.demqol import Demqol, DemqolProxy
+from camcops_server.tasks.diagnosis import (
+ DiagnosisIcd10,
+ DiagnosisIcd10Item,
+ DiagnosisIcd9CM,
+ DiagnosisIcd9CMItem,
+)
+from camcops_server.tasks.distressthermometer import DistressThermometer
+from camcops_server.tasks.edeq import Edeq
+from camcops_server.tasks.elixhauserci import ElixhauserCI
+from camcops_server.tasks.epds import Epds
+from camcops_server.tasks.eq5d5l import Eq5d5l
+from camcops_server.tasks.esspri import Esspri
+from camcops_server.tasks.factg import Factg
+from camcops_server.tasks.fast import Fast
+from camcops_server.tasks.fft import Fft
+from camcops_server.tasks.frs import Frs
+from camcops_server.tasks.gad7 import Gad7
+from camcops_server.tasks.gaf import Gaf
+from camcops_server.tasks.gbo import Gbogpc, Gbogras, Gbogres
+from camcops_server.tasks.gds import Gds15
+from camcops_server.tasks.gmcpq import GMCPQ
+from camcops_server.tasks.hads import Hads, HadsRespondent
+from camcops_server.tasks.hama import Hama
+from camcops_server.tasks.hamd import Hamd
+from camcops_server.tasks.hamd7 import Hamd7
+from camcops_server.tasks.honos import Honos, Honos65, Honosca
+from camcops_server.tasks.icd10depressive import Icd10Depressive
+from camcops_server.tasks.icd10manic import Icd10Manic
+from camcops_server.tasks.icd10mixed import Icd10Mixed
+from camcops_server.tasks.icd10schizophrenia import Icd10Schizophrenia
+from camcops_server.tasks.icd10schizotypal import Icd10Schizotypal
+from camcops_server.tasks.icd10specpd import Icd10SpecPD
+from camcops_server.tasks.ided3d import IDED3D
+from camcops_server.tasks.iesr import Iesr
+from camcops_server.tasks.ifs import Ifs
+from camcops_server.tasks.irac import Irac
+from camcops_server.tasks.isaaq10 import Isaaq10
+from camcops_server.tasks.isaaqed import IsaaqEd
+from camcops_server.tasks.khandaker_insight_medical import (
+ KhandakerInsightMedical,
+)
+from camcops_server.tasks.khandaker_mojo_medical import KhandakerMojoMedical
+from camcops_server.tasks.khandaker_mojo_sociodemographics import (
+ KhandakerMojoSociodemographics,
+)
+from camcops_server.tasks.khandaker_mojo_medicationtherapy import (
+ KhandakerMojoMedicationTherapy,
+)
+from camcops_server.tasks.kirby_mcq import Kirby
+from camcops_server.tasks.lynall_iam_life import LynallIamLifeEvents
+from camcops_server.tasks.lynall_iam_medical import LynallIamMedicalHistory
+from camcops_server.tasks.maas import Maas
+from camcops_server.tasks.mast import Mast
+from camcops_server.tasks.mds_updrs import MdsUpdrs
+from camcops_server.tasks.mfi20 import Mfi20
+from camcops_server.tasks.moca import Moca
+from camcops_server.tasks.nart import Nart
+from camcops_server.tasks.npiq import NpiQ
+from camcops_server.tasks.ors import Ors
+from camcops_server.tasks.panss import Panss
+from camcops_server.tasks.paradise24 import Paradise24
+from camcops_server.tasks.pbq import Pbq
+from camcops_server.tasks.pcl5 import Pcl5
+from camcops_server.tasks.pcl import PclC, PclM, PclS
+from camcops_server.tasks.pdss import Pdss
+from camcops_server.tasks.perinatalpoem import PerinatalPoem
+from camcops_server.tasks.photo import Photo, PhotoSequence
+from camcops_server.tasks.phq15 import Phq15
+from camcops_server.tasks.phq8 import Phq8
from camcops_server.tasks.phq9 import Phq9
+from camcops_server.tasks.progressnote import ProgressNote
+from camcops_server.tasks.pswq import Pswq
+from camcops_server.tasks.psychiatricclerking import PsychiatricClerking
+from camcops_server.tasks.qolbasic import QolBasic
+from camcops_server.tasks.qolsg import QolSG
+from camcops_server.tasks.rand36 import Rand36
+from camcops_server.tasks.rapid3 import Rapid3
+from camcops_server.tasks.service_satisfaction import (
+ PatientSatisfaction,
+ ReferrerSatisfactionGen,
+ ReferrerSatisfactionSpec,
+)
+from camcops_server.tasks.sfmpq2 import Sfmpq2
+from camcops_server.tasks.shaps import Shaps
+from camcops_server.tasks.slums import Slums
+from camcops_server.tasks.smast import Smast
+from camcops_server.tasks.srs import Srs
+from camcops_server.tasks.suppsp import Suppsp
+from camcops_server.tasks.wemwbs import Swemwbs, Wemwbs
+from camcops_server.tasks.wsas import Wsas
+from camcops_server.tasks.ybocs import Ybocs, YbocsSc
+from camcops_server.tasks.zbi import Zbi12
+
+if TYPE_CHECKING:
+ from factory.builder import Resolver
class TaskFactory(GenericTabletRecordFactory):
@@ -44,7 +192,11 @@ class Meta:
@factory.lazy_attribute
def when_created(self) -> pendulum.DateTime:
- return pendulum.parse(self.default_iso_datetime)
+ datetime = cast(
+ pendulum.DateTime, pendulum.parse(self.default_iso_datetime)
+ )
+
+ return datetime
class TaskHasPatientFactory(TaskFactory):
@@ -53,6 +205,46 @@ class Meta:
patient_id = None
+ @classmethod
+ def create(cls, *args, **kwargs) -> Task:
+ patient = kwargs.pop("patient", None)
+ if patient is not None:
+ if "patient_id" in kwargs:
+ raise TypeError(
+ "Both 'patient' and 'patient_id' keyword arguments "
+ "unexpectedly passed to a task factory. Use one or the "
+ "other."
+ )
+ kwargs["patient_id"] = patient.id
+
+ if "_device" not in kwargs:
+ kwargs["_device"] = patient._device
+
+ if "_era" not in kwargs:
+ kwargs["_era"] = patient._era
+
+ if "_group" not in kwargs:
+ kwargs["_group"] = patient._group
+
+ if "_current" not in kwargs:
+ kwargs["_current"] = True
+
+ return super().create(*args, **kwargs)
+
+
+class APEQCPFTPerinatalFactory(TaskFactory):
+ class Meta:
+ model = APEQCPFTPerinatal
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ApeqptFactory(TaskFactory):
+ class Meta:
+ model = Apeqpt
+
+ id = factory.Sequence(lambda n: n)
+
class BmiFactory(TaskHasPatientFactory):
class Meta:
@@ -60,9 +252,1031 @@ class Meta:
id = factory.Sequence(lambda n: n)
+ height_m = factory.LazyFunction(Fake.en_gb.height_m)
+ mass_kg = factory.LazyFunction(Fake.en_gb.mass_kg)
+ waist_cm = factory.LazyFunction(Fake.en_gb.waist_cm)
+
+
+class Core10Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Core10
+
+ id = factory.Sequence(lambda n: n)
+
+ q1 = 0
+ q2 = 0
+ q3 = 0
+ q4 = 0
+ q5 = 0
+ q6 = 0
+ q7 = 0
+ q8 = 0
+ q9 = 0
+ q10 = 0
+
+
+class DiagnosisIcd10Factory(TaskHasPatientFactory):
+ class Meta:
+ model = DiagnosisIcd10
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DiagnosisItemFactory(GenericTabletRecordFactory):
+ class Meta:
+ abstract = True
+
+
+class DiagnosisIcd10ItemFactory(DiagnosisItemFactory):
+ class Meta:
+ model = DiagnosisIcd10Item
+
+ id = factory.Sequence(lambda n: n)
+
+ @classmethod
+ def create(cls, *args, **kwargs) -> DiagnosisIcd10Item:
+ diagnosis_icd10 = kwargs.pop("diagnosis_icd10", None)
+ if diagnosis_icd10 is not None:
+ if "diagnosis_icd10_id" in kwargs:
+ raise TypeError(
+ "Both 'diagnosis_icd10' and 'diagnosis_icd10_id' keyword "
+ "arguments unexpectedly passed to a task factory. Use one "
+ "or the other."
+ )
+ kwargs["diagnosis_icd10_id"] = diagnosis_icd10.id
+
+ if "_device" not in kwargs:
+ kwargs["_device"] = diagnosis_icd10._device
+
+ if "_era" not in kwargs:
+ kwargs["_era"] = diagnosis_icd10._era
+
+ if "_current" not in kwargs:
+ kwargs["_current"] = True
+
+ return super().create(*args, **kwargs)
+
+
+class DiagnosisIcd9CMFactory(TaskHasPatientFactory):
+ class Meta:
+ model = DiagnosisIcd9CM
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DiagnosisIcd9CMItemFactory(DiagnosisItemFactory):
+ class Meta:
+ model = DiagnosisIcd9CMItem
+
+ id = factory.Sequence(lambda n: n)
+
+ @classmethod
+ def create(cls, *args, **kwargs) -> DiagnosisIcd9CMItem:
+ diagnosis_icd9cm = kwargs.pop("diagnosis_icd9cm", None)
+ if diagnosis_icd9cm is not None:
+ if "diagnosis_icd9cm_id" in kwargs:
+ raise TypeError(
+ "Both 'diagnosis_icd9cm' and 'diagnosis_icd9cm_id' "
+ "keyword arguments unexpectedly passed to a task factory. "
+ "Use one or the other."
+ )
+ kwargs["diagnosis_icd9cm_id"] = diagnosis_icd9cm.id
+
+ if "_device" not in kwargs:
+ kwargs["_device"] = diagnosis_icd9cm._device
+
+ if "_era" not in kwargs:
+ kwargs["_era"] = diagnosis_icd9cm._era
+
+ if "_current" not in kwargs:
+ kwargs["_current"] = True
+
+ return super().create(*args, **kwargs)
+
+
+class Gad7Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Gad7
+
+ id = factory.Sequence(lambda n: n)
+
+
+class KhandakerMojoMedicationTherapyFactory(TaskHasPatientFactory):
+ class Meta:
+ model = KhandakerMojoMedicationTherapy
+
+ id = factory.Sequence(lambda n: n)
+
+
+class MaasFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Maas
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PerinatalPoemFactory(TaskFactory):
+ class Meta:
+ model = PerinatalPoem
+
+ id = factory.Sequence(lambda n: n)
+
class Phq9Factory(TaskHasPatientFactory):
class Meta:
model = Phq9
id = factory.Sequence(lambda n: n)
+
+
+class Ace3Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Ace3
+
+ id = factory.Sequence(lambda n: n)
+
+
+class AimsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Aims
+
+ id = factory.Sequence(lambda n: n)
+
+
+class AqFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Aq
+
+ id = factory.Sequence(lambda n: n)
+
+
+class AsdasFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Asdas
+
+ id = factory.Sequence(lambda n: n)
+
+
+class AuditFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Audit
+
+ id = factory.Sequence(lambda n: n)
+
+
+class AuditCFactory(TaskHasPatientFactory):
+ class Meta:
+ model = AuditC
+
+ id = factory.Sequence(lambda n: n)
+
+
+class BadlsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Badls
+
+ id = factory.Sequence(lambda n: n)
+
+
+class BasdaiFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Basdai
+
+ id = factory.Sequence(lambda n: n)
+
+
+class BdiFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Bdi
+
+ id = factory.Sequence(lambda n: n)
+
+
+class BprsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Bprs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class BprseFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Bprse
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CageFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cage
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Cape42Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Cape42
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CapsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Caps
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CardinalExpectationDetectionFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CardinalExpectationDetection
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CardinalExpDetThresholdFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CardinalExpDetThreshold
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CbiRFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CbiR
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CecaQ3Factory(TaskHasPatientFactory):
+ class Meta:
+ model = CecaQ3
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CesdFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cesd
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CesdrFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cesdr
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CetFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cet
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CgiFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cgi
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CgiIFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CgiI
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CgiSchFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CgiSch
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ChitFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Chit
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CiaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cia
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CisrFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Cisr
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CiwaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Ciwa
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ContactLogFactory(TaskHasPatientFactory):
+ class Meta:
+ model = ContactLog
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CopeBriefFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CopeBrief
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CpftCovidMedicalFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CpftCovidMedical
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CPFTLPSDischargeFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CPFTLPSDischarge
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CPFTLPSReferralFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CPFTLPSReferral
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CPFTLPSResetResponseClockFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CPFTLPSResetResponseClock
+
+ id = factory.Sequence(lambda n: n)
+
+
+class CpftResearchPreferencesFactory(TaskHasPatientFactory):
+ class Meta:
+ model = CpftResearchPreferences
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DadFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Dad
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Das28Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Das28
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DastFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Dast
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DeakinS1HealthReviewFactory(TaskHasPatientFactory):
+ class Meta:
+ model = DeakinS1HealthReview
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DemoQuestionnaireFactory(TaskFactory):
+ class Meta:
+ model = DemoQuestionnaire
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DemqolFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Demqol
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DemqolProxyFactory(TaskHasPatientFactory):
+ class Meta:
+ model = DemqolProxy
+
+ id = factory.Sequence(lambda n: n)
+
+
+class DistressThermometerFactory(TaskHasPatientFactory):
+ class Meta:
+ model = DistressThermometer
+
+ id = factory.Sequence(lambda n: n)
+
+
+class EdeqFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Edeq
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ElixhauserCIFactory(TaskHasPatientFactory):
+ class Meta:
+ model = ElixhauserCI
+
+ id = factory.Sequence(lambda n: n)
+
+
+class EpdsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Epds
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Eq5d5lFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Eq5d5l
+
+ id = factory.Sequence(lambda n: n)
+
+
+class EsspriFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Esspri
+
+ id = factory.Sequence(lambda n: n)
+
+
+class FactgFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Factg
+
+ id = factory.Sequence(lambda n: n)
+
+
+class FastFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Fast
+
+ id = factory.Sequence(lambda n: n)
+
+
+class FftFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Fft
+
+ id = factory.Sequence(lambda n: n)
+
+
+class FrsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Frs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class GafFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Gaf
+
+ id = factory.Sequence(lambda n: n)
+
+
+class GbogpcFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Gbogpc
+
+ id = factory.Sequence(lambda n: n)
+
+
+class GbograsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Gbogras
+
+ id = factory.Sequence(lambda n: n)
+
+
+class GbogresFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Gbogres
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Gds15Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Gds15
+
+ id = factory.Sequence(lambda n: n)
+
+
+class GMCPQFactory(TaskFactory):
+ class Meta:
+ model = GMCPQ
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HadsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Hads
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HadsRespondentFactory(TaskHasPatientFactory):
+ class Meta:
+ model = HadsRespondent
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HamaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Hama
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HamdFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Hamd
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Hamd7Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Hamd7
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HonosFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Honos
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Honos65Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Honos65
+
+ id = factory.Sequence(lambda n: n)
+
+
+class HonoscaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Honosca
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10DepressiveFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10Depressive
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10ManicFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10Manic
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10MixedFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10Mixed
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10SchizophreniaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10Schizophrenia
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10SchizotypalFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10Schizotypal
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Icd10SpecPDFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Icd10SpecPD
+
+ id = factory.Sequence(lambda n: n)
+
+
+class IDED3DFactory(TaskHasPatientFactory):
+ class Meta:
+ model = IDED3D
+
+ id = factory.Sequence(lambda n: n)
+
+
+class IesrFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Iesr
+
+ id = factory.Sequence(lambda n: n)
+
+
+class IfsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Ifs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class IracFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Irac
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Isaaq10Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Isaaq10
+
+ id = factory.Sequence(lambda n: n)
+
+
+class IsaaqEdFactory(TaskHasPatientFactory):
+ class Meta:
+ model = IsaaqEd
+
+ id = factory.Sequence(lambda n: n)
+
+
+class KhandakerInsightMedicalFactory(TaskHasPatientFactory):
+ class Meta:
+ model = KhandakerInsightMedical
+
+ id = factory.Sequence(lambda n: n)
+
+
+class KhandakerMojoMedicalFactory(TaskHasPatientFactory):
+ class Meta:
+ model = KhandakerMojoMedical
+
+ id = factory.Sequence(lambda n: n)
+
+
+class KhandakerMojoSociodemographicsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = KhandakerMojoSociodemographics
+
+ id = factory.Sequence(lambda n: n)
+
+
+class KirbyFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Kirby
+
+ id = factory.Sequence(lambda n: n)
+
+
+class LynallIamMedicalHistoryFactory(TaskHasPatientFactory):
+ class Meta:
+ model = LynallIamMedicalHistory
+
+ id = factory.Sequence(lambda n: n)
+
+
+class LynallIamLifeEventsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = LynallIamLifeEvents
+
+ id = factory.Sequence(lambda n: n)
+
+
+class MastFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Mast
+
+ id = factory.Sequence(lambda n: n)
+
+
+class MdsUpdrsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = MdsUpdrs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Mfi20Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Mfi20
+
+ id = factory.Sequence(lambda n: n)
+
+
+class MiniAceFactory(TaskHasPatientFactory):
+ class Meta:
+ model = MiniAce
+
+ id = factory.Sequence(lambda n: n)
+
+
+class MocaFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Moca
+
+ id = factory.Sequence(lambda n: n)
+
+
+class NartFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Nart
+
+ id = factory.Sequence(lambda n: n)
+
+
+class NpiQFactory(TaskHasPatientFactory):
+ class Meta:
+ model = NpiQ
+
+ id = factory.Sequence(lambda n: n)
+
+
+class OrsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Ors
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PanssFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Panss
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Paradise24Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Paradise24
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PbqFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Pbq
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Pcl5Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Pcl5
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PclCFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PclC
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PclMFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PclM
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PclSFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PclS
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PdssFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Pdss
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PhotoFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Photo
+
+ id = factory.Sequence(lambda n: n)
+
+ @factory.post_generation
+ def create_blob(
+ obj: "Resolver", create: bool, extracted: None, **kwargs
+ ) -> None:
+ if not create:
+ return
+
+ obj.photo = BlobFactory.create(
+ tablename=obj.tablename, tablepk=obj.id, **kwargs
+ )
+
+
+class PhotoSequenceFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PhotoSequence
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Phq15Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Phq15
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Phq8Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Phq8
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ProgressNoteFactory(TaskHasPatientFactory):
+ class Meta:
+ model = ProgressNote
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PswqFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Pswq
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PsychiatricClerkingFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PsychiatricClerking
+
+ id = factory.Sequence(lambda n: n)
+
+
+class PatientSatisfactionFactory(TaskHasPatientFactory):
+ class Meta:
+ model = PatientSatisfaction
+
+ id = factory.Sequence(lambda n: n)
+
+
+class QolBasicFactory(TaskHasPatientFactory):
+ class Meta:
+ model = QolBasic
+
+ id = factory.Sequence(lambda n: n)
+
+
+class QolSGFactory(TaskHasPatientFactory):
+ class Meta:
+ model = QolSG
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Rand36Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Rand36
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Rapid3Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Rapid3
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ReferrerSatisfactionGenFactory(TaskFactory):
+ class Meta:
+ model = ReferrerSatisfactionGen
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ReferrerSatisfactionSpecFactory(TaskHasPatientFactory):
+ class Meta:
+ model = ReferrerSatisfactionSpec
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Sfmpq2Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Sfmpq2
+
+ id = factory.Sequence(lambda n: n)
+
+
+class ShapsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Shaps
+
+ id = factory.Sequence(lambda n: n)
+
+
+class SlumsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Slums
+
+ id = factory.Sequence(lambda n: n)
+
+
+class SmastFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Smast
+
+ id = factory.Sequence(lambda n: n)
+
+
+class SrsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Srs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class SuppspFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Suppsp
+
+ id = factory.Sequence(lambda n: n)
+
+
+class SwemwbsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Swemwbs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class WemwbsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Wemwbs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class WsasFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Wsas
+
+ id = factory.Sequence(lambda n: n)
+
+
+class YbocsFactory(TaskHasPatientFactory):
+ class Meta:
+ model = Ybocs
+
+ id = factory.Sequence(lambda n: n)
+
+
+class YbocsScFactory(TaskHasPatientFactory):
+ class Meta:
+ model = YbocsSc
+
+ id = factory.Sequence(lambda n: n)
+
+
+class Zbi12Factory(TaskHasPatientFactory):
+ class Meta:
+ model = Zbi12
+
+ id = factory.Sequence(lambda n: n)
diff --git a/server/camcops_server/tasks/tests/maas_tests.py b/server/camcops_server/tasks/tests/maas_tests.py
index 9b0566991..b8c45f89c 100644
--- a/server/camcops_server/tasks/tests/maas_tests.py
+++ b/server/camcops_server/tasks/tests/maas_tests.py
@@ -27,48 +27,53 @@
import pendulum
-from camcops_server.cc_modules.cc_patient import Patient
-from camcops_server.cc_modules.tests.cc_report_tests import (
- AverageScoreReportTestCase,
+from camcops_server.cc_modules.cc_testfactories import (
+ PatientFactory,
+ UserFactory,
)
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
from camcops_server.tasks.maas import Maas, MaasReport
+from camcops_server.tasks.tests.factories import MaasFactory
-class MaasReportTests(AverageScoreReportTestCase):
+class MaasReportTests(DemoRequestTestCase):
PROGRESS_COL = 4
- def create_report(self) -> MaasReport:
- return MaasReport(via_index=False)
+ def setUp(self) -> None:
+ super().setUp()
- def create_tasks(self) -> None:
- self.patient_1 = self.create_patient()
+ self.report = self.create_report()
+ self.req._debugging_user = UserFactory(superuser=True)
- self.create_task(
- patient=self.patient_1, q1=2, q2=2, era="2019-03-01"
- ) # total 17 + 2 + 2
- self.create_task(
- patient=self.patient_1, q1=5, q2=5, era="2019-06-01"
- ) # total 17 + 5 + 5
- self.dbsession.commit()
+ patient = PatientFactory()
- def create_task(self, patient: Patient, era: str = None, **kwargs) -> None:
- task = Maas()
- self.apply_standard_task_fields(task)
- task.id = next(self.task_id_sequence)
+ # Default to answering 1 to everything
+ response_dict = {q: 1 for q in Maas.TASK_FIELDS}
- task.patient_id = patient.id
- for fieldname in Maas.TASK_FIELDS:
- value = kwargs.get(fieldname, 1)
- setattr(task, fieldname, value)
+ # Same patient completing the task at different intervals
+ response_dict["q1"] = 2
+ response_dict["q2"] = 2
+ MaasFactory(
+ patient=patient,
+ when_created=pendulum.parse("2019-03-01"),
+ **response_dict,
+ ) # total 17 * 1 + 2 + 2 = 21
- if era is not None:
- task.when_created = pendulum.parse(era)
+ response_dict["q1"] = 5
+ response_dict["q2"] = 5
+ MaasFactory(
+ patient=patient,
+ when_created=pendulum.parse("2019-06-01"),
+ **response_dict,
+ ) # total 17 * 1 + 5 + 5 = 27
- self.dbsession.add(task)
+ def create_report(self) -> MaasReport:
+ return MaasReport(via_index=False)
def test_average_progress_is_positive(self) -> None:
pages = self.report.get_spreadsheet_pages(req=self.req)
+ # Numbers as above
expected_progress = 27 - 21
actual_progress = pages[0].plainrows[0][self.PROGRESS_COL]
diff --git a/server/camcops_server/tasks/tests/perinatalpoem_tests.py b/server/camcops_server/tasks/tests/perinatalpoem_tests.py
index 0689b84f4..975839dfa 100644
--- a/server/camcops_server/tasks/tests/perinatalpoem_tests.py
+++ b/server/camcops_server/tasks/tests/perinatalpoem_tests.py
@@ -25,27 +25,16 @@
"""
-from typing import Generator
-
-import pendulum
-
-from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase
-from camcops_server.tasks.perinatalpoem import (
- PerinatalPoem,
- PerinatalPoemReport,
-)
-
+from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase
+from camcops_server.tasks.perinatalpoem import PerinatalPoemReport
+from camcops_server.tasks.tests.factories import PerinatalPoemFactory
# =============================================================================
# Unit tests
# =============================================================================
-class PerinatalPoemReportTestCase(BasicDatabaseTestCase):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.id_sequence = self.get_id()
-
+class PerinatalPoemReportTestCase(DemoRequestTestCase):
def setUp(self) -> None:
super().setUp()
@@ -55,28 +44,6 @@ def setUp(self) -> None:
self.report.start_datetime = None
self.report.end_datetime = None
- @staticmethod
- def get_id() -> Generator[int, None, None]:
- i = 1
-
- while True:
- yield i
- i += 1
-
- def create_task(self, **kwargs) -> None:
- task = PerinatalPoem()
- self.apply_standard_task_fields(task)
- task.id = next(self.id_sequence)
-
- era = kwargs.pop("era", None)
- if era is not None:
- task.when_created = pendulum.parse(era)
-
- for name, value in kwargs.items():
- setattr(task, name, value)
-
- self.dbsession.add(task)
-
class PerinatalPoemReportTests(PerinatalPoemReportTestCase):
"""
@@ -84,10 +51,16 @@ class PerinatalPoemReportTests(PerinatalPoemReportTestCase):
sanity checking here
"""
- def create_tasks(self):
- self.create_task(general_comments="comment 1")
- self.create_task(general_comments="comment 2")
- self.create_task(general_comments="comment 3")
+ def setUp(self) -> None:
+ super().setUp()
+
+ t1 = PerinatalPoemFactory(general_comments="comment 1")
+ t2 = PerinatalPoemFactory(general_comments="comment 2")
+ t3 = PerinatalPoemFactory(general_comments="comment 3")
+
+ self.dbsession.add(t1)
+ self.dbsession.add(t2)
+ self.dbsession.add(t3)
self.dbsession.commit()
@@ -146,27 +119,35 @@ def test_comments(self) -> None:
class PerinatalPoemReportDateRangeTests(PerinatalPoemReportTestCase):
- def create_tasks(self) -> None:
- self.create_task(
+ def setUp(self) -> None:
+ super().setUp()
+
+ t1 = PerinatalPoemFactory(
general_comments="comments 1",
- era="2018-10-01T00:00:00.000000+00:00",
+ when_created="2018-10-01T00:00:00.000000+00:00",
)
- self.create_task(
+ t2 = PerinatalPoemFactory(
general_comments="comments 2",
- era="2018-10-02T00:00:00.000000+00:00",
+ when_created="2018-10-02T00:00:00.000000+00:00",
)
- self.create_task(
+ t3 = PerinatalPoemFactory(
general_comments="comments 3",
- era="2018-10-03T00:00:00.000000+00:00",
+ when_created="2018-10-03T00:00:00.000000+00:00",
)
- self.create_task(
+ t4 = PerinatalPoemFactory(
general_comments="comments 4",
- era="2018-10-04T00:00:00.000000+00:00",
+ when_created="2018-10-04T00:00:00.000000+00:00",
)
- self.create_task(
+ t5 = PerinatalPoemFactory(
general_comments="comments 5",
- era="2018-10-05T00:00:00.000000+00:00",
+ when_created="2018-10-05T00:00:00.000000+00:00",
)
+ self.dbsession.add(t1)
+ self.dbsession.add(t2)
+ self.dbsession.add(t3)
+ self.dbsession.add(t4)
+ self.dbsession.add(t5)
+
self.dbsession.commit()
def test_comments_filtered_by_date(self) -> None:
diff --git a/server/camcops_server/tools/generate_task_factories.py b/server/camcops_server/tools/generate_task_factories.py
new file mode 100755
index 000000000..ab5e4f0eb
--- /dev/null
+++ b/server/camcops_server/tools/generate_task_factories.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+
+"""
+camcops_server/tools/generate_task_factories.py
+
+===============================================================================
+
+ Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
+ Created by Rudolf Cardinal (rnc1001@cam.ac.uk).
+
+ This file is part of CamCOPS.
+
+ CamCOPS is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ CamCOPS is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with CamCOPS. If not, see .
+
+===============================================================================
+
+Script to generate skeleton Factory Boy test factories for
+camcops_server/tasks/tests/factories.py
+
+Probably not needed anymore.
+
+"""
+
+from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin
+from camcops_server.tasks.tests import factories as task_factories
+
+
+def main() -> None:
+ task_dict = {}
+
+ for cls in Task.all_subclasses_by_tablename():
+ task_class_name = cls.__name__
+ factory_name = f"{task_class_name}Factory"
+ factory_class = getattr(task_factories, factory_name, None)
+ if factory_class is None:
+ task_dict[task_class_name.lower()] = task_class_name
+ if issubclass(cls, TaskHasPatientMixin):
+ sub_class_name = "TaskHasPatientFactory"
+ else:
+ sub_class_name = "TaskFactory"
+
+ print(
+ f"""
+class {factory_name}({sub_class_name}):
+ class Meta:
+ model = {task_class_name}
+
+ id = factory.Sequence(lambda n: n)
+"""
+ )
+
+ for filename, class_name in task_dict.items():
+ print(f"from camcops_server.tasks.{filename} import {class_name}")
+
+
+if __name__ == "__main__":
+ main()