Skip to content

Commit

Permalink
Merge pull request #110 from federicoemartinez/issue/109
Browse files Browse the repository at this point in the history
adding a commit=True kwarg to save and other methods of active record
  • Loading branch information
michaelbukachi authored Jul 31, 2023
2 parents cbd35a6 + 59285d3 commit e3f9d38
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
38 changes: 22 additions & 16 deletions sqlalchemy_mixins/activerecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,56 @@ def fill(self, **kwargs):

return self

def save(self):
def save(self, commit=True):
"""Saves the updated model to the current entity db.
:param commit: where to commit the transaction
"""
try:
self.session.add(self)
self.session.commit()
return self
except:
self.session.rollback()
raise
self.session.add(self)
if commit:
self._commit_or_fail()
return self

@classmethod
def create(cls, **kwargs):
def create(cls, commit=True, **kwargs):
"""Create and persist a new record for the model
:param commit: where to commit the transaction
:param kwargs: attributes for the record
:return: the new model instance
"""
return cls().fill(**kwargs).save()
return cls().fill(**kwargs).save(commit=commit)

def update(self, **kwargs):
def update(self, commit=True, **kwargs):
"""Same as :meth:`fill` method but persists changes to database.
:param commit: where to commit the transaction
"""
return self.fill(**kwargs).save()
return self.fill(**kwargs).save(commit=commit)

def delete(self):
def delete(self, commit=True):
"""Removes the model from the current entity session and mark for deletion.
:param commit: where to commit the transaction
"""
self.session.delete(self)
if commit:
self._commit_or_fail()

def _commit_or_fail(self):
try:
self.session.delete(self)
self.session.commit()
except:
self.session.rollback()
raise

@classmethod
def destroy(cls, *ids):
def destroy(cls, *ids, commit=True):
"""Delete the records with the given ids
:type ids: list
:param ids: primary key ids of records
:param commit: where to commit the transaction
"""
for pk in ids:
obj = cls.find(pk)
if obj:
obj.delete()
obj.delete(commit=commit)
cls.session.flush()

@classmethod
Expand Down
60 changes: 60 additions & 0 deletions sqlalchemy_mixins/tests/test_activerecord.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import sqlalchemy
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.hybrid import hybrid_property
Expand Down Expand Up @@ -115,6 +116,30 @@ def test_fill_and_save(self):
self.assertEqual(p11, sess.query(Post).first())
self.assertEqual(p11.archived, True)

def test_save_commits(self):
with self.assertRaises(sqlalchemy.exc.InvalidRequestError):
with sess.begin():
u1 = User()
u1.fill(name='Bill u1')
u1.save()
u2 = User()
u2.fill(name='Bill u2')
u2.save()
self.assertEqual([u1, u2], sess.query(User).order_by(User.id.asc()).all())
# The first user is saved even when the block raises a Exception
self.assertEqual([u1], sess.query(User).order_by(User.id.asc()).all())

def test_save_do_not_commit(self):
with sess.begin():
u1 = User()
u1.fill(name='Bill u1')
u1.save(commit=False)
u2 = User()
u2.fill(name='Bill u2')
u2.save(commit=False)

self.assertEqual([u1,u2], sess.query(User).order_by(User.id.asc()).all())

def test_create(self):
u1 = User.create(name='Bill u1')
self.assertEqual(u1, sess.query(User).first())
Expand Down Expand Up @@ -158,6 +183,16 @@ def test_update(self):
self.assertEqual(sess.query(Post).get(11).public, True)
self.assertEqual(sess.query(Post).get(11).user, u2)

def test_update_no_commit(self):
u1 = User(name='Bill', id=1)
u1.save()
u1.update(name='Joe', commit=False)
self.assertEqual('Joe', sess.query(User).where(User.id==1).first().name)
sess.rollback()
self.assertEqual('Bill', sess.query(User).where(User.id==1).first().name)



def test_fill_wrong_attribute(self):
u1 = User(name='Bill u1')
sess.add(u1)
Expand All @@ -179,13 +214,32 @@ def test_delete(self):
u1.delete()
self.assertEqual(sess.query(User).get(1), None)

def test_delete_without_commit(self):
u1 = User()
u1.save()
u1.delete(commit=False)
self.assertIsNone(sess.query(User).one_or_none())
sess.rollback()
self.assertIsNotNone(sess.query(User).one_or_none())


def test_destroy(self):
u1, u2, p11, p12, p13 = self._seed()

self.assertEqual(set(sess.query(Post).all()), {p11, p12, p13})
Post.destroy(11, 12)
self.assertEqual(set(sess.query(Post).all()), {p13})


def test_destroy_no_commit(self):
u1, u2, p11, p12, p13 = self._seed()
sess.commit()
self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p11, p12, p13})
Post.destroy(11, 12, commit=False)
self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p13})
sess.rollback()
self.assertEqual(set(sess.query(Post).order_by(Post.id).all()), {p11, p12, p13})

def test_all(self):
u1, u2, p11, p12, p13 = self._seed()

Expand Down Expand Up @@ -231,6 +285,12 @@ def test_create(self):
u1 = UserAlternative.create(name='Bill u1')
self.assertEqual(u1, sess.query(UserAlternative).first())

def test_create_no_commit(self):
u1 = UserAlternative.create(name='Bill u1', commit=False)
self.assertEqual(u1, sess.query(UserAlternative).first())
sess.rollback()
self.assertIsNone(sess.query(UserAlternative).one_or_none())



if __name__ == '__main__': # pragma: no cover
Expand Down

0 comments on commit e3f9d38

Please sign in to comment.