"""
-import datetime
import os
from io import StringIO, open
from unittest.mock import MagicMock
@@ -32,7 +31,7 @@
import rdiffweb.test
from rdiffweb.core import authorizedkeys
-from rdiffweb.core.model import DuplicateSSHKeyError, RepoObject, Token, UserObject
+from rdiffweb.core.model import DuplicateSSHKeyError, RepoObject, UserObject
from rdiffweb.core.passwd import check_password
@@ -73,7 +72,7 @@ def tearDown(self):
def test_add_user(self):
"""Add user to database."""
userobj = UserObject.add_user('joe')
- self.assertIsNotNone(userobj)
+ userobj.commit()
self.assertIsNotNone(UserObject.get_user('joe'))
# Check if listener called
self.listener.user_added.assert_called_once_with(userobj)
@@ -87,7 +86,7 @@ def change_user_obj(userobj):
self.listener.user_added.side_effect = change_user_obj
# When adding user
userobj = UserObject.add_user('joe')
- self.assertIsNotNone(userobj)
+ userobj.commit()
self.assertIsNotNone(UserObject.get_user('joe'))
# Then lister get called
self.listener.user_added.assert_called_once_with(userobj)
@@ -96,7 +95,8 @@ def change_user_obj(userobj):
def test_add_user_with_duplicate(self):
"""Add user to database."""
- UserObject.add_user('denise')
+ user = UserObject.add_user('denise')
+ user.commit()
self.listener.user_added.reset_mock()
with self.assertRaises(ValueError):
UserObject.add_user('denise')
@@ -106,6 +106,7 @@ def test_add_user_with_duplicate(self):
def test_add_user_with_password(self):
"""Add user to database with password."""
userobj = UserObject.add_user('jo', 'password')
+ userobj.commit()
self.assertIsNotNone(UserObject.get_user('jo'))
# Check if listener called
self.listener.user_added.assert_called_once_with(userobj)
@@ -120,7 +121,8 @@ def test_users(self):
# Check admin exists
self.assertEqual(1, UserObject.query.count())
# Create user.
- UserObject.add_user('annik')
+ user = UserObject.add_user('annik')
+ user.commit()
users = UserObject.query.all()
self.assertEqual(2, len(users))
self.assertEqual('annik', users[1].username)
@@ -134,9 +136,11 @@ def test_get_user(self):
user.role = UserObject.ADMIN_ROLE
user.email = 'bernie@gmail.com'
user.refresh_repos()
+ user.commit()
self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in user.repo_objs]))
user.repo_objs[0].maxage = -1
user.repo_objs[1].maxage = 3
+ user.commit()
# Get user record.
obj = UserObject.get_user('bernie')
@@ -159,6 +163,7 @@ def test_get_user_with_invalid_user(self):
def test_get_set(self):
user = UserObject.add_user('larry', 'password')
+ user.add().commit()
self.assertEqual('', user.email)
self.assertEqual([], user.repo_objs)
@@ -168,18 +173,19 @@ def test_get_set(self):
user.user_root = self.testcases
user.refresh_repos()
+ user.commit()
self.listener.user_attr_changed.assert_called_with(user, {'user_root': ('', self.testcases)})
self.listener.user_attr_changed.reset_mock()
user = UserObject.get_user('larry')
user.role = UserObject.ADMIN_ROLE
- user.add()
+ user.commit()
self.listener.user_attr_changed.assert_called_with(
user, {'role': (UserObject.USER_ROLE, UserObject.ADMIN_ROLE)}
)
self.listener.user_attr_changed.reset_mock()
user = UserObject.get_user('larry')
user.email = 'larry@gmail.com'
- user.add()
+ user.commit()
self.listener.user_attr_changed.assert_called_with(user, {'email': ('', 'larry@gmail.com')})
self.listener.user_attr_changed.reset_mock()
@@ -192,11 +198,12 @@ def test_get_set(self):
def test_set_role_null(self):
# Given a user
user = UserObject.add_user('annik', 'password')
+ user.add().commit()
# When trying to set the role to null
user.role = None
# Then an exception is raised
with self.assertRaises(Exception):
- user.add()
+ user.add().commit()
@parameterized.expand(
[
@@ -212,7 +219,7 @@ def test_is_admin(self, role, expected_is_admin):
user = UserObject.add_user('annik', 'password')
# When setting the role value
user.role = role
- user.add()
+ user.commit()
# Then the is_admin value get updated too
self.assertEqual(expected_is_admin, user.is_admin)
@@ -230,16 +237,18 @@ def test_is_maintainer(self, role, expected_is_maintainer):
user = UserObject.add_user('annik', 'password')
# When setting the role value
user.role = role
- user.add()
+ user.commit()
# Then the is_admin value get updated too
self.assertEqual(expected_is_maintainer, user.is_maintainer)
def test_set_password_update(self):
# Given a user in database with a password
userobj = UserObject.add_user('annik', 'password')
+ userobj.commit()
self.listener.user_password_changed.reset_mock()
# When updating the user's password
userobj.set_password('new_password')
+ userobj.commit()
# Then password is SSHA
self.assertTrue(check_password('new_password', userobj.hash_password))
# Check if listener called
@@ -248,9 +257,11 @@ def test_set_password_update(self):
def test_delete_user(self):
# Given an existing user in database
userobj = UserObject.add_user('vicky')
+ userobj.commit()
self.assertIsNotNone(UserObject.get_user('vicky'))
# When deleting that user
userobj.delete()
+ userobj.commit()
# Then user it no longer in database
self.assertIsNone(UserObject.get_user('vicky'))
# Then listner was called
@@ -259,6 +270,7 @@ def test_delete_user(self):
def test_set_password_empty(self):
"""Expect error when trying to update password of invalid user."""
userobj = UserObject.add_user('john')
+ userobj.commit()
with self.assertRaises(ValueError):
self.assertFalse(userobj.set_password(''))
@@ -286,6 +298,7 @@ def test_add_authorizedkey_without_file(self):
# Add the key to the user
userobj = UserObject.get_user(self.USERNAME)
userobj.add_authorizedkey(key)
+ userobj.commit()
# validate
keys = list(userobj.authorizedkeys)
@@ -298,9 +311,11 @@ def test_add_authorizedkey_duplicate(self):
# Add the key to the user
userobj = UserObject.get_user(self.USERNAME)
userobj.add_authorizedkey(key)
+ userobj.commit()
# Add the same key
with self.assertRaises(DuplicateSSHKeyError):
userobj.add_authorizedkey(key)
+ userobj.commit()
def test_add_authorizedkey_with_file(self):
"""
@@ -316,6 +331,7 @@ def test_add_authorizedkey_with_file(self):
# Read the pub key
key = self._read_ssh_key()
userobj.add_authorizedkey(key)
+ userobj.commit()
# Validate
with open(filename, 'r') as fh:
@@ -341,6 +357,7 @@ def test_delete_authorizedkey_without_file(self):
# Remove a key
userobj.delete_authorizedkey("9a:f1:69:3c:bc:5a:cd:02:5e:33:bc:cd:c0:01:eb:4c")
+ userobj.commit()
# Validate
keys = list(userobj.authorizedkeys)
@@ -376,6 +393,7 @@ def test_repo_objs(self):
self.assertEqual(['broker-repo', 'testcases'], [r.name for r in repos])
# When deleting a repository empty list
repos[1].delete()
+ repos[1].commit()
# Then the repository is removed from the list.
self.assertEqual(['broker-repo'], sorted([r.name for r in userobj.repo_objs]))
@@ -383,10 +401,11 @@ def test_refresh_repos_without_delete(self):
# Given a user with invalid repositories
userobj = UserObject.get_user(self.USERNAME)
RepoObject.query.delete()
- RepoObject(userid=userobj.userid, repopath='invalid').add()
+ RepoObject(userid=userobj.userid, repopath='invalid').add().commit()
self.assertEqual(['invalid'], sorted([r.name for r in userobj.repo_objs]))
# When updating the repository list without deletion
userobj.refresh_repos()
+ userobj.commit()
# Then the list invlaid the invalid repo and new repos
self.assertEqual(['broker-repo', 'invalid', 'testcases'], sorted([r.name for r in userobj.repo_objs]))
@@ -394,10 +413,11 @@ def test_refresh_repos_with_delete(self):
# Given a user with invalid repositories
userobj = UserObject.get_user(self.USERNAME)
RepoObject.query.delete()
- RepoObject(userid=userobj.userid, repopath='invalid').add()
+ RepoObject(userid=userobj.userid, repopath='invalid').add().commit()
self.assertEqual(['invalid'], sorted([r.name for r in userobj.repo_objs]))
# When updating the repository list without deletion
userobj.refresh_repos(delete=True)
+ userobj.commit()
# Then the list invlaid the invalid repo and new repos
userobj.expire()
self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in userobj.repo_objs]))
@@ -408,102 +428,11 @@ def test_refresh_repos_with_single_repo(self):
userobj.user_root = os.path.join(self.testcases, 'testcases')
# When updating the repository list without deletion
userobj.refresh_repos(delete=True)
+ userobj.commit()
# Then the list invlaid the invalid repo and new repos
userobj.expire()
self.assertEqual([''], sorted([r.name for r in userobj.repo_objs]))
- def test_add_access_token(self):
- # Given a user with an email
- userobj = UserObject.get_user(self.USERNAME)
- userobj.email = 'test@examples.com'
- userobj.add()
- # When adding a new token
- token = userobj.add_access_token('test')
- # Then a new token get created
- self.assertTrue(token)
- tokenobj = Token.query.filter(Token.userid == userobj.userid).first()
- self.assertTrue(tokenobj)
- self.assertEqual(None, tokenobj.expiration_time)
- self.assertEqual(None, tokenobj.access_time)
- # Then an email is sent to the user.
- self.listener.access_token_added.assert_called_once_with(userobj, 'test')
- self.listener.queue_mail.assert_called_once()
-
- def test_add_access_token_duplicate_name(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- userobj.add_access_token('test')
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When adding a new token with the same name
- with self.assertRaises(ValueError):
- userobj.add_access_token('test')
- # Then token is not created
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # Then an email is not sent.
- self.listener.access_token_added.assert_called_once_with(userobj, 'test')
-
- def test_delete_access_token(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- userobj.add_access_token('test')
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When deleting an access token
- userobj.delete_access_token('test')
- # Then Token get deleted
- self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count())
-
- def test_delete_access_token_invalid(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- userobj.add_access_token('test')
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When deleting an invalid access token
- with self.assertRaises(ValueError):
- userobj.delete_access_token('invalid')
- # Then Token not deleted
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
-
- def test_delete_user_remove_access_tokens(self):
- # Given a user with an existing token
- userobj = UserObject.add_user('testuser', 'password')
- userobj.add_access_token('test')
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When deleting the user
- userobj.delete()
- # Then Token get deleted
- self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count())
-
- def test_verify_access_token(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- token = userobj.add_access_token('test')
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When validating the token
- # Then token is valid
- self.assertTrue(userobj.validate_access_token(token))
-
- def test_verify_access_token_with_expired(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- token = userobj.add_access_token(
- 'test', expiration_time=datetime.datetime.now() - datetime.timedelta(seconds=1)
- )
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When validating the token
- # Then token is invalid
- self.assertFalse(userobj.validate_access_token(token))
- # Then token get removed
- self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count())
-
- def test_verify_access_token_with_invalid(self):
- # Given a user with an existing token
- userobj = UserObject.get_user(self.USERNAME)
- userobj.add_access_token('test', expiration_time=datetime.datetime.now())
- self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count())
- # When validating the token
- # Then token is invalid
- self.assertFalse(userobj.validate_access_token('invalid'))
-
class UserObjectWithAdminPassword(rdiffweb.test.WebCase):
@@ -523,7 +452,6 @@ def test_create_admin_user(self):
self.assertEqual('{SSHA}wbSK4hlEX7mtGJplFi2oN6ABm6Y3Bo1e', userobj.hash_password)
self.assertTrue(check_password('test', userobj.hash_password))
- def test_set_password(self):
# Given admin-password is configure
# When trying to update admin password
# Then an exception is raised
diff --git a/rdiffweb/core/tests/test_login.py b/rdiffweb/core/tests/test_login.py
index 3604b7c9..65a0762d 100644
--- a/rdiffweb/core/tests/test_login.py
+++ b/rdiffweb/core/tests/test_login.py
@@ -45,6 +45,7 @@ class LoginTest(LoginAbstractTest):
def test_login(self):
# Given a valid user in database with a password
userobj = UserObject.add_user('tom', 'password')
+ userobj.commit()
# When trying to login with valid password
login = cherrypy.engine.publish('login', 'tom', 'password')
# Then login is successful
@@ -53,7 +54,8 @@ def test_login(self):
self.listener.user_login.assert_called_once_with(userobj)
def test_login_with_invalid_password(self):
- UserObject.add_user('jeff', 'password')
+ userobj = UserObject.add_user('jeff', 'password')
+ userobj.commit()
self.assertFalse(any(cherrypy.engine.publish('login', 'jeff', 'invalid')))
# password is case sensitive
self.assertFalse(any(cherrypy.engine.publish('login', 'jeff', 'Password')))
diff --git a/rdiffweb/core/tests/test_notification.py b/rdiffweb/core/tests/test_notification.py
index 6fe44b49..3d82a503 100644
--- a/rdiffweb/core/tests/test_notification.py
+++ b/rdiffweb/core/tests/test_notification.py
@@ -53,10 +53,10 @@ def test_notification_job(self):
# Set user config
user = UserObject.get_user(self.USERNAME)
user.email = 'test@test.com'
- user.add()
+ user.commit()
repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == self.REPO).first()
repo.maxage = 1
- repo.add()
+ repo.commit()
# When running notification_job
cherrypy.notification.notification_job()
@@ -71,11 +71,11 @@ def test_notification_job_undefined_last_backup_date(self):
# Given a valid user with a repository configured for notification
user = UserObject.get_user(self.USERNAME)
user.email = 'test@test.com'
- user.add()
+ user.add().commit()
# Given a repo with last_backup_date None
repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == 'broker-repo').first()
repo.maxage = 1
- repo.add()
+ repo.add().commit()
self.assertIsNone(repo.last_backup_date)
# When Notification job is running
@@ -92,10 +92,10 @@ def test_notification_job_without_notification(self):
# Given a valid user with a repository configured without notification (-1)
user = UserObject.get_user(self.USERNAME)
user.email = 'test@test.com'
- user.add()
+ user.add().commit()
repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == self.REPO).first()
repo.maxage = -1
- repo.add()
+ repo.add().commit()
# Call notification.
cherrypy.notification.notification_job()
@@ -123,13 +123,13 @@ def test_email_changed(self):
# Given a user with an email address
user = UserObject.get_user(self.USERNAME)
user.email = 'original_email@test.com'
- user.add()
+ user.add().commit()
self.listener.queue_email.reset_mock()
# When updating the user's email
user = UserObject.get_user(self.USERNAME)
user.email = 'email_changed@test.com'
- user.add()
+ user.add().commit()
# Then a email is queue to notify the user.
self.listener.queue_email.assert_called_once_with(
@@ -142,10 +142,12 @@ def test_email_updated_with_same_value(self):
# Given a user with an email
user = UserObject.get_user(self.USERNAME)
user.email = 'email_changed@test.com'
+ user.add().commit()
self.listener.queue_email.reset_mock()
# When updating the user's email with the same value
user.email = 'email_changed@test.com'
+ user.add().commit()
# Then no email are sent to the user
self.listener.queue_email.assert_not_called()
@@ -154,10 +156,12 @@ def test_password_change_notification(self):
# Given a user with a email.
user = UserObject.get_user(self.USERNAME)
user.email = 'password_change@test.com'
+ user.add().commit()
self.listener.queue_email.reset_mock()
# When updating the user password
user.set_password('new_password')
+ user.add().commit()
# Then a email is send to the user
self.listener.queue_email.assert_called_once_with(
@@ -171,10 +175,12 @@ def test_password_change_with_same_value(self):
user = UserObject.get_user(self.USERNAME)
user.email = 'password_change@test.com'
user.set_password('new_password')
+ user.add().commit()
self.listener.queue_email.reset_mock()
# When updating the user password with the same value
user.set_password('new_password')
+ user.add().commit()
# Then an email is sent to the user
self.listener.queue_email.assert_called_once_with(
diff --git a/rdiffweb/core/tests/test_quota.py b/rdiffweb/core/tests/test_quota.py
index b220af76..edd4556e 100644
--- a/rdiffweb/core/tests/test_quota.py
+++ b/rdiffweb/core/tests/test_quota.py
@@ -40,6 +40,7 @@ class QuotaPluginTest(test.WebCase):
def test_get_disk_usage(self):
# Given a user
userobj = UserObject.add_user('bob')
+ userobj.commit()
# When querying quota for a userobj
result = cherrypy.engine.publish('get_disk_usage', userobj)
# Then quota return a value
@@ -48,6 +49,7 @@ def test_get_disk_usage(self):
def test_get_disk_quota(self):
# Given a user
userobj = UserObject.add_user('bob')
+ userobj.commit()
# When querying quota for a userobj
result = cherrypy.engine.publish('get_disk_quota', userobj)
# Then quota return a value
@@ -56,6 +58,7 @@ def test_get_disk_quota(self):
def test_set_disk_quota(self):
# Given a used cmd
userobj = UserObject.add_user('bob')
+ userobj.commit()
# When querying quota for a userobj
results = cherrypy.engine.publish('set_disk_quota', userobj, 98765)
# Then quota return a value
@@ -73,6 +76,7 @@ class QuotaPluginTestWithFailure(test.WebCase):
def test_set_disk_quota_with_failure(self):
# Given a user object
userobj = UserObject.add_user('bob')
+ userobj.commit()
# When settings the quota
results = cherrypy.engine.publish('set_disk_quota', userobj, 98765)
# Then False is returned
@@ -108,6 +112,7 @@ def test_get_disk_usage_with_empty_user_root(self):
# Given a user with an empty user_root.
userobj = UserObject.add_user('bob')
userobj.user_root = ''
+ userobj.commit()
# When getting disk usage
results = cherrypy.engine.publish('get_disk_usage', userobj)
# Then default disk usage is return
@@ -117,6 +122,7 @@ def test_get_disk_usage_with_invalid_user_root(self):
# Given a user with an invalid user_root.
userobj = UserObject.add_user('bob')
userobj.user_root = 'invalid'
+ userobj.commit()
# When getting disk usage
results = cherrypy.engine.publish('get_disk_usage', userobj)
# Then default disk usage is return
diff --git a/rdiffweb/core/tests/test_rdw_templating.py b/rdiffweb/core/tests/test_rdw_templating.py
index c677022f..c8d8bab3 100644
--- a/rdiffweb/core/tests/test_rdw_templating.py
+++ b/rdiffweb/core/tests/test_rdw_templating.py
@@ -22,7 +22,7 @@
from rdiffweb.core.librdiff import RdiffTime
from rdiffweb.core.model import RepoObject, UserObject
from rdiffweb.core.rdw_templating import _ParentEntry, attrib, do_format_lastupdated, list_parents, url_for
-from rdiffweb.test import AppTestCase, WebCase
+from rdiffweb.test import WebCase
class TemplateManagerTest(unittest.TestCase):
@@ -91,7 +91,7 @@ def test_do_format_lastupdated(self):
self.assertEqual('4 years ago', do_format_lastupdated(RdiffTime(value=1452442324), now=1591978846))
-class ListParentsTest(AppTestCase):
+class ListParentsTest(WebCase):
def test_list_parents_with_root_dir(self):
repo, path = RepoObject.get_repo_path(b'admin/testcases', as_user=UserObject.get_user('admin'))
self.assertEqual(list_parents(repo, path), [_ParentEntry(path=b'', display_name='testcases')])
diff --git a/rdiffweb/rdw_app.py b/rdiffweb/rdw_app.py
index 34a1e677..6bc5c77d 100644
--- a/rdiffweb/rdw_app.py
+++ b/rdiffweb/rdw_app.py
@@ -232,7 +232,8 @@ def __init__(self, cfg):
os.environ["TMPDIR"] = self._tempdir
# create user manager
- UserObject.create_admin_user(cfg.admin_user, cfg.admin_password)
+ user = UserObject.create_admin_user(cfg.admin_user, cfg.admin_password)
+ user.commit()
@property
def currentuser(self):
diff --git a/rdiffweb/templates/access_token_added.html b/rdiffweb/templates/access_token_added.html
index 5142d38f..53b6bf5c 100644
--- a/rdiffweb/templates/access_token_added.html
+++ b/rdiffweb/templates/access_token_added.html
@@ -2,7 +2,7 @@
{% trans username=user.username %}Hey {{ username }},{% endtrans %}
- {% trans %}A new access token, named "{{ test }}", has been created.{% endtrans %}
+ {% trans %}A new access token, named "{{ name }}", has been created.{% endtrans %}
{% trans %}If you did not make this change and believe your account has been compromised, please contact your administrator.{% endtrans %}
diff --git a/rdiffweb/test.py b/rdiffweb/test.py
index 2cd452c2..821b6cb8 100644
--- a/rdiffweb/test.py
+++ b/rdiffweb/test.py
@@ -51,55 +51,6 @@ def create_testcases_repo(app):
return new
-class AppTestCase(unittest.TestCase):
-
- REPO = 'testcases'
-
- USERNAME = 'admin'
-
- PASSWORD = 'admin123'
-
- default_config = {}
-
- app_class = RdiffwebApp
-
- @classmethod
- def setup_class(cls):
- if cls is AppTestCase:
- raise unittest.SkipTest("%s is an abstract base class" % cls.__name__)
-
- @classmethod
- def teardown_class(cls):
- pass
-
- def setUp(self):
- # Allow defining a custom database uri for testing.
- self.database_dir = tempfile.mkdtemp(prefix='rdiffweb_tests_db_')
- uri = os.path.join(self.database_dir, 'rdiffweb.tmp.db')
- uri = os.environ.get('RDIFFWEB_TEST_DATABASE_URI', uri)
- self.default_config['database-uri'] = uri
- cfg = self.app_class.parse_args(
- args=[], config_file_contents='\n'.join('%s=%s' % (k, v) for k, v in self.default_config.items())
- )
- # Create Application
- self.app = self.app_class(cfg)
- # Create repositories
- self.testcases = create_testcases_repo(self.app)
- # Register repository
- admin_user = UserObject.get_user(self.USERNAME)
- if admin_user:
- admin_user.user_root = self.testcases
- admin_user.refresh_repos()
-
- def tearDown(self):
- if hasattr(self, 'database_dir'):
- shutil.rmtree(self.database_dir)
- delattr(self, 'database_dir')
- if hasattr(self, 'testcases'):
- shutil.rmtree(self.testcases)
- delattr(self, 'testcases')
-
-
class WebCase(helper.CPWebCase):
"""
Helper class for the rdiffweb test suite.
@@ -152,13 +103,14 @@ def setUp(self):
cherrypy.tools.db.drop_all()
cherrypy.tools.db.create_all()
# Create default admin
- UserObject.create_admin_user(self.USERNAME, self.PASSWORD)
+ admin_user = UserObject.create_admin_user(self.USERNAME, self.PASSWORD)
+ admin_user.commit()
# Create testcases repo
self.testcases = create_testcases_repo(self.app)
- admin_user = UserObject.get_user(self.USERNAME)
if admin_user:
admin_user.user_root = self.testcases
admin_user.refresh_repos()
+ admin_user.commit()
# Login to web application.
if self.login:
self._login()
@@ -167,6 +119,7 @@ def tearDown(self):
if hasattr(self, 'testcases'):
shutil.rmtree(self.testcases)
delattr(self, 'testcases')
+ cherrypy.tools.db.drop_all()
@property
def app(self):
diff --git a/rdiffweb/tools/db.py b/rdiffweb/tools/db.py
index 0aee0d07..76350b90 100644
--- a/rdiffweb/tools/db.py
+++ b/rdiffweb/tools/db.py
@@ -20,8 +20,9 @@
import logging
import cherrypy
-from sqlalchemy import create_engine, event
+from sqlalchemy import create_engine, event, inspect
from sqlalchemy.engine import Engine
+from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
@@ -39,6 +40,31 @@ def _set_sqlite_journal_mode_wal(connection, connection_record):
cursor.close()
+def _get_model_changes(model):
+ """
+ Return a dictionary containing changes made to the model since it was
+ fetched from the database.
+
+ The dictionary is of the form {'property_name': [old_value, new_value]}
+ """
+ state = inspect(model)
+ changes = {}
+ for attr in state.attrs:
+ hist = attr.history
+ if not hist.has_changes():
+ continue
+ if isinstance(attr.value, (list, tuple)) or len(hist.deleted) > 1 or len(hist.added) > 1:
+ # If array, store array
+ changes[attr.key] = [hist.deleted, hist.added]
+ else:
+ # If primitive, store primitive
+ changes[attr.key] = [
+ hist.deleted[0] if len(hist.deleted) >= 1 else None,
+ hist.added[0] if len(hist.added) >= 1 else None,
+ ]
+ return changes
+
+
class Base:
'''
Extends declarative base to provide convenience methods to models similar to
@@ -53,35 +79,32 @@ class Base:
changed = User.from_dict({}) # update record based on dict argument passed in and returns any keys changed
'''
- def add(self, commit=True):
+ def add(self):
"""
Add current object to session.
"""
self.__class__.session.add(self)
- if commit:
- self.__class__.session.commit()
return self
- def delete(self, commit=True):
- """
- Delete current object to session.
- """
+ def delete(self):
self.__class__.session.delete(self)
- if commit:
- self.__class__.session.commit()
return self
- def merge(self, commit=True):
- """
- Merge current object to session.
- """
- self.__class__.session.merge(self)
- if commit:
- self.__class__.session.commit()
+ def commit(self):
+ self.__class__.session.commit()
+ return self
+
+ def flush(self):
+ self.__class__.session.flush()
return self
def expire(self):
self.__class__.session.expire(self)
+ return self
+
+ def rollback(self):
+ self.__class__.session.rollback()
+ return self
class BaseExtensions(DeclarativeMeta):
@@ -117,6 +140,7 @@ def create_all(self):
if debug:
logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
base.metadata.create_all()
+ self.get_session().commit()
def drop_all(self):
# Release opened sessions.
@@ -124,6 +148,7 @@ def drop_all(self):
# Drop all
base = self.get_base()
base.metadata.drop_all()
+ self.get_session().commit()
def get_base(self):
if self._base is None:
@@ -140,13 +165,17 @@ def on_end_resource(self):
if self._session is None:
return
try:
- self._session.flush()
- self._session.commit()
- except Exception:
- logger.exception('error trying to flush and commit session')
+ # When terminating, raise an error if objects are not commit.
+ if self._session.dirty or self._session.new or self._session.deleted:
+ changes = ', '.join([str(_get_model_changes(obj)) for obj in self._session.dirty])
+ logger.exception(
+ 'session is dirty, some database object(s) are not commited, this indicate a bug in the application '
+ 'dirty %s new %s deleted %s' % (changes, self._session.new, self._session.deleted)
+ )
+ raise SQLAlchemyError('session is dirty')
+ finally:
self._session.rollback()
self._session.expunge_all()
- finally:
self._session.remove()