diff --git a/README.md b/README.md index ec5b4b1b..fc6df341 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,7 @@ This next release focus on two-factor-authentication as a measure to increase se * Send email notification when enabling or disabling MFA [CVE-2022-3363](https://nvd.nist.gov/vuln/detail/CVE-2022-3363) * Use Argon2id to store password hash #231 * Fixed plugin priorities to ensure that jobs are scheduled at each startup #232 +* Revoke previous user's sessions on password change [CVE-2022-3362](https://nvd.nist.gov/vuln/detail/CVE-2022-3362) Breaking changes: diff --git a/rdiffweb/controller/api.py b/rdiffweb/controller/api.py index 9b1e2bf6..87dc2bca 100644 --- a/rdiffweb/controller/api.py +++ b/rdiffweb/controller/api.py @@ -60,7 +60,10 @@ def _checkpassword(realm, username, password): userobj = UserObject.get_user(username) if userobj is not None: # Verify if the password matches a token. - if userobj.validate_access_token(password): + access_token = userobj.validate_access_token(password) + if access_token: + access_token.accessed() + access_token.commit() return True # Disable password authentication for MFA if userobj.mfa == UserObject.ENABLED_MFA: @@ -78,7 +81,8 @@ class ApiCurrentUser(Controller): @cherrypy.expose def default(self): u = self.app.currentuser - u.refresh_repos() + if u.refresh_repos(): + u.commit() return { "email": u.email, "username": u.username, diff --git a/rdiffweb/controller/page_admin_session.py b/rdiffweb/controller/page_admin_session.py index 065167dd..a1747f44 100644 --- a/rdiffweb/controller/page_admin_session.py +++ b/rdiffweb/controller/page_admin_session.py @@ -46,6 +46,7 @@ def default(self, action=None, **kwargs): flash(_('You cannot revoke your current session.'), level='warning') else: session.delete() + session.commit() flash(_('The session was successfully revoked.'), level='success') else: flash(form.error_message, level='error') diff --git a/rdiffweb/controller/page_admin_users.py b/rdiffweb/controller/page_admin_users.py index 668126f8..810ce423 100644 --- a/rdiffweb/controller/page_admin_users.py +++ b/rdiffweb/controller/page_admin_users.py @@ -182,6 +182,7 @@ def populate_obj(self, userobj): # Setting quota will silently fail. Check if quota was updated. if userobj.disk_quota != new_quota: flash(_("Setting user's quota is not supported"), level='warning') + userobj.commit() class EditUserForm(UserForm): @@ -214,6 +215,7 @@ def _delete_user(self, action, form): user = UserObject.get_user(form.username.data) if user: user.delete() + user.commit() flash(_("User account removed.")) else: flash(_("User doesn't exists!"), level='warning') diff --git a/rdiffweb/controller/page_delete.py b/rdiffweb/controller/page_delete.py index 3a79cca3..677a6c1b 100644 --- a/rdiffweb/controller/page_delete.py +++ b/rdiffweb/controller/page_delete.py @@ -40,6 +40,11 @@ _logger = logging.getLogger(__name__) +def delete_repo(repoobj, path): + repoobj.delete(path) + repoobj.commit() + + class DeleteRepoForm(CherryForm): confirm = StringField(_('Confirmation'), validators=[DataRequired()]) @@ -72,7 +77,7 @@ def default(self, path=b"", **kwargs): if form.is_submitted(): if form.validate(): RepoObject.session.expunge(repo) - cherrypy.engine.publish('schedule_task', repo.delete, path) + cherrypy.engine.publish('schedule_task', delete_repo, repo, path) # Redirect to parent folder or to root if repo get deleted if path_obj.isroot: raise cherrypy.HTTPRedirect(url_for('/')) diff --git a/rdiffweb/controller/page_locations.py b/rdiffweb/controller/page_locations.py index 74570485..748b7aa5 100644 --- a/rdiffweb/controller/page_locations.py +++ b/rdiffweb/controller/page_locations.py @@ -34,7 +34,8 @@ class LocationsPage(Controller): @cherrypy.expose def index(self): # Get page params - self.app.currentuser.refresh_repos() + if self.app.currentuser.refresh_repos(): + self.app.currentuser.commit() params = { "repos": self.app.currentuser.repo_objs, "disk_usage": self.app.currentuser.disk_usage, diff --git a/rdiffweb/controller/page_pref_general.py b/rdiffweb/controller/page_pref_general.py index eb46eda6..9d51f34b 100644 --- a/rdiffweb/controller/page_pref_general.py +++ b/rdiffweb/controller/page_pref_general.py @@ -58,7 +58,7 @@ def is_submitted(self): def populate_obj(self, user): user.fullname = self.fullname.data user.email = self.email.data - user.add() + user.commit() class UserPasswordForm(CherryForm): @@ -99,6 +99,7 @@ def populate_obj(self, user): return False try: user.set_password(self.new.data) + user.commit() return True except ValueError as e: self.new.errors = [str(e)] @@ -120,7 +121,8 @@ def is_submitted(self): def populate_obj(self, user): try: - user.refresh_repos(delete=True) + if user.refresh_repos(delete=True): + user.commit() flash(_("Repositories successfully updated"), level='success') except ValueError as e: flash(str(e), level='warning') diff --git a/rdiffweb/controller/page_pref_mfa.py b/rdiffweb/controller/page_pref_mfa.py index 918bf8d1..7b760268 100644 --- a/rdiffweb/controller/page_pref_mfa.py +++ b/rdiffweb/controller/page_pref_mfa.py @@ -78,9 +78,11 @@ def populate_obj(self, userobj): # Enable or disable MFA only when a code is provided. if self.enable_mfa.data: userobj.mfa = UserObject.ENABLED_MFA + userobj.commit() flash(_("Two-Factor authentication enabled successfully."), level='success') elif self.disable_mfa.data: userobj.mfa = UserObject.DISABLED_MFA + userobj.commit() flash(_("Two-Factor authentication disabled successfully."), level='success') def validate_code(self, field): diff --git a/rdiffweb/controller/page_pref_notification.py b/rdiffweb/controller/page_pref_notification.py index 1e2d8971..5ae485b4 100644 --- a/rdiffweb/controller/page_pref_notification.py +++ b/rdiffweb/controller/page_pref_notification.py @@ -76,6 +76,7 @@ def populate_obj(self, userobj): if repo.display_name in self: # Update the maxage repo.maxage = self[repo.display_name].data + userobj.commit() class PagePrefNotification(Controller): diff --git a/rdiffweb/controller/page_pref_session.py b/rdiffweb/controller/page_pref_session.py index 72094e26..95e27f24 100644 --- a/rdiffweb/controller/page_pref_session.py +++ b/rdiffweb/controller/page_pref_session.py @@ -47,6 +47,7 @@ def default(self, action=None, **kwargs): flash(_('You cannot revoke your current session.'), level='warning') else: session.delete() + session.commit() flash(_('The session was successfully revoked.'), level='success') else: flash(form.error_message, level='error') diff --git a/rdiffweb/controller/page_pref_sshkeys.py b/rdiffweb/controller/page_pref_sshkeys.py index ee5dc95c..26e704fd 100644 --- a/rdiffweb/controller/page_pref_sshkeys.py +++ b/rdiffweb/controller/page_pref_sshkeys.py @@ -72,7 +72,9 @@ class SshForm(CherryForm): def populate_obj(self, userobj): try: userobj.add_authorizedkey(key=self.key.data, comment=self.title.data) + userobj.commit() except DuplicateSSHKeyError as e: + userobj.rollback() flash(str(e), level='error') except Exception: flash(_("Unknown error while adding the SSH Key"), level='error') @@ -86,6 +88,7 @@ def populate_obj(self, userobj): is_maintainer() try: userobj.delete_authorizedkey(self.fingerprint.data) + userobj.commit() except Exception: flash(_("Unknown error while removing the SSH Key"), level='error') _logger.warning("error removing ssh key", exc_info=1) diff --git a/rdiffweb/controller/page_pref_tokens.py b/rdiffweb/controller/page_pref_tokens.py index f961587e..74365b85 100644 --- a/rdiffweb/controller/page_pref_tokens.py +++ b/rdiffweb/controller/page_pref_tokens.py @@ -62,6 +62,7 @@ def is_submitted(self): def populate_obj(self, userobj): try: token = userobj.add_access_token(self.name.data, self.expiration.data) + userobj.commit() flash( _( "Your new personal access token has been created.\n" diff --git a/rdiffweb/controller/tests/test_api.py b/rdiffweb/controller/tests/test_api.py index d7b0a846..d2223689 100644 --- a/rdiffweb/controller/tests/test_api.py +++ b/rdiffweb/controller/tests/test_api.py @@ -101,6 +101,7 @@ def test_auth_with_access_token(self): # Given a user with an access token userobj = UserObject.get_user(self.USERNAME) token = userobj.add_access_token('test').encode('ascii') + userobj.commit() # When using this token to authenticated with /api self.getPage('/api/', headers=[("Authorization", "Basic " + b64encode(b"admin:" + token).decode('ascii'))]) # Then authentication is successful @@ -110,7 +111,7 @@ def test_auth_failed_with_mfa_enabled(self): # Given a user with MFA enabled userobj = UserObject.get_user(self.USERNAME) userobj.mfa = UserObject.ENABLED_MFA - userobj.add() + userobj.commit() # When authenticating with /api/ self.getPage('/api/', headers=self.headers) # Then access is refused diff --git a/rdiffweb/controller/tests/test_controller.py b/rdiffweb/controller/tests/test_controller.py index 5d477ead..309f76c6 100644 --- a/rdiffweb/controller/tests/test_controller.py +++ b/rdiffweb/controller/tests/test_controller.py @@ -178,7 +178,7 @@ def test_clean_up_session(self): # When this session get old data = SessionObject.query.filter(SessionObject.id == self.session_id).first() data.expiration_time = datetime.datetime.now() - datetime.timedelta(seconds=1) - data.add() + data.commit() session = DbSession(id=self.session_id) # Then the session get deleted by clean_up process session.clean_up() diff --git a/rdiffweb/controller/tests/test_page_admin.py b/rdiffweb/controller/tests/test_page_admin.py index bb01562f..7fbac040 100644 --- a/rdiffweb/controller/tests/test_page_admin.py +++ b/rdiffweb/controller/tests/test_page_admin.py @@ -23,7 +23,8 @@ class AdminPagesAsUser(rdiffweb.test.WebCase): def setUp(self): super().setUp() # Add test user - UserObject.add_user('test', 'test123') + userobj = UserObject.add_user('test', 'test123') + userobj.commit() self._login('test', 'test123') @parameterized.expand( diff --git a/rdiffweb/controller/tests/test_page_admin_users.py b/rdiffweb/controller/tests/test_page_admin_users.py index c32a88f4..b437d67c 100644 --- a/rdiffweb/controller/tests/test_page_admin_users.py +++ b/rdiffweb/controller/tests/test_page_admin_users.py @@ -337,7 +337,8 @@ def test_delete_user_admin(self): def test_delete_user_method_get(self): # Given a user - UserObject.add_user('newuser') + user = UserObject.add_user('newuser') + user.commit() # When trying to delete this user using method GET self.getPage("/admin/users/?action=delete&username=newuser", method='GET') # Then page return without error @@ -370,7 +371,8 @@ def test_edit_user_with_invalid_path(self): """ Verify failure trying to update user with invalid path. """ - UserObject.add_user('test1') + userobj = UserObject.add_user('test1') + userobj.commit() self._edit_user("test1", "test1@test.com", "pr3j5Dwi", "/var/invalid/", UserObject.USER_ROLE) self.assertNotInBody("User added successfully.") self.assertInBody("User's root directory /var/invalid/ is not accessible!") @@ -397,18 +399,18 @@ def test_user_invalid_root(self): # Delete all user's for user in UserObject.query.all(): if user.username != self.USERNAME: - user.delete() + user.delete().commit() # Change the user's root user = UserObject.get_user('admin') user.user_root = "/invalid" - user.add() + user.commit() self.getPage("/admin/users") self.assertInBody("Root directory not accessible!") # Query the page by default user = UserObject.get_user('admin') user.user_root = "/tmp/" - user.add() + user.commit() self.getPage("/admin/users") self.assertNotInBody("Root directory not accessible!") diff --git a/rdiffweb/controller/tests/test_page_browse.py b/rdiffweb/controller/tests/test_page_browse.py index 420256c9..dc41bc4b 100644 --- a/rdiffweb/controller/tests/test_page_browse.py +++ b/rdiffweb/controller/tests/test_page_browse.py @@ -49,8 +49,8 @@ def test_locations(self): def test_locations_with_broken_tree(self): userobj = UserObject.get_user(self.USERNAME) - RepoObject(userid=userobj.userid, repopath='testcases/broker-repo').add() - RepoObject(userid=userobj.userid, repopath='testcases/testcases').add() + RepoObject(userid=userobj.userid, repopath='testcases/broker-repo').add().commit() + RepoObject(userid=userobj.userid, repopath='testcases/testcases').add().commit() self.getPage("/") def test_WithRelativePath(self): @@ -229,7 +229,7 @@ def test_with_single_repo(self): user = UserObject.get_user(self.USERNAME) user.user_root = os.path.join(self.testcases, 'testcases') user.refresh_repos() - user.add() + user.commit() self.assertEqual(['', 'broker-repo', 'testcases'], [r.name for r in user.repo_objs]) # Check if listing locations is working self.getPage('/') @@ -249,7 +249,7 @@ def test_browse_with_permissions(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() - user_obj.add() + user_obj.commit() self.getPage('/browse/admin') self.assertStatus('404 Not Found') @@ -265,8 +265,8 @@ def test_browse_without_permissions(self): # Remove admin role. admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() admin.refresh_repos() + admin.commit() # Browse other user's repos self.getPage('/browse/anotheruser/testcases') @@ -278,7 +278,7 @@ def test_browser_with_failed_repo(self): # Given a failed repo admin = UserObject.get_user('admin') admin.user_root = 'invalid' - admin.add() + admin.commit() # When querying the logs self._browse(self.USERNAME, self.REPO, '') # Then the page is return with an error message diff --git a/rdiffweb/controller/tests/test_page_delete.py b/rdiffweb/controller/tests/test_page_delete.py index 2ed89eff..8dc680ce 100644 --- a/rdiffweb/controller/tests/test_page_delete.py +++ b/rdiffweb/controller/tests/test_page_delete.py @@ -158,6 +158,7 @@ def test_delete_repo_as_admin(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self.assertEqual(['broker-repo', 'testcases'], [r.name for r in user_obj.repo_objs]) self._delete('anotheruser', 'testcases', 'testcases') @@ -178,6 +179,7 @@ def test_delete_repo_as_maintainer(self): user_obj.user_root = self.testcases user_obj.role = UserObject.MAINTAINER_ROLE user_obj.refresh_repos() + user_obj.commit() self.assertEqual(['broker-repo', 'testcases'], [r.name for r in user_obj.repo_objs]) # Login as maintainer @@ -200,6 +202,7 @@ def test_delete_repo_as_user(self): user_obj.user_root = self.testcases user_obj.role = UserObject.USER_ROLE user_obj.refresh_repos() + user_obj.commit() self.assertEqual(['broker-repo', 'testcases'], [r.name for r in user_obj.repo_objs]) # Login as maintainer diff --git a/rdiffweb/controller/tests/test_page_graphs.py b/rdiffweb/controller/tests/test_page_graphs.py index a8142928..7eae458e 100644 --- a/rdiffweb/controller/tests/test_page_graphs.py +++ b/rdiffweb/controller/tests/test_page_graphs.py @@ -54,6 +54,7 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self.getPage("/graphs/activities/anotheruser/testcases") self.assertStatus('200 OK') @@ -62,7 +63,7 @@ def test_as_another_user(self): # Remove admin role admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self.getPage("/graphs/activities/anotheruser/testcases") @@ -90,7 +91,7 @@ def test_browser_with_failed_repo(self): # Given a failed repo admin = UserObject.get_user('admin') admin.user_root = 'invalid' - admin.add() + admin.commit() # When querying the logs self.getPage("/graphs/activities/" + self.USERNAME + "/" + self.REPO + "/") # Then the page is return with an error message diff --git a/rdiffweb/controller/tests/test_page_history.py b/rdiffweb/controller/tests/test_page_history.py index 51e9b074..653ceb4d 100644 --- a/rdiffweb/controller/tests/test_page_history.py +++ b/rdiffweb/controller/tests/test_page_history.py @@ -75,13 +75,14 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self.getPage("/history/anotheruser/testcases") self.assertStatus('200 OK') # Remove admin right admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self.getPage("/history/anotheruser/testcases") @@ -99,7 +100,7 @@ def test_browser_with_failed_repo(self): # Given a failed repo admin = UserObject.get_user('admin') admin.user_root = 'invalid' - admin.add() + admin.commit() # When querying the logs self.getPage("/history/" + self.USERNAME + "/" + self.REPO) # Then the page is return with an error message diff --git a/rdiffweb/controller/tests/test_page_login.py b/rdiffweb/controller/tests/test_page_login.py index 1f096b42..5aa89fc0 100644 --- a/rdiffweb/controller/tests/test_page_login.py +++ b/rdiffweb/controller/tests/test_page_login.py @@ -170,7 +170,8 @@ def test_login_twice(self): self.assertStatus(200) self.assertInBody(self.USERNAME) # Given another user - UserObject.add_user('otheruser', password='password') + userobj = UserObject.add_user('otheruser', password='password') + userobj.commit() # When trying to re-authenticated with login page self.getPage('/login/', method='POST', body={'login': 'otheruser', 'password': 'password'}) # Then user is still authenticated with previous user diff --git a/rdiffweb/controller/tests/test_page_logs.py b/rdiffweb/controller/tests/test_page_logs.py index bb843944..54663700 100644 --- a/rdiffweb/controller/tests/test_page_logs.py +++ b/rdiffweb/controller/tests/test_page_logs.py @@ -100,7 +100,7 @@ def test_browser_with_failed_repo(self): # Given a failed repo admin = UserObject.get_user('admin') admin.user_root = 'invalid' - admin.add() + admin.commit() # When querying the logs self._log(self.USERNAME, self.REPO) # Then the page is return with an error message diff --git a/rdiffweb/controller/tests/test_page_mfa.py b/rdiffweb/controller/tests/test_page_mfa.py index 091c5294..a84c8ab1 100644 --- a/rdiffweb/controller/tests/test_page_mfa.py +++ b/rdiffweb/controller/tests/test_page_mfa.py @@ -48,7 +48,7 @@ def setUp(self): userobj = UserObject.get_user(self.USERNAME) userobj.mfa = UserObject.ENABLED_MFA userobj.email = 'admin@example.com' - userobj.add() + userobj.commit() def test_get_without_login(self): # Given an unauthenticated user @@ -64,7 +64,7 @@ def test_get_with_mfa_disabled(self): # Given an authenticated user with MFA Disable userobj = UserObject.get_user(self.USERNAME) userobj.mfa = UserObject.DISABLED_MFA - userobj.add() + userobj.commit() self.getPage("/") self.assertStatus(200) # When requesting /mfa/ page @@ -77,7 +77,7 @@ def test_get_with_user_without_email(self): # Given an authenticated user without email. userobj = UserObject.get_user(self.USERNAME) userobj.email = '' - userobj.add() + userobj.commit() # When requesting /mfa/ page self.getPage("/mfa/") # Then user is redirected to root page @@ -282,7 +282,7 @@ def setUp(self): userobj = UserObject.get_user(self.USERNAME) userobj.mfa = UserObject.ENABLED_MFA userobj.email = 'admin@example.com' - userobj.add() + userobj.commit() def test_getpage_default(self): # Given a user with MFA enabled diff --git a/rdiffweb/controller/tests/test_page_prefs_general.py b/rdiffweb/controller/tests/test_page_prefs_general.py index 940f6050..e7bba666 100644 --- a/rdiffweb/controller/tests/test_page_prefs_general.py +++ b/rdiffweb/controller/tests/test_page_prefs_general.py @@ -26,7 +26,7 @@ from parameterized import parameterized import rdiffweb.test -from rdiffweb.core.model import RepoObject, UserObject +from rdiffweb.core.model import RepoObject, SessionObject, UserObject class PagePrefGeneralTest(rdiffweb.test.WebCase): @@ -175,8 +175,14 @@ def test_change_email_with_too_long(self): self.assertInBody("Email too long.") def test_change_password(self): - self.listener.user_password_changed.reset_mock() + # Given a user with 3 active sessions + self.cookies = None + self._login(self.USERNAME, self.PASSWORD) + self.cookies = None + self._login(self.USERNAME, self.PASSWORD) + self.assertEqual(3, SessionObject.query.count()) # When udating user's password + self.listener.user_password_changed.reset_mock() self._set_password(self.PASSWORD, "pr3j5Dwi", "pr3j5Dwi") # Then user is redirect to same page self.assertStatus(303) @@ -185,6 +191,9 @@ def test_change_password(self): self.assertInBody("Password updated successfully.") # Then a notification is raised self.listener.user_password_changed.assert_called_once() + # Then all users session get deleted except our own session. + self.assertEqual(1, SessionObject.query.count()) + self.assertEqual(self.session_id, SessionObject.query.first().id) def test_change_password_with_wrong_confirmation(self): self._set_password(self.PASSWORD, "t", "a") @@ -231,7 +240,7 @@ def test_invalid_pref(self): def test_update_repos(self): # Given a user with invalid repositories userobj = UserObject.get_user(self.USERNAME) - RepoObject(userid=userobj.userid, repopath='invalid').add() + RepoObject(userid=userobj.userid, repopath='invalid').add().commit() self.assertEqual(['broker-repo', 'invalid', 'testcases'], sorted([r.name for r in userobj.repo_objs])) # When updating the repository list self.getPage(self.PREFS, method='POST', body={'action': 'update_repos'}) diff --git a/rdiffweb/controller/tests/test_page_prefs_mfa.py b/rdiffweb/controller/tests/test_page_prefs_mfa.py index 97e8e532..0204cb77 100644 --- a/rdiffweb/controller/tests/test_page_prefs_mfa.py +++ b/rdiffweb/controller/tests/test_page_prefs_mfa.py @@ -33,7 +33,7 @@ def setUp(self): # Define email for all test userobj = UserObject.get_user(self.USERNAME) userobj.email = 'admin@example.com' - userobj.add() + userobj.commit() # Register a listener on email self.listener = MagicMock() cherrypy.engine.subscribe('queue_mail', self.listener.queue_email, priority=50) @@ -46,7 +46,7 @@ def _set_mfa(self, mfa): # Define mfa for user userobj = UserObject.get_user(self.USERNAME) userobj.mfa = mfa - userobj.add() + userobj.commit() # Reset mock. self.listener.queue_email.reset_mock() # Leave to disable mfa @@ -142,7 +142,7 @@ def test_without_email(self, action, initial_mfa): # Given a user without email requesting a code userobj = UserObject.get_user(self.USERNAME) userobj.email = '' - userobj.add() + userobj.commit() # When trying to enable or disable mfa self.getPage("/prefs/mfa", method='POST', body={action: '1'}) # Then an error is return to the user diff --git a/rdiffweb/controller/tests/test_page_prefs_ssh.py b/rdiffweb/controller/tests/test_page_prefs_ssh.py index e2406b77..005bea4f 100644 --- a/rdiffweb/controller/tests/test_page_prefs_ssh.py +++ b/rdiffweb/controller/tests/test_page_prefs_ssh.py @@ -69,6 +69,7 @@ def test_add_duplicate(self): user = UserObject.get_user('admin') for key in user.authorizedkeys: user.delete_authorizedkey(key.fingerprint) + user.commit() self.assertEqual(0, len(list(user.authorizedkeys))) # Add a new key @@ -206,6 +207,7 @@ def test_delete(self): key="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSEN5VTn9MLituZvdYTZMbZEaMxe0UuU7BelxHkvxzSpVWtazrIBEc3KZjtVoK9F3+0kd26P4DzSQuPUl3yZDgyZZeXrF6p2GlEA7A3tPuOEsAQ9c0oTiDYktq5/Go8vD+XAZKLd//qmCWW1Jg4datkWchMKJzbHUgBrBH015FDbGvGDWYTfVyb8I9H+LQ0GmbTHsuTu63DhPODncMtWPuS9be/flb4EEojMIx5Vce0SNO9Eih38W7jTvNWxZb75k5yfPJxBULRnS5v/fPnDVVtD3JSGybSwKoMdsMX5iImAeNhqnvd8gBu1f0IycUQexTbJXk1rPiRcF13SjKrfXz ikus060@ikus060-t530", comment="test@mysshkey", ) + user.commit() self.assertEqual(1, len(list(user.authorizedkeys))) # When deleting the ssh key self.getPage( @@ -228,6 +230,7 @@ def test_get(self): key="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSEN5VTn9MLituZvdYTZMbZEaMxe0UuU7BelxHkvxzSpVWtazrIBEc3KZjtVoK9F3+0kd26P4DzSQuPUl3yZDgyZZeXrF6p2GlEA7A3tPuOEsAQ9c0oTiDYktq5/Go8vD+XAZKLd//qmCWW1Jg4datkWchMKJzbHUgBrBH015FDbGvGDWYTfVyb8I9H+LQ0GmbTHsuTu63DhPODncMtWPuS9be/flb4EEojMIx5Vce0SNO9Eih38W7jTvNWxZb75k5yfPJxBULRnS5v/fPnDVVtD3JSGybSwKoMdsMX5iImAeNhqnvd8gBu1f0IycUQexTbJXk1rPiRcF13SjKrfXz ikus060@ikus060-t530", comment="test@mysshkey", ) + user.commit() self.assertEqual(1, len(list(user.authorizedkeys))) # When deleting the ssh key data = self.getJson( @@ -250,6 +253,7 @@ def test_get_invalid(self): key="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSEN5VTn9MLituZvdYTZMbZEaMxe0UuU7BelxHkvxzSpVWtazrIBEc3KZjtVoK9F3+0kd26P4DzSQuPUl3yZDgyZZeXrF6p2GlEA7A3tPuOEsAQ9c0oTiDYktq5/Go8vD+XAZKLd//qmCWW1Jg4datkWchMKJzbHUgBrBH015FDbGvGDWYTfVyb8I9H+LQ0GmbTHsuTu63DhPODncMtWPuS9be/flb4EEojMIx5Vce0SNO9Eih38W7jTvNWxZb75k5yfPJxBULRnS5v/fPnDVVtD3JSGybSwKoMdsMX5iImAeNhqnvd8gBu1f0IycUQexTbJXk1rPiRcF13SjKrfXz ikus060@ikus060-t530", comment="test@mysshkey", ) + user.commit() self.assertEqual(1, len(list(user.authorizedkeys))) # When deleting the ssh key self.getPage( @@ -269,6 +273,7 @@ def test_list(self): key="ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSEN5VTn9MLituZvdYTZMbZEaMxe0UuU7BelxHkvxzSpVWtazrIBEc3KZjtVoK9F3+0kd26P4DzSQuPUl3yZDgyZZeXrF6p2GlEA7A3tPuOEsAQ9c0oTiDYktq5/Go8vD+XAZKLd//qmCWW1Jg4datkWchMKJzbHUgBrBH015FDbGvGDWYTfVyb8I9H+LQ0GmbTHsuTu63DhPODncMtWPuS9be/flb4EEojMIx5Vce0SNO9Eih38W7jTvNWxZb75k5yfPJxBULRnS5v/fPnDVVtD3JSGybSwKoMdsMX5iImAeNhqnvd8gBu1f0IycUQexTbJXk1rPiRcF13SjKrfXz ikus060@ikus060-t530", comment="test@mysshkey", ) + user.commit() self.assertEqual(1, len(list(user.authorizedkeys))) # When deleting the ssh key data = self.getJson( diff --git a/rdiffweb/controller/tests/test_page_prefs_tokens.py b/rdiffweb/controller/tests/test_page_prefs_tokens.py index 5f7e24b1..88fde349 100644 --- a/rdiffweb/controller/tests/test_page_prefs_tokens.py +++ b/rdiffweb/controller/tests/test_page_prefs_tokens.py @@ -95,6 +95,7 @@ def test_delete_access_token(self): # Given an existing user with access_token userobj = UserObject.get_user(self.USERNAME) userobj.add_access_token('test-token-name') + userobj.commit() # When deleting access token self.getPage( "/prefs/tokens", diff --git a/rdiffweb/controller/tests/test_page_restore.py b/rdiffweb/controller/tests/test_page_restore.py index f5002239..54362afa 100644 --- a/rdiffweb/controller/tests/test_page_restore.py +++ b/rdiffweb/controller/tests/test_page_restore.py @@ -347,6 +347,7 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self._restore("anotheruser", "testcases", "Fichier%20%40%20%3Croot%3E/", "1414921853") self.assertStatus('200 OK') self.assertInBody("Ajout d'info") @@ -354,7 +355,7 @@ def test_as_another_user(self): # Remove admin right admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self._restore("anotheruser", "testcases", "Fichier%20%40%20%3Croot%3E/", "1414921853") diff --git a/rdiffweb/controller/tests/test_page_settings.py b/rdiffweb/controller/tests/test_page_settings.py index a314ea3a..efca5e22 100644 --- a/rdiffweb/controller/tests/test_page_settings.py +++ b/rdiffweb/controller/tests/test_page_settings.py @@ -40,6 +40,7 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self.getPage("/settings/anotheruser/testcases") self.assertInBody("Character encoding") self.assertStatus('200 OK') @@ -47,7 +48,7 @@ def test_as_another_user(self): # Remove admin right admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self.getPage("/settings/anotheruser/testcases") @@ -81,7 +82,7 @@ def test_browser_with_failed_repo(self): # Given a failed repo admin = UserObject.get_user('admin') admin.user_root = '/invalid/' - admin.add() + admin.commit() # When querying the logs self.getPage("/settings/" + self.USERNAME + "/" + self.REPO) # Then the page is return with an error message diff --git a/rdiffweb/controller/tests/test_page_settings_remove_older.py b/rdiffweb/controller/tests/test_page_settings_remove_older.py index aa7447e5..b263b79e 100644 --- a/rdiffweb/controller/tests/test_page_settings_remove_older.py +++ b/rdiffweb/controller/tests/test_page_settings_remove_older.py @@ -52,6 +52,7 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self._remove_older('anotheruser', 'testcases', '1') self.assertStatus('200 OK') repo = RepoObject.query.filter(RepoObject.user == user_obj, RepoObject.repopath == self.REPO).first() @@ -60,7 +61,7 @@ def test_as_another_user(self): # Remove admin right admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self._remove_older('anotheruser', 'testcases', '2') diff --git a/rdiffweb/controller/tests/test_page_settings_set_encoding.py b/rdiffweb/controller/tests/test_page_settings_set_encoding.py index 850c3cb4..0b98af23 100644 --- a/rdiffweb/controller/tests/test_page_settings_set_encoding.py +++ b/rdiffweb/controller/tests/test_page_settings_set_encoding.py @@ -94,6 +94,7 @@ def test_as_another_user(self): user_obj = UserObject.add_user('anotheruser', 'password') user_obj.user_root = self.testcases user_obj.refresh_repos() + user_obj.commit() self._set_encoding('anotheruser', 'testcases', 'cp1252') self.assertStatus('200 OK') repo = RepoObject.query.filter(RepoObject.user == user_obj, RepoObject.repopath == self.REPO).first() @@ -102,7 +103,7 @@ def test_as_another_user(self): # Remove admin right admin = UserObject.get_user('admin') admin.role = UserObject.USER_ROLE - admin.add() + admin.commit() # Browse admin's repos self._set_encoding('anotheruser', 'testcases', 'utf-8') diff --git a/rdiffweb/controller/tests/test_page_status.py b/rdiffweb/controller/tests/test_page_status.py index f0b42793..af389c8d 100644 --- a/rdiffweb/controller/tests/test_page_status.py +++ b/rdiffweb/controller/tests/test_page_status.py @@ -39,7 +39,7 @@ def test_page_with_broken_repo(self): # Given a user's with broken repo userobj = UserObject.get_user('admin') userobj.user_root = '/invalid/' - userobj.add() + userobj.commit() # When browsing the status page self.getPage("/status/") # Then not error should be raised diff --git a/rdiffweb/core/login.py b/rdiffweb/core/login.py index aca9903f..607c8e68 100644 --- a/rdiffweb/core/login.py +++ b/rdiffweb/core/login.py @@ -85,7 +85,7 @@ def login(self, username, password): email=email, role=default_role, user_root=default_user_root, - ).add() + ).commit() except Exception: logger.error('fail to create new user', exc_info=1) if userobj is None: @@ -101,7 +101,7 @@ def login(self, username, password): userobj.email = email dirty = True if dirty: - userobj.add() + userobj.commit() self.bus.publish('user_login', userobj) return userobj diff --git a/rdiffweb/core/model/__init__.py b/rdiffweb/core/model/__init__.py index 2f324300..ea057b06 100644 --- a/rdiffweb/core/model/__init__.py +++ b/rdiffweb/core/model/__init__.py @@ -89,10 +89,10 @@ def add_column(column): for row in result: if row.repopath.startswith('/') or row.repopath.endswith('/'): row.repopath = row.repopath.strip('/') - row.add() + row.commit() if row.repopath == '.': row.repopath = '' - row.add() + row.commit() # Remove duplicates and nested repositories. result = RepoObject.query.order_by(RepoObject.userid, RepoObject.repopath).all() prev_repo = (None, None) diff --git a/rdiffweb/core/model/_repo.py b/rdiffweb/core/model/_repo.py index 73964cc3..a09c0fe1 100644 --- a/rdiffweb/core/model/_repo.py +++ b/rdiffweb/core/model/_repo.py @@ -96,7 +96,8 @@ def get_repo(cls, name, as_user=None, refresh=False): record = query.first() # If the repo is not found but refresh is requested if refresh and not record: - as_user.refresh_repos() + if as_user.refresh_repos(): + as_user.commit() record = query.first() # If repo is not found, raise an error if not record: @@ -176,7 +177,7 @@ def delete(self, path=b''): RdiffRepo.delete(self, path=path) # Remove entry from database after deleting files. # Otherwise, refresh will add this repo back. - super().delete() + return super().delete() @validates('encoding') def validate_encoding(self, key, value): diff --git a/rdiffweb/core/model/_session.py b/rdiffweb/core/model/_session.py index 8e96683d..34e4cc39 100644 --- a/rdiffweb/core/model/_session.py +++ b/rdiffweb/core/model/_session.py @@ -71,16 +71,18 @@ def _save(self, expiration_time): session.data = self._data session.data['_timeout'] = self.timeout session.expiration_time = expiration_time - session.add() + session.add().commit() def _delete(self): SessionObject.query.filter(SessionObject.id == self.id).delete() + SessionObject.session.commit() def clean_up(self): """Clean up expired sessions.""" try: now = self.now() SessionObject.query.filter(SessionObject.expiration_time < now).delete() + SessionObject.session.commit() except Exception: logger.error('fail to clean-up sessions', exc_info=1) finally: diff --git a/rdiffweb/core/model/_token.py b/rdiffweb/core/model/_token.py index fc9593a0..267fb0cb 100644 --- a/rdiffweb/core/model/_token.py +++ b/rdiffweb/core/model/_token.py @@ -17,6 +17,7 @@ import datetime import cherrypy +from cherrypy.process.plugins import SimplePlugin from sqlalchemy import Column, DateTime, Integer, String from sqlalchemy.orm import relationship from sqlalchemy.sql import func @@ -43,3 +44,33 @@ class Token(Base): @property def is_expired(self): return self.expiration_time is not None and self.expiration_time <= datetime.datetime.now() + + def accessed(self): + self.access_time = datetime.datetime.utcnow() + + +class TokenCleanup(SimplePlugin): + + execution_time = '23:00' + + def start(self): + self.bus.log('Start Token Clean Up plugin') + self.bus.publish('schedule_job', self.execution_time, self.clean_up) + + start.priority = 55 + + def stop(self): + self.bus.log('Stop Token Clean Up plugin') + self.bus.publish('unschedule_job', self.clean_up) + + stop.priority = 45 + + def clean_up(self): + Token.query.filter(Token.expiration_time <= datetime.datetime.now()).delete() + Token.session.commit() + + +cherrypy.token_cleanup = TokenCleanup(cherrypy.engine) +cherrypy.token_cleanup.subscribe() + +cherrypy.config.namespaces['token_cleanup'] = lambda key, value: setattr(cherrypy.token_cleanup, key, value) diff --git a/rdiffweb/core/model/_user.py b/rdiffweb/core/model/_user.py index d3d94bf1..1510be3c 100644 --- a/rdiffweb/core/model/_user.py +++ b/rdiffweb/core/model/_user.py @@ -14,7 +14,6 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import datetime import logging import os import secrets @@ -33,6 +32,7 @@ from rdiffweb.tools.i18n import ugettext as _ from ._repo import RepoObject +from ._session import SessionObject from ._sshkey import SshKey from ._token import Token @@ -127,6 +127,7 @@ def create_admin_user(cls, default_username, default_password): else: userobj.hash_password = hash_password('admin123') userobj.add() + return userobj @classmethod def add_user(cls, username, password=None, role=USER_ROLE, **attrs): @@ -177,9 +178,8 @@ def add_authorizedkey(self, key, comment=None): # Also look in database. logger.info("add key [%s] to [%s] database", key, self.username) try: - SshKey(userid=self.userid, fingerprint=key.fingerprint, key=key.getvalue()).add() + SshKey(userid=self.userid, fingerprint=key.fingerprint, key=key.getvalue()).add().flush() except IntegrityError: - SshKey.session.rollback() raise DuplicateSSHKeyError( _("Duplicate key. This key already exists or is associated to another user.") ) @@ -195,10 +195,10 @@ def add_access_token(self, name, expiration_time=None, length=16): token = ''.join(secrets.choice(string.ascii_lowercase) for i in range(length)) # Store hash token try: - obj = Token(userid=self.userid, name=name, hash_token=hash_password(token), expiration_time=expiration_time) - obj.add() + Token( + userid=self.userid, name=name, hash_token=hash_password(token), expiration_time=expiration_time + ).add().flush() except IntegrityError: - Token.session.rollback() raise ValueError(_("Duplicate token name: %s") % name) cherrypy.engine.publish('access_token_added', self, name) return token @@ -221,7 +221,7 @@ def delete(self, *args, **kwargs): RepoObject.query.filter(RepoObject.userid == self.userid).delete() Token.query.filter(Token.userid == self.userid).delete() # Delete ourself - Base.delete(self) + return Base.delete(self) def delete_authorizedkey(self, fingerprint): """ @@ -378,9 +378,17 @@ def set_password(self, password): msg += ' ' + ' '.join(suggestions) raise ValueError(msg) - logger.info("updating user password [%s]", self.username) + # Store password + logger.info("updating user password [%s] and revoke sessions", self.username) self.hash_password = hash_password(password) + # Revoke other session to force re-login + session_id = cherrypy.serving.session.id if getattr(cherrypy.serving, 'session', None) else None + SessionObject.query.filter( + SessionObject.username == self.username, + SessionObject.id != session_id, + ).delete() + def __eq__(self, other): return type(self) == type(other) and inspect(self).key == inspect(other).key @@ -395,14 +403,11 @@ def validate_access_token(self, token): Check if the given token matches. """ for access_token in Token.query.all(): - # If token expired. Let delete it. if access_token.is_expired: - access_token.delete() continue if check_password(token, access_token.hash_token): - # When it matches, let update the record. - access_token.access_time = datetime.datetime.utcnow - return True + # When it matches, return the record. + return access_token return False def validate_password(self, password): diff --git a/rdiffweb/core/model/tests/test_repo.py b/rdiffweb/core/model/tests/test_repo.py index 201b77a6..bc566175 100644 --- a/rdiffweb/core/model/tests/test_repo.py +++ b/rdiffweb/core/model/tests/test_repo.py @@ -33,7 +33,7 @@ def test_update_remove_duplicates(self): # Given a database with duplicate path userobj = UserObject.get_user(self.USERNAME) self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in userobj.repo_objs])) - RepoObject(userid=userobj.userid, repopath='/testcases').add() + RepoObject(userid=userobj.userid, repopath='/testcases').add().commit() self.assertEqual(['/testcases', 'broker-repo', 'testcases'], sorted([r.name for r in userobj.repo_objs])) # When creating database cherrypy.tools.db.create_all() @@ -46,7 +46,7 @@ def test_update_remove_nested(self): userobj = UserObject.get_user(self.USERNAME) self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in userobj.repo_objs])) RepoObject(userid=userobj.userid, repopath='testcases/home/admin/testcases').add() - RepoObject(userid=userobj.userid, repopath='/testcases/home/admin/data').add() + RepoObject(userid=userobj.userid, repopath='/testcases/home/admin/data').add().commit() self.assertEqual( ['/testcases/home/admin/data', 'broker-repo', 'testcases', 'testcases/home/admin/testcases'], sorted([r.name for r in userobj.repo_objs]), @@ -61,7 +61,7 @@ def test_update_repos_remove_slash(self): # Given a user with a repository named "/testcases" userobj = UserObject.get_user(self.USERNAME) RepoObject.query.filter(RepoObject.userid == userobj.userid).delete() - RepoObject(userid=userobj.userid, repopath='/testcases').add() + RepoObject(userid=userobj.userid, repopath='/testcases').add().commit() self.assertEqual(['/testcases'], sorted([r.name for r in userobj.repo_objs])) # When updating the database schema cherrypy.tools.db.create_all() @@ -73,7 +73,7 @@ def test_get_repo(self): user = UserObject.add_user('bernie', 'my-password') user.user_root = self.testcases user.refresh_repos() - + user.commit() # Get as bernie repo_obj = RepoObject.get_repo('bernie/testcases', user) self.assertEqual('testcases', repo_obj.name) @@ -94,10 +94,12 @@ def test_get_repo_as_other_user(self): user = UserObject.add_user('bernie', 'my-password') user.user_root = self.testcases user.refresh_repos() + user.commit() RepoObject.get_repo('bernie/testcases', user) # Get as otheruser other = UserObject.add_user('other') + other.commit() with self.assertRaises(AccessDeniedError): RepoObject.get_repo('bernie/testcases', other) @@ -105,10 +107,12 @@ def test_get_repo_as_admin(self): user = UserObject.add_user('bernie', 'my-password') user.user_root = self.testcases user.refresh_repos() + user.commit() # Get as admin other = UserObject.add_user('other') other.role = UserObject.ADMIN_ROLE + other.commit() repo_obj3 = RepoObject.get_repo('bernie/testcases', other) self.assertEqual('testcases', repo_obj3.name) self.assertEqual('bernie', repo_obj3.owner) @@ -128,7 +132,7 @@ def test_set_get_encoding(self): userobj = UserObject.get_user(self.USERNAME) repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() repo_obj.encoding = "cp1252" - repo_obj.add() + repo_obj.commit() repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() self.assertEqual("cp1252", repo_obj.encoding) # Check with invalid value. @@ -139,7 +143,7 @@ def test_set_get_maxage(self): userobj = UserObject.get_user(self.USERNAME) repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() repo_obj.maxage = 10 - repo_obj.add() + repo_obj.commit() repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() self.assertEqual(10, repo_obj.maxage) # Check with invalid value. @@ -150,7 +154,7 @@ def test_set_get_keepdays(self): userobj = UserObject.get_user(self.USERNAME) repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() repo_obj.keepdays = 10 - repo_obj.add() + repo_obj.commit() repo_obj = RepoObject.query.filter(RepoObject.user == userobj, RepoObject.repopath == self.REPO).first() self.assertEqual(10, repo_obj.keepdays) # Check with invalid value. @@ -169,7 +173,7 @@ def test_keepdays_default_value_from_init(self): # Given a User userobj = UserObject.get_user(self.USERNAME) # When creating a new repository - repo_obj = RepoObject(user=userobj, repopath='repopath').add() + repo_obj = RepoObject(user=userobj, repopath='repopath').add().commit() # New repo get created with keepdays == -1 self.assertEqual('-1', repo_obj._keepdays) self.assertEqual(-1, repo_obj.keepdays) @@ -178,7 +182,7 @@ def test_keepdays_empty_string(self): # Given a User userobj = UserObject.get_user(self.USERNAME) # When creating a new repository - repo_obj = RepoObject(user=userobj, repopath='repopath').add() + repo_obj = RepoObject(user=userobj, repopath='repopath').add().commit() RepoObject.session.execute( RepoObject.__table__.update().where(RepoObject.__table__.c.RepoID == repo_obj.repoid).values(keepdays='') ) @@ -199,6 +203,6 @@ def test_encoding_default_value_from_init(self): # Given a User userobj = UserObject.get_user(self.USERNAME) # When creating a new repository - repo_obj = RepoObject(user=userobj, repopath='repopath').add() + repo_obj = RepoObject(user=userobj, repopath='repopath').add().commit() # New repo get created with utf-8 self.assertEqual('utf-8', repo_obj.encoding) diff --git a/rdiffweb/core/model/tests/test_token.py b/rdiffweb/core/model/tests/test_token.py new file mode 100644 index 00000000..698a6fb2 --- /dev/null +++ b/rdiffweb/core/model/tests/test_token.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# rdiffweb, A web interface to rdiff-backup repositories +# Copyright (C) 2012-2021 rdiffweb contributors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Created on June 30, 2022 + +Module to test `user` model. + +@author: Patrik Dufresne +""" +import datetime +from io import open +from unittest.mock import MagicMock + +import cherrypy +import pkg_resources + +import rdiffweb.test +from rdiffweb.core.model import Token, UserObject + + +class TokenTest(rdiffweb.test.WebCase): + def _read_ssh_key(self): + """Readthe pub key from test packages""" + filename = pkg_resources.resource_filename('rdiffweb.core.tests', 'test_publickey_ssh_rsa.pub') + with open(filename, 'r', encoding='utf8') as f: + return f.readline() + + def _read_authorized_keys(self): + """Read the content of test_authorized_keys""" + filename = pkg_resources.resource_filename('rdiffweb.core.tests', 'test_authorized_keys') + with open(filename, 'r', encoding='utf8') as f: + return f.read() + + def setUp(self): + super().setUp() + self.listener = MagicMock() + cherrypy.engine.subscribe('access_token_added', self.listener.access_token_added, priority=50) + cherrypy.engine.subscribe('queue_mail', self.listener.queue_mail, priority=50) + + def tearDown(self): + cherrypy.engine.unsubscribe('access_token_added', self.listener.access_token_added) + cherrypy.engine.unsubscribe('queue_mail', self.listener.queue_mail) + return super().tearDown() + + def test_check_schedule(self): + # Given the application is started + # Then remove_older job should be schedule + self.assertEqual(1, len([job for job in cherrypy.scheduler.list_jobs() if job.name == 'clean_up'])) + + def test_clean_up_without_expired(self): + # Given a user with 3 Token + user = UserObject.get_user(self.USERNAME) + user.add_access_token('test1') + user.add_access_token('test2') + user.add_access_token('test3') + user.commit() + self.assertEqual(3, Token.query.count()) + # When running notification_job + cherrypy.token_cleanup.clean_up() + # Then token are not removed + self.assertEqual(3, Token.query.count()) + + def test_clean_up_with_expired(self): + # Given a user with 3 Token + user = UserObject.get_user(self.USERNAME) + user.add_access_token('test1') + user.add_access_token('test2') + user.add_access_token('test3') + for t in Token.query.all(): + t.expiration_time = datetime.datetime.now() + user.commit() + self.assertEqual(3, Token.query.count()) + # When running notification_job + cherrypy.token_cleanup.clean_up() + # Then token are not removed + self.assertEqual(0, Token.query.count()) + + def test_add_access_token(self): + # Given a user with an email + userobj = UserObject.get_user(self.USERNAME) + userobj.email = 'test@examples.com' + userobj.commit() + # When adding a new token + token = userobj.add_access_token('test') + userobj.commit() + # Then a new token get created + self.assertTrue(token) + tokenobj = Token.query.filter(Token.userid == userobj.userid).first() + self.assertTrue(tokenobj) + self.assertEqual(None, tokenobj.expiration_time) + self.assertEqual(None, tokenobj.access_time) + # Then an email is sent to the user. + self.listener.access_token_added.assert_called_once_with(userobj, 'test') + self.listener.queue_mail.assert_called_once() + + def test_add_access_token_duplicate_name(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + userobj.add_access_token('test') + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When adding a new token with the same name + with self.assertRaises(ValueError): + userobj.add_access_token('test') + userobj.commit() + userobj.rollback() + # Then token is not created + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # Then an email is not sent. + self.listener.access_token_added.assert_called_once_with(userobj, 'test') + + def test_delete_access_token(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + userobj.add_access_token('test') + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When deleting an access token + userobj.delete_access_token('test') + userobj.commit() + # Then Token get deleted + self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count()) + + def test_delete_access_token_invalid(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + userobj.add_access_token('test') + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When deleting an invalid access token + with self.assertRaises(ValueError): + userobj.delete_access_token('invalid') + # Then Token not deleted + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + + def test_delete_user_remove_access_tokens(self): + # Given a user with an existing token + userobj = UserObject.add_user('testuser', 'password') + userobj.commit() + userobj.add_access_token('test') + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When deleting the user + userobj.delete() + userobj.commit() + # Then Token get deleted + self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count()) + + def test_verify_access_token(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + token = userobj.add_access_token('test') + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When validating the token + # Then token is valid + self.assertTrue(userobj.validate_access_token(token)) + + def test_verify_access_token_with_expired(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + token = userobj.add_access_token( + 'test', expiration_time=datetime.datetime.now() - datetime.timedelta(seconds=1) + ) + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When validating the token + # Then token is invalid + self.assertFalse(userobj.validate_access_token(token)) + # Then token is not removed + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + + def test_verify_access_token_with_invalid(self): + # Given a user with an existing token + userobj = UserObject.get_user(self.USERNAME) + userobj.add_access_token('test', expiration_time=datetime.datetime.now()) + userobj.commit() + self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) + # When validating the token + # Then token is invalid + self.assertFalse(userobj.validate_access_token('invalid')) diff --git a/rdiffweb/core/model/tests/test_user.py b/rdiffweb/core/model/tests/test_user.py index 4a57ded5..497e0738 100644 --- a/rdiffweb/core/model/tests/test_user.py +++ b/rdiffweb/core/model/tests/test_user.py @@ -21,7 +21,6 @@ @author: Patrik Dufresne """ -import datetime import os from io import StringIO, open from unittest.mock import MagicMock @@ -32,7 +31,7 @@ import rdiffweb.test from rdiffweb.core import authorizedkeys -from rdiffweb.core.model import DuplicateSSHKeyError, RepoObject, Token, UserObject +from rdiffweb.core.model import DuplicateSSHKeyError, RepoObject, UserObject from rdiffweb.core.passwd import check_password @@ -73,7 +72,7 @@ def tearDown(self): def test_add_user(self): """Add user to database.""" userobj = UserObject.add_user('joe') - self.assertIsNotNone(userobj) + userobj.commit() self.assertIsNotNone(UserObject.get_user('joe')) # Check if listener called self.listener.user_added.assert_called_once_with(userobj) @@ -87,7 +86,7 @@ def change_user_obj(userobj): self.listener.user_added.side_effect = change_user_obj # When adding user userobj = UserObject.add_user('joe') - self.assertIsNotNone(userobj) + userobj.commit() self.assertIsNotNone(UserObject.get_user('joe')) # Then lister get called self.listener.user_added.assert_called_once_with(userobj) @@ -96,7 +95,8 @@ def change_user_obj(userobj): def test_add_user_with_duplicate(self): """Add user to database.""" - UserObject.add_user('denise') + user = UserObject.add_user('denise') + user.commit() self.listener.user_added.reset_mock() with self.assertRaises(ValueError): UserObject.add_user('denise') @@ -106,6 +106,7 @@ def test_add_user_with_duplicate(self): def test_add_user_with_password(self): """Add user to database with password.""" userobj = UserObject.add_user('jo', 'password') + userobj.commit() self.assertIsNotNone(UserObject.get_user('jo')) # Check if listener called self.listener.user_added.assert_called_once_with(userobj) @@ -120,7 +121,8 @@ def test_users(self): # Check admin exists self.assertEqual(1, UserObject.query.count()) # Create user. - UserObject.add_user('annik') + user = UserObject.add_user('annik') + user.commit() users = UserObject.query.all() self.assertEqual(2, len(users)) self.assertEqual('annik', users[1].username) @@ -134,9 +136,11 @@ def test_get_user(self): user.role = UserObject.ADMIN_ROLE user.email = 'bernie@gmail.com' user.refresh_repos() + user.commit() self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in user.repo_objs])) user.repo_objs[0].maxage = -1 user.repo_objs[1].maxage = 3 + user.commit() # Get user record. obj = UserObject.get_user('bernie') @@ -159,6 +163,7 @@ def test_get_user_with_invalid_user(self): def test_get_set(self): user = UserObject.add_user('larry', 'password') + user.add().commit() self.assertEqual('', user.email) self.assertEqual([], user.repo_objs) @@ -168,18 +173,19 @@ def test_get_set(self): user.user_root = self.testcases user.refresh_repos() + user.commit() self.listener.user_attr_changed.assert_called_with(user, {'user_root': ('', self.testcases)}) self.listener.user_attr_changed.reset_mock() user = UserObject.get_user('larry') user.role = UserObject.ADMIN_ROLE - user.add() + user.commit() self.listener.user_attr_changed.assert_called_with( user, {'role': (UserObject.USER_ROLE, UserObject.ADMIN_ROLE)} ) self.listener.user_attr_changed.reset_mock() user = UserObject.get_user('larry') user.email = 'larry@gmail.com' - user.add() + user.commit() self.listener.user_attr_changed.assert_called_with(user, {'email': ('', 'larry@gmail.com')}) self.listener.user_attr_changed.reset_mock() @@ -192,11 +198,12 @@ def test_get_set(self): def test_set_role_null(self): # Given a user user = UserObject.add_user('annik', 'password') + user.add().commit() # When trying to set the role to null user.role = None # Then an exception is raised with self.assertRaises(Exception): - user.add() + user.add().commit() @parameterized.expand( [ @@ -212,7 +219,7 @@ def test_is_admin(self, role, expected_is_admin): user = UserObject.add_user('annik', 'password') # When setting the role value user.role = role - user.add() + user.commit() # Then the is_admin value get updated too self.assertEqual(expected_is_admin, user.is_admin) @@ -230,16 +237,18 @@ def test_is_maintainer(self, role, expected_is_maintainer): user = UserObject.add_user('annik', 'password') # When setting the role value user.role = role - user.add() + user.commit() # Then the is_admin value get updated too self.assertEqual(expected_is_maintainer, user.is_maintainer) def test_set_password_update(self): # Given a user in database with a password userobj = UserObject.add_user('annik', 'password') + userobj.commit() self.listener.user_password_changed.reset_mock() # When updating the user's password userobj.set_password('new_password') + userobj.commit() # Then password is SSHA self.assertTrue(check_password('new_password', userobj.hash_password)) # Check if listener called @@ -248,9 +257,11 @@ def test_set_password_update(self): def test_delete_user(self): # Given an existing user in database userobj = UserObject.add_user('vicky') + userobj.commit() self.assertIsNotNone(UserObject.get_user('vicky')) # When deleting that user userobj.delete() + userobj.commit() # Then user it no longer in database self.assertIsNone(UserObject.get_user('vicky')) # Then listner was called @@ -259,6 +270,7 @@ def test_delete_user(self): def test_set_password_empty(self): """Expect error when trying to update password of invalid user.""" userobj = UserObject.add_user('john') + userobj.commit() with self.assertRaises(ValueError): self.assertFalse(userobj.set_password('')) @@ -286,6 +298,7 @@ def test_add_authorizedkey_without_file(self): # Add the key to the user userobj = UserObject.get_user(self.USERNAME) userobj.add_authorizedkey(key) + userobj.commit() # validate keys = list(userobj.authorizedkeys) @@ -298,9 +311,11 @@ def test_add_authorizedkey_duplicate(self): # Add the key to the user userobj = UserObject.get_user(self.USERNAME) userobj.add_authorizedkey(key) + userobj.commit() # Add the same key with self.assertRaises(DuplicateSSHKeyError): userobj.add_authorizedkey(key) + userobj.commit() def test_add_authorizedkey_with_file(self): """ @@ -316,6 +331,7 @@ def test_add_authorizedkey_with_file(self): # Read the pub key key = self._read_ssh_key() userobj.add_authorizedkey(key) + userobj.commit() # Validate with open(filename, 'r') as fh: @@ -341,6 +357,7 @@ def test_delete_authorizedkey_without_file(self): # Remove a key userobj.delete_authorizedkey("9a:f1:69:3c:bc:5a:cd:02:5e:33:bc:cd:c0:01:eb:4c") + userobj.commit() # Validate keys = list(userobj.authorizedkeys) @@ -376,6 +393,7 @@ def test_repo_objs(self): self.assertEqual(['broker-repo', 'testcases'], [r.name for r in repos]) # When deleting a repository empty list repos[1].delete() + repos[1].commit() # Then the repository is removed from the list. self.assertEqual(['broker-repo'], sorted([r.name for r in userobj.repo_objs])) @@ -383,10 +401,11 @@ def test_refresh_repos_without_delete(self): # Given a user with invalid repositories userobj = UserObject.get_user(self.USERNAME) RepoObject.query.delete() - RepoObject(userid=userobj.userid, repopath='invalid').add() + RepoObject(userid=userobj.userid, repopath='invalid').add().commit() self.assertEqual(['invalid'], sorted([r.name for r in userobj.repo_objs])) # When updating the repository list without deletion userobj.refresh_repos() + userobj.commit() # Then the list invlaid the invalid repo and new repos self.assertEqual(['broker-repo', 'invalid', 'testcases'], sorted([r.name for r in userobj.repo_objs])) @@ -394,10 +413,11 @@ def test_refresh_repos_with_delete(self): # Given a user with invalid repositories userobj = UserObject.get_user(self.USERNAME) RepoObject.query.delete() - RepoObject(userid=userobj.userid, repopath='invalid').add() + RepoObject(userid=userobj.userid, repopath='invalid').add().commit() self.assertEqual(['invalid'], sorted([r.name for r in userobj.repo_objs])) # When updating the repository list without deletion userobj.refresh_repos(delete=True) + userobj.commit() # Then the list invlaid the invalid repo and new repos userobj.expire() self.assertEqual(['broker-repo', 'testcases'], sorted([r.name for r in userobj.repo_objs])) @@ -408,102 +428,11 @@ def test_refresh_repos_with_single_repo(self): userobj.user_root = os.path.join(self.testcases, 'testcases') # When updating the repository list without deletion userobj.refresh_repos(delete=True) + userobj.commit() # Then the list invlaid the invalid repo and new repos userobj.expire() self.assertEqual([''], sorted([r.name for r in userobj.repo_objs])) - def test_add_access_token(self): - # Given a user with an email - userobj = UserObject.get_user(self.USERNAME) - userobj.email = 'test@examples.com' - userobj.add() - # When adding a new token - token = userobj.add_access_token('test') - # Then a new token get created - self.assertTrue(token) - tokenobj = Token.query.filter(Token.userid == userobj.userid).first() - self.assertTrue(tokenobj) - self.assertEqual(None, tokenobj.expiration_time) - self.assertEqual(None, tokenobj.access_time) - # Then an email is sent to the user. - self.listener.access_token_added.assert_called_once_with(userobj, 'test') - self.listener.queue_mail.assert_called_once() - - def test_add_access_token_duplicate_name(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - userobj.add_access_token('test') - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When adding a new token with the same name - with self.assertRaises(ValueError): - userobj.add_access_token('test') - # Then token is not created - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # Then an email is not sent. - self.listener.access_token_added.assert_called_once_with(userobj, 'test') - - def test_delete_access_token(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - userobj.add_access_token('test') - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When deleting an access token - userobj.delete_access_token('test') - # Then Token get deleted - self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count()) - - def test_delete_access_token_invalid(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - userobj.add_access_token('test') - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When deleting an invalid access token - with self.assertRaises(ValueError): - userobj.delete_access_token('invalid') - # Then Token not deleted - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - - def test_delete_user_remove_access_tokens(self): - # Given a user with an existing token - userobj = UserObject.add_user('testuser', 'password') - userobj.add_access_token('test') - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When deleting the user - userobj.delete() - # Then Token get deleted - self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count()) - - def test_verify_access_token(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - token = userobj.add_access_token('test') - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When validating the token - # Then token is valid - self.assertTrue(userobj.validate_access_token(token)) - - def test_verify_access_token_with_expired(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - token = userobj.add_access_token( - 'test', expiration_time=datetime.datetime.now() - datetime.timedelta(seconds=1) - ) - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When validating the token - # Then token is invalid - self.assertFalse(userobj.validate_access_token(token)) - # Then token get removed - self.assertEqual(0, Token.query.filter(Token.userid == userobj.userid).count()) - - def test_verify_access_token_with_invalid(self): - # Given a user with an existing token - userobj = UserObject.get_user(self.USERNAME) - userobj.add_access_token('test', expiration_time=datetime.datetime.now()) - self.assertEqual(1, Token.query.filter(Token.userid == userobj.userid).count()) - # When validating the token - # Then token is invalid - self.assertFalse(userobj.validate_access_token('invalid')) - class UserObjectWithAdminPassword(rdiffweb.test.WebCase): @@ -523,7 +452,6 @@ def test_create_admin_user(self): self.assertEqual('{SSHA}wbSK4hlEX7mtGJplFi2oN6ABm6Y3Bo1e', userobj.hash_password) self.assertTrue(check_password('test', userobj.hash_password)) - def test_set_password(self): # Given admin-password is configure # When trying to update admin password # Then an exception is raised diff --git a/rdiffweb/core/tests/test_login.py b/rdiffweb/core/tests/test_login.py index 3604b7c9..65a0762d 100644 --- a/rdiffweb/core/tests/test_login.py +++ b/rdiffweb/core/tests/test_login.py @@ -45,6 +45,7 @@ class LoginTest(LoginAbstractTest): def test_login(self): # Given a valid user in database with a password userobj = UserObject.add_user('tom', 'password') + userobj.commit() # When trying to login with valid password login = cherrypy.engine.publish('login', 'tom', 'password') # Then login is successful @@ -53,7 +54,8 @@ def test_login(self): self.listener.user_login.assert_called_once_with(userobj) def test_login_with_invalid_password(self): - UserObject.add_user('jeff', 'password') + userobj = UserObject.add_user('jeff', 'password') + userobj.commit() self.assertFalse(any(cherrypy.engine.publish('login', 'jeff', 'invalid'))) # password is case sensitive self.assertFalse(any(cherrypy.engine.publish('login', 'jeff', 'Password'))) diff --git a/rdiffweb/core/tests/test_notification.py b/rdiffweb/core/tests/test_notification.py index 6fe44b49..3d82a503 100644 --- a/rdiffweb/core/tests/test_notification.py +++ b/rdiffweb/core/tests/test_notification.py @@ -53,10 +53,10 @@ def test_notification_job(self): # Set user config user = UserObject.get_user(self.USERNAME) user.email = 'test@test.com' - user.add() + user.commit() repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == self.REPO).first() repo.maxage = 1 - repo.add() + repo.commit() # When running notification_job cherrypy.notification.notification_job() @@ -71,11 +71,11 @@ def test_notification_job_undefined_last_backup_date(self): # Given a valid user with a repository configured for notification user = UserObject.get_user(self.USERNAME) user.email = 'test@test.com' - user.add() + user.add().commit() # Given a repo with last_backup_date None repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == 'broker-repo').first() repo.maxage = 1 - repo.add() + repo.add().commit() self.assertIsNone(repo.last_backup_date) # When Notification job is running @@ -92,10 +92,10 @@ def test_notification_job_without_notification(self): # Given a valid user with a repository configured without notification (-1) user = UserObject.get_user(self.USERNAME) user.email = 'test@test.com' - user.add() + user.add().commit() repo = RepoObject.query.filter(RepoObject.user == user, RepoObject.repopath == self.REPO).first() repo.maxage = -1 - repo.add() + repo.add().commit() # Call notification. cherrypy.notification.notification_job() @@ -123,13 +123,13 @@ def test_email_changed(self): # Given a user with an email address user = UserObject.get_user(self.USERNAME) user.email = 'original_email@test.com' - user.add() + user.add().commit() self.listener.queue_email.reset_mock() # When updating the user's email user = UserObject.get_user(self.USERNAME) user.email = 'email_changed@test.com' - user.add() + user.add().commit() # Then a email is queue to notify the user. self.listener.queue_email.assert_called_once_with( @@ -142,10 +142,12 @@ def test_email_updated_with_same_value(self): # Given a user with an email user = UserObject.get_user(self.USERNAME) user.email = 'email_changed@test.com' + user.add().commit() self.listener.queue_email.reset_mock() # When updating the user's email with the same value user.email = 'email_changed@test.com' + user.add().commit() # Then no email are sent to the user self.listener.queue_email.assert_not_called() @@ -154,10 +156,12 @@ def test_password_change_notification(self): # Given a user with a email. user = UserObject.get_user(self.USERNAME) user.email = 'password_change@test.com' + user.add().commit() self.listener.queue_email.reset_mock() # When updating the user password user.set_password('new_password') + user.add().commit() # Then a email is send to the user self.listener.queue_email.assert_called_once_with( @@ -171,10 +175,12 @@ def test_password_change_with_same_value(self): user = UserObject.get_user(self.USERNAME) user.email = 'password_change@test.com' user.set_password('new_password') + user.add().commit() self.listener.queue_email.reset_mock() # When updating the user password with the same value user.set_password('new_password') + user.add().commit() # Then an email is sent to the user self.listener.queue_email.assert_called_once_with( diff --git a/rdiffweb/core/tests/test_quota.py b/rdiffweb/core/tests/test_quota.py index b220af76..edd4556e 100644 --- a/rdiffweb/core/tests/test_quota.py +++ b/rdiffweb/core/tests/test_quota.py @@ -40,6 +40,7 @@ class QuotaPluginTest(test.WebCase): def test_get_disk_usage(self): # Given a user userobj = UserObject.add_user('bob') + userobj.commit() # When querying quota for a userobj result = cherrypy.engine.publish('get_disk_usage', userobj) # Then quota return a value @@ -48,6 +49,7 @@ def test_get_disk_usage(self): def test_get_disk_quota(self): # Given a user userobj = UserObject.add_user('bob') + userobj.commit() # When querying quota for a userobj result = cherrypy.engine.publish('get_disk_quota', userobj) # Then quota return a value @@ -56,6 +58,7 @@ def test_get_disk_quota(self): def test_set_disk_quota(self): # Given a used cmd userobj = UserObject.add_user('bob') + userobj.commit() # When querying quota for a userobj results = cherrypy.engine.publish('set_disk_quota', userobj, 98765) # Then quota return a value @@ -73,6 +76,7 @@ class QuotaPluginTestWithFailure(test.WebCase): def test_set_disk_quota_with_failure(self): # Given a user object userobj = UserObject.add_user('bob') + userobj.commit() # When settings the quota results = cherrypy.engine.publish('set_disk_quota', userobj, 98765) # Then False is returned @@ -108,6 +112,7 @@ def test_get_disk_usage_with_empty_user_root(self): # Given a user with an empty user_root. userobj = UserObject.add_user('bob') userobj.user_root = '' + userobj.commit() # When getting disk usage results = cherrypy.engine.publish('get_disk_usage', userobj) # Then default disk usage is return @@ -117,6 +122,7 @@ def test_get_disk_usage_with_invalid_user_root(self): # Given a user with an invalid user_root. userobj = UserObject.add_user('bob') userobj.user_root = 'invalid' + userobj.commit() # When getting disk usage results = cherrypy.engine.publish('get_disk_usage', userobj) # Then default disk usage is return diff --git a/rdiffweb/core/tests/test_rdw_templating.py b/rdiffweb/core/tests/test_rdw_templating.py index c677022f..c8d8bab3 100644 --- a/rdiffweb/core/tests/test_rdw_templating.py +++ b/rdiffweb/core/tests/test_rdw_templating.py @@ -22,7 +22,7 @@ from rdiffweb.core.librdiff import RdiffTime from rdiffweb.core.model import RepoObject, UserObject from rdiffweb.core.rdw_templating import _ParentEntry, attrib, do_format_lastupdated, list_parents, url_for -from rdiffweb.test import AppTestCase, WebCase +from rdiffweb.test import WebCase class TemplateManagerTest(unittest.TestCase): @@ -91,7 +91,7 @@ def test_do_format_lastupdated(self): self.assertEqual('4 years ago', do_format_lastupdated(RdiffTime(value=1452442324), now=1591978846)) -class ListParentsTest(AppTestCase): +class ListParentsTest(WebCase): def test_list_parents_with_root_dir(self): repo, path = RepoObject.get_repo_path(b'admin/testcases', as_user=UserObject.get_user('admin')) self.assertEqual(list_parents(repo, path), [_ParentEntry(path=b'', display_name='testcases')]) diff --git a/rdiffweb/rdw_app.py b/rdiffweb/rdw_app.py index 34a1e677..6bc5c77d 100644 --- a/rdiffweb/rdw_app.py +++ b/rdiffweb/rdw_app.py @@ -232,7 +232,8 @@ def __init__(self, cfg): os.environ["TMPDIR"] = self._tempdir # create user manager - UserObject.create_admin_user(cfg.admin_user, cfg.admin_password) + user = UserObject.create_admin_user(cfg.admin_user, cfg.admin_password) + user.commit() @property def currentuser(self): diff --git a/rdiffweb/templates/access_token_added.html b/rdiffweb/templates/access_token_added.html index 5142d38f..53b6bf5c 100644 --- a/rdiffweb/templates/access_token_added.html +++ b/rdiffweb/templates/access_token_added.html @@ -2,7 +2,7 @@ {% trans username=user.username %}Hey {{ username }},{% endtrans %} -

{% trans %}A new access token, named "{{ test }}", has been created.{% endtrans %}

+

{% trans %}A new access token, named "{{ name }}", has been created.{% endtrans %}

{% trans %}If you did not make this change and believe your account has been compromised, please contact your administrator.{% endtrans %}

diff --git a/rdiffweb/test.py b/rdiffweb/test.py index 2cd452c2..821b6cb8 100644 --- a/rdiffweb/test.py +++ b/rdiffweb/test.py @@ -51,55 +51,6 @@ def create_testcases_repo(app): return new -class AppTestCase(unittest.TestCase): - - REPO = 'testcases' - - USERNAME = 'admin' - - PASSWORD = 'admin123' - - default_config = {} - - app_class = RdiffwebApp - - @classmethod - def setup_class(cls): - if cls is AppTestCase: - raise unittest.SkipTest("%s is an abstract base class" % cls.__name__) - - @classmethod - def teardown_class(cls): - pass - - def setUp(self): - # Allow defining a custom database uri for testing. - self.database_dir = tempfile.mkdtemp(prefix='rdiffweb_tests_db_') - uri = os.path.join(self.database_dir, 'rdiffweb.tmp.db') - uri = os.environ.get('RDIFFWEB_TEST_DATABASE_URI', uri) - self.default_config['database-uri'] = uri - cfg = self.app_class.parse_args( - args=[], config_file_contents='\n'.join('%s=%s' % (k, v) for k, v in self.default_config.items()) - ) - # Create Application - self.app = self.app_class(cfg) - # Create repositories - self.testcases = create_testcases_repo(self.app) - # Register repository - admin_user = UserObject.get_user(self.USERNAME) - if admin_user: - admin_user.user_root = self.testcases - admin_user.refresh_repos() - - def tearDown(self): - if hasattr(self, 'database_dir'): - shutil.rmtree(self.database_dir) - delattr(self, 'database_dir') - if hasattr(self, 'testcases'): - shutil.rmtree(self.testcases) - delattr(self, 'testcases') - - class WebCase(helper.CPWebCase): """ Helper class for the rdiffweb test suite. @@ -152,13 +103,14 @@ def setUp(self): cherrypy.tools.db.drop_all() cherrypy.tools.db.create_all() # Create default admin - UserObject.create_admin_user(self.USERNAME, self.PASSWORD) + admin_user = UserObject.create_admin_user(self.USERNAME, self.PASSWORD) + admin_user.commit() # Create testcases repo self.testcases = create_testcases_repo(self.app) - admin_user = UserObject.get_user(self.USERNAME) if admin_user: admin_user.user_root = self.testcases admin_user.refresh_repos() + admin_user.commit() # Login to web application. if self.login: self._login() @@ -167,6 +119,7 @@ def tearDown(self): if hasattr(self, 'testcases'): shutil.rmtree(self.testcases) delattr(self, 'testcases') + cherrypy.tools.db.drop_all() @property def app(self): diff --git a/rdiffweb/tools/db.py b/rdiffweb/tools/db.py index 0aee0d07..76350b90 100644 --- a/rdiffweb/tools/db.py +++ b/rdiffweb/tools/db.py @@ -20,8 +20,9 @@ import logging import cherrypy -from sqlalchemy import create_engine, event +from sqlalchemy import create_engine, event, inspect from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import scoped_session, sessionmaker @@ -39,6 +40,31 @@ def _set_sqlite_journal_mode_wal(connection, connection_record): cursor.close() +def _get_model_changes(model): + """ + Return a dictionary containing changes made to the model since it was + fetched from the database. + + The dictionary is of the form {'property_name': [old_value, new_value]} + """ + state = inspect(model) + changes = {} + for attr in state.attrs: + hist = attr.history + if not hist.has_changes(): + continue + if isinstance(attr.value, (list, tuple)) or len(hist.deleted) > 1 or len(hist.added) > 1: + # If array, store array + changes[attr.key] = [hist.deleted, hist.added] + else: + # If primitive, store primitive + changes[attr.key] = [ + hist.deleted[0] if len(hist.deleted) >= 1 else None, + hist.added[0] if len(hist.added) >= 1 else None, + ] + return changes + + class Base: ''' Extends declarative base to provide convenience methods to models similar to @@ -53,35 +79,32 @@ class Base: changed = User.from_dict({}) # update record based on dict argument passed in and returns any keys changed ''' - def add(self, commit=True): + def add(self): """ Add current object to session. """ self.__class__.session.add(self) - if commit: - self.__class__.session.commit() return self - def delete(self, commit=True): - """ - Delete current object to session. - """ + def delete(self): self.__class__.session.delete(self) - if commit: - self.__class__.session.commit() return self - def merge(self, commit=True): - """ - Merge current object to session. - """ - self.__class__.session.merge(self) - if commit: - self.__class__.session.commit() + def commit(self): + self.__class__.session.commit() + return self + + def flush(self): + self.__class__.session.flush() return self def expire(self): self.__class__.session.expire(self) + return self + + def rollback(self): + self.__class__.session.rollback() + return self class BaseExtensions(DeclarativeMeta): @@ -117,6 +140,7 @@ def create_all(self): if debug: logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG) base.metadata.create_all() + self.get_session().commit() def drop_all(self): # Release opened sessions. @@ -124,6 +148,7 @@ def drop_all(self): # Drop all base = self.get_base() base.metadata.drop_all() + self.get_session().commit() def get_base(self): if self._base is None: @@ -140,13 +165,17 @@ def on_end_resource(self): if self._session is None: return try: - self._session.flush() - self._session.commit() - except Exception: - logger.exception('error trying to flush and commit session') + # When terminating, raise an error if objects are not commit. + if self._session.dirty or self._session.new or self._session.deleted: + changes = ', '.join([str(_get_model_changes(obj)) for obj in self._session.dirty]) + logger.exception( + 'session is dirty, some database object(s) are not commited, this indicate a bug in the application ' + 'dirty %s new %s deleted %s' % (changes, self._session.new, self._session.deleted) + ) + raise SQLAlchemyError('session is dirty') + finally: self._session.rollback() self._session.expunge_all() - finally: self._session.remove()