diff --git a/flexmeasures/api/v3_0/assets.py b/flexmeasures/api/v3_0/assets.py index 14c81edab..abd7a5e53 100644 --- a/flexmeasures/api/v3_0/assets.py +++ b/flexmeasures/api/v3_0/assets.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import json from flask import current_app from flask_classful import FlaskView, route +from flask_login import current_user from flask_security import auth_required from flask_json import as_json from marshmallow import fields @@ -18,6 +21,8 @@ from flexmeasures.api.common.schemas.users import AccountIdField from flexmeasures.utils.coding_utils import flatten_unique from flexmeasures.ui.utils.view_utils import set_session_variables +from flexmeasures.auth.policy import check_access +from werkzeug.exceptions import Forbidden, Unauthorized asset_schema = AssetSchema() @@ -38,21 +43,27 @@ class AssetAPI(FlaskView): @route("", methods=["GET"]) @use_kwargs( { - "account": AccountIdField( - data_key="account_id", load_default=AccountIdField.load_current + "account": AccountIdField(data_key="account_id", load_default=None), + }, + location="query", + ) + @use_kwargs( + { + "all_accessible": fields.Bool( + data_key="all_accessible", load_default=False ), }, location="query", ) - @permission_required_for_context("read", ctx_arg_name="account") @as_json - def index(self, account: Account): - """List all assets owned by a certain account. + def index(self, account: Account | None, all_accessible: bool): + """List all assets owned or accessible by a certain account. .. :quickref: Asset; Download asset list This endpoint returns all accessible assets for the account of the user. The `account_id` query parameter can be used to list assets from a different account. + The `all_accessible` query parameter can be used to list all the assets accessible by the requesting user. Defaults to `false`. **Example response** @@ -80,7 +91,29 @@ def index(self, account: Account): :status 403: INVALID_SENDER :status 422: UNPROCESSABLE_ENTITY """ - return assets_schema.dump(account.generic_assets), 200 + + if all_accessible: + accounts = [] + for _account in db.session.scalars(select(Account)).all(): + try: + check_access(_account, "read") + accounts.append(_account) + except (Forbidden, Unauthorized): + # re-raise exception if the account is provided + # but the requesting user has no read access to it. + if _account == account: + raise + else: + if account is None: + account = current_user.account + check_access(account, "read") + accounts = [account] + + assets = [] + for account in accounts: + assets.extend(account.generic_assets) + + return assets_schema.dump(assets), 200 @route("/public", methods=["GET"]) @as_json diff --git a/flexmeasures/api/v3_0/users.py b/flexmeasures/api/v3_0/users.py index 378bf5c31..769512ce4 100644 --- a/flexmeasures/api/v3_0/users.py +++ b/flexmeasures/api/v3_0/users.py @@ -1,11 +1,13 @@ from flask_classful import FlaskView, route from marshmallow import fields from sqlalchemy.exc import IntegrityError +from sqlalchemy import select from webargs.flaskparser import use_kwargs from flask_security import current_user, auth_required from flask_security.recoverable import send_reset_password_instructions from flask_json import as_json -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Forbidden, Unauthorized +from flexmeasures.auth.policy import check_access from flexmeasures.data.models.audit_log import AuditLog from flexmeasures.data.models.user import User as UserModel, Account @@ -41,22 +43,22 @@ class UserAPI(FlaskView): @route("", methods=["GET"]) @use_kwargs( { - "account": AccountIdField( - data_key="account_id", load_default=AccountIdField.load_current - ), + "account": AccountIdField(data_key="account_id", load_default=None), "include_inactive": fields.Bool(load_default=False), }, location="query", ) - @permission_required_for_context("read", ctx_arg_name="account") @as_json def index(self, account: Account, include_inactive: bool = False): - """API endpoint to list all users of an account. + """API endpoint to list all users. + .. :quickref: User; Download user list This endpoint returns all accessible users. By default, only active users are returned. + The `account_id` query parameter can be used to filter the users of + a given account. The `include_inactive` query parameter can be used to also fetch inactive users. Accessible users are users in the same account as the current user. @@ -89,7 +91,25 @@ def index(self, account: Account, include_inactive: bool = False): :status 403: INVALID_SENDER :status 422: UNPROCESSABLE_ENTITY """ - users = get_users(account_name=account.name, only_active=not include_inactive) + + if account is not None: + check_access(account, "read") + accounts = [account] + else: + accounts = [] + for account in db.session.scalars(select(Account)).all(): + try: + check_access(account, "read") + accounts.append(account) + except (Forbidden, Unauthorized): + pass + + users = [] + for account in accounts: + users += get_users( + account_name=account.name, + only_active=not include_inactive, + ) return users_schema.dump(users), 200 @route("/") diff --git a/flexmeasures/data/schemas/generic_assets.py b/flexmeasures/data/schemas/generic_assets.py index 745020556..8306bb30d 100644 --- a/flexmeasures/data/schemas/generic_assets.py +++ b/flexmeasures/data/schemas/generic_assets.py @@ -44,7 +44,7 @@ class GenericAssetSchema(ma.SQLAlchemySchema): generic_asset_type_id = fields.Integer(required=True) attributes = JSON(required=False) parent_asset_id = fields.Int(required=False, allow_none=True) - child_assets = ma.Nested("GenericAssetSchema", many=True, dumb_only=True) + child_assets = ma.Nested("GenericAssetSchema", many=True, dump_only=True) class Meta: model = GenericAsset diff --git a/flexmeasures/ui/crud/assets.py b/flexmeasures/ui/crud/assets.py index 64e50bbb2..a4dd76abd 100644 --- a/flexmeasures/ui/crud/assets.py +++ b/flexmeasures/ui/crud/assets.py @@ -214,6 +214,26 @@ def get_assets_by_account(account_id: int | str | None) -> list[GenericAsset]: ] +def get_all_assets() -> list[GenericAsset]: + get_assets_response = ( + InternalApi() + .get(url_for("AssetAPI:index"), query={"all_accessible": True}) + .json() + ) + get_assets_response_public = InternalApi().get(url_for("AssetAPI:public")).json() + + assets = [] + if isinstance(get_assets_response, list): + assets.extend(get_assets_response) + + if isinstance(get_assets_response_public, list): + assets.extend(get_assets_response_public) + + asset_ids_filter = [GenericAsset.id.in_(ad["id"] for ad in assets)] + + return db.session.scalars(select(GenericAsset).where(*asset_ids_filter)).all() + + class AssetCrudUI(FlaskView): """ These views help us offer a Jinja2-based UI. @@ -231,14 +251,7 @@ def index(self, msg=""): List the user's assets. For admins, list across all accounts. """ - assets = [] - - if user_has_admin_access(current_user, "read"): - for account in db.session.scalars(select(Account)).all(): - assets += get_assets_by_account(account.id) - assets += get_assets_by_account(account_id=None) - else: - assets = get_assets_by_account(current_user.account_id) + assets = get_all_assets() return render_flexmeasures_template( "crud/assets.html", diff --git a/flexmeasures/ui/crud/users.py b/flexmeasures/ui/crud/users.py index acc3c720b..9da810c71 100644 --- a/flexmeasures/ui/crud/users.py +++ b/flexmeasures/ui/crud/users.py @@ -4,7 +4,6 @@ from flask import request, url_for from flask_classful import FlaskView -from flask_login import current_user from flask_security import login_required from flask_wtf import FlaskForm from wtforms import StringField, FloatField, DateTimeField, BooleanField @@ -92,6 +91,20 @@ def get_users_by_account( return users +def get_all_users(include_inactive: bool = False) -> list[User]: + get_users_response = InternalApi().get( + url_for( + "UserAPI:index", + include_inactive=include_inactive, + ) + ) + users = [ + process_internal_api_response(user, make_obj=True) + for user in get_users_response.json() + ] + return users + + class UserCrudUI(FlaskView): route_base = "/users" trailing_slash = False @@ -100,15 +113,8 @@ class UserCrudUI(FlaskView): def index(self): """/users""" include_inactive = request.args.get("include_inactive", "0") != "0" - users = [] - if current_user.has_role(ADMIN_ROLE) or current_user.has_role( - ADMIN_READER_ROLE - ): - accounts = db.session.scalars(select(Account)).all() - else: - accounts = [current_user.account] - for account in accounts: - users += get_users_by_account(account.id, include_inactive) + users = get_all_users(include_inactive) + return render_flexmeasures_template( "crud/users.html", users=users, include_inactive=include_inactive ) diff --git a/flexmeasures/ui/tests/conftest.py b/flexmeasures/ui/tests/conftest.py index d2b78f811..0a3f080e6 100644 --- a/flexmeasures/ui/tests/conftest.py +++ b/flexmeasures/ui/tests/conftest.py @@ -2,6 +2,7 @@ from flexmeasures.data.services.users import create_user from flexmeasures.ui.tests.utils import login, logout +from flexmeasures import Asset @pytest.fixture(scope="function") @@ -41,3 +42,21 @@ def setup_ui_test_data( account_name=setup_accounts["Prosumer"].name, user_roles=dict(name="admin", description="A site admin."), ) + + +@pytest.fixture +def assets_prosumer(db, setup_accounts, setup_generic_asset_types): + assets = [] + for name in ["TestAsset", "TestAsset2"]: + asset = Asset( + name=name, + generic_asset_type=setup_generic_asset_types["battery"], + owner=setup_accounts["Prosumer"], + latitude=70.4, + longitude=30.9, + ) + assets.append(asset) + + db.session.add_all(assets) + + return assets diff --git a/flexmeasures/ui/tests/test_asset_crud.py b/flexmeasures/ui/tests/test_asset_crud.py index 2e2c8aa3f..7d7446848 100644 --- a/flexmeasures/ui/tests/test_asset_crud.py +++ b/flexmeasures/ui/tests/test_asset_crud.py @@ -16,28 +16,25 @@ def test_assets_page_empty(db, client, requests_mock, as_prosumer_user1): - requests_mock.get(f"{api_path_assets}?account_id=1", status_code=200, json={}) - requests_mock.get(f"{api_path_assets}/public", status_code=200, json={}) + requests_mock.get(f"{api_path_assets}", status_code=200, json=[]) + requests_mock.get(f"{api_path_assets}/public", status_code=200, json=[]) asset_index = client.get(url_for("AssetCrudUI:index"), follow_redirects=True) assert asset_index.status_code == 200 def test_get_assets_by_account(db, client, requests_mock, as_prosumer_user1): mock_assets = mock_asset_response(multiple=True) - requests_mock.get( - f"{api_path_assets}?account_id=1", status_code=200, json=mock_assets - ) + requests_mock.get(f"{api_path_assets}", status_code=200, json=mock_assets) assert get_assets_by_account(1)[1].name == "TestAsset2" @pytest.mark.parametrize("use_owned_by", [False, True]) def test_assets_page_nonempty( - db, client, requests_mock, as_prosumer_user1, use_owned_by + db, client, requests_mock, as_prosumer_user1, use_owned_by, assets_prosumer ): mock_assets = mock_asset_response(multiple=True) - requests_mock.get( - f"{api_path_assets}?account_id=1", status_code=200, json=mock_assets - ) + requests_mock.get(f"{api_path_assets}", status_code=200, json=mock_assets) + requests_mock.get(f"{api_path_assets}/public", status_code=200, json=[]) if use_owned_by: asset_index = client.get( url_for("AssetCrudUI:owned_by", account_id=mock_assets[0]["account_id"]) diff --git a/flexmeasures/ui/tests/test_user_crud.py b/flexmeasures/ui/tests/test_user_crud.py index 62e075c5c..e3823581c 100644 --- a/flexmeasures/ui/tests/test_user_crud.py +++ b/flexmeasures/ui/tests/test_user_crud.py @@ -15,7 +15,7 @@ def test_get_users_by_account(client, requests_mock, as_prosumer_user1): requests_mock.get( - f"http://localhost//api/v3_0/users?account_id={current_user.account.id}", + "http://localhost//api/v3_0/users", status_code=200, json=mock_user_response(multiple=True), ) diff --git a/flexmeasures/ui/tests/test_views.py b/flexmeasures/ui/tests/test_views.py index c2d13536b..78d218915 100644 --- a/flexmeasures/ui/tests/test_views.py +++ b/flexmeasures/ui/tests/test_views.py @@ -21,9 +21,14 @@ def test_dashboard_responds_only_for_logged_in_users(client, as_prosumer_user1): def test_assets_responds(client, requests_mock, as_prosumer_user1): requests_mock.get( - "http://localhost//api/v3_0/assets?account_id=1", + "http://localhost//api/v3_0/assets", status_code=200, - json={}, + json=[], + ) + requests_mock.get( + "http://localhost//api/v3_0/assets/public", + status_code=200, + json=[], ) assets_page = client.get(url_for("AssetCrudUI:index"), follow_redirects=True) assert assets_page.status_code == 200 diff --git a/flexmeasures/ui/tests/utils.py b/flexmeasures/ui/tests/utils.py index c5658198d..6783da896 100644 --- a/flexmeasures/ui/tests/utils.py +++ b/flexmeasures/ui/tests/utils.py @@ -20,7 +20,7 @@ def logout(client): def mock_asset_response( - asset_id: int = 1, + asset_id: int = 2, account_id: int = 1, as_list: bool = True, multiple: bool = False, @@ -38,6 +38,7 @@ def mock_asset_response( if multiple: asset2 = copy.deepcopy(asset) asset2["name"] = "TestAsset2" + asset2["id"] += 1 asset_list.append(asset2) return asset_list return asset