Skip to content

Commit

Permalink
Fix DAL session management to support connection pooling (#911)
Browse files Browse the repository at this point in the history
* release v1.0.4

* Select person_id from patient table instead of person table.

* Update session handling to support connection pooling

* add logging to debug

* add logging to debug

* make initial_schema() idempotent

* initialize_schema during insert_match_patient

* copying cte_query_table

* add additional logging

* removed extra schema init

* Update mpi.py

* Update mpi.py

* Update mpi.py

* Update mpi.py

* Update mpi.py

* Update mpi.py

* remove print statements added for debuging purposes.

* black

* flake8

* Update phdi/linkage/mpi.py

Co-authored-by: Marcelle <[email protected]>

* fix mpi tests

---------

Co-authored-by: Brady Fausett at Skylight <[email protected]>
Co-authored-by: Marcelle <[email protected]>
  • Loading branch information
3 people authored Nov 7, 2023
1 parent c4b5712 commit b501f1d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 45 deletions.
21 changes: 10 additions & 11 deletions phdi/linkage/dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class DataAccessLayer(object):

def __init__(self) -> None:
self.engine = None
self.session = None
self.Meta = MetaData()
self.PATIENT_TABLE = None
self.PERSON_TABLE = None
Expand Down Expand Up @@ -63,10 +62,6 @@ def get_connection(
max_overflow=max_overflow,
)

self.session = scoped_session(
sessionmaker(bind=self.engine)
) # NOTE extra config can be implemented in this call to sessionmaker factory

def initialize_schema(self) -> None:
"""
Initialize the database schema
Expand Down Expand Up @@ -95,6 +90,7 @@ def initialize_schema(self) -> None:

# order of the list determines the order of
# inserts due to FK constraints
self.TABLE_LIST = []
self.TABLE_LIST.append(self.PERSON_TABLE)
self.TABLE_LIST.append(self.EXTERNAL_SOURCE_TABLE)
self.TABLE_LIST.append(self.EXTERNAL_PERSON_TABLE)
Expand All @@ -116,7 +112,7 @@ def transaction(self) -> None:
:yield: SQLAlchemy session object
:raises ValueError: if an error occurs during the transaction
"""
session = self.session()
session = self.get_session()

try:
yield session
Expand Down Expand Up @@ -230,7 +226,10 @@ def get_session(self) -> scoped_session:
:return: SQLAlchemy scoped session
"""

return self.session()
session = scoped_session(
sessionmaker(bind=self.engine)
) # NOTE extra config can be implemented in this call to sessionmaker factory
return session()

def get_table_by_name(self, table_name: str) -> Table:
"""
Expand All @@ -240,8 +239,8 @@ def get_table_by_name(self, table_name: str) -> Table:
:param table_name: the name of the table you want to get.
:return: SqlAlchemy ORM Table Object.
"""
if len(self.TABLE_LIST) == 0:
self.initialize_schema()

self.initialize_schema()

if table_name is not None and table_name != "":
# TODO: I am sure there is an easier way to do this
Expand All @@ -260,8 +259,8 @@ def get_table_by_column(self, column_name: str) -> Table:
table it belongs to.
:return: SqlAlchemy ORM Table Object.
"""
if len(self.TABLE_LIST) == 0:
self.initialize_schema()

self.initialize_schema()

if column_name is not None and column_name != "":
# TODO: I am sure there is an easier way to do this
Expand Down
13 changes: 7 additions & 6 deletions phdi/linkage/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from phdi.linkage.dal import DataAccessLayer
from phdi.fhir.utils import extract_value_with_resource_path
import uuid
import copy


class DIBBsMPIConnectorClient(BaseMPIConnectorClient):
Expand Down Expand Up @@ -168,6 +169,7 @@ def insert_matched_patient(
the patient record if a match has been found in the MPI, defaults to None.
:return: the person id
"""

try:
correct_person_id = self._get_person_id(
person_id=person_id, external_person_id=external_person_id
Expand Down Expand Up @@ -248,23 +250,22 @@ def _generate_block_query(
.cte(f"{table_key}_cte")
)
else:
fk_info = cte_query_table.foreign_keys.pop()
fk_query_table = copy.deepcopy(cte_query_table)
fk_info = fk_query_table.foreign_keys.pop()
fk_column = fk_info.column
fk_table = fk_info.column.table
sub_query = (
select(cte_query_table)
.where(text(" AND ".join(query_criteria)))
.subquery(f"{cte_query_table.name}_cte_subq")
)

cte_query = (
select(fk_table.c.patient_id.label("patient_id"))
.join(sub_query)
.where(
select(fk_table.c.patient_id).join(
sub_query,
text(
f"{fk_table.name}.{fk_column.name} = "
+ f"{sub_query.name}.{fk_column.name}"
)
),
)
).cte(f"{table_key}_cte")
if cte_query is not None:
Expand Down
20 changes: 12 additions & 8 deletions tests/linkage/test_dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pathlib
from phdi.linkage.dal import DataAccessLayer
from sqlalchemy import Engine, Table, select, text
from sqlalchemy.orm import scoped_session
from phdi.linkage.mpi import DIBBsMPIConnectorClient


Expand Down Expand Up @@ -31,6 +30,7 @@ def _init_db() -> DataAccessLayer:
with dal.engine.connect() as db_conn:
db_conn.rollback()
dal.initialize_schema()
dal.get_session()
return dal


Expand All @@ -54,7 +54,6 @@ def test_init_dal():
dal = DataAccessLayer()

assert dal.engine is None
assert dal.session is None
assert dal.PATIENT_TABLE is None
assert dal.PERSON_TABLE is None

Expand All @@ -67,8 +66,6 @@ def test_get_connection():

assert dal.engine is not None
assert isinstance(dal.engine, Engine)
assert dal.session is not None
assert isinstance(dal.session, scoped_session)

assert dal.PATIENT_TABLE is None
assert dal.PERSON_TABLE is None
Expand All @@ -81,13 +78,19 @@ def test_get_connection():
assert dal.EXTERNAL_SOURCE_TABLE is None


def test_get_session():
dal = DataAccessLayer()
dal.get_connection(
engine_url="postgresql+psycopg2://postgres:pw@localhost:5432/testdb"
)
dal.get_session()


def test_initialize_schema():
dal = _init_db()

assert dal.engine is not None
assert isinstance(dal.engine, Engine)
assert dal.session is not None
assert isinstance(dal.session, scoped_session)
assert isinstance(dal.PATIENT_TABLE, Table)
assert isinstance(dal.PERSON_TABLE, Table)
assert isinstance(dal.NAME_TABLE, Table)
Expand Down Expand Up @@ -140,9 +143,10 @@ def test_bulk_insert_dict():
error_msg = error
finally:
assert error_msg != ""
query = dal.session.query(dal.PATIENT_TABLE)
session = dal.get_session()
query = session.query(dal.PATIENT_TABLE)
results = query.all()
dal.session.close()
session.close()
assert len(results) == 1
assert results[0].dob == datetime.date(1977, 11, 11)
assert results[0].sex == "male"
Expand Down
32 changes: 12 additions & 20 deletions tests/linkage/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,27 +296,19 @@ def test_generate_block_query():
}

expected_result2 = (
"WITH given_name_cte AS"
+ "(SELECT name.patient_id AS patient_id"
+ "FROM name JOIN (SELECT given_name.given_name_id "
+ "AS given_name_id, given_name.name_id AS name_id, "
+ "given_name.given_name AS given_name, "
+ "given_name.given_name_index AS given_name_index"
+ "FROM given_name"
+ "WHERE given_name.given_name = 'Homer') "
+ "AS given_name_cte_subq ON name.name_id = given_name_cte_subq.name_id"
+ "WHERE name.name_id = given_name_cte_subq.name_id),"
+ "name_cte AS"
+ "(SELECT name.patient_id AS patient_id"
+ "FROM name"
+ "WHERE name.last_name = 'Simpson')"
+ "SELECT patient.patient_id, patient.person_id,"
+ " patient.dob, patient.sex, patient.race, patient.ethnicity"
+ "FROM patient JOIN given_name_cte ON "
+ "given_name_cte.patient_id = patient.patient_id "
+ "JOIN name_cte ON name_cte.patient_id = patient.patient_id"
"WITH given_name_cte AS "
"(SELECT name.patient_id AS patient_id FROM name JOIN "
"(SELECT given_name.given_name_id AS given_name_id, "
"given_name.name_id AS name_id, given_name.given_name AS given_name, "
"given_name.given_name_index AS given_name_index FROM given_name "
"WHERE given_name.given_name = 'Homer') AS given_name_cte_subq "
"ON name.name_id = given_name_cte_subq.name_id), name_cte AS "
"(SELECT name.patient_id AS patient_id FROM name WHERE "
"name.last_name = 'Simpson') SELECT patient.patient_id, patient.person_id, "
"patient.dob, patient.sex, patient.race, patient.ethnicity FROM patient JOIN "
"given_name_cte ON given_name_cte.patient_id = patient.patient_id JOIN "
"name_cte ON name_cte.patient_id = patient.patient_id"
)

base_query2 = select(MPI.dal.PATIENT_TABLE)
my_query2 = MPI._generate_block_query(block_data2, base_query2)

Expand Down

0 comments on commit b501f1d

Please sign in to comment.