diff --git a/docs/source/autodoc/server/camcops_server/_index.rst b/docs/source/autodoc/server/camcops_server/_index.rst index 817aaf333..a652e61d7 100644 --- a/docs/source/autodoc/server/camcops_server/_index.rst +++ b/docs/source/autodoc/server/camcops_server/_index.rst @@ -198,6 +198,7 @@ server/camcops_server cc_modules/cc_taskschedule.py.rst cc_modules/cc_taskschedulereports.py.rst cc_modules/cc_testfactories.py.rst + cc_modules/cc_testproviders.py.rst cc_modules/cc_text.py.rst cc_modules/cc_tracker.py.rst cc_modules/cc_trackerhelpers.py.rst @@ -649,5 +650,6 @@ server/camcops_server templates/test/test_template_filters.mako.rst templates/test/testpage.mako.rst tools/fetch_snomed_codes.py.rst + tools/generate_task_factories.py.rst tools/print_latest_github_version.py.rst tools/run_server_self_tests.py.rst diff --git a/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst b/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst new file mode 100644 index 000000000..d45db5afa --- /dev/null +++ b/docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst @@ -0,0 +1,29 @@ +.. docs/source/autodoc/server/camcops_server/cc_modules/cc_testproviders.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2012, University of Cambridge, Department of Psychiatry. + Created by Rudolf Cardinal (rnc1001@cam.ac.uk). + . + This file is part of CamCOPS. + . + CamCOPS is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + . + CamCOPS is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + . + You should have received a copy of the GNU General Public License + along with CamCOPS. If not, see . + + +camcops_server.cc_modules.cc_testproviders +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: camcops_server.cc_modules.cc_testproviders + :members: diff --git a/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst b/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst new file mode 100644 index 000000000..963fc05a7 --- /dev/null +++ b/docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst @@ -0,0 +1,29 @@ +.. docs/source/autodoc/server/camcops_server/tools/generate_task_factories.py.rst + +.. THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. + + +.. Copyright (C) 2012, University of Cambridge, Department of Psychiatry. + Created by Rudolf Cardinal (rnc1001@cam.ac.uk). + . + This file is part of CamCOPS. + . + CamCOPS is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + . + CamCOPS is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + . + You should have received a copy of the GNU General Public License + along with CamCOPS. If not, see . + + +camcops_server.tools.generate_task_factories +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: camcops_server.tools.generate_task_factories + :members: diff --git a/docs/source/developer/server_testing.rst b/docs/source/developer/server_testing.rst index f8b1a242a..ae65e1cde 100644 --- a/docs/source/developer/server_testing.rst +++ b/docs/source/developer/server_testing.rst @@ -20,7 +20,7 @@ .. _pytest: https://docs.pytest.org/en/stable/ - +.. _Factory Boy: https://factoryboy.readthedocs.io/en/stable/ Testing the server code ======================= @@ -32,12 +32,26 @@ with the filename of the module appended with ``_tests.py``. So the module ``camcops_server/cc_modules/cc_patient.py`` is tested in ``camcops_server/cc_modules/tests/cc_patient_tests.py``. -Test classes should end in ``Tests`` e.g. ``PatientTests``. Tests that require -an empty database should inherit from ``DemoRequestTestCase``. Tests that -require the demonstration database should inherit from -``DemoDatabaseTestCase``. See ``camcops_server/cc_modules/cc_unittest``. Tests -that do not require a database can just inherit from the standard python -``unittest.TestCase`` +Test classes should end in ``Tests`` e.g. ``PatientTests``. A number of +``unittest.TestCase`` subclasses are defined in +``camcops_server/cc_modules/cc_unittest``. + +- Tests that require an empty database and a request object should inherit from + ``DemoRequestTestCase``. + +- Tests that require a minimal database setup (system user, superuser set on the + request object, group administrator and a server device) should inherit from + ``BasicDatabaseTestCase``. + +- Tests that require the demonstration database, which has a patient and two + instances of each type of task should inherit from ``DemoDatabaseTestCase``. + +- Tests that do not require a database + can just inherit from the standard python ``unittest.TestCase``. + +Use `Factory Boy`_ test factories to create test instances of SQLAlchemy +database models. See ``camcops_server/cc_modules/cc_testfactories.py`` and +``camcops_server/tasks/tests/factories.py``. .. _run_all_server_tests: diff --git a/server/camcops_server/cc_modules/cc_membership.py b/server/camcops_server/cc_modules/cc_membership.py index b05e879d4..3d6de0f16 100644 --- a/server/camcops_server/cc_modules/cc_membership.py +++ b/server/camcops_server/cc_modules/cc_membership.py @@ -156,10 +156,6 @@ class UserGroupMembership(Base): group = relationship("Group", back_populates="user_group_memberships") user = relationship("User", back_populates="user_group_memberships") - def __init__(self, user_id: int, group_id: int): - self.user_id = user_id - self.group_id = group_id - @classmethod def get_ugm_by_id( cls, dbsession: SqlASession, ugm_id: Optional[int] diff --git a/server/camcops_server/cc_modules/cc_request.py b/server/camcops_server/cc_modules/cc_request.py index 4977733aa..30506a864 100644 --- a/server/camcops_server/cc_modules/cc_request.py +++ b/server/camcops_server/cc_modules/cc_request.py @@ -2540,8 +2540,5 @@ def get_unittest_request( req.set_get_params(params) req._debugging_db_session = dbsession - user = User() - user.superuser = True - req._debugging_user = user return req diff --git a/server/camcops_server/cc_modules/cc_testfactories.py b/server/camcops_server/cc_modules/cc_testfactories.py index 635533445..d499c8a82 100644 --- a/server/camcops_server/cc_modules/cc_testfactories.py +++ b/server/camcops_server/cc_modules/cc_testfactories.py @@ -27,19 +27,27 @@ """ +from typing import cast, Optional, TYPE_CHECKING + from cardinal_pythonlib.datetimefunc import ( convert_datetime_to_utc, format_datetime, ) import factory +from faker import Faker import pendulum -from camcops_server.cc_modules.cc_constants import DateFormat +from camcops_server.cc_modules.cc_blob import Blob +from camcops_server.cc_modules.cc_constants import DateFormat, ERA_NOW from camcops_server.cc_modules.cc_device import Device from camcops_server.cc_modules.cc_email import Email from camcops_server.cc_modules.cc_group import Group +from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition +from camcops_server.cc_modules.cc_ipuse import IpUse from camcops_server.cc_modules.cc_membership import UserGroupMembership from camcops_server.cc_modules.cc_patient import Patient +from camcops_server.cc_modules.cc_patientidnum import PatientIdNum +from camcops_server.cc_modules.cc_testproviders import register_all_providers from camcops_server.cc_modules.cc_taskschedule import ( PatientTaskSchedule, PatientTaskScheduleEmail, @@ -48,89 +56,282 @@ ) from camcops_server.cc_modules.cc_user import User +if TYPE_CHECKING: + from factory.builder import Resolver + from camcops_server.cc_modules.cc_request import CamcopsRequest + + +# Avoid any ID clashes with objects not created with factories +ID_OFFSET = 1000 +RIO_ID_OFFSET = 10000 +STUDY_ID_OFFSET = 5000 + + +class Fake: + # Factory Boy has its own interface to Faker (factory.Faker()). This + # takes a function to be called at object generation time and as far as I + # can tell this doesn't support being able to create fake data based on + # other fake attributes such as notes for a patient. You can work + # around this by adding a lot of logic to the factories. To me it makes + # sense to keep the factories simple and do as much as possible of the + # content generation in the providers. So we call Faker directly instead. + en_gb = Faker("en_GB") # For UK postcodes, phone numbers etc + en_us = Faker("en_US") # en_GB gives Lorem ipsum for pad words. + + +register_all_providers(Fake.en_gb) + # sqlalchemy_session gets poked in by DemoRequestCase.setUp() class BaseFactory(factory.alchemy.SQLAlchemyModelFactory): - pass + class Meta: + sqlalchemy_session_persistence = "commit" class DeviceFactory(BaseFactory): class Meta: model = Device - id = factory.Sequence(lambda n: n) - name = factory.Sequence(lambda n: f"Test device {n}") + id = factory.Sequence(lambda n: n + ID_OFFSET) + name = factory.Sequence(lambda n: f"test-device-{n + ID_OFFSET}") + + +class IpUseFactory(BaseFactory): + class Meta: + model = IpUse + + clinical = factory.LazyFunction(Fake.en_gb.pybool) + commercial = factory.LazyFunction(Fake.en_gb.pybool) + educational = factory.LazyFunction(Fake.en_gb.pybool) + research = factory.LazyFunction(Fake.en_gb.pybool) class GroupFactory(BaseFactory): class Meta: model = Group - id = factory.Sequence(lambda n: n) - name = factory.Sequence(lambda n: f"Group {n}") + id = factory.Sequence(lambda n: n + ID_OFFSET) + name = factory.Sequence(lambda n: f"Group {n + ID_OFFSET}") + ip_use = factory.SubFactory(IpUseFactory) + + +class AnyIdNumGroupFactory(GroupFactory): + upload_policy = "sex and anyidnum" + finalize_policy = "sex and anyidnum" class UserFactory(BaseFactory): class Meta: model = User - id = factory.Sequence(lambda n: n) username = factory.Sequence(lambda n: f"user{n}") hashedpw = "" + @factory.post_generation + def password( + obj: User, + create: bool, + password: Optional[str], + request: "CamcopsRequest" = None, + **kwargs, + ) -> None: + if not create: + return + + if password is None: + return + + assert request is not None + + obj.set_password(request, password) + class GenericTabletRecordFactory(BaseFactory): + class Meta: + exclude = ("default_iso_datetime",) + abstract = True + default_iso_datetime = "1970-01-01T12:00" + _pk = factory.Sequence(lambda n: n + ID_OFFSET) _device = factory.SubFactory(DeviceFactory) - _group = factory.SubFactory(GroupFactory) + _group = factory.SubFactory(AnyIdNumGroupFactory) _adding_user = factory.SubFactory(UserFactory) @factory.lazy_attribute - def _when_added_exact(self) -> pendulum.DateTime: - return pendulum.parse(self.default_iso_datetime) + def _when_added_exact(obj: "Resolver") -> pendulum.DateTime: + datetime = cast( + pendulum.DateTime, pendulum.parse(obj.default_iso_datetime) + ) + + return datetime @factory.lazy_attribute - def _when_added_batch_utc(self) -> pendulum.DateTime: - era_time = pendulum.parse(self.default_iso_datetime) + def _when_added_batch_utc(obj: "Resolver") -> pendulum.DateTime: + era_time = pendulum.parse(obj.default_iso_datetime) return convert_datetime_to_utc(era_time) @factory.lazy_attribute - def _era(self) -> str: - era_time = pendulum.parse(self.default_iso_datetime) + def _era(obj: "Resolver") -> str: + era_time = pendulum.parse(obj.default_iso_datetime) return format_datetime(era_time, DateFormat.ISO8601) @factory.lazy_attribute - def _current(self) -> bool: + def _current(obj: "Resolver") -> bool: # _current = True gets ignored for some reason return True - class Meta: - exclude = ("default_iso_datetime",) - abstract = True - class PatientFactory(GenericTabletRecordFactory): class Meta: model = Patient - id = factory.Sequence(lambda n: n) + id = factory.Sequence(lambda n: n + ID_OFFSET) + sex = factory.LazyFunction(Fake.en_gb.sex) + dob = factory.LazyFunction(Fake.en_gb.consistent_date_of_birth) + address = factory.LazyFunction(Fake.en_gb.address) + gp = factory.LazyFunction(Fake.en_gb.name) + other = factory.LazyFunction(Fake.en_us.paragraph) + email = factory.LazyFunction(Fake.en_gb.email) + + @factory.lazy_attribute + def forename(obj: "Resolver") -> str: + return Fake.en_gb.forename(obj.sex) + + surname = factory.LazyFunction(Fake.en_gb.last_name) class ServerCreatedPatientFactory(PatientFactory): @factory.lazy_attribute - def _device(self) -> Device: - # Should have been created in BasicDatabaseTestCase.setUp + def _device(obj: "Resolver") -> Device: + # May have been created in BasicDatabaseTestCase.setUp return Device.get_server_device( ServerCreatedPatientFactory._meta.sqlalchemy_session ) + @factory.lazy_attribute + def _era(obj: "Resolver") -> str: + return ERA_NOW + + +class IdNumDefinitionFactory(BaseFactory): + class Meta: + model = IdNumDefinition + + which_idnum = factory.Sequence(lambda n: n + ID_OFFSET) + + +class NHSIdNumDefinitionFactory(IdNumDefinitionFactory): + description = "NHS number" + short_description = "NHS#" + hl7_assigning_authority = "NHS" + hl7_id_type = "NHSN" + + +class StudyIdNumDefinitionFactory(IdNumDefinitionFactory): + description = "Study number" + short_description = "Study" + + +class RioIdNumDefinitionFactory(IdNumDefinitionFactory): + description = "RiO number" + short_description = "RiO" + hl7_assigning_authority = "CPFT" + hl7_id_type = "CPRiO" + + +class PatientIdNumFactory(GenericTabletRecordFactory): + class Meta: + model = PatientIdNum + + id = factory.Sequence(lambda n: n + ID_OFFSET) + patient = factory.SubFactory(PatientFactory) + patient_id = factory.SelfAttribute("patient.id") + _group = factory.SelfAttribute("patient._group") + _device = factory.SelfAttribute("patient._device") + + +class NHSPatientIdNumFactory(PatientIdNumFactory): + class Meta: + exclude = PatientIdNumFactory._meta.exclude + ("idnum",) + + idnum = factory.SubFactory(NHSIdNumDefinitionFactory) + + which_idnum = factory.SelfAttribute("idnum.which_idnum") + idnum_value = factory.LazyFunction(Fake.en_gb.nhs_number) + + +class RioPatientIdNumFactory(PatientIdNumFactory): + class Meta: + exclude = PatientIdNumFactory._meta.exclude + ("idnum",) + + idnum = factory.SubFactory(RioIdNumDefinitionFactory) + + which_idnum = factory.SelfAttribute("idnum.which_idnum") + idnum_value = factory.Sequence(lambda n: n + RIO_ID_OFFSET) + + +class StudyPatientIdNumFactory(PatientIdNumFactory): + class Meta: + exclude = PatientIdNumFactory._meta.exclude + ("idnum",) + + idnum = factory.SubFactory(StudyIdNumDefinitionFactory) + + which_idnum = factory.SelfAttribute("idnum.which_idnum") + idnum_value = factory.Sequence(lambda n: n + STUDY_ID_OFFSET) + + +class ServerCreatedPatientIdNumFactory(PatientIdNumFactory): + patient = factory.SubFactory(ServerCreatedPatientFactory) + + @factory.lazy_attribute + def _device(obj: "Resolver") -> Device: + # Should have been created in BasicDatabaseTestCase.setUp + return Device.get_server_device( + ServerCreatedPatientIdNumFactory._meta.sqlalchemy_session + ) + + @factory.lazy_attribute + def _era(obj: "Resolver") -> str: + return ERA_NOW + + +class ServerCreatedNHSPatientIdNumFactory( + ServerCreatedPatientIdNumFactory, NHSPatientIdNumFactory +): + class Meta: + exclude = ( + ServerCreatedPatientIdNumFactory._meta.exclude + + NHSPatientIdNumFactory._meta.exclude + ) + + +class ServerCreatedRioPatientIdNumFactory( + ServerCreatedPatientIdNumFactory, RioPatientIdNumFactory +): + class Meta: + exclude = ( + ServerCreatedPatientIdNumFactory._meta.exclude + + RioPatientIdNumFactory._meta.exclude + ) + + +class ServerCreatedStudyPatientIdNumFactory( + ServerCreatedPatientIdNumFactory, StudyPatientIdNumFactory +): + class Meta: + exclude = ( + ServerCreatedPatientIdNumFactory._meta.exclude + + StudyPatientIdNumFactory._meta.exclude + ) + class TaskScheduleFactory(BaseFactory): class Meta: model = TaskSchedule group = factory.SubFactory(GroupFactory) + name = factory.Sequence(lambda n: f"Schedule {n + ID_OFFSET}") class TaskScheduleItemFactory(BaseFactory): @@ -158,37 +359,45 @@ class EmailFactory(BaseFactory): class Meta: model = Email + # Although sent and sent_at_utc are columns, they are not keyword + # arguments to Email's constructor so they are populated after the object + # has been created. For some reason 'sent' needs to be set explicitly + # when creating the factory even though the default should be False. Might + # be a SQLite thing. @factory.post_generation def sent_at_utc( - self, create: bool, sent_at_utc: pendulum.DateTime, **kwargs + obj: Email, create: bool, sent_at_utc: pendulum.DateTime, **kwargs ) -> None: if not create: return - self.sent_at_utc = sent_at_utc + obj.sent_at_utc = sent_at_utc @factory.post_generation - def sent(self, create: bool, sent: bool, **kwargs) -> None: + def sent(obj: Email, create: bool, sent: bool, **kwargs) -> None: if not create: return - self.sent = sent + obj.sent = sent class PatientTaskScheduleEmailFactory(BaseFactory): class Meta: model = PatientTaskScheduleEmail + patient_task_schedule = factory.SubFactory( + PatientTaskScheduleFactory, + ) + email = factory.SubFactory(EmailFactory, sent=True) + class UserGroupMembershipFactory(BaseFactory): class Meta: model = UserGroupMembership - @factory.post_generation - def may_run_reports( - self, create: bool, may_run_reports: bool, **kwargs - ) -> None: - if not create: - return - self.may_run_reports = may_run_reports +class BlobFactory(GenericTabletRecordFactory): + class Meta: + model = Blob + + id = factory.Sequence(lambda n: n + ID_OFFSET) diff --git a/server/camcops_server/cc_modules/cc_testproviders.py b/server/camcops_server/cc_modules/cc_testproviders.py new file mode 100644 index 000000000..32513dc2a --- /dev/null +++ b/server/camcops_server/cc_modules/cc_testproviders.py @@ -0,0 +1,154 @@ +""" +camcops_server/cc_modules/cc_testproviders.py + +=============================================================================== + + Copyright (C) 2012, University of Cambridge, Department of Psychiatry. + Created by Rudolf Cardinal (rnc1001@cam.ac.uk). + + This file is part of CamCOPS. + + CamCOPS is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + CamCOPS is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with CamCOPS. If not, see . + +=============================================================================== + +**Faker test data providers.** + +There may be some interest in a Faker Medical community provider if we felt it +was worth the effort. + +https://github.com/joke2k/faker/issues/1142 + +See also duplicate functionality in CRATE: crate_anon/testing/providers.py + +""" + +import datetime + +from cardinal_pythonlib.nhs import generate_random_nhs_number +from faker import Faker +from faker.providers import BaseProvider +import pendulum +from pendulum import DateTime as Pendulum +from typing import Any, List + + +class NhsNumberProvider(BaseProvider): + def nhs_number(self) -> str: + return generate_random_nhs_number() + + +class ChoiceProvider(BaseProvider): + def random_choice(self, choices: List, **kwargs) -> Any: + """ + Given a list of choices return a random value + """ + choices = self.generator.random.choices(choices, **kwargs) + + return choices[0] + + +# No one is born after this +_max_birth_datetime = Pendulum(year=2000, month=1, day=1, hour=9) + + +class ConsistentDateOfBirthProvider(BaseProvider): + """ + Faker date_of_birth calculates from the current time so gives different + results on different days. + """ + + def consistent_date_of_birth(self) -> datetime.datetime: + return self.generator.date_between_dates( + date_start=pendulum.date(1900, 1, 1), + date_end=_max_birth_datetime, + ) + + +class ForenameProvider(BaseProvider): + """ + Return a forename given the sex of the person + """ + + def forename(self, sex: str) -> str: + if sex == "M": + return self.generator.first_name_male() + + if sex == "F": + return self.generator.first_name_female() + + return self.generator.first_name()[:1] + + +class HeightProvider(BaseProvider): + def height_m(self) -> float: + """ + Return a random patient height in metres + """ + + return float(self.generator.random_int(min=145, max=191) / 100.0) + + +class MassProvider(BaseProvider): + def mass_kg(self) -> float: + """ + Return a random patient mass in kilograms + """ + + return float(self.generator.random_int(min=400, max=1000) / 10.0) + + +class SexProvider(ChoiceProvider): + """ + Return a random sex, with realistic distribution. + """ + + def sex(self) -> str: + return self.random_choice(["M", "F", "X"], weights=[49.8, 49.8, 0.4]) + + +class ValidPhoneNumberProvider(BaseProvider): + """ + Return a random mobile phone number + """ + + # The default Faker phone_number provider for en_GB uses + # https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501 + # 07700 900000 to 900999 reserved for TV and Radio drama purposes + # but unfortunately the phonenumbers library considers these invalid. + def valid_phone_number(self) -> str: + number = self.generator.random_int(min=7000000000, max=7999999999) + + return f"+44{number}" + + +class WaistProvider(BaseProvider): + """ + Return a random waist circumference in centimetres + """ + + def waist_cm(self) -> float: + return float(self.generator.random_int(min=40, max=130)) + + +def register_all_providers(fake: Faker) -> None: + fake.add_provider(ChoiceProvider) + fake.add_provider(ConsistentDateOfBirthProvider) + fake.add_provider(ForenameProvider) + fake.add_provider(HeightProvider) + fake.add_provider(MassProvider) + fake.add_provider(NhsNumberProvider) + fake.add_provider(ValidPhoneNumberProvider) + fake.add_provider(WaistProvider) + fake.add_provider(SexProvider) diff --git a/server/camcops_server/cc_modules/cc_unittest.py b/server/camcops_server/cc_modules/cc_unittest.py index 6ed2e9c54..d20b781b9 100644 --- a/server/camcops_server/cc_modules/cc_unittest.py +++ b/server/camcops_server/cc_modules/cc_unittest.py @@ -29,48 +29,44 @@ import base64 import copy +from faker import Faker import logging import os +import random import sqlite3 -from typing import Any, List, Type, TYPE_CHECKING -import unittest +from typing import Any, Dict, List, Type, TYPE_CHECKING +from unittest import mock, TestCase from cardinal_pythonlib.classes import all_subclasses from cardinal_pythonlib.dbfunc import get_fieldnames_from_cursor from cardinal_pythonlib.httpconst import MimeType from cardinal_pythonlib.logs import BraceStyleAdapter -import pendulum import pytest from sqlalchemy.engine.base import Engine from camcops_server.cc_modules.cc_baseconstants import ENVVAR_CONFIG_FILE -from camcops_server.cc_modules.cc_constants import ERA_NOW from camcops_server.cc_modules.cc_device import Device from camcops_server.cc_modules.cc_exportrecipient import ExportRecipient -from camcops_server.cc_modules.cc_group import Group -from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition -from camcops_server.cc_modules.cc_ipuse import IpUse from camcops_server.cc_modules.cc_request import ( CamcopsRequest, get_unittest_request, ) from camcops_server.cc_modules.cc_sqlalchemy import sql_from_sqlite_database +from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin from camcops_server.cc_modules.cc_user import User -from camcops_server.cc_modules.cc_membership import UserGroupMembership from camcops_server.cc_modules.cc_testfactories import ( BaseFactory, - DeviceFactory, GroupFactory, + NHSPatientIdNumFactory, + PatientFactory, + RioPatientIdNumFactory, UserFactory, + UserGroupMembershipFactory, ) -from camcops_server.cc_modules.cc_version import CAMCOPS_SERVER_VERSION +from camcops_server.tasks.tests import factories as task_factories if TYPE_CHECKING: from sqlalchemy.orm import Session - from camcops_server.cc_modules.cc_db import GenericTabletRecordMixin - from camcops_server.cc_modules.cc_patient import Patient - from camcops_server.cc_modules.cc_patientidnum import PatientIdNum - from camcops_server.cc_modules.cc_task import Task log = BraceStyleAdapter(logging.getLogger(__name__)) @@ -91,7 +87,15 @@ # ============================================================================= -class ExtendedTestCase(unittest.TestCase): +class ExtendedTestCase(TestCase): + + def setUp(self) -> None: + super().setUp() + + # Arbitrary seed + Faker.seed(1234) + random.seed(1234) + """ A subclass of :class:`unittest.TestCase` that provides some additional functionality. @@ -135,6 +139,8 @@ class DemoRequestTestCase(ExtendedTestCase): db_filename: str def setUp(self) -> None: + super().setUp() + for factory in all_subclasses(BaseFactory): factory._meta.sqlalchemy_session = self.dbsession @@ -148,7 +154,7 @@ def setUp(self) -> None: # afterwards. self.old_config = copy.copy(self.req.config) - self.req.matched_route = unittest.mock.Mock() + self.req.matched_route = mock.Mock() self.recipdef = ExportRecipient() def tearDown(self) -> None: @@ -215,246 +221,29 @@ def dump_table( class BasicDatabaseTestCase(DemoRequestTestCase): """ - Test case that sets up some useful database records for testing: - ID numbers, user, group, devices etc and has helper methods for - creating patients and tasks + Test case that sets up some minimal database records for testing. """ def setUp(self) -> None: super().setUp() - self.set_era("2010-07-07T13:40+0100") - - # Set up groups, users, etc. - # ... ID number definitions - idnum_type_nhs = 1 - idnum_type_rio = 2 - idnum_type_study = 3 - self.nhs_iddef = IdNumDefinition( - which_idnum=idnum_type_nhs, - description="NHS number", - short_description="NHS#", - hl7_assigning_authority="NHS", - hl7_id_type="NHSN", - ) - self.dbsession.add(self.nhs_iddef) - self.rio_iddef = IdNumDefinition( - which_idnum=idnum_type_rio, - description="RiO number", - short_description="RiO", - hl7_assigning_authority="CPFT", - hl7_id_type="CPRiO", - ) - self.dbsession.add(self.rio_iddef) - self.study_iddef = IdNumDefinition( - which_idnum=idnum_type_study, - description="Study number", - short_description="Study", - ) - self.dbsession.add(self.study_iddef) - # ... group - self.group = Group() - self.group.name = "testgroup" - self.group.description = "Test group" - self.group.upload_policy = "sex AND anyidnum" - self.group.finalize_policy = "sex AND idnum1" - self.group.ip_use = IpUse() - self.dbsession.add(self.group) - self.dbsession.flush() # sets PK fields - GroupFactory.reset_sequence(self.group.id + 1) - - # ... users - - self.user = User.get_system_user(self.dbsession) - self.user.upload_group_id = self.group.id - self.req._debugging_user = self.user # improve our debugging user - - # ... devices - self.server_device = Device.get_server_device(self.dbsession) - DeviceFactory.reset_sequence(self.server_device.id + 1) - self.other_device = DeviceFactory( - name="other_device", - friendly_name="Test device that may upload", - registered_by_user=self.user, - when_registered_utc=self.era_time_utc, - camcops_version=CAMCOPS_SERVER_VERSION, - ) - # ... export recipient definition (the minimum) - self.recipdef.primary_idnum = idnum_type_nhs - - self.dbsession.flush() # sets PK fields - UserFactory.reset_sequence(self.user.id + 1) - - self.create_tasks() - - def set_era(self, iso_datetime: str) -> None: - from cardinal_pythonlib.datetimefunc import ( - convert_datetime_to_utc, - format_datetime, - ) - from camcops_server.cc_modules.cc_constants import DateFormat - - self.era_time = pendulum.parse(iso_datetime) - self.era_time_utc = convert_datetime_to_utc(self.era_time) - self.era = format_datetime(self.era_time, DateFormat.ISO8601) - - def create_patient_with_two_idnums(self) -> "Patient": - from camcops_server.cc_modules.cc_patient import Patient - from camcops_server.cc_modules.cc_patientidnum import PatientIdNum - - # Populate database with two of everything - patient = Patient() - patient.id = 1 - self.apply_standard_db_fields(patient) - patient.forename = "Forename1" - patient.surname = "Surname1" - patient.dob = pendulum.parse("1950-01-01") - self.dbsession.add(patient) - patient_idnum1 = PatientIdNum() - patient_idnum1.id = 1 - self.apply_standard_db_fields(patient_idnum1) - patient_idnum1.patient_id = patient.id - patient_idnum1.which_idnum = self.nhs_iddef.which_idnum - patient_idnum1.idnum_value = 333 - self.dbsession.add(patient_idnum1) - patient_idnum2 = PatientIdNum() - patient_idnum2.id = 2 - self.apply_standard_db_fields(patient_idnum2) - patient_idnum2.patient_id = patient.id - patient_idnum2.which_idnum = self.rio_iddef.which_idnum - patient_idnum2.idnum_value = 444 - self.dbsession.add(patient_idnum2) - self.dbsession.commit() - - return patient + self.group = GroupFactory() + self.groupadmin = UserFactory() - def create_patient_with_one_idnum(self) -> "Patient": - from camcops_server.cc_modules.cc_patient import Patient + self.superuser = UserFactory(superuser=True) - patient = Patient() - patient.id = 2 - self.apply_standard_db_fields(patient) - patient.forename = "Forename2" - patient.surname = "Surname2" - patient.dob = pendulum.parse("1975-12-12") - self.dbsession.add(patient) - - self.create_patient_idnum( - id=3, - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=555, + UserGroupMembershipFactory( + group_id=self.group.id, user_id=self.groupadmin.id, groupadmin=True ) - return patient - - def create_patient_idnum( - self, as_server_patient: bool = False, **kwargs: Any - ) -> "PatientIdNum": - from camcops_server.cc_modules.cc_patientidnum import PatientIdNum - - patient_idnum = PatientIdNum() - self.apply_standard_db_fields(patient_idnum, era_now=as_server_patient) - - for key, value in kwargs.items(): - setattr(patient_idnum, key, value) - - if "id" not in kwargs: - patient_idnum.save_with_next_available_id( - self.req, patient_idnum._device_id - ) - else: - self.dbsession.add(patient_idnum) - - self.dbsession.commit() - - return patient_idnum - - def create_patient( - self, as_server_patient: bool = False, **kwargs: Any - ) -> "Patient": - from camcops_server.cc_modules.cc_patient import Patient + self.system_user = User.get_system_user(self.dbsession) + self.system_user.upload_group_id = self.group.id - patient = Patient() - self.apply_standard_db_fields(patient, era_now=as_server_patient) - - for key, value in kwargs.items(): - setattr(patient, key, value) - - if "id" not in kwargs: - patient.save_with_next_available_id(self.req, patient._device_id) - else: - self.dbsession.add(patient) + self.req._debugging_user = self.superuser # improve our debugging user + self.server_device = Device.get_server_device(self.dbsession) self.dbsession.commit() - return patient - - def create_tasks(self) -> None: - # Override in subclass - pass - - def apply_standard_task_fields(self, task: "Task") -> None: - """ - Writes some default values to an SQLAlchemy ORM object representing - a task. - """ - self.apply_standard_db_fields(task) - task.when_created = self.era_time - - def apply_standard_db_fields( - self, obj: "GenericTabletRecordMixin", era_now: bool = False - ) -> None: - """ - Writes some default values to an SQLAlchemy ORM object representing a - record uploaded from a client (tablet) device. - - Though we use the server device ID. - """ - obj._device_id = self.server_device.id - obj._era = ERA_NOW if era_now else self.era - obj._group_id = self.group.id - obj._current = True - obj._adding_user_id = self.user.id - obj._when_added_batch_utc = self.era_time_utc - - def create_user(self, **kwargs) -> User: - user = User() - user.hashedpw = "" - - for key, value in kwargs.items(): - setattr(user, key, value) - - self.dbsession.add(user) - - return user - - def create_group(self, name: str, **kwargs) -> Group: - group = Group() - group.name = name - - for key, value in kwargs.items(): - setattr(group, key, value) - - self.dbsession.add(group) - - return group - - def create_membership( - self, user: User, group: Group, **kwargs - ) -> UserGroupMembership: - ugm = UserGroupMembership(user_id=user.id, group_id=group.id) - - for key, value in kwargs.items(): - setattr(ugm, key, value) - - self.dbsession.add(ugm) - - return ugm - - def tearDown(self) -> None: - pass - class DemoDatabaseTestCase(BasicDatabaseTestCase): """ @@ -462,43 +251,39 @@ class DemoDatabaseTestCase(BasicDatabaseTestCase): each type """ - def create_tasks(self) -> None: - from camcops_server.cc_modules.cc_blob import Blob - from camcops_server.tasks.photo import Photo - from camcops_server.cc_modules.cc_task import Task + def setUp(self) -> None: + super().setUp() - patient_with_two_idnums = self.create_patient_with_two_idnums() - patient_with_one_idnum = self.create_patient_with_one_idnum() + self.demo_database_group = GroupFactory() - for cls in Task.all_subclasses_by_tablename(): - t1 = cls() - t1.id = 1 - self.apply_standard_task_fields(t1) - if t1.has_patient: - t1.patient_id = patient_with_two_idnums.id - - if isinstance(t1, Photo): - b = Blob() - b.id = 1 - self.apply_standard_db_fields(b) - b.tablename = t1.tablename - b.tablepk = t1.id - b.fieldname = "photo_blobid" - b.filename = "some_picture.png" - b.mimetype = MimeType.PNG - b.image_rotation_deg_cw = 0 - b.theblob = DEMO_PNG_BYTES - self.dbsession.add(b) - - t1.photo_blobid = b.id - - self.dbsession.add(t1) - - t2 = cls() - t2.id = 2 - self.apply_standard_task_fields(t2) - if t2.has_patient: - t2.patient_id = patient_with_one_idnum.id - self.dbsession.add(t2) + patient_with_two_idnums = PatientFactory( + _group=self.demo_database_group + ) + NHSPatientIdNumFactory(patient=patient_with_two_idnums) + RioPatientIdNumFactory(patient=patient_with_two_idnums) - self.dbsession.commit() + patient_with_one_idnum = PatientFactory( + _group=self.demo_database_group + ) + NHSPatientIdNumFactory(patient=patient_with_one_idnum) + + for cls in Task.all_subclasses_by_tablename(): + factory_class = getattr(task_factories, f"{cls.__name__}Factory") + + t1_kwargs: Dict[str, Any] = dict(_group=self.demo_database_group) + t2_kwargs = t1_kwargs + if issubclass(cls, TaskHasPatientMixin): + t1_kwargs.update(patient=patient_with_two_idnums) + t2_kwargs.update(patient=patient_with_one_idnum) + + if cls.__name__ == "Photo": + t1_kwargs.update( + create_blob__fieldname="photo_blobid", + create_blob__filename="some_picture.png", + create_blob__mimetype=MimeType.PNG, + create_blob__image_rotation_deg_cw=0, + create_blob__theblob=DEMO_PNG_BYTES, + ) + + factory_class(**t1_kwargs) + factory_class(**t2_kwargs) diff --git a/server/camcops_server/cc_modules/tests/cc_fhir_tests.py b/server/camcops_server/cc_modules/tests/cc_fhir_tests.py index 69475a630..6b5af11fa 100644 --- a/server/camcops_server/cc_modules/tests/cc_fhir_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_fhir_tests.py @@ -32,11 +32,10 @@ from unittest import mock from cardinal_pythonlib.httpconst import HttpMethod -from cardinal_pythonlib.nhs import generate_random_nhs_number import pendulum from requests.exceptions import HTTPError -from camcops_server.cc_modules.cc_constants import FHIRConst as Fc, JSON_INDENT +from camcops_server.cc_modules.cc_constants import FHIRConst as Fc from camcops_server.cc_modules.cc_exportmodels import ( ExportedTask, ExportedTaskFhir, @@ -53,36 +52,31 @@ FhirTaskExporter, ) from camcops_server.cc_modules.cc_pyramid import Routes -from camcops_server.cc_modules.cc_unittest import DemoDatabaseTestCase +from camcops_server.cc_modules.cc_testfactories import ( + NHSPatientIdNumFactory, + PatientFactory, + RioPatientIdNumFactory, +) + +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase from camcops_server.cc_modules.cc_version_string import ( CAMCOPS_SERVER_VERSION_STRING, ) -from camcops_server.tasks.apeqpt import Apeqpt -from camcops_server.tasks.bmi import Bmi -from camcops_server.tasks.gad7 import Gad7 -from camcops_server.tasks.diagnosis import ( - DiagnosisIcd10, - DiagnosisIcd10Item, - DiagnosisIcd9CM, - DiagnosisIcd9CMItem, +from camcops_server.tasks.tests.factories import ( + ApeqptFactory, + BmiFactory, + DiagnosisIcd10Factory, + DiagnosisIcd10ItemFactory, + DiagnosisIcd9CMFactory, + DiagnosisIcd9CMItemFactory, + Gad7Factory, + Phq9Factory, ) -from camcops_server.tasks.phq9 import Phq9 log = logging.getLogger() -# ============================================================================= -# Constants -# ============================================================================= - -TEST_NHS_NUMBER = generate_random_nhs_number() -TEST_RIO_NUMBER = 12345 -TEST_FORENAME = "Gwendolyn" -TEST_SURNAME = "Ryann" -TEST_SEX = "F" - - # ============================================================================= # Helper classes # ============================================================================= @@ -100,13 +94,12 @@ def __init__(self, response_json: Dict): ) -class FhirExportTestCase(DemoDatabaseTestCase): +class FhirExportTestCase(DemoRequestTestCase): def setUp(self) -> None: super().setUp() recipientinfo = ExportRecipientInfo() self.recipient = ExportRecipient(recipientinfo) - self.recipient.primary_idnum = self.rio_iddef.which_idnum self.recipient.fhir_api_url = "https://www.example.com/fhir" # auto increment doesn't work for BigInteger with SQLite @@ -116,20 +109,16 @@ def setUp(self) -> None: self.camcops_root_url = self.req.route_url(Routes.HOME).rstrip("/") # ... no trailing slash - def create_fhir_patient(self) -> None: - self.patient = self.create_patient( - forename=TEST_FORENAME, surname=TEST_SURNAME, sex=TEST_SEX - ) - self.patient_nhs = self.create_patient_idnum( - patient_id=self.patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - ) - self.patient_rio = self.create_patient_idnum( - patient_id=self.patient.id, - which_idnum=self.rio_iddef.which_idnum, - idnum_value=TEST_RIO_NUMBER, - ) + +class FhirExportPatientTestCase(FhirExportTestCase): + def setUp(self) -> None: + super().setUp() + + self.patient = PatientFactory() + self.patient_nhs_idnum = NHSPatientIdNumFactory(patient=self.patient) + self.patient_rio_idnum = RioPatientIdNumFactory(patient=self.patient) + + self.recipient.primary_idnum = self.patient_rio_idnum.which_idnum # ============================================================================= @@ -137,27 +126,23 @@ def create_fhir_patient(self) -> None: # ============================================================================= -class FhirTaskExporterPhq9Tests(FhirExportTestCase): - def create_tasks(self) -> None: - self.create_fhir_patient() - - self.task = Phq9() - self.apply_standard_task_fields(self.task) - self.task.q1 = 0 - self.task.q2 = 1 - self.task.q3 = 2 - self.task.q4 = 3 - self.task.q5 = 0 - self.task.q6 = 1 - self.task.q7 = 2 - self.task.q8 = 3 - self.task.q9 = 0 - self.task.q10 = 3 - self.task.patient_id = self.patient.id - self.task.save_with_next_available_id( - self.req, self.patient._device_id +class FhirTaskExporterPhq9Tests(FhirExportPatientTestCase): + def setUp(self) -> None: + super().setUp() + + self.task = Phq9Factory( + patient=self.patient, + q1=0, + q2=1, + q3=2, + q4=3, + q5=0, + q6=1, + q7=2, + q8=3, + q9=0, + q10=3, ) - self.dbsession.commit() def test_patient_exported(self) -> None: exported_task = ExportedTask(task=self.task, recipient=self.recipient) @@ -185,7 +170,7 @@ def test_patient_exported(self) -> None: self.assertEqual(patient[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_PATIENT) identifier = patient[Fc.IDENTIFIER] - idnum_value = self.patient_rio.idnum_value + idnum_value = self.patient_rio_idnum.idnum_value patient_id = self.patient.get_fhir_identifier(self.req, self.recipient) @@ -346,7 +331,7 @@ def test_questionnaire_response_exported(self) -> None: subject = response[Fc.SUBJECT] identifier = subject[Fc.IDENTIFIER] self.assertEqual(subject[Fc.TYPE], Fc.RESOURCE_TYPE_PATIENT) - idnum_value = self.patient_rio.idnum_value + idnum_value = self.patient_rio_idnum.idnum_value patient_id = self.patient.get_fhir_identifier(self.req, self.recipient) if isinstance(identifier, list): @@ -534,18 +519,17 @@ def test_raises_for_missing_api_url(self) -> None: class FhirTaskExporterAnonymousTests(FhirExportTestCase): - def create_tasks(self) -> None: - self.task = Apeqpt() - self.apply_standard_task_fields(self.task) - self.task.q_datetime = pendulum.now() - self.task.q1_choice = 0 - self.task.q2_choice = 1 - self.task.q3_choice = 2 - self.task.q1_satisfaction = 3 - self.task.q2_satisfaction = "Service experience" - - self.task.save_with_next_available_id(self.req, self.server_device.id) - self.dbsession.commit() + def setUp(self) -> None: + super().setUp() + + self.task = ApeqptFactory( + q_datetime=pendulum.now(), + q1_choice=0, + q2_choice=1, + q3_choice=2, + q1_satisfaction=3, + q2_satisfaction="Service experience", + ) def test_questionnaire_exported(self) -> None: exported_task = ExportedTask(task=self.task, recipient=self.recipient) @@ -764,45 +748,73 @@ def test_questionnaire_response_exported(self) -> None: # ============================================================================= -class FhirTaskExporterBMITests(FhirExportTestCase): - def create_tasks(self) -> None: - self.create_fhir_patient() +class FhirTaskExporterBMITests(FhirExportPatientTestCase): + def setUp(self) -> None: + super().setUp() - self.task = Bmi() - self.apply_standard_task_fields(self.task) - self.task.mass_kg = 70 - self.task.height_m = 1.8 - self.task.waist_cm = 82 - self.task.patient_id = self.patient.id - self.task.save_with_next_available_id( - self.req, self.patient._device_id - ) - self.dbsession.commit() + self.task = BmiFactory(patient=self.patient) def test_observations(self) -> None: bundle = self.task.get_fhir_bundle( self.req, self.recipient, skip_docs_if_other_content=True ) - bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT) - log.debug(f"Bundle:\n{bundle_str}") - # The test is that it doesn't crash. + bundle_json = bundle.as_json() + + height_entry = bundle_json[Fc.ENTRY][3] + mass_entry = bundle_json[Fc.ENTRY][4] + bmi_entry = bundle_json[Fc.ENTRY][5] + waist_entry = bundle_json[Fc.ENTRY][6] -class FhirTaskExporterDiagnosisIcd10Tests(FhirExportTestCase): - def create_tasks(self) -> None: - self.create_fhir_patient() + height_resource = height_entry[Fc.RESOURCE] + mass_resource = mass_entry[Fc.RESOURCE] + bmi_resource = bmi_entry[Fc.RESOURCE] + waist_resource = waist_entry[Fc.RESOURCE] - self.task = DiagnosisIcd10() - self.apply_standard_task_fields(self.task) - self.task.patient_id = self.patient.id - self.task.save_with_next_available_id( - self.req, self.patient._device_id + self.assertEqual( + height_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION + ) + self.assertEqual( + height_resource[Fc.VALUE_QUANTITY][Fc.VALUE], self.task.height_m + ) + + self.assertEqual( + mass_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION + ) + self.assertAlmostEqual( + mass_resource[Fc.VALUE_QUANTITY][Fc.VALUE], + self.task.mass_kg, + places=2, + ) + + self.assertEqual( + bmi_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION + ) + self.assertAlmostEqual( + bmi_resource[Fc.VALUE_QUANTITY][Fc.VALUE], + self.task.bmi(), + places=2, + ) + + self.assertEqual( + waist_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_OBSERVATION + ) + self.assertAlmostEqual( + waist_resource[Fc.VALUE_QUANTITY][Fc.VALUE], + self.task.waist_cm, + places=2, ) - self.dbsession.commit() + + +class FhirTaskExporterDiagnosisIcd10Tests(FhirExportPatientTestCase): + def setUp(self) -> None: + super().setUp() + + self.task = DiagnosisIcd10Factory(patient=self.patient) # noinspection PyArgumentList - item1 = DiagnosisIcd10Item( - diagnosis_icd10_id=self.task.id, + self.item1 = DiagnosisIcd10ItemFactory( + diagnosis_icd10=self.task, seqnum=1, code="F33.30", description="Recurrent depressive disorder, current episode " @@ -810,89 +822,143 @@ def create_tasks(self) -> None: "with mood-congruent psychotic symptoms", comment="Cotard's syndrome", ) - self.apply_standard_db_fields(item1) - item1.save_with_next_available_id(self.req, self.task._device_id) # noinspection PyArgumentList - item2 = DiagnosisIcd10Item( - diagnosis_icd10_id=self.task.id, + self.item2 = DiagnosisIcd10ItemFactory( + diagnosis_icd10=self.task, seqnum=2, code="F43.1", description="Post-traumatic stress disorder", ) - self.apply_standard_db_fields(item2) - item2.save_with_next_available_id(self.req, self.task._device_id) def test_observations(self) -> None: bundle = self.task.get_fhir_bundle( self.req, self.recipient, skip_docs_if_other_content=True ) - bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT) - log.debug(f"Bundle:\n{bundle_str}") - # The test is that it doesn't crash. + bundle_json = bundle.as_json() -class FhirTaskExporterDiagnosisIcd9CMTests(FhirExportTestCase): - def create_tasks(self) -> None: - self.create_fhir_patient() + cotard_resource = bundle_json[Fc.ENTRY][4][Fc.RESOURCE] + ptsd_resource = bundle_json[Fc.ENTRY][5][Fc.RESOURCE] - self.task = DiagnosisIcd9CM() - self.apply_standard_task_fields(self.task) - self.task.patient_id = self.patient.id - self.task.save_with_next_available_id( - self.req, self.patient._device_id + self.assertEqual( + cotard_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_CONDITION ) - self.dbsession.commit() + self.assertEqual( + cotard_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "F33.30" + ) + self.assertIn( + "Cotard's syndrome", + cotard_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY], + ) + self.assertIn( + "Recurrent depressive", + cotard_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY], + ) + + self.assertEqual( + ptsd_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "F43.1" + ) + + +class FhirTaskExporterDiagnosisIcd9CMTests(FhirExportPatientTestCase): + def setUp(self) -> None: + super().setUp() + + self.task = DiagnosisIcd9CMFactory(patient=self.patient) # noinspection PyArgumentList - item1 = DiagnosisIcd9CMItem( - diagnosis_icd9cm_id=self.task.id, + self.item1 = DiagnosisIcd9CMItemFactory( + diagnosis_icd9cm=self.task, seqnum=1, code="290.4", description="Vascular dementia", comment="or perhaps mixed dementia", ) - self.apply_standard_db_fields(item1) - item1.save_with_next_available_id(self.req, self.task._device_id) # noinspection PyArgumentList - item2 = DiagnosisIcd9CMItem( - diagnosis_icd9cm_id=self.task.id, + self.item2 = DiagnosisIcd9CMItemFactory( + diagnosis_icd9cm=self.task, seqnum=2, code="303.0", description="Acute alcoholic intoxication", ) - self.apply_standard_db_fields(item2) - item2.save_with_next_available_id(self.req, self.task._device_id) def test_observations(self) -> None: bundle = self.task.get_fhir_bundle( self.req, self.recipient, skip_docs_if_other_content=True ) - bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT) - log.debug(f"Bundle:\n{bundle_str}") - # The test is that it doesn't crash. + bundle_json = bundle.as_json() + dementia_resource = bundle_json[Fc.ENTRY][4][Fc.RESOURCE] + intoxication_resource = bundle_json[Fc.ENTRY][5][Fc.RESOURCE] + + self.assertEqual( + dementia_resource[Fc.RESOURCE_TYPE], Fc.RESOURCE_TYPE_CONDITION + ) + self.assertEqual( + dementia_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "290.4" + ) + self.assertIn( + "Vascular dementia", + dementia_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY], + ) + self.assertIn( + "or perhaps mixed dementia", + dementia_resource[Fc.CODE][Fc.CODING][0][Fc.DISPLAY], + ) + + self.assertEqual( + intoxication_resource[Fc.CODE][Fc.CODING][0][Fc.CODE], "303.0" + ) -class FhirTaskExporterGad7Tests(FhirExportTestCase): +class FhirTaskExporterGad7Tests(FhirExportPatientTestCase): """ The GAD7 is a standard questionnaire that we don't provide any special - FHIR support for; we rely on autodiscovery. + FHIR support for; we rely on autodiscovery. This is essentially a high + level test for _fhir_autodiscover() in cc_task.py, albeit not a + particularly thorough one. """ - def create_tasks(self) -> None: - self.create_fhir_patient() + def setUp(self) -> None: + super().setUp() - self.task = Gad7() - self.apply_standard_task_fields(self.task) - self.task.patient_id = self.patient.id - self.task.save_with_next_available_id( - self.req, self.patient._device_id + self.task = Gad7Factory( + patient=self.patient, + q1=0, + q2=1, + q3=2, + q4=3, + q5=0, + q6=1, + q7=2, ) - self.dbsession.commit() def test_observations(self) -> None: bundle = self.task.get_fhir_bundle( self.req, self.recipient, skip_docs_if_other_content=True ) - bundle_str = json.dumps(bundle.as_json(), indent=JSON_INDENT) - log.critical(f"Bundle:\n{bundle_str}") - # The test is that it doesn't crash. + bundle_json = bundle.as_json() + questions = bundle_json[Fc.ENTRY][1][Fc.RESOURCE][Fc.ITEM] + answers = bundle_json[Fc.ENTRY][2][Fc.RESOURCE][Fc.ITEM] + + # Question text + self.assertIn( + "1. Feeling nervous, anxious or on edge", questions[0][Fc.TEXT] + ) + # Comment string + self.assertIn( + "Q1, nervous/anxious/on edge (0 not at all - 3 nearly every day)", + questions[0][Fc.TEXT], + ) + + self.assertIn( + "1. Feeling nervous, anxious or on edge", answers[0][Fc.TEXT] + ) + self.assertIn( + "Q1, nervous/anxious/on edge (0 not at all - 3 nearly every day)", + answers[0][Fc.TEXT], + ) + + self.assertEqual(answers[0][Fc.ANSWER][0][Fc.VALUE_INTEGER], 0) + self.assertEqual(answers[1][Fc.ANSWER][0][Fc.VALUE_INTEGER], 1) + self.assertEqual(answers[2][Fc.ANSWER][0][Fc.VALUE_INTEGER], 2) + self.assertEqual(answers[3][Fc.ANSWER][0][Fc.VALUE_INTEGER], 3) diff --git a/server/camcops_server/cc_modules/tests/cc_forms_tests.py b/server/camcops_server/cc_modules/tests/cc_forms_tests.py index 293c281b7..42ae682b0 100644 --- a/server/camcops_server/cc_modules/tests/cc_forms_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_forms_tests.py @@ -54,33 +54,22 @@ ) from camcops_server.cc_modules.cc_ipuse import IpContexts from camcops_server.cc_modules.cc_pyramid import ViewParam -from camcops_server.cc_modules.cc_taskschedule import TaskSchedule +from camcops_server.cc_modules.cc_testfactories import ( + Fake, + GroupFactory, + TaskScheduleFactory, + UserFactory, + UserGroupMembershipFactory, +) from camcops_server.cc_modules.cc_unittest import ( BasicDatabaseTestCase, - DemoDatabaseTestCase, DemoRequestTestCase, ) -TEST_PHONE_NUMBER = "+{ctry}{tel}".format( - ctry=phonenumbers.PhoneMetadata.metadata_for_region("GB").country_code, - tel=phonenumbers.PhoneMetadata.metadata_for_region( - "GB" - ).personal_number.example_number, -) # see webview_tests.py - log = logging.getLogger(__name__) -# ============================================================================= -# Unit tests -# ============================================================================= - - class SchemaTestCase(DemoRequestTestCase): - """ - Unit tests. - """ - def serialize_deserialize( self, schema: Schema, appstruct: Dict[str, Any] ) -> None: @@ -113,7 +102,7 @@ def test_serialize_deserialize(self) -> None: self.serialize_deserialize(schema, appstruct) -class TaskScheduleSchemaTests(DemoDatabaseTestCase): +class TaskScheduleSchemaTests(BasicDatabaseTestCase): def test_invalid_for_bad_template_placeholder(self) -> None: schema = TaskScheduleSchema().bind(request=self.req) cstruct = { @@ -252,23 +241,19 @@ def test_invalid_for_negative_due_from(self) -> None: self.assertIn("must be zero or more days", cm.exception.messages()[0]) -class TaskScheduleItemSchemaIpTests(BasicDatabaseTestCase): - def setUp(self) -> None: - super().setUp() - - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.dbsession.add(self.schedule) - self.dbsession.commit() - +class TaskScheduleItemSchemaIpTests(DemoRequestTestCase): def test_invalid_for_commercial_mismatch(self) -> None: - self.group.ip_use.commercial = True - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory( + ip_use__clinical=False, + ip_use__commercial=True, + ip_use__educational=False, + ip_use__research=False, + ) + schedule = TaskScheduleFactory(group=group) schema = TaskScheduleItemSchema().bind(request=self.req) appstruct = { - ViewParam.SCHEDULE_ID: self.schedule.id, + ViewParam.SCHEDULE_ID: schedule.id, ViewParam.TABLE_NAME: "mfi20", ViewParam.CLINICIAN_CONFIRMATION: False, ViewParam.DUE_FROM: Duration(days=0), @@ -282,13 +267,17 @@ def test_invalid_for_commercial_mismatch(self) -> None: self.assertIn("prohibits commercial", cm.exception.messages()[0]) def test_invalid_for_clinical_mismatch(self) -> None: - self.group.ip_use.clinical = True - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory( + ip_use__clinical=True, + ip_use__commercial=False, + ip_use__educational=False, + ip_use__research=False, + ) + schedule = TaskScheduleFactory(group=group) schema = TaskScheduleItemSchema().bind(request=self.req) appstruct = { - ViewParam.SCHEDULE_ID: self.schedule.id, + ViewParam.SCHEDULE_ID: schedule.id, ViewParam.TABLE_NAME: "mfi20", ViewParam.CLINICIAN_CONFIRMATION: False, ViewParam.DUE_FROM: Duration(days=0), @@ -302,13 +291,17 @@ def test_invalid_for_clinical_mismatch(self) -> None: self.assertIn("prohibits clinical", cm.exception.messages()[0]) def test_invalid_for_educational_mismatch(self) -> None: - self.group.ip_use.educational = True - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory( + ip_use__clinical=False, + ip_use__commercial=False, + ip_use__educational=True, + ip_use__research=False, + ) + schedule = TaskScheduleFactory(group=group) schema = TaskScheduleItemSchema().bind(request=self.req) appstruct = { - ViewParam.SCHEDULE_ID: self.schedule.id, + ViewParam.SCHEDULE_ID: schedule.id, ViewParam.TABLE_NAME: "mfi20", ViewParam.CLINICIAN_CONFIRMATION: True, ViewParam.DUE_FROM: Duration(days=0), @@ -328,13 +321,17 @@ def test_invalid_for_educational_mismatch(self) -> None: self.assertIn("prohibits educational", cm.exception.messages()[0]) def test_invalid_for_research_mismatch(self) -> None: - self.group.ip_use.research = True - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory( + ip_use__clinical=False, + ip_use__commercial=False, + ip_use__educational=False, + ip_use__research=True, + ) + schedule = TaskScheduleFactory(group=group) schema = TaskScheduleItemSchema().bind(request=self.req) appstruct = { - ViewParam.SCHEDULE_ID: self.schedule.id, + ViewParam.SCHEDULE_ID: schedule.id, ViewParam.TABLE_NAME: "moca", ViewParam.CLINICIAN_CONFIRMATION: True, ViewParam.DUE_FROM: Duration(days=0), @@ -348,13 +345,12 @@ def test_invalid_for_research_mismatch(self) -> None: self.assertIn("prohibits research", cm.exception.messages()[0]) def test_invalid_for_missing_ip_use(self) -> None: - self.group.ip_use = None - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory(ip_use=None) + schedule = TaskScheduleFactory(group=group) schema = TaskScheduleItemSchema().bind(request=self.req) appstruct = { - ViewParam.SCHEDULE_ID: self.schedule.id, + ViewParam.SCHEDULE_ID: schedule.id, ViewParam.TABLE_NAME: "moca", ViewParam.CLINICIAN_CONFIRMATION: True, ViewParam.DUE_FROM: Duration(days=0), @@ -366,7 +362,7 @@ def test_invalid_for_missing_ip_use(self) -> None: schema.deserialize(cstruct) self.assertIn( - f"The group '{self.group.name}' has no intellectual property " + f"The group '{group.name}' has no intellectual property " f"settings", cm.exception.messages()[0], ) @@ -713,25 +709,17 @@ def test_deserialize_not_a_json_object_fails_validation(self) -> None: self.assertEqual(cm.exception.value, "[{}]") -class TaskScheduleSelectorTests(BasicDatabaseTestCase): +class TaskScheduleSelectorTests(DemoRequestTestCase): def test_displays_only_users_schedules(self) -> None: - user = self.create_user(username="regular_user") - my_group = self.create_group("mygroup") - not_my_group = self.create_group("notmygroup") - self.dbsession.flush() - - self.create_membership(user, my_group, may_manage_patients=True) - - my_schedule = TaskSchedule() - my_schedule.group_id = my_group.id - my_schedule.name = "My group's schedule" - self.dbsession.add(my_schedule) + user = UserFactory() + my_group = GroupFactory() + not_my_group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=my_group.id, may_manage_patients=True + ) - not_my_schedule = TaskSchedule() - not_my_schedule.group_id = not_my_group.id - not_my_schedule.name = "Not my group's schedule" - self.dbsession.add(not_my_schedule) - self.dbsession.commit() + my_schedule = TaskScheduleFactory(group=my_group) + not_my_schedule = TaskScheduleFactory(group=not_my_group) self.req._debugging_user = user @@ -994,9 +982,8 @@ def test_raises_for_invalid_parsable_number(self) -> None: self.assertIn("Invalid phone number", cm.exception.messages()[0]) def test_returns_valid_phone_number(self) -> None: - phone_number = self.phone_type.deserialize( - self.node, TEST_PHONE_NUMBER - ) + phone_number_str = Fake.en_gb.valid_phone_number() + phone_number = self.phone_type.deserialize(self.node, phone_number_str) self.assertIsInstance(phone_number, phonenumbers.PhoneNumber) @@ -1004,7 +991,7 @@ def test_returns_valid_phone_number(self) -> None: phonenumbers.format_number( phone_number, phonenumbers.PhoneNumberFormat.E164 ), - TEST_PHONE_NUMBER, + phone_number_str, ) @@ -1013,11 +1000,12 @@ def test_returns_null_for_appstruct_none(self) -> None: self.assertIs(self.phone_type.serialize(self.node, None), null) def test_returns_number_formatted_e164(self) -> None: - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) + phone_number_str = Fake.en_gb.valid_phone_number() + phone_number = phonenumbers.parse(phone_number_str) self.assertEqual( self.phone_type.serialize(self.node, phone_number), - TEST_PHONE_NUMBER, + phone_number_str, ) diff --git a/server/camcops_server/cc_modules/tests/cc_patient_tests.py b/server/camcops_server/cc_modules/tests/cc_patient_tests.py index eebbb491b..e543171b0 100644 --- a/server/camcops_server/cc_modules/tests/cc_patient_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_patient_tests.py @@ -28,8 +28,8 @@ import hl7 import pendulum +from camcops_server.cc_modules.cc_group import Group from camcops_server.cc_modules.cc_simpleobjects import BarePatientInfo -from camcops_server.cc_modules.cc_patient import Patient from camcops_server.cc_modules.cc_patientidnum import PatientIdNum from camcops_server.cc_modules.cc_simpleobjects import IdNumReference from camcops_server.cc_modules.cc_taskschedule import ( @@ -38,9 +38,21 @@ TaskScheduleItem, ) from camcops_server.cc_modules.cc_spreadsheet import SpreadsheetPage +from camcops_server.cc_modules.cc_testfactories import ( + GroupFactory, + NHSPatientIdNumFactory, + PatientFactory, + PatientTaskScheduleFactory, + RioPatientIdNumFactory, + ServerCreatedPatientFactory, + TaskScheduleFactory, + TaskScheduleItemFactory, + UserFactory, + UserGroupMembershipFactory, +) from camcops_server.cc_modules.cc_unittest import ( BasicDatabaseTestCase, - DemoDatabaseTestCase, + DemoRequestTestCase, ) from camcops_server.cc_modules.cc_xml import XmlElement @@ -50,26 +62,30 @@ # ============================================================================= -class PatientTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class PatientTests(DemoRequestTestCase): def test_patient(self) -> None: - self.announce("test_patient") - from camcops_server.cc_modules.cc_group import Group - req = self.req - q = self.dbsession.query(Patient) - p = q.first() # type: Patient - assert p, "Missing Patient in demo database!" + req._debugging_user = UserFactory() - for pidnum in p.get_idnum_objects(): + p = PatientFactory() + nhs_idnum = NHSPatientIdNumFactory(patient=p) + RioPatientIdNumFactory(patient=p) + + idnum_objects = p.get_idnum_objects() + self.assertEqual(len(idnum_objects), 2) + for pidnum in idnum_objects: self.assertIsInstance(pidnum, PatientIdNum) - for idref in p.get_idnum_references(): + + idnum_references = p.get_idnum_references() + self.assertEqual(len(idnum_references), 2) + for idref in idnum_references: self.assertIsInstance(idref, IdNumReference) - for idnum in p.get_idnum_raw_values_only(): + + idnum_raw_values = p.get_idnum_raw_values_only() + self.assertEqual(len(idnum_raw_values), 2) + for idnum in idnum_raw_values: self.assertIsInstance(idnum, int) + self.assertIsInstance(p.get_xml_root(req), XmlElement) self.assertIsInstance(p.get_spreadsheet_page(req), SpreadsheetPage) self.assertIsInstance(p.get_bare_ptinfo(), BarePatientInfo) @@ -99,104 +115,82 @@ def test_patient(self) -> None: p.get_hl7_pid_segment(req, self.recipdef), hl7.Segment ) self.assertIsInstanceOrNone( - p.get_idnum_object(which_idnum=1), PatientIdNum + p.get_idnum_object(which_idnum=nhs_idnum.which_idnum), PatientIdNum + ) + self.assertIsInstanceOrNone( + p.get_idnum_value(which_idnum=nhs_idnum.which_idnum), int + ) + self.assertIsInstance( + p.get_iddesc(req, which_idnum=nhs_idnum.which_idnum), str + ) + self.assertIsInstance( + p.get_idshortdesc(req, which_idnum=nhs_idnum.which_idnum), str ) - self.assertIsInstanceOrNone(p.get_idnum_value(which_idnum=1), int) - self.assertIsInstance(p.get_iddesc(req, which_idnum=1), str) - self.assertIsInstance(p.get_idshortdesc(req, which_idnum=1), str) self.assertIsInstance(p.is_preserved(), bool) self.assertIsInstance(p.is_finalized(), bool) self.assertIsInstance(p.user_may_edit(req), bool) def test_surname_forename_upper(self) -> None: - patient = Patient() - patient.forename = "Forename" - patient.surname = "Surname" - + patient = PatientFactory(forename="Forename", surname="Surname") self.assertEqual( patient.get_surname_forename_upper(), "SURNAME, FORENAME" ) def test_surname_forename_upper_no_forename(self) -> None: - patient = Patient() - patient.surname = "Surname" - + patient = PatientFactory(forename=None, surname="Surname") self.assertEqual( patient.get_surname_forename_upper(), "SURNAME, (UNKNOWN)" ) def test_surname_forename_upper_no_surname(self) -> None: - patient = Patient() - patient.forename = "Forename" - + patient = PatientFactory(forename="Forename", surname=None) self.assertEqual( patient.get_surname_forename_upper(), "(UNKNOWN), FORENAME" ) -class LineageTests(DemoDatabaseTestCase): - def create_tasks(self) -> None: - # Actually not creating any tasks but we don't want the patients - # created by default in the baseclass - - # First record for patient 1 - self.set_era("2020-01-01") - - self.patient_1 = Patient() - self.patient_1.id = 1 - self.apply_standard_db_fields(self.patient_1) - self.dbsession.add(self.patient_1) - - # First ID number record for patient 1 - self.patient_idnum_1_1 = PatientIdNum() - self.patient_idnum_1_1.id = 3 - self.apply_standard_db_fields(self.patient_idnum_1_1) - self.patient_idnum_1_1.patient_id = 1 - self.patient_idnum_1_1.which_idnum = self.nhs_iddef.which_idnum - self.patient_idnum_1_1.idnum_value = 555 - self.dbsession.add(self.patient_idnum_1_1) - - # Second ID number record for patient 1 - self.patient_idnum_1_2 = PatientIdNum() - self.patient_idnum_1_2.id = 3 - self.apply_standard_db_fields(self.patient_idnum_1_2) - # This one is not current - self.patient_idnum_1_2._current = False - self.patient_idnum_1_2.patient_id = 1 - self.patient_idnum_1_2.which_idnum = self.nhs_iddef.which_idnum - self.patient_idnum_1_2.idnum_value = 555 - self.dbsession.add(self.patient_idnum_1_2) +class LineageTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() - self.dbsession.commit() + self.patient = PatientFactory() + self.current_patient_idnum = NHSPatientIdNumFactory( + patient=self.patient + ) + self.assertTrue(self.current_patient_idnum._current) + + self.not_current_patient_idnum = NHSPatientIdNumFactory( + patient=self.patient, + _current=False, + id=self.current_patient_idnum.id, + which_idnum=self.current_patient_idnum.which_idnum, + idnum_value=self.current_patient_idnum.idnum_value, + ) + self.assertFalse(self.not_current_patient_idnum._current) def test_gen_patient_idnums_even_noncurrent(self) -> None: - idnums = list(self.patient_1.gen_patient_idnums_even_noncurrent()) + idnums = list(self.patient.gen_patient_idnums_even_noncurrent()) self.assertEqual(len(idnums), 2) -class PatientDeleteTests(DemoDatabaseTestCase): +class PatientDeleteTests(DemoRequestTestCase): def test_deletes_patient_task_schedule(self) -> None: - schedule = TaskSchedule() - schedule.group_id = self.group.id - self.dbsession.add(schedule) - self.dbsession.flush() - - item = TaskScheduleItem() - item.schedule_id = schedule.id - item.task_table_name = "ace3" - item.due_from = pendulum.Duration(days=30) - item.due_by = pendulum.Duration(days=60) - self.dbsession.add(item) - self.dbsession.flush() - - patient = self.create_patient() - - pts = PatientTaskSchedule() - pts.schedule_id = schedule.id - pts.patient_pk = patient.pk - self.dbsession.add(pts) - self.dbsession.commit() + schedule = TaskScheduleFactory() + + item = TaskScheduleItemFactory( + task_schedule=schedule, + task_table_name="ace3", + due_from=pendulum.Duration(days=30), + due_by=pendulum.Duration(days=60), + ) + + patient = ServerCreatedPatientFactory() + + pts = PatientTaskScheduleFactory( + task_schedule=schedule, + patient=patient, + ) self.assertIsNotNone( self.dbsession.query(TaskSchedule) @@ -236,60 +230,52 @@ def test_deletes_patient_task_schedule(self) -> None: class PatientPermissionTests(BasicDatabaseTestCase): - def test_group_administrator_may_edit_server_created(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() + def setUp(self) -> None: + super().setUp() - patient = self.create_patient( - _group=self.group, as_server_patient=True - ) + self.user = UserFactory() + self.group = GroupFactory() - self.create_membership(user, self.group, groupadmin=True) - self.dbsession.commit() + def test_group_administrator_may_edit_server_patient(self) -> None: + patient = ServerCreatedPatientFactory(_group=self.group) + ugm = UserGroupMembershipFactory( + user_id=self.user.id, group_id=self.group.id, groupadmin=True + ) - self.req._debugging_user = user + self.req._debugging_user = ugm.user self.assertTrue(patient.user_may_edit(self.req)) - def test_group_administrator_may_edit_finalized(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() - - patient = self.create_patient( - _group=self.group, as_server_patient=False + def test_group_administrator_may_edit_finalized_patient(self) -> None: + patient = PatientFactory(_group=self.group) + ugm = UserGroupMembershipFactory( + user_id=self.user.id, group_id=self.group.id, groupadmin=True ) - self.create_membership(user, self.group, groupadmin=True) - self.dbsession.commit() + self.assertTrue(ugm.groupadmin) - self.req._debugging_user = user + self.req._debugging_user = ugm.user self.assertTrue(patient.user_may_edit(self.req)) def test_group_member_with_permission_may_edit_server_created( self, ) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() - - patient = self.create_patient( - _group=self.group, as_server_patient=True + patient = ServerCreatedPatientFactory(_group=self.group) + ugm = UserGroupMembershipFactory( + user_id=self.user.id, + group_id=self.group.id, + may_manage_patients=True, ) - self.create_membership(user, self.group, may_manage_patients=True) - self.dbsession.commit() - - self.req._debugging_user = user + self.req._debugging_user = ugm.user self.assertTrue(patient.user_may_edit(self.req)) def test_group_member_with_permission_may_not_edit_finalized(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() - - patient = self.create_patient( - _group=self.group, as_server_patient=False + patient = PatientFactory(_group=self.group) + ugm = UserGroupMembershipFactory( + user_id=self.user.id, + group_id=self.group.id, + may_manage_patients=True, ) - self.create_membership(user, self.group, may_manage_patients=True) - self.dbsession.commit() - - self.req._debugging_user = user + self.req._debugging_user = ugm.user self.assertFalse(patient.user_may_edit(self.req)) diff --git a/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py b/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py index f6ccd8e98..d759a9f2a 100644 --- a/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_pyramid_tests.py @@ -32,13 +32,15 @@ CamcopsAuthenticationPolicy, Permission, ) -from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase - +from camcops_server.cc_modules.cc_testfactories import ( + GroupFactory, + UserFactory, + UserGroupMembershipFactory, +) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase -class CamcopsAuthenticationPolicyTests(BasicDatabaseTestCase): - def setUp(self) -> None: - super().setUp() +class CamcopsAuthenticationPolicyTests(DemoRequestTestCase): def test_principals_for_no_user(self) -> None: self.req._debugging_user = None self.assertEqual( @@ -47,10 +49,7 @@ def test_principals_for_no_user(self) -> None: ) def test_principals_for_authenticated_user(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() - - self.req._debugging_user = user + user = self.req._debugging_user = UserFactory() self.assertIn( Authenticated, CamcopsAuthenticationPolicy.effective_principals(self.req), @@ -61,23 +60,29 @@ def test_principals_for_authenticated_user(self) -> None: ) def test_principals_when_user_must_change_pasword(self) -> None: - user = self.create_user(username="test", must_change_password=True) - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + user = self.req._debugging_user = UserFactory( + when_agreed_terms_of_use=self.req.now, + must_change_password=True, + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True + ) - self.req._debugging_user = user self.assertIn( Permission.MUST_CHANGE_PASSWORD, CamcopsAuthenticationPolicy.effective_principals(self.req), ) def test_principals_when_user_must_set_up_mfa(self) -> None: - user = self.create_user(username="test", mfa_method=MfaMethod.NO_MFA) - user.agree_terms(self.req) - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + user = self.req._debugging_user = UserFactory( + mfa_method=MfaMethod.NO_MFA, when_agreed_terms_of_use=self.req.now + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True + ) - self.req._debugging_user = user self.req.config.mfa_methods = [MfaMethod.HOTP_EMAIL] self.assertIn( Permission.MUST_SET_MFA, @@ -85,23 +90,28 @@ def test_principals_when_user_must_set_up_mfa(self) -> None: ) def test_principals_when_user_must_agree_terms(self) -> None: - user = self.create_user(username="test", when_agreed_terms_of_use=None) - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + user = self.req._debugging_user = UserFactory( + when_agreed_terms_of_use=None + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True + ) - self.req._debugging_user = user self.assertIn( Permission.MUST_AGREE_TERMS, CamcopsAuthenticationPolicy.effective_principals(self.req), ) def test_principals_when_everything_ok(self) -> None: - user = self.create_user(username="test", mfa_method=MfaMethod.NO_MFA) - user.agree_terms(self.req) - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + user = self.req._debugging_user = UserFactory( + mfa_method=MfaMethod.NO_MFA, when_agreed_terms_of_use=self.req.now + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True + ) - self.req._debugging_user = user self.req.config.mfa_methods = [MfaMethod.NO_MFA] self.assertIn( Permission.HAPPY, @@ -109,19 +119,19 @@ def test_principals_when_everything_ok(self) -> None: ) def test_principals_for_superuser(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + self.req._debugging_user = UserFactory(superuser=True) - self.req._debugging_user = user self.assertIn( Permission.SUPERUSER, CamcopsAuthenticationPolicy.effective_principals(self.req), ) def test_principals_for_groupadmin(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() - self.create_membership(user, self.group, groupadmin=True) + user = self.req._debugging_user = UserFactory() + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, groupadmin=True + ) self.req._debugging_user = user self.assertIn( diff --git a/server/camcops_server/cc_modules/tests/cc_redcap_tests.py b/server/camcops_server/cc_modules/tests/cc_redcap_tests.py index e31229a33..c9d30b9c9 100644 --- a/server/camcops_server/cc_modules/tests/cc_redcap_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_redcap_tests.py @@ -25,7 +25,7 @@ import os import tempfile -from typing import Generator, TYPE_CHECKING +from typing import Any, Dict, Generator from unittest import mock, TestCase from pandas import DataFrame @@ -33,6 +33,10 @@ import redcap from camcops_server.cc_modules.cc_constants import ConfigParamExportRecipient +from camcops_server.cc_modules.cc_exportmodels import ( + ExportedTask, + ExportedTaskRedcap, +) from camcops_server.cc_modules.cc_exportrecipient import ExportRecipient from camcops_server.cc_modules.cc_exportrecipientinfo import ( ExportRecipientInfo, @@ -45,11 +49,17 @@ RedcapRecordStatus, RedcapTaskExporter, ) -from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase - -if TYPE_CHECKING: - from camcops_server.cc_modules.cc_patient import Patient - +from camcops_server.cc_modules.cc_testfactories import ( + NHSPatientIdNumFactory, + PatientFactory, +) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase +from camcops_server.tasks.tests.factories import ( + APEQCPFTPerinatalFactory, + BmiFactory, + KhandakerMojoMedicationTherapyFactory, + Phq9Factory, +) # ============================================================================= # Unit testing @@ -124,7 +134,7 @@ def test_raises_when_fieldmap_has_unknown_symbols(self) -> None: task = mock.Mock(tablename="bmi") fieldmap = {"pa_height": "sys.platform"} - field_dict = {} + field_dict: Dict[str, Any] = {} with self.assertRaises(RedcapExportException) as cm: exporter.transform_fields(field_dict, task, fieldmap) @@ -172,7 +182,7 @@ def test_raises_when_error_from_redcap_on_import(self) -> None: ) with self.assertRaises(RedcapExportException) as cm: - record = {} + record: Dict[str, Any] = {} exporter.upload_record(record) message = str(cm.exception) @@ -468,14 +478,19 @@ def test_raises_when_fields_missing_attributes(self) -> None: # ============================================================================= -class RedcapExportTestCase(BasicDatabaseTestCase): +class RedcapExportTestCase(DemoRequestTestCase): fieldmap = "" def setUp(self) -> None: + super().setUp() + + self.patient = PatientFactory() + self.patient_idnum = NHSPatientIdNumFactory(patient=self.patient) + recipientinfo = ExportRecipientInfo() self.recipient = ExportRecipient(recipientinfo) - self.recipient.primary_idnum = 1001 + self.recipient.primary_idnum = self.patient_idnum.which_idnum # auto increment doesn't work for BigInteger with SQLite self.recipient.id = 1 @@ -485,58 +500,12 @@ def setUp(self) -> None: ) self.write_fieldmaps(self.recipient.redcap_fieldmap_filename) - super().setUp() - def write_fieldmaps(self, filename: str) -> None: with open(filename, "w") as f: f.write(self.fieldmap) - def create_patient_with_idnum_1001(self) -> "Patient": - from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition - from camcops_server.cc_modules.cc_patient import Patient - from camcops_server.cc_modules.cc_patientidnum import PatientIdNum - - patient = Patient() - patient.id = 2 - self.apply_standard_db_fields(patient) - patient.forename = "Forename2" - patient.surname = "Surname2" - patient.dob = pendulum.parse("1975-12-12") - self.dbsession.add(patient) - - idnumdef_1001 = IdNumDefinition() - idnumdef_1001.which_idnum = 1001 - idnumdef_1001.description = "Test idnumdef 1001" - self.dbsession.add(idnumdef_1001) - self.dbsession.commit() - - patient_idnum1 = PatientIdNum() - patient_idnum1.id = 3 - self.apply_standard_db_fields(patient_idnum1) - patient_idnum1.patient_id = patient.id - patient_idnum1.which_idnum = 1001 - patient_idnum1.idnum_value = 555 - self.dbsession.add(patient_idnum1) - self.dbsession.commit() - - return patient - - -class BmiRedcapExportTestCase(RedcapExportTestCase): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - -class BmiRedcapValidFieldmapTestCase(BmiRedcapExportTestCase): +class BmiRedcapValidFieldmapTestCase(RedcapExportTestCase): fieldmap = """ @@ -559,25 +528,17 @@ class BmiRedcapExportTests(BmiRedcapValidFieldmapTestCase): related to the BMI task """ - def create_tasks(self) -> None: - from camcops_server.tasks.bmi import Bmi - - patient = self.create_patient_with_idnum_1001() - self.task = Bmi() - self.apply_standard_task_fields(self.task) - self.task.id = next(self.id_sequence) - self.task.height_m = 1.83 - self.task.mass_kg = 67.57 - self.task.patient_id = patient.id - self.dbsession.add(self.task) - self.dbsession.commit() + def setUp(self) -> None: + super().setUp() - def test_record_exported(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, + self.task = BmiFactory( + patient=self.patient, + height_m=1.83, + mass_kg=67.57, + when_created=pendulum.parse("2010-07-07"), ) + def test_record_exported(self) -> None: exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -625,14 +586,9 @@ def test_record_exported(self) -> None: rows = args[0] record = rows[0] - self.assertEqual(record["patient_id"], 555) + self.assertEqual(record["patient_id"], self.patient_idnum.idnum_value) def test_record_exported_with_non_integer_id(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -648,11 +604,6 @@ def test_record_exported_with_non_integer_id(self) -> None: self.assertEqual(exported_task_redcap.redcap_record_id, "15-123") def test_record_id_generated_when_no_autonumbering(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -678,11 +629,6 @@ def test_record_id_generated_when_no_autonumbering(self) -> None: self.assertFalse(kwargs["force_auto_number"]) def test_record_imported_when_no_existing_records(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exporter = MockRedcapTaskExporter() project = exporter.get_project() project.export_records.return_value = DataFrame() @@ -701,33 +647,17 @@ def test_record_imported_when_no_existing_records(self) -> None: class BmiRedcapUpdateTests(BmiRedcapValidFieldmapTestCase): - def create_tasks(self) -> None: - from camcops_server.tasks.bmi import Bmi - - patient = self.create_patient_with_idnum_1001() - self.task1 = Bmi() - self.apply_standard_task_fields(self.task1) - self.task1.id = next(self.id_sequence) - self.task1.height_m = 1.83 - self.task1.mass_kg = 67.57 - self.task1.patient_id = patient.id - self.dbsession.add(self.task1) - - self.task2 = Bmi() - self.apply_standard_task_fields(self.task2) - self.task2.id = next(self.id_sequence) - self.task2.height_m = 1.83 - self.task2.mass_kg = 68.5 - self.task2.patient_id = patient.id - self.dbsession.add(self.task2) - self.dbsession.commit() + def setUp(self) -> None: + super().setUp() - def test_existing_record_id_used_for_update(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, + self.task1 = BmiFactory( + patient=self.patient, + ) + self.task2 = BmiFactory( + patient=self.patient, ) + def test_existing_record_id_used_for_update(self) -> None: exporter = MockRedcapTaskExporter() project = exporter.get_project() project.export_records.return_value = DataFrame({"patient_id": []}) @@ -748,7 +678,7 @@ def test_existing_record_id_used_for_update(self) -> None: project.export_records.return_value = DataFrame( { "record_id": ["123"], - "patient_id": [555], + "patient_id": [self.patient_idnum.idnum_value], "redcap_repeat_instrument": ["bmi"], "redcap_repeat_instance": [1], } @@ -810,43 +740,26 @@ class Phq9RedcapExportTests(RedcapExportTestCase): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - while True: - yield i - i += 1 - - def create_tasks(self) -> None: - from camcops_server.tasks.phq9 import Phq9 - - patient = self.create_patient_with_idnum_1001() - self.task = Phq9() - self.apply_standard_task_fields(self.task) - self.task.id = next(self.id_sequence) - self.task.q1 = 0 - self.task.q2 = 1 - self.task.q3 = 2 - self.task.q4 = 3 - self.task.q5 = 0 - self.task.q6 = 1 - self.task.q7 = 2 - self.task.q8 = 3 - self.task.q9 = 0 - self.task.q10 = 3 - self.task.patient_id = patient.id - self.dbsession.add(self.task) - self.dbsession.commit() + def setUp(self) -> None: + super().setUp() - def test_record_exported(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, + self.task = Phq9Factory( + patient=self.patient, + q1=0, + q2=1, + q3=2, + q4=3, + q5=0, + q6=1, + q7=2, + q8=3, + q9=0, + q10=3, + when_created=pendulum.parse("2010-07-07"), ) + def test_record_exported(self) -> None: exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -883,8 +796,8 @@ def test_record_exported(self) -> None: ) self.assertEqual(record["phq9_how_difficult"], 4) self.assertEqual(record["phq9_total_score"], 12) - self.assertEqual(record["phq9_first_name"], "Forename2") - self.assertEqual(record["phq9_last_name"], "Surname2") + self.assertEqual(record["phq9_first_name"], self.patient.forename) + self.assertEqual(record["phq9_last_name"], self.patient.surname) self.assertEqual(record["phq9_date_enrolled"], "2010-07-07") self.assertEqual(record["phq9_1"], 0) @@ -905,7 +818,7 @@ def test_record_exported(self) -> None: rows = args[0] record = rows[0] - self.assertEqual(record["patient_id"], 555) + self.assertEqual(record["patient_id"], self.patient_idnum.idnum_value) class MedicationTherapyRedcapExportTests(RedcapExportTestCase): @@ -940,25 +853,14 @@ def get_id() -> Generator[int, None, None]: yield i i += 1 - def create_tasks(self) -> None: - from camcops_server.tasks.khandaker_mojo_medicationtherapy import ( - KhandakerMojoMedicationTherapy, - ) - - patient = self.create_patient_with_idnum_1001() - self.task = KhandakerMojoMedicationTherapy() - self.apply_standard_task_fields(self.task) - self.task.id = next(self.id_sequence) - self.task.patient_id = patient.id - self.dbsession.add(self.task) - self.dbsession.commit() + def setUp(self) -> None: + super().setUp() - def test_record_exported(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, + self.task = KhandakerMojoMedicationTherapyFactory( + patient=self.patient, ) + def test_record_exported(self) -> None: exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1030,46 +932,17 @@ class MultipleTaskRedcapExportTests(RedcapExportTestCase): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - while True: - yield i - i += 1 + def setUp(self) -> None: + super().setUp() - def create_tasks(self) -> None: - from camcops_server.tasks.khandaker_mojo_medicationtherapy import ( - KhandakerMojoMedicationTherapy, + self.mojo_task = KhandakerMojoMedicationTherapyFactory( + patient=self.patient ) - patient = self.create_patient_with_idnum_1001() - self.mojo_task = KhandakerMojoMedicationTherapy() - self.apply_standard_task_fields(self.mojo_task) - self.mojo_task.id = next(self.id_sequence) - self.mojo_task.patient_id = patient.id - self.dbsession.add(self.mojo_task) - self.dbsession.commit() - - from camcops_server.tasks.bmi import Bmi - - self.bmi_task = Bmi() - self.apply_standard_task_fields(self.bmi_task) - self.bmi_task.id = next(self.id_sequence) - self.bmi_task.height_m = 1.83 - self.bmi_task.mass_kg = 67.57 - self.bmi_task.patient_id = patient.id - self.dbsession.add(self.bmi_task) - self.dbsession.commit() + self.bmi_task = BmiFactory(patient=self.patient) def test_instance_ids_on_different_tasks_in_same_record(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exporter = MockRedcapTaskExporter() project = exporter.get_project() project.export_records.return_value = DataFrame({"patient_id": []}) @@ -1091,7 +964,7 @@ def test_instance_ids_on_different_tasks_in_same_record(self) -> None: project.export_records.return_value = DataFrame( { "record_id": ["123"], - "patient_id": [555], + "patient_id": [self.patient_idnum.idnum_value], "redcap_repeat_instrument": [ "khandaker_mojo_medicationtherapy" ], @@ -1115,11 +988,6 @@ def test_instance_ids_on_different_tasks_in_same_record(self) -> None: self.assertEqual(record["redcap_repeat_instance"], 1) def test_imported_into_different_events(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exporter = MockRedcapTaskExporter() project = exporter.get_project() @@ -1163,28 +1031,11 @@ def test_imported_into_different_events(self) -> None: class BadConfigurationRedcapTests(RedcapExportTestCase): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - def create_tasks(self) -> None: - from camcops_server.tasks.bmi import Bmi + def setUp(self) -> None: + super().setUp() - patient = self.create_patient_with_idnum_1001() - self.task = Bmi() - self.apply_standard_task_fields(self.task) - self.task.id = next(self.id_sequence) - self.task.height_m = 1.83 - self.task.mass_kg = 67.57 - self.task.patient_id = patient.id - self.dbsession.add(self.task) - self.dbsession.commit() + self.task = BmiFactory(patient=self.patient) class MissingInstrumentRedcapTests(BadConfigurationRedcapTests): @@ -1201,11 +1052,6 @@ class MissingInstrumentRedcapTests(BadConfigurationRedcapTests): """ def test_raises_when_instrument_missing_from_fieldmap(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1237,11 +1083,6 @@ class IncorrectRecordIdRedcapTests(BadConfigurationRedcapTests): """ def test_raises_when_record_id_is_incorrect(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1250,7 +1091,7 @@ def test_raises_when_record_id_is_incorrect(self) -> None: project.export_records.return_value = DataFrame( { "record_id": ["123"], - "patient_id": [555], + "patient_id": [self.patient_idnum.idnum_value], "redcap_repeat_instrument": ["bmi"], "redcap_repeat_instance": [1], } @@ -1281,11 +1122,6 @@ class IncorrectPatientIdRedcapTests(BadConfigurationRedcapTests): """ def test_raises_when_patient_id_is_incorrect(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1294,7 +1130,7 @@ def test_raises_when_patient_id_is_incorrect(self) -> None: project.export_records.return_value = DataFrame( { "record_id": ["123"], - "patient_id": [555], + "patient_id": [self.patient_idnum.idnum_value], "redcap_repeat_instrument": ["bmi"], "redcap_repeat_instance": [1], } @@ -1327,11 +1163,6 @@ class MissingPatientInstrumentRedcapTests(BadConfigurationRedcapTests): """ def test_raises_when_instrument_is_missing(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1362,11 +1193,6 @@ class MissingEventRedcapTests(BadConfigurationRedcapTests): """ def test_raises_for_longitudinal_project(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1400,11 +1226,6 @@ class MissingInstrumentEventRedcapTests(BadConfigurationRedcapTests): """ def test_raises_when_instrument_missing_event(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) @@ -1421,21 +1242,12 @@ def test_raises_when_instrument_missing_event(self) -> None: class AnonymousTaskRedcapTests(RedcapExportTestCase): - def create_tasks(self) -> None: - from camcops_server.tasks.apeq_cpft_perinatal import APEQCPFTPerinatal + def setUp(self) -> None: + super().setUp() - self.task = APEQCPFTPerinatal() - self.apply_standard_task_fields(self.task) - self.task.id = 1 - self.dbsession.add(self.task) - self.dbsession.commit() + self.task = APEQCPFTPerinatalFactory() def test_raises_when_task_is_anonymous(self) -> None: - from camcops_server.cc_modules.cc_exportmodels import ( - ExportedTask, - ExportedTaskRedcap, - ) - exported_task = ExportedTask(task=self.task, recipient=self.recipient) exported_task_redcap = ExportedTaskRedcap(exported_task) diff --git a/server/camcops_server/cc_modules/tests/cc_report_tests.py b/server/camcops_server/cc_modules/tests/cc_report_tests.py index c77380933..fda963dc7 100644 --- a/server/camcops_server/cc_modules/tests/cc_report_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_report_tests.py @@ -25,8 +25,10 @@ """ +import csv +import io import logging -from typing import Generator, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from cardinal_pythonlib.classes import classproperty from cardinal_pythonlib.logs import BraceStyleAdapter @@ -36,20 +38,17 @@ XlsxResponse, ) from deform.form import Form -import pendulum from pyramid.httpexceptions import HTTPBadRequest from pyramid.response import Response from sqlalchemy.orm.query import Query from sqlalchemy.sql.selectable import SelectBase from camcops_server.cc_modules.cc_report import ( - AverageScoreReport, get_all_report_classes, PlainReportType, Report, ) from camcops_server.cc_modules.cc_unittest import ( - BasicDatabaseTestCase, DemoDatabaseTestCase, DemoRequestTestCase, ) @@ -62,8 +61,6 @@ ReportParamForm, ReportParamSchema, ) - from camcops_server.cc_modules.cc_patient import Patient - from camcops_server.cc_modules.cc_patientidnum import PatientIdNum from camcops_server.cc_modules.cc_request import ( CamcopsRequest, ) @@ -142,81 +139,6 @@ def test_reports(self) -> None: pass -class AverageScoreReportTestCase(BasicDatabaseTestCase): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.patient_id_sequence = self.get_patient_id() - self.task_id_sequence = self.get_task_id() - self.patient_idnum_id_sequence = self.get_patient_idnum_id() - - def setUp(self) -> None: - super().setUp() - - self.report = self.create_report() - - def create_report(self) -> AverageScoreReport: - raise NotImplementedError( - "Report TestCase needs to implement create_report" - ) - - @staticmethod - def get_patient_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - @staticmethod - def get_task_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - @staticmethod - def get_patient_idnum_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - def create_patient(self, idnum_value: int = 333) -> "Patient": - from camcops_server.cc_modules.cc_patient import Patient - - patient = Patient() - patient.id = next(self.patient_id_sequence) - self.apply_standard_db_fields(patient) - - patient.forename = f"Forename {patient.id}" - patient.surname = f"Surname {patient.id}" - patient.dob = pendulum.parse("1950-01-01") - self.dbsession.add(patient) - - self.create_patient_idnum(patient, idnum_value) - - self.dbsession.commit() - - return patient - - def create_patient_idnum( - self, patient, idnum_value: int = 333 - ) -> "PatientIdNum": - from camcops_server.cc_modules.cc_patient import PatientIdNum - - patient_idnum = PatientIdNum() - patient_idnum.id = next(self.patient_idnum_id_sequence) - self.apply_standard_db_fields(patient_idnum) - patient_idnum.patient_id = patient.id - patient_idnum.which_idnum = self.nhs_iddef.which_idnum - patient_idnum.idnum_value = idnum_value - self.dbsession.add(patient_idnum) - - return patient_idnum - - class TestReport(Report): # noinspection PyMethodParameters @classproperty @@ -278,9 +200,6 @@ def test_render_tsv(self) -> None: self.assertIn(".tsv", response.content_disposition) - import csv - import io - reader = csv.reader( io.StringIO(response.body.decode()), dialect="excel-tab" ) diff --git a/server/camcops_server/cc_modules/tests/cc_task_tests.py b/server/camcops_server/cc_modules/tests/cc_task_tests.py index e4fc40f42..a667aa177 100644 --- a/server/camcops_server/cc_modules/tests/cc_task_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_task_tests.py @@ -34,6 +34,7 @@ from camcops_server.cc_modules.cc_dummy_database import DummyDataInserter from camcops_server.cc_modules.cc_task import Task +from camcops_server.cc_modules.cc_testfactories import UserFactory from camcops_server.cc_modules.cc_unittest import DemoDatabaseTestCase from camcops_server.cc_modules.cc_validators import validate_task_tablename @@ -46,9 +47,6 @@ class TaskTests(DemoDatabaseTestCase): - """ - Unit tests. - """ def test_query_phq9(self) -> None: self.announce("test_query_phq9") @@ -79,6 +77,8 @@ def test_all_tasks(self) -> None: ) from camcops_server.cc_modules.cc_xml import XmlElement + user = UserFactory() + subclasses = Task.all_subclasses_by_tablename() tables = [cls.tablename for cls in subclasses] log.info("Actual task table names: {!r} (n={})", tables, len(tables)) @@ -214,7 +214,7 @@ def test_all_tasks(self) -> None: t.get_rio_metadata( req, which_idnum=1, - uploading_user_id=self.user.id, + uploading_user_id=user.id, document_type="some_doc_type", ), str, diff --git a/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py b/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py index 3d0cbbfda..2ec6417b2 100644 --- a/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_taskreports_tests.py @@ -81,7 +81,6 @@ def setUp(self) -> None: UserGroupMembershipFactory( group_id=self.group_a.id, user_id=self.jim.id, may_run_reports=True ) - self.dbsession.commit() self.num_jim_tasks = self.num_01_oct_2022_bmi_tasks = 2 self.num_freda_tasks = ( @@ -214,9 +213,9 @@ def test_task_counts_by_year_and_month(self) -> None: # Default is by year and month but better to be explicit self.req.add_get_params( { - ViewParam.BY_YEAR: True, - ViewParam.BY_MONTH: True, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "true", + ViewParam.BY_MONTH: "true", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) @@ -291,9 +290,9 @@ def test_task_counts_by_year(self) -> None: self.req.add_get_params( { - ViewParam.BY_YEAR: True, - ViewParam.BY_MONTH: False, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "true", + ViewParam.BY_MONTH: "false", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) @@ -351,10 +350,10 @@ def test_task_counts_by_task(self) -> None: self.req.add_get_params( { - ViewParam.BY_YEAR: False, - ViewParam.BY_MONTH: False, - ViewParam.BY_TASK: True, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "false", + ViewParam.BY_MONTH: "false", + ViewParam.BY_TASK: "true", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) @@ -393,11 +392,11 @@ def test_task_counts_by_user(self) -> None: self.req.add_get_params( { - ViewParam.BY_YEAR: False, - ViewParam.BY_MONTH: False, - ViewParam.BY_TASK: False, - ViewParam.BY_USER: True, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "false", + ViewParam.BY_MONTH: "false", + ViewParam.BY_TASK: "false", + ViewParam.BY_USER: "true", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) @@ -434,11 +433,11 @@ def test_total_task_count_for_superuser(self) -> None: self.req.add_get_params( { - ViewParam.BY_YEAR: False, - ViewParam.BY_MONTH: False, - ViewParam.BY_TASK: False, - ViewParam.BY_USER: False, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "false", + ViewParam.BY_MONTH: "false", + ViewParam.BY_TASK: "false", + ViewParam.BY_USER: "false", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) @@ -477,11 +476,11 @@ def test_task_counts_by_day_of_month(self) -> None: self.req.add_get_params( { - ViewParam.BY_YEAR: False, - ViewParam.BY_MONTH: False, - ViewParam.BY_DAY_OF_MONTH: True, - ViewParam.BY_TASK: False, - ViewParam.VIA_INDEX: self.via_index, + ViewParam.BY_YEAR: "false", + ViewParam.BY_MONTH: "false", + ViewParam.BY_DAY_OF_MONTH: "true", + ViewParam.BY_TASK: "false", + ViewParam.VIA_INDEX: "true" if self.via_index else "false", } ) diff --git a/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py b/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py index 96ef47d4d..e64ec9883 100644 --- a/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_taskschedule_tests.py @@ -35,13 +35,17 @@ from camcops_server.cc_modules.cc_taskschedule import ( PatientTaskSchedule, PatientTaskScheduleEmail, - TaskSchedule, TaskScheduleItem, ) -from camcops_server.cc_modules.cc_unittest import ( - DemoDatabaseTestCase, - DemoRequestTestCase, +from camcops_server.cc_modules.cc_testfactories import ( + EmailFactory, + PatientTaskScheduleEmailFactory, + PatientTaskScheduleFactory, + ServerCreatedPatientFactory, + TaskScheduleFactory, + TaskScheduleItemFactory, ) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase # ============================================================================= @@ -49,38 +53,24 @@ # ============================================================================= -class TaskScheduleTests(DemoDatabaseTestCase): +class TaskScheduleTests(DemoRequestTestCase): def test_deleting_deletes_related_objects(self) -> None: - schedule = TaskSchedule() - schedule.group_id = self.group.id - self.dbsession.add(schedule) - self.dbsession.flush() - - item = TaskScheduleItem() - item.schedule_id = schedule.id - item.task_table_name = "ace3" - item.due_from = Duration(days=30) - item.due_by = Duration(days=60) - self.dbsession.add(item) - self.dbsession.flush() - - patient = self.create_patient() - - pts = PatientTaskSchedule() - pts.schedule_id = schedule.id - pts.patient_pk = patient.pk - self.dbsession.add(pts) - self.dbsession.flush() - - email = Email() - self.dbsession.add(email) - self.dbsession.flush() - - pts_email = PatientTaskScheduleEmail() - pts_email.email_id = email.id - pts_email.patient_task_schedule_id = pts.id - self.dbsession.add(pts_email) - self.dbsession.commit() + patient = ServerCreatedPatientFactory() + schedule = TaskScheduleFactory(group=patient._group) + + item = TaskScheduleItemFactory( + task_schedule=schedule, + task_table_name="ace3", + ) + + pts = PatientTaskScheduleFactory( + task_schedule=schedule, + patient=patient, + ) + + pts_email = PatientTaskScheduleEmailFactory( + patient_task_schedule=pts, + ) self.assertIsNotNone( self.dbsession.query(TaskScheduleItem) @@ -101,7 +91,7 @@ def test_deleting_deletes_related_objects(self) -> None: ) self.assertIsNotNone( self.dbsession.query(Email) - .filter(Email.id == email.id) + .filter(Email.id == pts_email.email.id) .one_or_none() ) @@ -127,114 +117,73 @@ def test_deleting_deletes_related_objects(self) -> None: ) self.assertIsNone( self.dbsession.query(Email) - .filter(Email.id == email.id) + .filter(Email.id == pts_email.email.id) .one_or_none() ) class TaskScheduleItemTests(DemoRequestTestCase): def test_description_shows_shortname_and_number_of_days(self) -> None: - item = TaskScheduleItem() - item.task_table_name = "bmi" - item.due_from = Duration(days=30) - + item = TaskScheduleItemFactory( + task_table_name="bmi", + due_from=Duration(days=30), + ) self.assertEqual(item.description(self.req), "BMI @ 30 days") def test_description_with_no_durations(self) -> None: - item = TaskScheduleItem() - item.task_table_name = "bmi" - + item = TaskScheduleItemFactory(task_table_name="bmi") self.assertEqual(item.description(self.req), "BMI @ ? days") def test_due_within_calculated_from_due_by_and_due_from(self) -> None: - item = TaskScheduleItem() - item.due_from = Duration(days=30) - item.due_by = Duration(days=50) - + item = TaskScheduleItemFactory( + due_from=Duration(days=30), + due_by=Duration(days=50), + ) self.assertEqual(item.due_within.in_days(), 20) def test_due_within_is_none_when_missing_due_by(self) -> None: - item = TaskScheduleItem() - item.due_from = Duration(days=30) - + item = TaskScheduleItemFactory(due_from=Duration(days=30)) self.assertIsNone(item.due_within) def test_due_within_calculated_when_missing_due_from(self) -> None: - item = TaskScheduleItem() - item.due_by = Duration(days=30) - + item = TaskScheduleItemFactory(due_by=Duration(days=30)) self.assertEqual(item.due_within.in_days(), 30) -class PatientTaskScheduleTests(DemoDatabaseTestCase): - def setUp(self) -> None: - super().setUp() - - import datetime - - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.dbsession.add(self.schedule) - - self.patient = self.create_patient( - id=1, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - ) - - self.pts = PatientTaskSchedule() - self.pts.schedule_id = self.schedule.id - self.pts.patient_pk = self.patient.pk - self.dbsession.add(self.pts) - self.dbsession.flush() - +class PatientTaskScheduleTests(DemoRequestTestCase): def test_email_body_contains_access_key(self) -> None: - self.schedule.email_template = "{access_key}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{access_key}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) self.assertIn( - f"{self.patient.uuid_as_proquint}", self.pts.email_body(self.req) + f"{pts.patient.uuid_as_proquint}", pts.email_body(self.req) ) def test_email_body_contains_server_url(self) -> None: - self.schedule.email_template = "{server_url}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{server_url}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) expected_url = self.req.route_url(Routes.CLIENT_API) - self.assertIn(f"{expected_url}", self.pts.email_body(self.req)) + self.assertIn(f"{expected_url}", pts.email_body(self.req)) def test_email_body_contains_patient_forename(self) -> None: - self.schedule.email_template = "{forename}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{forename}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) - self.assertIn( - f"{self.pts.patient.forename}", self.pts.email_body(self.req) - ) + self.assertIn(f"{pts.patient.forename}", pts.email_body(self.req)) def test_email_body_contains_patient_surname(self) -> None: - self.schedule.email_template = "{surname}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{surname}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) - self.assertIn( - f"{self.pts.patient.surname}", self.pts.email_body(self.req) - ) + self.assertIn(f"{pts.patient.surname}", pts.email_body(self.req)) def test_email_body_contains_android_launch_url(self) -> None: - self.schedule.email_template = "{android_launch_url}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{android_launch_url}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) - url = self.pts.email_body(self.req) + url = pts.email_body(self.req) (scheme, netloc, path, query, fragment) = urlsplit(url) self.assertEqual(scheme, UriSchemes.HTTP) self.assertEqual(netloc, "camcops.org") @@ -246,15 +195,14 @@ def test_email_body_contains_android_launch_url(self) -> None: [self.req.route_url(Routes.CLIENT_API)], ) self.assertEqual( - query_dict["default_access_key"], [self.patient.uuid_as_proquint] + query_dict["default_access_key"], [pts.patient.uuid_as_proquint] ) def test_email_body_contains_ios_launch_url(self) -> None: - self.schedule.email_template = "{ios_launch_url}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{ios_launch_url}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) - url = self.pts.email_body(self.req) + url = pts.email_body(self.req) (scheme, netloc, path, query, fragment) = urlsplit(url) self.assertEqual(scheme, "camcops") self.assertEqual(netloc, "camcops.org") @@ -266,59 +214,33 @@ def test_email_body_contains_ios_launch_url(self) -> None: [self.req.route_url(Routes.CLIENT_API)], ) self.assertEqual( - query_dict["default_access_key"], [self.patient.uuid_as_proquint] + query_dict["default_access_key"], [pts.patient.uuid_as_proquint] ) def test_email_body_disallows_invalid_template(self) -> None: - self.schedule.email_template = "{foobar}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{foobar}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) with self.assertRaises(KeyError): - self.pts.email_body(self.req) + pts.email_body(self.req) def test_email_body_disallows_accessing_properties(self) -> None: - self.schedule.email_template = "{server_url.__class__}" - self.dbsession.add(self.schedule) - self.dbsession.flush() + schedule = TaskScheduleFactory(email_template="{server_url.__class__}") + pts = PatientTaskScheduleFactory(task_schedule=schedule) with self.assertRaises(KeyError): - self.pts.email_body(self.req) + pts.email_body(self.req) def test_email_sent_false_for_no_emails(self) -> None: - self.assertFalse(self.pts.email_sent) + pts = PatientTaskScheduleFactory() + self.assertFalse(pts.email_sent) def test_email_sent_false_for_one_unsent_email(self) -> None: - email1 = Email() - email1.sent = False - self.dbsession.add(email1) - self.dbsession.flush() - pts_email1 = PatientTaskScheduleEmail() - pts_email1.email_id = email1.id - pts_email1.patient_task_schedule_id = self.pts.id - self.dbsession.add(pts_email1) - self.dbsession.commit() - - self.assertFalse(self.pts.email_sent) + email1 = EmailFactory(sent=False) + pts_email1 = PatientTaskScheduleEmailFactory(email=email1) + self.assertFalse(pts_email1.patient_task_schedule.email_sent) def test_email_sent_true_for_one_sent_email(self) -> None: - email1 = Email() - email1.sent = False - self.dbsession.add(email1) - self.dbsession.flush() - pts_email1 = PatientTaskScheduleEmail() - pts_email1.email_id = email1.id - pts_email1.patient_task_schedule_id = self.pts.id - self.dbsession.add(pts_email1) - - email2 = Email() - email2.sent = True - self.dbsession.add(email2) - self.dbsession.flush() - pts_email2 = PatientTaskScheduleEmail() - pts_email2.email_id = email2.id - pts_email2.patient_task_schedule_id = self.pts.id - self.dbsession.add(pts_email2) - self.dbsession.commit() - - self.assertTrue(self.pts.email_sent) + email1 = EmailFactory(sent=True) + pts_email1 = PatientTaskScheduleEmailFactory(email=email1) + self.assertTrue(pts_email1.patient_task_schedule.email_sent) diff --git a/server/camcops_server/cc_modules/tests/cc_user_tests.py b/server/camcops_server/cc_modules/tests/cc_user_tests.py index 9962f7fc3..96ec8d422 100644 --- a/server/camcops_server/cc_modules/tests/cc_user_tests.py +++ b/server/camcops_server/cc_modules/tests/cc_user_tests.py @@ -33,9 +33,13 @@ OBSCURE_PHONE_ASTERISKS, ) from camcops_server.cc_modules.cc_group import Group +from camcops_server.cc_modules.cc_testfactories import ( + GroupFactory, + UserFactory, + UserGroupMembershipFactory, +) from camcops_server.cc_modules.cc_unittest import ( - BasicDatabaseTestCase, - DemoDatabaseTestCase, + DemoRequestTestCase, ) from camcops_server.cc_modules.cc_user import ( SecurityAccountLockout, @@ -49,13 +53,15 @@ # ============================================================================= -class UserTests(DemoDatabaseTestCase): +class UserTests(DemoRequestTestCase): """ Unit tests. """ def test_user(self) -> None: - self.announce("test_user") + UserFactory() + GroupFactory() + req = self.req SecurityAccountLockout.delete_old_account_lockouts(req) @@ -134,7 +140,7 @@ def test_partial_email(self) -> None: ("very.unusual.”@”.unusual.com@example.com", f"v{a}m@example.com"), ) - user = self.create_user() + user = UserFactory() for email, expected_partial in tests: user.email = email @@ -143,31 +149,36 @@ def test_partial_email(self) -> None: ) def test_partial_phone_number(self) -> None: - user = self.create_user() # https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501 - user.phone_number = phonenumbers.parse("+447700900123") + user = UserFactory(phone_number=phonenumbers.parse("+447700900123")) a = OBSCURE_PHONE_ASTERISKS self.assertEqual(user.partial_phone_number, f"{a}23") -class UserPermissionTests(BasicDatabaseTestCase): +class UserPermissionTests(DemoRequestTestCase): def setUp(self) -> None: super().setUp() # Deliberately not in alphabetical order to test sorting - self.group_c = self.create_group("groupc") - self.group_b = self.create_group("groupb") - self.group_a = self.create_group("groupa") - self.group_d = self.create_group("groupd") - self.dbsession.flush() + self.group_c = GroupFactory(name="groupc") + self.group_b = GroupFactory(name="groupb") + self.group_a = GroupFactory(name="groupa") + self.group_d = GroupFactory(name="groupd") def test_groups_user_may_manage_patients_in(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_manage_patients=True) - self.create_membership(user, self.group_c, may_manage_patients=True) - self.create_membership(user, self.group_a, may_manage_patients=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_manage_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_manage_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + may_manage_patients=False, + ) self.assertEqual( [self.group_c, self.group_d], @@ -175,12 +186,17 @@ def test_groups_user_may_manage_patients_in(self) -> None: ) def test_groups_user_may_email_patients_in(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_email_patients=True) - self.create_membership(user, self.group_c, may_email_patients=True) - self.create_membership(user, self.group_a, may_email_patients=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_email_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_email_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_email_patients=False + ) self.assertEqual( [self.group_c, self.group_d], @@ -188,12 +204,17 @@ def test_groups_user_may_email_patients_in(self) -> None: ) def test_ids_of_groups_user_may_report_on(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_run_reports=False) - self.create_membership(user, self.group_c, may_run_reports=True) - self.create_membership(user, self.group_d, may_run_reports=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_run_reports=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_run_reports=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_run_reports=True + ) ids = user.ids_of_groups_user_may_report_on @@ -203,8 +224,7 @@ def test_ids_of_groups_user_may_report_on(self) -> None: self.assertNotIn(self.group_b.id, ids) def test_ids_of_groups_superuser_may_report_on(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) ids = user.ids_of_groups_user_may_report_on @@ -214,12 +234,17 @@ def test_ids_of_groups_superuser_may_report_on(self) -> None: self.assertIn(self.group_d.id, ids) def test_ids_of_groups_user_is_admin_for(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) ids = user.ids_of_groups_user_is_admin_for @@ -229,8 +254,7 @@ def test_ids_of_groups_user_is_admin_for(self) -> None: self.assertNotIn(self.group_b.id, ids) def test_ids_of_groups_superuser_is_admin_for(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) ids = user.ids_of_groups_user_is_admin_for @@ -240,12 +264,17 @@ def test_ids_of_groups_superuser_is_admin_for(self) -> None: self.assertIn(self.group_d.id, ids) def test_names_of_groups_user_is_admin_for(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) names = user.names_of_groups_user_is_admin_for @@ -255,8 +284,7 @@ def test_names_of_groups_user_is_admin_for(self) -> None: self.assertNotIn(self.group_b.name, names) def test_names_of_groups_superuser_is_admin_for(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) names = user.names_of_groups_user_is_admin_for @@ -266,33 +294,41 @@ def test_names_of_groups_superuser_is_admin_for(self) -> None: self.assertIn(self.group_d.name, names) def test_groups_user_is_admin_for(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertEqual( [self.group_c, self.group_d], user.groups_user_is_admin_for ) def test_user_may_administer_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertFalse(user.may_administer_group(self.group_a.id)) self.assertTrue(user.may_administer_group(self.group_c.id)) self.assertTrue(user.may_administer_group(self.group_d.id)) def test_superuser_may_administer_group(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_administer_group(self.group_a.id)) self.assertTrue(user.may_administer_group(self.group_b.id)) @@ -300,48 +336,68 @@ def test_superuser_may_administer_group(self) -> None: self.assertTrue(user.may_administer_group(self.group_d.id)) def test_groups_user_may_dump(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_dump_data=True) - self.create_membership(user, self.group_c, may_dump_data=True) - self.create_membership(user, self.group_a, may_dump_data=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_dump_data=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_dump_data=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_dump_data=False + ) self.assertEqual( [self.group_c, self.group_d], user.groups_user_may_dump ) def test_groups_user_may_report_on(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_run_reports=True) - self.create_membership(user, self.group_c, may_run_reports=True) - self.create_membership(user, self.group_a, may_run_reports=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_run_reports=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_run_reports=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_run_reports=False + ) self.assertEqual( [self.group_c, self.group_d], user.groups_user_may_report_on ) def test_groups_user_may_upload_into(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_upload=True) - self.create_membership(user, self.group_c, may_upload=True) - self.create_membership(user, self.group_a, may_upload=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_upload=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_upload=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=False + ) self.assertEqual( [self.group_c, self.group_d], user.groups_user_may_upload_into ) def test_groups_user_may_add_special_notes(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, may_add_notes=True) - self.create_membership(user, self.group_c, may_add_notes=True) - self.create_membership(user, self.group_a, may_add_notes=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_add_notes=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_add_notes=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_add_notes=False + ) self.assertEqual( [self.group_c, self.group_d], @@ -349,17 +405,22 @@ def test_groups_user_may_add_special_notes(self) -> None: ) def test_groups_user_may_see_all_pts_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership( - user, self.group_d, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_d.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_c, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_c.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_a, view_all_patients_when_unfiltered=False + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + view_all_patients_when_unfiltered=False, ) self.assertEqual( @@ -368,252 +429,259 @@ def test_groups_user_may_see_all_pts_when_unfiltered(self) -> None: ) def test_is_a_group_admin(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertTrue(user.is_a_groupadmin) def test_is_not_a_group_admin(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, groupadmin=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=False + ) self.assertFalse(user.is_a_groupadmin) def test_authorized_as_groupadmin(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertTrue(user.authorized_as_groupadmin) def test_not_authorized_as_groupadmin(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_d, groupadmin=False) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=False + ) self.assertFalse(user.authorized_as_groupadmin) def test_superuser_authorized_as_groupadmin(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_as_groupadmin) def test_membership_for_group_id(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - ugm = self.create_membership(user, self.group_a) + ugm = UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id + ) self.assertEqual(user.membership_for_group_id(self.group_a.id), ugm) def test_no_membership_for_group_id(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertIsNone(user.membership_for_group_id(self.group_a.id)) def test_may_use_webviewer(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_use_webviewer=False) - self.create_membership(user, self.group_c, may_use_webviewer=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_use_webviewer=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_use_webviewer=True + ) self.assertTrue(user.may_use_webviewer) def test_may_not_use_webviewer(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.may_use_webviewer) def test_superuser_may_use_webviewer(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_use_webviewer) def test_authorized_to_add_special_note(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_c, may_add_notes=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_add_notes=True + ) self.assertTrue(user.authorized_to_add_special_note(self.group_c.id)) def test_not_authorized_to_add_special_note(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_c, may_add_notes=False) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_add_notes=False + ) self.assertFalse(user.authorized_to_add_special_note(self.group_c.id)) def test_superuser_authorized_to_add_special_note(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_to_add_special_note(self.group_c.id)) def test_groupadmin_authorized_to_erase_tasks(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_c, groupadmin=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) self.assertTrue(user.authorized_to_erase_tasks(self.group_c.id)) def test_non_member_not_authorized_to_erase_tasks(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=True + ) self.assertFalse(user.authorized_to_erase_tasks(self.group_c.id)) def test_non_admin_not_authorized_to_erase_tasks(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_c) - self.dbsession.commit() + UserGroupMembershipFactory(user_id=user.id, group_id=self.group_c.id) self.assertFalse(user.authorized_to_erase_tasks(self.group_c.id)) def test_superuser_authorized_to_erase_tasks(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_to_erase_tasks(self.group_c.id)) def test_authorized_to_dump(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_dump_data=False) - self.create_membership(user, self.group_c, may_dump_data=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_dump_data=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_dump_data=True + ) self.assertTrue(user.authorized_to_dump) def test_not_authorized_to_dump(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.authorized_to_dump) def test_superuser_authorized_to_dump(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_to_dump) def test_authorized_for_reports(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_run_reports=False) - self.create_membership(user, self.group_c, may_run_reports=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_run_reports=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_run_reports=True + ) self.assertTrue(user.authorized_for_reports) def test_not_authorized_for_reports(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.authorized_for_reports) def test_superuser_authorized_for_reports(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_for_reports) def test_may_view_all_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership( - user, self.group_a, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_c, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_c.id, + view_all_patients_when_unfiltered=True, ) - self.dbsession.commit() self.assertTrue(user.may_view_all_patients_when_unfiltered) def test_may_not_view_all_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership( - user, self.group_a, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_c, view_all_patients_when_unfiltered=False + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_c.id, + view_all_patients_when_unfiltered=False, ) - self.dbsession.commit() self.assertFalse(user.may_view_all_patients_when_unfiltered) def test_superuser_may_view_all_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_view_all_patients_when_unfiltered) def test_may_view_no_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertTrue(user.may_view_no_patients_when_unfiltered) def test_may_not_view_no_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership( - user, self.group_a, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_c, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_c.id, + view_all_patients_when_unfiltered=True, ) - self.dbsession.commit() self.assertFalse(user.may_view_no_patients_when_unfiltered) def test_superuser_may_not_view_no_patients_when_unfiltered(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertFalse(user.may_view_no_patients_when_unfiltered) def test_group_ids_that_nonsuperuser_may_see_when_unfiltered(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership( - user, self.group_a, view_all_patients_when_unfiltered=False + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + view_all_patients_when_unfiltered=False, ) - self.create_membership( - user, self.group_c, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_c.id, + view_all_patients_when_unfiltered=True, ) - self.create_membership( - user, self.group_d, view_all_patients_when_unfiltered=True + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_d.id, + view_all_patients_when_unfiltered=True, ) ids = user.group_ids_nonsuperuser_may_see_when_unfiltered() @@ -624,165 +692,167 @@ def test_group_ids_that_nonsuperuser_may_see_when_unfiltered(self) -> None: self.assertNotIn(self.group_b.id, ids) def test_may_upload_to_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_upload=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=True + ) self.assertTrue(user.may_upload_to_group(self.group_a.id)) def test_may_not_upload_to_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_upload=False) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=False + ) self.assertFalse(user.may_upload_to_group(self.group_a.id)) def test_superuser_may_upload_to_group(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_upload_to_group(self.group_a.id)) def test_may_upload_to_upload_group(self) -> None: - user = self.create_user( - username="test", upload_group_id=self.group_a.id - ) - self.dbsession.flush() + user = UserFactory(upload_group_id=self.group_a.id) - self.create_membership(user, self.group_a, may_upload=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=True + ) self.assertTrue(user.may_upload) def test_may_not_upload_with_no_upload_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_upload=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=True + ) self.assertFalse(user.may_upload) def test_may_not_upload_with_upload_group_but_no_permission(self) -> None: - user = self.create_user( - username="test", upload_group_id=self.group_a.id - ) - self.dbsession.flush() + user = UserFactory(upload_group_id=self.group_a.id) - self.create_membership(user, self.group_a, may_upload=False) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_upload=False + ) self.assertFalse(user.may_upload) def test_may_register_devices_with_upload_group(self) -> None: - user = self.create_user( - username="test", upload_group_id=self.group_a.id - ) - self.dbsession.flush() + user = UserFactory(upload_group_id=self.group_a.id) - self.create_membership(user, self.group_a, may_register_devices=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + may_register_devices=True, + ) self.assertTrue(user.may_register_devices) def test_may_not_register_devices_with_no_upload_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.may_register_devices) def test_may_not_register_devices_with_upload_group_but_no_permission( self, ) -> None: - user = self.create_user( - username="test", upload_group_id=self.group_a.id - ) - self.dbsession.flush() + user = UserFactory(upload_group_id=self.group_a.id) - self.create_membership(user, self.group_a, may_register_devices=False) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + may_register_devices=False, + ) self.assertFalse(user.may_register_devices) def test_superuser_may_register_devices_with_upload_group(self) -> None: - user = self.create_user( - username="test", upload_group_id=self.group_a.id, superuser=True - ) - self.dbsession.flush() + user = UserFactory(upload_group_id=self.group_a.id, superuser=True) self.assertTrue(user.may_register_devices) def test_superuser_may_not_register_devices_with_no_upload_group( self, ) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertFalse(user.may_register_devices) def test_authorized_to_manage_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_manage_patients=False) - self.create_membership(user, self.group_c, may_manage_patients=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + may_manage_patients=False, + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_manage_patients=True + ) self.assertTrue(user.authorized_to_manage_patients) def test_not_authorized_to_manage_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.authorized_to_manage_patients) def test_groupadmin_authorized_to_manage_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=True + ) self.assertTrue(user.authorized_to_manage_patients) def test_superuser_authorized_to_manage_patients(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_to_manage_patients) def test_user_may_manage_patients_in_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_manage_patients=False) - self.create_membership(user, self.group_c, may_manage_patients=True) - self.create_membership(user, self.group_d, may_manage_patients=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group_a.id, + may_manage_patients=False, + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_manage_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_manage_patients=True + ) self.assertFalse(user.may_manage_patients_in_group(self.group_a.id)) self.assertTrue(user.may_manage_patients_in_group(self.group_c.id)) self.assertTrue(user.may_manage_patients_in_group(self.group_d.id)) def test_groupadmin_may_manage_patients_in_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertFalse(user.may_manage_patients_in_group(self.group_a.id)) self.assertTrue(user.may_manage_patients_in_group(self.group_c.id)) self.assertTrue(user.may_manage_patients_in_group(self.group_d.id)) def test_superuser_may_manage_patients_in_group(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_manage_patients_in_group(self.group_a.id)) self.assertTrue(user.may_manage_patients_in_group(self.group_b.id)) @@ -790,66 +860,104 @@ def test_superuser_may_manage_patients_in_group(self) -> None: self.assertTrue(user.may_manage_patients_in_group(self.group_d.id)) def test_authorized_to_email_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_email_patients=False) - self.create_membership(user, self.group_c, may_email_patients=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_email_patients=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_email_patients=True + ) self.assertTrue(user.authorized_to_email_patients) def test_not_authorized_to_email_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() self.assertFalse(user.authorized_to_email_patients) def test_groupadmin_authorized_to_email_patients(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=True) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=True + ) self.assertTrue(user.authorized_to_email_patients) def test_superuser_authorized_to_email_patients(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.authorized_to_email_patients) def test_user_may_email_patients_in_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, may_email_patients=False) - self.create_membership(user, self.group_c, may_email_patients=True) - self.create_membership(user, self.group_d, may_email_patients=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, may_email_patients=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, may_email_patients=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, may_email_patients=True + ) self.assertFalse(user.may_email_patients_in_group(self.group_a.id)) self.assertTrue(user.may_email_patients_in_group(self.group_c.id)) self.assertTrue(user.may_email_patients_in_group(self.group_d.id)) def test_groupadmin_may_email_patients_in_group(self) -> None: - user = self.create_user(username="test") - self.dbsession.flush() + user = UserFactory() - self.create_membership(user, self.group_a, groupadmin=False) - self.create_membership(user, self.group_c, groupadmin=True) - self.create_membership(user, self.group_d, groupadmin=True) - self.dbsession.commit() + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_a.id, groupadmin=False + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_c.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group_d.id, groupadmin=True + ) self.assertFalse(user.may_email_patients_in_group(self.group_a.id)) self.assertTrue(user.may_email_patients_in_group(self.group_c.id)) self.assertTrue(user.may_email_patients_in_group(self.group_d.id)) def test_superuser_may_email_patients_in_group(self) -> None: - user = self.create_user(username="test", superuser=True) - self.dbsession.flush() + user = UserFactory(superuser=True) self.assertTrue(user.may_email_patients_in_group(self.group_a.id)) self.assertTrue(user.may_email_patients_in_group(self.group_b.id)) self.assertTrue(user.may_email_patients_in_group(self.group_c.id)) self.assertTrue(user.may_email_patients_in_group(self.group_d.id)) + + +class SetGroupIdsTests(DemoRequestTestCase): + def test_old_group_ids_removed(self) -> None: + group_a = GroupFactory() + group_b = GroupFactory() + + user = UserFactory() + + UserGroupMembershipFactory(user_id=user.id, group_id=group_a.id) + UserGroupMembershipFactory(user_id=user.id, group_id=group_b.id) + self.assertEqual(len(user.user_group_memberships), 2) + + user.set_group_ids([]) + self.dbsession.refresh(user) + + self.assertEqual(len(user.user_group_memberships), 0) + + def test_new_group_ids_added(self) -> None: + group_a = GroupFactory() + group_b = GroupFactory() + + user = UserFactory() + + self.assertEqual(len(user.user_group_memberships), 0) + + user.set_group_ids([group_a.id, group_b.id]) + self.dbsession.refresh(user) + + self.assertEqual(len(user.user_group_memberships), 2) diff --git a/server/camcops_server/cc_modules/tests/client_api_tests.py b/server/camcops_server/cc_modules/tests/client_api_tests.py index 8ccf083d4..119c99c66 100644 --- a/server/camcops_server/cc_modules/tests/client_api_tests.py +++ b/server/camcops_server/cc_modules/tests/client_api_tests.py @@ -38,6 +38,8 @@ from cardinal_pythonlib.nhs import generate_random_nhs_number from cardinal_pythonlib.sql.literals import sql_quote_string from cardinal_pythonlib.text import escape_newlines, unescape_newlines +from pendulum import DateTime as Pendulum, Duration, local, parse + from pyramid.response import Response from camcops_server.cc_modules.cc_client_api_core import ( @@ -49,12 +51,25 @@ UserErrorException, ) from camcops_server.cc_modules.cc_convert import decode_values -from camcops_server.cc_modules.cc_ipuse import IpUse from camcops_server.cc_modules.cc_proquint import uuid_from_proquint -from camcops_server.cc_modules.cc_unittest import ( - BasicDatabaseTestCase, - DemoDatabaseTestCase, +from camcops_server.cc_modules.cc_taskindex import ( + PatientIdNumIndexEntry, + TaskIndexEntry, +) +from camcops_server.cc_modules.cc_testfactories import ( + DeviceFactory, + GroupFactory, + NHSPatientIdNumFactory, + PatientFactory, + PatientTaskScheduleFactory, + ServerCreatedNHSPatientIdNumFactory, + ServerCreatedPatientFactory, + TaskScheduleFactory, + TaskScheduleItemFactory, + UserFactory, + UserGroupMembershipFactory, ) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase from camcops_server.cc_modules.cc_user import User from camcops_server.cc_modules.cc_version import MINIMUM_TABLET_VERSION from camcops_server.cc_modules.cc_validators import ( @@ -63,11 +78,12 @@ from camcops_server.cc_modules.client_api import ( client_api, FAILURE_CODE, + get_or_create_single_user, make_single_user_mode_username, Operations, SUCCESS_CODE, ) - +from camcops_server.tasks.tests.factories import BmiFactory TEST_NHS_NUMBER = generate_random_nhs_number() @@ -101,14 +117,8 @@ def get_reply_dict_from_response(response: Response) -> Dict[str, str]: return {} -class ClientApiTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class ClientApiTests(DemoRequestTestCase): def test_client_api_basics(self) -> None: - self.announce("test_client_api_basics") - with self.assertRaises(UserErrorException): fail_user_error("testmsg") with self.assertRaises(ServerErrorException): @@ -171,10 +181,12 @@ def test_client_api_basics(self) -> None: # TODO: client_api.ClientApiTests: more tests here... ? def test_non_existent_table_rejected(self) -> None: + device = DeviceFactory() + self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: device.name, TabletParam.OPERATION: Operations.WHICH_KEYS_TO_SEND, TabletParam.TABLE: "nonexistent_table", } @@ -184,7 +196,6 @@ def test_non_existent_table_rejected(self) -> None: self.assertEqual(d[TabletParam.SUCCESS], FAILURE_CODE) def test_client_api_validators(self) -> None: - self.announce("test_client_api_validators") for x in class_attribute_names(Operations): try: validate_alphanum_underscore(x, self.req) @@ -192,38 +203,26 @@ def test_client_api_validators(self) -> None: self.fail(f"Operations.{x} fails validate_alphanum_underscore") -class PatientRegistrationTests(BasicDatabaseTestCase): - def test_returns_patient_info(self) -> None: - import datetime +class PatientRegistrationTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() - patient = self.create_patient( - forename="JO", - surname="PATIENT", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - as_server_patient=True, - ) + self.device = DeviceFactory() - self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - as_server_patient=True, - ) + def test_returns_patient_info(self) -> None: + patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient) proquint = patient.uuid_as_proquint # For type checker assert proquint is not None - assert self.other_device.name is not None + assert self.device is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -237,43 +236,34 @@ def test_returns_patient_info(self) -> None: patient_dict = json.loads(reply_dict[TabletParam.PATIENT_INFO])[0] - self.assertEqual(patient_dict[TabletParam.SURNAME], "PATIENT") - self.assertEqual(patient_dict[TabletParam.FORENAME], "JO") - self.assertEqual(patient_dict[TabletParam.SEX], "F") - self.assertEqual(patient_dict[TabletParam.DOB], "1958-04-19") - self.assertEqual(patient_dict[TabletParam.ADDRESS], "Address") - self.assertEqual(patient_dict[TabletParam.GP], "GP") - self.assertEqual(patient_dict[TabletParam.OTHER], "Other") + self.assertEqual(patient_dict[TabletParam.SURNAME], patient.surname) + self.assertEqual(patient_dict[TabletParam.FORENAME], patient.forename) + self.assertEqual(patient_dict[TabletParam.SEX], patient.sex) self.assertEqual( - patient_dict[f"idnum{self.nhs_iddef.which_idnum}"], TEST_NHS_NUMBER + patient_dict[TabletParam.DOB], patient.dob.isoformat() ) - - def test_creates_user(self) -> None: - from camcops_server.cc_modules.cc_taskindex import ( - PatientIdNumIndexEntry, + self.assertEqual(patient_dict[TabletParam.ADDRESS], patient.address) + self.assertEqual(patient_dict[TabletParam.GP], patient.gp) + self.assertEqual(patient_dict[TabletParam.OTHER], patient.other) + self.assertEqual( + patient_dict[f"idnum{idnum.which_idnum}"], idnum.idnum_value ) - patient = self.create_patient( - _group_id=self.group.id, as_server_patient=True - ) - idnum = self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - as_server_patient=True, - ) + def test_creates_user(self) -> None: + patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) proquint = patient.uuid_as_proquint # For type checker assert proquint is not None - assert self.other_device.name is not None + assert self.device.name is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -288,9 +278,7 @@ def test_creates_user(self) -> None: username = reply_dict[TabletParam.USER] self.assertEqual( username, - make_single_user_mode_username( - self.other_device.name, patient._pk - ), + make_single_user_mode_username(self.device.name, patient._pk), ) password = reply_dict[TabletParam.PASSWORD] self.assertEqual(len(password), 32) @@ -310,36 +298,26 @@ def test_creates_user(self) -> None: self.assertTrue(user.may_upload) def test_does_not_create_user_when_name_exists(self) -> None: - from camcops_server.cc_modules.cc_taskindex import ( - PatientIdNumIndexEntry, - ) - - patient = self.create_patient( - _group_id=self.group.id, as_server_patient=True - ) - idnum = self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - as_server_patient=True, - ) + patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) proquint = patient.uuid_as_proquint - user = User( - username=make_single_user_mode_username( - self.other_device.name, patient._pk - ) + single_user_username = make_single_user_mode_username( + self.device.name, patient._pk + ) + + user = UserFactory( + username=single_user_username, + password="old password", + password__request=self.req, ) - user.set_password(self.req, "old password") - self.dbsession.add(user) - self.dbsession.commit() self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -354,9 +332,7 @@ def test_does_not_create_user_when_name_exists(self) -> None: username = reply_dict[TabletParam.USER] self.assertEqual( username, - make_single_user_mode_username( - self.other_device.name, patient._pk - ), + make_single_user_mode_username(self.device.name, patient._pk), ) password = reply_dict[TabletParam.PASSWORD] self.assertEqual(len(password), 32) @@ -377,12 +353,12 @@ def test_does_not_create_user_when_name_exists(self) -> None: def test_raises_for_invalid_proquint(self) -> None: # For type checker - assert self.other_device.name is not None + assert self.device.name is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: "invalid", } @@ -405,12 +381,12 @@ def test_raises_for_missing_valid_proquint(self) -> None: # test proquint really is valid (should not raise) uuid_from_proquint(valid_proquint) - assert self.other_device.name is not None + assert self.device.name is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: valid_proquint, } @@ -429,13 +405,13 @@ def test_raises_for_missing_valid_proquint(self) -> None: def test_raises_when_no_patient_idnums(self) -> None: # In theory this shouldn't be possible in normal operation as the # patient cannot be created without any idnums - patient = self.create_patient(as_server_patient=True) + patient = ServerCreatedPatientFactory() proquint = patient.uuid_as_proquint self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -451,15 +427,13 @@ def test_raises_when_no_patient_idnums(self) -> None: ) def test_raises_when_patient_not_created_on_server(self) -> None: - patient = self.create_patient( - _device_id=self.other_device.id, as_server_patient=True - ) + patient = PatientFactory() proquint = patient.uuid_as_proquint self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -476,49 +450,21 @@ def test_raises_when_patient_not_created_on_server(self) -> None: ) def test_returns_ip_use_flags(self) -> None: - import datetime - from camcops_server.cc_modules.cc_taskindex import ( - PatientIdNumIndexEntry, - ) - - patient = self.create_patient( - forename="JO", - surname="PATIENT", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - as_server_patient=True, - ) - idnum = self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - as_server_patient=True, - ) + patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) - - patient.group.ip_use = IpUse() - - patient.group.ip_use.commercial = True - patient.group.ip_use.clinical = True - patient.group.ip_use.educational = False - patient.group.ip_use.research = False - - self.dbsession.add(patient.group) - self.dbsession.commit() + ip_use = patient.group.ip_use proquint = patient.uuid_as_proquint # For type checker assert proquint is not None - assert self.other_device.name is not None + assert self.device.name is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: self.device.name, TabletParam.OPERATION: Operations.REGISTER_PATIENT, TabletParam.PATIENT_PROQUINT: proquint, } @@ -532,123 +478,113 @@ def test_returns_ip_use_flags(self) -> None: ip_use_info = json.loads(reply_dict[TabletParam.IP_USE_INFO]) - self.assertEqual(ip_use_info[TabletParam.IP_USE_COMMERCIAL], 1) - self.assertEqual(ip_use_info[TabletParam.IP_USE_CLINICAL], 1) - self.assertEqual(ip_use_info[TabletParam.IP_USE_EDUCATIONAL], 0) - self.assertEqual(ip_use_info[TabletParam.IP_USE_RESEARCH], 0) + self.assertEqual( + ip_use_info[TabletParam.IP_USE_COMMERCIAL], ip_use.commercial + ) + self.assertEqual( + ip_use_info[TabletParam.IP_USE_CLINICAL], ip_use.clinical + ) + self.assertEqual( + ip_use_info[TabletParam.IP_USE_EDUCATIONAL], ip_use.educational + ) + self.assertEqual( + ip_use_info[TabletParam.IP_USE_RESEARCH], ip_use.research + ) + + +class GetTaskSchedulesTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() + + self.group = GroupFactory() + user = self.req._debugging_user = UserFactory( + upload_group_id=self.group.id, + ) + UserGroupMembershipFactory( + user_id=user.id, + group_id=self.group.id, + may_register_devices=True, + ) -class GetTaskSchedulesTests(BasicDatabaseTestCase): def test_returns_task_schedules(self) -> None: - from pendulum import DateTime as Pendulum, Duration, local, parse - - from camcops_server.cc_modules.cc_taskindex import ( - PatientIdNumIndexEntry, - TaskIndexEntry, - ) - from camcops_server.cc_modules.cc_taskschedule import ( - PatientTaskSchedule, - TaskSchedule, - TaskScheduleItem, - ) - from camcops_server.tasks.bmi import Bmi - - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - - schedule2 = TaskSchedule() - schedule2.group_id = self.group.id - self.dbsession.add(schedule2) - self.dbsession.commit() - - item1 = TaskScheduleItem() - item1.schedule_id = schedule1.id - item1.task_table_name = "phq9" - item1.due_from = Duration(days=0) - item1.due_by = Duration(days=7) - self.dbsession.add(item1) - - item2 = TaskScheduleItem() - item2.schedule_id = schedule1.id - item2.task_table_name = "bmi" - item2.due_from = Duration(days=0) - item2.due_by = Duration(days=8) - self.dbsession.add(item2) - - item3 = TaskScheduleItem() - item3.schedule_id = schedule1.id - item3.task_table_name = "phq9" - item3.due_from = Duration(days=30) - item3.due_by = Duration(days=37) - self.dbsession.add(item3) - - item4 = TaskScheduleItem() - item4.schedule_id = schedule1.id - item4.task_table_name = "gmcpq" - item4.due_from = Duration(days=30) - item4.due_by = Duration(days=38) - self.dbsession.add(item4) - self.dbsession.commit() - - patient = self.create_patient() - idnum = self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, + schedule1 = TaskScheduleFactory(group=self.group) + schedule2 = TaskScheduleFactory(group=self.group) + + TaskScheduleItemFactory( + task_schedule=schedule1, + task_table_name="phq9", + due_from=Duration(days=0), + due_by=Duration(days=7), + ) + TaskScheduleItemFactory( + task_schedule=schedule1, + task_table_name="bmi", + due_from=Duration(days=0), + due_by=Duration(days=8), + ) + TaskScheduleItemFactory( + task_schedule=schedule1, + task_table_name="phq9", + due_from=Duration(days=30), + due_by=Duration(days=37), + ) + TaskScheduleItemFactory( + task_schedule=schedule1, + task_table_name="gmcpq", + due_from=Duration(days=30), + due_by=Duration(days=38), + ) + + # This is the patient originally created om the server + server_patient = ServerCreatedPatientFactory(_group=self.group) + server_idnum = ServerCreatedNHSPatientIdNumFactory( + patient=server_patient + ) + + # This is the same patient but from the device + patient = PatientFactory(_group=self.group) + idnum = NHSPatientIdNumFactory( + patient=patient, + which_idnum=server_idnum.which_idnum, + idnum_value=server_idnum.idnum_value, ) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) - server_patient = self.create_patient(as_server_patient=True) - _ = self.create_patient_idnum( - patient_id=server_patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER, - as_server_patient=True, + PatientTaskScheduleFactory( + patient=server_patient, + task_schedule=schedule1, + settings={ + "bmi": {"bmi_key": "bmi_value"}, + "phq9": {"phq9_key": "phq9_value"}, + }, + start_datetime=local(2020, 7, 31), ) - schedule_1 = PatientTaskSchedule() - schedule_1.patient_pk = server_patient.pk - schedule_1.schedule_id = schedule1.id - schedule_1.settings = { - "bmi": {"bmi_key": "bmi_value"}, - "phq9": {"phq9_key": "phq9_value"}, - } - schedule_1.start_datetime = local(2020, 7, 31) - self.dbsession.add(schedule_1) - - schedule_2 = PatientTaskSchedule() - schedule_2.patient_pk = server_patient.pk - schedule_2.schedule_id = schedule2.id - self.dbsession.add(schedule_2) - - bmi = Bmi() - self.apply_standard_task_fields(bmi) - bmi.id = 1 - bmi.height_m = 1.83 - bmi.mass_kg = 67.57 - bmi.patient_id = patient.id - bmi.when_created = local(2020, 8, 1) - self.dbsession.add(bmi) - self.dbsession.commit() + PatientTaskScheduleFactory( + patient=server_patient, + task_schedule=schedule2, + ) + + bmi = BmiFactory( + patient=patient, + when_created=local(2020, 8, 1), + ) self.assertTrue(bmi.is_complete()) TaskIndexEntry.index_task( bmi, self.dbsession, indexed_at_utc=Pendulum.utcnow() ) - self.dbsession.commit() proquint = server_patient.uuid_as_proquint # For type checker assert proquint is not None - assert self.other_device.name is not None self.req.fake_request_post_from_dict( { TabletParam.CAMCOPS_VERSION: MINIMUM_TABLET_VERSION, - TabletParam.DEVICE: self.other_device.name, + TabletParam.DEVICE: patient._device.name, TabletParam.OPERATION: Operations.GET_TASK_SCHEDULES, TabletParam.PATIENT_PROQUINT: proquint, } @@ -665,7 +601,7 @@ def test_returns_task_schedules(self) -> None: self.assertEqual(len(task_schedules), 2) s = task_schedules[0] - self.assertEqual(s[TabletParam.TASK_SCHEDULE_NAME], "Test 1") + self.assertEqual(s[TabletParam.TASK_SCHEDULE_NAME], schedule1.name) schedule_items = s[TabletParam.TASK_SCHEDULE_ITEMS] self.assertEqual(len(schedule_items), 4) @@ -715,3 +651,74 @@ def test_returns_task_schedules(self) -> None: # GMCPQ gmcpq_sched = schedule_items[3] self.assertTrue(gmcpq_sched[TabletParam.ANONYMOUS]) + + +class GetOrCreateSingleUserTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() + + self.patient = PatientFactory() + self.req._debugging_user = UserFactory() + + def test_user_is_added_to_patient_group(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertIn(self.patient.group.id, user.group_ids) + + def test_user_is_created_with_username(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertEqual(user.username, "test") + + def test_user_is_assigned_password(self) -> None: + _, password = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + valid_chars = string.ascii_letters + string.digits + string.punctuation + self.assertTrue(all(c in valid_chars for c in password)) + + def test_user_upload_group_set(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertEqual(user.upload_group, self.patient.group) + + def test_user_auto_generated_flag_set(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertTrue(user.auto_generated) + + def test_user_is_not_superuser(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertFalse(user.superuser) + + def test_single_patient_pk_set(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertEqual(user.single_patient_pk, self.patient._pk) + + def test_user_may_register_devices(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertTrue(user.user_group_memberships[0].may_register_devices) + + def test_user_may_upload(self) -> None: + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertTrue(user.user_group_memberships[0].may_upload) + + def test_existing_user_is_updated(self) -> None: + existing_user = UserFactory(username="test") + + user, _ = get_or_create_single_user(self.req, "test", self.patient) + self.dbsession.flush() + + self.assertEqual(user, existing_user) diff --git a/server/camcops_server/cc_modules/tests/webview_tests.py b/server/camcops_server/cc_modules/tests/webview_tests.py index fb4bdeb56..0c4fdd23c 100644 --- a/server/camcops_server/cc_modules/tests/webview_tests.py +++ b/server/camcops_server/cc_modules/tests/webview_tests.py @@ -35,8 +35,7 @@ from cardinal_pythonlib.classes import class_attribute_names from cardinal_pythonlib.httpconst import MimeType -from cardinal_pythonlib.nhs import generate_random_nhs_number -from pendulum import local +from pendulum import Duration, local import phonenumbers import pyotp from pyramid.httpexceptions import HTTPBadRequest, HTTPFound @@ -49,7 +48,6 @@ ) from camcops_server.cc_modules.cc_device import Device from camcops_server.cc_modules.cc_group import Group -from camcops_server.cc_modules.cc_membership import UserGroupMembership from camcops_server.cc_modules.cc_patient import Patient from camcops_server.cc_modules.cc_patientidnum import PatientIdNum from camcops_server.cc_modules.cc_pyramid import ( @@ -66,9 +64,27 @@ TaskSchedule, TaskScheduleItem, ) +from camcops_server.cc_modules.cc_testfactories import ( + AnyIdNumGroupFactory, + Fake, + GroupFactory, + NHSIdNumDefinitionFactory, + NHSPatientIdNumFactory, + PatientFactory, + PatientTaskScheduleFactory, + RioIdNumDefinitionFactory, + ServerCreatedNHSPatientIdNumFactory, + ServerCreatedPatientFactory, + StudyPatientIdNumFactory, + TaskScheduleFactory, + TaskScheduleItemFactory, + UserFactory, + UserGroupMembershipFactory, +) from camcops_server.cc_modules.cc_unittest import ( BasicDatabaseTestCase, DemoDatabaseTestCase, + DemoRequestTestCase, ) from camcops_server.cc_modules.cc_user import ( SecurityAccountLockout, @@ -79,11 +95,13 @@ validate_alphanum_underscore, ) from camcops_server.cc_modules.cc_view_classes import FormWizardMixin +from camcops_server.tasks.tests.factories import BmiFactory from camcops_server.cc_modules.tests.cc_view_classes_tests import ( TestStateMixin, ) from camcops_server.cc_modules.webview import ( add_patient, + add_user, AddPatientView, AddTaskScheduleItemView, AddTaskScheduleView, @@ -123,30 +141,13 @@ UTF8 = "utf-8" -TEST_NHS_NUMBER_1 = generate_random_nhs_number() -TEST_NHS_NUMBER_2 = generate_random_nhs_number() - -# https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501 -# 07700 900000 to 900999 reserved for TV and Radio drama purposes -# but unfortunately phonenumbers considers these invalid. However, it offers -# some examples: -TEST_PHONE_NUMBER = "+{ctry}{tel}".format( - ctry=phonenumbers.PhoneMetadata.metadata_for_region("GB").country_code, - tel=phonenumbers.PhoneMetadata.metadata_for_region( - "GB" - ).personal_number.example_number, -) - class WebviewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - def test_any_records_use_group_true(self) -> None: # All tasks created in DemoDatabaseTestCase will be in this group - self.announce("test_any_records_use_group_true") - self.assertTrue(any_records_use_group(self.req, self.group)) + self.assertTrue( + any_records_use_group(self.req, self.demo_database_group) + ) def test_any_records_use_group_false(self) -> None: """ @@ -156,15 +157,11 @@ def test_any_records_use_group_false(self) -> None: then the base class probably needs to be declared __abstract__. See DiagnosisItemBase as an example. """ - self.announce("test_any_records_use_group_false") - group = Group() - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory() self.assertFalse(any_records_use_group(self.req, group)) def test_webview_constant_validators(self) -> None: - self.announce("test_webview_constant_validators") for x in class_attribute_names(ViewArg): try: validate_alphanum_underscore(x, self.req) @@ -172,11 +169,7 @@ def test_webview_constant_validators(self) -> None: self.fail(f"Operations.{x} fails validate_alphanum_underscore") -class AddTaskScheduleViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class AddTaskScheduleViewTests(BasicDatabaseTestCase): def test_schedule_form_displayed(self) -> None: view = AddTaskScheduleView(self.req) @@ -222,35 +215,29 @@ def test_schedule_is_created(self) -> None: ) -class EditTaskScheduleViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - - def setUp(self) -> None: - super().setUp() - - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test" - self.dbsession.add(self.schedule) - self.dbsession.commit() - +class EditTaskScheduleViewTests(DemoRequestTestCase): def test_schedule_name_can_be_updated(self) -> None: + user = self.req._debugging_user = UserFactory() + group = GroupFactory() + UserGroupMembershipFactory( + group_id=group.id, user_id=user.id, groupadmin=True + ) + + schedule = TaskScheduleFactory(group=group) multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), (ViewParam.NAME, "MOJO"), - (ViewParam.GROUP_ID, self.group.id), + (ViewParam.GROUP_ID, group.id), (FormAction.SUBMIT, "submit"), ] ) self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.SCHEDULE_ID: str(self.schedule.id)}, + {ViewParam.SCHEDULE_ID: str(schedule.id)}, set_method_get=False, ) @@ -269,28 +256,16 @@ def test_schedule_name_can_be_updated(self) -> None: ) def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None: - group_a = Group() - group_a.name = "Group A" - self.dbsession.add(group_a) + group_a = GroupFactory() + group_b = GroupFactory() - group_b = Group() - group_b.name = "Group B" - self.dbsession.add(group_b) - self.dbsession.commit() - - group_a_schedule = TaskSchedule() - group_a_schedule.group_id = group_a.id - group_a_schedule.name = "Group A schedule" - self.dbsession.add(group_a_schedule) - self.dbsession.commit() + group_a_schedule = TaskScheduleFactory(group=group_a) - self.user = User() - self.user.upload_group_id = group_b.id - self.user.username = "group b admin" - self.user.set_password(self.req, "secret123") - self.dbsession.add(self.user) - self.dbsession.commit() - self.req._debugging_user = self.user + group_b_user = UserFactory() + UserGroupMembershipFactory( + group_id=group_b.id, user_id=group_b_user.id, groupadmin=True + ) + self.req._debugging_user = group_b_user multidict = MultiDict( [ @@ -298,14 +273,14 @@ def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None: ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), (ViewParam.NAME, "Something else"), - (ViewParam.GROUP_ID, self.group.id), + (ViewParam.GROUP_ID, group_b.id), (FormAction.SUBMIT, "submit"), ] ) self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.SCHEDULE_ID: str(self.schedule.id)}, + {ViewParam.SCHEDULE_ID: str(group_a_schedule.id)}, set_method_get=False, ) @@ -317,21 +292,16 @@ def test_group_a_schedule_cannot_be_edited_by_group_b_admin(self) -> None: self.assertIn("not a group administrator", cm.exception.message) -class DeleteTaskScheduleViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - - def setUp(self) -> None: - super().setUp() - - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test" - self.dbsession.add(self.schedule) - self.dbsession.commit() - +class DeleteTaskScheduleViewTests(DemoRequestTestCase): def test_schedule_item_is_deleted(self) -> None: + user = self.req._debugging_user = UserFactory() + group = GroupFactory() + UserGroupMembershipFactory( + group_id=group.id, user_id=user.id, groupadmin=True + ) + schedule = TaskScheduleFactory(group=group) + self.assertIsNotNone(self.dbsession.query(TaskSchedule).one_or_none()) + multidict = MultiDict( [ ("_charset_", UTF8), @@ -352,7 +322,7 @@ def test_schedule_item_is_deleted(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.SCHEDULE_ID: str(self.schedule.id)}, + {ViewParam.SCHEDULE_ID: str(schedule.id)}, set_method_get=False, ) view = DeleteTaskScheduleView(self.req) @@ -365,25 +335,14 @@ def test_schedule_item_is_deleted(self) -> None: Routes.VIEW_TASK_SCHEDULES, e.exception.headers["Location"] ) - item = self.dbsession.query(TaskScheduleItem).one_or_none() - - self.assertIsNone(item) - + self.assertIsNone(self.dbsession.query(TaskSchedule).one_or_none()) -class AddTaskScheduleItemViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ +class AddTaskScheduleItemViewTests(BasicDatabaseTestCase): def setUp(self) -> None: super().setUp() - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test" - - self.dbsession.add(self.schedule) - self.dbsession.commit() + self.schedule = TaskScheduleFactory(group=self.group) def test_schedule_item_form_displayed(self) -> None: view = AddTaskScheduleItemView(self.req) @@ -480,31 +439,19 @@ def test_non_existent_schedule_handled(self) -> None: view.dispatch() -class EditTaskScheduleItemViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class EditTaskScheduleItemViewTests(BasicDatabaseTestCase): def setUp(self) -> None: - from pendulum import Duration - super().setUp() - - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test" - self.dbsession.add(self.schedule) - self.dbsession.commit() - - self.item = TaskScheduleItem() - self.item.schedule_id = self.schedule.id - self.item.task_table_name = "ace3" - self.item.due_from = Duration(days=30) - self.item.due_by = Duration(days=60) - self.dbsession.add(self.item) - self.dbsession.commit() + self.schedule = TaskScheduleFactory(group=self.group) def test_schedule_item_is_updated(self) -> None: + item = TaskScheduleItemFactory( + task_schedule=self.schedule, + task_table_name="ace3", + due_from=Duration(days=30), + due_by=Duration(days=60), + ) + multidict = MultiDict( [ ("_charset_", UTF8), @@ -529,7 +476,7 @@ def test_schedule_item_is_updated(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.SCHEDULE_ITEM_ID: str(self.item.id)}, + {ViewParam.SCHEDULE_ITEM_ID: str(item.id)}, set_method_get=False, ) view = EditTaskScheduleItemView(self.req) @@ -537,15 +484,22 @@ def test_schedule_item_is_updated(self) -> None: with self.assertRaises(HTTPFound) as cm: view.dispatch() - self.assertEqual(self.item.task_table_name, "bmi") + self.assertEqual(item.task_table_name, "bmi") self.assertEqual(cm.exception.status_code, 302) self.assertIn( f"{Routes.VIEW_TASK_SCHEDULE_ITEMS}" - f"?{ViewParam.SCHEDULE_ID}={self.item.schedule_id}", + f"?{ViewParam.SCHEDULE_ID}={item.schedule_id}", cm.exception.headers["Location"], ) def test_schedule_item_is_not_updated_on_cancel(self) -> None: + item = TaskScheduleItemFactory( + task_schedule=self.schedule, + task_table_name="ace3", + due_from=Duration(days=30), + due_by=Duration(days=60), + ) + multidict = MultiDict( [ ("_charset_", UTF8), @@ -570,7 +524,7 @@ def test_schedule_item_is_not_updated_on_cancel(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.SCHEDULE_ITEM_ID: str(self.item.id)}, + {ViewParam.SCHEDULE_ITEM_ID: str(item.id)}, set_method_get=False, ) view = EditTaskScheduleItemView(self.req) @@ -578,7 +532,7 @@ def test_schedule_item_is_not_updated_on_cancel(self) -> None: with self.assertRaises(HTTPFound): view.dispatch() - self.assertEqual(self.item.task_table_name, "ace3") + self.assertEqual(item.task_table_name, "ace3") def test_non_existent_item_handled(self) -> None: self.req.add_get_params({ViewParam.SCHEDULE_ITEM_ID: "99999"}) @@ -595,53 +549,37 @@ def test_null_item_handled(self) -> None: view.dispatch() def test_get_form_values(self) -> None: + item = TaskScheduleItemFactory( + task_schedule=self.schedule, + task_table_name="ace3", + due_from=Duration(days=30), + due_by=Duration(days=60), + ) view = EditTaskScheduleItemView(self.req) - view.object = self.item + view.object = item form_values = view.get_form_values() self.assertEqual(form_values[ViewParam.SCHEDULE_ID], self.schedule.id) self.assertEqual( - form_values[ViewParam.TABLE_NAME], self.item.task_table_name + form_values[ViewParam.TABLE_NAME], item.task_table_name ) - self.assertEqual(form_values[ViewParam.DUE_FROM], self.item.due_from) + self.assertEqual(form_values[ViewParam.DUE_FROM], item.due_from) - due_within = self.item.due_by - self.item.due_from + due_within = item.due_by - item.due_from self.assertEqual(form_values[ViewParam.DUE_WITHIN], due_within) def test_group_a_item_cannot_be_edited_by_group_b_admin(self) -> None: - from pendulum import Duration + group_a = GroupFactory() + group_b = GroupFactory() - group_a = Group() - group_a.name = "Group A" - self.dbsession.add(group_a) - - group_b = Group() - group_b.name = "Group B" - self.dbsession.add(group_b) - self.dbsession.commit() - - group_a_schedule = TaskSchedule() - group_a_schedule.group_id = group_a.id - group_a_schedule.name = "Group A schedule" - self.dbsession.add(group_a_schedule) - self.dbsession.commit() - - group_a_item = TaskScheduleItem() - group_a_item.schedule_id = group_a_schedule.id - group_a_item.task_table_name = "ace3" - group_a_item.due_from = Duration(days=30) - group_a_item.due_by = Duration(days=60) - self.dbsession.add(group_a_item) - self.dbsession.commit() + group_b_admin = self.req._debugging_user = UserFactory() + UserGroupMembershipFactory( + group_id=group_b.id, user_id=group_b_admin.id, groupadmin=True + ) - self.user = User() - self.user.upload_group_id = group_b.id - self.user.username = "group b admin" - self.user.set_password(self.req, "secret123") - self.dbsession.add(self.user) - self.dbsession.commit() - self.req._debugging_user = self.user + group_a_schedule = TaskScheduleFactory(group=group_a) + group_a_item = TaskScheduleItemFactory(task_schedule=group_a_schedule) view = EditTaskScheduleItemView(self.req) view.object = group_a_item @@ -652,25 +590,16 @@ def test_group_a_item_cannot_be_edited_by_group_b_admin(self) -> None: self.assertIn("not a group administrator", cm.exception.message) -class DeleteTaskScheduleItemViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class DeleteTaskScheduleItemViewTests(BasicDatabaseTestCase): def setUp(self) -> None: super().setUp() - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test" - self.dbsession.add(self.schedule) - self.dbsession.commit() + self.schedule = TaskScheduleFactory(group=self.group) - self.item = TaskScheduleItem() - self.item.schedule_id = self.schedule.id - self.item.task_table_name = "ace3" - self.dbsession.add(self.item) - self.dbsession.commit() + self.schedule = TaskScheduleFactory() + self.item = TaskScheduleItemFactory( + task_schedule=self.schedule, task_table_name="ace3" + ) def test_delete_form_displayed(self) -> None: view = DeleteTaskScheduleItemView(self.req) @@ -754,10 +683,19 @@ def test_schedule_item_not_deleted_on_cancel(self) -> None: self.assertIsNotNone(item) -class EditFinalizedPatientViewTests(BasicDatabaseTestCase): - """ - Unit tests. - """ +class EditFinalizedPatientViewTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() + + self.group = AnyIdNumGroupFactory() + user = self.req._debugging_user = UserFactory() + + UserGroupMembershipFactory( + group_id=self.group.id, + user_id=user.id, + groupadmin=True, + view_all_patients_when_unfiltered=True, + ) def test_raises_when_patient_does_not_exists(self) -> None: with self.assertRaises(HTTPBadRequest) as cm: @@ -769,7 +707,7 @@ def test_raises_when_patient_does_not_exists(self) -> None: @unittest.skip("Can't save patient in database without group") def test_raises_when_patient_not_in_a_group(self) -> None: - patient = self.create_patient(_group_id=None) + patient = PatientFactory(_group=None) self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)}) @@ -779,9 +717,7 @@ def test_raises_when_patient_not_in_a_group(self) -> None: self.assertEqual(str(cm.exception), "Bad patient: not in a group") def test_raises_when_not_authorized(self) -> None: - patient = self.create_patient() - - self.req._debugging_user = User() + patient = PatientFactory() with mock.patch.object( self.req._debugging_user, @@ -798,11 +734,7 @@ def test_raises_when_not_authorized(self) -> None: ) def test_raises_when_patient_not_finalized(self) -> None: - device = Device(name="Not the server device") - self.req.dbsession.add(device) - self.req.dbsession.commit() - - patient = self.create_patient(id=1, _device_id=device.id, _era=ERA_NOW) + patient = PatientFactory(_era=ERA_NOW, _group=self.group) self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)}) @@ -812,12 +744,23 @@ def test_raises_when_patient_not_finalized(self) -> None: self.assertIn("Patient is not editable", str(cm.exception)) def test_patient_updated(self) -> None: - patient = self.create_patient() + patient = PatientFactory(_group=self.group) + nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient) self.req.add_get_params( {ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False ) + new_sex = Fake.en_gb.sex() + new_forename = Fake.en_gb.forename(new_sex) + new_surname = Fake.en_gb.last_name() + new_address = Fake.en_gb.address() + new_email = Fake.en_gb.email() + new_gp = Fake.en_gb.name() + new_other = Fake.en_us.paragraph() + new_dob = Fake.en_gb.consistent_date_of_birth() + new_nhs_number = Fake.en_gb.nhs_number() + multidict = MultiDict( [ ("_charset_", UTF8), @@ -825,22 +768,22 @@ def test_patient_updated(self) -> None: (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), (ViewParam.SERVER_PK, str(patient.pk)), (ViewParam.GROUP_ID, str(patient.group.id)), - (ViewParam.FORENAME, "Jo"), - (ViewParam.SURNAME, "Patient"), + (ViewParam.FORENAME, new_forename), + (ViewParam.SURNAME, new_surname), ("__start__", "dob:mapping"), - ("date", "1958-04-19"), + ("date", new_dob), ("__end__", "dob:mapping"), ("__start__", "sex:rename"), - ("deformField7", "X"), + ("deformField7", new_sex), ("__end__", "sex:rename"), - (ViewParam.ADDRESS, "New address"), - (ViewParam.EMAIL, "newjopatient@example.com"), - (ViewParam.GP, "New GP"), - (ViewParam.OTHER, "New other"), + (ViewParam.ADDRESS, new_address), + (ViewParam.EMAIL, new_email), + (ViewParam.GP, new_gp), + (ViewParam.OTHER, new_other), ("__start__", "id_references:sequence"), ("__start__", "idnum_sequence:mapping"), - (ViewParam.WHICH_IDNUM, self.nhs_iddef.which_idnum), - (ViewParam.IDNUM_VALUE, str(TEST_NHS_NUMBER_1)), + (ViewParam.WHICH_IDNUM, nhs_patient_idnum.which_idnum), + (ViewParam.IDNUM_VALUE, new_nhs_number), ("__end__", "idnum_sequence:mapping"), ("__end__", "id_references:sequence"), ("__start__", "danger:mapping"), @@ -858,32 +801,32 @@ def test_patient_updated(self) -> None: self.dbsession.commit() - self.assertEqual(patient.forename, "Jo") - self.assertEqual(patient.surname, "Patient") - self.assertEqual(patient.dob.isoformat(), "1958-04-19") - self.assertEqual(patient.sex, "X") - self.assertEqual(patient.address, "New address") - self.assertEqual(patient.email, "newjopatient@example.com") - self.assertEqual(patient.gp, "New GP") - self.assertEqual(patient.other, "New other") + self.assertEqual(patient.forename, new_forename) + self.assertEqual(patient.surname, new_surname) + self.assertEqual(patient.dob, new_dob) + self.assertEqual(patient.sex, new_sex) + self.assertEqual(patient.address, new_address) + self.assertEqual(patient.email, new_email) + self.assertEqual(patient.gp, new_gp) + self.assertEqual(patient.other, new_other) idnum = patient.get_idnum_objects()[0] self.assertEqual(idnum.patient_id, patient.id) - self.assertEqual(idnum.which_idnum, self.nhs_iddef.which_idnum) - self.assertEqual(idnum.idnum_value, TEST_NHS_NUMBER_1) + self.assertEqual(idnum.which_idnum, nhs_patient_idnum.which_idnum) + self.assertEqual(idnum.idnum_value, new_nhs_number) self.assertEqual(len(patient.special_notes), 1) note = patient.special_notes[0].note self.assertIn("Patient details edited", note) self.assertIn("forename", note) - self.assertIn("Jo", note) + self.assertIn(new_forename, note) self.assertIn("surname", note) - self.assertIn("Patient", note) + self.assertIn(new_surname, note) - self.assertIn("idnum1", note) - self.assertIn(str(TEST_NHS_NUMBER_1), note) + self.assertIn(f"idnum{nhs_patient_idnum.which_idnum}", note) + self.assertIn(str(new_nhs_number), note) messages = self.req.session.peek_flash(FlashQueue.SUCCESS) @@ -891,46 +834,20 @@ def test_patient_updated(self) -> None: f"Amended patient record with server PK {patient.pk}", messages[0] ) self.assertIn("forename", messages[0]) - self.assertIn("Jo", messages[0]) + self.assertIn(new_forename, messages[0]) self.assertIn("surname", messages[0]) - self.assertIn("Patient", messages[0]) + self.assertIn(new_surname, messages[0]) self.assertIn("idnum1", messages[0]) - self.assertIn(str(TEST_NHS_NUMBER_1), messages[0]) + self.assertIn(str(new_nhs_number), messages[0]) def test_message_when_no_changes(self) -> None: - patient = self.create_patient( - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - ) - patient_idnum = self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, - ) - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - self.dbsession.commit() + patient = PatientFactory(_group=self.group) - patient_task_schedule = PatientTaskSchedule() - patient_task_schedule.patient_pk = patient.pk - patient_task_schedule.schedule_id = schedule1.id - patient_task_schedule.start_datetime = local(2020, 6, 12, 9) - patient_task_schedule.settings = { - "name 1": "value 1", - "name 2": "value 2", - "name 3": "value 3", - } - - self.dbsession.add(patient_task_schedule) + patient_idnum = NHSPatientIdNumFactory( + patient=patient, + ) self.req.add_get_params( {ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False ) @@ -951,6 +868,7 @@ def test_message_when_no_changes(self) -> None: ("deformField7", patient.sex), ("__end__", "sex:rename"), (ViewParam.ADDRESS, patient.address), + (ViewParam.EMAIL, patient.email), (ViewParam.GP, patient.gp), (ViewParam.OTHER, patient.other), ("__start__", "id_references:sequence"), @@ -963,25 +881,6 @@ def test_message_when_no_changes(self) -> None: ("target", "7836"), ("user_entry", "7836"), ("__end__", "danger:mapping"), - ("__start__", "task_schedules:sequence"), - ("__start__", "task_schedule_sequence:mapping"), - ("schedule_id", schedule1.id), - ("__start__", "start_datetime:mapping"), - ("date", "2020-06-12"), - ("time", "09:00:00"), - ("__end__", "start_datetime:mapping"), - ( - "settings", - json.dumps( - { - "name 1": "value 1", - "name 2": "value 2", - "name 3": "value 3", - } - ), - ), - ("__end__", "task_schedule_sequence:mapping"), - ("__end__", "task_schedules:sequence"), (FormAction.SUBMIT, "submit"), ] ) @@ -996,44 +895,11 @@ def test_message_when_no_changes(self) -> None: self.assertIn("No changes required", messages[0]) def test_template_rendered_with_values(self) -> None: - patient = self.create_patient( - id=1, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - ) - self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, - ) - - from camcops_server.tasks import Bmi - - task1 = Bmi() - task1.id = 1 - task1._device_id = patient.device_id - task1._group_id = patient.group_id - task1._era = patient.era - task1.patient_id = patient.id - task1.when_created = self.era_time - task1._current = False - self.dbsession.add(task1) - - task2 = Bmi() - task2.id = 2 - task2._device_id = patient.device_id - task2._group_id = patient.group_id - task2._era = patient.era - task2.patient_id = patient.id - task2.when_created = self.era_time - task2._current = False - self.dbsession.add(task2) - self.dbsession.commit() + patient = PatientFactory(_group=self.group) + NHSPatientIdNumFactory(patient=patient) + + task1 = BmiFactory(patient=patient, _current=False) + task2 = BmiFactory(patient=patient, _current=False) self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)}) @@ -1053,54 +919,48 @@ def test_template_rendered_with_values(self) -> None: def test_changes_to_simple_params(self) -> None: view = EditFinalizedPatientView(self.req) - patient = self.create_patient( - id=1, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - email="jopatient@example.com", - gp="GP", - other=None, - ) + patient = PatientFactory() + old_forename = patient.forename + old_surname = patient.surname + old_address = patient.address + new_forename = Fake.en_gb.forename(patient.sex) + new_surname = Fake.en_gb.last_name() + new_address = Fake.en_gb.address() + view.object = patient changes = OrderedDict() # type: OrderedDict appstruct = { - ViewParam.FORENAME: "Joanna", - ViewParam.SURNAME: "Patient-Patient", - ViewParam.DOB: datetime.date(1958, 4, 19), - ViewParam.ADDRESS: "New address", - ViewParam.OTHER: "", + ViewParam.FORENAME: new_forename, + ViewParam.SURNAME: new_surname, + ViewParam.DOB: patient.dob, + ViewParam.ADDRESS: new_address, + ViewParam.OTHER: patient.other, } view._save_simple_params(appstruct, changes) - self.assertEqual(changes[ViewParam.FORENAME], ("Jo", "Joanna")) self.assertEqual( - changes[ViewParam.SURNAME], ("Patient", "Patient-Patient") + changes[ViewParam.FORENAME], (old_forename, new_forename) + ) + self.assertEqual( + changes[ViewParam.SURNAME], (old_surname, new_surname) ) self.assertNotIn(ViewParam.DOB, changes) self.assertEqual( - changes[ViewParam.ADDRESS], ("Address", "New address") + changes[ViewParam.ADDRESS], (old_address, new_address) ) self.assertNotIn(ViewParam.OTHER, changes) def test_changes_to_idrefs(self) -> None: view = EditFinalizedPatientView(self.req) - patient = self.create_patient(id=1) - self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, - ) - self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.study_iddef.which_idnum, - idnum_value=123, - ) + patient = PatientFactory() + nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient) + study_patient_idnum = StudyPatientIdNumFactory(patient=patient) + rio_iddef = RioIdNumDefinitionFactory() + new_nhs_number = Fake.en_gb.nhs_number() + new_rio_number = 9999 # Below the range the factory would use view.object = patient @@ -1109,40 +969,43 @@ def test_changes_to_idrefs(self) -> None: appstruct = { ViewParam.ID_REFERENCES: [ { - ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum, - ViewParam.IDNUM_VALUE: TEST_NHS_NUMBER_2, + ViewParam.WHICH_IDNUM: nhs_patient_idnum.which_idnum, + ViewParam.IDNUM_VALUE: new_nhs_number, }, { - ViewParam.WHICH_IDNUM: self.rio_iddef.which_idnum, - ViewParam.IDNUM_VALUE: 456, + ViewParam.WHICH_IDNUM: rio_iddef.which_idnum, + ViewParam.IDNUM_VALUE: new_rio_number, }, ] } view._save_idrefs(appstruct, changes) + nhs_key = f"idnum{nhs_patient_idnum.which_idnum} (NHS number)" + self.assertIn(nhs_key, changes) + + study_key = f"idnum{study_patient_idnum.which_idnum} (Study number)" + self.assertIn(study_key, changes) + + rio_key = f"idnum{rio_iddef.which_idnum} (RiO number)" + self.assertIn(rio_key, changes) + self.assertEqual( - changes["idnum1 (NHS number)"], - (TEST_NHS_NUMBER_1, TEST_NHS_NUMBER_2), + changes[nhs_key], + (nhs_patient_idnum.idnum_value, new_nhs_number), ) - self.assertEqual(changes["idnum3 (Study number)"], (123, None)) - self.assertEqual(changes["idnum2 (RiO number)"], (None, 456)) + self.assertEqual( + changes[study_key], (study_patient_idnum.idnum_value, None) + ) + self.assertEqual(changes[rio_key], (None, new_rio_number)) -class EditServerCreatedPatientViewTests(BasicDatabaseTestCase): - """ - Unit tests. - """ +class EditServerCreatedPatientViewTests(BasicDatabaseTestCase): def test_group_updated(self) -> None: - patient = self.create_patient(sex="F", as_server_patient=True) - new_group = Group() - new_group.name = "newgroup" - new_group.description = "New group" - new_group.upload_policy = "sex AND anyidnum" - new_group.finalize_policy = "sex AND idnum1" - self.dbsession.add(new_group) - self.dbsession.commit() + patient = ServerCreatedPatientFactory(_group=self.group) + old_group = patient.group + new_group = GroupFactory() view = EditServerCreatedPatientView(self.req) view.object = patient @@ -1155,12 +1018,12 @@ def test_group_updated(self) -> None: messages = self.req.session.peek_flash(FlashQueue.SUCCESS) - self.assertIn("testgroup", messages[0]) - self.assertIn("newgroup", messages[0]) + self.assertIn(old_group.name, messages[0]) + self.assertIn(new_group.name, messages[0]) self.assertIn("group:", messages[0]) def test_raises_when_not_created_on_the_server(self) -> None: - patient = self.create_patient(id=1, _device_id=self.other_device.id) + patient = PatientFactory() view = EditServerCreatedPatientView(self.req) @@ -1172,41 +1035,29 @@ def test_raises_when_not_created_on_the_server(self) -> None: self.assertIn("Patient is not editable", str(cm.exception)) def test_patient_task_schedules_updated(self) -> None: - patient = self.create_patient(sex="F", as_server_patient=True) - - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - schedule2 = TaskSchedule() - schedule2.group_id = self.group.id - schedule2.name = "Test 2" - self.dbsession.add(schedule2) - schedule3 = TaskSchedule() - schedule3.group_id = self.group.id - schedule3.name = "Test 3" - self.dbsession.add(schedule3) - self.dbsession.commit() - - patient_task_schedule = PatientTaskSchedule() - patient_task_schedule.patient_pk = patient.pk - patient_task_schedule.schedule_id = schedule1.id - patient_task_schedule.start_datetime = local(2020, 6, 12, 9) - patient_task_schedule.settings = { - "name 1": "value 1", - "name 2": "value 2", - "name 3": "value 3", - } - - self.dbsession.add(patient_task_schedule) - - patient_task_schedule = PatientTaskSchedule() - patient_task_schedule.patient_pk = patient.pk - patient_task_schedule.schedule_id = schedule3.id - - self.dbsession.add(patient_task_schedule) - self.dbsession.commit() + patient = ServerCreatedPatientFactory() + nhs_patient_idnum = NHSPatientIdNumFactory(patient=patient) + group = patient._group + + schedule1 = TaskScheduleFactory(group=group) + schedule2 = TaskScheduleFactory(group=group) + schedule3 = TaskScheduleFactory(group=group) + + PatientTaskScheduleFactory( + patient=patient, + task_schedule=schedule1, + start_datetime=local(2020, 6, 12, 9), + settings={ + "name 1": "value 1", + "name 2": "value 2", + "name 3": "value 3", + }, + ) + PatientTaskScheduleFactory( + patient=patient, + task_schedule=schedule3, + ) self.req.add_get_params( {ViewParam.SERVER_PK: str(patient.pk)}, set_method_get=False ) @@ -1216,11 +1067,14 @@ def test_patient_task_schedules_updated(self) -> None: "name 2": "new value 2", "name 3": "new value 3", } + changed_schedule_1_datetime = local(2020, 6, 19, 8, 0, 0) new_schedule_2_settings = { "name 4": "value 4", "name 5": "value 5", "name 6": "value 6", } + new_schedule_2_datetime = local(2020, 7, 1, 13, 45, 0) + new_nhs_number = Fake.en_gb.nhs_number() multidict = MultiDict( [ ("_charset_", UTF8), @@ -1241,8 +1095,8 @@ def test_patient_task_schedules_updated(self) -> None: (ViewParam.OTHER, patient.other), ("__start__", "id_references:sequence"), ("__start__", "idnum_sequence:mapping"), - (ViewParam.WHICH_IDNUM, self.nhs_iddef.which_idnum), - (ViewParam.IDNUM_VALUE, str(TEST_NHS_NUMBER_1)), + (ViewParam.WHICH_IDNUM, nhs_patient_idnum.which_idnum), + (ViewParam.IDNUM_VALUE, str(new_nhs_number)), ("__end__", "idnum_sequence:mapping"), ("__end__", "id_references:sequence"), ("__start__", "danger:mapping"), @@ -1253,16 +1107,16 @@ def test_patient_task_schedules_updated(self) -> None: ("__start__", "task_schedule_sequence:mapping"), ("schedule_id", schedule1.id), ("__start__", "start_datetime:mapping"), - ("date", "2020-06-19"), - ("time", "08:00:00"), + ("date", changed_schedule_1_datetime.to_date_string()), + ("time", changed_schedule_1_datetime.to_time_string()), ("__end__", "start_datetime:mapping"), ("settings", json.dumps(changed_schedule_1_settings)), ("__end__", "task_schedule_sequence:mapping"), ("__start__", "task_schedule_sequence:mapping"), ("schedule_id", schedule2.id), ("__start__", "start_datetime:mapping"), - ("date", "2020-07-01"), - ("time", "13:45:00"), + ("date", new_schedule_2_datetime.to_date_string()), + ("time", new_schedule_2_datetime.to_time_string()), ("__end__", "start_datetime:mapping"), ("settings", json.dumps(new_schedule_2_settings)), ("__end__", "task_schedule_sequence:mapping"), @@ -1281,20 +1135,24 @@ def test_patient_task_schedules_updated(self) -> None: schedules = { pts.task_schedule.name: pts for pts in patient.task_schedules } - self.assertIn("Test 1", schedules) - self.assertIn("Test 2", schedules) - self.assertNotIn("Test 3", schedules) + self.assertIn(schedule1.name, schedules) + self.assertIn(schedule2.name, schedules) + self.assertNotIn(schedule3.name, schedules) self.assertEqual( - schedules["Test 1"].start_datetime, local(2020, 6, 19, 8) + schedules[schedule1.name].start_datetime, + changed_schedule_1_datetime, + ) + self.assertEqual( + schedules[schedule1.name].settings, changed_schedule_1_settings ) self.assertEqual( - schedules["Test 1"].settings, changed_schedule_1_settings + schedules[schedule2.name].start_datetime, + new_schedule_2_datetime, ) self.assertEqual( - schedules["Test 2"].start_datetime, local(2020, 7, 1, 13, 45) + schedules[schedule2.name].settings, new_schedule_2_settings ) - self.assertEqual(schedules["Test 2"].settings, new_schedule_2_settings) messages = self.req.session.peek_flash(FlashQueue.SUCCESS) @@ -1304,12 +1162,9 @@ def test_patient_task_schedules_updated(self) -> None: self.assertIn("Task schedules", messages[0]) def test_unprivileged_user_cannot_edit_patient(self) -> None: - patient = self.create_patient(sex="F", as_server_patient=True) + patient = ServerCreatedPatientFactory() - user = self.create_user(username="testuser") - self.dbsession.flush() - - self.req._debugging_user = user + self.req._debugging_user = UserFactory() view = EditServerCreatedPatientView(self.req) view.object = patient @@ -1324,20 +1179,15 @@ def test_unprivileged_user_cannot_edit_patient(self) -> None: ) def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None: - patient = self.create_patient(sex="F", as_server_patient=True) - - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - self.dbsession.flush() - - pts = PatientTaskSchedule() - pts.patient_pk = patient.pk - pts.schedule_id = schedule1.id - pts.start_datetime = local(2020, 6, 12, 12, 34) - self.dbsession.add(pts) - self.dbsession.commit() + patient = ServerCreatedPatientFactory() + + schedule1 = TaskScheduleFactory(group=self.group) + + pts = PatientTaskScheduleFactory( + patient=patient, + task_schedule=schedule1, + start_datetime=local(2020, 6, 12, 12, 34), + ) appstruct = { ViewParam.TASK_SCHEDULES: [ @@ -1359,7 +1209,7 @@ def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None: view = EditServerCreatedPatientView(self.req) view.object = patient - changes = {} + changes: OrderedDict = OrderedDict() view._save_task_schedules(appstruct, changes) self.req.dbsession.commit() @@ -1367,43 +1217,24 @@ def test_patient_can_be_assigned_the_same_schedule_twice(self) -> None: self.assertEqual(patient.task_schedules[1].task_schedule, schedule1) def test_form_values_for_existing_patient(self) -> None: - patient = self.create_patient( - id=1, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - email="jopatient@example.com", - gp="GP", - other="Other", - ) - - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - self.dbsession.commit() + patient = PatientFactory() - patient_task_schedule = PatientTaskSchedule() - patient_task_schedule.patient_pk = patient.pk - patient_task_schedule.schedule_id = schedule1.id - patient_task_schedule.start_datetime = local(2020, 6, 12) - patient_task_schedule.settings = { - "name 1": "value 1", - "name 2": "value 2", - "name 3": "value 3", - } - - self.dbsession.add(patient_task_schedule) - self.dbsession.commit() + schedule1 = TaskScheduleFactory( + group=self.group, + ) - self.create_patient_idnum( - patient_id=patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, + patient_task_schedule = PatientTaskScheduleFactory( + patient=patient, + task_schedule=schedule1, + start_datetime=local(2020, 6, 12), + settings={ + "name 1": "value 1", + "name 2": "value 2", + "name 3": "value 3", + }, ) + patient_idnum = NHSPatientIdNumFactory(patient=patient) self.req.add_get_params({ViewParam.SERVER_PK: str(patient.pk)}) view = EditServerCreatedPatientView(self.req) @@ -1411,25 +1242,26 @@ def test_form_values_for_existing_patient(self) -> None: form_values = view.get_form_values() - self.assertEqual(form_values[ViewParam.FORENAME], "Jo") - self.assertEqual(form_values[ViewParam.SURNAME], "Patient") - self.assertEqual( - form_values[ViewParam.DOB], datetime.date(1958, 4, 19) - ) - self.assertEqual(form_values[ViewParam.SEX], "F") - self.assertEqual(form_values[ViewParam.ADDRESS], "Address") - self.assertEqual(form_values[ViewParam.EMAIL], "jopatient@example.com") - self.assertEqual(form_values[ViewParam.GP], "GP") - self.assertEqual(form_values[ViewParam.OTHER], "Other") + self.assertEqual(form_values[ViewParam.FORENAME], patient.forename) + self.assertEqual(form_values[ViewParam.SURNAME], patient.surname) + self.assertEqual(form_values[ViewParam.DOB], patient.dob) + self.assertEqual(form_values[ViewParam.SEX], patient.sex) + self.assertEqual(form_values[ViewParam.ADDRESS], patient.address) + self.assertEqual(form_values[ViewParam.EMAIL], patient.email) + self.assertEqual(form_values[ViewParam.GP], patient.gp) + self.assertEqual(form_values[ViewParam.OTHER], patient.other) self.assertEqual(form_values[ViewParam.SERVER_PK], patient.pk) self.assertEqual(form_values[ViewParam.GROUP_ID], patient.group.id) idnum = form_values[ViewParam.ID_REFERENCES][0] self.assertEqual( - idnum[ViewParam.WHICH_IDNUM], self.nhs_iddef.which_idnum + idnum[ViewParam.WHICH_IDNUM], + patient_idnum.which_idnum, + ) + self.assertEqual( + idnum[ViewParam.IDNUM_VALUE], patient_idnum.idnum_value ) - self.assertEqual(idnum[ViewParam.IDNUM_VALUE], TEST_NHS_NUMBER_1) task_schedule = form_values[ViewParam.TASK_SCHEDULES][0] self.assertEqual( @@ -1449,24 +1281,12 @@ def test_form_values_for_existing_patient(self) -> None: ) -class AddPatientViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ - +class AddPatientViewTests(BasicDatabaseTestCase): def test_patient_created(self) -> None: view = AddPatientView(self.req) - schedule1 = TaskSchedule() - schedule1.group_id = self.group.id - schedule1.name = "Test 1" - self.dbsession.add(schedule1) - - schedule2 = TaskSchedule() - schedule2.group_id = self.group.id - schedule2.name = "Test 2" - self.dbsession.add(schedule2) - self.dbsession.commit() + schedule1 = TaskScheduleFactory() + schedule2 = TaskScheduleFactory() start_datetime1 = local(2020, 6, 12) start_datetime2 = local(2020, 7, 1) @@ -1475,6 +1295,9 @@ def test_patient_created(self) -> None: {"name 1": "value 1", "name 2": "value 2", "name 3": "value 3"} ) + nhs_iddef = NHSIdNumDefinitionFactory() + nhs_number = Fake.en_gb.nhs_number() + appstruct = { ViewParam.GROUP_ID: self.group.id, ViewParam.FORENAME: "Jo", @@ -1487,8 +1310,8 @@ def test_patient_created(self) -> None: ViewParam.OTHER: "Other", ViewParam.ID_REFERENCES: [ { - ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum, - ViewParam.IDNUM_VALUE: 1192220552, + ViewParam.WHICH_IDNUM: nhs_iddef.which_idnum, + ViewParam.IDNUM_VALUE: nhs_number, } ], ViewParam.TASK_SCHEDULES: [ @@ -1506,12 +1329,12 @@ def test_patient_created(self) -> None: } view.save_object(appstruct) + self.dbsession.commit() patient = cast(Patient, view.object) server_device = Device.get_server_device(self.req.dbsession) - self.assertEqual(patient.id, 1) self.assertEqual(patient.device_id, server_device.id) self.assertEqual(patient.era, ERA_NOW) self.assertEqual(patient.group.id, self.group.id) @@ -1526,27 +1349,32 @@ def test_patient_created(self) -> None: self.assertEqual(patient.other, "Other") idnum = patient.get_idnum_objects()[0] - self.assertEqual(idnum.patient_id, 1) - self.assertEqual(idnum.which_idnum, self.nhs_iddef.which_idnum) - self.assertEqual(idnum.idnum_value, 1192220552) + self.assertEqual(idnum.patient_id, patient.id) + self.assertEqual(idnum.which_idnum, nhs_iddef.which_idnum) + self.assertEqual(idnum.idnum_value, nhs_number) patient_task_schedules = { pts.task_schedule.name: pts for pts in patient.task_schedules } - self.assertIn("Test 1", patient_task_schedules) - self.assertIn("Test 2", patient_task_schedules) + self.assertIn(schedule1.name, patient_task_schedules) + self.assertIn(schedule2.name, patient_task_schedules) self.assertEqual( - patient_task_schedules["Test 1"].start_datetime, start_datetime1 + patient_task_schedules[schedule1.name].start_datetime, + start_datetime1, ) - self.assertEqual(patient_task_schedules["Test 1"].settings, settings1) self.assertEqual( - patient_task_schedules["Test 2"].start_datetime, start_datetime2 + patient_task_schedules[schedule1.name].settings, settings1 + ) + self.assertEqual( + patient_task_schedules[schedule2.name].start_datetime, + start_datetime2, ) def test_patient_takes_next_available_id(self) -> None: - self.create_patient(id=1234, as_server_patient=True) + patient = ServerCreatedPatientFactory(id=1234) + nhs_iddef = NHSIdNumDefinitionFactory() view = AddPatientView(self.req) @@ -1561,8 +1389,8 @@ def test_patient_takes_next_available_id(self) -> None: ViewParam.OTHER: "Other", ViewParam.ID_REFERENCES: [ { - ViewParam.WHICH_IDNUM: self.nhs_iddef.which_idnum, - ViewParam.IDNUM_VALUE: 1192220552, + ViewParam.WHICH_IDNUM: nhs_iddef.which_idnum, + ViewParam.IDNUM_VALUE: Fake.en_gb.nhs_number(), } ], ViewParam.TASK_SCHEDULES: [], @@ -1587,8 +1415,7 @@ def test_form_rendered_with_values(self) -> None: self.assertIn("form", context) def test_unprivileged_user_cannot_add_patient(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() + user = UserFactory(username="testuser") self.req._debugging_user = user @@ -1600,10 +1427,11 @@ def test_unprivileged_user_cannot_add_patient(self) -> None: ) def test_group_listed_for_privileged_group_member(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() - self.create_membership(user, self.group, may_manage_patients=True) - self.dbsession.commit() + user = UserFactory() + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_manage_patients=True + ) self.req._debugging_user = user @@ -1616,50 +1444,24 @@ def test_group_listed_for_privileged_group_member(self) -> None: context = args[0] - self.assertIn("testgroup", context["form"]) + self.assertIn(group.name, context["form"]) class DeleteServerCreatedPatientViewTests(BasicDatabaseTestCase): - """ - Unit tests. - """ - def setUp(self) -> None: super().setUp() - self.patient = self.create_patient( - as_server_patient=True, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - ) - - patient_pk = self.patient.pk - - idnum = self.create_patient_idnum( - as_server_patient=True, - patient_id=self.patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, - ) + self.patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=self.patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test 1" - self.dbsession.add(self.schedule) - self.dbsession.commit() + self.schedule = TaskScheduleFactory(group=self.group) - pts = PatientTaskSchedule() - pts.patient_pk = patient_pk - pts.schedule_id = self.schedule.id - self.dbsession.add(pts) - self.dbsession.commit() + PatientTaskScheduleFactory( + patient=self.patient, + task_schedule=self.schedule, + ) self.multidict = MultiDict( [ @@ -1770,16 +1572,7 @@ def test_registered_patient_deleted(self) -> None: # self.assertIsNone(user.single_patient_pk) def test_unrelated_patient_unaffected(self) -> None: - other_patient = self.create_patient( - as_server_patient=True, - forename="Mo", - surname="Patient", - dob=datetime.date(1968, 11, 30), - sex="M", - address="Address", - gp="GP", - other="Other", - ) + other_patient = ServerCreatedPatientFactory() patient_pk = other_patient._pk saved_patient = ( @@ -1790,12 +1583,7 @@ def test_unrelated_patient_unaffected(self) -> None: self.assertIsNotNone(saved_patient) - idnum = self.create_patient_idnum( - as_server_patient=True, - patient_id=other_patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_2, - ) + idnum = ServerCreatedNHSPatientIdNumFactory(patient=other_patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) @@ -1812,11 +1600,9 @@ def test_unrelated_patient_unaffected(self) -> None: self.assertIsNotNone(saved_idnum) - pts = PatientTaskSchedule() - pts.patient_pk = patient_pk - pts.schedule_id = self.schedule.id - self.dbsession.add(pts) - self.dbsession.commit() + PatientTaskScheduleFactory( + patient=other_patient, task_schedule=self.schedule + ) self.req.fake_request_post_from_dict(self.multidict) @@ -1866,8 +1652,7 @@ def test_unprivileged_user_cannot_delete_patient(self) -> None: ) view = DeleteServerCreatedPatientView(self.req) - user = self.create_user(username="testuser") - self.dbsession.flush() + user = UserFactory(username="testuser") self.req._debugging_user = user @@ -1885,8 +1670,7 @@ def test_unprivileged_user_cannot_see_delete_form(self) -> None: self.req.add_get_params({ViewParam.SERVER_PK: str(patient_pk)}) view = DeleteServerCreatedPatientView(self.req) - user = self.create_user(username="testuser") - self.dbsession.flush() + user = UserFactory() self.req._debugging_user = user @@ -1899,33 +1683,19 @@ def test_unprivileged_user_cannot_see_delete_form(self) -> None: class EraseTaskTestCase(BasicDatabaseTestCase): - """ - Unit tests. - """ - - def create_tasks(self) -> None: - from camcops_server.tasks.bmi import Bmi - - self.task = Bmi() - self.task.id = 1 - self.apply_standard_task_fields(self.task) - patient = self.create_patient_with_one_idnum() - self.task.patient_id = patient.id + def setUp(self) -> None: + super().setUp() - self.dbsession.add(self.task) - self.dbsession.commit() + self.patient = PatientFactory(_group=self.group) class EraseTaskLeavingPlaceholderViewTests(EraseTaskTestCase): - """ - Unit tests. - """ - def test_displays_form(self) -> None: + task = BmiFactory(patient=self.patient) self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -1940,13 +1710,14 @@ def test_displays_form(self) -> None: self.assertIn("form", context) def test_deletes_task_leaving_placeholder(self) -> None: + task = BmiFactory(patient=self.patient) multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), - (ViewParam.SERVER_PK, self.task.pk), - (ViewParam.TABLE_NAME, self.task.tablename), + (ViewParam.SERVER_PK, task.pk), + (ViewParam.TABLE_NAME, task.tablename), ("confirm_1_t", "true"), ("confirm_2_t", "true"), ("confirm_4_t", "true"), @@ -1962,9 +1733,7 @@ def test_deletes_task_leaving_placeholder(self) -> None: self.req.fake_request_post_from_dict(multidict) view = EraseTaskLeavingPlaceholderView(self.req) - with mock.patch.object( - self.task, "manually_erase" - ) as mock_manually_erase: + with mock.patch.object(task, "manually_erase") as mock_manually_erase: with self.assertRaises(HTTPFound): view.dispatch() @@ -1976,12 +1745,13 @@ def test_deletes_task_leaving_placeholder(self) -> None: self.assertEqual(request, self.req) def test_task_not_deleted_on_cancel(self) -> None: + task = BmiFactory(patient=self.patient) self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"}) self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -1990,17 +1760,18 @@ def test_task_not_deleted_on_cancel(self) -> None: with self.assertRaises(HTTPFound): view.dispatch() - task = self.dbsession.query(self.task.__class__).one_or_none() + task = self.dbsession.query(task.__class__).one_or_none() self.assertIsNotNone(task) def test_redirect_on_cancel(self) -> None: + task = BmiFactory(patient=self.patient) self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"}) self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -2012,11 +1783,11 @@ def test_redirect_on_cancel(self) -> None: self.assertEqual(cm.exception.status_code, 302) self.assertIn(f"/{Routes.TASK}", cm.exception.headers["Location"]) self.assertIn( - f"{ViewParam.TABLE_NAME}={self.task.tablename}", + f"{ViewParam.TABLE_NAME}={task.tablename}", cm.exception.headers["Location"], ) self.assertIn( - f"{ViewParam.SERVER_PK}={self.task.pk}", + f"{ViewParam.SERVER_PK}={task.pk}", cm.exception.headers["Location"], ) self.assertIn( @@ -2037,14 +1808,12 @@ def test_raises_when_task_does_not_exist(self) -> None: self.assertEqual(cm.exception.message, "No such task: phq9, PK=123") def test_raises_when_task_is_live_on_tablet(self) -> None: - self.task._era = ERA_NOW - self.dbsession.add(self.task) - self.dbsession.commit() + task = BmiFactory(patient=self.patient, _era=ERA_NOW) self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -2056,14 +1825,21 @@ def test_raises_when_task_is_live_on_tablet(self) -> None: self.assertIn("Task is live on tablet", cm.exception.message) def test_raises_when_user_not_authorized_to_erase(self) -> None: + task = BmiFactory(patient=self.patient) + user = UserFactory() + + self.req._debugging_user = user + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group.id, groupadmin=True + ) + with mock.patch.object( - self.user, "authorized_to_erase_tasks", return_value=False + user, "authorized_to_erase_tasks", return_value=False ): - self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -2075,14 +1851,12 @@ def test_raises_when_user_not_authorized_to_erase(self) -> None: self.assertIn("Not authorized to erase tasks", cm.exception.message) def test_raises_when_task_already_erased(self) -> None: - self.task._manually_erased = True - self.dbsession.add(self.task) - self.dbsession.commit() + task = BmiFactory(patient=self.patient, _manually_erased=True) self.req.add_get_params( { - ViewParam.SERVER_PK: str(self.task.pk), - ViewParam.TABLE_NAME: self.task.tablename, + ViewParam.SERVER_PK: str(task.pk), + ViewParam.TABLE_NAME: task.tablename, }, set_method_get=False, ) @@ -2095,18 +1869,15 @@ def test_raises_when_task_already_erased(self) -> None: class EraseTaskEntirelyViewTests(EraseTaskTestCase): - """ - Unit tests. - """ - def test_deletes_task_entirely(self) -> None: + task = BmiFactory(patient=self.patient) multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), - (ViewParam.SERVER_PK, self.task.pk), - (ViewParam.TABLE_NAME, self.task.tablename), + (ViewParam.SERVER_PK, task.pk), + (ViewParam.TABLE_NAME, task.tablename), ("confirm_1_t", "true"), ("confirm_2_t", "true"), ("confirm_4_t", "true"), @@ -2124,7 +1895,7 @@ def test_deletes_task_entirely(self) -> None: view = EraseTaskEntirelyView(self.req) with mock.patch.object( - self.task, "delete_entirely" + task, "delete_entirely" ) as mock_delete_entirely: with self.assertRaises(HTTPFound): @@ -2140,36 +1911,37 @@ def test_deletes_task_entirely(self) -> None: self.assertTrue(len(messages) > 0) self.assertIn("Task erased", messages[0]) - self.assertIn(self.task.tablename, messages[0]) - self.assertIn("server PK {}".format(self.task.pk), messages[0]) - + self.assertIn(task.tablename, messages[0]) + self.assertIn("server PK {}".format(task.pk), messages[0]) -class EditGroupViewTests(DemoDatabaseTestCase): - """ - Unit tests. - """ +class EditGroupViewTests(DemoRequestTestCase): def test_group_updated(self) -> None: - other_group_1 = Group() - other_group_1.name = "other-group-1" - self.dbsession.add(other_group_1) + groupadmin = self.req._debugging_user = UserFactory() + group = GroupFactory() + UserGroupMembershipFactory( + group_id=group.id, user_id=groupadmin.id, groupadmin=True + ) + other_group_1 = GroupFactory() + other_group_2 = GroupFactory() - other_group_2 = Group() - other_group_2.name = "other-group-2" - self.dbsession.add(other_group_2) + nhs_iddef = NHSIdNumDefinitionFactory() - self.dbsession.commit() + new_name = "new-name" + new_description = "new description" + new_upload_policy = "anyidnum AND sex" + new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex" multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), - (ViewParam.GROUP_ID, self.group.id), - (ViewParam.NAME, "new-name"), - (ViewParam.DESCRIPTION, "new description"), - (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"), # reversed - (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"), # reversed + (ViewParam.GROUP_ID, group.id), + (ViewParam.NAME, new_name), + (ViewParam.DESCRIPTION, new_description), + (ViewParam.UPLOAD_POLICY, new_upload_policy), + (ViewParam.FINALIZE_POLICY, new_finalize_policy), ("__start__", "group_ids:sequence"), ("group_id_sequence", str(other_group_1.id)), ("group_id_sequence", str(other_group_2.id)), @@ -2182,26 +1954,38 @@ def test_group_updated(self) -> None: with self.assertRaises(HTTPFound): edit_group(self.req) - self.assertEqual(self.group.name, "new-name") - self.assertEqual(self.group.description, "new description") - self.assertEqual(self.group.upload_policy, "anyidnum AND sex") - self.assertEqual(self.group.finalize_policy, "idnum1 AND sex") - self.assertIn(other_group_1, self.group.can_see_other_groups) - self.assertIn(other_group_2, self.group.can_see_other_groups) + self.assertEqual(group.name, new_name) + self.assertEqual(group.description, new_description) + self.assertEqual(group.upload_policy, new_upload_policy) + self.assertEqual(group.finalize_policy, new_finalize_policy) + self.assertIn(other_group_1, group.can_see_other_groups) + self.assertIn(other_group_2, group.can_see_other_groups) def test_ip_use_added(self) -> None: from camcops_server.cc_modules.cc_ipuse import IpContexts + group = GroupFactory() + groupadmin = self.req._debugging_user = UserFactory() + UserGroupMembershipFactory( + group_id=group.id, user_id=groupadmin.id, groupadmin=True + ) + nhs_iddef = NHSIdNumDefinitionFactory() + + new_name = "new-name" + new_description = "new description" + new_upload_policy = "anyidnum AND sex" + new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex" + multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), - (ViewParam.GROUP_ID, self.group.id), - (ViewParam.NAME, "new-name"), - (ViewParam.DESCRIPTION, "new description"), - (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"), - (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"), + (ViewParam.GROUP_ID, group.id), + (ViewParam.NAME, new_name), + (ViewParam.DESCRIPTION, new_description), + (ViewParam.UPLOAD_POLICY, new_upload_policy), + (ViewParam.FINALIZE_POLICY, new_finalize_policy), ("__start__", "ip_use:mapping"), (IpContexts.CLINICAL, "true"), (IpContexts.COMMERCIAL, "true"), @@ -2214,31 +1998,39 @@ def test_ip_use_added(self) -> None: with self.assertRaises(HTTPFound): edit_group(self.req) - self.assertTrue(self.group.ip_use.clinical) - self.assertTrue(self.group.ip_use.commercial) - self.assertFalse(self.group.ip_use.educational) - self.assertFalse(self.group.ip_use.research) + self.assertTrue(group.ip_use.clinical) + self.assertTrue(group.ip_use.commercial) + self.assertFalse(group.ip_use.educational) + self.assertFalse(group.ip_use.research) def test_ip_use_updated(self) -> None: from camcops_server.cc_modules.cc_ipuse import IpContexts - self.group.ip_use.educational = True - self.group.ip_use.research = True - self.dbsession.add(self.group.ip_use) - self.dbsession.commit() + group = GroupFactory(ip_use__educational=True, ip_use__research=True) + groupadmin = self.req._debugging_user = UserFactory() + UserGroupMembershipFactory( + group_id=group.id, user_id=groupadmin.id, groupadmin=True + ) + + old_id = group.ip_use.id + + nhs_iddef = NHSIdNumDefinitionFactory() - old_id = self.group.ip_use.id + new_name = "new-name" + new_description = "new description" + new_upload_policy = "anyidnum AND sex" + new_finalize_policy = f"idnum{nhs_iddef.which_idnum} AND sex" multidict = MultiDict( [ ("_charset_", UTF8), ("__formid__", "deform"), (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), - (ViewParam.GROUP_ID, self.group.id), - (ViewParam.NAME, "new-name"), - (ViewParam.DESCRIPTION, "new description"), - (ViewParam.UPLOAD_POLICY, "anyidnum AND sex"), - (ViewParam.FINALIZE_POLICY, "idnum1 AND sex"), + (ViewParam.GROUP_ID, group.id), + (ViewParam.NAME, new_name), + (ViewParam.DESCRIPTION, new_description), + (ViewParam.UPLOAD_POLICY, new_upload_policy), + (ViewParam.FINALIZE_POLICY, new_finalize_policy), ("__start__", "ip_use:mapping"), (IpContexts.CLINICAL, "true"), (IpContexts.COMMERCIAL, "true"), @@ -2251,32 +2043,23 @@ def test_ip_use_updated(self) -> None: with self.assertRaises(HTTPFound): edit_group(self.req) - self.assertTrue(self.group.ip_use.clinical) - self.assertTrue(self.group.ip_use.commercial) - self.assertFalse(self.group.ip_use.educational) - self.assertFalse(self.group.ip_use.research) - self.assertEqual(self.group.ip_use.id, old_id) + self.assertTrue(group.ip_use.clinical) + self.assertTrue(group.ip_use.commercial) + self.assertFalse(group.ip_use.educational) + self.assertFalse(group.ip_use.research) + self.assertEqual(group.ip_use.id, old_id) def test_other_groups_displayed_in_form(self) -> None: - z_group = Group() - z_group.name = "z-group" - self.dbsession.add(z_group) - - a_group = Group() - a_group.name = "a-group" - self.dbsession.add(a_group) - self.dbsession.commit() + z_group = GroupFactory(name="z-group") + a_group = GroupFactory(name="a-group") other_groups = Group.get_groups_from_id_list( self.dbsession, [z_group.id, a_group.id] ) - self.group.can_see_other_groups = other_groups - - self.dbsession.add(self.group) - self.dbsession.commit() + group = GroupFactory(can_see_other_groups=other_groups) view = EditGroupView(self.req) - view.object = self.group + view.object = group form_values = view.get_form_values() @@ -2285,59 +2068,38 @@ def test_other_groups_displayed_in_form(self) -> None: ) def test_group_id_displayed_in_form(self) -> None: + group = GroupFactory() view = EditGroupView(self.req) - view.object = self.group + view.object = group form_values = view.get_form_values() - self.assertEqual(form_values[ViewParam.GROUP_ID], self.group.id) + self.assertEqual(form_values[ViewParam.GROUP_ID], group.id) def test_ip_use_displayed_in_form(self) -> None: + group = GroupFactory() view = EditGroupView(self.req) - view.object = self.group + view.object = group form_values = view.get_form_values() - self.assertEqual(form_values[ViewParam.IP_USE], self.group.ip_use) + self.assertEqual(form_values[ViewParam.IP_USE], group.ip_use) class SendEmailFromPatientTaskScheduleViewTests(BasicDatabaseTestCase): def setUp(self) -> None: super().setUp() - self.patient = self.create_patient( - as_server_patient=True, - forename="Jo", - surname="Patient", - dob=datetime.date(1958, 4, 19), - sex="F", - address="Address", - gp="GP", - other="Other", - ) - - patient_pk = self.patient.pk - - idnum = self.create_patient_idnum( - as_server_patient=True, - patient_id=self.patient.id, - which_idnum=self.nhs_iddef.which_idnum, - idnum_value=TEST_NHS_NUMBER_1, - ) + self.patient = ServerCreatedPatientFactory() + idnum = ServerCreatedNHSPatientIdNumFactory(patient=self.patient) PatientIdNumIndexEntry.index_idnum(idnum, self.dbsession) - self.schedule = TaskSchedule() - self.schedule.group_id = self.group.id - self.schedule.name = "Test 1" - self.dbsession.add(self.schedule) - self.dbsession.commit() + self.schedule = TaskScheduleFactory(group=self.group) - self.pts = PatientTaskSchedule() - self.pts.patient_pk = patient_pk - self.pts.schedule_id = self.schedule.id - self.dbsession.add(self.pts) - self.dbsession.commit() + self.pts = PatientTaskScheduleFactory( + patient=self.patient, task_schedule=self.schedule + ) def test_displays_form(self) -> None: self.req.add_get_params( @@ -2576,8 +2338,7 @@ def test_email_record_created( self.assertEqual(self.pts.emails[0].email.to, "patient@example.com") def test_unprivileged_user_cannot_email_patient(self) -> None: - user = self.create_user(username="testuser") - self.dbsession.flush() + user = UserFactory(username="testuser") self.req._debugging_user = user @@ -2647,8 +2408,9 @@ def test_password_autocomplete_read_from_config(self) -> None: self.assertIn('autocomplete="current-password"', context["form"]) def test_fails_when_user_locked_out(self) -> None: - user = self.create_user(username="test") - user.set_password(self.req, "secret") + user = UserFactory( + username="test", password__request=self.req, password="secret" + ) SecurityAccountLockout.lock_user_out( self.req, user.username, lockout_minutes=1 ) @@ -2679,10 +2441,12 @@ def test_fails_when_user_locked_out(self) -> None: @mock.patch("camcops_server.cc_modules.webview.audit") def test_user_can_log_in(self, mock_audit: mock.Mock) -> None: - user = self.create_user(username="test") - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + user = UserFactory( + username="test", password__request=self.req, password="secret" + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group.id, may_use_webviewer=True + ) multidict = MultiDict( [ @@ -2715,14 +2479,17 @@ def test_user_can_log_in(self, mock_audit: mock.Mock) -> None: self.assertEqual(kwargs["user_id"], user.id) def test_user_with_totp_sees_token_form(self) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.TOTP, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) view = LoginView(self.req) view.state.update( @@ -2749,16 +2516,19 @@ def test_user_with_totp_sees_token_form(self) -> None: def test_user_with_hotp_email_sees_token_form( self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock ) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_EMAIL, email="user@example.com", hotp_counter=0, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) view = LoginView(self.req) view.state.update( mfa_user_id=user.id, @@ -2783,17 +2553,20 @@ def test_user_with_hotp_sms_sees_token_form(self) -> None: SmsBackendNames.CONSOLE, {} ) - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) - user = self.create_user( + phone_number_str = Fake.en_gb.valid_phone_number() + user = UserFactory( username="test", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_SMS, - phone_number=phone_number, + phone_number=phonenumbers.parse(phone_number_str), hotp_counter=0, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) view = LoginView(self.req) view.state.update( @@ -2824,15 +2597,18 @@ def test_session_state_set_for_user_with_mfa( mock_make_email: mock.Mock, mock_send_msg: mock.Mock, ) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_EMAIL, email="user@example.com", + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) multidict = MultiDict( [ @@ -2876,16 +2652,19 @@ def test_user_with_hotp_is_sent_email( self.req.config.email_use_tls = True self.req.config.email_from = "server@example.com" - user = self.create_user( + user = UserFactory( username="test", email="user@example.com", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_EMAIL, hotp_counter=0, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) multidict = MultiDict( [ @@ -2926,18 +2705,21 @@ def test_user_with_hotp_is_sent_sms(self) -> None: ) self.req.config.sms_config = test_config - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) - user = self.create_user( + phone_number_str = Fake.en_gb.valid_phone_number() + user = UserFactory( username="test", email="user@example.com", - phone_number=phone_number, + phone_number=phonenumbers.parse(phone_number_str), mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_SMS, hotp_counter=0, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) multidict = MultiDict( [ @@ -2958,7 +2740,7 @@ def test_user_with_hotp_is_sent_sms(self) -> None: expected_message = f"Your CamCOPS verification code is {expected_code}" self.assertIn( - ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message), + ConsoleSmsBackend.make_msg(phone_number_str, expected_message), logging_cm.output[0], ) @@ -2967,16 +2749,19 @@ def test_user_with_hotp_is_sent_sms(self) -> None: def test_login_with_hotp_increments_counter( self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock ) -> None: - user = self.create_user( + user = UserFactory( username="test", email="user@example.com", mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_EMAIL, hotp_counter=0, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) multidict = MultiDict( [ @@ -2996,15 +2781,18 @@ def test_login_with_hotp_increments_counter( @mock.patch("camcops_server.cc_modules.webview.audit") def test_user_with_totp_can_log_in(self, mock_audit: mock.Mock) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_method=MfaMethod.TOTP, mfa_secret_key=pyotp.random_base32(), + password__request=self.req, + password="secret", ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True + ) totp = pyotp.TOTP(user.mfa_secret_key) @@ -3042,16 +2830,18 @@ def test_user_with_totp_can_log_in(self, mock_audit: mock.Mock) -> None: @mock.patch("camcops_server.cc_modules.webview.audit") def test_user_with_hotp_can_log_in(self, mock_audit: mock.Mock) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_method=MfaMethod.HOTP_EMAIL, mfa_secret_key=pyotp.random_base32(), hotp_counter=1, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - - self.create_membership(user, self.group, may_use_webviewer=True) hotp = pyotp.HOTP(user.mfa_secret_key) multidict = MultiDict( @@ -3087,15 +2877,18 @@ def test_user_with_hotp_can_log_in(self, mock_audit: mock.Mock) -> None: self.assert_state_is_finished() def test_form_state_cleared_on_failed_login(self) -> None: - user = self.create_user( + user = UserFactory( username="test", mfa_method=MfaMethod.HOTP_EMAIL, mfa_secret_key=pyotp.random_base32(), hotp_counter=1, + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) hotp = pyotp.HOTP(user.mfa_secret_key) @@ -3123,14 +2916,17 @@ def test_form_state_cleared_on_failed_login(self) -> None: def test_user_cannot_log_in_if_timed_out(self) -> None: self.req.config.mfa_timeout_s = 600 - user = self.create_user( + user = UserFactory( username="test", mfa_method=MfaMethod.TOTP, mfa_secret_key=pyotp.random_base32(), + password__request=self.req, + password="secret", + ) + group = GroupFactory() + UserGroupMembershipFactory( + user_id=user.id, group_id=group.id, may_use_webviewer=True ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.create_membership(user, self.group, may_use_webviewer=True) totp = pyotp.TOTP(user.mfa_secret_key) @@ -3159,11 +2955,12 @@ def test_user_cannot_log_in_if_timed_out(self) -> None: mock_fail_timed_out.assert_called_once() def test_unprivileged_user_cannot_log_in(self) -> None: - user = self.create_user(username="test") - user.set_password(self.req, "secret") - self.dbsession.flush() - - self.create_membership(user, self.group, may_use_webviewer=False) + user = UserFactory( + username="test", password__request=self.req, password="secret" + ) + UserGroupMembershipFactory( + user_id=user.id, group_id=self.group.id, may_use_webviewer=False + ) multidict = MultiDict( [ @@ -3235,7 +3032,7 @@ def test_timed_out_false_when_no_authenticated_user(self) -> None: def test_timed_out_false_when_no_authentication_time(self) -> None: view = LoginView(self.req) - user = self.create_user(username="test") + user = UserFactory(username="test") # Should never be the case that we have a user ID but no # authentication time view.state["mfa_user_id"] = user.id @@ -3245,8 +3042,7 @@ def test_timed_out_false_when_no_authentication_time(self) -> None: class EditUserViewTests(BasicDatabaseTestCase): def test_redirect_on_cancel(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() + regular_user = UserFactory(username="regular_user") self.req.fake_request_post_from_dict({FormAction.CANCEL: "cancel"}) self.req.add_get_params( {ViewParam.USER_ID: str(regular_user.id)}, set_method_get=False @@ -3261,10 +3057,9 @@ def test_redirect_on_cancel(self) -> None: ) def test_raises_if_user_may_not_edit_another(self) -> None: - self.req.add_get_params({ViewParam.USER_ID: str(self.user.id)}) + self.req.add_get_params({ViewParam.USER_ID: str(self.system_user.id)}) - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() + regular_user = UserFactory(username="regular_user") self.req._debugging_user = regular_user with self.assertRaises(HTTPBadRequest) as cm: edit_user(self.req) @@ -3272,8 +3067,7 @@ def test_raises_if_user_may_not_edit_another(self) -> None: self.assertIn("Nobody may edit the system user", cm.exception.message) def test_superuser_sees_full_form(self) -> None: - superuser = self.create_user(username="admin", superuser=True) - self.dbsession.flush() + superuser = UserFactory(username="admin", superuser=True) self.req._debugging_user = superuser self.req.add_get_params({ViewParam.USER_ID: str(superuser.id)}) @@ -3283,12 +3077,13 @@ def test_superuser_sees_full_form(self) -> None: self.assertIn("Superuser (CAUTION!)", response.body.decode(UTF8)) def test_groupadmin_sees_groupadmin_form(self) -> None: - groupadmin = self.create_user(username="groupadmin") - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - self.create_membership(groupadmin, self.group, groupadmin=True) - self.create_membership(regular_user, self.group) - self.dbsession.flush() + groupadmin = UserFactory(username="groupadmin") + regular_user = UserFactory(username="regular_user") + group = GroupFactory() + UserGroupMembershipFactory( + user_id=groupadmin.id, group_id=group.id, groupadmin=True + ) + UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id) self.req._debugging_user = groupadmin self.req.add_get_params({ViewParam.USER_ID: str(regular_user.id)}) @@ -3300,9 +3095,8 @@ def test_groupadmin_sees_groupadmin_form(self) -> None: self.assertNotIn("Superuser (CAUTION!)", content) def test_raises_for_conflicting_user_name(self) -> None: - self.create_user(username="existing_user") - other_user = self.create_user(username="other_user") - self.dbsession.flush() + UserFactory(username="existing_user") + other_user = UserFactory(username="other_user") multidict = MultiDict( [ @@ -3321,13 +3115,12 @@ def test_raises_for_conflicting_user_name(self) -> None: self.assertIn("Can't rename user", cm.exception.message) def test_user_is_updated(self) -> None: - user = self.create_user( + user = UserFactory( username="old_username", fullname="Old Name", email="old@example.com", language="da_DK", ) - self.dbsession.flush() multidict = MultiDict( [ @@ -3352,9 +3145,8 @@ def test_user_is_updated(self) -> None: self.assertEqual(user.language, "en_GB") def test_user_is_added_to_group(self) -> None: - user = self.create_user(username="regular_user") - group = self.create_group("group") - self.dbsession.flush() + user = UserFactory() + group = GroupFactory() multidict = MultiDict( [ @@ -3377,15 +3169,19 @@ def test_user_is_added_to_group(self) -> None: mock_set_group_ids.assert_called_once_with([group.id]) def test_user_stays_in_group_the_groupadmin_cannot_edit(self) -> None: - regular_user = self.create_user(username="regular_user") - group_b_admin = self.create_user(username="group_b_admin") - group_a = self.create_group("group_a") - group_b = self.create_group("group_b") - self.dbsession.flush() - self.create_membership(regular_user, group_a) - self.create_membership(regular_user, group_b) - self.create_membership(group_b_admin, group_b, groupadmin=True) - self.dbsession.flush() + regular_user = UserFactory(username="regular_user") + group_b_admin = UserFactory(username="group_b_admin") + group_a = GroupFactory() + group_b = GroupFactory() + UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group_a.id + ) + UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group_b.id + ) + UserGroupMembershipFactory( + user_id=group_b_admin.id, group_id=group_b.id, groupadmin=True + ) self.req._debugging_user = group_b_admin multidict = MultiDict( @@ -3411,18 +3207,22 @@ def test_user_stays_in_group_the_groupadmin_cannot_edit(self) -> None: mock_set_group_ids.assert_called_once_with([group_a.id, group_b.id]) def test_upload_group_id_unset_when_membership_removed(self) -> None: - group_a = self.create_group("group_a") - group_b = self.create_group("group_b") - regular_user = self.create_user( - username="regular_user", upload_group=group_a - ) - groupadmin = self.create_user(username="groupadmin") - self.dbsession.flush() - self.create_membership(regular_user, group_a) - self.create_membership(regular_user, group_b) - self.create_membership(groupadmin, group_a, groupadmin=True) - self.create_membership(groupadmin, group_b, groupadmin=True) - self.dbsession.flush() + group_a = GroupFactory() + group_b = GroupFactory() + regular_user = UserFactory(upload_group=group_a) + groupadmin = UserFactory() + UserGroupMembershipFactory( + group_id=group_a.id, user_id=regular_user.id + ) + UserGroupMembershipFactory( + group_id=group_b.id, user_id=regular_user.id + ) + UserGroupMembershipFactory( + group_id=group_a.id, user_id=groupadmin.id, groupadmin=True + ) + UserGroupMembershipFactory( + group_id=group_b.id, user_id=groupadmin.id, groupadmin=True + ) self.req._debugging_user = groupadmin multidict = MultiDict( @@ -3445,20 +3245,24 @@ def test_upload_group_id_unset_when_membership_removed(self) -> None: self.assertIsNone(regular_user.upload_group_id) def test_get_form_values(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", fullname="Full Name", email="user@example.com", language="da_DK", ) - group_b_admin = self.create_user(username="group_b_admin") - group_a = self.create_group("group_a") - group_b = self.create_group("group_b") - self.dbsession.flush() - self.create_membership(regular_user, group_a) - self.create_membership(regular_user, group_b) - self.create_membership(group_b_admin, group_b, groupadmin=True) - self.dbsession.flush() + group_b_admin = UserFactory(username="group_b_admin") + group_a = GroupFactory() + group_b = GroupFactory() + UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group_a.id + ) + UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group_b.id + ) + UserGroupMembershipFactory( + user_id=group_b_admin.id, group_id=group_b.id, groupadmin=True + ) self.req._debugging_user = group_b_admin view = EditUserGroupAdminView(self.req) @@ -3480,12 +3284,11 @@ def test_get_form_values(self) -> None: self.assertEqual(form_values[ViewParam.GROUP_IDS], [group_b.id]) def test_raises_if_email_address_used_for_mfa(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.HOTP_EMAIL, email="user@example.com", ) - self.dbsession.flush() multidict = MultiDict( [ @@ -3509,11 +3312,9 @@ def test_raises_if_email_address_used_for_mfa(self) -> None: class EditOwnUserMfaViewTests(BasicDatabaseTestCase): def test_get_form_values_mfa_method(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.HOTP_SMS ) - self.dbsession.flush() - self.req._debugging_user = regular_user view = EditOwnUserMfaView(self.req) @@ -3527,13 +3328,11 @@ def test_get_form_values_mfa_method(self) -> None: ) def test_get_form_values_hotp_email(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.HOTP_EMAIL, email="regular_user@example.com", ) - self.dbsession.flush() - self.req._debugging_user = regular_user view = EditOwnUserMfaView(self.req) @@ -3556,13 +3355,11 @@ def test_get_form_values_hotp_email(self) -> None: self.assertEqual(form_values[ViewParam.EMAIL], regular_user.email) def test_get_form_values_hotp_sms(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.HOTP_SMS, - phone_number=phonenumbers.parse(TEST_PHONE_NUMBER), + phone_number=phonenumbers.parse(Fake.en_gb.valid_phone_number()), ) - self.dbsession.flush() - self.req._debugging_user = regular_user view = EditOwnUserMfaView(self.req) @@ -3587,11 +3384,9 @@ def test_get_form_values_hotp_sms(self) -> None: ) def test_get_form_values_totp(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.TOTP ) - self.dbsession.flush() - self.req._debugging_user = regular_user view = EditOwnUserMfaView(self.req) @@ -3613,13 +3408,11 @@ def test_get_form_values_totp(self) -> None: ) def test_user_can_set_secret_key(self) -> None: - regular_user = self.create_user(username="regular_user") + regular_user = UserFactory(username="regular_user") regular_user.mfa_method = MfaMethod.TOTP regular_user.ensure_mfa_info() # ... otherwise, the absence of e.g. the HOTP counter will cause a # secret key reset. - self.dbsession.flush() - mfa_secret_key = pyotp.random_base32() multidict = MultiDict( @@ -3640,9 +3433,7 @@ def test_user_can_set_secret_key(self) -> None: self.assertEqual(regular_user.mfa_secret_key, mfa_secret_key) def test_user_can_set_method_totp(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - + regular_user = UserFactory(username="regular_user") multidict = MultiDict( [ (ViewParam.MFA_METHOD, MfaMethod.TOTP), @@ -3660,9 +3451,7 @@ def test_user_can_set_method_totp(self) -> None: self.assertEqual(regular_user.mfa_method, MfaMethod.TOTP) def test_user_can_set_method_hotp_email(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - + regular_user = UserFactory(username="regular_user") multidict = MultiDict( [ (ViewParam.MFA_METHOD, MfaMethod.HOTP_EMAIL), @@ -3681,9 +3470,7 @@ def test_user_can_set_method_hotp_email(self) -> None: self.assertEqual(regular_user.hotp_counter, 0) def test_user_can_set_method_hotp_sms(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - + regular_user = UserFactory(username="regular_user") multidict = MultiDict( [ (ViewParam.MFA_METHOD, MfaMethod.HOTP_SMS), @@ -3702,11 +3489,9 @@ def test_user_can_set_method_hotp_sms(self) -> None: self.assertEqual(regular_user.hotp_counter, 0) def test_user_can_disable_mfa(self) -> None: - regular_user = self.create_user( + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.TOTP ) - self.dbsession.flush() - multidict = MultiDict( [ (ViewParam.MFA_METHOD, MfaMethod.NO_MFA), @@ -3730,13 +3515,14 @@ def test_user_can_disable_mfa(self) -> None: self.assertEqual(regular_user.mfa_method, MfaMethod.NO_MFA) def test_user_can_set_phone_number(self) -> None: - regular_user = self.create_user(username="regular_user") + regular_user = UserFactory(username="regular_user") regular_user.mfa_method = MfaMethod.HOTP_SMS - self.dbsession.flush() + + phone_number_str = Fake.en_gb.valid_phone_number() multidict = MultiDict( [ - (ViewParam.PHONE_NUMBER, TEST_PHONE_NUMBER), + (ViewParam.PHONE_NUMBER, phone_number_str), (FormAction.SUBMIT, "submit"), ] ) @@ -3749,16 +3535,15 @@ def test_user_can_set_phone_number(self) -> None: view.dispatch() - test_number = phonenumbers.parse(TEST_PHONE_NUMBER) - self.assertEqual(regular_user.phone_number, test_number) + self.assertEqual( + regular_user.phone_number, phonenumbers.parse(phone_number_str) + ) def test_user_can_set_email_address(self) -> None: - regular_user = self.create_user(username="regular_user") + regular_user = UserFactory(username="regular_user") # We're going to force this user to the e-mail verification step, so # we need to ensure it's set to use e-mail MFA: regular_user.mfa_method = MfaMethod.HOTP_EMAIL - self.dbsession.flush() - multidict = MultiDict( [ (ViewParam.EMAIL, "regular_user@example.com"), @@ -3797,8 +3582,7 @@ def test_raises_for_invalid_user(self) -> None: self.assertIn("Cannot find User with id:123", cm.exception.message) def test_raises_when_user_may_not_edit_other_user(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() + regular_user = UserFactory(username="regular_user") multidict = MultiDict( [ ("__start__", "new_password:mapping"), @@ -3812,7 +3596,7 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_ID: str(self.user.id)}, set_method_get=False + {ViewParam.USER_ID: str(self.system_user.id)}, set_method_get=False ) view = ChangeOtherPasswordView(self.req) @@ -3822,12 +3606,13 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None: self.assertIn("Nobody may edit the system user", cm.exception.message) def test_password_set(self) -> None: - groupadmin = self.create_user(username="groupadmin") - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - self.create_membership(groupadmin, self.group, groupadmin=True) - self.create_membership(regular_user, self.group) - self.dbsession.flush() + groupadmin = UserFactory(username="groupadmin") + regular_user = UserFactory(username="regular_user") + group = GroupFactory() + UserGroupMembershipFactory( + user_id=groupadmin.id, group_id=group.id, groupadmin=True + ) + UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id) self.assertFalse(regular_user.must_change_password) @@ -3863,13 +3648,13 @@ def test_password_set(self) -> None: self.assertIn("Password changed for user 'regular_user'", messages[0]) def test_user_forced_to_change_password(self) -> None: - groupadmin = self.create_user(username="groupadmin") - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() - self.create_membership(groupadmin, self.group, groupadmin=True) - self.create_membership(regular_user, self.group) - self.dbsession.flush() - + groupadmin = UserFactory(username="groupadmin") + regular_user = UserFactory(username="regular_user") + group = GroupFactory() + UserGroupMembershipFactory( + user_id=groupadmin.id, group_id=group.id, groupadmin=True + ) + UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id) multidict = MultiDict( [ (ViewParam.MUST_CHANGE_PASSWORD, "true"), @@ -3898,8 +3683,7 @@ def test_user_forced_to_change_password(self) -> None: mock_force_change.assert_called_once() def test_redirects_if_editing_own_account(self) -> None: - superuser = self.create_user(username="admin", superuser=True) - self.dbsession.flush() + superuser = UserFactory(username="admin", superuser=True) self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(superuser.id)}, set_method_get=False @@ -3919,7 +3703,7 @@ def test_redirects_if_editing_own_account(self) -> None: def test_user_sees_otp_form_if_mfa_setup( self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock ) -> None: - superuser = self.create_user( + superuser = UserFactory( username="admin", superuser=True, email="admin@example.com", @@ -3928,8 +3712,7 @@ def test_user_sees_otp_form_if_mfa_setup( hotp_counter=0, ) - user = self.create_user(username="user") - self.dbsession.flush() + user = UserFactory(username="user") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -3951,19 +3734,17 @@ def test_code_sent_if_mfa_setup(self) -> None: SmsBackendNames.CONSOLE, {} ) - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) - superuser = self.create_user( + phone_number_str = Fake.en_gb.valid_phone_number() + superuser = UserFactory( username="admin", superuser=True, email="admin@example.com", - phone_number=phone_number, + phone_number=phonenumbers.parse(phone_number_str), mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_SMS, hotp_counter=0, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -3977,12 +3758,12 @@ def test_code_sent_if_mfa_setup(self) -> None: expected_message = f"Your CamCOPS verification code is {expected_code}" self.assertIn( - ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message), + ConsoleSmsBackend.make_msg(phone_number_str, expected_message), logging_cm.output[0], ) def test_user_can_enter_token(self) -> None: - superuser = self.create_user( + superuser = UserFactory( username="admin", superuser=True, mfa_method=MfaMethod.HOTP_EMAIL, @@ -3990,9 +3771,7 @@ def test_user_can_enter_token(self) -> None: email="user@example.com", hotp_counter=1, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4021,7 +3800,7 @@ def test_user_can_enter_token(self) -> None: ) def test_form_state_cleared_on_invalid_token(self) -> None: - superuser = self.create_user( + superuser = UserFactory( username="superuser", superuser=True, mfa_method=MfaMethod.HOTP_EMAIL, @@ -4029,9 +3808,7 @@ def test_form_state_cleared_on_invalid_token(self) -> None: email="user@example.com", hotp_counter=1, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4059,15 +3836,13 @@ def test_form_state_cleared_on_invalid_token(self) -> None: def test_cannot_change_password_if_timed_out(self) -> None: self.req.config.mfa_timeout_s = 600 - superuser = self.create_user( + superuser = UserFactory( username="admin", superuser=True, mfa_method=MfaMethod.TOTP, mfa_secret_key=pyotp.random_base32(), ) - user = self.create_user(username="user") - self.dbsession.flush() - + user = UserFactory(username="user") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4119,14 +3894,13 @@ def test_raises_for_invalid_user(self) -> None: self.assertIn("Cannot find User with id:123", cm.exception.message) def test_raises_when_user_may_not_edit_other_user(self) -> None: - regular_user = self.create_user(username="regular_user") - self.dbsession.flush() + regular_user = UserFactory(username="regular_user") multidict = MultiDict([(FormAction.SUBMIT, "submit")]) self.req._debugging_user = regular_user self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_ID: str(self.user.id)}, set_method_get=False + {ViewParam.USER_ID: str(self.system_user.id)}, set_method_get=False ) view = EditOtherUserMfaView(self.req) @@ -4136,15 +3910,15 @@ def test_raises_when_user_may_not_edit_other_user(self) -> None: self.assertIn("Nobody may edit the system user", cm.exception.message) def test_disable_mfa(self) -> None: - groupadmin = self.create_user(username="groupadmin") - regular_user = self.create_user( + groupadmin = UserFactory(username="groupadmin") + regular_user = UserFactory( username="regular_user", mfa_method=MfaMethod.TOTP ) - self.dbsession.flush() - self.create_membership(groupadmin, self.group, groupadmin=True) - self.create_membership(regular_user, self.group) - self.dbsession.flush() - + group = GroupFactory() + UserGroupMembershipFactory( + user_id=groupadmin.id, group_id=group.id, groupadmin=True + ) + UserGroupMembershipFactory(user_id=regular_user.id, group_id=group.id) self.assertFalse(regular_user.must_change_password) multidict = MultiDict( @@ -4171,8 +3945,7 @@ def test_disable_mfa(self) -> None: ) def test_redirects_if_editing_own_account(self) -> None: - superuser = self.create_user(username="admin", superuser=True) - self.dbsession.flush() + superuser = UserFactory(username="admin", superuser=True) self.req._debugging_user = superuser self.req.add_get_params({ViewParam.USER_ID: str(superuser.id)}) @@ -4190,7 +3963,7 @@ def test_redirects_if_editing_own_account(self) -> None: def test_user_sees_otp_form_if_mfa_setup( self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock ) -> None: - superuser = self.create_user( + superuser = UserFactory( username="admin", superuser=True, email="admin@example.com", @@ -4199,8 +3972,7 @@ def test_user_sees_otp_form_if_mfa_setup( hotp_counter=0, ) - user = self.create_user(username="user") - self.dbsession.flush() + user = UserFactory(username="user") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4222,19 +3994,17 @@ def test_code_sent_if_mfa_setup(self) -> None: SmsBackendNames.CONSOLE, {} ) - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) - superuser = self.create_user( + phone_number_str = Fake.en_gb.valid_phone_number() + superuser = UserFactory( username="admin", superuser=True, email="admin@example.com", - phone_number=phone_number, + phone_number=phonenumbers.parse(phone_number_str), mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_SMS, hotp_counter=0, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4248,12 +4018,12 @@ def test_code_sent_if_mfa_setup(self) -> None: expected_message = f"Your CamCOPS verification code is {expected_code}" self.assertIn( - ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message), + ConsoleSmsBackend.make_msg(phone_number_str, expected_message), logging_cm.output[0], ) def test_user_can_enter_token(self) -> None: - superuser = self.create_user( + superuser = UserFactory( username="admin", superuser=True, mfa_method=MfaMethod.HOTP_EMAIL, @@ -4261,9 +4031,7 @@ def test_user_can_enter_token(self) -> None: email="user@example.com", hotp_counter=1, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4292,7 +4060,7 @@ def test_user_can_enter_token(self) -> None: ) def test_form_state_cleared_on_invalid_token(self) -> None: - superuser = self.create_user( + superuser = UserFactory( username="superuser", superuser=True, mfa_method=MfaMethod.HOTP_EMAIL, @@ -4300,9 +4068,7 @@ def test_form_state_cleared_on_invalid_token(self) -> None: email="user@example.com", hotp_counter=1, ) - user = self.create_user(username="user", email="user@example.com") - self.dbsession.flush() - + user = UserFactory(username="user", email="user@example.com") self.req._debugging_user = superuser self.req.add_get_params( {ViewParam.USER_ID: str(user.id)}, set_method_get=False @@ -4329,45 +4095,34 @@ def test_form_state_cleared_on_invalid_token(self) -> None: self.assert_state_is_clean() -class EditUserGroupMembershipViewTests(BasicDatabaseTestCase): - def setUp(self) -> None: - super().setUp() - - self.regular_user = User() - self.regular_user.username = "ruser" - self.regular_user.hashedpw = "" - self.dbsession.add(self.regular_user) - self.dbsession.flush() - - self.group_admin = User() - self.group_admin.username = "gadmin" - self.group_admin.hashedpw = "" - self.dbsession.add(self.group_admin) - self.dbsession.flush() +class EditUserGroupMembershipViewTests(DemoRequestTestCase): + def test_superuser_can_update_user_group_membership(self) -> None: + regular_user = UserFactory() + groupadmin = UserFactory() + group = GroupFactory() - admin_ugm = UserGroupMembership( - user_id=self.group_admin.id, group_id=self.group.id + UserGroupMembershipFactory( + user_id=groupadmin.id, + group_id=group.id, + groupadmin=True, ) - admin_ugm.groupadmin = True - self.dbsession.add(admin_ugm) - self.ugm = UserGroupMembership( - user_id=self.regular_user.id, group_id=self.group.id + ugm = UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group.id ) - self.dbsession.add(self.ugm) - self.dbsession.commit() - def test_superuser_can_update_user_group_membership(self) -> None: - self.assertFalse(self.ugm.may_upload) - self.assertFalse(self.ugm.may_register_devices) - self.assertFalse(self.ugm.may_use_webviewer) - self.assertFalse(self.ugm.view_all_patients_when_unfiltered) - self.assertFalse(self.ugm.may_dump_data) - self.assertFalse(self.ugm.may_run_reports) - self.assertFalse(self.ugm.may_add_notes) - self.assertFalse(self.ugm.may_manage_patients) - self.assertFalse(self.ugm.may_email_patients) - self.assertFalse(self.ugm.groupadmin) + self.req._debugging_user = groupadmin + + self.assertFalse(ugm.may_upload) + self.assertFalse(ugm.may_register_devices) + self.assertFalse(ugm.may_use_webviewer) + self.assertFalse(ugm.view_all_patients_when_unfiltered) + self.assertFalse(ugm.may_dump_data) + self.assertFalse(ugm.may_run_reports) + self.assertFalse(ugm.may_add_notes) + self.assertFalse(ugm.may_manage_patients) + self.assertFalse(ugm.may_email_patients) + self.assertFalse(ugm.groupadmin) multidict = MultiDict( [ @@ -4387,35 +4142,49 @@ def test_superuser_can_update_user_group_membership(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)}, + {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)}, set_method_get=False, ) with self.assertRaises(HTTPFound): edit_user_group_membership(self.req) - self.assertTrue(self.ugm.may_upload) - self.assertTrue(self.ugm.may_register_devices) - self.assertTrue(self.ugm.may_use_webviewer) - self.assertTrue(self.ugm.view_all_patients_when_unfiltered) - self.assertTrue(self.ugm.may_dump_data) - self.assertTrue(self.ugm.may_run_reports) - self.assertTrue(self.ugm.may_add_notes) - self.assertTrue(self.ugm.may_manage_patients) - self.assertTrue(self.ugm.may_email_patients) + self.assertTrue(ugm.may_upload) + self.assertTrue(ugm.may_register_devices) + self.assertTrue(ugm.may_use_webviewer) + self.assertTrue(ugm.view_all_patients_when_unfiltered) + self.assertTrue(ugm.may_dump_data) + self.assertTrue(ugm.may_run_reports) + self.assertTrue(ugm.may_add_notes) + self.assertTrue(ugm.may_manage_patients) + self.assertTrue(ugm.may_email_patients) def test_groupadmin_can_update_user_group_membership(self) -> None: - self.req._debugging_user = self.group_admin - - self.assertFalse(self.ugm.may_upload) - self.assertFalse(self.ugm.may_register_devices) - self.assertFalse(self.ugm.may_use_webviewer) - self.assertFalse(self.ugm.view_all_patients_when_unfiltered) - self.assertFalse(self.ugm.may_dump_data) - self.assertFalse(self.ugm.may_run_reports) - self.assertFalse(self.ugm.may_add_notes) - self.assertFalse(self.ugm.may_manage_patients) - self.assertFalse(self.ugm.may_email_patients) + regular_user = UserFactory() + groupadmin = UserFactory() + group = GroupFactory() + + UserGroupMembershipFactory( + user_id=groupadmin.id, + group_id=group.id, + groupadmin=True, + ) + + ugm = UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group.id + ) + + self.req._debugging_user = groupadmin + + self.assertFalse(ugm.may_upload) + self.assertFalse(ugm.may_register_devices) + self.assertFalse(ugm.may_use_webviewer) + self.assertFalse(ugm.view_all_patients_when_unfiltered) + self.assertFalse(ugm.may_dump_data) + self.assertFalse(ugm.may_run_reports) + self.assertFalse(ugm.may_add_notes) + self.assertFalse(ugm.may_manage_patients) + self.assertFalse(ugm.may_email_patients) multidict = MultiDict( [ @@ -4434,33 +4203,45 @@ def test_groupadmin_can_update_user_group_membership(self) -> None: self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)}, + {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)}, set_method_get=False, ) with self.assertRaises(HTTPFound): edit_user_group_membership(self.req) - self.assertTrue(self.ugm.may_upload) - self.assertTrue(self.ugm.may_register_devices) - self.assertTrue(self.ugm.may_use_webviewer) - self.assertTrue(self.ugm.view_all_patients_when_unfiltered) - self.assertTrue(self.ugm.may_dump_data) - self.assertTrue(self.ugm.may_run_reports) - self.assertTrue(self.ugm.may_add_notes) - self.assertTrue(self.ugm.may_manage_patients) - self.assertTrue(self.ugm.may_email_patients) + self.assertTrue(ugm.may_upload) + self.assertTrue(ugm.may_register_devices) + self.assertTrue(ugm.may_use_webviewer) + self.assertTrue(ugm.view_all_patients_when_unfiltered) + self.assertTrue(ugm.may_dump_data) + self.assertTrue(ugm.may_run_reports) + self.assertTrue(ugm.may_add_notes) + self.assertTrue(ugm.may_manage_patients) + self.assertTrue(ugm.may_email_patients) def test_raises_if_cant_edit_user(self) -> None: - self.ugm.user_id = self.user.id - self.dbsession.add(self.ugm) - self.dbsession.commit() + system_user = User.get_system_user(self.dbsession) + groupadmin = UserFactory() + group = GroupFactory() + + UserGroupMembershipFactory( + user_id=groupadmin.id, + group_id=group.id, + groupadmin=True, + ) + + system_ugm = UserGroupMembershipFactory( + user_id=system_user.id, group_id=group.id + ) + + self.req._debugging_user = groupadmin multidict = MultiDict([(FormAction.SUBMIT, "submit")]) self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)}, + {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(system_ugm.id)}, set_method_get=False, ) @@ -4470,22 +4251,22 @@ def test_raises_if_cant_edit_user(self) -> None: self.assertIn("Nobody may edit the system user", cm.exception.message) def test_raises_if_cant_administer_group(self) -> None: - group_a = self.create_group("groupa") - group_b = self.create_group("groupb") + group_a = GroupFactory() + group_b = GroupFactory() - user1 = self.create_user(username="user1") - user2 = self.create_user(username="user2") - self.dbsession.flush() + user1 = UserFactory() + user2 = UserFactory() # User 1 is a group administrator for group A, # User 2 is a member if group A - self.create_membership(user1, group_a, groupadmin=True) - self.create_membership(user2, group_a), + UserGroupMembershipFactory( + user_id=user1.id, group_id=group_a.id, groupadmin=True + ) + UserGroupMembershipFactory(user_id=user2.id, group_id=group_a.id), # User 1 is not an administrator of group B # User 2 is a member of group B - ugm = self.create_membership(user2, group_b) - self.dbsession.commit() + ugm = UserGroupMembershipFactory(user_id=user2.id, group_id=group_b.id) multidict = MultiDict([(FormAction.SUBMIT, "submit")]) @@ -4505,11 +4286,26 @@ def test_raises_if_cant_administer_group(self) -> None: ) def test_cancel_returns_to_users_list(self) -> None: + regular_user = UserFactory() + groupadmin = UserFactory() + group = GroupFactory() + + UserGroupMembershipFactory( + user_id=groupadmin.id, + group_id=group.id, + groupadmin=True, + ) + + ugm = UserGroupMembershipFactory( + user_id=regular_user.id, group_id=group.id + ) + + self.req._debugging_user = groupadmin multidict = MultiDict([(FormAction.CANCEL, "cancel")]) self.req.fake_request_post_from_dict(multidict) self.req.add_get_params( - {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(self.ugm.id)}, + {ViewParam.USER_GROUP_MEMBERSHIP_ID: str(ugm.id)}, set_method_get=False, ) @@ -4530,8 +4326,12 @@ def setUp(self) -> None: def test_user_can_change_password(self) -> None: new_password = "monkeybusiness" - user = self.create_user(username="user", mfa_method=MfaMethod.NO_MFA) - user.set_password(self.req, "secret") + user = UserFactory( + username="user", + mfa_method=MfaMethod.NO_MFA, + password__request=self.req, + password="secret", + ) multidict = MultiDict( [ (ViewParam.OLD_PASSWORD, "secret"), @@ -4558,7 +4358,7 @@ def test_user_can_change_password(self) -> None: self.assert_state_is_finished() def test_user_sees_expiry_message(self) -> None: - user = self.create_user( + user = UserFactory( username="user", mfa_method=MfaMethod.NO_MFA, must_change_password=True, @@ -4584,7 +4384,7 @@ def test_password_must_differ(self) -> None: def test_user_sees_otp_form_if_mfa_setup( self, mock_make_email: mock.Mock, mock_send_msg: mock.Mock ) -> None: - user = self.create_user( + user = UserFactory( username="user", email="user@example.com", mfa_method=MfaMethod.HOTP_EMAIL, @@ -4608,11 +4408,11 @@ def test_code_sent_if_mfa_setup(self) -> None: self.req.config.sms_backend = get_sms_backend( SmsBackendNames.CONSOLE, {} ) - phone_number = phonenumbers.parse(TEST_PHONE_NUMBER) - user = self.create_user( + phone_number_str = Fake.en_gb.valid_phone_number() + user = UserFactory( username="user", email="user@example.com", - phone_number=phone_number, + phone_number=phonenumbers.parse(phone_number_str), mfa_secret_key=pyotp.random_base32(), mfa_method=MfaMethod.HOTP_SMS, hotp_counter=0, @@ -4627,21 +4427,20 @@ def test_code_sent_if_mfa_setup(self) -> None: expected_message = f"Your CamCOPS verification code is {expected_code}" self.assertIn( - ConsoleSmsBackend.make_msg(TEST_PHONE_NUMBER, expected_message), + ConsoleSmsBackend.make_msg(phone_number_str, expected_message), logging_cm.output[0], ) def test_user_can_enter_token(self) -> None: - user = self.create_user( + user = UserFactory( username="user", mfa_method=MfaMethod.HOTP_EMAIL, mfa_secret_key=pyotp.random_base32(), email="user@example.com", hotp_counter=1, + password__request=self.req, + password="secret", ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.req._debugging_user = user hotp = pyotp.HOTP(user.mfa_secret_key) @@ -4667,16 +4466,15 @@ def test_user_can_enter_token(self) -> None: ) def test_form_state_cleared_on_invalid_token(self) -> None: - user = self.create_user( + user = UserFactory( username="user", mfa_method=MfaMethod.HOTP_EMAIL, mfa_secret_key=pyotp.random_base32(), email="user@example.com", hotp_counter=1, + password__request=self.req, + password="secret", ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.req._debugging_user = user hotp = pyotp.HOTP(user.mfa_secret_key) @@ -4701,14 +4499,13 @@ def test_form_state_cleared_on_invalid_token(self) -> None: def test_cannot_change_password_if_timed_out(self) -> None: self.req.config.mfa_timeout_s = 600 - user = self.create_user( + user = UserFactory( username="user", mfa_method=MfaMethod.TOTP, mfa_secret_key=pyotp.random_base32(), + password__request=self.req, + password="secret", ) - user.set_password(self.req, "secret") - self.dbsession.flush() - self.req._debugging_user = user totp = pyotp.TOTP(user.mfa_secret_key) @@ -4734,3 +4531,58 @@ def test_cannot_change_password_if_timed_out(self) -> None: view.dispatch() mock_fail_timed_out.assert_called_once() + + +class AddUserTests(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() + + self.groupadmin = self.req._debugging_user = UserFactory() + + def test_user_created(self) -> None: + group_1 = GroupFactory() + group_2 = GroupFactory() + + UserGroupMembershipFactory( + user_id=self.groupadmin.id, group_id=group_1.id, groupadmin=True + ) + UserGroupMembershipFactory( + user_id=self.groupadmin.id, group_id=group_2.id, groupadmin=True + ) + + multidict = MultiDict( + [ + ("_charset_", UTF8), + ("__formid__", "deform"), + (ViewParam.CSRF_TOKEN, self.req.session.get_csrf_token()), + (ViewParam.USERNAME, "test"), + ("__start__", "new_password:mapping"), + (ViewParam.NEW_PASSWORD, "monkeybusiness"), + ("new_password-confirm", "monkeybusiness"), + ("__end__", "new_password:mapping"), + (ViewParam.MUST_CHANGE_PASSWORD, "true"), + ("__start__", "group_ids:sequence"), + ("group_id_sequence", str(group_1.id)), + ("group_id_sequence", str(group_2.id)), + ("__end__", "group_ids:sequence"), + (FormAction.SUBMIT, "submit"), + ] + ) + self.req.fake_request_post_from_dict(multidict) + + with self.assertRaises(HTTPFound): + add_user(self.req) + + user = ( + self.dbsession.query(User) + .filter( + User.username == "test", + ) + .one_or_none() + ) + + self.assertIsNotNone(user) + + self.assertTrue(user.must_change_password) + self.assertIn(group_1.id, user.group_ids) + self.assertIn(group_2.id, user.group_ids) diff --git a/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py b/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py index 7a24da711..5733eae09 100644 --- a/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py +++ b/server/camcops_server/tasks/tests/apeq_cpft_perinatal_tests.py @@ -25,33 +25,26 @@ """ -from typing import Generator, Optional - import pendulum -from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase from camcops_server.tasks.apeq_cpft_perinatal import ( - APEQCPFTPerinatal, APEQCPFTPerinatalReport, ) - +from camcops_server.tasks.tests.factories import APEQCPFTPerinatalFactory # ============================================================================= # Unit tests # ============================================================================= -class APEQCPFTPerinatalReportTestCase(BasicDatabaseTestCase): +class APEQCPFTPerinatalReportTestCase(DemoRequestTestCase): COL_Q = 0 COL_TOTAL = 1 COL_RESPONSE_START = 2 COL_FF_WHY = 1 - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - def setUp(self) -> None: super().setUp() @@ -61,48 +54,9 @@ def setUp(self) -> None: self.report.start_datetime = None self.report.end_datetime = None - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - def create_task( - self, - q1: Optional[int], - q2: Optional[int], - q3: Optional[int], - q4: Optional[int], - q5: Optional[int], - q6: Optional[int], - ff_rating: int, - ff_why: str = None, - comments: str = None, - era: str = None, - ) -> None: - task = APEQCPFTPerinatal() - self.apply_standard_task_fields(task) - task.id = next(self.id_sequence) - task.q1 = q1 - task.q2 = q2 - task.q3 = q3 - task.q4 = q4 - task.q5 = q5 - task.q6 = q6 - task.ff_rating = ff_rating - task.ff_why = ff_why - task.comments = comments - - if era is not None: - task.when_created = pendulum.parse(era) - - self.dbsession.add(task) - class APEQCPFTPerinatalReportTests(APEQCPFTPerinatalReportTestCase): - def create_tasks(self) -> None: + def setUp(self) -> None: """ Creates 20 tasks. Should give us: @@ -132,43 +86,102 @@ def create_tasks(self) -> None: 5 - 35% """ - # q1 q2 q3 q4 q5 q6 ff - self.create_task(0, 1, 0, 0, 2, 2, 5, ff_why="ff_5_1") - self.create_task( - 0, 1, 1, 0, 2, 2, 5, ff_why="ff_5_2", comments="comments_2" + super().setUp() + + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=0, q4=0, q5=2, q6=2, ff_rating=5, ff_why="ff_5_1" + ) + APEQCPFTPerinatalFactory( + q1=0, + q2=1, + q3=1, + q4=0, + q5=2, + q6=2, + ff_rating=5, + ff_why="ff_5_2", + comments="comments_2", + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=1, q4=1, q5=2, q6=2, ff_rating=5 + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=1, q4=1, q5=2, q6=2, ff_rating=5 + ) + APEQCPFTPerinatalFactory( + q1=0, + q2=1, + q3=1, + q4=1, + q5=2, + q6=2, + ff_rating=5, + comments="comments_5", + ) + + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=2, q4=1, q5=2, q6=2, ff_rating=5 + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=5 + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=4, ff_why="ff_4_1" + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=2, q4=1, q5=1, q6=2, ff_rating=3 + ) + APEQCPFTPerinatalFactory( + q1=0, q2=1, q3=2, q4=1, q5=1, q6=1, ff_rating=3, ff_why="ff_3_1" + ) + + APEQCPFTPerinatalFactory( + q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2, ff_why="ff_2_1" + ) + APEQCPFTPerinatalFactory( + q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2 + ) + APEQCPFTPerinatalFactory( + q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=2, ff_why="ff_2_2" + ) + APEQCPFTPerinatalFactory( + q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=1, ff_why="ff_1_1" + ) + APEQCPFTPerinatalFactory( + q1=1, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=1, ff_why="ff_1_2" + ) + + APEQCPFTPerinatalFactory( + q1=2, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=0 + ) + APEQCPFTPerinatalFactory( + q1=2, q2=1, q3=2, q4=2, q5=1, q6=1, ff_rating=0 + ) + APEQCPFTPerinatalFactory( + q1=2, q2=1, q3=2, q4=2, q5=0, q6=None, ff_rating=0 + ) + APEQCPFTPerinatalFactory( + q1=2, q2=1, q3=2, q4=2, q5=0, q6=None, ff_rating=0 + ) + APEQCPFTPerinatalFactory( + q1=2, + q2=1, + q3=2, + q4=2, + q5=0, + q6=1, + ff_rating=0, + comments="comments_20", ) - self.create_task(0, 1, 1, 1, 2, 2, 5) - self.create_task(0, 1, 1, 1, 2, 2, 5) - self.create_task(0, 1, 1, 1, 2, 2, 5, comments="comments_5") - - self.create_task(0, 1, 2, 1, 2, 2, 5) - self.create_task(0, 1, 2, 1, 1, 2, 5) - self.create_task(0, 1, 2, 1, 1, 2, 4, ff_why="ff_4_1") - self.create_task(0, 1, 2, 1, 1, 2, 3) - self.create_task(0, 1, 2, 1, 1, 1, 3, ff_why="ff_3_1") - - self.create_task(1, 1, 2, 2, 1, 1, 2, ff_why="ff_2_1") - self.create_task(1, 1, 2, 2, 1, 1, 2) - self.create_task(1, 1, 2, 2, 1, 1, 2, ff_why="ff_2_2") - self.create_task(1, 1, 2, 2, 1, 1, 1, ff_why="ff_1_1") - self.create_task(1, 1, 2, 2, 1, 1, 1, ff_why="ff_1_2") - - self.create_task(2, 1, 2, 2, 1, 1, 0) - self.create_task(2, 1, 2, 2, 1, 1, 0) - self.create_task(2, 1, 2, 2, 0, None, 0) - self.create_task(2, 1, 2, 2, 0, None, 0) - self.create_task(2, 1, 2, 2, 0, 1, 0, comments="comments_20") - - self.dbsession.commit() def test_main_rows_contain_percentages(self) -> None: expected_percentages = [ - [20, 50, 25, 25], # q1 - [20, "", 100, ""], # q2 - [20, 5, 20, 75], # q3 - [20, 10, 40, 50], # q4 - [20, 15, 55, 30], # q5 - [18, "", 50, 50], # q6 + ["20", "50", "25", "25"], # q1 + ["20", "", "100", ""], # q2 + ["20", "5", "20", "75"], # q3 + ["20", "10", "40", "50"], # q4 + ["20", "15", "55", "30"], # q5 + ["18", "", "50", "50"], # q6 ] main_rows = self.report._get_main_rows(self.req) @@ -179,7 +192,7 @@ def test_main_rows_contain_percentages(self) -> None: for p in row[1:]: if p != "": - p = int(float(p)) + p = str(int(float(p))) percentages.append(p) @@ -229,73 +242,75 @@ def test_ff_why_rows_contain_reasons(self) -> None: def test_comments(self) -> None: expected_comments = ["comments_2", "comments_5", "comments_20"] + comments = self.report._get_comments(self.req) self.assertEqual(comments, expected_comments) class APEQCPFTPerinatalReportDateRangeTests(APEQCPFTPerinatalReportTestCase): - def create_tasks(self) -> None: - self.create_task( - 1, - 0, - 0, - 0, - 0, - 0, - 0, + def setUp(self) -> None: + super().setUp() + + APEQCPFTPerinatalFactory( + q1=1, + q2=0, + q3=0, + q4=0, + q5=0, + q6=0, + ff_rating=0, ff_why="ff why 1", comments="comments 1", - era="2018-10-01T00:00:00.000000+00:00", + when_created=pendulum.parse("2018-10-01"), ) - self.create_task( - 0, - 0, - 0, - 0, - 0, - 0, - 2, + APEQCPFTPerinatalFactory( + q1=0, + q2=0, + q3=0, + q4=0, + q5=0, + q6=0, + ff_rating=2, ff_why="ff why 2", comments="comments 2", - era="2018-10-02T00:00:00.000000+00:00", + when_created=pendulum.parse("2018-10-02"), ) - self.create_task( - 0, - 0, - 0, - 0, - 0, - 0, - 2, + APEQCPFTPerinatalFactory( + q1=0, + q2=0, + q3=0, + q4=0, + q5=0, + q6=0, + ff_rating=2, ff_why="ff why 3", comments="comments 3", - era="2018-10-03T00:00:00.000000+00:00", + when_created=pendulum.parse("2018-10-03"), ) - self.create_task( - 0, - 0, - 0, - 0, - 0, - 0, - 2, + APEQCPFTPerinatalFactory( + q1=0, + q2=0, + q3=0, + q4=0, + q5=0, + q6=0, + ff_rating=2, ff_why="ff why 4", comments="comments 4", - era="2018-10-04T00:00:00.000000+00:00", + when_created=pendulum.parse("2018-10-04"), ) - self.create_task( - 1, - 0, - 0, - 0, - 0, - 0, - 0, + APEQCPFTPerinatalFactory( + q1=1, + q2=0, + q3=0, + q4=0, + q5=0, + q6=0, + ff_rating=0, ff_why="ff why 5", comments="comments 5", - era="2018-10-05T00:00:00.000000+00:00", + when_created=pendulum.parse("2018-10-05"), ) - self.dbsession.commit() def test_main_rows_filtered_by_date(self) -> None: self.report.start_datetime = "2018-10-02T00:00:00.000000+00:00" diff --git a/server/camcops_server/tasks/tests/core10_tests.py b/server/camcops_server/tasks/tests/core10_tests.py index b089447f7..4fdf19139 100644 --- a/server/camcops_server/tasks/tests/core10_tests.py +++ b/server/camcops_server/tasks/tests/core10_tests.py @@ -27,86 +27,75 @@ import pendulum -from camcops_server.cc_modules.cc_patient import Patient -from camcops_server.cc_modules.tests.cc_report_tests import ( - AverageScoreReportTestCase, +from camcops_server.cc_modules.cc_testfactories import ( + PatientFactory, + UserFactory, ) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase from camcops_server.tasks.core10 import Core10, Core10Report +from camcops_server.tasks.tests.factories import Core10Factory -class Core10ReportTestCase(AverageScoreReportTestCase): +class Core10ReportTestCase(DemoRequestTestCase): + def setUp(self) -> None: + super().setUp() + + self.report = self.create_report() + self.req._debugging_user = UserFactory(superuser=True) + def create_report(self) -> Core10Report: return Core10Report(via_index=False) - def create_task( - self, - patient: Patient, - q1: int = 0, - q2: int = 0, - q3: int = 0, - q4: int = 0, - q5: int = 0, - q6: int = 0, - q7: int = 0, - q8: int = 0, - q9: int = 0, - q10: int = 0, - era: str = None, - ) -> None: - task = Core10() - self.apply_standard_task_fields(task) - task.id = next(self.task_id_sequence) - - task.patient_id = patient.id - - task.q1 = q1 - task.q2 = q2 - task.q3 = q3 - task.q4 = q4 - task.q5 = q5 - task.q6 = q6 - task.q7 = q7 - task.q8 = q8 - task.q9 = q9 - task.q10 = q10 - - if era is not None: - task.when_created = pendulum.parse(era) - # log.info(f"Creating task, when_created = {task.when_created}") - - self.dbsession.add(task) - - -class Core10ReportTests(Core10ReportTestCase): - def create_tasks(self) -> None: - self.patient_1 = self.create_patient(idnum_value=333) - self.patient_2 = self.create_patient(idnum_value=444) - self.patient_3 = self.create_patient(idnum_value=555) + +class Core10ReportTotalsTests(Core10ReportTestCase): + def setUp(self) -> None: + super().setUp() + + patient_1 = PatientFactory() + patient_2 = PatientFactory() + patient_3 = PatientFactory() # Initial average score = (8 + 6 + 4) / 3 = 6 # Latest average score = (2 + 3 + 4) / 3 = 3 - self.create_task( - patient=self.patient_1, q1=4, q2=4, era="2018-06-01" + Core10Factory( + patient=patient_1, + q1=4, + q2=4, + when_created=pendulum.parse("2018-06-01"), ) # Score 8 - self.create_task( - patient=self.patient_1, q7=1, q8=1, era="2018-10-04" + Core10Factory( + patient=patient_1, + q7=1, + q8=1, + when_created=pendulum.parse("2018-10-04"), ) # Score 2 - self.create_task( - patient=self.patient_2, q3=3, q4=3, era="2018-05-02" + Core10Factory( + patient=patient_2, + q3=3, + q4=3, + when_created=pendulum.parse("2018-05-02"), ) # Score 6 - self.create_task( - patient=self.patient_2, q3=2, q4=1, era="2018-10-03" + Core10Factory( + patient=patient_2, + q3=2, + q4=1, + when_created=pendulum.parse("2018-10-03"), ) # Score 3 - self.create_task( - patient=self.patient_3, q5=2, q6=2, era="2018-01-10" + Core10Factory( + patient=patient_3, + q5=2, + q6=2, + when_created=pendulum.parse("2018-01-10"), ) # Score 4 - self.create_task( - patient=self.patient_3, q9=1, q10=3, era="2018-10-01" + Core10Factory( + patient=patient_3, + q9=1, + q10=3, + when_created=pendulum.parse("2018-10-01"), ) # Score 4 - self.dbsession.commit() def test_row_has_totals_and_averages(self) -> None: pages = self.report.get_spreadsheet_pages(req=self.req) @@ -131,32 +120,49 @@ def test_no_rows_when_no_data(self) -> None: class Core10ReportDoubleCountingTests(Core10ReportTestCase): - def create_tasks(self) -> None: - self.patient_1 = self.create_patient(idnum_value=333) - self.patient_2 = self.create_patient(idnum_value=444) - self.patient_3 = self.create_patient(idnum_value=555) + def setUp(self) -> None: + super().setUp() + + patient_1 = PatientFactory() + patient_2 = PatientFactory() + patient_3 = PatientFactory() # Initial average score = (8 + 6 + 4) / 3 = 6 # Latest average score = ( 3 + 3) / 2 = 3 # Progress avg score = ( 3 + 1) / 2 = 2 ... NOT 3. - self.create_task( - patient=self.patient_1, q1=4, q2=4, era="2018-06-01" + + Core10Factory( + patient=patient_1, + q1=4, + q2=4, + when_created=pendulum.parse("2018-06-01"), ) # Score 8 - self.create_task( - patient=self.patient_2, q3=3, q4=3, era="2018-05-02" + Core10Factory( + patient=patient_2, + q3=3, + q4=3, + when_created=pendulum.parse("2018-05-02"), ) # Score 6 - self.create_task( - patient=self.patient_2, q3=2, q4=1, era="2018-10-03" + Core10Factory( + patient=patient_2, + q3=2, + q4=1, + when_created=pendulum.parse("2018-10-03"), ) # Score 3 - self.create_task( - patient=self.patient_3, q5=2, q6=2, era="2018-01-10" + Core10Factory( + patient=patient_3, + q5=2, + q6=2, + when_created=pendulum.parse("2018-01-10"), ) # Score 4 - self.create_task( - patient=self.patient_3, q9=1, q10=2, era="2018-10-01" + Core10Factory( + patient=patient_3, + q9=1, + q10=2, + when_created=pendulum.parse("2018-10-01"), ) # Score 3 - self.dbsession.commit() def test_record_does_not_appear_in_first_and_latest(self) -> None: pages = self.report.get_spreadsheet_pages(req=self.req) @@ -223,45 +229,73 @@ class Core10ReportDateRangeTests(Core10ReportTestCase): """ # noqa - def create_tasks(self) -> None: - self.patient_1 = self.create_patient(idnum_value=333) - self.patient_2 = self.create_patient(idnum_value=444) - self.patient_3 = self.create_patient(idnum_value=555) + def setUp(self) -> None: + super().setUp() + + patient_1 = PatientFactory() + patient_2 = PatientFactory() + patient_3 = PatientFactory() # 2018-06 average score = (8 + 6 + 4) / 3 = 6 # 2018-08 average score = (4 + 4 + 4) / 3 = 4 # 2018-10 average score = (2 + 3 + 4) / 3 = 3 - self.create_task( - patient=self.patient_1, q1=4, q2=4, era="2018-06-01" + Core10Factory( + patient=patient_1, + q1=4, + q2=4, + when_created=pendulum.parse("2018-06-01"), ) # Score 8 - self.create_task( - patient=self.patient_1, q7=3, q8=1, era="2018-08-01" + Core10Factory( + patient=patient_1, + q7=3, + q8=1, + when_created=pendulum.parse("2018-08-01"), ) # Score 4 - self.create_task( - patient=self.patient_1, q7=1, q8=1, era="2018-10-01" + Core10Factory( + patient=patient_1, + q7=1, + q8=1, + when_created=pendulum.parse("2018-10-01"), ) # Score 2 - self.create_task( - patient=self.patient_2, q3=3, q4=3, era="2018-06-01" + Core10Factory( + patient=patient_2, + q3=3, + q4=3, + when_created=pendulum.parse("2018-06-01"), ) # Score 6 - self.create_task( - patient=self.patient_2, q3=2, q4=2, era="2018-08-01" + Core10Factory( + patient=patient_2, + q3=2, + q4=2, + when_created=pendulum.parse("2018-08-01"), ) # Score 4 - self.create_task( - patient=self.patient_2, q3=1, q4=2, era="2018-10-01" + Core10Factory( + patient=patient_2, + q3=1, + q4=2, + when_created=pendulum.parse("2018-10-01"), ) # Score 3 - self.create_task( - patient=self.patient_3, q5=2, q6=2, era="2018-06-01" + Core10Factory( + patient=patient_3, + q5=2, + q6=2, + when_created=pendulum.parse("2018-06-01"), ) # Score 4 - self.create_task( - patient=self.patient_3, q9=1, q10=3, era="2018-08-01" + Core10Factory( + patient=patient_3, + q9=1, + q10=3, + when_created=pendulum.parse("2018-08-01"), ) # Score 4 - self.create_task( - patient=self.patient_3, q9=1, q10=3, era="2018-10-01" + Core10Factory( + patient=patient_3, + q9=1, + q10=3, + when_created=pendulum.parse("2018-10-01"), ) # Score 4 - self.dbsession.commit() self.dump_table( Core10.__tablename__, @@ -269,13 +303,8 @@ def create_tasks(self) -> None: ) def test_report_filtered_by_date_range(self) -> None: - # self.report.start_datetime = pendulum.parse("2018-05-01T00:00:00.000000+00:00") # noqa - self.report.start_datetime = pendulum.parse( - "2018-06-01T00:00:00.000000+00:00" - ) - self.report.end_datetime = pendulum.parse( - "2018-09-01T00:00:00.000000+00:00" - ) + self.report.start_datetime = "2018-06-01T00:00:00.000000+00:00" + self.report.end_datetime = "2018-09-01T00:00:00.000000+00:00" self.set_echo(True) pages = self.report.get_spreadsheet_pages(req=self.req) diff --git a/server/camcops_server/tasks/tests/factories.py b/server/camcops_server/tasks/tests/factories.py index f8d9e3cbf..673738812 100644 --- a/server/camcops_server/tasks/tests/factories.py +++ b/server/camcops_server/tasks/tests/factories.py @@ -29,13 +29,161 @@ import factory import pendulum +from typing import cast, TYPE_CHECKING +from camcops_server.cc_modules.cc_task import Task from camcops_server.cc_modules.cc_testfactories import ( + BlobFactory, + Fake, GenericTabletRecordFactory, ) +from camcops_server.tasks.ace3 import Ace3, MiniAce +from camcops_server.tasks.aims import Aims +from camcops_server.tasks.apeqpt import Apeqpt +from camcops_server.tasks.apeq_cpft_perinatal import APEQCPFTPerinatal +from camcops_server.tasks.aq import Aq +from camcops_server.tasks.asdas import Asdas +from camcops_server.tasks.audit import Audit, AuditC +from camcops_server.tasks.badls import Badls +from camcops_server.tasks.basdai import Basdai +from camcops_server.tasks.bdi import Bdi from camcops_server.tasks.bmi import Bmi +from camcops_server.tasks.bprs import Bprs +from camcops_server.tasks.bprse import Bprse +from camcops_server.tasks.cage import Cage +from camcops_server.tasks.cape42 import Cape42 +from camcops_server.tasks.caps import Caps +from camcops_server.tasks.cardinal_expdetthreshold import ( + CardinalExpDetThreshold, +) +from camcops_server.tasks.cardinal_expectationdetection import ( + CardinalExpectationDetection, +) +from camcops_server.tasks.cbir import CbiR +from camcops_server.tasks.ceca import CecaQ3 +from camcops_server.tasks.cesd import Cesd +from camcops_server.tasks.cesdr import Cesdr +from camcops_server.tasks.cet import Cet +from camcops_server.tasks.cgi_task import Cgi, CgiI +from camcops_server.tasks.cgisch import CgiSch +from camcops_server.tasks.chit import Chit +from camcops_server.tasks.cia import Cia +from camcops_server.tasks.cisr import Cisr +from camcops_server.tasks.ciwa import Ciwa +from camcops_server.tasks.contactlog import ContactLog +from camcops_server.tasks.cope import CopeBrief +from camcops_server.tasks.core10 import Core10 +from camcops_server.tasks.cpft_covid_medical import CpftCovidMedical +from camcops_server.tasks.cpft_lps import ( + CPFTLPSDischarge, + CPFTLPSReferral, + CPFTLPSResetResponseClock, +) +from camcops_server.tasks.cpft_research_preferences import ( + CpftResearchPreferences, +) +from camcops_server.tasks.dad import Dad +from camcops_server.tasks.das28 import Das28 +from camcops_server.tasks.dast import Dast +from camcops_server.tasks.deakin_s1_healthreview import DeakinS1HealthReview +from camcops_server.tasks.demoquestionnaire import DemoQuestionnaire +from camcops_server.tasks.demqol import Demqol, DemqolProxy +from camcops_server.tasks.diagnosis import ( + DiagnosisIcd10, + DiagnosisIcd10Item, + DiagnosisIcd9CM, + DiagnosisIcd9CMItem, +) +from camcops_server.tasks.distressthermometer import DistressThermometer +from camcops_server.tasks.edeq import Edeq +from camcops_server.tasks.elixhauserci import ElixhauserCI +from camcops_server.tasks.epds import Epds +from camcops_server.tasks.eq5d5l import Eq5d5l +from camcops_server.tasks.esspri import Esspri +from camcops_server.tasks.factg import Factg +from camcops_server.tasks.fast import Fast +from camcops_server.tasks.fft import Fft +from camcops_server.tasks.frs import Frs +from camcops_server.tasks.gad7 import Gad7 +from camcops_server.tasks.gaf import Gaf +from camcops_server.tasks.gbo import Gbogpc, Gbogras, Gbogres +from camcops_server.tasks.gds import Gds15 +from camcops_server.tasks.gmcpq import GMCPQ +from camcops_server.tasks.hads import Hads, HadsRespondent +from camcops_server.tasks.hama import Hama +from camcops_server.tasks.hamd import Hamd +from camcops_server.tasks.hamd7 import Hamd7 +from camcops_server.tasks.honos import Honos, Honos65, Honosca +from camcops_server.tasks.icd10depressive import Icd10Depressive +from camcops_server.tasks.icd10manic import Icd10Manic +from camcops_server.tasks.icd10mixed import Icd10Mixed +from camcops_server.tasks.icd10schizophrenia import Icd10Schizophrenia +from camcops_server.tasks.icd10schizotypal import Icd10Schizotypal +from camcops_server.tasks.icd10specpd import Icd10SpecPD +from camcops_server.tasks.ided3d import IDED3D +from camcops_server.tasks.iesr import Iesr +from camcops_server.tasks.ifs import Ifs +from camcops_server.tasks.irac import Irac +from camcops_server.tasks.isaaq10 import Isaaq10 +from camcops_server.tasks.isaaqed import IsaaqEd +from camcops_server.tasks.khandaker_insight_medical import ( + KhandakerInsightMedical, +) +from camcops_server.tasks.khandaker_mojo_medical import KhandakerMojoMedical +from camcops_server.tasks.khandaker_mojo_sociodemographics import ( + KhandakerMojoSociodemographics, +) +from camcops_server.tasks.khandaker_mojo_medicationtherapy import ( + KhandakerMojoMedicationTherapy, +) +from camcops_server.tasks.kirby_mcq import Kirby +from camcops_server.tasks.lynall_iam_life import LynallIamLifeEvents +from camcops_server.tasks.lynall_iam_medical import LynallIamMedicalHistory +from camcops_server.tasks.maas import Maas +from camcops_server.tasks.mast import Mast +from camcops_server.tasks.mds_updrs import MdsUpdrs +from camcops_server.tasks.mfi20 import Mfi20 +from camcops_server.tasks.moca import Moca +from camcops_server.tasks.nart import Nart +from camcops_server.tasks.npiq import NpiQ +from camcops_server.tasks.ors import Ors +from camcops_server.tasks.panss import Panss +from camcops_server.tasks.paradise24 import Paradise24 +from camcops_server.tasks.pbq import Pbq +from camcops_server.tasks.pcl5 import Pcl5 +from camcops_server.tasks.pcl import PclC, PclM, PclS +from camcops_server.tasks.pdss import Pdss +from camcops_server.tasks.perinatalpoem import PerinatalPoem +from camcops_server.tasks.photo import Photo, PhotoSequence +from camcops_server.tasks.phq15 import Phq15 +from camcops_server.tasks.phq8 import Phq8 from camcops_server.tasks.phq9 import Phq9 +from camcops_server.tasks.progressnote import ProgressNote +from camcops_server.tasks.pswq import Pswq +from camcops_server.tasks.psychiatricclerking import PsychiatricClerking +from camcops_server.tasks.qolbasic import QolBasic +from camcops_server.tasks.qolsg import QolSG +from camcops_server.tasks.rand36 import Rand36 +from camcops_server.tasks.rapid3 import Rapid3 +from camcops_server.tasks.service_satisfaction import ( + PatientSatisfaction, + ReferrerSatisfactionGen, + ReferrerSatisfactionSpec, +) +from camcops_server.tasks.sfmpq2 import Sfmpq2 +from camcops_server.tasks.shaps import Shaps +from camcops_server.tasks.slums import Slums +from camcops_server.tasks.smast import Smast +from camcops_server.tasks.srs import Srs +from camcops_server.tasks.suppsp import Suppsp +from camcops_server.tasks.wemwbs import Swemwbs, Wemwbs +from camcops_server.tasks.wsas import Wsas +from camcops_server.tasks.ybocs import Ybocs, YbocsSc +from camcops_server.tasks.zbi import Zbi12 + +if TYPE_CHECKING: + from factory.builder import Resolver class TaskFactory(GenericTabletRecordFactory): @@ -44,7 +192,11 @@ class Meta: @factory.lazy_attribute def when_created(self) -> pendulum.DateTime: - return pendulum.parse(self.default_iso_datetime) + datetime = cast( + pendulum.DateTime, pendulum.parse(self.default_iso_datetime) + ) + + return datetime class TaskHasPatientFactory(TaskFactory): @@ -53,6 +205,46 @@ class Meta: patient_id = None + @classmethod + def create(cls, *args, **kwargs) -> Task: + patient = kwargs.pop("patient", None) + if patient is not None: + if "patient_id" in kwargs: + raise TypeError( + "Both 'patient' and 'patient_id' keyword arguments " + "unexpectedly passed to a task factory. Use one or the " + "other." + ) + kwargs["patient_id"] = patient.id + + if "_device" not in kwargs: + kwargs["_device"] = patient._device + + if "_era" not in kwargs: + kwargs["_era"] = patient._era + + if "_group" not in kwargs: + kwargs["_group"] = patient._group + + if "_current" not in kwargs: + kwargs["_current"] = True + + return super().create(*args, **kwargs) + + +class APEQCPFTPerinatalFactory(TaskFactory): + class Meta: + model = APEQCPFTPerinatal + + id = factory.Sequence(lambda n: n) + + +class ApeqptFactory(TaskFactory): + class Meta: + model = Apeqpt + + id = factory.Sequence(lambda n: n) + class BmiFactory(TaskHasPatientFactory): class Meta: @@ -60,9 +252,1031 @@ class Meta: id = factory.Sequence(lambda n: n) + height_m = factory.LazyFunction(Fake.en_gb.height_m) + mass_kg = factory.LazyFunction(Fake.en_gb.mass_kg) + waist_cm = factory.LazyFunction(Fake.en_gb.waist_cm) + + +class Core10Factory(TaskHasPatientFactory): + class Meta: + model = Core10 + + id = factory.Sequence(lambda n: n) + + q1 = 0 + q2 = 0 + q3 = 0 + q4 = 0 + q5 = 0 + q6 = 0 + q7 = 0 + q8 = 0 + q9 = 0 + q10 = 0 + + +class DiagnosisIcd10Factory(TaskHasPatientFactory): + class Meta: + model = DiagnosisIcd10 + + id = factory.Sequence(lambda n: n) + + +class DiagnosisItemFactory(GenericTabletRecordFactory): + class Meta: + abstract = True + + +class DiagnosisIcd10ItemFactory(DiagnosisItemFactory): + class Meta: + model = DiagnosisIcd10Item + + id = factory.Sequence(lambda n: n) + + @classmethod + def create(cls, *args, **kwargs) -> DiagnosisIcd10Item: + diagnosis_icd10 = kwargs.pop("diagnosis_icd10", None) + if diagnosis_icd10 is not None: + if "diagnosis_icd10_id" in kwargs: + raise TypeError( + "Both 'diagnosis_icd10' and 'diagnosis_icd10_id' keyword " + "arguments unexpectedly passed to a task factory. Use one " + "or the other." + ) + kwargs["diagnosis_icd10_id"] = diagnosis_icd10.id + + if "_device" not in kwargs: + kwargs["_device"] = diagnosis_icd10._device + + if "_era" not in kwargs: + kwargs["_era"] = diagnosis_icd10._era + + if "_current" not in kwargs: + kwargs["_current"] = True + + return super().create(*args, **kwargs) + + +class DiagnosisIcd9CMFactory(TaskHasPatientFactory): + class Meta: + model = DiagnosisIcd9CM + + id = factory.Sequence(lambda n: n) + + +class DiagnosisIcd9CMItemFactory(DiagnosisItemFactory): + class Meta: + model = DiagnosisIcd9CMItem + + id = factory.Sequence(lambda n: n) + + @classmethod + def create(cls, *args, **kwargs) -> DiagnosisIcd9CMItem: + diagnosis_icd9cm = kwargs.pop("diagnosis_icd9cm", None) + if diagnosis_icd9cm is not None: + if "diagnosis_icd9cm_id" in kwargs: + raise TypeError( + "Both 'diagnosis_icd9cm' and 'diagnosis_icd9cm_id' " + "keyword arguments unexpectedly passed to a task factory. " + "Use one or the other." + ) + kwargs["diagnosis_icd9cm_id"] = diagnosis_icd9cm.id + + if "_device" not in kwargs: + kwargs["_device"] = diagnosis_icd9cm._device + + if "_era" not in kwargs: + kwargs["_era"] = diagnosis_icd9cm._era + + if "_current" not in kwargs: + kwargs["_current"] = True + + return super().create(*args, **kwargs) + + +class Gad7Factory(TaskHasPatientFactory): + class Meta: + model = Gad7 + + id = factory.Sequence(lambda n: n) + + +class KhandakerMojoMedicationTherapyFactory(TaskHasPatientFactory): + class Meta: + model = KhandakerMojoMedicationTherapy + + id = factory.Sequence(lambda n: n) + + +class MaasFactory(TaskHasPatientFactory): + class Meta: + model = Maas + + id = factory.Sequence(lambda n: n) + + +class PerinatalPoemFactory(TaskFactory): + class Meta: + model = PerinatalPoem + + id = factory.Sequence(lambda n: n) + class Phq9Factory(TaskHasPatientFactory): class Meta: model = Phq9 id = factory.Sequence(lambda n: n) + + +class Ace3Factory(TaskHasPatientFactory): + class Meta: + model = Ace3 + + id = factory.Sequence(lambda n: n) + + +class AimsFactory(TaskHasPatientFactory): + class Meta: + model = Aims + + id = factory.Sequence(lambda n: n) + + +class AqFactory(TaskHasPatientFactory): + class Meta: + model = Aq + + id = factory.Sequence(lambda n: n) + + +class AsdasFactory(TaskHasPatientFactory): + class Meta: + model = Asdas + + id = factory.Sequence(lambda n: n) + + +class AuditFactory(TaskHasPatientFactory): + class Meta: + model = Audit + + id = factory.Sequence(lambda n: n) + + +class AuditCFactory(TaskHasPatientFactory): + class Meta: + model = AuditC + + id = factory.Sequence(lambda n: n) + + +class BadlsFactory(TaskHasPatientFactory): + class Meta: + model = Badls + + id = factory.Sequence(lambda n: n) + + +class BasdaiFactory(TaskHasPatientFactory): + class Meta: + model = Basdai + + id = factory.Sequence(lambda n: n) + + +class BdiFactory(TaskHasPatientFactory): + class Meta: + model = Bdi + + id = factory.Sequence(lambda n: n) + + +class BprsFactory(TaskHasPatientFactory): + class Meta: + model = Bprs + + id = factory.Sequence(lambda n: n) + + +class BprseFactory(TaskHasPatientFactory): + class Meta: + model = Bprse + + id = factory.Sequence(lambda n: n) + + +class CageFactory(TaskHasPatientFactory): + class Meta: + model = Cage + + id = factory.Sequence(lambda n: n) + + +class Cape42Factory(TaskHasPatientFactory): + class Meta: + model = Cape42 + + id = factory.Sequence(lambda n: n) + + +class CapsFactory(TaskHasPatientFactory): + class Meta: + model = Caps + + id = factory.Sequence(lambda n: n) + + +class CardinalExpectationDetectionFactory(TaskHasPatientFactory): + class Meta: + model = CardinalExpectationDetection + + id = factory.Sequence(lambda n: n) + + +class CardinalExpDetThresholdFactory(TaskHasPatientFactory): + class Meta: + model = CardinalExpDetThreshold + + id = factory.Sequence(lambda n: n) + + +class CbiRFactory(TaskHasPatientFactory): + class Meta: + model = CbiR + + id = factory.Sequence(lambda n: n) + + +class CecaQ3Factory(TaskHasPatientFactory): + class Meta: + model = CecaQ3 + + id = factory.Sequence(lambda n: n) + + +class CesdFactory(TaskHasPatientFactory): + class Meta: + model = Cesd + + id = factory.Sequence(lambda n: n) + + +class CesdrFactory(TaskHasPatientFactory): + class Meta: + model = Cesdr + + id = factory.Sequence(lambda n: n) + + +class CetFactory(TaskHasPatientFactory): + class Meta: + model = Cet + + id = factory.Sequence(lambda n: n) + + +class CgiFactory(TaskHasPatientFactory): + class Meta: + model = Cgi + + id = factory.Sequence(lambda n: n) + + +class CgiIFactory(TaskHasPatientFactory): + class Meta: + model = CgiI + + id = factory.Sequence(lambda n: n) + + +class CgiSchFactory(TaskHasPatientFactory): + class Meta: + model = CgiSch + + id = factory.Sequence(lambda n: n) + + +class ChitFactory(TaskHasPatientFactory): + class Meta: + model = Chit + + id = factory.Sequence(lambda n: n) + + +class CiaFactory(TaskHasPatientFactory): + class Meta: + model = Cia + + id = factory.Sequence(lambda n: n) + + +class CisrFactory(TaskHasPatientFactory): + class Meta: + model = Cisr + + id = factory.Sequence(lambda n: n) + + +class CiwaFactory(TaskHasPatientFactory): + class Meta: + model = Ciwa + + id = factory.Sequence(lambda n: n) + + +class ContactLogFactory(TaskHasPatientFactory): + class Meta: + model = ContactLog + + id = factory.Sequence(lambda n: n) + + +class CopeBriefFactory(TaskHasPatientFactory): + class Meta: + model = CopeBrief + + id = factory.Sequence(lambda n: n) + + +class CpftCovidMedicalFactory(TaskHasPatientFactory): + class Meta: + model = CpftCovidMedical + + id = factory.Sequence(lambda n: n) + + +class CPFTLPSDischargeFactory(TaskHasPatientFactory): + class Meta: + model = CPFTLPSDischarge + + id = factory.Sequence(lambda n: n) + + +class CPFTLPSReferralFactory(TaskHasPatientFactory): + class Meta: + model = CPFTLPSReferral + + id = factory.Sequence(lambda n: n) + + +class CPFTLPSResetResponseClockFactory(TaskHasPatientFactory): + class Meta: + model = CPFTLPSResetResponseClock + + id = factory.Sequence(lambda n: n) + + +class CpftResearchPreferencesFactory(TaskHasPatientFactory): + class Meta: + model = CpftResearchPreferences + + id = factory.Sequence(lambda n: n) + + +class DadFactory(TaskHasPatientFactory): + class Meta: + model = Dad + + id = factory.Sequence(lambda n: n) + + +class Das28Factory(TaskHasPatientFactory): + class Meta: + model = Das28 + + id = factory.Sequence(lambda n: n) + + +class DastFactory(TaskHasPatientFactory): + class Meta: + model = Dast + + id = factory.Sequence(lambda n: n) + + +class DeakinS1HealthReviewFactory(TaskHasPatientFactory): + class Meta: + model = DeakinS1HealthReview + + id = factory.Sequence(lambda n: n) + + +class DemoQuestionnaireFactory(TaskFactory): + class Meta: + model = DemoQuestionnaire + + id = factory.Sequence(lambda n: n) + + +class DemqolFactory(TaskHasPatientFactory): + class Meta: + model = Demqol + + id = factory.Sequence(lambda n: n) + + +class DemqolProxyFactory(TaskHasPatientFactory): + class Meta: + model = DemqolProxy + + id = factory.Sequence(lambda n: n) + + +class DistressThermometerFactory(TaskHasPatientFactory): + class Meta: + model = DistressThermometer + + id = factory.Sequence(lambda n: n) + + +class EdeqFactory(TaskHasPatientFactory): + class Meta: + model = Edeq + + id = factory.Sequence(lambda n: n) + + +class ElixhauserCIFactory(TaskHasPatientFactory): + class Meta: + model = ElixhauserCI + + id = factory.Sequence(lambda n: n) + + +class EpdsFactory(TaskHasPatientFactory): + class Meta: + model = Epds + + id = factory.Sequence(lambda n: n) + + +class Eq5d5lFactory(TaskHasPatientFactory): + class Meta: + model = Eq5d5l + + id = factory.Sequence(lambda n: n) + + +class EsspriFactory(TaskHasPatientFactory): + class Meta: + model = Esspri + + id = factory.Sequence(lambda n: n) + + +class FactgFactory(TaskHasPatientFactory): + class Meta: + model = Factg + + id = factory.Sequence(lambda n: n) + + +class FastFactory(TaskHasPatientFactory): + class Meta: + model = Fast + + id = factory.Sequence(lambda n: n) + + +class FftFactory(TaskHasPatientFactory): + class Meta: + model = Fft + + id = factory.Sequence(lambda n: n) + + +class FrsFactory(TaskHasPatientFactory): + class Meta: + model = Frs + + id = factory.Sequence(lambda n: n) + + +class GafFactory(TaskHasPatientFactory): + class Meta: + model = Gaf + + id = factory.Sequence(lambda n: n) + + +class GbogpcFactory(TaskHasPatientFactory): + class Meta: + model = Gbogpc + + id = factory.Sequence(lambda n: n) + + +class GbograsFactory(TaskHasPatientFactory): + class Meta: + model = Gbogras + + id = factory.Sequence(lambda n: n) + + +class GbogresFactory(TaskHasPatientFactory): + class Meta: + model = Gbogres + + id = factory.Sequence(lambda n: n) + + +class Gds15Factory(TaskHasPatientFactory): + class Meta: + model = Gds15 + + id = factory.Sequence(lambda n: n) + + +class GMCPQFactory(TaskFactory): + class Meta: + model = GMCPQ + + id = factory.Sequence(lambda n: n) + + +class HadsFactory(TaskHasPatientFactory): + class Meta: + model = Hads + + id = factory.Sequence(lambda n: n) + + +class HadsRespondentFactory(TaskHasPatientFactory): + class Meta: + model = HadsRespondent + + id = factory.Sequence(lambda n: n) + + +class HamaFactory(TaskHasPatientFactory): + class Meta: + model = Hama + + id = factory.Sequence(lambda n: n) + + +class HamdFactory(TaskHasPatientFactory): + class Meta: + model = Hamd + + id = factory.Sequence(lambda n: n) + + +class Hamd7Factory(TaskHasPatientFactory): + class Meta: + model = Hamd7 + + id = factory.Sequence(lambda n: n) + + +class HonosFactory(TaskHasPatientFactory): + class Meta: + model = Honos + + id = factory.Sequence(lambda n: n) + + +class Honos65Factory(TaskHasPatientFactory): + class Meta: + model = Honos65 + + id = factory.Sequence(lambda n: n) + + +class HonoscaFactory(TaskHasPatientFactory): + class Meta: + model = Honosca + + id = factory.Sequence(lambda n: n) + + +class Icd10DepressiveFactory(TaskHasPatientFactory): + class Meta: + model = Icd10Depressive + + id = factory.Sequence(lambda n: n) + + +class Icd10ManicFactory(TaskHasPatientFactory): + class Meta: + model = Icd10Manic + + id = factory.Sequence(lambda n: n) + + +class Icd10MixedFactory(TaskHasPatientFactory): + class Meta: + model = Icd10Mixed + + id = factory.Sequence(lambda n: n) + + +class Icd10SchizophreniaFactory(TaskHasPatientFactory): + class Meta: + model = Icd10Schizophrenia + + id = factory.Sequence(lambda n: n) + + +class Icd10SchizotypalFactory(TaskHasPatientFactory): + class Meta: + model = Icd10Schizotypal + + id = factory.Sequence(lambda n: n) + + +class Icd10SpecPDFactory(TaskHasPatientFactory): + class Meta: + model = Icd10SpecPD + + id = factory.Sequence(lambda n: n) + + +class IDED3DFactory(TaskHasPatientFactory): + class Meta: + model = IDED3D + + id = factory.Sequence(lambda n: n) + + +class IesrFactory(TaskHasPatientFactory): + class Meta: + model = Iesr + + id = factory.Sequence(lambda n: n) + + +class IfsFactory(TaskHasPatientFactory): + class Meta: + model = Ifs + + id = factory.Sequence(lambda n: n) + + +class IracFactory(TaskHasPatientFactory): + class Meta: + model = Irac + + id = factory.Sequence(lambda n: n) + + +class Isaaq10Factory(TaskHasPatientFactory): + class Meta: + model = Isaaq10 + + id = factory.Sequence(lambda n: n) + + +class IsaaqEdFactory(TaskHasPatientFactory): + class Meta: + model = IsaaqEd + + id = factory.Sequence(lambda n: n) + + +class KhandakerInsightMedicalFactory(TaskHasPatientFactory): + class Meta: + model = KhandakerInsightMedical + + id = factory.Sequence(lambda n: n) + + +class KhandakerMojoMedicalFactory(TaskHasPatientFactory): + class Meta: + model = KhandakerMojoMedical + + id = factory.Sequence(lambda n: n) + + +class KhandakerMojoSociodemographicsFactory(TaskHasPatientFactory): + class Meta: + model = KhandakerMojoSociodemographics + + id = factory.Sequence(lambda n: n) + + +class KirbyFactory(TaskHasPatientFactory): + class Meta: + model = Kirby + + id = factory.Sequence(lambda n: n) + + +class LynallIamMedicalHistoryFactory(TaskHasPatientFactory): + class Meta: + model = LynallIamMedicalHistory + + id = factory.Sequence(lambda n: n) + + +class LynallIamLifeEventsFactory(TaskHasPatientFactory): + class Meta: + model = LynallIamLifeEvents + + id = factory.Sequence(lambda n: n) + + +class MastFactory(TaskHasPatientFactory): + class Meta: + model = Mast + + id = factory.Sequence(lambda n: n) + + +class MdsUpdrsFactory(TaskHasPatientFactory): + class Meta: + model = MdsUpdrs + + id = factory.Sequence(lambda n: n) + + +class Mfi20Factory(TaskHasPatientFactory): + class Meta: + model = Mfi20 + + id = factory.Sequence(lambda n: n) + + +class MiniAceFactory(TaskHasPatientFactory): + class Meta: + model = MiniAce + + id = factory.Sequence(lambda n: n) + + +class MocaFactory(TaskHasPatientFactory): + class Meta: + model = Moca + + id = factory.Sequence(lambda n: n) + + +class NartFactory(TaskHasPatientFactory): + class Meta: + model = Nart + + id = factory.Sequence(lambda n: n) + + +class NpiQFactory(TaskHasPatientFactory): + class Meta: + model = NpiQ + + id = factory.Sequence(lambda n: n) + + +class OrsFactory(TaskHasPatientFactory): + class Meta: + model = Ors + + id = factory.Sequence(lambda n: n) + + +class PanssFactory(TaskHasPatientFactory): + class Meta: + model = Panss + + id = factory.Sequence(lambda n: n) + + +class Paradise24Factory(TaskHasPatientFactory): + class Meta: + model = Paradise24 + + id = factory.Sequence(lambda n: n) + + +class PbqFactory(TaskHasPatientFactory): + class Meta: + model = Pbq + + id = factory.Sequence(lambda n: n) + + +class Pcl5Factory(TaskHasPatientFactory): + class Meta: + model = Pcl5 + + id = factory.Sequence(lambda n: n) + + +class PclCFactory(TaskHasPatientFactory): + class Meta: + model = PclC + + id = factory.Sequence(lambda n: n) + + +class PclMFactory(TaskHasPatientFactory): + class Meta: + model = PclM + + id = factory.Sequence(lambda n: n) + + +class PclSFactory(TaskHasPatientFactory): + class Meta: + model = PclS + + id = factory.Sequence(lambda n: n) + + +class PdssFactory(TaskHasPatientFactory): + class Meta: + model = Pdss + + id = factory.Sequence(lambda n: n) + + +class PhotoFactory(TaskHasPatientFactory): + class Meta: + model = Photo + + id = factory.Sequence(lambda n: n) + + @factory.post_generation + def create_blob( + obj: "Resolver", create: bool, extracted: None, **kwargs + ) -> None: + if not create: + return + + obj.photo = BlobFactory.create( + tablename=obj.tablename, tablepk=obj.id, **kwargs + ) + + +class PhotoSequenceFactory(TaskHasPatientFactory): + class Meta: + model = PhotoSequence + + id = factory.Sequence(lambda n: n) + + +class Phq15Factory(TaskHasPatientFactory): + class Meta: + model = Phq15 + + id = factory.Sequence(lambda n: n) + + +class Phq8Factory(TaskHasPatientFactory): + class Meta: + model = Phq8 + + id = factory.Sequence(lambda n: n) + + +class ProgressNoteFactory(TaskHasPatientFactory): + class Meta: + model = ProgressNote + + id = factory.Sequence(lambda n: n) + + +class PswqFactory(TaskHasPatientFactory): + class Meta: + model = Pswq + + id = factory.Sequence(lambda n: n) + + +class PsychiatricClerkingFactory(TaskHasPatientFactory): + class Meta: + model = PsychiatricClerking + + id = factory.Sequence(lambda n: n) + + +class PatientSatisfactionFactory(TaskHasPatientFactory): + class Meta: + model = PatientSatisfaction + + id = factory.Sequence(lambda n: n) + + +class QolBasicFactory(TaskHasPatientFactory): + class Meta: + model = QolBasic + + id = factory.Sequence(lambda n: n) + + +class QolSGFactory(TaskHasPatientFactory): + class Meta: + model = QolSG + + id = factory.Sequence(lambda n: n) + + +class Rand36Factory(TaskHasPatientFactory): + class Meta: + model = Rand36 + + id = factory.Sequence(lambda n: n) + + +class Rapid3Factory(TaskHasPatientFactory): + class Meta: + model = Rapid3 + + id = factory.Sequence(lambda n: n) + + +class ReferrerSatisfactionGenFactory(TaskFactory): + class Meta: + model = ReferrerSatisfactionGen + + id = factory.Sequence(lambda n: n) + + +class ReferrerSatisfactionSpecFactory(TaskHasPatientFactory): + class Meta: + model = ReferrerSatisfactionSpec + + id = factory.Sequence(lambda n: n) + + +class Sfmpq2Factory(TaskHasPatientFactory): + class Meta: + model = Sfmpq2 + + id = factory.Sequence(lambda n: n) + + +class ShapsFactory(TaskHasPatientFactory): + class Meta: + model = Shaps + + id = factory.Sequence(lambda n: n) + + +class SlumsFactory(TaskHasPatientFactory): + class Meta: + model = Slums + + id = factory.Sequence(lambda n: n) + + +class SmastFactory(TaskHasPatientFactory): + class Meta: + model = Smast + + id = factory.Sequence(lambda n: n) + + +class SrsFactory(TaskHasPatientFactory): + class Meta: + model = Srs + + id = factory.Sequence(lambda n: n) + + +class SuppspFactory(TaskHasPatientFactory): + class Meta: + model = Suppsp + + id = factory.Sequence(lambda n: n) + + +class SwemwbsFactory(TaskHasPatientFactory): + class Meta: + model = Swemwbs + + id = factory.Sequence(lambda n: n) + + +class WemwbsFactory(TaskHasPatientFactory): + class Meta: + model = Wemwbs + + id = factory.Sequence(lambda n: n) + + +class WsasFactory(TaskHasPatientFactory): + class Meta: + model = Wsas + + id = factory.Sequence(lambda n: n) + + +class YbocsFactory(TaskHasPatientFactory): + class Meta: + model = Ybocs + + id = factory.Sequence(lambda n: n) + + +class YbocsScFactory(TaskHasPatientFactory): + class Meta: + model = YbocsSc + + id = factory.Sequence(lambda n: n) + + +class Zbi12Factory(TaskHasPatientFactory): + class Meta: + model = Zbi12 + + id = factory.Sequence(lambda n: n) diff --git a/server/camcops_server/tasks/tests/maas_tests.py b/server/camcops_server/tasks/tests/maas_tests.py index 9b0566991..b8c45f89c 100644 --- a/server/camcops_server/tasks/tests/maas_tests.py +++ b/server/camcops_server/tasks/tests/maas_tests.py @@ -27,48 +27,53 @@ import pendulum -from camcops_server.cc_modules.cc_patient import Patient -from camcops_server.cc_modules.tests.cc_report_tests import ( - AverageScoreReportTestCase, +from camcops_server.cc_modules.cc_testfactories import ( + PatientFactory, + UserFactory, ) +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase from camcops_server.tasks.maas import Maas, MaasReport +from camcops_server.tasks.tests.factories import MaasFactory -class MaasReportTests(AverageScoreReportTestCase): +class MaasReportTests(DemoRequestTestCase): PROGRESS_COL = 4 - def create_report(self) -> MaasReport: - return MaasReport(via_index=False) + def setUp(self) -> None: + super().setUp() - def create_tasks(self) -> None: - self.patient_1 = self.create_patient() + self.report = self.create_report() + self.req._debugging_user = UserFactory(superuser=True) - self.create_task( - patient=self.patient_1, q1=2, q2=2, era="2019-03-01" - ) # total 17 + 2 + 2 - self.create_task( - patient=self.patient_1, q1=5, q2=5, era="2019-06-01" - ) # total 17 + 5 + 5 - self.dbsession.commit() + patient = PatientFactory() - def create_task(self, patient: Patient, era: str = None, **kwargs) -> None: - task = Maas() - self.apply_standard_task_fields(task) - task.id = next(self.task_id_sequence) + # Default to answering 1 to everything + response_dict = {q: 1 for q in Maas.TASK_FIELDS} - task.patient_id = patient.id - for fieldname in Maas.TASK_FIELDS: - value = kwargs.get(fieldname, 1) - setattr(task, fieldname, value) + # Same patient completing the task at different intervals + response_dict["q1"] = 2 + response_dict["q2"] = 2 + MaasFactory( + patient=patient, + when_created=pendulum.parse("2019-03-01"), + **response_dict, + ) # total 17 * 1 + 2 + 2 = 21 - if era is not None: - task.when_created = pendulum.parse(era) + response_dict["q1"] = 5 + response_dict["q2"] = 5 + MaasFactory( + patient=patient, + when_created=pendulum.parse("2019-06-01"), + **response_dict, + ) # total 17 * 1 + 5 + 5 = 27 - self.dbsession.add(task) + def create_report(self) -> MaasReport: + return MaasReport(via_index=False) def test_average_progress_is_positive(self) -> None: pages = self.report.get_spreadsheet_pages(req=self.req) + # Numbers as above expected_progress = 27 - 21 actual_progress = pages[0].plainrows[0][self.PROGRESS_COL] diff --git a/server/camcops_server/tasks/tests/perinatalpoem_tests.py b/server/camcops_server/tasks/tests/perinatalpoem_tests.py index 0689b84f4..975839dfa 100644 --- a/server/camcops_server/tasks/tests/perinatalpoem_tests.py +++ b/server/camcops_server/tasks/tests/perinatalpoem_tests.py @@ -25,27 +25,16 @@ """ -from typing import Generator - -import pendulum - -from camcops_server.cc_modules.cc_unittest import BasicDatabaseTestCase -from camcops_server.tasks.perinatalpoem import ( - PerinatalPoem, - PerinatalPoemReport, -) - +from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase +from camcops_server.tasks.perinatalpoem import PerinatalPoemReport +from camcops_server.tasks.tests.factories import PerinatalPoemFactory # ============================================================================= # Unit tests # ============================================================================= -class PerinatalPoemReportTestCase(BasicDatabaseTestCase): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.id_sequence = self.get_id() - +class PerinatalPoemReportTestCase(DemoRequestTestCase): def setUp(self) -> None: super().setUp() @@ -55,28 +44,6 @@ def setUp(self) -> None: self.report.start_datetime = None self.report.end_datetime = None - @staticmethod - def get_id() -> Generator[int, None, None]: - i = 1 - - while True: - yield i - i += 1 - - def create_task(self, **kwargs) -> None: - task = PerinatalPoem() - self.apply_standard_task_fields(task) - task.id = next(self.id_sequence) - - era = kwargs.pop("era", None) - if era is not None: - task.when_created = pendulum.parse(era) - - for name, value in kwargs.items(): - setattr(task, name, value) - - self.dbsession.add(task) - class PerinatalPoemReportTests(PerinatalPoemReportTestCase): """ @@ -84,10 +51,16 @@ class PerinatalPoemReportTests(PerinatalPoemReportTestCase): sanity checking here """ - def create_tasks(self): - self.create_task(general_comments="comment 1") - self.create_task(general_comments="comment 2") - self.create_task(general_comments="comment 3") + def setUp(self) -> None: + super().setUp() + + t1 = PerinatalPoemFactory(general_comments="comment 1") + t2 = PerinatalPoemFactory(general_comments="comment 2") + t3 = PerinatalPoemFactory(general_comments="comment 3") + + self.dbsession.add(t1) + self.dbsession.add(t2) + self.dbsession.add(t3) self.dbsession.commit() @@ -146,27 +119,35 @@ def test_comments(self) -> None: class PerinatalPoemReportDateRangeTests(PerinatalPoemReportTestCase): - def create_tasks(self) -> None: - self.create_task( + def setUp(self) -> None: + super().setUp() + + t1 = PerinatalPoemFactory( general_comments="comments 1", - era="2018-10-01T00:00:00.000000+00:00", + when_created="2018-10-01T00:00:00.000000+00:00", ) - self.create_task( + t2 = PerinatalPoemFactory( general_comments="comments 2", - era="2018-10-02T00:00:00.000000+00:00", + when_created="2018-10-02T00:00:00.000000+00:00", ) - self.create_task( + t3 = PerinatalPoemFactory( general_comments="comments 3", - era="2018-10-03T00:00:00.000000+00:00", + when_created="2018-10-03T00:00:00.000000+00:00", ) - self.create_task( + t4 = PerinatalPoemFactory( general_comments="comments 4", - era="2018-10-04T00:00:00.000000+00:00", + when_created="2018-10-04T00:00:00.000000+00:00", ) - self.create_task( + t5 = PerinatalPoemFactory( general_comments="comments 5", - era="2018-10-05T00:00:00.000000+00:00", + when_created="2018-10-05T00:00:00.000000+00:00", ) + self.dbsession.add(t1) + self.dbsession.add(t2) + self.dbsession.add(t3) + self.dbsession.add(t4) + self.dbsession.add(t5) + self.dbsession.commit() def test_comments_filtered_by_date(self) -> None: diff --git a/server/camcops_server/tools/generate_task_factories.py b/server/camcops_server/tools/generate_task_factories.py new file mode 100755 index 000000000..ab5e4f0eb --- /dev/null +++ b/server/camcops_server/tools/generate_task_factories.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +""" +camcops_server/tools/generate_task_factories.py + +=============================================================================== + + Copyright (C) 2012, University of Cambridge, Department of Psychiatry. + Created by Rudolf Cardinal (rnc1001@cam.ac.uk). + + This file is part of CamCOPS. + + CamCOPS is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + CamCOPS is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with CamCOPS. If not, see . + +=============================================================================== + +Script to generate skeleton Factory Boy test factories for +camcops_server/tasks/tests/factories.py + +Probably not needed anymore. + +""" + +from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin +from camcops_server.tasks.tests import factories as task_factories + + +def main() -> None: + task_dict = {} + + for cls in Task.all_subclasses_by_tablename(): + task_class_name = cls.__name__ + factory_name = f"{task_class_name}Factory" + factory_class = getattr(task_factories, factory_name, None) + if factory_class is None: + task_dict[task_class_name.lower()] = task_class_name + if issubclass(cls, TaskHasPatientMixin): + sub_class_name = "TaskHasPatientFactory" + else: + sub_class_name = "TaskFactory" + + print( + f""" +class {factory_name}({sub_class_name}): + class Meta: + model = {task_class_name} + + id = factory.Sequence(lambda n: n) +""" + ) + + for filename, class_name in task_dict.items(): + print(f"from camcops_server.tasks.{filename} import {class_name}") + + +if __name__ == "__main__": + main()