diff --git a/.dockerignore b/.dockerignore index b124585869..c43e425204 100644 --- a/.dockerignore +++ b/.dockerignore @@ -25,7 +25,7 @@ grr/gui/static/bower_components grr/gui/static/node_modules grr/gui/static/tmp grr/var -grr-server*.tar.gz +grr_server*.tar.gz LICENSE README.md travis diff --git a/CHANGELOG.md b/CHANGELOG.md index ac5d196656..12286e05e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,8 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Removed support for foreman rules using `uname` of an endpoint (this can be simulated by using 3 rules for system name, release and version). * GRR server Debian package is removed when github actions are updated. The - docker image and docker compose stack (see section "Added") are the + docker image and Docker Compose stack (see section "Added") are the recommended wait of running GRR in the future. +* Removed the `provides` field from the `Artifact` message. This change has been + done in anticipation of the removal of the same field from the official GitHub + repository (ForensicArtifacts/artifacts#275). ### Added diff --git a/appveyor/e2e_tests/run_docker_compose_e2e_test.sh b/appveyor/e2e_tests/run_docker_compose_e2e_test.sh index 0d54dfe9e3..a06f9db031 100755 --- a/appveyor/e2e_tests/run_docker_compose_e2e_test.sh +++ b/appveyor/e2e_tests/run_docker_compose_e2e_test.sh @@ -1,18 +1,18 @@ #!/bin/bash # -# Runs the e2e tests in the docker-compose stack. +# Runs the e2e tests in the Docker Compose stack. # -# This script is executed in the grr docker container or in an +# This script is executed in the grr Docker container or in an # environment with the grr src and develpment environment # (grr-python-api, grr-test) available. And assumes the -# docker-compose stack to be running with exposed ports for +# Docker Compose stack to be running with exposed ports for # the admin API and GRR database. # # Running this test (from the main folder): -# - Start the docker compose stack with: -# $ docker-compose up +# - Start the Docker Compose stack with: +# $ docker compose up # -# - Build and run the GRR docker container and set the entrypoint +# - Build and run the GRR Docker container and set the entrypoint # to this script: # $ docker build -f ./Dockerfile . -t local-grr-container # $ docker run \ @@ -25,7 +25,7 @@ set -ex -# The IP address of the client inside the docker-compose stack. +# The IP address of the client inside the Docker Compose stack. readonly CLIENT_IP=${1} readonly GRR_API="http://host.docker.internal:8000" diff --git a/appveyor/windows_templates/build_windows_templates.py b/appveyor/windows_templates/build_windows_templates.py index 5bfba269fc..456ae19f3e 100644 --- a/appveyor/windows_templates/build_windows_templates.py +++ b/appveyor/windows_templates/build_windows_templates.py @@ -188,8 +188,9 @@ def MakeProtoSdist(self): self.virtualenv_python64, "setup.py", "sdist", "--formats=zip", "--dist-dir=%s" % args.build_dir ]) - return glob.glob(os.path.join(args.build_dir, - "grr-response-proto-*.zip")).pop() + return glob.glob( + os.path.join(args.build_dir, "grr_response_proto-*.zip") + ).pop() def MakeCoreSdist(self): os.chdir(os.path.join(args.grr_src, "grr/core")) @@ -197,8 +198,9 @@ def MakeCoreSdist(self): self.virtualenv_python64, "setup.py", "sdist", "--formats=zip", "--dist-dir=%s" % args.build_dir, "--no-sync-artifacts" ]) - return glob.glob(os.path.join(args.build_dir, - "grr-response-core-*.zip")).pop() + return glob.glob( + os.path.join(args.build_dir, "grr_response_core-*.zip") + ).pop() def MakeClientSdist(self): os.chdir(os.path.join(args.grr_src, "grr/client/")) @@ -206,8 +208,9 @@ def MakeClientSdist(self): self.virtualenv_python64, "setup.py", "sdist", "--formats=zip", "--dist-dir=%s" % args.build_dir ]) - return glob.glob(os.path.join(args.build_dir, - "grr-response-client-*.zip")).pop() + return glob.glob( + os.path.join(args.build_dir, "grr_response_client-*.zip") + ).pop() def MakeClientBuilderSdist(self): os.chdir(os.path.join(args.grr_src, "grr/client_builder/")) @@ -216,8 +219,8 @@ def MakeClientBuilderSdist(self): "--dist-dir=%s" % args.build_dir ]) return glob.glob( - os.path.join(args.build_dir, - "grr-response-client-builder-*.zip")).pop() + os.path.join(args.build_dir, "grr_response_client_builder-*.zip") + ).pop() def InstallGRR(self, path): """Installs GRR.""" diff --git a/build_requirements.txt b/build_requirements.txt new file mode 100644 index 0000000000..068e0e8e40 --- /dev/null +++ b/build_requirements.txt @@ -0,0 +1,5 @@ +pip==24.0 +pytest==6.2.5 +pytest-xdist==2.2.1 +setuptools==69.5.1 +wheel==0.43.0 \ No newline at end of file diff --git a/colab/grr_colab/fs.py b/colab/grr_colab/fs.py index 9f1fa4d6b8..9ea6e9b553 100644 --- a/colab/grr_colab/fs.py +++ b/colab/grr_colab/fs.py @@ -13,7 +13,6 @@ from grr_colab import vfs from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 -from grr_response_server.flows.general import file_finder class FileSystem(object): @@ -208,9 +207,7 @@ def _collect_file(self, path: Text) -> None: args.action.action_type = flows_pb2.FileFinderAction.Action.DOWNLOAD try: - cff = self._client.CreateFlow( - name=file_finder.ClientFileFinder.__name__, args=args - ) + cff = self._client.CreateFlow(name='FileFinder', args=args) except api_errors.AccessForbiddenError as e: raise errors.ApprovalMissingError(self.id, e) diff --git a/colab/grr_colab/testing.py b/colab/grr_colab/testing.py index 0fe2bd1f44..6009cdb617 100644 --- a/colab/grr_colab/testing.py +++ b/colab/grr_colab/testing.py @@ -88,7 +88,20 @@ def wait_until_done(*args, **kwargs): actions = list(client_actions.REGISTRY.values()) client_mock = action_mocks.ActionMock(*actions) - flow_test_lib.FinishAllFlows(client_mock=client_mock) + flow_test_lib.FinishAllFlows( + client_mock=client_mock, + # Sometimes (e.g. during interrogation) some subflows fail (which + # can happen if we do not run with root privileges or if certain + # data is not available) but the flow can cope with this. To avoid + # tests hard failing in such scenarios, we disable checking flow + # errors. + # + # Note that we are still going to verify status of the root flow: + # the original `WaitUntilDone` (called right below) does this. If + # the flow ends in a state that is not `FINISHED` (e.g. because it + # crashed) the test is going to fail anyway. + check_flow_errors=False, + ) func(*args, **kwargs) return wait_until_done diff --git a/colab/grr_colab/vfs.py b/colab/grr_colab/vfs.py index 13e1d57b75..3d32abdbc2 100644 --- a/colab/grr_colab/vfs.py +++ b/colab/grr_colab/vfs.py @@ -128,7 +128,7 @@ def detach(self) -> None: # pytype: disable=signature-mismatch # overriding-re def readable(self) -> bool: return True - def read(self, size: int = -1) -> bytes: + def read(self, size: int = -1) -> bytes: # pytype: disable=signature-mismatch self._ensure_not_closed() size = size or -1 diff --git a/docker-compose.testing.yaml b/compose.testing.yaml similarity index 100% rename from docker-compose.testing.yaml rename to compose.testing.yaml diff --git a/docker-compose.yaml b/compose.yaml similarity index 56% rename from docker-compose.yaml rename to compose.yaml index d9b969e656..1964bf4ed7 100644 --- a/docker-compose.yaml +++ b/compose.yaml @@ -1,3 +1,4 @@ +version: "3.8" services: db: image: mysql:8.2 @@ -26,37 +27,50 @@ services: retries: 10 grr-admin-ui: - image: ghcr.io/google/grr:docker-compose-testing + image: ghcr.io/google/grr:latest container_name: grr-admin-ui hostname: admin-ui - restart: always depends_on: db: condition: service_healthy + fleetspeak-admin: + condition: service_started volumes: - - ./docker_config_files/server:/configs/ + - ./docker_config_files:/configs/ + # Mount a directory for the repacked client installers, so they + # can be used in the grr-client container which mounts the same volume. + - client_installers:/client_installers ports: - "8000:8000" expose: - "8000" networks: - server-network - command: - - -component - - admin_ui - - -config - - /configs/grr.server.yaml - - --verbose + entrypoint: [ + "/bin/bash", + "-c", + "/configs/server/repack_clients.sh && grr_server -component admin_ui -config /configs/server/grr.server.yaml --verbose" + ] + healthcheck: + # As soon as any files have been written to the /client_installer we + # assume the service is healthy. + test: | + if [[ -z "$(ls /client_installers)" ]]; then + echo "Healthckeck: GRR client installer not available" + exit 1 + fi + timeout: 10s + retries: 10 grr-fleetspeak-frontend: - image: ghcr.io/google/grr:docker-compose-testing + image: ghcr.io/google/grr:latest container_name: grr-fleetspeak-frontend hostname: grr-fleetspeak-frontend depends_on: db: condition: service_healthy volumes: - - ./docker_config_files/server/:/configs/ + - ./docker_config_files:/configs expose: - "11111" restart: always @@ -66,11 +80,11 @@ services: - -component - frontend - -config - - /configs/grr.server.yaml + - /configs/server/grr.server.yaml - --verbose fleetspeak-admin: - image: ghcr.io/google/fleetspeak:cl-601031487 + image: ghcr.io/google/fleetspeak:latest container_name: fleetspeak-admin hostname: fleetspeak-admin depends_on: @@ -81,20 +95,20 @@ services: expose: - "4444" volumes: - - ./docker_config_files/server/:/configs/ + - ./docker_config_files:/configs entrypoint: [ "server", "-components_config", - "/configs/textservices/admin.components.config", + "/configs/server/textservices/admin.components.config", "-services_config", - "/configs/grr_frontend.service", + "/configs/server/grr_frontend.service", "-alsologtostderr", "-v", "1000" ] fleetspeak-frontend: - image: ghcr.io/google/fleetspeak:cl-601031487 + image: ghcr.io/google/fleetspeak:latest container_name: fleetspeak-frontend hostname: fleetspeak-frontend depends_on: @@ -106,23 +120,23 @@ services: - "4443" - "10000" volumes: - - ./docker_config_files/server/:/configs/ + - ./docker_config_files:/configs entrypoint: [ "server", "-components_config", - "/configs/textservices/frontend.components.config", + "/configs/server/textservices/frontend.components.config", "-services_config", - "/configs/grr_frontend.service", + "/configs/server/grr_frontend.service", "-alsologtostderr", "-v", "1000" ] grr-worker: - image: ghcr.io/google/grr:docker-compose-testing + image: ghcr.io/google/grr:latest container_name: grr-worker volumes: - - ./docker_config_files/server/:/configs/ + - ./docker_config_files:/configs hostname: grr-worker depends_on: db: @@ -134,27 +148,34 @@ services: - -component - worker - -config - - /configs/grr.server.yaml + - /configs/server/grr.server.yaml - --verbose grr-client: - image: ghcr.io/google/grr:docker-compose-testing + image: ubuntu:22.04 container_name: grr-client - restart: always depends_on: - - db - - fleetspeak-frontend + db: + condition: service_healthy + fleetspeak-frontend: + condition: service_started + grr-admin-ui: + # Service is healthy as soon as client installers are repacked. + condition: service_healthy volumes: - - ./docker_config_files/client/:/configs/ - # Mount the client_installers folder, to preserve - # the repacked templates across restarts. + - ./docker_config_files:/configs + # Mount the client_installers folder which contains the + # repacked templates written by the grr-admin-ui container - client_installers:/client_installers + # Mount the client_state volume to preserve the clients state + # including the client_id across restarts. + - client_state:/client_state networks: - server-network entrypoint: [ "/bin/bash", "-c", - "/configs/repack_install_client.sh && fleetspeak-client -config /configs/client.config" + "/configs/client/install_client.sh && fleetspeak-client -config /configs/client/client.config" ] healthcheck: test: | @@ -168,5 +189,6 @@ services: volumes: db_data: client_installers: + client_state: networks: server-network: diff --git a/debian/changelog b/debian/changelog deleted file mode 100644 index 024329f33c..0000000000 --- a/debian/changelog +++ /dev/null @@ -1,139 +0,0 @@ -grr-server (3.1.0-2) unstable; urgency=low - - * Update - - -- GRR development team Thu, 14 Apr 2016 18:16:26 -0700 - -grr-server (3.1.0-1) unstable; urgency=low - - * Update - - -- GRR development team Mon, 11 Apr 2016 00:00:00 -0700 - -grr-server (0.3.0-8) unstable; urgency=low - - * Update - - -- GRR development team Fri, 19 Feb 2016 00:00:00 -0700 - -grr-server (0.3.0-7) unstable; urgency=low - - * Update - - -- GRR development team Mon, 20 Jul 2015 13:40:26 -0700 - -grr-server (0.3.0-6) unstable; urgency=low - - * Update - - -- GRR development team Wed, 17 Mar 2015 00:00:00 +0100 - -grr-server (0.3.0-2) unstable; urgency=low - - * Fixed some small client bugs. - * Updated GRR to use Rekall v1.0.2, removed Volatility completely. - - -- GRR development team Wed, 30 Jun 2014 00:00:00 +0100 - -grr-server (0.3.0-1) unstable; urgency=low - - * Version change to make client versioning easier. - - -- GRR development team Tue, 26 Jun 2014 11:35:00 +0100 - -grr-server (0.2.10-1) unstable; urgency=low - - * Update - - -- GRR development team Tue, 15 Apr 2014 12:08:00 +0100 - -grr-server (0.2.9-1) unstable; urgency=low - - * Version correction - - -- GRR development team Fri, 06 Dec 2013 11:13:18 +0100 - -grr-server (0.2-9) unstable; urgency=low - - * Update - - -- GRR development team Sun, 20 Oct 2013 12:00:00 +0000 - -grr-server (0.2-8) unstable; urgency=low - - * Added Windows 32 bit client - * Updated deployment documentation - * Fix for TSK URN escaping issues - * Added support for renaming Windows executables during repacking - * Added verbose repacking to make debugging clients easier - * Bugfixes - - -- GRR development team Thu, 4 Apr 2013 12:00:00 +0200 - -grr-server (0.2-7) unstable; urgency=low - - * Bugfixes for build and deployment - * Fixed tests - * Added OSX client automatic packaging - * Resolved bad log paths for packaged Windows 64 bit client - * Fixed mongo index error - * Improved logging - * Added initctl_switch.sh helper script - * Updated winpmem drivers - - -- GRR development team Thu, 4 Apr 2013 12:00:00 +0200 - -grr-server (0.2-6) unstable; urgency=low - - * New Bootstrap UI - * Expanded Hunt UI for scheduling - * Config system complete overhaul - * Additional mysql datastore option (currently low performance) - * New client builds for OSX and Windows - * New pmem memory drivers - - -- GRR development team Sun, 31 Mar 2013 12:00:00 +0200 - -grr-server (0.2-5) unstable; urgency=low - - * Numerous bug fixes - * Memory drivers now ship included - * Volatility integration updates - * Hunt scheduling UI v1 ready - * Refactor of flow typing system for faster development/testing - - -- GRR development team Thurs, 31 Oct 2012 12:00:00 +0200 - -grr-server (0.2-4) unstable; urgency=low - - * Update for memory improvements - * Django 1.4 fixes - * Numerous bug fixes - - -- GRR development team Thurs, 31 Oct 2012 12:00:00 +0200 - -grr-server (0.2-3) unstable; urgency=low - - * Improved build and update mechanisms - * Add initial hunt scheduling UI - * Compatibility fixes for Django 1.4 including XSRF changes - * New flows GetProcessesBinariesVolatility, CacheGrep, Glob, SendFile - * Scheduling performance improvements - * Numerous bug fixes - - -- GRR development team Thurs, 31 Oct 2012 12:00:00 +0200 - -grr-server (0.2-2) unstable; urgency=low - - * Update build and update mechanisms to ease maintenance and install - * Include OSX client builds - * Add memory driver installation support - * Numerous bug fixes - - -- GRR development team Thurs, 18 Oct 2012 12:00:00 +0200 - -grr-server (0.2-1) unstable; urgency=low - - * Initial release - - -- GRR development team Fri, 28 Sep 2012 12:00:00 +0200 diff --git a/debian/compat b/debian/compat deleted file mode 100644 index ec635144f6..0000000000 --- a/debian/compat +++ /dev/null @@ -1 +0,0 @@ -9 diff --git a/debian/control b/debian/control deleted file mode 100644 index 59a6e66bdd..0000000000 --- a/debian/control +++ /dev/null @@ -1,15 +0,0 @@ -Source: grr-server -Section: misc -Priority: extra -Maintainer: GRR developers -Build-Depends: debhelper (>= 9), debhelper (>= 9.20160709) | dh-systemd (>= 1.5), dh-make, dh-virtualenv (>= 0.6), lib32z1, libc6-i386, python3-dev -Standards-Version: 3.8.3 -Homepage: https://github.com/google/grr - -Package: grr-server -Section: python -Architecture: any -Pre-Depends: -Depends: debhelper, dh-make, dpkg (>= 1.16.1), dpkg-dev, python3-dev, python3-mysqldb, rpm, systemd, zip -Description: GRR Rapid Response is an Incident Response Framework - GRR Rapid Response is an Incident Response Framework. diff --git a/debian/copyright b/debian/copyright deleted file mode 100644 index 22d58fe0d0..0000000000 --- a/debian/copyright +++ /dev/null @@ -1,49 +0,0 @@ -This work was packaged for Debian by: - - GRR development team on Fri, 28 Sep 2012 12:00:00 +0200 - -It was downloaded from: https://github.com/google/grr - -Upstream Author(s): - - GRR development team - -Copyright: - - Copyright 2011 Google Inc. - -License: - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -On Debian systems, the complete text of the Apache-2.0 License -can be found in `/usr/share/common-licenses/Apache-2.0'. - -The Debian packaging is: - - Copyright 2012 Google Inc. - -License: - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - diff --git a/debian/grr-server.default b/debian/grr-server.default deleted file mode 100644 index 1f7950a1bf..0000000000 --- a/debian/grr-server.default +++ /dev/null @@ -1,16 +0,0 @@ -# Configuration file for grr-response-server under Debian based systems. - -# The main location for the GRR virtualenv. If you want to run GRR from source, -# simply create a new virtualenv, pip install into it (possibly with the -e -# flag) and point this at the location of the virtualenv. All initd scripts will -# then start up that GRR server instead of the one shipped with the deb package. -GRR_PREFIX=/usr/share/grr-server - -# These args are appended to every command line invocation from /usr/bin/. The -# debian scripts will use this to force the GRR writeback location to -# /etc/grr/. If you want to store the local grr configuration file in another -# location, change the below. If you remove this override the GRR installation -# will write local configuration changes into a private location within the -# virtualenv (this way you can have multiple different GRR installations running -# at the same time). -GRR_EXTRA_ARGS=(--context 'Global Install Context') \ No newline at end of file diff --git a/debian/grr-server.service b/debian/grr-server.service deleted file mode 100644 index 966d9cd897..0000000000 --- a/debian/grr-server.service +++ /dev/null @@ -1,17 +0,0 @@ -# This service is actually a systemd target, but we are using a service since -# targets cannot be reloaded and we may want to implement reload in the future. - -[Unit] -Description=GRR Service -After=syslog.target network.target -Documentation=https://github.com/google/grr - -[Service] -Type=oneshot -RemainAfterExit=yes -ExecReload=/bin/systemctl --no-block reload grr-server@admin_ui.service grr-server@frontend.service grr-server@worker.service grr-server@worker2.service fleetspeak-server.service -ExecStart=/bin/systemctl --no-block start grr-server@admin_ui.service grr-server@frontend.service grr-server@worker.service grr-server@worker2.service fleetspeak-server.service -ExecStop=/bin/systemctl --no-block stop grr-server@admin_ui.service grr-server@frontend.service grr-server@worker.service grr-server@worker2.service fleetspeak-server.service - -[Install] -WantedBy=multi-user.target diff --git a/debian/grr-server@.service b/debian/grr-server@.service deleted file mode 100644 index 21cb72db34..0000000000 --- a/debian/grr-server@.service +++ /dev/null @@ -1,18 +0,0 @@ -[Unit] -Description=GRR %I -PartOf=grr-server.service -ReloadPropagatedFrom=grr-server.service -After=syslog.target network.target -Documentation=https://github.com/google/grr - -[Service] -Type=simple -PrivateTmp=true -Restart=on-failure -LimitNOFILE=65536 -Environment="MPLCONFIGDIR=/var/run/grr/tmp/%i" "PYTHON_EGG_CACHE=/var/run/grr/tmp/%i" -ExecStartPre=/bin/mkdir -p /var/run/grr/tmp/%i -ExecStart=/usr/bin/grr_server --component %i -p StatsStore.process_id=%i_%m - -[Install] -WantedBy=multi-user.target diff --git a/debian/postinst b/debian/postinst deleted file mode 100644 index bc5c103d69..0000000000 --- a/debian/postinst +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# -# Post-installation script for the GRR server deb. -# -# Installs GRR in a virtualenv, sets it up as a systemd service, starts -# the service, then runs 'grr_config_updater initialize'. - -set -e - -# The token below is replaced with shellscript snippets generated -# by debhelper commands. See http://manpages.ubuntu.com/dh_installdeb - -#DEBHELPER# - -case "$1" in - configure) - adduser --system fleetspeak - groupadd --system -f fleetspeak - if [ "$DEBIAN_FRONTEND" != noninteractive ]; then - - echo "#################################################################" - echo "Running grr_config_updater initialize" - echo "To avoid this prompting, set DEBIAN_FRONTEND=noninteractive" - echo "#################################################################" - - grr_config_updater initialize - fi -esac - -echo "#################################################################" -echo "Install complete." -echo "If upgrading, make sure you read the release notes:" -echo "https://grr-doc.readthedocs.io/en/latest/release-notes.html" -echo "#################################################################" diff --git a/debian/preinst b/debian/preinst deleted file mode 100644 index 21608aac63..0000000000 --- a/debian/preinst +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -# -# Pre-installation script for the GRR server deb. -# -# Checks whether a MySQL process is running, and if not, gives the user -# and option to quit the installation. - -set -e - -case "${1}" in - install) - if [[ "${DEBIAN_FRONTEND}" == 'noninteractive' || ! -z "$(pgrep -x mysqld)" ]]; then - exit 0 - fi - MYSQL_WARNING=("" - "######################################################################################" - "GRR has failed to detect a running MySQL instance on this machine." - "This is ok if you plan on connecting to a remote MySQL instance." - "If you aren't though, we recommend you exit this installation and install MySQL first." - "FYI you can skip this check by setting DEBIAN_FRONTEND=noninteractive." - "######################################################################################" - "" - "Would you like to proceed with GRR's installation? [Yn]: ") - (IFS=$'\n';printf "${MYSQL_WARNING[*]}") - read REPLY - if [[ -z "${REPLY}" || "${REPLY}" == 'Y'* || "${REPLY}" == 'y'* ]]; then - exit 0 - elif [[ "${REPLY}" == 'N'* || "${REPLY}" == 'n'* ]]; then - ABORT_MESSAGE=("" - "#####################################################################################" - "Aborting installation. For instructions on how to set up MySQL, please see" - "https://grr-doc.readthedocs.io/en/latest/installing-grr-server/from-release-deb.html." - "#####################################################################################" - "" - "") - (IFS=$'\n';printf "${ABORT_MESSAGE[*]}") - exit 1 - else - echo "Invalid input: '${REPLY}'. Aborting installation." - exit 2 - fi - ;; - upgrade|abort-upgrade) - # Nothing to do if GRR is already installed. - ;; -esac - -# The token below is replaced with shellscript snippets generated -# by debhelper commands. See http://manpages.ubuntu.com/dh_installdeb - -#DEBHELPER# diff --git a/debian/rules b/debian/rules deleted file mode 100644 index f3d1be1b13..0000000000 --- a/debian/rules +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/make -f -# debian/rules that uses debhelper >= 8. - -# Uncomment this to turn on verbose mode (deb creation generates A LOT of -# log output). -# export DH_VERBOSE=1 - -# This has to be exported to make some magic below work. -export DH_OPTIONS -export DH_VIRTUALENV_INSTALL_ROOT=/usr/share/ - -%: - dh $@ --with python-virtualenv,systemd - -dh_python2: - - -override_dh_pysupport: - -override_dh_auto_clean: - -override_dh_auto_test: - -override_dh_installinit: - dh_installinit - bash grr/core/scripts/install_server_from_src.sh -i debian/grr-server - cp -v -r debian/grr-server/usr/share/grr-server/fleetspeak-server-bin/etc/fleetspeak-server debian/grr-server/etc - rm -r debian/grr-server/usr/share/grr-server/fleetspeak-server-bin/etc/fleetspeak-server - -override_dh_installdocs: - -override_dh_installmenu: - -override_dh_installmime: - -override_dh_installmodules: - -override_dh_installlogcheck: - -override_dh_installlogrotate: - -override_dh_installpam: - -override_dh_installppp: - -override_dh_installudev: - -override_dh_installwm: - -override_dh_installxfonts: - -override_dh_link: - dh_link etc/fleetspeak-server usr/share/grr-server/fleetspeak-server-bin/etc/fleetspeak-server - dh_link usr/share/grr-server/fleetspeak-server-bin/usr/bin/fleetspeak-server usr/bin/fleetspeak-server - dh_link usr/share/grr-server/fleetspeak-server-bin/usr/bin/fleetspeak-config usr/bin/fleetspeak-config - dh_link usr/share/grr-server/fleetspeak-server-bin/lib/systemd/system/fleetspeak-server.service lib/systemd/system/fleetspeak-server.service - -override_dh_gconf: - -override_dh_icons: - -override_dh_perl: - -override_dh_strip: - dh_strip --exclude=ffi - -# Removing fleetspeak-client-bin in this rule to conserve space. -# A fleetspeak client installation is obsolete in the server DEB. -override_dh_virtualenv: - dh_virtualenv --python python3 \ - --builtin-venv \ - --use-system-packages \ - --extra-pip-arg "--ignore-installed" \ - --extra-pip-arg "--no-cache-dir" \ - --extra-pip-arg "--no-index" \ - --extra-pip-arg "--find-links=${LOCAL_DEB_PYINDEX}" \ - --skip-install \ - --preinstall "${LOCAL_DEB_PYINDEX}/${API_SDIST}" \ - --preinstall "${LOCAL_DEB_PYINDEX}/${TEMPLATES_SDIST}" \ - --preinstall "${LOCAL_DEB_PYINDEX}/${CLIENT_BUILDER_SDIST}" \ - --preinstall "${LOCAL_DEB_PYINDEX}/${SERVER_SDIST}" - rm -r debian/grr-server/usr/share/grr-server/fleetspeak-client-bin - -override_dh_shlibdeps: - dh_shlibdeps -Xcygrpc.cpython diff --git a/devenv/config/grr-server.yaml b/devenv/config/grr-server.yaml index 86e87975d3..351c8775f0 100644 --- a/devenv/config/grr-server.yaml +++ b/devenv/config/grr-server.yaml @@ -42,22 +42,3 @@ FleetspeakFrontend Context: Server.fleetspeak_enabled: true Server.fleetspeak_server: localhost:4444 Server.initialized: true - -Frontend.certificate: | - -----BEGIN CERTIFICATE----- - MIICuTCCAaGgAwIBAgIBAjANBgkqhkiG9w0BAQsFADAgMREwDwYDVQQDDAhncnJf - dGVzdDELMAkGA1UEBhMCVVMwHhcNMjMwMTIyMTEyMTA2WhcNMzMwMTIwMTEyMTA2 - WjAOMQwwCgYDVQQDDANncnIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB - AQDRgX3/3lvGJ2wHO502LFBmNOdN3OHqeo8LNpam0wzDYKevZUpebcCl4aiqYU8g - t/Cd+F5TCOnjLHRore7c86yzI0cfk2ytP0bTCQsCR6AUXzlSt87J6x510wGgW5oB - pEfdTsBHl+bAm3dzJNA0TzNr2i4VfpV9/L1wEw+Se6lC/J74W+Tjm4cHFtIQcwGt - 547wBU3CN71XFMrV8LhaIT7FV4jOqiGZLCTSSR0143d9TOeEErxwXyqPMPhIF0Xm - ihdd9h6VHq/1L6B0qiKTsGnxdtb0KmBIgs/b9i33PEyAWnTTIcw1eK1ryAulCXjF - e7BDiOtVbz43AjSm1iwaWGM9AgMBAAGjEDAOMAwGA1UdEwEB/wQCMAAwDQYJKoZI - hvcNAQELBQADggEBAIt3sLvzluPDkWvNoDKnil9HQ8zBlP1sxMlwtCvTDZbIiTuM - IK+VL1KuNzGEhpeEbziSpN7ZDUT053xpPYnoZZgQlgLBiNmXJaoHOnj+WAewsK0j - vJm7mxLgqdjkXBVyc7jIE/yoZJihygjwDiA3YgvMj/lWZfqU6f57XJERnVDlEUAW - QP2YYDStQZvQuwdn/Lie3PfNTIgwkFRoFcrd4tQGrWhH7/pEfSetgeGJLbW56xSl - GCXRpNm584CHsx3JzkUNgpM6wl+Jc7arcy8uF6bqbQXOFVL2drgzOFWbI+RXS/74 - TGQFaywDGAmHCMx/vcLacmwycH8tEWVLFP1DbLo= - -----END CERTIFICATE----- diff --git a/docker/install_grr_from_gcs.sh b/docker/install_grr_from_gcs.sh index c7217df5fb..0e2783cf24 100755 --- a/docker/install_grr_from_gcs.sh +++ b/docker/install_grr_from_gcs.sh @@ -16,19 +16,19 @@ WORK_DIR=/tmp/docker_work_dir mkdir "${WORK_DIR}" cd "${WORK_DIR}" -mv -v "$INITIAL_DIR"/_artifacts/grr-server_*.tar.gz . +mv -v "$INITIAL_DIR"/_artifacts/grr_server_*.tar.gz . -tar xzf grr-server_*.tar.gz +tar xzf grr_server-*.tar.gz "${GRR_VENV}/bin/pip" install --no-index --no-cache-dir \ --find-links=grr/local_pypi \ - grr/local_pypi/grr-response-proto-*.zip \ - grr/local_pypi/grr-response-core-*.zip \ - grr/local_pypi/grr-response-client-*.zip \ - grr/local_pypi/grr-api-client-*.zip \ - grr/local_pypi/grr-response-server-*.zip \ - grr/local_pypi/grr-response-test-*.zip \ - grr/local_pypi/grr-response-templates-*.zip + grr/local_pypi/grr_response_proto-*.zip \ + grr/local_pypi/grr_response_core-*.zip \ + grr/local_pypi/grr_response_client-*.zip \ + grr/local_pypi/grr_api_client-*.zip \ + grr/local_pypi/grr_response_server-*.zip \ + grr/local_pypi/grr_response_test-*.zip \ + grr/local_pypi/grr_response_templates-*.zip cd "${INITIAL_DIR}" rm -rf "${WORK_DIR}" diff --git a/docker_config_files/client/client.config b/docker_config_files/client/client.config index 48838c15f3..47ec330a3d 100644 --- a/docker_config_files/client/client.config +++ b/docker_config_files/client/client.config @@ -1,11 +1,9 @@ server: "fleetspeak-frontend:4443" -# .-. -# (o.o) WARNING: Publicly stored key. For testing only. -# |=| NEVER reuse in production. -trusted_certs: "-----BEGIN CERTIFICATE-----\nMIIBhjCCASygAwIBAgIQbZTIkKIjOwVDH5kZDEwz+zAKBggqhkjOPQQDAjAjMSEw\nHwYDVQQDExhGbGVldHNwZWFrIEZsZWV0c3BlYWsgQ0EwHhcNMjQwMTEyMTQ1MTU0\nWhcNMzQwMTA5MTQ1MTU0WjAjMSEwHwYDVQQDExhGbGVldHNwZWFrIEZsZWV0c3Bl\nYWsgQ0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARcKcmCDpGj32sDzRUxBO9E\n9eNg92wGHYYbqHJ5DxqQWVyU8lmE7pPyrZAhVvAAIWQN5pL/MwGRDncOhAciseFW\no0IwQDAOBgNVHQ8BAf8EBAMCAoQwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU\nWl3keEC1M5wmeN/+sUTqrtOVgpIwCgYIKoZIzj0EAwIDSAAwRQIgGMUGaqhSEt4Q\n4SkeTjeU2lr4UpO5wCTRJ80SVENoZUICIQDL31xpZF25HQroy9ApHYuxn8C7oUES\n2RvOjey+9sHQzg==\n-----END CERTIFICATE-----\n" + +trusted_certs: "%TRUSTED_FLEETSPEAK_CERT%" client_label: "" filesystem_handler: { - configuration_directory:"/configs/" - state_file:"/tmp/fleetspeak-client.state" + configuration_directory:"/configs/client" + state_file:"/client_state/fleetspeak-client.state" } streaming:true diff --git a/docker_config_files/client/grr.client.yaml b/docker_config_files/client/grr.client.yaml index 4a67ff58a0..0bf8e128d6 100644 --- a/docker_config_files/client/grr.client.yaml +++ b/docker_config_files/client/grr.client.yaml @@ -1,32 +1,13 @@ Client.fleetspeak_enabled: true -ClientBuilder.fleetspeak_bundled: true -ClientBuilder.template_dir: /client_templates Client.server_urls: - fleetspeak-frontend Client.foreman_check_frequency: 10 # seconds +Config.directory: /configs/client + Logging.verbose: true Logging.engines: file,stderr Logging.path: /tmp/grr-client Logging.filename: /tmp/grr-client/grr-client.log -# .-. -# (o.o) WARNING: Publicly stored key. For testing only. -# |=| NEVER reuse in production. -Client.executable_signing_public_key: | - -----BEGIN PUBLIC KEY----- - MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAx6YQNUwITzi7l+biDnwv - n63Rg3vbfPZexL/0O1XzQw1Z7mFp3uHtnSrkgDmqYIDXwxDXvn8Ck+k8dYt8SZCc - Jq4Jd/YkJXaUiM2E/2Y+Gv33ioVaN7QRyVBGRldK7X6a9Z8tEBE8jF3mlzlO2Z16 - ZCgMLD1I6ZJpHfQFcDGJP7idHY1TVHJ7j9YG8PObi2k9r5E9UBg6DcFD3Rqg5CP/ - OUtE56B7VW3y8q49c8pw+ZfiQaXd11xMLuMOX9Brlsp/RqFC6wvM1RJc9oR08Bq8 - je7ZmTVuwGEUR8snL2eqPqhM1UAvelbEF4IVG9E7A043Fhh7qVPxVGqKSkgfwXS0 - 0QIDAQAB - -----END PUBLIC KEY----- - -Target:Linux: - ClientBuilder.fleetspeak_client_config: /configs/client.config -Target:Windows: - ClientBuilder.fleetspeak_client_config: /configs/client.config -Target:Darwin: - ClientBuilder.fleetspeak_client_config: /configs/client.config +Client.executable_signing_public_key: "%(/configs/public-key.pem|file)" diff --git a/docker_config_files/client/install_client.sh b/docker_config_files/client/install_client.sh new file mode 100755 index 0000000000..be7302b481 --- /dev/null +++ b/docker_config_files/client/install_client.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# +# This script is run when the client is started in the Docker Compose stack. +# It installs the provided debian package if no installers or fleetspeak-client +# binary are found. +# The client installers are repacked by the admin ui. +INSTALLERS_DIR="/client_installers" + +if ! command -v fleetspeak-client &> /dev/null + then + echo "**Installing Client from debian package." + dpkg -i ${INSTALLERS_DIR}/*.deb + else + echo "** Found fleetspeak-client binary, skipping install." +fi + +echo "** Completed client setup." diff --git a/docker_config_files/client/repack_install_client.sh b/docker_config_files/client/repack_install_client.sh deleted file mode 100755 index 9cee4d3e95..0000000000 --- a/docker_config_files/client/repack_install_client.sh +++ /dev/null @@ -1,38 +0,0 @@ -#! /bin/bash - -# GRR client docker compose initialization script. -# This script is run when the client is started in the -# docker-compose stack. It repacks the client using the -# provided configuration files and installs the resulting -# debian package if no installers or fleetspeak-client -# binary are found. -# -# This script assumes the client-config files -# (docker_config_files/client) to be mounted at /configs. - -# Template dir is initializes when building the image via -# the github actions, which also builds the templates. -TEMPLATE_DIR="/client_templates" -INSTALLERS_DIR="/client_installers" - - -if [[ -z "$(ls -A ${INSTALLERS_DIR})" ]] - then - echo "** Repacking clients." - grr_client_build repack_multiple \ - --templates ${TEMPLATE_DIR}/*/*.zip \ - --repack_configs /configs/grr.client.yaml \ - --output_dir ${INSTALLERS_DIR} - else - echo "** Found existing client installers dir, skipping repacking." -fi - -if ! command -v fleetspeak-client &> /dev/null - then - echo "**Installing Client from debian package." - dpkg -i ${INSTALLERS_DIR}/grr.client/*.deb - else - echo "** Found fleetspeak-client binary, skipping install." -fi - -echo "** Completed client setup." diff --git a/docker_config_files/client/textservices/grr_client.service b/docker_config_files/client/textservices/grr_client.service index ee6615fb67..05d3a5e113 100644 --- a/docker_config_files/client/textservices/grr_client.service +++ b/docker_config_files/client/textservices/grr_client.service @@ -2,10 +2,8 @@ name: "GRR" factory: "Daemon" config: { [type.googleapis.com/fleetspeak.daemonservice.Config]: { - argv: "python" - argv: "-m" - argv: "grr_response_client.client" - argv: "--secondary_configs" - argv: "/configs/grr.client.yaml" + argv: "grrd" + argv: "--config" + argv: "/configs/client/grr.client.yaml" } } diff --git a/docker_config_files/init_certs.sh b/docker_config_files/init_certs.sh new file mode 100755 index 0000000000..762f6f88e3 --- /dev/null +++ b/docker_config_files/init_certs.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# +# Script to generate a set of keys and certificates for GRR and Fleetspeak that +# replace the placeholders in the config files. +# +# Usage: +# ./init_certs.sh + +set -ex + +FILE_DIR="$(dirname "$(which "$0")")" + +# Generate key pair .pem files, which is linked in the GRR client and +# server configs (client.yaml, server.local.yaml). +openssl genrsa -out "$FILE_DIR/private-key.pem" +openssl rsa -in "$FILE_DIR/private-key.pem" -pubout -out "$FILE_DIR/public-key.pem" + +# Create a CA/trusted private key and cert for Fleetspeak. +openssl genrsa \ + -out "$FILE_DIR/fleetspeak-ca-key.pem" +openssl req -new -x509 -days 365 -subj "/CN=Fleetspeak CA"\ + -key "$FILE_DIR/fleetspeak-ca-key.pem" \ + -out "$FILE_DIR/fleetspeak-ca-cert.pem" \ + +# Create keys for CA signed key and cert for fleetspeak. Resulting files are +# also copied in the envoy container, see containers/envoy/Dockerfile). +openssl genrsa \ + -out "$FILE_DIR/fleetspeak-key.pem" +openssl req -new -x509 -days 365 \ + -subj "/CN=Fleetspeak CA" -addext "subjectAltName = DNS:fleetspeak-frontend" \ + -key "$FILE_DIR/fleetspeak-key.pem" \ + -out "$FILE_DIR/fleetspeak-cert.pem" \ + -CA "$FILE_DIR/fleetspeak-ca-cert.pem" \ + -CAkey "$FILE_DIR/fleetspeak-ca-key.pem" + +# Replace placeholders in fleetspeak and grr-client config files. +TRUSTED_FLEETSPEAK_CERT=$(sed ':a;N;$!ba;s/\n/\\\\n/g' "$FILE_DIR/fleetspeak-ca-cert.pem") +FLEETSPEAK_KEY=$(sed ':a;N;$!ba;s/\n/\\\\n/g' "$FILE_DIR/fleetspeak-key.pem") +FLEETSPEAK_CERT=$(sed ':a;N;$!ba;s/\n/\\\\n/g' "$FILE_DIR/fleetspeak-cert.pem") + +sed -i 's@%FLEETSPEAK_CERT%@'"$FLEETSPEAK_CERT"'@' "$FILE_DIR/server/textservices/frontend.components.config" +sed -i 's@%FLEETSPEAK_KEY%@'"$FLEETSPEAK_KEY"'@' "$FILE_DIR/server/textservices/frontend.components.config" +sed -i 's@%TRUSTED_FLEETSPEAK_CERT%@'"$TRUSTED_FLEETSPEAK_CERT"'@' "$FILE_DIR/client/client.config" diff --git a/docker_config_files/server/grr.server.yaml b/docker_config_files/server/grr.server.yaml index a4d6c642cc..3879125e25 100644 --- a/docker_config_files/server/grr.server.yaml +++ b/docker_config_files/server/grr.server.yaml @@ -1,14 +1,13 @@ -AdminUI.csrf_secret_key: KPK,_0a_xY&DTeiaokEdsH1uXGobNIhfrr67BTSLlPPv64_UE0nyn8QsD6 - nwNZ-C87mwVLkdrc77AKdoz12hxzmYXsBTT1bC#d7 -AdminUI.url: http://admin-ui:8000 -AdminUI.bind: 0.0.0.0 -AdminUI.use_precompiled_js: true - Server.initialized: true Server.fleetspeak_enabled: true Server.fleetspeak_server: fleetspeak-admin:4444 Server.fleetspeak_message_listen_address: grr-fleetspeak-frontend:11111 +AdminUI.csrf_secret_key: random_passphrase____PLEASE_REPLACE_WHEN_RUNNING_IN_PRODUCTION____ +AdminUI.url: http://admin-ui:8000 +AdminUI.bind: 0.0.0.0 +AdminUI.use_precompiled_js: true + API.DefaultRouter: ApiCallRouterWithoutChecks Mysql.host: mysql-host @@ -23,45 +22,18 @@ Mysql.username: grru Blobstore.implementation: DbBlobStore Database.implementation: MysqlDB -# .-. -# (o.o) WARNING: Publicly stored key. For testing only. -# |=| NEVER reuse in production. -Client.executable_signing_public_key: | - -----BEGIN PUBLIC KEY----- - MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAx6YQNUwITzi7l+biDnwv - n63Rg3vbfPZexL/0O1XzQw1Z7mFp3uHtnSrkgDmqYIDXwxDXvn8Ck+k8dYt8SZCc - Jq4Jd/YkJXaUiM2E/2Y+Gv33ioVaN7QRyVBGRldK7X6a9Z8tEBE8jF3mlzlO2Z16 - ZCgMLD1I6ZJpHfQFcDGJP7idHY1TVHJ7j9YG8PObi2k9r5E9UBg6DcFD3Rqg5CP/ - OUtE56B7VW3y8q49c8pw+ZfiQaXd11xMLuMOX9Brlsp/RqFC6wvM1RJc9oR08Bq8 - je7ZmTVuwGEUR8snL2eqPqhM1UAvelbEF4IVG9E7A043Fhh7qVPxVGqKSkgfwXS0 - 0QIDAQAB - -----END PUBLIC KEY----- +Client.executable_signing_public_key: "%(/configs/public-key.pem|file)" +PrivateKeys.executable_signing_private_key: "%(/configs/private-key.pem|file)" + +# Configuration for repacking client templates: +Client.fleetspeak_enabled: true +ClientBuilder.fleetspeak_bundled: true +ClientBuilder.template_dir: /client_templates +ClientBuilder.executables_dir: /client_installers -PrivateKeys.executable_signing_private_key: | - -----BEGIN RSA PRIVATE KEY----- - MIIEpAIBAAKCAQEAx6YQNUwITzi7l+biDnwvn63Rg3vbfPZexL/0O1XzQw1Z7mFp - 3uHtnSrkgDmqYIDXwxDXvn8Ck+k8dYt8SZCcJq4Jd/YkJXaUiM2E/2Y+Gv33ioVa - N7QRyVBGRldK7X6a9Z8tEBE8jF3mlzlO2Z16ZCgMLD1I6ZJpHfQFcDGJP7idHY1T - VHJ7j9YG8PObi2k9r5E9UBg6DcFD3Rqg5CP/OUtE56B7VW3y8q49c8pw+ZfiQaXd - 11xMLuMOX9Brlsp/RqFC6wvM1RJc9oR08Bq8je7ZmTVuwGEUR8snL2eqPqhM1UAv - elbEF4IVG9E7A043Fhh7qVPxVGqKSkgfwXS00QIDAQABAoIBAQCi51KEWoTRN4aC - PMcpcJVfYnH5Kj/+5/yN596957T1elhuFRhQ3+KFgrEuG191HMxxAzY23uXYkNBf - TTBdylxPh2R8eOAnnWk3cxLZXrDAT4gDhCoIF6sHq7Obw7CEtvB0CKy5VockNZ5o - uD8pe8CZJsA//MWYqHmTEkC5ugG2dlde7FcYHsqVU7NlGHhz5UqPpzrgvdTfnWwj - GOd2zL+BuUKbs8ZIVGEDbgtr8ILNN9MMK8nDioIB29SMWP/Jfb2Z7HSRkn2HK7Jf - bkv/eTJlOJnAlB5BbDDvQ8vUPgk0j0cMjcapoyoENGmbsgSvydG2O7RyBnkeGmud - vEExNZHBAoGBAPgGmD3A07pTYGzd7RytJJZ1u+so4IlWPg2Jp9p0WmP6D6vbB2dl - 1lIdtzII5hh/wbd2FNZJ5X2iV93gQsffRBGeOJ8b5No91q/EdmCZpFGu7LJQqWVO - 1+Nft/xW6Kkog811KwYNgQpE241ZRCGoD/KzZpOfb9n+EW+hVYbjOfiZAoGBAM4R - S56AFXKHIoZQOgX1drsWr6DKDH8Za7BNsGT1nDi1ROmNZxzx8I9avF4ZSwUMmiXR - AXMY69CjqFFwTtWhrZ8UHhl5x7zWAffQdof4jKtdCJ8G4CyYDCZ31Cbi7Gfo4tUP - FmLmN59o3l69887y1vgyFnDevSGuCzJ9hJ1LSij5AoGAGKjvMhSd+ISZrblS/erp - HFyQVo015fHBMa9iFQJEinQuYrPgRJOHf5qcwEjKN91b8VW4NKYcPyWI/vJxMVYt - emL01jz7wAct9UPfUTN1dvmhZwlGDmCMbnrx3BD4CPmSQTdJE8z76311JtSdRYtk - KolTxZGwmUf9i8/KpSKqfOECgYB8Kj23TpQdw0FRTwv3RTV6e6vtpXEsMGQMAnPU - EY5FOSxB0hscfMeniVPRG0pxy2sieDJ4aL7Go6YrFBHcdaQJI3UTgqaQqR7cdHbH - bUNNiixErj7rf95qW2+w0rEB13i+Sm4Bv5gqbGT5D1nWC8ruGDgfYIbzwUwr6ye6 - I4CW+QKBgQC9xKPizqJoi375rDeLVSc/bN3fidyj+Ti87YQa9sDSyXxSF2uk2HUF - xCjMJcqyIOhPSze9wpip6edj8p6N3pvKEMLdFrRJR9Gkv/V9+kJffJbLwyH6Ta/x - v89V954580cna0V/lZYpZM/DDdhVv3hCaGIm+uAHA1mYtxzBBTKX3Q== - -----END RSA PRIVATE KEY----- +Target:Linux: + ClientBuilder.fleetspeak_client_config: /configs/client/client.config +Target:Windows: + ClientBuilder.fleetspeak_client_config: /configs/client/client.config +Target:Darwin: + ClientBuilder.fleetspeak_client_config: /configs/client/client.config diff --git a/docker_config_files/server/repack_clients.sh b/docker_config_files/server/repack_clients.sh new file mode 100755 index 0000000000..dd2311775c --- /dev/null +++ b/docker_config_files/server/repack_clients.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# This script repacks the client using the provided configuration +# and uploads the installers to the blobstore to make them available +# via the Web UI. +# +# In this Docker Compose example the folder the installers are stored +# is mounted in the grr-client container, which will install the debian +# installer on startup. +INSTALLERS_DIR="/client_installers" + + +if [[ -z "$(ls -A ${INSTALLERS_DIR})" ]] + then + echo "** Repacking clients." + grr_config_updater repack_clients \ + --secondary_configs /configs/server/grr.server.yaml + else + echo "** Found existing client installers dir, skipping repacking." +fi \ No newline at end of file diff --git a/docker_config_files/server/textservices/frontend.components.config b/docker_config_files/server/textservices/frontend.components.config index cc8e7e64f7..a3147cf0b7 100644 --- a/docker_config_files/server/textservices/frontend.components.config +++ b/docker_config_files/server/textservices/frontend.components.config @@ -2,11 +2,8 @@ mysql_data_source_name: "fleetspeak-user:fleetspeak-password@tcp(mysql-host:3306 https_config: < listen_address: "fleetspeak-frontend:4443" - # .-. - # (o.o) WARNING: Publicly stored key. For testing only. - # |=| NEVER reuse in production. - certificates: "-----BEGIN CERTIFICATE-----\nMIIBzjCCAXSgAwIBAgIRAJDHUJue7M1sw5dSJH9lrz4wCgYIKoZIzj0EAwIwIzEh\nMB8GA1UEAxMYRmxlZXRzcGVhayBGbGVldHNwZWFrIENBMB4XDTI0MDExMjE0NTE1\nNFoXDTI1MDExMTE0NTE1NFowJzElMCMGA1UEAxMcRmxlZXRzcGVhayBGbGVldHNw\nZWFrIFNlcnZlcjBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABLLeYLg2srhBEAlf\nCJbddIf0ops80/sATmwCsXk3Ly/skJCmz/Dk9T05IjolyNZbbG6clYgKTXSbA+OK\nHHOXqvSjgYQwgYEwDgYDVR0PAQH/BAQDAgKEMA8GA1UdEwEB/wQFMAMBAf8wHQYD\nVR0OBBYEFFzvvyzN5JaKm8/oXmdtNioscn46MB8GA1UdIwQYMBaAFFpd5HhAtTOc\nJnjf/rFE6q7TlYKSMB4GA1UdEQQXMBWCE2ZsZWV0c3BlYWstZnJvbnRlbmQwCgYI\nKoZIzj0EAwIDSAAwRQIgc2g8s657NA/8hQqKcPlMZRiFk3BsWvc9v5ztZoOG6PAC\nIQDW/geEXuIjaHN/Z6CEG3xAnBLW2ZOBd8ml50yMmHaVuA==\n-----END CERTIFICATE-----\n" - key: "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIJzmo0Nfbncsx7ql/oQmlBUuLF6/AcLe+P/a80cq/uIpoAoGCCqGSM49\nAwEHoUQDQgAEst5guDayuEEQCV8Ilt10h/SimzzT+wBObAKxeTcvL+yQkKbP8OT1\nPTkiOiXI1ltsbpyViApNdJsD44occ5eq9A==\n-----END EC PRIVATE KEY-----\n" + certificates: "%FLEETSPEAK_CERT%" + key: "%FLEETSPEAK_KEY%" > notification_listen_address: "fleetspeak-frontend:10000" notification_public_address: "fleetspeak-frontend:10000" diff --git a/docker_config_files/testing/grr.testing.yaml b/docker_config_files/testing/grr.testing.yaml index 678fa4b2b1..86b3237640 100644 --- a/docker_config_files/testing/grr.testing.yaml +++ b/docker_config_files/testing/grr.testing.yaml @@ -9,34 +9,5 @@ Mysql.database: grr Mysql.password: grrp Mysql.username: grru -# .-. -# (o.o) WARNING: Publicly stored key. For testing only. -# |=| NEVER reuse in production. PrivateKeys.executable_signing_private_key: | - -----BEGIN RSA PRIVATE KEY----- - MIIEpAIBAAKCAQEAx6YQNUwITzi7l+biDnwvn63Rg3vbfPZexL/0O1XzQw1Z7mFp - 3uHtnSrkgDmqYIDXwxDXvn8Ck+k8dYt8SZCcJq4Jd/YkJXaUiM2E/2Y+Gv33ioVa - N7QRyVBGRldK7X6a9Z8tEBE8jF3mlzlO2Z16ZCgMLD1I6ZJpHfQFcDGJP7idHY1T - VHJ7j9YG8PObi2k9r5E9UBg6DcFD3Rqg5CP/OUtE56B7VW3y8q49c8pw+ZfiQaXd - 11xMLuMOX9Brlsp/RqFC6wvM1RJc9oR08Bq8je7ZmTVuwGEUR8snL2eqPqhM1UAv - elbEF4IVG9E7A043Fhh7qVPxVGqKSkgfwXS00QIDAQABAoIBAQCi51KEWoTRN4aC - PMcpcJVfYnH5Kj/+5/yN596957T1elhuFRhQ3+KFgrEuG191HMxxAzY23uXYkNBf - TTBdylxPh2R8eOAnnWk3cxLZXrDAT4gDhCoIF6sHq7Obw7CEtvB0CKy5VockNZ5o - uD8pe8CZJsA//MWYqHmTEkC5ugG2dlde7FcYHsqVU7NlGHhz5UqPpzrgvdTfnWwj - GOd2zL+BuUKbs8ZIVGEDbgtr8ILNN9MMK8nDioIB29SMWP/Jfb2Z7HSRkn2HK7Jf - bkv/eTJlOJnAlB5BbDDvQ8vUPgk0j0cMjcapoyoENGmbsgSvydG2O7RyBnkeGmud - vEExNZHBAoGBAPgGmD3A07pTYGzd7RytJJZ1u+so4IlWPg2Jp9p0WmP6D6vbB2dl - 1lIdtzII5hh/wbd2FNZJ5X2iV93gQsffRBGeOJ8b5No91q/EdmCZpFGu7LJQqWVO - 1+Nft/xW6Kkog811KwYNgQpE241ZRCGoD/KzZpOfb9n+EW+hVYbjOfiZAoGBAM4R - S56AFXKHIoZQOgX1drsWr6DKDH8Za7BNsGT1nDi1ROmNZxzx8I9avF4ZSwUMmiXR - AXMY69CjqFFwTtWhrZ8UHhl5x7zWAffQdof4jKtdCJ8G4CyYDCZ31Cbi7Gfo4tUP - FmLmN59o3l69887y1vgyFnDevSGuCzJ9hJ1LSij5AoGAGKjvMhSd+ISZrblS/erp - HFyQVo015fHBMa9iFQJEinQuYrPgRJOHf5qcwEjKN91b8VW4NKYcPyWI/vJxMVYt - emL01jz7wAct9UPfUTN1dvmhZwlGDmCMbnrx3BD4CPmSQTdJE8z76311JtSdRYtk - KolTxZGwmUf9i8/KpSKqfOECgYB8Kj23TpQdw0FRTwv3RTV6e6vtpXEsMGQMAnPU - EY5FOSxB0hscfMeniVPRG0pxy2sieDJ4aL7Go6YrFBHcdaQJI3UTgqaQqR7cdHbH - bUNNiixErj7rf95qW2+w0rEB13i+Sm4Bv5gqbGT5D1nWC8ruGDgfYIbzwUwr6ye6 - I4CW+QKBgQC9xKPizqJoi375rDeLVSc/bN3fidyj+Ti87YQa9sDSyXxSF2uk2HUF - xCjMJcqyIOhPSze9wpip6edj8p6N3pvKEMLdFrRJR9Gkv/V9+kJffJbLwyH6Ta/x - v89V954580cna0V/lZYpZM/DDdhVv3hCaGIm+uAHA1mYtxzBBTKX3Q== - -----END RSA PRIVATE KEY----- + %(/configs/private-key.pem|file) diff --git a/grr/client/grr_response_client/actions.py b/grr/client/grr_response_client/actions.py index 669fe5e377..c81bcb4a09 100644 --- a/grr/client/grr_response_client/actions.py +++ b/grr/client/grr_response_client/actions.py @@ -8,7 +8,6 @@ from typing import NamedTuple from absl import flags - import psutil from grr_response_client.unprivileged import communication @@ -56,10 +55,20 @@ def cpu_used(self) -> _CpuUsed: end = self.proc.cpu_times() unprivileged_cpu_end = communication.TotalServerCpuTime() unprivileged_sys_end = communication.TotalServerSysTime() - return _CpuUsed((end.user - self.cpu_start.user + unprivileged_cpu_end - - self.unprivileged_cpu_start), - (end.system - self.cpu_start.system + unprivileged_sys_end - - self.unprivileged_sys_start)) + return _CpuUsed( + ( + end.user + - self.cpu_start.user + + unprivileged_cpu_end + - self.unprivileged_cpu_start + ), + ( + end.system + - self.cpu_start.system + + unprivileged_sys_end + - self.unprivileged_sys_start + ), + ) @property def total_cpu_used(self) -> float: @@ -84,6 +93,7 @@ class ActionPlugin(object): EnumerateInterfaces) as linux actions must accept and return the same rdfvalue types as their linux counterparts. """ + # The rdfvalue used to encode this message. in_rdfvalue = None @@ -112,10 +122,12 @@ def __init__(self, grr_worker=None): self.response_id = INITIAL_RESPONSE_ID self.cpu_used = None self.status = rdf_flows.GrrStatus( - status=rdf_flows.GrrStatus.ReturnedStatus.OK) + status=rdf_flows.GrrStatus.ReturnedStatus.OK + ) self._last_gc_run = rdfvalue.RDFDatetime.Now() self._gc_frequency = rdfvalue.Duration.From( - config.CONFIG["Client.gc_frequency"], rdfvalue.SECONDS) + config.CONFIG["Client.gc_frequency"], rdfvalue.SECONDS + ) self.cpu_times = _CpuTimes() self.cpu_limit = rdf_flows.GrrMessage().cpu_limit self.start_time = None @@ -142,22 +154,26 @@ def Execute(self, message): try: if self.message.args_rdf_name: if not self.in_rdfvalue: - raise RuntimeError("Did not expect arguments, got %s." % - self.message.args_rdf_name) + raise RuntimeError( + "Did not expect arguments, got %s." % self.message.args_rdf_name + ) if self.in_rdfvalue.__name__ != self.message.args_rdf_name: raise RuntimeError( - "Unexpected arg type %s != %s." % - (self.message.args_rdf_name, self.in_rdfvalue.__name__)) + "Unexpected arg type %s != %s." + % (self.message.args_rdf_name, self.in_rdfvalue.__name__) + ) args = self.message.payload # Only allow authenticated messages in the client if self._authentication_required and ( - self.message.auth_state != - rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED): - raise RuntimeError("Message for %s was not Authenticated." % - self.message.name) + self.message.auth_state + != rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED + ): + raise RuntimeError( + "Message for %s was not Authenticated." % self.message.name + ) self.cpu_times = _CpuTimes() self.cpu_limit = self.message.cpu_limit @@ -178,31 +194,46 @@ def Execute(self, message): except NetworkBytesExceededError as e: self.grr_worker.SendClientAlert("Network limit exceeded.") - self.SetStatus(rdf_flows.GrrStatus.ReturnedStatus.NETWORK_LIMIT_EXCEEDED, - "%r: %s" % (e, e), traceback.format_exc()) + self.SetStatus( + rdf_flows.GrrStatus.ReturnedStatus.NETWORK_LIMIT_EXCEEDED, + "%r: %s" % (e, e), + traceback.format_exc(), + ) except RuntimeExceededError as e: self.grr_worker.SendClientAlert("Runtime limit exceeded.") - self.SetStatus(rdf_flows.GrrStatus.ReturnedStatus.RUNTIME_LIMIT_EXCEEDED, - "%r: %s" % (e, e), traceback.format_exc()) + self.SetStatus( + rdf_flows.GrrStatus.ReturnedStatus.RUNTIME_LIMIT_EXCEEDED, + "%r: %s" % (e, e), + traceback.format_exc(), + ) except CPUExceededError as e: self.grr_worker.SendClientAlert("Cpu limit exceeded.") - self.SetStatus(rdf_flows.GrrStatus.ReturnedStatus.CPU_LIMIT_EXCEEDED, - "%r: %s" % (e, e), traceback.format_exc()) + self.SetStatus( + rdf_flows.GrrStatus.ReturnedStatus.CPU_LIMIT_EXCEEDED, + "%r: %s" % (e, e), + traceback.format_exc(), + ) # We want to report back all errors and map Python exceptions to # Grr Errors. except Exception as e: # pylint: disable=broad-except - self.SetStatus(rdf_flows.GrrStatus.ReturnedStatus.GENERIC_ERROR, - "%r: %s" % (e, e), traceback.format_exc()) + self.SetStatus( + rdf_flows.GrrStatus.ReturnedStatus.GENERIC_ERROR, + "%r: %s" % (e, e), + traceback.format_exc(), + ) if flags.FLAGS.pdb_post_mortem: pdb.post_mortem() if self.status.status != rdf_flows.GrrStatus.ReturnedStatus.OK: - logging.info("Job Error (%s): %s", self.__class__.__name__, - self.status.error_message) + logging.info( + "Job Error (%s): %s", + self.__class__.__name__, + self.status.error_message, + ) if self.status.backtrace: logging.debug(self.status.backtrace) @@ -240,8 +271,9 @@ def Run(self, unused_args): Raises: KeyError: if not implemented. """ - raise KeyError("Action %s not available on this platform." % - self.message.name) + raise KeyError( + "Action %s not available on this platform." % self.message.name + ) def SetStatus(self, status, message="", backtrace=None): """Set a status to report back to the server.""" @@ -255,10 +287,12 @@ def SetStatus(self, status, message="", backtrace=None): # some other method to communicate with well-known flows. The naming is also # confusing since sending messages to well-knows flows is not really replying # to anything. - def SendReply(self, - rdf_value=None, - session_id=None, - message_type=rdf_flows.GrrMessage.Type.MESSAGE): + def SendReply( + self, + rdf_value=None, + session_id=None, + message_type=rdf_flows.GrrMessage.Type.MESSAGE, + ): """Send response back to the server.""" # TODO(hanuszczak): This is pretty bad. Here we assume that if the session # id is not none we are "replying" to a well-known flow. If we are replying @@ -323,8 +357,11 @@ def Progress(self): return if self.runtime_limit and now - self.start_time > self.runtime_limit: - raise RuntimeExceededError("{} exceeded runtime limit of {}.".format( - type(self).__name__, self.runtime_limit)) + raise RuntimeExceededError( + "{} exceeded runtime limit of {}.".format( + type(self).__name__, self.runtime_limit + ) + ) ActionPlugin.last_progress_time = now @@ -346,7 +383,8 @@ def SyncTransactionLog(self): def ChargeBytesToSession(self, length): self.grr_worker.ChargeBytesToSession( - self.message.session_id, length, limit=self.network_bytes_limit) + self.message.session_id, length, limit=self.network_bytes_limit + ) @property def session_id(self): diff --git a/grr/client/grr_response_client/client_actions/action_test.py b/grr/client/grr_response_client/client_actions/action_test.py index 7b5c3ce8ba..ce742d1595 100644 --- a/grr/client/grr_response_client/client_actions/action_test.py +++ b/grr/client/grr_response_client/client_actions/action_test.py @@ -28,6 +28,7 @@ class ProgressAction(actions.ActionPlugin): """A mock action which just calls Progress.""" + in_rdfvalue = rdf_client.LogMessage out_rdfvalues = [rdf_client.LogMessage] @@ -57,7 +58,8 @@ def testReadBuffer(self): p = rdf_paths.PathSpec(path=path, pathtype=rdf_paths.PathSpec.PathType.OS) result = self.RunAction( standard.ReadBuffer, - rdf_client.BufferReference(pathspec=p, offset=100, length=10))[0] + rdf_client.BufferReference(pathspec=p, offset=100, length=10), + )[0] self.assertEqual(result.offset, 100) self.assertEqual(result.length, 10) @@ -66,8 +68,9 @@ def testReadBuffer(self): def testListDirectory(self): """Tests listing directories.""" p = rdf_paths.PathSpec(path=self.base_path, pathtype=0) - results = self.RunAction(standard.ListDirectory, - rdf_client_action.ListDirRequest(pathspec=p)) + results = self.RunAction( + standard.ListDirectory, rdf_client_action.ListDirRequest(pathspec=p) + ) # Find the number.txt file result = None for result in results: @@ -114,7 +117,8 @@ def ProcessIter(): def testRaisesWhenRuntimeLimitIsExceeded(self): message = rdf_flows.GrrMessage( name="ProgressAction", - runtime_limit_us=rdfvalue.Duration.From(9, rdfvalue.SECONDS)) + runtime_limit_us=rdfvalue.Duration.From(9, rdfvalue.SECONDS), + ) worker = mock.MagicMock() with test_lib.FakeTime(100): action = ProgressAction(worker) @@ -122,19 +126,22 @@ def testRaisesWhenRuntimeLimitIsExceeded(self): action.Execute(message) self.assertEqual(action.SendReply.call_count, 1) - self.assertEqual(action.SendReply.call_args[0][0].status, - "RUNTIME_LIMIT_EXCEEDED") + self.assertEqual( + action.SendReply.call_args[0][0].status, "RUNTIME_LIMIT_EXCEEDED" + ) self.assertEqual(worker.Heartbeat.call_count, 1) self.assertEqual(worker.SendClientAlert.call_count, 1) - self.assertEqual(worker.SendClientAlert.call_args[0][0], - "Runtime limit exceeded.") + self.assertEqual( + worker.SendClientAlert.call_args[0][0], "Runtime limit exceeded." + ) def testDoesNotRaiseWhenFasterThanRuntimeLimit(self): message = rdf_flows.GrrMessage( name="ProgressAction", - runtime_limit_us=rdfvalue.Duration.From(16, rdfvalue.SECONDS)) + runtime_limit_us=rdfvalue.Duration.From(16, rdfvalue.SECONDS), + ) worker = mock.MagicMock() with test_lib.FakeTime(100): action = ProgressAction(worker) @@ -162,11 +169,15 @@ def testCPUAccounting(self): server_cpu_time = 1.0 server_sys_time = 1.1 stack.enter_context( - mock.patch.object(communication, "TotalServerCpuTime", - lambda: server_cpu_time)) + mock.patch.object( + communication, "TotalServerCpuTime", lambda: server_cpu_time + ) + ) stack.enter_context( - mock.patch.object(communication, "TotalServerSysTime", - lambda: server_sys_time)) + mock.patch.object( + communication, "TotalServerSysTime", lambda: server_sys_time + ) + ) process_cpu_time = 1.2 process_sys_time = 1.3 @@ -177,9 +188,9 @@ def __init__(self, pid=None): pass def cpu_times(self): # pylint: disable=invalid-name - return collections.namedtuple("pcputimes", - ["user", "system"])(process_cpu_time, - process_sys_time) + return collections.namedtuple("pcputimes", ["user", "system"])( + process_cpu_time, process_sys_time + ) stack.enter_context(mock.patch.object(psutil, "Process", FakeProcess)) @@ -202,10 +213,12 @@ def Run(self, *args): self.assertEqual(action.SendReply.call_count, 1) self.assertAlmostEqual( action.SendReply.call_args[0][0].cpu_time_used.user_cpu_time, - 42.0 - 1.0 + 10.0 - 1.2) + 42.0 - 1.0 + 10.0 - 1.2, + ) self.assertAlmostEqual( action.SendReply.call_args[0][0].cpu_time_used.system_cpu_time, - 43.0 - 1.1 + 11.0 - 1.3) + 43.0 - 1.1 + 11.0 - 1.3, + ) def testCPULimit(self): @@ -240,11 +253,13 @@ def cpu_times(self): # pylint: disable=g-bad-name self.assertEqual("CPU_LIMIT_EXCEEDED", reply.status) self.assertEqual(worker.SendClientAlert.call_count, 1) - self.assertEqual(worker.SendClientAlert.call_args[0][0], - "Cpu limit exceeded.") + self.assertEqual( + worker.SendClientAlert.call_args[0][0], "Cpu limit exceeded." + ) - @unittest.skipIf(platform.system() == "Windows", - "os.statvfs is not available on Windows") + @unittest.skipIf( + platform.system() == "Windows", "os.statvfs is not available on Windows" + ) def testStatFS(self): import posix # pylint: disable=g-import-not-at-top @@ -261,9 +276,18 @@ def testStatFS(self): f_namemax = 255 def MockStatFS(unused_path): - return posix.statvfs_result( - (f_bsize, f_frsize, f_blocks, f_bfree, f_bavail, f_files, f_ffree, - f_favail, f_flag, f_namemax)) + return posix.statvfs_result(( + f_bsize, + f_frsize, + f_blocks, + f_bfree, + f_bavail, + f_files, + f_ffree, + f_favail, + f_flag, + f_namemax, + )) def MockIsMount(path): """Only return True for the root path.""" @@ -274,18 +298,21 @@ def MockIsMount(path): # well). return path == "/" or path == b"/" - with utils.MultiStubber((os, "statvfs", MockStatFS), - (os.path, "ismount", MockIsMount)): + with utils.MultiStubber( + (os, "statvfs", MockStatFS), (os.path, "ismount", MockIsMount) + ): # This test assumes "/" is the mount point for /usr/bin results = self.RunAction( standard.StatFS, - rdf_client_action.StatFSRequest(path_list=["/usr/bin", "/"])) + rdf_client_action.StatFSRequest(path_list=["/usr/bin", "/"]), + ) self.assertLen(results, 2) # Both results should have mount_point as "/" - self.assertEqual(results[0].unixvolume.mount_point, - results[1].unixvolume.mount_point) + self.assertEqual( + results[0].unixvolume.mount_point, results[1].unixvolume.mount_point + ) result = results[0] self.assertEqual(result.bytes_per_sector, f_bsize) self.assertEqual(result.sectors_per_allocation_unit, 1) @@ -298,7 +325,8 @@ def MockIsMount(path): # Test we get a result even if one path is bad results = self.RunAction( standard.StatFS, - rdf_client_action.StatFSRequest(path_list=["/does/not/exist", "/"])) + rdf_client_action.StatFSRequest(path_list=["/does/not/exist", "/"]), + ) self.assertLen(results, 1) self.assertEqual(result.Name(), "/") diff --git a/grr/client/grr_response_client/client_actions/admin.py b/grr/client/grr_response_client/client_actions/admin.py index fa2f038e4b..9f60bdb721 100644 --- a/grr/client/grr_response_client/client_actions/admin.py +++ b/grr/client/grr_response_client/client_actions/admin.py @@ -30,6 +30,7 @@ class Echo(actions.ActionPlugin): """Returns a message to the server.""" + in_rdfvalue = rdf_client_action.EchoRequest out_rdfvalues = [rdf_client_action.EchoRequest] @@ -54,6 +55,7 @@ def Run(self, args): class GetPlatformInfo(actions.ActionPlugin): """Retrieves platform information.""" + out_rdfvalues = [rdf_client.Uname] def Run(self, unused_args): @@ -66,6 +68,7 @@ class Kill(actions.ActionPlugin): Used for testing process respawn. """ + out_rdfvalues = [rdf_flows.GrrMessage] def Run(self, unused_arg): @@ -85,24 +88,20 @@ def Run(self, unused_arg): class GetConfiguration(actions.ActionPlugin): """Retrieves the running configuration parameters.""" + in_rdfvalue = None out_rdfvalues = [rdf_protodict.Dict] - BLOCKED_PARAMETERS = ["Client.private_key"] - def Run(self, unused_arg): """Retrieve the configuration except for the blocked parameters.""" out = self.out_rdfvalues[0]() for descriptor in config.CONFIG.type_infos: - if descriptor.name in self.BLOCKED_PARAMETERS: - value = "[Redacted]" - else: - try: - value = config.CONFIG.Get(descriptor.name, default=None) - except (config_lib.Error, KeyError, AttributeError, ValueError) as e: - logging.info("Config reading error: %s", e) - continue + try: + value = config.CONFIG.Get(descriptor.name, default=None) + except (config_lib.Error, KeyError, AttributeError, ValueError) as e: + logging.info("Config reading error: %s", e) + continue if value is not None: out[descriptor.name] = value @@ -112,6 +111,7 @@ def Run(self, unused_arg): class GetLibraryVersions(actions.ActionPlugin): """Retrieves version information for installed libraries.""" + in_rdfvalue = None out_rdfvalues = [rdf_protodict.Dict] @@ -160,6 +160,7 @@ def Run(self, unused_arg): class UpdateConfiguration(actions.ActionPlugin): """Updates configuration parameters on the client.""" + in_rdfvalue = rdf_protodict.Dict UPDATABLE_FIELDS = {"Client.foreman_check_frequency", @@ -190,13 +191,16 @@ def Run(self, arg): smart_arg = {str(field): value for field, value in arg.items()} disallowed_fields = [ - field for field in smart_arg + field + for field in smart_arg if field not in UpdateConfiguration.UPDATABLE_FIELDS ] if disallowed_fields: - raise ValueError("Received an update request for restricted field(s) %s." - % ",".join(disallowed_fields)) + raise ValueError( + "Received an update request for restricted field(s) %s." + % ",".join(disallowed_fields) + ) if platform.system() != "Windows": # Check config validity before really applying the changes. This isn't @@ -237,11 +241,13 @@ def GetClientInformation() -> rdf_client.ClientInformation: build_time=config.CONFIG["Client.build_time"], labels=config.CONFIG.Get("Client.labels", default=None), timeline_btime_support=timeline.BTIME_SUPPORT, - sandbox_support=sandbox.IsSandboxInitialized()) + sandbox_support=sandbox.IsSandboxInitialized(), + ) class GetClientInfo(actions.ActionPlugin): """Obtains information about the GRR client installed.""" + out_rdfvalues = [rdf_client.ClientInformation] def Run(self, unused_args): @@ -259,16 +265,20 @@ def _CheckInterrogateTrigger(self) -> bool: interrogate_trigger_path = config.CONFIG["Client.interrogate_trigger_path"] if not interrogate_trigger_path: logging.info( - "Client.interrogate_trigger_path not set, skipping the check.") + "Client.interrogate_trigger_path not set, skipping the check." + ) return False if not os.path.exists(interrogate_trigger_path): - logging.info("Interrogate trigger file (%s) does not exist.", - interrogate_trigger_path) + logging.info( + "Interrogate trigger file (%s) does not exist.", + interrogate_trigger_path, + ) return False - logging.info("Interrogate trigger file exists: %s", - interrogate_trigger_path) + logging.info( + "Interrogate trigger file exists: %s", interrogate_trigger_path + ) # First try to remove the file and return True only if the removal # is successful. This is to prevent a permission error + a crash loop from @@ -278,7 +288,10 @@ def _CheckInterrogateTrigger(self) -> bool: except (OSError, IOError) as e: logging.exception( "Not triggering interrogate - failed to remove the " - "interrogate trigger file (%s): %s", interrogate_trigger_path, e) + "interrogate trigger file (%s): %s", + interrogate_trigger_path, + e, + ) return False return True @@ -300,4 +313,5 @@ def Run(self, unused_arg, ttl=None): response_id=0, request_id=0, message_type=rdf_flows.GrrMessage.Type.MESSAGE, - ttl=ttl) + ttl=ttl, + ) diff --git a/grr/client/grr_response_client/client_actions/admin_test.py b/grr/client/grr_response_client/client_actions/admin_test.py index 173f211f1a..db3c5583de 100644 --- a/grr/client/grr_response_client/client_actions/admin_test.py +++ b/grr/client/grr_response_client/client_actions/admin_test.py @@ -40,7 +40,7 @@ def testUpdateConfiguration(self): # Make sure the file is gone self.assertRaises(IOError, open, self.config_file) - location = [u"http://www.example1.com/", u"http://www.example2.com/"] + location = ["http://www.example1.com/", "http://www.example2.com/"] request = rdf_protodict.Dict() request["Client.server_urls"] = location request["Client.foreman_check_frequency"] = 3600 @@ -63,11 +63,11 @@ def testUpdateConfiguration(self): def testOnlyUpdatableFieldsAreUpdated(self): with test_lib.ConfigOverrider({ - "Client.server_urls": [u"http://something.com/"], - "Client.server_serial_number": 1 + "Client.server_urls": ["http://something.com/"], + "Client.server_serial_number": 1, }): - location = [u"http://www.example.com"] + location = ["http://www.example.com"] request = rdf_protodict.Dict() request["Client.server_urls"] = location request["Client.server_serial_number"] = 10 @@ -76,14 +76,15 @@ def testOnlyUpdatableFieldsAreUpdated(self): self.RunAction(admin.UpdateConfiguration, request) # Nothing was updated. - self.assertEqual(config.CONFIG["Client.server_urls"], - [u"http://something.com/"]) + self.assertEqual( + config.CONFIG["Client.server_urls"], ["http://something.com/"] + ) self.assertEqual(config.CONFIG["Client.server_serial_number"], 1) def testGetConfig(self): """Check GetConfig client action works.""" # Use UpdateConfig to generate a config. - location = [u"http://example.com/"] + location = ["http://example.com/"] request = rdf_protodict.Dict() request["Client.server_urls"] = location request["Client.foreman_check_frequency"] = 3600 @@ -130,7 +131,8 @@ def testDoesNotSendInterrogateRequestWhenConfigOptionNotSet(self): def testDoesNotSendInterrogateRequestWhenTriggerFileIsMissing(self): with test_lib.ConfigOverrider( - {"Client.interrogate_trigger_path": "/none/existingpath"}): + {"Client.interrogate_trigger_path": "/none/existingpath"} + ): results = self._RunAction() self.assertLen(results, 1) @@ -141,7 +143,8 @@ def testSendsInterrogateRequestWhenTriggerFileIsPresent(self): trigger_path = fd.name with test_lib.ConfigOverrider( - {"Client.interrogate_trigger_path": trigger_path}): + {"Client.interrogate_trigger_path": trigger_path} + ): results = self._RunAction() # Check that the trigger file got removed. @@ -155,7 +158,8 @@ def testInterrogateIsTriggeredOnlyOnceForOneTriggerFile(self): trigger_path = fd.name with test_lib.ConfigOverrider( - {"Client.interrogate_trigger_path": trigger_path}): + {"Client.interrogate_trigger_path": trigger_path} + ): results = self._RunAction() self.assertLen(results, 1) @@ -172,7 +176,8 @@ def testInterrogateNotRequestedIfTriggerFileCanNotBeRemoved(self, _): trigger_path = fd.name with test_lib.ConfigOverrider( - {"Client.interrogate_trigger_path": trigger_path}): + {"Client.interrogate_trigger_path": trigger_path} + ): results = self._RunAction() self.assertLen(results, 1) diff --git a/grr/client/grr_response_client/client_actions/cloud.py b/grr/client/grr_response_client/client_actions/cloud.py index 96044b0003..d2cfe51233 100644 --- a/grr/client/grr_response_client/client_actions/cloud.py +++ b/grr/client/grr_response_client/client_actions/cloud.py @@ -23,6 +23,7 @@ class GetCloudVMMetadata(actions.ActionPlugin): We make the regexes used to check that data customizable from the server side so we can adapt to minor changes without updating the client. """ + in_rdfvalue = rdf_cloud.CloudMetadataRequests out_rdfvalues = [rdf_cloud.CloudMetadataResponses] @@ -62,6 +63,7 @@ def GetMetaData(self, request): Args: request: CloudMetadataRequest object + Returns: rdf_cloud.CloudMetadataResponse object Raises: @@ -72,7 +74,8 @@ def GetMetaData(self, request): if request.timeout == 0: raise ValueError("Requests library can't handle timeout of 0") result = requests.request( - "GET", request.url, headers=request.headers, timeout=request.timeout) + "GET", request.url, headers=request.headers, timeout=request.timeout + ) # By default requests doesn't raise on HTTP error codes. result.raise_for_status() @@ -82,14 +85,16 @@ def GetMetaData(self, request): raise requests.RequestException(response=result) return rdf_cloud.CloudMetadataResponse( - label=request.label or request.url, text=result.text) + label=request.label or request.url, text=result.text + ) def GetAWSMetadataToken(self) -> str: """Get the session token for IMDSv2.""" result = requests.put( self.AMAZON_TOKEN_URL, headers=self.AMAZON_TOKEN_REQUEST_HEADERS, - timeout=1.0) + timeout=1.0, + ) result.raise_for_status() # Requests does not always raise an exception when an incorrect response @@ -128,8 +133,18 @@ def Run(self, args: rdf_cloud.CloudMetadataRequests): if not aws_metadata_token: aws_metadata_token = self.GetAWSMetadataToken() request.headers[self.AMAZON_TOKEN_HEADER] = aws_metadata_token - result_list.append(self.GetMetaData(request)) + + try: + result_list.append(self.GetMetaData(request)) + except requests.RequestException: + if request.ignore_http_errors: + continue + else: + raise + if result_list: self.SendReply( rdf_cloud.CloudMetadataResponses( - responses=result_list, instance_type=instance_type)) + responses=result_list, instance_type=instance_type + ) + ) diff --git a/grr/client/grr_response_client/client_actions/cloud_test.py b/grr/client/grr_response_client/client_actions/cloud_test.py index 91c679708f..5a8b7d1fa2 100644 --- a/grr/client/grr_response_client/client_actions/cloud_test.py +++ b/grr/client/grr_response_client/client_actions/cloud_test.py @@ -19,36 +19,44 @@ from grr.test_lib import test_lib -@unittest.skipIf(platform.system() == "Darwin", - "OS X cloud machines unsupported.") +@unittest.skipIf( + platform.system() == "Darwin", "OS X cloud machines unsupported." +) class GetCloudVMMetadataTest(client_test_lib.EmptyActionTest): ZONE_URL = "http://metadata.google.internal/computeMetadata/v1/instance/zone" - PROJ_URL = ("http://metadata.google.internal/computeMetadata/" - "v1/project/project-id") + PROJ_URL = ( + "http://metadata.google.internal/computeMetadata/v1/project/project-id" + ) def testBIOSCommandRaises(self): with mock.patch.multiple( cloud.GetCloudVMMetadata, LINUX_BIOS_VERSION_COMMAND="/bin/false", - WINDOWS_SERVICES_COMMAND=["cmd", "/C", "exit 1"]): + WINDOWS_SERVICES_COMMAND=["cmd", "/C", "exit 1"], + ): with self.assertRaises(subprocess.CalledProcessError): self.RunAction( - cloud.GetCloudVMMetadata, arg=rdf_cloud.CloudMetadataRequests()) + cloud.GetCloudVMMetadata, arg=rdf_cloud.CloudMetadataRequests() + ) def testNonMatchingBIOS(self): zone = mock.Mock(text="projects/123456789733/zones/us-central1-a") - arg = rdf_cloud.CloudMetadataRequests(requests=[ - rdf_cloud.CloudMetadataRequest( - bios_version_regex="Google", - instance_type="GOOGLE", - url=self.ZONE_URL, - label="zone", - headers={"Metadata-Flavor": "Google"}) - ]) + arg = rdf_cloud.CloudMetadataRequests( + requests=[ + rdf_cloud.CloudMetadataRequest( + bios_version_regex="Google", + instance_type="GOOGLE", + url=self.ZONE_URL, + label="zone", + headers={"Metadata-Flavor": "Google"}, + ) + ] + ) with mock.patch.object( cloud.GetCloudVMMetadata, "LINUX_BIOS_VERSION_COMMAND", - new=["/bin/echo", "Gaagle"]): + new=["/bin/echo", "Gaagle"], + ): with mock.patch.object(requests, "request") as mock_requests: mock_requests.side_effect = [zone] results = self.RunAction(cloud.GetCloudVMMetadata, arg=arg) @@ -58,30 +66,36 @@ def testNonMatchingBIOS(self): def testWindowsServiceQuery(self): project = mock.Mock(text="myproject") - scquery_output_path = os.path.join(config.CONFIG["Test.data_dir"], - "scquery_output.txt") + scquery_output_path = os.path.join( + config.CONFIG["Test.data_dir"], "scquery_output.txt" + ) with io.open(scquery_output_path, "rb") as filedesc: sc_query_output = filedesc.read() - arg = rdf_cloud.CloudMetadataRequests(requests=[ - rdf_cloud.CloudMetadataRequest( - bios_version_regex=".*amazon", - instance_type="AMAZON", - service_name_regex="SERVICE_NAME: AWSLiteAgent", - url="http://169.254.169.254/latest/meta-data/ami-id", - label="amazon-ami"), - rdf_cloud.CloudMetadataRequest( - bios_version_regex="Google", - instance_type="GOOGLE", - service_name_regex="SERVICE_NAME: GCEAgent", - url=self.PROJ_URL, - label="Google-project-id", - headers={"Metadata-Flavor": "Google"}) - ]) + arg = rdf_cloud.CloudMetadataRequests( + requests=[ + rdf_cloud.CloudMetadataRequest( + bios_version_regex=".*amazon", + instance_type="AMAZON", + service_name_regex="SERVICE_NAME: AWSLiteAgent", + url="http://169.254.169.254/latest/meta-data/ami-id", + label="amazon-ami", + ), + rdf_cloud.CloudMetadataRequest( + bios_version_regex="Google", + instance_type="GOOGLE", + service_name_regex="SERVICE_NAME: GCEAgent", + url=self.PROJ_URL, + label="Google-project-id", + headers={"Metadata-Flavor": "Google"}, + ), + ] + ) with mock.patch.object(platform, "system", return_value="Windows"): with mock.patch.object( - subprocess, "check_output", return_value=sc_query_output): + subprocess, "check_output", return_value=sc_query_output + ): with mock.patch.object(requests, "request") as mock_requests: mock_requests.side_effect = [project] results = self.RunAction(cloud.GetCloudVMMetadata, arg=arg) @@ -97,31 +111,39 @@ def testWindowsServiceQuery(self): def testMultipleBIOSMultipleURLs(self): ami = mock.Mock(text="ami-12345678") - arg = rdf_cloud.CloudMetadataRequests(requests=[ - rdf_cloud.CloudMetadataRequest( - bios_version_regex=".*amazon", - service_name_regex="SERVICE_NAME: AWSLiteAgent", - instance_type="AMAZON", - url="http://169.254.169.254/latest/meta-data/ami-id", - label="amazon-ami"), - rdf_cloud.CloudMetadataRequest( - bios_version_regex="Google", - service_name_regex="SERVICE_NAME: GCEAgent", - instance_type="GOOGLE", - url=self.PROJ_URL, - label="Google-project-id", - headers={"Metadata-Flavor": "Google"}) - ]) + arg = rdf_cloud.CloudMetadataRequests( + requests=[ + rdf_cloud.CloudMetadataRequest( + bios_version_regex=".*amazon", + service_name_regex="SERVICE_NAME: AWSLiteAgent", + instance_type="AMAZON", + url="http://169.254.169.254/latest/meta-data/ami-id", + label="amazon-ami", + ), + rdf_cloud.CloudMetadataRequest( + bios_version_regex="Google", + service_name_regex="SERVICE_NAME: GCEAgent", + instance_type="GOOGLE", + url=self.PROJ_URL, + label="Google-project-id", + headers={"Metadata-Flavor": "Google"}, + ), + ] + ) with mock.patch.multiple( cloud.GetCloudVMMetadata, LINUX_BIOS_VERSION_COMMAND=["/bin/echo", "4.2.amazon"], WINDOWS_SERVICES_COMMAND=[ - "cmd.exe", "/C", "echo SERVICE_NAME: AWSLiteAgent" - ]): + "cmd.exe", + "/C", + "echo SERVICE_NAME: AWSLiteAgent", + ], + ): with mock.patch.object( cloud.GetCloudVMMetadata, "GetAWSMetadataToken", - return_value="testtoken"): + return_value="testtoken", + ): with mock.patch.object(requests, "request") as mock_requests: mock_requests.side_effect = [ami] results = self.RunAction(cloud.GetCloudVMMetadata, arg=arg) @@ -138,29 +160,36 @@ def testMultipleBIOSMultipleURLs(self): def testMatchingBIOSMultipleURLs(self): zone = mock.Mock(text="projects/123456789733/zones/us-central1-a") project = mock.Mock(text="myproject") - arg = rdf_cloud.CloudMetadataRequests(requests=[ - rdf_cloud.CloudMetadataRequest( - bios_version_regex="Google", - service_name_regex="SERVICE_NAME: GCEAgent", - instance_type="GOOGLE", - url=self.ZONE_URL, - label="zone", - headers={"Metadata-Flavor": "Google"}), - rdf_cloud.CloudMetadataRequest( - bios_version_regex="Google", - service_name_regex="SERVICE_NAME: GCEAgent", - instance_type="GOOGLE", - url=self.PROJ_URL, - label="project-id", - headers={"Metadata-Flavor": "Google"}) - ]) + arg = rdf_cloud.CloudMetadataRequests( + requests=[ + rdf_cloud.CloudMetadataRequest( + bios_version_regex="Google", + service_name_regex="SERVICE_NAME: GCEAgent", + instance_type="GOOGLE", + url=self.ZONE_URL, + label="zone", + headers={"Metadata-Flavor": "Google"}, + ), + rdf_cloud.CloudMetadataRequest( + bios_version_regex="Google", + service_name_regex="SERVICE_NAME: GCEAgent", + instance_type="GOOGLE", + url=self.PROJ_URL, + label="project-id", + headers={"Metadata-Flavor": "Google"}, + ), + ] + ) with mock.patch.multiple( cloud.GetCloudVMMetadata, LINUX_BIOS_VERSION_COMMAND=["/bin/echo", "Google"], WINDOWS_SERVICES_COMMAND=[ - "cmd.exe", "/C", "echo SERVICE_NAME: GCEAgent" - ]): + "cmd.exe", + "/C", + "echo SERVICE_NAME: GCEAgent", + ], + ): with mock.patch.object(requests, "request") as mock_requests: mock_requests.side_effect = [zone, project] results = self.RunAction(cloud.GetCloudVMMetadata, arg=arg) @@ -178,20 +207,25 @@ def testMatchingBIOSMultipleURLs(self): def testAWSRequestWithMetadataResponse(self): instanceid = mock.Mock(text="i-001d78bb6472d9d3b") - arg = rdf_cloud.CloudMetadataRequests(requests=[ - rdf_cloud.CloudMetadataRequest( - bios_version_regex=".*amazon", - service_name_regex="SERVICE_NAME: AWSLiteAgent", - instance_type="AMAZON", - url="http://169.254.169.254/latest/meta-data/instance-id", - label="instance-id") - ]) + arg = rdf_cloud.CloudMetadataRequests( + requests=[ + rdf_cloud.CloudMetadataRequest( + bios_version_regex=".*amazon", + service_name_regex="SERVICE_NAME: AWSLiteAgent", + instance_type="AMAZON", + url="http://169.254.169.254/latest/meta-data/instance-id", + label="instance-id", + ) + ] + ) with mock.patch.multiple( cloud.GetCloudVMMetadata, LINUX_BIOS_VERSION_COMMAND=["/bin/echo", "4.2.amazon"], WINDOWS_SERVICES_COMMAND=[ - "cmd.exe", "/C", "echo SERVICE_NAME: AWSLiteAgent" + "cmd.exe", + "/C", + "echo SERVICE_NAME: AWSLiteAgent", ], GetAWSMetadataToken=mock.Mock(return_value="metadataaccesstoken"), ): @@ -223,6 +257,27 @@ def testNonUnicodeOutput(self): self.assertLen(results[0].responses, 1) self.assertEqual(results[0].responses[0].label, "https://foo.bar/quux") + @responses.activate + def testIgnoreHTTPErrors(self): + responses.add(responses.GET, "https://foo.bar/quux", status=404) + + request = rdf_cloud.CloudMetadataRequest() + request.bios_version_regex = "foo" + request.service_name_regex = "foo" + request.url = "https://foo.bar/quux" + request.ignore_http_errors = True + + args = rdf_cloud.CloudMetadataRequests() + args.requests.append(request) + + # We need to mock `subprocess.check_output` to bypass BIOS check that + # normally requires root privileges. + with mock.patch.object(subprocess, "check_output", return_value=b"foo"): + results = self.RunAction(cloud.GetCloudVMMetadata, args) + + # Results should be empty, but the action should not raise. + self.assertEmpty(results) + def main(argv): test_lib.main(argv) diff --git a/grr/client/grr_response_client/client_actions/file_finder.py b/grr/client/grr_response_client/client_actions/file_finder.py index 80b8ed78f0..8f643d6c67 100644 --- a/grr/client/grr_response_client/client_actions/file_finder.py +++ b/grr/client/grr_response_client/client_actions/file_finder.py @@ -6,12 +6,10 @@ from typing import Callable, Iterator, List, Text from grr_response_client import actions -from grr_response_client import client_utils from grr_response_client.client_actions.file_finder_utils import conditions from grr_response_client.client_actions.file_finder_utils import globbing from grr_response_client.client_actions.file_finder_utils import subactions from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.util import filesystem @@ -25,36 +23,6 @@ class _SkipFileException(Exception): pass -def FileFinderOSFromClient( - args: rdf_file_finder.FileFinderArgs) -> Iterator[rdf_client_fs.StatEntry]: - """This function expands paths from the args and returns related stat entries. - - Args: - args: A proto message with arguments for the file finder action. - - Yields: - Stat entries corresponding to the found files. - """ - stat_cache = filesystem.StatCache() - - opts = args.action.stat - - for path in GetExpandedPaths(args): - try: - content_conditions = conditions.ContentCondition.Parse(args.conditions) - for content_condition in content_conditions: - with io.open(path, "rb") as fd: - result = list(content_condition.Search(fd)) - if not result: - raise _SkipFileException() - stat = stat_cache.Get(path, follow_symlink=opts.resolve_links) - stat_entry = client_utils.StatEntryFromStatPathSpec( - stat, ext_attrs=opts.collect_ext_attrs) - yield stat_entry - except _SkipFileException: - pass - - class FileFinderOS(actions.ActionPlugin): """The file finder implementation using the OS file api.""" @@ -65,15 +33,19 @@ def Run(self, args: rdf_file_finder.FileFinderArgs): if args.pathtype != rdf_paths.PathSpec.PathType.OS: raise ValueError( "FileFinderOS can only be used with OS paths, got {}".format( - args.pathspec)) + args.pathspec + ) + ) self.stat_cache = filesystem.StatCache() action = self._ParseAction(args) self._metadata_conditions = list( - conditions.MetadataCondition.Parse(args.conditions)) + conditions.MetadataCondition.Parse(args.conditions) + ) self._content_conditions = list( - conditions.ContentCondition.Parse(args.conditions)) + conditions.ContentCondition.Parse(args.conditions) + ) for path in GetExpandedPaths(args, heartbeat_cb=self.Progress): self.Progress() @@ -86,8 +58,9 @@ def Run(self, args: rdf_file_finder.FileFinderArgs): except _SkipFileException: pass - def _ParseAction(self, - args: rdf_file_finder.FileFinderArgs) -> subactions.Action: + def _ParseAction( + self, args: rdf_file_finder.FileFinderArgs + ) -> subactions.Action: action_type = args.action.action_type if action_type == rdf_file_finder.FileFinderAction.Action.STAT: return subactions.StatAction(self, args.action.stat) @@ -103,8 +76,9 @@ def _GetStat(self, filepath, follow_symlink=True): except OSError: raise _SkipFileException() - def _Validate(self, args: rdf_file_finder.FileFinderArgs, - filepath: Text) -> List[rdf_client.BufferReference]: + def _Validate( + self, args: rdf_file_finder.FileFinderArgs, filepath: Text + ) -> List[rdf_client.BufferReference]: matches = [] stat = self._GetStat(filepath, follow_symlink=bool(args.follow_links)) self._ValidateRegularity(stat, args, filepath) @@ -156,8 +130,12 @@ def _ValidateContent(self, stat, filepath, matches): raise _SkipFileException() for content_condition in self._content_conditions: - with io.open(filepath, "rb") as fd: - result = list(content_condition.Search(fd)) + try: + with io.open(filepath, "rb") as fd: + result = list(content_condition.Search(fd)) + except OSError: + logging.error("Error reading: %s", filepath) + raise _SkipFileException() from None if not result: raise _SkipFileException() matches.extend(result) @@ -165,7 +143,8 @@ def _ValidateContent(self, stat, filepath, matches): def GetExpandedPaths( args: rdf_file_finder.FileFinderArgs, - heartbeat_cb: Callable[[], None] = _NoOp) -> Iterator[Text]: + heartbeat_cb: Callable[[], None] = _NoOp, +) -> Iterator[Text]: """Expands given path patterns. Args: @@ -185,7 +164,8 @@ def GetExpandedPaths( raise ValueError("Unsupported path type: ", args.pathtype) opts = globbing.PathOpts( - follow_links=args.follow_links, xdev=args.xdev, pathtype=pathtype) + follow_links=args.follow_links, xdev=args.xdev, pathtype=pathtype + ) for path in args.paths: for expanded_path in globbing.ExpandPath(str(path), opts, heartbeat_cb): diff --git a/grr/client/grr_response_client/client_actions/file_finder_test.py b/grr/client/grr_response_client/client_actions/file_finder_test.py index f76add5e75..b1c5837704 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_test.py +++ b/grr/client/grr_response_client/client_actions/file_finder_test.py @@ -37,17 +37,19 @@ def setUp(self): def _GetRelativeResults(self, raw_results, base_path=None): base_path = base_path or self.base_path return [ - result.stat_entry.pathspec.path[len(base_path) + 1:] + result.stat_entry.pathspec.path[len(base_path) + 1 :] for result in raw_results ] - def _RunFileFinder(self, - paths, - action, - conditions=None, - follow_links=True, - process_non_regular_files=True, - **kw): + def _RunFileFinder( + self, + paths, + action, + conditions=None, + follow_links=True, + process_non_regular_files=True, + **kw, + ): return self.RunAction( client_file_finder.FileFinderOS, arg=rdf_file_finder.FileFinderArgs( @@ -56,20 +58,24 @@ def _RunFileFinder(self, conditions=conditions, process_non_regular_files=process_non_regular_files, follow_links=follow_links, - **kw)) + **kw, + ), + ) def testFileFinder(self): paths = [self.base_path + "/*"] results = self._RunFileFinder(paths, self.stat_action) self.assertEqual( - self._GetRelativeResults(results), os.listdir(self.base_path)) + self._GetRelativeResults(results), os.listdir(self.base_path) + ) profiles_path = os.path.join(self.base_path, "profiles/v1.0") paths = [os.path.join(self.base_path, "profiles/v1.0") + "/*"] results = self._RunFileFinder(paths, self.stat_action) self.assertEqual( self._GetRelativeResults(results, base_path=profiles_path), - os.listdir(profiles_path)) + os.listdir(profiles_path), + ) def testNonExistentPath(self): paths = [self.base_path + "/does/not/exist/**"] @@ -81,8 +87,9 @@ def testRecursiveGlobCallsProgressWithoutMatches(self): progress = mock.MagicMock() - with mock.patch.object(client_file_finder.FileFinderOS, "Progress", - progress): + with mock.patch.object( + client_file_finder.FileFinderOS, "Progress", progress + ): results = self._RunFileFinder(paths, self.stat_action) self.assertEmpty(results) @@ -189,17 +196,20 @@ def testLinksAndContent(self): paths = [self.temp_dir + "/**"] condition = rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=b"sometext") + literal=b"sometext" + ) results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition], follow_links=True) + paths, self.stat_action, conditions=[condition], follow_links=True + ) self.assertLen(results, 2) relative_results = self._GetRelativeResults(results, base_path=test_dir) self.assertIn("lnk_target/contents", relative_results) self.assertIn("lnk/contents", relative_results) results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition], follow_links=False) + paths, self.stat_action, conditions=[condition], follow_links=False + ) self.assertLen(results, 1) self.assertEqual(results[0].stat_entry.pathspec.path, contents) @@ -276,20 +286,24 @@ def _PrepareTimestampedFiles(self): return test_dir - def RunAndCheck(self, - paths, - action=None, - conditions=None, - expected=None, - unexpected=None, - base_path=None, - **kw): + def RunAndCheck( + self, + paths, + action=None, + conditions=None, + expected=None, + unexpected=None, + base_path=None, + **kw, + ): action = action or self.stat_action raw_results = self._RunFileFinder( - paths, action, conditions=conditions, **kw) + paths, action, conditions=conditions, **kw + ) relative_results = self._GetRelativeResults( - raw_results, base_path=base_path) + raw_results, base_path=base_path + ) for f in unexpected: self.assertNotIn(f, relative_results) @@ -305,12 +319,15 @@ def testLiteralMatchCondition(self): bytes_after = 20 condition = rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=literal, bytes_before=bytes_before, bytes_after=bytes_after) + literal=literal, bytes_before=bytes_before, bytes_after=bytes_after + ) raw_results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition]) + paths, self.stat_action, conditions=[condition] + ) relative_results = self._GetRelativeResults( - raw_results, base_path=searching_path) + raw_results, base_path=searching_path + ) self.assertLen(relative_results, 1) self.assertIn("auth.log", relative_results) self.assertLen(raw_results[0].matches, 1) @@ -321,8 +338,9 @@ def testLiteralMatchCondition(self): self.assertLen(buffer_ref.data, bytes_before + len(literal) + bytes_after) self.assertEqual( - orig_data[buffer_ref.offset:buffer_ref.offset + buffer_ref.length], - buffer_ref.data) + orig_data[buffer_ref.offset : buffer_ref.offset + buffer_ref.length], + buffer_ref.data, + ) def testLiteralMatchConditionAllHits(self): searching_path = os.path.join(self.base_path, "searching") @@ -336,15 +354,18 @@ def testLiteralMatchConditionAllHits(self): literal=literal, mode="ALL_HITS", bytes_before=bytes_before, - bytes_after=bytes_after) + bytes_after=bytes_after, + ) raw_results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition]) + paths, self.stat_action, conditions=[condition] + ) self.assertLen(raw_results, 1) self.assertLen(raw_results[0].matches, 6) for buffer_ref in raw_results[0].matches: self.assertEqual( - buffer_ref.data[bytes_before:bytes_before + len(literal)], literal) + buffer_ref.data[bytes_before : bytes_before + len(literal)], literal + ) def testLiteralMatchConditionLargeFile(self): paths = [os.path.join(self.base_path, "new_places.sqlite")] @@ -357,10 +378,12 @@ def testLiteralMatchConditionLargeFile(self): literal=literal, mode="ALL_HITS", bytes_before=bytes_before, - bytes_after=bytes_after) + bytes_after=bytes_after, + ) raw_results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition]) + paths, self.stat_action, conditions=[condition] + ) self.assertLen(raw_results, 1) self.assertLen(raw_results[0].matches, 1) buffer_ref = raw_results[0].matches[0] @@ -368,23 +391,27 @@ def testLiteralMatchConditionLargeFile(self): fd.seek(buffer_ref.offset) self.assertEqual(buffer_ref.data, fd.read(buffer_ref.length)) self.assertEqual( - buffer_ref.data[bytes_before:bytes_before + len(literal)], literal) + buffer_ref.data[bytes_before : bytes_before + len(literal)], literal + ) def testRegexMatchCondition(self): searching_path = os.path.join(self.base_path, "searching") paths = [searching_path + "/{dpkg.log,dpkg_false.log,auth.log}"] - regex = br"pa[nm]_o?unix\(s{2}h" + regex = rb"pa[nm]_o?unix\(s{2}h" bytes_before = 10 bytes_after = 20 condition = rdf_file_finder.FileFinderCondition.ContentsRegexMatch( - regex=regex, bytes_before=bytes_before, bytes_after=bytes_after) + regex=regex, bytes_before=bytes_before, bytes_after=bytes_after + ) raw_results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition]) + paths, self.stat_action, conditions=[condition] + ) relative_results = self._GetRelativeResults( - raw_results, base_path=searching_path) + raw_results, base_path=searching_path + ) self.assertLen(relative_results, 1) self.assertIn("auth.log", relative_results) self.assertLen(raw_results[0].matches, 1) @@ -394,14 +421,15 @@ def testRegexMatchCondition(self): orig_data = filedesc.read() self.assertEqual( - orig_data[buffer_ref.offset:buffer_ref.offset + buffer_ref.length], - buffer_ref.data) + orig_data[buffer_ref.offset : buffer_ref.offset + buffer_ref.length], + buffer_ref.data, + ) def testRegexMatchConditionAllHits(self): searching_path = os.path.join(self.base_path, "searching") paths = [searching_path + "/{dpkg.log,dpkg_false.log,auth.log}"] - regex = br"mydo....\.com" + regex = rb"mydo....\.com" bytes_before = 10 bytes_after = 20 @@ -409,16 +437,19 @@ def testRegexMatchConditionAllHits(self): regex=regex, mode="ALL_HITS", bytes_before=bytes_before, - bytes_after=bytes_after) + bytes_after=bytes_after, + ) raw_results = self._RunFileFinder( - paths, self.stat_action, conditions=[condition]) + paths, self.stat_action, conditions=[condition] + ) self.assertLen(raw_results, 1) self.assertLen(raw_results[0].matches, 6) for buffer_ref in raw_results[0].matches: needle = b"mydomain.com" - self.assertEqual(buffer_ref.data[bytes_before:bytes_before + len(needle)], - needle) + self.assertEqual( + buffer_ref.data[bytes_before : bytes_before + len(needle)], needle + ) def testContentMatchIgnoreDirsWildcard(self): with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: @@ -435,18 +466,36 @@ def testContentMatchIgnoreDirsWildcard(self): filedesc.write(b"thudfoobaz") condition = rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=b"fooba") + literal=b"fooba" + ) results = self._RunFileFinder( paths=[os.path.join(temp_dirpath, "*")], action=rdf_file_finder.FileFinderAction.Stat(), - conditions=[condition]) + conditions=[condition], + ) result_paths = [result.stat_entry.pathspec.path for result in results] - self.assertCountEqual(result_paths, [ - os.path.join(temp_dirpath, "quux"), - os.path.join(temp_dirpath, "thud"), - ]) + self.assertCountEqual( + result_paths, + [ + os.path.join(temp_dirpath, "quux"), + os.path.join(temp_dirpath, "thud"), + ], + ) + + def testContentMatchIgnoresReadError(self): + condition = rdf_file_finder.FileFinderCondition.ContentsRegexMatch( + regex=b"\\d+" + ) + with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: + results = self._RunFileFinder( + paths=[os.path.join(temp_dirpath, "/**4/nonexistent")], + action=rdf_file_finder.FileFinderAction.Stat(), + conditions=[condition], + ) + + self.assertEmpty(results) def testContentMatchIgnoreDirsRecursive(self): with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: @@ -469,18 +518,23 @@ def testContentMatchIgnoreDirsRecursive(self): filedesc.write(b"456") condition = rdf_file_finder.FileFinderCondition.ContentsRegexMatch( - regex=b"\\d+") + regex=b"\\d+" + ) results = self._RunFileFinder( paths=[os.path.join(temp_dirpath, "**", "*")], action=rdf_file_finder.FileFinderAction.Stat(), - conditions=[condition]) + conditions=[condition], + ) result_paths = [result.stat_entry.pathspec.path for result in results] - self.assertCountEqual(result_paths, [ - os.path.join(temp_dirpath, "foo", "bar", "norf"), - os.path.join(temp_dirpath, "foo", "bar", "ztesch"), - ]) + self.assertCountEqual( + result_paths, + [ + os.path.join(temp_dirpath, "foo", "bar", "norf"), + os.path.join(temp_dirpath, "foo", "bar", "ztesch"), + ], + ) def testHashAction(self): paths = [os.path.join(self.base_path, "win_hello.exe")] @@ -491,15 +545,19 @@ def testHashAction(self): res = results[0] data = open(paths[0], "rb").read() self.assertLen(data, res.hash_entry.num_bytes) - self.assertEqual(res.hash_entry.md5.HexDigest(), - hashlib.md5(data).hexdigest()) - self.assertEqual(res.hash_entry.sha1.HexDigest(), - hashlib.sha1(data).hexdigest()) - self.assertEqual(res.hash_entry.sha256.HexDigest(), - hashlib.sha256(data).hexdigest()) + self.assertEqual( + res.hash_entry.md5.HexDigest(), hashlib.md5(data).hexdigest() + ) + self.assertEqual( + res.hash_entry.sha1.HexDigest(), hashlib.sha1(data).hexdigest() + ) + self.assertEqual( + res.hash_entry.sha256.HexDigest(), hashlib.sha256(data).hexdigest() + ) hash_action = rdf_file_finder.FileFinderAction.Hash( - max_size=100, oversized_file_policy="SKIP") + max_size=100, oversized_file_policy="SKIP" + ) results = self._RunFileFinder(paths, hash_action) self.assertLen(results, 1) @@ -507,19 +565,23 @@ def testHashAction(self): self.assertFalse(res.HasField("hash")) hash_action = rdf_file_finder.FileFinderAction.Hash( - max_size=100, oversized_file_policy="HASH_TRUNCATED") + max_size=100, oversized_file_policy="HASH_TRUNCATED" + ) results = self._RunFileFinder(paths, hash_action) self.assertLen(results, 1) res = results[0] data = open(paths[0], "rb").read()[:100] self.assertLen(data, res.hash_entry.num_bytes) - self.assertEqual(res.hash_entry.md5.HexDigest(), - hashlib.md5(data).hexdigest()) - self.assertEqual(res.hash_entry.sha1.HexDigest(), - hashlib.sha1(data).hexdigest()) - self.assertEqual(res.hash_entry.sha256.HexDigest(), - hashlib.sha256(data).hexdigest()) + self.assertEqual( + res.hash_entry.md5.HexDigest(), hashlib.md5(data).hexdigest() + ) + self.assertEqual( + res.hash_entry.sha1.HexDigest(), hashlib.sha1(data).hexdigest() + ) + self.assertEqual( + res.hash_entry.sha256.HexDigest(), hashlib.sha256(data).hexdigest() + ) def testHashDirectory(self): action = rdf_file_finder.FileFinderAction.Hash() @@ -546,7 +608,8 @@ def testDownloadActionDefault(self): args = rdf_file_finder.FileFinderArgs( action=action, paths=[os.path.join(self.base_path, "win_hello.exe")], - process_non_regular_files=True) + process_non_regular_files=True, + ) transfer_store = MockTransferStore() executor = ClientActionExecutor() @@ -561,11 +624,13 @@ def testDownloadActionDefault(self): def testDownloadActionSkip(self): action = rdf_file_finder.FileFinderAction.Download( - max_size=0, oversized_file_policy="SKIP") + max_size=0, oversized_file_policy="SKIP" + ) args = rdf_file_finder.FileFinderArgs( action=action, paths=[os.path.join(self.base_path, "win_hello.exe")], - process_non_regular_files=True) + process_non_regular_files=True, + ) transfer_store = MockTransferStore() executor = ClientActionExecutor() @@ -579,11 +644,13 @@ def testDownloadActionSkip(self): def testDownloadActionTruncate(self): action = rdf_file_finder.FileFinderAction.Download( - max_size=42, oversized_file_policy="DOWNLOAD_TRUNCATED") + max_size=42, oversized_file_policy="DOWNLOAD_TRUNCATED" + ) args = rdf_file_finder.FileFinderArgs( action=action, paths=[os.path.join(self.base_path, "win_hello.exe")], - process_non_regular_files=True) + process_non_regular_files=True, + ) transfer_store = MockTransferStore() executor = ClientActionExecutor() @@ -598,11 +665,13 @@ def testDownloadActionTruncate(self): def testDownloadActionHash(self): action = rdf_file_finder.FileFinderAction.Download( - max_size=42, oversized_file_policy="HASH_TRUNCATED") + max_size=42, oversized_file_policy="HASH_TRUNCATED" + ) args = rdf_file_finder.FileFinderArgs( action=action, paths=[os.path.join(self.base_path, "win_hello.exe")], - process_non_regular_files=True) + process_non_regular_files=True, + ) transfer_store = MockTransferStore() executor = ClientActionExecutor() @@ -635,9 +704,11 @@ def testStatExtFlags(self): def testStatExtAttrs(self): with temp.AutoTempFilePath() as temp_filepath: filesystem_test_lib.SetExtAttr( - temp_filepath, name=b"user.foo", value=b"norf") + temp_filepath, name=b"user.foo", value=b"norf" + ) filesystem_test_lib.SetExtAttr( - temp_filepath, name=b"user.bar", value=b"quux") + temp_filepath, name=b"user.bar", value=b"quux" + ) action = rdf_file_finder.FileFinderAction.Stat(collect_ext_attrs=True) results = self._RunFileFinder([temp_filepath], action) @@ -724,10 +795,13 @@ def testLinkStat(self): paths = [lnk] link_size = os.lstat(lnk).st_size target_size = os.stat(lnk).st_size - for expected_size, resolve_links in [(link_size, False), - (target_size, True)]: + for expected_size, resolve_links in [ + (link_size, False), + (target_size, True), + ]: stat_action = rdf_file_finder.FileFinderAction.Stat( - resolve_links=resolve_links) + resolve_links=resolve_links + ) results = self._RunFileFinder(paths, stat_action) self.assertLen(results, 1) res = results[0] @@ -746,15 +820,19 @@ def testLinkStatWithProcessNonRegularFilesSetToFalse(self): paths = [lnk] link_size = os.lstat(lnk).st_size target_size = os.stat(lnk).st_size - for expected_size, resolve_links in [(link_size, False), - (target_size, True)]: + for expected_size, resolve_links in [ + (link_size, False), + (target_size, True), + ]: stat_action = rdf_file_finder.FileFinderAction.Stat( - resolve_links=resolve_links) + resolve_links=resolve_links + ) results = self._RunFileFinder( paths, stat_action, follow_links=False, - process_non_regular_files=False) + process_non_regular_files=False, + ) self.assertLen(results, 1) res = results[0] self.assertEqual(res.stat_entry.st_size, expected_size) @@ -775,25 +853,29 @@ def testModificationTimeCondition(self): paths = [os.path.join(temp_dirpath, "{{{}}}".format(",".join(files)))] condition = rdf_file_finder.FileFinderCondition.ModificationTime( - max_last_modified_time=change_time) + max_last_modified_time=change_time + ) self.RunAndCheck( paths, conditions=[condition], expected=files[:2], unexpected=files[2:], - base_path=temp_dirpath) + base_path=temp_dirpath, + ) # Now just the file from 2022. condition = rdf_file_finder.FileFinderCondition.ModificationTime( - min_last_modified_time=change_time) + min_last_modified_time=change_time + ) self.RunAndCheck( paths, conditions=[condition], expected=files[2:], unexpected=files[:2], - base_path=temp_dirpath) + base_path=temp_dirpath, + ) def testAccessTimeCondition(self): with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: @@ -811,25 +893,29 @@ def testAccessTimeCondition(self): # Check we can get the normal files. condition = rdf_file_finder.FileFinderCondition.AccessTime( - max_last_access_time=change_time) + max_last_access_time=change_time + ) self.RunAndCheck( paths, conditions=[condition], expected=files[:2], unexpected=files[2:], - base_path=temp_dirpath) + base_path=temp_dirpath, + ) # Now just the file from 2022. condition = rdf_file_finder.FileFinderCondition.AccessTime( - min_last_access_time=change_time) + min_last_access_time=change_time + ) self.RunAndCheck( paths, conditions=[condition], expected=files[2:], unexpected=files[:2], - base_path=temp_dirpath) + base_path=temp_dirpath, + ) # TODO(hanuszczak): Add tests for change metadata time conditions. @@ -847,7 +933,8 @@ def testSizeCondition(self): conditions=[condition], expected=["auth.log"], unexpected=["dpkg.log", "dpkg_false.log"], - base_path=test_dir) + base_path=test_dir, + ) condition = rdf_file_finder.FileFinderCondition.Size(max_file_size=700) @@ -856,7 +943,8 @@ def testSizeCondition(self): conditions=[condition], expected=["dpkg.log", "dpkg_false.log"], unexpected=["auth.log"], - base_path=test_dir) + base_path=test_dir, + ) def testXDEV(self): test_dir = os.path.join(self.temp_dir, "xdev_test") @@ -910,30 +998,37 @@ def MyStat(path): with utils.MultiStubber( (os, "stat", MyStat), - (globbing, "_GetAllowedDevices", MyGetAllowedDevices)): + (globbing, "_GetAllowedDevices", MyGetAllowedDevices), + ): paths = [test_dir + "/**5"] self.RunAndCheck( paths, expected=[ - "local_dev", "local_dev/local_file", "net_dev", "net_dev/net_file" + "local_dev", + "local_dev/local_file", + "net_dev", + "net_dev/net_file", ], unexpected=[], base_path=test_dir, - xdev="ALWAYS") + xdev="ALWAYS", + ) self.RunAndCheck( paths, expected=["local_dev", "local_dev/local_file", "net_dev"], unexpected=["net_dev/net_file"], base_path=test_dir, - xdev="LOCAL") + xdev="LOCAL", + ) self.RunAndCheck( paths, expected=["local_dev", "net_dev"], unexpected=["local_dev/local_file", "net_dev/net_file"], base_path=test_dir, - xdev="NEVER") + xdev="NEVER", + ) # TODO(hanuszczak): Revist this class after refactoring the GRR client worker @@ -952,9 +1047,9 @@ def RegisterWellKnownFlow(self, wkf): def Execute(self, action_cls, args): responses = list() - def SendReply(value, - session_id=None, - message_type=rdf_flows.GrrMessage.Type.MESSAGE): + def SendReply( + value, session_id=None, message_type=rdf_flows.GrrMessage.Type.MESSAGE + ): if message_type != rdf_flows.GrrMessage.Type.MESSAGE: return @@ -963,7 +1058,8 @@ def SendReply(value, name=action_cls.__name__, payload=value, auth_state="AUTHENTICATED", - session_id=session_id) + session_id=session_id, + ) self.wkfs[str(session_id)].ProcessMessage(message) else: responses.append(value) @@ -972,7 +1068,8 @@ def SendReply(value, name=action_cls.__name__, payload=args, auth_state="AUTHENTICATED", - session_id=rdfvalue.SessionID()) + session_id=rdfvalue.SessionID(), + ) action = action_cls(grr_worker=worker_mocks.FakeClientWorker()) action.SendReply = SendReply @@ -1020,7 +1117,8 @@ def Retrieve(self, blobdesc): blob = RichBlob( data=self.blobs[chunk.digest], offset=chunk.offset, - length=chunk.length) + length=chunk.length, + ) blobs.append(blob) blobs.sort(key=lambda blob: blob.offset) diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/conditions.py b/grr/client/grr_response_client/client_actions/file_finder_utils/conditions.py index e52fc38291..2cf9ab0039 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/conditions.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/conditions.py @@ -176,8 +176,11 @@ def Parse(conditions): OVERLAP_SIZE = 1024 * 1024 CHUNK_SIZE = 10 * 1024 * 1024 - def Scan(self, fd, - matcher: "Matcher") -> Iterator[rdf_client.BufferReference]: + def Scan( + self, + fd, + matcher: "Matcher", + ) -> Iterator[rdf_client.BufferReference]: """Scans given file searching for occurrences of given pattern. Args: @@ -188,7 +191,8 @@ def Scan(self, fd, `BufferReference` objects pointing to file parts with matching content. """ streamer = streaming.Streamer( - chunk_size=self.CHUNK_SIZE, overlap_size=self.OVERLAP_SIZE) + chunk_size=self.CHUNK_SIZE, overlap_size=self.OVERLAP_SIZE + ) offset = self.params.start_offset amount = self.params.length @@ -199,9 +203,8 @@ def Scan(self, fd, ctx_data = chunk.data[ctx_begin:ctx_end] yield rdf_client.BufferReference( - offset=chunk.offset + ctx_begin, - length=len(ctx_data), - data=ctx_data) + offset=chunk.offset + ctx_begin, length=len(ctx_data), data=ctx_data + ) if self.params.mode == self.params.Mode.FIRST_HIT: return diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/conditions_test.py b/grr/client/grr_response_client/client_actions/file_finder_utils/conditions_test.py index a94600434b..a4a251652a 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/conditions_test.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/conditions_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io import os import platform @@ -129,8 +128,10 @@ def Touch(self, mode, date): self.assertEqual(result, 0) -class ModificationTimeConditionTest(MetadataConditionTestMixin, - absltest.TestCase): +class ModificationTimeConditionTest( + MetadataConditionTestMixin, + absltest.TestCase, +): def testDefault(self): params = rdf_file_finder.FileFinderCondition() diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/globbing.py b/grr/client/grr_response_client/client_actions/file_finder_utils/globbing.py index 90cfc6eb9c..a96536bd30 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/globbing.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/globbing.py @@ -8,7 +8,7 @@ import platform import re import stat -from typing import Callable, Iterator, Optional, Text, Iterable +from typing import Callable, Iterable, Iterator, Optional, Text import psutil @@ -41,7 +41,8 @@ def __init__( follow_links: bool = False, xdev: Optional[rdf_structs.EnumNamedValue] = None, pathtype: Optional[rdf_structs.EnumNamedValue] = None, - implementation_type: Optional[rdf_structs.EnumNamedValue] = None): + implementation_type: Optional[rdf_structs.EnumNamedValue] = None, + ): self.follow_links = follow_links self.pathtype = pathtype or rdf_paths.PathSpec.PathType.OS self.implementation_type = implementation_type @@ -93,8 +94,9 @@ def _Generate(self, dirpath, depth): if depth > self.max_depth: return - for item in _ListDir(dirpath, self.opts.pathtype, - self.opts.implementation_type): + for item in _ListDir( + dirpath, self.opts.pathtype, self.opts.implementation_type + ): itempath = os.path.join(dirpath, item) yield itempath @@ -111,8 +113,10 @@ def _Recurse(self, path, depth): # Can happen for links pointing to non existent files/directories. return - if (self._allowed_devices is not _XDEV_ALL_ALLOWED and - stat_entry.st_dev not in self._allowed_devices): + if ( + self._allowed_devices is not _XDEV_ALL_ALLOWED + and stat_entry.st_dev not in self._allowed_devices + ): return if not stat.S_ISDIR(stat_entry.st_mode): @@ -125,7 +129,8 @@ def _Recurse(self, path, depth): elif self.opts.pathtype == rdf_paths.PathSpec.PathType.REGISTRY: pathspec = rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY) + path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY + ) try: with vfs.VFSOpen(pathspec) as filedesc: if not filedesc.IsDirectory(): @@ -153,7 +158,8 @@ def _Recurse(self, path, depth): def __repr__(self): return "RecursiveComponent(max_depth={}, opts={!r})".format( - self.max_depth, self.opts) + self.max_depth, self.opts + ) class GlobComponent(PathComponent): @@ -213,7 +219,8 @@ def _GenerateLiteralMatch(self, dirpath: Text) -> Optional[Text]: pathspec = rdf_paths.PathSpec( path=new_path, pathtype=self.opts.pathtype, - implementation_type=self.opts.implementation_type) + implementation_type=self.opts.implementation_type, + ) try: with vfs.VFSOpen(pathspec) as filedesc: if filedesc.path == "/" and new_path != "/": @@ -240,8 +247,9 @@ def Generate(self, dirpath): yield os.path.join(dirpath, literal_match) return - for item in _ListDir(dirpath, self.opts.pathtype, - self.opts.implementation_type): + for item in _ListDir( + dirpath, self.opts.pathtype, self.opts.implementation_type + ): if self.regex.match(item): yield os.path.join(dirpath, item) @@ -316,8 +324,9 @@ def ParsePathItem(item, opts=None): return RecursiveComponent(max_depth=max_depth, opts=opts) -def ParsePath(path: Text, - opts: Optional[PathOpts] = None) -> Iterator[PathComponent]: +def ParsePath( + path: Text, opts: Optional[PathOpts] = None +) -> Iterator[PathComponent]: """Parses given path into a stream of `PathComponent` instances. Args: @@ -348,9 +357,11 @@ def ParsePath(path: Text, yield component -def ExpandPath(path: Text, - opts: Optional[PathOpts] = None, - heartbeat_cb: Callable[[], None] = _NoOp): +def ExpandPath( + path: Text, + opts: Optional[PathOpts] = None, + heartbeat_cb: Callable[[], None] = _NoOp, +): """Applies all expansion mechanisms to the given path. Args: @@ -386,7 +397,7 @@ def ExpandGroups(path): offset = 0 for match in PATH_GROUP_REGEX.finditer(path): - chunks.append([path[offset:match.start()]]) + chunks.append([path[offset : match.start()]]) chunks.append(match.group("alts").split(",")) offset = match.end() @@ -396,9 +407,11 @@ def ExpandGroups(path): yield "".join(prod) -def ExpandGlobs(path: Text, - opts: Optional[PathOpts] = None, - heartbeat_cb: Callable[[], None] = _NoOp): +def ExpandGlobs( + path: Text, + opts: Optional[PathOpts] = None, + heartbeat_cb: Callable[[], None] = _NoOp, +): """Performs glob expansion on a given path. Path can contain regular glob elements (such as `**`, `*`, `?`, `[a-z]`). For @@ -453,14 +466,16 @@ def _ExpandComponents(basepath, components, index=0, heartbeat_cb=_NoOp): yield basepath return for childpath in components[index].Generate(basepath): - for path in _ExpandComponents(childpath, components, index + 1, - heartbeat_cb): + for path in _ExpandComponents( + childpath, components, index + 1, heartbeat_cb + ): yield path def _ListDir( - dirpath: str, pathtype: rdf_paths.PathSpec.PathType, - implementation_type: rdf_paths.PathSpec.ImplementationType + dirpath: str, + pathtype: rdf_paths.PathSpec.PathType, + implementation_type: rdf_paths.PathSpec.ImplementationType, ) -> Iterable[str]: """Returns children of a given directory. @@ -483,7 +498,8 @@ def _ListDir( return [] pathspec = rdf_paths.PathSpec( - path=dirpath, pathtype=pathtype, implementation_type=implementation_type) + path=dirpath, pathtype=pathtype, implementation_type=implementation_type + ) childpaths = [] try: with vfs.VFSOpen(pathspec) as filedesc: diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/globbing_test.py b/grr/client/grr_response_client/client_actions/file_finder_utils/globbing_test.py index 2817ed08b7..babfe1d63e 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/globbing_test.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/globbing_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import contextlib import io import os @@ -33,35 +32,47 @@ def testSimple(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("foo", "0")), - hierarchy(("foo", "1")), - hierarchy(("foo", "bar")), - hierarchy(("foo", "bar", "0")), - hierarchy(("baz",)), - hierarchy(("baz", "0")), - hierarchy(("baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("foo", "0")), + hierarchy(("foo", "1")), + hierarchy(("foo", "bar")), + hierarchy(("foo", "bar", "0")), + hierarchy(("baz",)), + hierarchy(("baz", "0")), + hierarchy(("baz", "1")), + ], + ) results = list(component.Generate(hierarchy(("foo",)))) - self.assertCountEqual(results, [ - hierarchy(("foo", "0")), - hierarchy(("foo", "1")), - hierarchy(("foo", "bar")), - hierarchy(("foo", "bar", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "0")), + hierarchy(("foo", "1")), + hierarchy(("foo", "bar")), + hierarchy(("foo", "bar", "0")), + ], + ) results = list(component.Generate(hierarchy(("baz",)))) - self.assertCountEqual(results, [ - hierarchy(("baz", "0")), - hierarchy(("baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("baz", "0")), + hierarchy(("baz", "1")), + ], + ) results = list(component.Generate(hierarchy(("foo", "bar")))) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + ], + ) def testMaxDepth(self): filepaths = [ @@ -91,7 +102,8 @@ def testMaxDepth(self): @unittest.skipIf( platform.system() == "Windows", - reason="Symlinks are not available on Windows") + reason="Symlinks are not available on Windows", + ) def testFollowLinks(self): filepaths = [ ("foo", "0"), @@ -112,49 +124,58 @@ def testFollowLinks(self): # It should resolve two links and recur to linked directories. results = list(component.Generate(hierarchy(("quux",)))) - self.assertCountEqual(results, [ - hierarchy(("quux", "0")), - hierarchy(("quux", "bar")), - hierarchy(("quux", "bar", "0")), - hierarchy(("quux", "baz")), - hierarchy(("quux", "baz", "0")), - hierarchy(("quux", "baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("quux", "0")), + hierarchy(("quux", "bar")), + hierarchy(("quux", "bar", "0")), + hierarchy(("quux", "baz")), + hierarchy(("quux", "baz", "0")), + hierarchy(("quux", "baz", "1")), + ], + ) # It should resolve symlinks recursively. results = list(component.Generate(hierarchy(("norf",)))) - self.assertCountEqual(results, [ - hierarchy(("norf", "0")), - hierarchy(("norf", "quux")), - hierarchy(("norf", "quux", "0")), - hierarchy(("norf", "quux", "bar")), - hierarchy(("norf", "quux", "bar", "0")), - hierarchy(("norf", "quux", "baz")), - hierarchy(("norf", "quux", "baz", "0")), - hierarchy(("norf", "quux", "baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("norf", "0")), + hierarchy(("norf", "quux")), + hierarchy(("norf", "quux", "0")), + hierarchy(("norf", "quux", "bar")), + hierarchy(("norf", "quux", "bar", "0")), + hierarchy(("norf", "quux", "baz")), + hierarchy(("norf", "quux", "baz", "0")), + hierarchy(("norf", "quux", "baz", "1")), + ], + ) opts = globbing.PathOpts(follow_links=False) component = globbing.RecursiveComponent(opts=opts) # It should list symlinks but should not recur to linked directories. results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("foo", "0")), - hierarchy(("foo", "bar")), - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "baz")), - hierarchy(("foo", "baz", "0")), - hierarchy(("foo", "baz", "1")), - hierarchy(("quux",)), - hierarchy(("quux", "0")), - hierarchy(("quux", "bar")), - hierarchy(("quux", "baz")), - hierarchy(("norf",)), - hierarchy(("norf", "0")), - hierarchy(("norf", "quux")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("foo", "0")), + hierarchy(("foo", "bar")), + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "baz")), + hierarchy(("foo", "baz", "0")), + hierarchy(("foo", "baz", "1")), + hierarchy(("quux",)), + hierarchy(("quux", "0")), + hierarchy(("quux", "bar")), + hierarchy(("quux", "baz")), + hierarchy(("norf",)), + hierarchy(("norf", "0")), + hierarchy(("norf", "quux")), + ], + ) def testInvalidDirpath(self): component = globbing.RecursiveComponent() @@ -176,16 +197,22 @@ def testLiterals(self): component = globbing.GlobComponent("foo") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + ], + ) component = globbing.GlobComponent("bar") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("bar",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("bar",)), + ], + ) def testStar(self): filepaths = [ @@ -199,20 +226,26 @@ def testStar(self): component = globbing.GlobComponent("*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("bar",)), - hierarchy(("baz",)), - hierarchy(("quux",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("bar",)), + hierarchy(("baz",)), + hierarchy(("quux",)), + ], + ) component = globbing.GlobComponent("ba*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("bar",)), - hierarchy(("baz",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("bar",)), + hierarchy(("baz",)), + ], + ) def testQuestionmark(self): filepaths = [ @@ -226,10 +259,13 @@ def testQuestionmark(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("bar",)), - hierarchy(("baz",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("bar",)), + hierarchy(("baz",)), + ], + ) def testSimpleClass(self): filepaths = [ @@ -243,10 +279,13 @@ def testSimpleClass(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("baz",)), - hierarchy(("bar",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("baz",)), + hierarchy(("bar",)), + ], + ) def testRangeClass(self): filepaths = [ @@ -261,27 +300,36 @@ def testRangeClass(self): component = globbing.GlobComponent("[a-z]*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("bar",)), - hierarchy(("quux42",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("bar",)), + hierarchy(("quux42",)), + ], + ) component = globbing.GlobComponent("[0-9]*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("8AR",)), - hierarchy(("4815162342",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("8AR",)), + hierarchy(("4815162342",)), + ], + ) component = globbing.GlobComponent("*[0-9]") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("4815162342",)), - hierarchy(("quux42",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("4815162342",)), + hierarchy(("quux42",)), + ], + ) def testMultiRangeClass(self): filepaths = [ @@ -295,11 +343,14 @@ def testMultiRangeClass(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("f00",)), - hierarchy(("b4R",)), - hierarchy(("quux",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("f00",)), + hierarchy(("b4R",)), + hierarchy(("quux",)), + ], + ) def testComplementationClass(self): filepaths = [ @@ -312,10 +363,13 @@ def testComplementationClass(self): component = globbing.GlobComponent("*[!0-9]*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("bar",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("bar",)), + ], + ) # TODO(hanuszczak): This test should be split into multiple cases. def testCornerCases(self): @@ -336,11 +390,14 @@ def testCornerCases(self): component = globbing.GlobComponent("[][-]") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("[",)), - hierarchy(("-",)), - hierarchy(("]",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("[",)), + hierarchy(("-",)), + hierarchy(("]",)), + ], + ) component = globbing.GlobComponent("[!]f-]*") @@ -356,14 +413,18 @@ def testCornerCases(self): component = globbing.GlobComponent("[*?]") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("*",)), - hierarchy(("?",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("*",)), + hierarchy(("?",)), + ], + ) @unittest.skipIf( platform.system() == "Windows", - reason="Windows disallows usage of whitespace-only paths") + reason="Windows disallows usage of whitespace-only paths", + ) def testWhitespace(self): filepaths = [ ("foo bar",), @@ -375,10 +436,13 @@ def testWhitespace(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo bar",)), - hierarchy((" ",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo bar",)), + hierarchy((" ",)), + ], + ) def testCaseInsensivity(self): filepaths = [ @@ -391,22 +455,31 @@ def testCaseInsensivity(self): with DirHierarchy(filepaths) as hierarchy: component = globbing.GlobComponent("b*") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("BAR",)), - hierarchy(("BaZ",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("BAR",)), + hierarchy(("BaZ",)), + ], + ) component = globbing.GlobComponent("quux") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("qUuX",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("qUuX",)), + ], + ) component = globbing.GlobComponent("FoO") results = list(component.Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + ], + ) def testUnicodeGlobbing(self): filepaths = [ @@ -416,20 +489,29 @@ def testUnicodeGlobbing(self): with DirHierarchy(filepaths) as hierarchy: results = list(globbing.GlobComponent("ścieżka").Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("ścieżka",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("ścieżka",)), + ], + ) results = list(globbing.GlobComponent("dróżka").Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("dróżka",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("dróżka",)), + ], + ) results = list(globbing.GlobComponent("*żka").Generate(hierarchy(()))) - self.assertCountEqual(results, [ - hierarchy(("ścieżka",)), - hierarchy(("dróżka",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("ścieżka",)), + hierarchy(("dróżka",)), + ], + ) def testUnicodeSubfolderGlobbing(self): filepaths = [ @@ -441,10 +523,13 @@ def testUnicodeSubfolderGlobbing(self): with DirHierarchy(filepaths) as hierarchy: results = list(component.Generate(hierarchy(("zbiór",)))) - self.assertCountEqual(results, [ - hierarchy(("zbiór", "podścieżka")), - hierarchy(("zbiór", "poddróżka")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("zbiór", "podścieżka")), + hierarchy(("zbiór", "poddróżka")), + ], + ) class CurrentComponentTest(absltest.TestCase): @@ -539,22 +624,28 @@ def testSimple(self): path = os.path.join("foo", "**", "ba*") components = list(globbing.ParsePath(path)) - self.assertAreInstances(components, [ - globbing.GlobComponent, - globbing.RecursiveComponent, - globbing.GlobComponent, - ]) + self.assertAreInstances( + components, + [ + globbing.GlobComponent, + globbing.RecursiveComponent, + globbing.GlobComponent, + ], + ) path = os.path.join("foo", os.path.curdir, "bar", "baz", os.path.pardir) components = list(globbing.ParsePath(path)) - self.assertAreInstances(components, [ - globbing.GlobComponent, - globbing.CurrentComponent, - globbing.GlobComponent, - globbing.GlobComponent, - globbing.ParentComponent, - ]) + self.assertAreInstances( + components, + [ + globbing.GlobComponent, + globbing.CurrentComponent, + globbing.GlobComponent, + globbing.GlobComponent, + globbing.ParentComponent, + ], + ) def testMultiRecursive(self): path = os.path.join("foo", "**", "bar", "**", "baz") @@ -569,36 +660,45 @@ def testSimple(self): path = "fooba{r,z}" results = list(globbing.ExpandGroups(path)) - self.assertCountEqual(results, [ - "foobar", - "foobaz", - ]) + self.assertCountEqual( + results, + [ + "foobar", + "foobaz", + ], + ) def testMultiple(self): path = os.path.join("f{o,0}o{bar,baz}", "{quux,norf}") results = list(globbing.ExpandGroups(path)) - self.assertCountEqual(results, [ - os.path.join("foobar", "quux"), - os.path.join("foobar", "norf"), - os.path.join("foobaz", "quux"), - os.path.join("foobaz", "norf"), - os.path.join("f0obar", "quux"), - os.path.join("f0obar", "norf"), - os.path.join("f0obaz", "quux"), - os.path.join("f0obaz", "norf"), - ]) + self.assertCountEqual( + results, + [ + os.path.join("foobar", "quux"), + os.path.join("foobar", "norf"), + os.path.join("foobaz", "quux"), + os.path.join("foobaz", "norf"), + os.path.join("f0obar", "quux"), + os.path.join("f0obar", "norf"), + os.path.join("f0obaz", "quux"), + os.path.join("f0obaz", "norf"), + ], + ) def testMany(self): path = os.path.join("foo{bar,baz,quux,norf}thud") results = list(globbing.ExpandGroups(path)) - self.assertCountEqual(results, [ - os.path.join("foobarthud"), - os.path.join("foobazthud"), - os.path.join("fooquuxthud"), - os.path.join("foonorfthud"), - ]) + self.assertCountEqual( + results, + [ + os.path.join("foobarthud"), + os.path.join("foobazthud"), + os.path.join("fooquuxthud"), + os.path.join("foonorfthud"), + ], + ) def testEmpty(self): path = os.path.join("foo{}bar") @@ -652,11 +752,14 @@ def testWildcards(self): path = hierarchy(("*", "ba?", "0")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("quux", "bar", "0")), - hierarchy(("quux", "baz", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("quux", "bar", "0")), + hierarchy(("quux", "baz", "0")), + ], + ) def testRecursion(self): filepaths = [ @@ -670,11 +773,14 @@ def testRecursion(self): path = hierarchy(("foo", "**", "0")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "baz", "0")), - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "quux", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "baz", "0")), + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "quux", "0")), + ], + ) def testMixed(self): filepaths = [ @@ -692,15 +798,18 @@ def testMixed(self): path = hierarchy(("**", "ba?", "[0-2]")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("norf", "bar", "0")), - hierarchy(("norf", "baz", "0")), - hierarchy(("norf", "baz", "1")), - hierarchy(("quux", "bar", "0")), - hierarchy(("quux", "baz", "1")), - hierarchy(("quux", "baz", "2")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("norf", "bar", "0")), + hierarchy(("norf", "baz", "0")), + hierarchy(("norf", "baz", "1")), + hierarchy(("quux", "bar", "0")), + hierarchy(("quux", "baz", "1")), + hierarchy(("quux", "baz", "2")), + ], + ) def testEmpty(self): with self.assertRaises(ValueError): @@ -721,18 +830,24 @@ def testCurrent(self): path = hierarchy(("foo", os.path.curdir, "bar", "*")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "bar", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "bar", "1")), + ], + ) path = hierarchy((os.path.curdir, "*", "bar", "0")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("quux", "bar", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("quux", "bar", "0")), + ], + ) def testParent(self): filepaths = [ @@ -746,19 +861,25 @@ def testParent(self): path = hierarchy(("foo", "*")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "0")), - hierarchy(("foo", "1")), - hierarchy(("foo", "bar")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "0")), + hierarchy(("foo", "1")), + hierarchy(("foo", "bar")), + ], + ) path = hierarchy(("foo", os.path.pardir, "*")) results = list(globbing.ExpandGlobs(path)) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("bar",)), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("bar",)), + ], + ) class ExpandPathTest(absltest.TestCase): @@ -776,21 +897,27 @@ def testGlobAndGroup(self): with DirHierarchy(filepaths) as hierarchy: path = hierarchy(("foo", "ba{r,z}", "*")) results = list(globbing.ExpandPath(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "bar", "1")), - hierarchy(("foo", "baz", "0")), - hierarchy(("foo", "baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "bar", "1")), + hierarchy(("foo", "baz", "0")), + hierarchy(("foo", "baz", "1")), + ], + ) path = hierarchy(("foo", "ba*", "{0,1}")) results = list(globbing.ExpandPath(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "bar", "1")), - hierarchy(("foo", "baz", "0")), - hierarchy(("foo", "baz", "1")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "bar", "1")), + hierarchy(("foo", "baz", "0")), + hierarchy(("foo", "baz", "1")), + ], + ) def testRecursiveAndGroup(self): filepaths = [ @@ -803,28 +930,34 @@ def testRecursiveAndGroup(self): with DirHierarchy(filepaths) as hierarchy: path = hierarchy(("foo", "**")) results = list(globbing.ExpandPath(path)) - self.assertCountEqual(results, [ - hierarchy(("foo", "0")), - hierarchy(("foo", "1")), - hierarchy(("foo", "bar")), - hierarchy(("foo", "baz")), - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "baz", "quux")), - hierarchy(("foo", "baz", "quux", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo", "0")), + hierarchy(("foo", "1")), + hierarchy(("foo", "bar")), + hierarchy(("foo", "baz")), + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "baz", "quux")), + hierarchy(("foo", "baz", "quux", "0")), + ], + ) path = hierarchy(("foo", "{.,**}")) results = list(globbing.ExpandPath(path)) - self.assertCountEqual(results, [ - hierarchy(("foo",)), - hierarchy(("foo", "0")), - hierarchy(("foo", "1")), - hierarchy(("foo", "bar")), - hierarchy(("foo", "baz")), - hierarchy(("foo", "bar", "0")), - hierarchy(("foo", "baz", "quux")), - hierarchy(("foo", "baz", "quux", "0")), - ]) + self.assertCountEqual( + results, + [ + hierarchy(("foo",)), + hierarchy(("foo", "0")), + hierarchy(("foo", "1")), + hierarchy(("foo", "bar")), + hierarchy(("foo", "baz")), + hierarchy(("foo", "bar", "0")), + hierarchy(("foo", "baz", "quux")), + hierarchy(("foo", "baz", "quux", "0")), + ], + ) class DirHierarchyContext(object): @@ -859,7 +992,8 @@ def __call__(self, components: Sequence[Text]) -> Text: @contextlib.contextmanager def DirHierarchy( - filepaths: Sequence[Sequence[Text]]) -> Iterator[DirHierarchyContext]: + filepaths: Sequence[Sequence[Text]], +) -> Iterator[DirHierarchyContext]: """A context manager that setups a fake directory hierarchy. Args: diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/subactions.py b/grr/client/grr_response_client/client_actions/file_finder_utils/subactions.py index c47236d6dd..0422f22872 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/subactions.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/subactions.py @@ -52,7 +52,8 @@ def Execute(self, filepath, result): stat = stat_cache.Get(filepath, follow_symlink=self.opts.resolve_links) result.stat_entry = client_utils.StatEntryFromStatPathSpec( - stat, ext_attrs=self.opts.collect_ext_attrs) + stat, ext_attrs=self.opts.collect_ext_attrs + ) class HashAction(Action): @@ -74,7 +75,8 @@ def __init__(self, flow, opts): def Execute(self, filepath, result): stat = self.flow.stat_cache.Get(filepath, follow_symlink=True) result.stat_entry = client_utils.StatEntryFromStatPathSpec( - stat, ext_attrs=self.opts.collect_ext_attrs) + stat, ext_attrs=self.opts.collect_ext_attrs + ) if stat.IsDirectory(): return @@ -110,7 +112,8 @@ def __init__(self, flow, opts): def Execute(self, filepath, result): stat = self.flow.stat_cache.Get(filepath, follow_symlink=True) result.stat_entry = client_utils.StatEntryFromStatPathSpec( - stat, ext_attrs=self.opts.collect_ext_attrs) + stat, ext_attrs=self.opts.collect_ext_attrs + ) if stat.IsDirectory(): return diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/subactions_test.py b/grr/client/grr_response_client/client_actions/file_finder_utils/subactions_test.py index a5875f0a3d..2db68b78e0 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/subactions_test.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/subactions_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest # TODO(hanuszczak): Implement basic unit tests for subactions. diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/uploading.py b/grr/client/grr_response_client/client_actions/file_finder_utils/uploading.py index a2ca5dd96d..1c1be5ee69 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/uploading.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/uploading.py @@ -47,7 +47,8 @@ def UploadFilePath(self, filepath, offset=0, amount=None): A `BlobImageDescriptor` object. """ return self._UploadChunkStream( - self._streamer.StreamFilePath(filepath, offset=offset, amount=amount)) + self._streamer.StreamFilePath(filepath, offset=offset, amount=amount) + ) def UploadFile(self, fd, offset=0, amount=None): """Uploads chunks of a given file descriptor to the transfer store flow. @@ -62,7 +63,8 @@ def UploadFile(self, fd, offset=0, amount=None): A `BlobImageDescriptor` object. """ return self._UploadChunkStream( - self._streamer.StreamFile(fd, offset=offset, amount=amount)) + self._streamer.StreamFile(fd, offset=offset, amount=amount) + ) def _UploadChunkStream(self, chunk_stream): chunks = [] @@ -70,7 +72,8 @@ def _UploadChunkStream(self, chunk_stream): chunks.append(self._UploadChunk(chunk)) return rdf_client_fs.BlobImageDescriptor( - chunks=chunks, chunk_size=self._streamer.chunk_size) + chunks=chunks, chunk_size=self._streamer.chunk_size + ) def _UploadChunk(self, chunk): """Uploads a single chunk to the transfer store flow. @@ -89,10 +92,12 @@ def _UploadChunk(self, chunk): return rdf_client_fs.BlobImageChunkDescriptor( digest=hashlib.sha256(chunk.data).digest(), offset=chunk.offset, - length=len(chunk.data)) + length=len(chunk.data), + ) def _CompressedDataBlob(chunk): return rdf_protodict.DataBlob( data=zlib.compress(chunk.data), - compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION) + compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION, + ) diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/uploading_test.py b/grr/client/grr_response_client/client_actions/file_finder_utils/uploading_test.py index 3d89030be1..36540ebca6 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/uploading_test.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/uploading_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import collections import hashlib import io diff --git a/grr/client/grr_response_client/client_actions/file_finder_utils/vfs_subactions.py b/grr/client/grr_response_client/client_actions/file_finder_utils/vfs_subactions.py index 7c117d24df..c575f89506 100644 --- a/grr/client/grr_response_client/client_actions/file_finder_utils/vfs_subactions.py +++ b/grr/client/grr_response_client/client_actions/file_finder_utils/vfs_subactions.py @@ -21,8 +21,11 @@ def __init__(self, action: actions.ActionPlugin): self._action = action @abc.abstractmethod - def __call__(self, stat_entry: rdf_client_fs.StatEntry, - fd: vfs.VFSHandler) -> rdf_file_finder.FileFinderResult: + def __call__( + self, + stat_entry: rdf_client_fs.StatEntry, + fd: vfs.VFSHandler, + ) -> rdf_file_finder.FileFinderResult: """Executes the action on a given file. Args: @@ -46,12 +49,16 @@ class StatAction(Action): def __init__( self, flow, - opts: Optional[rdf_file_finder.FileFinderStatActionOptions] = None): + opts: Optional[rdf_file_finder.FileFinderStatActionOptions] = None, + ): super().__init__(flow) del opts # Unused. - def __call__(self, stat_entry: rdf_client_fs.StatEntry, - fd: vfs.VFSHandler) -> rdf_file_finder.FileFinderResult: + def __call__( + self, + stat_entry: rdf_client_fs.StatEntry, + fd: vfs.VFSHandler, + ) -> rdf_file_finder.FileFinderResult: return rdf_file_finder.FileFinderResult(stat_entry=stat_entry) @@ -67,8 +74,11 @@ def __init__(self, flow, opts: rdf_file_finder.FileFinderHashActionOptions): super().__init__(flow) self._opts = opts - def __call__(self, stat_entry: rdf_client_fs.StatEntry, - fd: vfs.VFSHandler) -> rdf_file_finder.FileFinderResult: + def __call__( + self, + stat_entry: rdf_client_fs.StatEntry, + fd: vfs.VFSHandler, + ) -> rdf_file_finder.FileFinderResult: result = StatAction(self._action)(stat_entry, fd) # stat_entry.st_mode has StatMode type. @@ -86,7 +96,8 @@ def __call__(self, stat_entry: rdf_client_fs.StatEntry, stat_entry, fd, max_size=int(self._opts.max_size), - progress=self._action.Progress) + progress=self._action.Progress, + ) # else: Skip due to OversizedFilePolicy.SKIP. return result @@ -100,13 +111,19 @@ class DownloadAction(Action): file. """ - def __init__(self, flow, - opts: rdf_file_finder.FileFinderDownloadActionOptions): + def __init__( + self, + flow, + opts: rdf_file_finder.FileFinderDownloadActionOptions, + ): super().__init__(flow) self._opts = opts - def __call__(self, stat_entry: rdf_client_fs.StatEntry, - fd: vfs.VFSHandler) -> rdf_file_finder.FileFinderResult: + def __call__( + self, + stat_entry: rdf_client_fs.StatEntry, + fd: vfs.VFSHandler, + ) -> rdf_file_finder.FileFinderResult: result = StatAction(self._action)(stat_entry, fd) # stat_entry.st_mode has StatMode type. @@ -121,7 +138,8 @@ def __call__(self, stat_entry: rdf_client_fs.StatEntry, result.transferred_file = self._UploadFilePath(fd, truncate=truncate) elif policy == self._opts.OversizedFilePolicy.HASH_TRUNCATED: result.hash_entry = _HashEntry( - stat_entry, fd, self._action.Progress, max_size=max_size) + stat_entry, fd, self._action.Progress, max_size=max_size + ) # else: Skip due to OversizedFilePolicy.SKIP. return result @@ -131,14 +149,17 @@ def _UploadFilePath(self, fd, truncate): chunk_size = self._opts.chunk_size uploader = uploading.TransferStoreUploader( - self._action, chunk_size=chunk_size) + self._action, chunk_size=chunk_size + ) return uploader.UploadFile(fd, amount=max_size) -def _HashEntry(stat_entry: rdf_client_fs.StatEntry, - fd: vfs.VFSHandler, - progress: Callable[[], None], - max_size: Optional[int] = None) -> Optional[rdf_crypto.Hash]: +def _HashEntry( + stat_entry: rdf_client_fs.StatEntry, + fd: vfs.VFSHandler, + progress: Callable[[], None], + max_size: Optional[int] = None, +) -> Optional[rdf_crypto.Hash]: hasher = client_utils_common.MultiHasher(progress=progress) try: hasher.HashFile(fd, max_size or stat_entry.st_size) diff --git a/grr/client/grr_response_client/client_actions/file_fingerprint.py b/grr/client/grr_response_client/client_actions/file_fingerprint.py index a1dfceea05..6a9c8a8da9 100644 --- a/grr/client/grr_response_client/client_actions/file_fingerprint.py +++ b/grr/client/grr_response_client/client_actions/file_fingerprint.py @@ -23,6 +23,7 @@ def _GetNextInterval(self): class FingerprintFile(standard.ReadBuffer): """Apply a set of fingerprinting methods to a file.""" + in_rdfvalue = rdf_client_action.FingerprintRequest out_rdfvalues = [rdf_client_action.FingerprintResponse] @@ -34,15 +35,18 @@ class FingerprintFile(standard.ReadBuffer): _fingerprint_types = { rdf_client_action.FingerprintTuple.Type.FPT_GENERIC: ( - fingerprint.Fingerprinter.EvalGeneric), + fingerprint.Fingerprinter.EvalGeneric + ), rdf_client_action.FingerprintTuple.Type.FPT_PE_COFF: ( - fingerprint.Fingerprinter.EvalPecoff), + fingerprint.Fingerprinter.EvalPecoff + ), } def Run(self, args): """Fingerprint a file.""" with vfs.VFSOpen( - args.pathspec, progress_callback=self.Progress) as file_obj: + args.pathspec, progress_callback=self.Progress + ) as file_obj: fingerprinter = Fingerprinter(self.Progress, file_obj) response = rdf_client_action.FingerprintResponse() response.pathspec = file_obj.pathspec @@ -63,7 +67,8 @@ def Run(self, args): response.matching_types.append(finger.fp_type) else: raise RuntimeError( - "Encountered unknown fingerprint type. %s" % finger.fp_type) + "Encountered unknown fingerprint type. %s" % finger.fp_type + ) # Structure of the results is a list of dicts, each containing the # name of the hashing method, hashes for enabled hash algorithms, @@ -88,6 +93,7 @@ def Run(self, args): signed_data = result.GetItem("SignedData", []) for data in signed_data: response.hash.signed_data.Append( - revision=data[0], cert_type=data[1], certificate=data[2]) + revision=data[0], cert_type=data[1], certificate=data[2] + ) self.SendReply(response) diff --git a/grr/client/grr_response_client/client_actions/file_fingerprint_test.py b/grr/client/grr_response_client/client_actions/file_fingerprint_test.py index ce819c77f5..d7efdafbc8 100644 --- a/grr/client/grr_response_client/client_actions/file_fingerprint_test.py +++ b/grr/client/grr_response_client/client_actions/file_fingerprint_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - # Copyright 2010 Google Inc. All Rights Reserved. """Test client vfs.""" @@ -24,19 +23,21 @@ def testHashFile(self): p = rdf_paths.PathSpec(path=path, pathtype=rdf_paths.PathSpec.PathType.OS) result = self.RunAction( file_fingerprint.FingerprintFile, - rdf_client_action.FingerprintRequest(pathspec=p)) + rdf_client_action.FingerprintRequest(pathspec=p), + ) types = result[0].matching_types fingers = {} for f in result[0].results: fingers[f["name"]] = f generic_sha256 = fingers["generic"]["sha256"] - self.assertEqual(generic_sha256, - hashlib.sha256(open(path, "rb").read()).digest()) + self.assertEqual( + generic_sha256, hashlib.sha256(open(path, "rb").read()).digest() + ) # Make sure all fingers are listed in types and vice versa. t_map = { rdf_client_action.FingerprintTuple.Type.FPT_GENERIC: "generic", - rdf_client_action.FingerprintTuple.Type.FPT_PE_COFF: "pecoff" + rdf_client_action.FingerprintTuple.Type.FPT_PE_COFF: "pecoff", } ti_map = dict((v, k) for k, v in t_map.items()) for t in types: @@ -54,7 +55,8 @@ def testMissingFile(self): IOError, self.RunAction, file_fingerprint.FingerprintFile, - rdf_client_action.FingerprintRequest(pathspec=p)) + rdf_client_action.FingerprintRequest(pathspec=p), + ) def main(argv): diff --git a/grr/client/grr_response_client/client_actions/large_file.py b/grr/client/grr_response_client/client_actions/large_file.py index 0a6f3efdd0..7b6827feba 100644 --- a/grr/client/grr_response_client/client_actions/large_file.py +++ b/grr/client/grr_response_client/client_actions/large_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with action for large file collection.""" + from typing import Iterator from grr_response_client import actions diff --git a/grr/client/grr_response_client/client_actions/large_file_test.py b/grr/client/grr_response_client/client_actions/large_file_test.py index 2cf5ea3fa6..512e4a988c 100644 --- a/grr/client/grr_response_client/client_actions/large_file_test.py +++ b/grr/client/grr_response_client/client_actions/large_file_test.py @@ -23,9 +23,13 @@ class CollectLargeFileTest(absltest.TestCase): def setUp(self): super().setUp() - vfs_patcher = mock.patch.object(vfs, "VFS_HANDLERS", { - rdf_paths.PathSpec.PathType.OS: files.File, - }) + vfs_patcher = mock.patch.object( + vfs, + "VFS_HANDLERS", + { + rdf_paths.PathSpec.PathType.OS: files.File, + }, + ) vfs_patcher.start() self.addCleanup(vfs_patcher.stop) diff --git a/grr/client/grr_response_client/client_actions/linux/linux.py b/grr/client/grr_response_client/client_actions/linux/linux.py index 3615c11411..baea5a412c 100644 --- a/grr/client/grr_response_client/client_actions/linux/linux.py +++ b/grr/client/grr_response_client/client_actions/linux/linux.py @@ -7,10 +7,9 @@ import io import os import pwd -import time - from grr_response_client import actions +from grr_response_client.client_actions import osx_linux from grr_response_client.client_actions import standard from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils @@ -33,6 +32,7 @@ class Sockaddrll(ctypes.Structure): """The sockaddr_ll struct.""" + _fields_ = [ ("sll_family", ctypes.c_ushort), ("sll_protocol", ctypes.c_ushort), @@ -56,6 +56,7 @@ class Sockaddrll(ctypes.Structure): class Sockaddrin(ctypes.Structure): """The sockaddr_in struct.""" + _fields_ = [ ("sin_family", ctypes.c_ubyte), ("sin_port", ctypes.c_ushort), @@ -63,6 +64,7 @@ class Sockaddrin(ctypes.Structure): ("sin_zero", ctypes.c_char * 8) ] # pyformat: disable + # struct sockaddr_in6 { # unsigned short int sin6_family; /* AF_INET6 */ # __be16 sin6_port; /* Transport layer port # */ @@ -74,6 +76,7 @@ class Sockaddrin(ctypes.Structure): class Sockaddrin6(ctypes.Structure): """The sockaddr_in6 struct.""" + _fields_ = [ ("sin6_family", ctypes.c_ubyte), ("sin6_port", ctypes.c_ushort), @@ -82,6 +85,7 @@ class Sockaddrin6(ctypes.Structure): ("sin6_scope_id", ctypes.c_ubyte * 4) ] # pyformat: disable + # struct ifaddrs *ifa_next; /* Pointer to next struct */ # char *ifa_name; /* Interface name */ # u_int ifa_flags; /* Interface flags */ @@ -133,7 +137,8 @@ def EnumerateInterfacesFromClient(args): ip4 = bytes(list(data.contents.sin_addr)) address_type = rdf_client_network.NetworkAddress.Family.INET address = rdf_client_network.NetworkAddress( - address_type=address_type, packed_bytes=ip4) + address_type=address_type, packed_bytes=ip4 + ) addresses.setdefault(ifname, []).append(address) if iffamily == 0x11: # AF_PACKET @@ -146,7 +151,8 @@ def EnumerateInterfacesFromClient(args): ip6 = bytes(list(data.contents.sin6_addr)) address_type = rdf_client_network.NetworkAddress.Family.INET6 address = rdf_client_network.NetworkAddress( - address_type=address_type, packed_bytes=ip6) + address_type=address_type, packed_bytes=ip6 + ) addresses.setdefault(ifname, []).append(address) except ValueError: # Some interfaces don't have a iffamily and will raise a null pointer @@ -170,6 +176,7 @@ def EnumerateInterfacesFromClient(args): class EnumerateInterfaces(actions.ActionPlugin): """Enumerates all MAC addresses on this system.""" + out_rdfvalues = [rdf_client_network.Interface] def Run(self, args): @@ -180,6 +187,7 @@ def Run(self, args): class GetInstallDate(actions.ActionPlugin): """Estimate the install date of this system.""" + out_rdfvalues = [rdf_protodict.DataBlob, rdfvalue.RDFDatetime] def Run(self, unused_args): @@ -189,6 +197,7 @@ def Run(self, unused_args): class UtmpStruct(utils.Struct): """Parse wtmp file from utmp.h.""" + _fields = [ ("h", "ut_type"), ("i", "ut_pid"), @@ -225,7 +234,8 @@ def EnumerateUsersFromClient(args): last_login = 0 result = rdf_client.User( - username=username, last_logon=last_login * 1000000) + username=username, last_logon=last_login * 1000000 + ) try: pwdict = pwd.getpwnam(username) @@ -248,6 +258,7 @@ class EnumerateUsers(actions.ActionPlugin): allow for the metadata (homedir) expansion to occur on the client, where we have access to LDAP. """ + # Client versions 3.0.7.1 and older used to return KnowledgeBaseUser. # KnowledgeBaseUser was renamed to User. out_rdfvalues = [rdf_client.User, rdf_client.KnowledgeBaseUser] @@ -293,7 +304,8 @@ def EnumerateFilesystemsFromClient(args): for filename in filenames: for device, fs_type, mnt_point in CheckMounts(filename): yield rdf_client_fs.Filesystem( - mount_point=mnt_point, type=fs_type, device=device) + mount_point=mnt_point, type=fs_type, device=device + ) class EnumerateFilesystems(actions.ActionPlugin): @@ -302,6 +314,7 @@ class EnumerateFilesystems(actions.ActionPlugin): Filesystems picked from: https://www.kernel.org/doc/Documentation/filesystems/ """ + out_rdfvalues = [rdf_client_fs.Filesystem] def Run(self, args): @@ -314,6 +327,7 @@ class EnumerateRunningServices(actions.ActionPlugin): TODO(user): This is a placeholder and needs to be implemented. """ + in_rdfvalue = None out_rdfvalues = [None] @@ -326,73 +340,19 @@ class UpdateAgent(standard.ExecuteBinaryCommand): def ProcessFile(self, path, args): if path.endswith(".deb"): - self._InstallDeb(path, args) + self._InstallDeb(path) elif path.endswith(".rpm"): self._InstallRpm(path) else: - raise ValueError("Unknown suffix for file %s." % path) + raise ValueError(f"Unknown suffix for file {path}.") - def _InstallDeb(self, path, args): - pid = os.fork() - if pid == 0: - # This is the child that will become the installer process. - - # We call os.setsid here to become the session leader of this new session - # and the process group leader of the new process group so we don't get - # killed when the main process exits. - try: - os.setsid() - except OSError: - # This only works if the process is running as root. - pass - - env = os.environ.copy() - env.pop("LD_LIBRARY_PATH", None) - env.pop("PYTHON_PATH", None) - - cmd = "/usr/bin/dpkg" - cmd_args = [cmd, "-i", path] - - os.execve(cmd, cmd_args, env) - else: - # The installer will run in the background and kill the main process - # so we just wait. If something goes wrong, the nanny will restart the - # service after a short while and the client will come back to life. - time.sleep(1000) + def _InstallDeb(self, path): + osx_linux.RunInstallerCmd(["/usr/bin/dpkg", "-i", path]) def _InstallRpm(self, path): - """Client update for rpm based distros. - - Upgrading rpms is a bit more tricky than upgrading deb packages since there - is a preinstall script that kills the running GRR daemon and, thus, also - the installer process. We need to make sure we detach the child process - properly and therefore cannot use client_utils_common.Execute(). - - Args: - path: Path to the .rpm. - """ - - pid = os.fork() - if pid == 0: - # This is the child that will become the installer process. - - cmd = "/bin/rpm" - cmd_args = [cmd, "-U", "--replacepkgs", "--replacefiles", path] - - # We need to clean the environment or rpm will fail - similar to the - # use_client_context=False parameter. - env = os.environ.copy() - env.pop("LD_LIBRARY_PATH", None) - env.pop("PYTHON_PATH", None) - - # This call doesn't return. - os.execve(cmd, cmd_args, env) - - else: - # The installer will run in the background and kill the main process - # so we just wait. If something goes wrong, the nanny will restart the - # service after a short while and the client will come back to life. - time.sleep(1000) + # Note: --replacepkgs --replacefiles are equivalent to --force. + cmd = ["/bin/rpm", "-U", "--replacepkgs", "--replacefiles", path] + osx_linux.RunInstallerCmd(cmd) def _ParseWtmp(): @@ -410,7 +370,7 @@ def _ParseWtmp(): for offset in range(0, len(wtmp), wtmp_struct_size): try: - record = UtmpStruct(wtmp[offset:offset + wtmp_struct_size]) + record = UtmpStruct(wtmp[offset : offset + wtmp_struct_size]) except utils.ParsingError: break diff --git a/grr/client/grr_response_client/client_actions/linux/linux_test.py b/grr/client/grr_response_client/client_actions/linux/linux_test.py index 63ac7b2de6..d6df441414 100644 --- a/grr/client/grr_response_client/client_actions/linux/linux_test.py +++ b/grr/client/grr_response_client/client_actions/linux/linux_test.py @@ -20,16 +20,23 @@ class LinuxOnlyTest(client_test_lib.EmptyActionTest): def testEnumerateUsersLinux(self): """Enumerate users from the wtmp file.""" - def MockedOpen(requested_path, mode="rb"): + def MockedOpen(requested_path, mode="rb", buffering=-1): try: - fixture_path = os.path.join(self.base_path, "VFSFixture", - requested_path.lstrip("/")) - return builtins.open.old_target(fixture_path, mode) + fixture_path = os.path.join( + self.base_path, "VFSFixture", requested_path.lstrip("/") + ) + return builtins.open.old_target( + fixture_path, mode=mode, buffering=buffering + ) except IOError: - return builtins.open.old_target(requested_path, mode) - - with utils.MultiStubber((builtins, "open", MockedOpen), - (glob, "glob", lambda x: ["/var/log/wtmp"])): + return builtins.open.old_target( + requested_path, mode=mode, buffering=buffering + ) + + with utils.MultiStubber( + (builtins, "open", MockedOpen), + (glob, "glob", lambda x: ["/var/log/wtmp"]), + ): results = self.RunAction(linux.EnumerateUsers) found = 0 @@ -66,7 +73,8 @@ def MockCheckMounts(unused_filename): expected = rdf_client_fs.Filesystem( mount_point="/", type="ext4", - device="/dev/mapper/dhcp--100--104--9--24--vg-root") + device="/dev/mapper/dhcp--100--104--9--24--vg-root", + ) self.assertLen(results, 2) for result in results: diff --git a/grr/client/grr_response_client/client_actions/memory.py b/grr/client/grr_response_client/client_actions/memory.py index 1bd7fb65ac..047db16c26 100644 --- a/grr/client/grr_response_client/client_actions/memory.py +++ b/grr/client/grr_response_client/client_actions/memory.py @@ -10,12 +10,15 @@ import platform import re import shutil +from typing import Any from typing import Callable from typing import Dict +from typing import IO from typing import Iterable from typing import Iterator from typing import List from typing import Optional +from typing import Sequence import psutil import yara @@ -36,8 +39,14 @@ from grr_response_core.lib.rdfvalues import paths as rdf_paths -def ProcessIterator(pids, process_regex_string, cmdline_regex_string, - ignore_grr_process, error_list): +def ProcessIterator( + pids: Iterable[int], + process_regex_string: Optional[str], + cmdline_regex_string: Optional[str], + ignore_grr_process: bool, + ignore_parent_processes: bool, + error_list: list[rdf_memory.ProcessMemoryError], +) -> Iterator[psutil.Process]: """Yields all (psutil-) processes that match certain criteria. Args: @@ -48,6 +57,8 @@ def ProcessIterator(pids, process_regex_string, cmdline_regex_string, cmdline_regex_string: If given, only processes whose cmdline matches the regex are returned. ignore_grr_process: If True, the grr process itself will not be returned. + ignore_parent_processes: Whether to skip scanning all parent processes of + the GRR agent. error_list: All errors while handling processes are appended to this list. Type is repeated ProcessMemoryError. @@ -55,10 +66,12 @@ def ProcessIterator(pids, process_regex_string, cmdline_regex_string, psutils.Process objects matching all criteria. """ pids = set(pids) + + ignore_pids: set[int] = set() if ignore_grr_process: - grr_pid = psutil.Process().pid - else: - grr_pid = -1 + ignore_pids.add(psutil.Process().pid) + if ignore_parent_processes: + ignore_pids.update(_.pid for _ in psutil.Process().parents()) if process_regex_string: process_regex = re.compile(process_regex_string) @@ -76,14 +89,20 @@ def ProcessIterator(pids, process_regex_string, cmdline_regex_string, try: process_iterator.append(psutil.Process(pid=pid)) except Exception as e: # pylint: disable=broad-except - error_list.Append( + error_list.append( rdf_memory.ProcessMemoryError( - process=rdf_client.Process(pid=pid), error=str(e))) + process=rdf_client.Process(pid=pid), + error=str(e), + ) + ) else: process_iterator = psutil.process_iter() for p in process_iterator: + if p.pid in ignore_pids: + continue + try: process_name = p.name() except ( @@ -113,12 +132,26 @@ def ProcessIterator(pids, process_regex_string, cmdline_regex_string, if cmdline_regex and not cmdline_regex.search(" ".join(cmdline)): continue - if p.pid == grr_pid: - continue - yield p +def _ShouldIncludeError( + policy: rdf_memory.YaraProcessScanRequest.ErrorPolicy, + error: rdf_memory.ProcessMemoryError, +) -> bool: + """Returns whether the error should be included in the flow response.""" + + if policy == rdf_memory.YaraProcessScanRequest.ErrorPolicy.NO_ERRORS: + return False + + if policy == rdf_memory.YaraProcessScanRequest.ErrorPolicy.CRITICAL_ERRORS: + msg = error.error.lower() + return "failed to open process" not in msg and "access denied" not in msg + + # Fall back to including all errors. + return True + + class YaraWrapperError(Exception): pass @@ -135,16 +168,18 @@ class YaraWrapper(abc.ABC): """Wraps the Yara library.""" @abc.abstractmethod - def Match(self, process, chunks: Iterable[streaming.Chunk], - deadline: rdfvalue.RDFDatetime, - progress: Callable[[], None]) -> Iterator[rdf_memory.YaraMatch]: + def Match( + self, + process, + chunks: Iterable[streaming.Chunk], + deadline: rdfvalue.RDFDatetime, + ) -> Iterator[rdf_memory.YaraMatch]: """Matches the rules in this instance against a chunk of process memory. Args: process: A process opened by `client_utils.OpenProcessForMemoryAccess`. chunks: Chunks to match. The chunks doesn't have `data` set. deadline: Deadline for the match. - progress: Progress callback. Yields: Matches matching the rules. @@ -172,29 +207,42 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: class DirectYaraWrapper(YaraWrapper): """Wrapper for the YARA library.""" - def __init__(self, rules_str: str): + def __init__( + self, rules_str: str, progress: Callable[[], None], context_window: int + ): """Constructor. Args: rules_str: The YARA rules represented as string. + progress: A progress callback + context_window: The amount of bytes to store before and after the match. """ self._rules_str = rules_str self._rules: Optional[yara.Rules] = None - - def Match(self, process, chunks: Iterable[streaming.Chunk], - deadline: rdfvalue.RDFDatetime, - progress: Callable[[], None]) -> Iterator[rdf_memory.YaraMatch]: + self._progress = progress + self._context_window: int = context_window + + def Match( + self, + process, + chunks: Iterable[streaming.Chunk], + deadline: rdfvalue.RDFDatetime, + ) -> Iterator[rdf_memory.YaraMatch]: for chunk in chunks: for match in self._MatchChunk(process, chunk, deadline): yield match - progress() + self._progress() def _MatchChunk( - self, process, chunk: streaming.Chunk, - deadline: rdfvalue.RDFDatetime) -> Iterator[rdf_memory.YaraMatch]: + self, + process, + chunk: streaming.Chunk, + deadline: rdfvalue.RDFDatetime, + ) -> Iterator[rdf_memory.YaraMatch]: """Matches one chunk of memory.""" timeout_secs = (deadline - rdfvalue.RDFDatetime.Now()).ToInt( - rdfvalue.SECONDS) + rdfvalue.SECONDS + ) if self._rules is None: self._rules = yara.compile(source=self._rules_str) data = process.ReadBytes(chunk.offset, chunk.amount) @@ -209,7 +257,9 @@ def _MatchChunk( for offset, _, s in m.strings: if offset + len(s) > chunk.overlap: # We haven't seen this match before. - rdf_match = rdf_memory.YaraMatch.FromLibYaraMatch(m) + rdf_match = rdf_memory.YaraMatch.FromLibYaraMatch( + m, data, self._context_window + ) for string_match in rdf_match.string_matches: string_match.offset += chunk.offset yield rdf_match @@ -249,16 +299,20 @@ def Open(self) -> None: for psutil_process in self._psutil_processes: try: process = stack.enter_context( - client_utils.OpenProcessForMemoryAccess(psutil_process.pid)) + client_utils.OpenProcessForMemoryAccess(psutil_process.pid) + ) except Exception as e: # pylint: disable=broad-except # OpenProcessForMemoryAccess can raise any exception upon error. self._pid_to_exception[psutil_process.pid] = e continue - self._pid_to_serializable_file_descriptor[ - psutil_process.pid] = process.serialized_file_descriptor + self._pid_to_serializable_file_descriptor[psutil_process.pid] = ( + process.serialized_file_descriptor + ) file_descriptors.append( communication.FileDescriptor.FromSerialized( - process.serialized_file_descriptor, communication.Mode.READ)) + process.serialized_file_descriptor, communication.Mode.READ + ) + ) self._server = memory_server.CreateMemoryServer(file_descriptors) self._server.Start() self._client = memory_client.Client(self._server.Connect()) @@ -270,11 +324,16 @@ def Close(self) -> None: def ContainsPid(self, pid: int) -> bool: return pid in self._pids - def Match(self, process, chunks: Iterable[streaming.Chunk], - deadline: rdfvalue.RDFDatetime, - progress: Callable[[], None]) -> Iterator[rdf_memory.YaraMatch]: + def Match( + self, + process, + chunks: Iterable[streaming.Chunk], + deadline: rdfvalue.RDFDatetime, + context_window: int = 0, + ) -> Iterator[rdf_memory.YaraMatch]: timeout_secs = (deadline - rdfvalue.RDFDatetime.Now()).ToInt( - rdfvalue.SECONDS) + rdfvalue.SECONDS + ) if self._client is None: raise ValueError("Client not instantiated.") if not self._rules_uploaded: @@ -292,45 +351,57 @@ def Match(self, process, chunks: Iterable[streaming.Chunk], chunk.offset: chunk.offset + chunk.overlap for chunk in chunks } response = self._client.ProcessScan( - self._pid_to_serializable_file_descriptor[process.pid], chunks_pb, - timeout_secs) + self._pid_to_serializable_file_descriptor[process.pid], + chunks_pb, + timeout_secs, + context_window, + ) if response.status == memory_pb2.ProcessScanResponse.Status.NO_ERROR: - return self._ScanResultToYaraMatches(response.scan_result, - overlap_end_map) - elif (response.status == - memory_pb2.ProcessScanResponse.Status.TOO_MANY_MATCHES): - raise TooManyMatchesError() + return self._ScanResultToYaraMatches( + response.scan_result, overlap_end_map + ) elif ( - response.status == memory_pb2.ProcessScanResponse.Status.TIMEOUT_ERROR): + response.status + == memory_pb2.ProcessScanResponse.Status.TOO_MANY_MATCHES + ): + raise TooManyMatchesError() + elif response.status == memory_pb2.ProcessScanResponse.Status.TIMEOUT_ERROR: raise YaraTimeoutError() else: raise YaraWrapperError() def _ScanResultToYaraMatches( - self, scan_result: memory_pb2.ScanResult, - overlap_end_map: Dict[int, int]) -> Iterator[rdf_memory.YaraMatch]: + self, scan_result: memory_pb2.ScanResult, overlap_end_map: Dict[int, int] + ) -> Iterator[rdf_memory.YaraMatch]: """Converts a scan result from protobuf to RDF.""" for rule_match in scan_result.scan_match: rdf_match = self._RuleMatchToYaraMatch(rule_match) - for string_match, rdf_string_match in zip(rule_match.string_matches, - rdf_match.string_matches): - if rdf_string_match.offset + len( - rdf_string_match.data) > overlap_end_map[string_match.chunk_offset]: + for string_match, rdf_string_match in zip( + rule_match.string_matches, rdf_match.string_matches + ): + if ( + rdf_string_match.offset + len(rdf_string_match.data) + > overlap_end_map[string_match.chunk_offset] + ): yield rdf_match break def _RuleMatchToYaraMatch( - self, rule_match: memory_pb2.RuleMatch) -> rdf_memory.YaraMatch: + self, rule_match: memory_pb2.RuleMatch + ) -> rdf_memory.YaraMatch: result = rdf_memory.YaraMatch() if rule_match.HasField("rule_name"): result.rule_name = rule_match.rule_name for string_match in rule_match.string_matches: result.string_matches.append( - self._StringMatchToYaraStringMatch(string_match)) + self._StringMatchToYaraStringMatch(string_match) + ) return result def _StringMatchToYaraStringMatch( - self, string_match: memory_pb2.StringMatch) -> rdf_memory.YaraStringMatch: + self, string_match: memory_pb2.StringMatch + ) -> rdf_memory.YaraStringMatch: + """Builds a YaraStringMatch from a StringMatch proto object.""" result = rdf_memory.YaraStringMatch() if string_match.HasField("string_id"): result.string_id = string_match.string_id @@ -338,6 +409,8 @@ def _StringMatchToYaraStringMatch( result.offset = string_match.offset if string_match.HasField("data"): result.data = string_match.data + if string_match.HasField("context"): + result.context = string_match.context return result @@ -353,32 +426,43 @@ class BatchedUnprivilegedYaraWrapper(YaraWrapper): # Windows has a limit of 10k handles per process. BATCH_SIZE = 512 - def __init__(self, rules_str: str, psutil_processes: List[psutil.Process]): + def __init__( + self, + rules_str: str, + psutil_processes: List[psutil.Process], + context_window: Optional[int] = None, + ): """Constructor. Args: rules_str: The YARA rules represented as string. psutil_processes: List of processes that can be scanned using `Match`. + context_window: Amount of bytes surrounding the match to return. """ self._batches: List[UnprivilegedYaraWrapper] = [] for i in range(0, len(psutil_processes), self.BATCH_SIZE): - process_batch = psutil_processes[i:i + self.BATCH_SIZE] + process_batch = psutil_processes[i : i + self.BATCH_SIZE] self._batches.append(UnprivilegedYaraWrapper(rules_str, process_batch)) self._current_batch = self._batches.pop(0) - - def Match(self, process, chunks: Iterable[streaming.Chunk], - deadline: rdfvalue.RDFDatetime, - progress: Callable[[], None]) -> Iterator[rdf_memory.YaraMatch]: + self._context_window = context_window or 0 + + def Match( + self, + process, + chunks: Iterable[streaming.Chunk], + deadline: rdfvalue.RDFDatetime, + ) -> Iterator[rdf_memory.YaraMatch]: if not self._current_batch.ContainsPid(process.pid): while True: if not self._batches: raise ValueError( "`_batches` is empty. " "Processes must be passed to `Match` in the same order as they " - "appear in `psutil_processes`.") + "appear in `psutil_processes`." + ) if self._batches[0].ContainsPid(process.pid): break self._batches.pop(0) @@ -386,7 +470,9 @@ def Match(self, process, chunks: Iterable[streaming.Chunk], self._current_batch = self._batches.pop(0) self._current_batch.Open() - yield from self._current_batch.Match(process, chunks, deadline, progress) + yield from self._current_batch.Match( + process, chunks, deadline, self._context_window + ) def Open(self) -> None: self._current_batch.Open() @@ -395,23 +481,52 @@ def Close(self) -> None: self._current_batch.Close() -class YaraProcessScan(actions.ActionPlugin): - """Scans the memory of a number of processes using Yara.""" - in_rdfvalue = rdf_memory.YaraProcessScanRequest - out_rdfvalues = [rdf_memory.YaraProcessScanResponse] +class YaraScanRequestMatcher: + """Applies the yara matching function to a process under constraints of a scan_request.""" MAX_BATCH_SIZE_CHUNKS = 100 - def __init__(self, grr_worker=None): - super().__init__(grr_worker=grr_worker) - self._yara_wrapper = None + def __init__(self, yara_wrapper: YaraWrapper) -> None: + self._yara_wrapper = yara_wrapper + + def GetMatchesForProcess( + self, + psutil_process: psutil.Process, + scan_request: rdf_memory.YaraProcessScanRequest, + ) -> Sequence[rdf_memory.YaraMatch]: + """Scans the memory of a process, applies scan_request constraints.""" + + if scan_request.per_process_timeout: + deadline = rdfvalue.RDFDatetime.Now() + scan_request.per_process_timeout + else: + deadline = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( + 1, rdfvalue.WEEKS + ) + + process = client_utils.OpenProcessForMemoryAccess(pid=psutil_process.pid) + with process: + matches = [] + + try: + for chunks in self._BatchIterateRegions(process, scan_request): + for m in self._ScanRegion(process, chunks, deadline): + matches.append(m) + if 0 < scan_request.max_results_per_process <= len(matches): + return matches + except TooManyMatchesError: + # We need to report this as a hit, not an error. + return matches + + return matches def _ScanRegion( - self, process, chunks: Iterable[streaming.Chunk], - deadline: rdfvalue.RDFDatetime) -> Iterator[rdf_memory.YaraMatch]: + self, + process, + chunks: Iterable[streaming.Chunk], + deadline: rdfvalue.RDFDatetime, + ) -> Iterator[rdf_memory.YaraMatch]: assert self._yara_wrapper is not None - yield from self._yara_wrapper.Match(process, chunks, deadline, - self.Progress) + yield from self._yara_wrapper.Match(process, chunks, deadline) # Windows has 1000-2000 regions per process. # There a lot of small regions consiting of 1 chunk only. @@ -422,9 +537,11 @@ def _ScanRegion( def _BatchIterateRegions( self, process, scan_request: rdf_memory.YaraProcessScanRequest ) -> Iterator[List[streaming.Chunk]]: + """Iterates over regions of a process.""" streamer = streaming.Streamer( chunk_size=scan_request.chunk_size, - overlap_size=scan_request.overlap_size) + overlap_size=scan_request.overlap_size, + ) batch = [] batch_size_bytes = 0 for region in client_utils.MemoryRegions(process, scan_request): @@ -432,69 +549,73 @@ def _BatchIterateRegions( for chunk in chunks: batch.append(chunk) batch_size_bytes += chunk.amount - if (len(batch) >= self.MAX_BATCH_SIZE_CHUNKS or - batch_size_bytes >= scan_request.chunk_size): + if ( + len(batch) >= self.MAX_BATCH_SIZE_CHUNKS + or batch_size_bytes >= scan_request.chunk_size + ): yield batch batch = [] batch_size_bytes = 0 if batch: yield batch - def _GetMatches(self, psutil_process, scan_request): - if scan_request.per_process_timeout: - deadline = rdfvalue.RDFDatetime.Now() + scan_request.per_process_timeout - else: - deadline = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.WEEKS) - process = client_utils.OpenProcessForMemoryAccess(pid=psutil_process.pid) - with process: - matches = [] - - try: - for chunks in self._BatchIterateRegions(process, scan_request): - for m in self._ScanRegion(process, chunks, deadline): - matches.append(m) - if 0 < scan_request.max_results_per_process <= len(matches): - return matches - except TooManyMatchesError: - # We need to report this as a hit, not an error. - return matches +class YaraProcessScan(actions.ActionPlugin): + """Scans the memory of a number of processes using Yara.""" - return matches + in_rdfvalue = rdf_memory.YaraProcessScanRequest + out_rdfvalues = [rdf_memory.YaraProcessScanResponse] # We don't want individual response messages to get too big so we send # multiple responses for 100 processes each. _RESULTS_PER_RESPONSE = 100 - def _ScanProcess(self, process, scan_request, scan_response): + def __init__(self, grr_worker=None): + super().__init__(grr_worker=grr_worker) + self._yara_process_matcher = None + + def _ScanProcess( + self, + process: psutil.Process, + scan_request: rdf_memory.YaraProcessScanRequest, + scan_response: rdf_memory.YaraProcessScanResponse, + matcher: YaraScanRequestMatcher, + ) -> None: rdf_process = rdf_client.Process.FromPsutilProcess(process) start_time = rdfvalue.RDFDatetime.Now() try: - matches = self._GetMatches(process, scan_request) + matches = matcher.GetMatchesForProcess(process, scan_request) scan_time = rdfvalue.RDFDatetime.Now() - start_time scan_time_us = scan_time.ToInt(rdfvalue.MICROSECONDS) except YaraTimeoutError: - scan_response.errors.Append( - rdf_memory.ProcessMemoryError( - process=rdf_process, - error="Scanning timed out (%s)." % - (rdfvalue.RDFDatetime.Now() - start_time))) + err = rdf_memory.ProcessMemoryError( + process=rdf_process, + error="Scanning timed out (%s)." + % (rdfvalue.RDFDatetime.Now() - start_time), + ) + if _ShouldIncludeError(scan_request.include_errors_in_results, err): + scan_response.errors.Append(err) return except Exception as e: # pylint: disable=broad-except - scan_response.errors.Append( - rdf_memory.ProcessMemoryError(process=rdf_process, error=str(e))) + err = rdf_memory.ProcessMemoryError(process=rdf_process, error=str(e)) + if _ShouldIncludeError(scan_request.include_errors_in_results, err): + scan_response.errors.Append(err) return if matches: scan_response.matches.Append( rdf_memory.YaraProcessScanMatch( - process=rdf_process, match=matches, scan_time_us=scan_time_us)) + process=rdf_process, match=matches, scan_time_us=scan_time_us + ) + ) else: - scan_response.misses.Append( - rdf_memory.YaraProcessScanMiss( - process=rdf_process, scan_time_us=scan_time_us)) + if scan_request.include_misses_in_results: + scan_response.misses.Append( + rdf_memory.YaraProcessScanMiss( + process=rdf_process, scan_time_us=scan_time_us + ) + ) def _SaveSignatureShard(self, scan_request): """Writes a YaraSignatureShard received from the server to disk. @@ -510,14 +631,19 @@ def _SaveSignatureShard(self, scan_request): def GetShardName(shard_index, num_shards): return "shard_%02d_of_%02d" % (shard_index, num_shards) - signature_dir = os.path.join(tempfiles.GetDefaultGRRTempDirectory(), - "Sig_%s" % self.session_id.Basename()) + signature_dir = os.path.join( + tempfiles.GetDefaultGRRTempDirectory(), + "Sig_%s" % self.session_id.Basename(), + ) # Create the temporary directory and set permissions, if it does not exist. tempfiles.EnsureTempDirIsSane(signature_dir) shard_path = os.path.join( signature_dir, - GetShardName(scan_request.signature_shard.index, - scan_request.num_signature_shards)) + GetShardName( + scan_request.signature_shard.index, + scan_request.num_signature_shards, + ), + ) with io.open(shard_path, "wb") as f: f.write(scan_request.signature_shard.payload) @@ -541,7 +667,8 @@ def GetShardName(shard_index, num_shards): def Run(self, args): if args.yara_signature or not args.signature_shard.payload: raise ValueError( - "A Yara signature shard is required, and not the full signature.") + "A Yara signature shard is required, and not the full signature." + ) if args.num_signature_shards == 1: # Skip saving to disk if there is just one shard. @@ -556,33 +683,55 @@ def Run(self, args): scan_request.yara_signature = yara_signature scan_response = rdf_memory.YaraProcessScanResponse() processes = list( - ProcessIterator(scan_request.pids, scan_request.process_regex, - scan_request.cmdline_regex, - scan_request.ignore_grr_process, scan_response.errors)) + ProcessIterator( + scan_request.pids, + scan_request.process_regex, + scan_request.cmdline_regex, + scan_request.ignore_grr_process, + scan_request.ignore_parent_processes, + scan_response.errors, + ) + ) if not processes: - scan_response.errors.Append( - rdf_memory.ProcessMemoryError(error="No matching processes to scan.") + err = rdf_memory.ProcessMemoryError( + error="No matching processes to scan." ) + if _ShouldIncludeError(scan_request.include_errors_in_results, err): + scan_response.errors.Append(err) self.SendReply(scan_response) return if self._UseSandboxing(args): - self._yara_wrapper: YaraWrapper = BatchedUnprivilegedYaraWrapper( - str(scan_request.yara_signature), processes) + yara_wrapper: YaraWrapper = BatchedUnprivilegedYaraWrapper( + str(scan_request.yara_signature), + processes, + scan_request.context_window, + ) else: - self._yara_wrapper: YaraWrapper = DirectYaraWrapper( - str(scan_request.yara_signature)) + yara_wrapper: YaraWrapper = DirectYaraWrapper( + str(scan_request.yara_signature), + self.Progress, + scan_request.context_window, + ) - with self._yara_wrapper: + with yara_wrapper: + matcher = YaraScanRequestMatcher(yara_wrapper) for process in processes: self.Progress() num_results = ( - len(scan_response.errors) + len(scan_response.matches) + - len(scan_response.misses)) + len(scan_response.errors) + + len(scan_response.matches) + + len(scan_response.misses) + ) if num_results >= self._RESULTS_PER_RESPONSE: self.SendReply(scan_response) scan_response = rdf_memory.YaraProcessScanResponse() - self._ScanProcess(process, scan_request, scan_response) + self._ScanProcess( + process, + scan_request, + scan_response, + matcher, + ) self.SendReply(scan_response) @@ -590,11 +739,15 @@ def _UseSandboxing(self, args: rdf_memory.YaraProcessScanRequest) -> bool: # Memory sandboxing is currently not supported on macOS. if platform.system() == "Darwin": return False - if (args.implementation_type == - rdf_memory.YaraProcessScanRequest.ImplementationType.DIRECT): + if ( + args.implementation_type + == rdf_memory.YaraProcessScanRequest.ImplementationType.DIRECT + ): return False - elif (args.implementation_type == - rdf_memory.YaraProcessScanRequest.ImplementationType.SANDBOX): + elif ( + args.implementation_type + == rdf_memory.YaraProcessScanRequest.ImplementationType.SANDBOX + ): return True else: return config.CONFIG["Client.use_memory_sandboxing"] @@ -602,7 +755,7 @@ def _UseSandboxing(self, args: rdf_memory.YaraProcessScanRequest) -> bool: def _PrioritizeRegions( regions: Iterable[rdf_memory.ProcessMemoryRegion], - prioritize_offsets: Iterable[int] + prioritize_offsets: Iterable[int], ) -> Iterable[rdf_memory.ProcessMemoryRegion]: """Returns reordered `regions` to prioritize regions containing offsets. @@ -666,8 +819,9 @@ def _PrioritizeRegions( return prio_regions + nonprio_regions + list(all_regions) # pytype: disable=bad-return-type -def _ApplySizeLimit(regions: Iterable[rdf_memory.ProcessMemoryRegion], - size_limit: int) -> List[rdf_memory.ProcessMemoryRegion]: +def _ApplySizeLimit( + regions: Iterable[rdf_memory.ProcessMemoryRegion], size_limit: int +) -> List[rdf_memory.ProcessMemoryRegion]: """Truncates regions so that the total size stays in size_limit.""" total_size = 0 regions_in_limit = [] @@ -682,10 +836,15 @@ def _ApplySizeLimit(regions: Iterable[rdf_memory.ProcessMemoryRegion], class YaraProcessDump(actions.ActionPlugin): """Dumps a process to disk and returns pathspecs for GRR to pick up.""" + in_rdfvalue = rdf_memory.YaraProcessDumpArgs out_rdfvalues = [rdf_memory.YaraProcessDumpResponse] - def _SaveMemDumpToFile(self, fd, chunks): + def _SaveMemDumpToFile( + self, + fd: IO[bytes], + chunks: Iterator[streaming.Chunk], + ) -> int: bytes_written = 0 for chunk in chunks: @@ -697,7 +856,11 @@ def _SaveMemDumpToFile(self, fd, chunks): return bytes_written - def _SaveMemDumpToFilePath(self, filename, chunks): + def _SaveMemDumpToFilePath( + self, + filename: str, + chunks: Iterator[streaming.Chunk], + ) -> int: with open(filename, "wb") as fd: bytes_written = self._SaveMemDumpToFile(fd, chunks) @@ -710,19 +873,30 @@ def _SaveMemDumpToFilePath(self, filename, chunks): return bytes_written - def _SaveRegionToDirectory(self, psutil_process, process, region, tmp_dir, - streamer): + def _SaveRegionToDirectory( + self, + psutil_process: psutil.Process, + process: Any, # Each platform uses a specific type without common base. + region: rdf_memory.ProcessMemoryRegion, + tmp_dir: tempfiles.TemporaryDirectory, + streamer: streaming.Streamer, + ) -> Optional[rdf_paths.PathSpec]: end = region.start + region.size # _ReplaceDumpPathspecsWithMultiGetFilePathspec in DumpProcessMemory # flow asserts that MemoryRegions can be uniquely identified by their # file's basename. - filename = "%s_%d_%x_%x.tmp" % (psutil_process.name(), psutil_process.pid, - region.start, end) + filename = "%s_%d_%x_%x.tmp" % ( + psutil_process.name(), + psutil_process.pid, + region.start, + end, + ) filepath = os.path.join(tmp_dir.path, filename) chunks = streamer.StreamMemory( - process, offset=region.start, amount=region.dumped_size) + process, offset=region.start, amount=region.dumped_size + ) bytes_written = self._SaveMemDumpToFilePath(filepath, chunks) if not bytes_written: @@ -734,9 +908,14 @@ def _SaveRegionToDirectory(self, psutil_process, process, region, tmp_dir, canonical_path = "/" + canonical_path return rdf_paths.PathSpec( - path=canonical_path, pathtype=rdf_paths.PathSpec.PathType.TMPFILE) - - def DumpProcess(self, psutil_process, args): + path=canonical_path, pathtype=rdf_paths.PathSpec.PathType.TMPFILE + ) + + def DumpProcess( + self, + psutil_process: psutil.Process, + args: rdf_memory.YaraProcessScanRequest, + ) -> rdf_memory.YaraProcessDumpInformation: response = rdf_memory.YaraProcessDumpInformation() response.process = rdf_client.Process.FromPsutilProcess(psutil_process) streamer = streaming.Streamer(chunk_size=args.chunk_size) @@ -751,8 +930,9 @@ def DumpProcess(self, psutil_process, args): total_regions = len(regions) regions = _ApplySizeLimit(regions, args.size_limit) if len(regions) < total_regions: - response.error = ("Byte limit exceeded. Writing {} of {} " - "regions.").format(len(regions), total_regions) + response.error = ( + "Byte limit exceeded. Writing {} of {} regions." + ).format(len(regions), total_regions) else: for region in regions: region.dumped_size = region.size @@ -762,25 +942,36 @@ def DumpProcess(self, psutil_process, args): with tempfiles.TemporaryDirectory(cleanup=False) as tmp_dir: for region in regions: self.Progress() - pathspec = self._SaveRegionToDirectory(psutil_process, process, - region, tmp_dir, streamer) + pathspec = self._SaveRegionToDirectory( + psutil_process, process, region, tmp_dir, streamer + ) if pathspec is not None: region.file = pathspec response.memory_regions.Append(region) return response - def Run(self, args): + def Run( + self, + args: rdf_memory.YaraProcessScanRequest, + ) -> None: if args.prioritize_offsets and len(args.pids) != 1: raise ValueError( "Supplied prioritize_offsets {} for PIDs {} in YaraProcessDump. " - "Required exactly one PID.".format(args.prioritize_offsets, - args.pids)) + "Required exactly one PID.".format(args.prioritize_offsets, args.pids) + ) result = rdf_memory.YaraProcessDumpResponse() - - for p in ProcessIterator(args.pids, args.process_regex, None, - args.ignore_grr_process, result.errors): + errors = [] + + for p in ProcessIterator( + args.pids, + args.process_regex, + None, + args.ignore_grr_process, + args.ignore_parent_processes, + errors, + ): self.Progress() start = rdfvalue.RDFDatetime.Now() @@ -793,9 +984,13 @@ def Run(self, args): # Limit exceeded, we bail out early. break except Exception as e: # pylint: disable=broad-except - result.errors.Append( + errors.append( rdf_memory.ProcessMemoryError( - process=rdf_client.Process.FromPsutilProcess(p), error=str(e))) + process=rdf_client.Process.FromPsutilProcess(p), + error=str(e), + ) + ) continue + result.errors = errors self.SendReply(result) diff --git a/grr/client/grr_response_client/client_actions/memory_test.py b/grr/client/grr_response_client/client_actions/memory_test.py index 8b39d9f96a..e4563333f8 100644 --- a/grr/client/grr_response_client/client_actions/memory_test.py +++ b/grr/client/grr_response_client/client_actions/memory_test.py @@ -25,6 +25,16 @@ def setUp(self): config_overrider.Start() self.addCleanup(config_overrider.Stop) + patcher = mock.patch.object( + psutil, + "process_iter", + return_value=[ + Process(0, "foo"), + ], + ) + patcher.start() + self.addCleanup(patcher.stop) + def testSignatureShards_Multiple(self): requests = [ rdf_memory.YaraProcessScanRequest( @@ -32,18 +42,24 @@ def testSignatureShards_Multiple(self): index=0, payload=b"123" ), num_signature_shards=3, + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, ), rdf_memory.YaraProcessScanRequest( signature_shard=rdf_memory.YaraSignatureShard( index=1, payload=b"456" ), num_signature_shards=3, + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, ), rdf_memory.YaraProcessScanRequest( signature_shard=rdf_memory.YaraSignatureShard( index=2, payload=b"789" ), num_signature_shards=3, + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, ), ] flow_id = "01234567" @@ -91,6 +107,8 @@ def testSignatureShards_Single(self): scan_request = rdf_memory.YaraProcessScanRequest( signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), num_signature_shards=1, + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, ) results = self.ExecuteAction( @@ -111,12 +129,14 @@ def testRaisesWhenNoMatchingProcesses(self): signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), num_signature_shards=1, process_regex=invalid_process_regex, + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, ) results = self.ExecuteAction( memory.YaraProcessScan, arg=scan_request, session_id=session_id ) - print(results) + self.assertGreater(len(results), 1) self.assertIsInstance(results[0], rdf_memory.YaraProcessScanResponse) self.assertLen(results[0].errors, 1) @@ -125,6 +145,142 @@ def testRaisesWhenNoMatchingProcesses(self): results[0].errors[0].error, "No matching processes to scan." ) + def testCanExcludesMisses(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + include_misses_in_results=False, + include_errors_in_results=( + rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS + ), + ) + yara_process_scan = memory.YaraProcessScan() + scan_response = rdf_memory.YaraProcessScanResponse() + + process_mock = Process(123, "cmd") + mock_process_matcher = mock.create_autospec(memory.YaraScanRequestMatcher) + mock_process_matcher.GetMatchesForProcess.return_value = [] + + yara_process_scan._ScanProcess( + process_mock, scan_request, scan_response, mock_process_matcher + ) + mock_process_matcher.GetMatchesForProcess.assert_called_once_with( + process_mock, scan_request + ) + self.assertEmpty(scan_response.misses) + self.assertEmpty(scan_response.errors) + self.assertEmpty(scan_response.matches) + + def testCanExcludeErrors_WhenNoProcessesAreFound(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + include_misses_in_results=True, + include_errors_in_results=( + rdf_memory.YaraProcessScanRequest.ErrorPolicy.NO_ERRORS + ), + ) + with mock.patch.object( + memory, + "ProcessIterator", + return_value=(), + ): + results = self.ExecuteAction( + memory.YaraProcessScan, + arg=scan_request, + session_id="C.0123456789abcdef/01234567", + ) + + self.assertLen(results, 2) + self.assertIsInstance(results[0], rdf_memory.YaraProcessScanResponse) + self.assertEmpty(results[0].misses) + self.assertEmpty(results[0].errors) + self.assertEmpty(results[0].matches) + self.assertIsInstance(results[1], rdf_flows.GrrStatus) + + def testProcessScanCanExcludeErrors_WhenTimeoutErrorIsRaised(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + include_misses_in_results=True, + include_errors_in_results=( + rdf_memory.YaraProcessScanRequest.ErrorPolicy.NO_ERRORS + ), + ) + + yara_process_scan = memory.YaraProcessScan() + scan_response = rdf_memory.YaraProcessScanResponse() + + process_mock = Process(123, "cmd") + mock_process_matcher = mock.create_autospec(memory.YaraScanRequestMatcher) + mock_process_matcher.GetMatchesForProcess.side_effect = TimeoutError() + + yara_process_scan._ScanProcess( + process_mock, scan_request, scan_response, mock_process_matcher + ) + mock_process_matcher.GetMatchesForProcess.assert_called_once_with( + process_mock, scan_request + ) + self.assertEmpty(scan_response.misses) + self.assertEmpty(scan_response.errors) + self.assertEmpty(scan_response.matches) + + def testCanExcludeErrors_WhenExceptionIsRaised(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + include_misses_in_results=True, + include_errors_in_results=( + rdf_memory.YaraProcessScanRequest.ErrorPolicy.NO_ERRORS + ), + ) + yara_process_scan = memory.YaraProcessScan() + scan_response = rdf_memory.YaraProcessScanResponse() + + process_mock = Process(123, "cmd") + mock_process_matcher = mock.create_autospec(memory.YaraScanRequestMatcher) + mock_process_matcher.GetMatchesForProcess.side_effect = Exception("any") + + yara_process_scan._ScanProcess( + process_mock, scan_request, scan_response, mock_process_matcher + ) + mock_process_matcher.GetMatchesForProcess.assert_called_once_with( + process_mock, scan_request + ) + self.assertEmpty(scan_response.misses) + self.assertEmpty(scan_response.errors) + self.assertEmpty(scan_response.matches) + + def testCanExcludeNonCriticalErrors(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + include_misses_in_results=True, + include_errors_in_results=( + rdf_memory.YaraProcessScanRequest.ErrorPolicy.CRITICAL_ERRORS + ), + ) + yara_process_scan = memory.YaraProcessScan() + scan_response = rdf_memory.YaraProcessScanResponse() + + process_mock = Process(123, "cmd") + mock_process_matcher = mock.create_autospec(memory.YaraScanRequestMatcher) + # Filtering errors based on lower case string matching, so adding some + # capital letters to `access denied`. + mock_process_matcher.GetMatchesForProcess.side_effect = Exception( + "Any ACCESS Denied is a non-critical error" + ) + + yara_process_scan._ScanProcess( + process_mock, scan_request, scan_response, mock_process_matcher + ) + mock_process_matcher.GetMatchesForProcess.assert_called_once_with( + process_mock, scan_request + ) + self.assertEmpty(scan_response.misses) + self.assertEmpty(scan_response.errors) + self.assertEmpty(scan_response.matches) + def R(start, size): """Returns a new ProcessMemoryRegion with the given start and size.""" @@ -228,6 +384,7 @@ def GetProcessIteratorPids( process_regex_string=None, cmdline_regex_string=None, ignore_grr_process=False, + ignore_parent_processes=False, ): return [ p.pid @@ -236,6 +393,7 @@ def GetProcessIteratorPids( process_regex_string, cmdline_regex_string, ignore_grr_process, + ignore_parent_processes, [], ) ] @@ -283,17 +441,131 @@ def testCmdlineRegex(self): signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), num_signature_shards=1, cmdline_regex="svchost.exe -k def", + include_misses_in_results=True, + include_errors_in_results=rdf_memory.YaraProcessScanRequest.ErrorPolicy.ALL_ERRORS, + ) + yara_process_scan = memory.YaraProcessScan() + scan_response = rdf_memory.YaraProcessScanResponse() + + process_mock = Process(1, "cmd") + mock_process_matcher = mock.create_autospec(memory.YaraScanRequestMatcher) + mock_process_matcher.GetMatchesForProcess.return_value = [ + rdf_memory.YaraMatch() + ] + + yara_process_scan._ScanProcess( + process_mock, scan_request, scan_response, mock_process_matcher + ) + mock_process_matcher.GetMatchesForProcess.assert_called_once_with( + process_mock, scan_request + ) + self.assertEmpty(scan_response.misses) + self.assertEmpty(scan_response.errors) + self.assertLen(scan_response.matches, 1) + self.assertEqual(scan_response.matches[0].process.pid, 1) + + +class ProcessIteratorTest(absltest.TestCase): + + def testNoIgnores(self): + iterator = memory.ProcessIterator( + pids=[], + process_regex_string=None, + cmdline_regex_string=None, + ignore_grr_process=False, + ignore_parent_processes=False, + error_list=[], + ) + + pids = set(_.pid for _ in iterator) + self.assertIn(os.getpid(), pids) + self.assertIn(os.getppid(), pids) + + def testIgnoreGRRProcess(self): + iterator = memory.ProcessIterator( + pids=[], + process_regex_string=None, + cmdline_regex_string=None, + ignore_grr_process=True, + ignore_parent_processes=False, + error_list=[], + ) + + pids = set(_.pid for _ in iterator) + self.assertNotEmpty(pids) + self.assertNotIn(os.getpid(), pids) + + def testIgnoreParentProcesses(self): + iterator = memory.ProcessIterator( + pids=[], + process_regex_string=None, + cmdline_regex_string=None, + ignore_grr_process=False, + ignore_parent_processes=True, + error_list=[], + ) + + pids = set(_.pid for _ in iterator) + self.assertNotEmpty(pids) + self.assertNotIn(os.getppid(), pids) + + +class ParametersTest(client_test_lib.EmptyActionTest): + + def setUp(self): + super().setUp() + patcher = mock.patch.object( + psutil, + "process_iter", + return_value=[ + Process(3, "foo"), + ], + ) + patcher.start() + self.addCleanup(patcher.stop) + + def testContextWindowDefaultValue(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, ) + with mock.patch.object( + memory.YaraScanRequestMatcher, + "GetMatchesForProcess", + return_value=[ + rdf_memory.YaraMatch( + string_matches=[ + rdf_memory.YaraStringMatch( + string_id="$", + offset=0, + data=b"bla", + context=b"blablabla", + ) + ] + ) + ], + ) as mock_get_matches: + results = self.ExecuteAction(memory.YaraProcessScan, arg=scan_request) + _, scan_request = mock_get_matches.call_args_list[0][0] + self.assertEqual( + results[0].matches[0].match[0].string_matches[0].context, b"blablabla" + ) + self.assertEqual(scan_request.context_window, 50) + def testContextWindowCustomValue(self): + scan_request = rdf_memory.YaraProcessScanRequest( + signature_shard=rdf_memory.YaraSignatureShard(index=0, payload=b"123"), + num_signature_shards=1, + context_window=100, + ) with mock.patch.object( - memory.YaraProcessScan, - "_GetMatches", + memory.YaraScanRequestMatcher, + "GetMatchesForProcess", return_value=[rdf_memory.YaraMatch()], - ): - results = self.ExecuteAction(memory.YaraProcessScan, arg=scan_request) - self.assertLen(results, 2) - self.assertLen(results[0].matches, 1) - self.assertEqual(results[0].matches[0].process.pid, 1) + ) as mock_get_matches: + self.ExecuteAction(memory.YaraProcessScan, arg=scan_request) + _, scan_request = mock_get_matches.call_args_list[0][0] + self.assertEqual(scan_request.context_window, 100) if __name__ == "__main__": diff --git a/grr/client/grr_response_client/client_actions/network.py b/grr/client/grr_response_client/client_actions/network.py index e298013f94..36f1fa03fe 100644 --- a/grr/client/grr_response_client/client_actions/network.py +++ b/grr/client/grr_response_client/client_actions/network.py @@ -2,6 +2,7 @@ """Get Information about network states.""" import logging + import psutil from grr_response_client import actions @@ -37,8 +38,9 @@ def ListNetworkConnectionsFromClient(args): if conn.status: res.state = conn.status except ValueError: - logging.warning("Encountered unknown connection status (%s).", - conn.status) + logging.warning( + "Encountered unknown connection status (%s).", conn.status + ) res.local_address.ip, res.local_address.port = conn.laddr if conn.raddr: @@ -49,6 +51,7 @@ def ListNetworkConnectionsFromClient(args): class ListNetworkConnections(actions.ActionPlugin): """Gather open network connection stats.""" + in_rdfvalue = rdf_client_action.ListNetworkConnectionsArgs out_rdfvalues = [rdf_client_network.NetworkConnection] diff --git a/grr/client/grr_response_client/client_actions/network_test.py b/grr/client/grr_response_client/client_actions/network_test.py index d5ceb9c6b6..0bb9b30d22 100644 --- a/grr/client/grr_response_client/client_actions/network_test.py +++ b/grr/client/grr_response_client/client_actions/network_test.py @@ -15,7 +15,8 @@ class NetstatActionTest(client_test_lib.EmptyActionTest): def testListNetworkConnections(self): result = self.RunAction( network.ListNetworkConnections, - arg=rdf_client_action.ListNetworkConnectionsArgs()) + arg=rdf_client_action.ListNetworkConnectionsArgs(), + ) for r in result: self.assertTrue(r.process_name) self.assertTrue(r.local_address) @@ -23,7 +24,8 @@ def testListNetworkConnections(self): def testListNetworkConnectionsFilter(self): result = self.RunAction( network.ListNetworkConnections, - arg=rdf_client_action.ListNetworkConnectionsArgs(listening_only=True)) + arg=rdf_client_action.ListNetworkConnectionsArgs(listening_only=True), + ) for r in result: self.assertTrue(r.process_name) self.assertTrue(r.local_address) diff --git a/grr/client/grr_response_client/client_actions/operating_system.py b/grr/client/grr_response_client/client_actions/operating_system.py index 9cbbf44030..2b32d8bbd3 100644 --- a/grr/client/grr_response_client/client_actions/operating_system.py +++ b/grr/client/grr_response_client/client_actions/operating_system.py @@ -9,12 +9,15 @@ # These imports populate the Action registry if platform.system() == "Linux": from grr_response_client.client_actions.linux import linux + submodule = linux elif platform.system() == "Windows": from grr_response_client.client_actions.windows import windows + submodule = windows elif platform.system() == "Darwin": from grr_response_client.client_actions.osx import osx + submodule = osx else: raise RuntimeError("Unknown platform.system() {!r}".format(platform.system())) diff --git a/grr/client/grr_response_client/client_actions/osquery.py b/grr/client/grr_response_client/client_actions/osquery.py index 59ca7d1713..6433a91dc2 100644 --- a/grr/client/grr_response_client/client_actions/osquery.py +++ b/grr/client/grr_response_client/client_actions/osquery.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with client action for talking with osquery.""" + import json import logging import os @@ -48,12 +49,17 @@ def Process( self, args: rdf_osquery.OsqueryArgs ) -> Iterator[rdf_osquery.OsqueryResult]: if not config.CONFIG["Osquery.path"]: - raise RuntimeError("The `Osquery` action invoked on a client without " - "osquery path specified.") + raise RuntimeError( + "The `Osquery` action invoked on a client without " + "osquery path specified." + ) if not os.path.exists(config.CONFIG["Osquery.path"]): - raise RuntimeError("The `Osquery` action invoked on a client where " - "osquery executable is not available.") + raise RuntimeError( + "The `Osquery` action invoked on a client where " + "the specified osquery executable " + f"({config.CONFIG['Osquery.path']!r}) is not available." + ) if not args.query: raise ValueError("The `Osquery` was invoked with an empty query.") @@ -72,8 +78,9 @@ def Process( yield rdf_osquery.OsqueryResult(table=chunk) -def ChunkTable(table: rdf_osquery.OsqueryTable, - max_chunk_size: int) -> Iterator[rdf_osquery.OsqueryTable]: +def ChunkTable( + table: rdf_osquery.OsqueryTable, max_chunk_size: int +) -> Iterator[rdf_osquery.OsqueryTable]: """Chunks given table into multiple smaller ones. Tables that osquery yields can be arbitrarily large. Because GRR's messages @@ -172,8 +179,9 @@ def ParseHeader(table: Any) -> rdf_osquery.OsqueryHeader: return result -def ParseRow(header: rdf_osquery.OsqueryHeader, - row: Any) -> rdf_osquery.OsqueryRow: +def ParseRow( + header: rdf_osquery.OsqueryHeader, row: Any +) -> rdf_osquery.OsqueryRow: """Parses a single row of osquery output. Args: diff --git a/grr/client/grr_response_client/client_actions/osquery_test.py b/grr/client/grr_response_client/client_actions/osquery_test.py index 73fcae7e03..7f4d6bbba0 100644 --- a/grr/client/grr_response_client/client_actions/osquery_test.py +++ b/grr/client/grr_response_client/client_actions/osquery_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with client action for talking with osquery.""" + import hashlib import io import json @@ -30,8 +31,9 @@ def _Query(query: Text, **kwargs) -> List[rdf_osquery.OsqueryResult]: return list(osquery.Osquery().Process(args)) -@skip.Unless(lambda: config.CONFIG["Osquery.path"], - "osquery path not specified") +@skip.Unless( + lambda: config.CONFIG["Osquery.path"], "osquery path not specified" +) class OsqueryTest(absltest.TestCase): @classmethod @@ -131,16 +133,20 @@ def testFile(self): table = results[0].table self.assertLen(table.rows, 3) self.assertEqual( - list(table.Column("path")), [ + list(table.Column("path")), + [ os.path.join(dirpath, "abc"), os.path.join(dirpath, "def"), os.path.join(dirpath, "ghi"), - ]) + ], + ) self.assertEqual(list(table.Column("size")), ["3", "6", "4"]) # TODO(hanuszczak): https://github.com/osquery/osquery/issues/4150 - @skip.If(platform.system() == "Windows", - "osquery ignores files with unicode characters.") + @skip.If( + platform.system() == "Windows", + "osquery ignores files with unicode characters.", + ) def testFileUnicode(self): with temp.AutoTempFilePath(prefix="zółć", suffix="💰") as filepath: with io.open(filepath, "wb") as filedesc: @@ -306,19 +312,28 @@ def testSingleRowChunks(self): self.assertLen(chunks, 3) self.assertEqual(chunks[0].query, table.query) self.assertEqual(chunks[0].header, table.header) - self.assertEqual(chunks[0].rows, [ - rdf_osquery.OsqueryRow(values=["ABC", "DEF", "GHI"]), - ]) + self.assertEqual( + chunks[0].rows, + [ + rdf_osquery.OsqueryRow(values=["ABC", "DEF", "GHI"]), + ], + ) self.assertEqual(chunks[1].query, table.query) self.assertEqual(chunks[1].header, table.header) - self.assertEqual(chunks[1].rows, [ - rdf_osquery.OsqueryRow(values=["JKL", "MNO", "PQR"]), - ]) + self.assertEqual( + chunks[1].rows, + [ + rdf_osquery.OsqueryRow(values=["JKL", "MNO", "PQR"]), + ], + ) self.assertEqual(chunks[2].query, table.query) self.assertEqual(chunks[2].header, table.header) - self.assertEqual(chunks[2].rows, [ - rdf_osquery.OsqueryRow(values=["RST", "UVW", "XYZ"]), - ]) + self.assertEqual( + chunks[2].rows, + [ + rdf_osquery.OsqueryRow(values=["RST", "UVW", "XYZ"]), + ], + ) def testMultiRowChunks(self): table = rdf_osquery.OsqueryTable() @@ -339,24 +354,33 @@ def testMultiRowChunks(self): self.assertLen(chunks, 3) self.assertEqual(chunks[0].query, table.query) self.assertEqual(chunks[0].header, table.header) - self.assertEqual(chunks[0].rows, [ - rdf_osquery.OsqueryRow(values=["A", "B", "C"]), - rdf_osquery.OsqueryRow(values=["D", "E", "F"]), - rdf_osquery.OsqueryRow(values=["G", "H", "I"]), - ]) + self.assertEqual( + chunks[0].rows, + [ + rdf_osquery.OsqueryRow(values=["A", "B", "C"]), + rdf_osquery.OsqueryRow(values=["D", "E", "F"]), + rdf_osquery.OsqueryRow(values=["G", "H", "I"]), + ], + ) self.assertEqual(chunks[1].query, table.query) self.assertEqual(chunks[1].header, table.header) - self.assertEqual(chunks[1].rows, [ - rdf_osquery.OsqueryRow(values=["J", "K", "L"]), - rdf_osquery.OsqueryRow(values=["M", "N", "O"]), - rdf_osquery.OsqueryRow(values=["P", "Q", "R"]), - ]) + self.assertEqual( + chunks[1].rows, + [ + rdf_osquery.OsqueryRow(values=["J", "K", "L"]), + rdf_osquery.OsqueryRow(values=["M", "N", "O"]), + rdf_osquery.OsqueryRow(values=["P", "Q", "R"]), + ], + ) self.assertEqual(chunks[2].query, table.query) self.assertEqual(chunks[2].header, table.header) - self.assertEqual(chunks[2].rows, [ - rdf_osquery.OsqueryRow(values=["S", "T", "U"]), - rdf_osquery.OsqueryRow(values=["V", "W", "X"]), - ]) + self.assertEqual( + chunks[2].rows, + [ + rdf_osquery.OsqueryRow(values=["S", "T", "U"]), + rdf_osquery.OsqueryRow(values=["V", "W", "X"]), + ], + ) def testMultiByteStrings(self): table = rdf_osquery.OsqueryTable() @@ -369,12 +393,15 @@ def testMultiByteStrings(self): chunks = list(osquery.ChunkTable(table, max_chunk_size=10)) self.assertLen(chunks, 3) - self.assertEqual(chunks[0].rows, - [rdf_osquery.OsqueryRow(values=["🐔", "🐓"])]) - self.assertEqual(chunks[1].rows, - [rdf_osquery.OsqueryRow(values=["🐣", "🐤"])]) - self.assertEqual(chunks[2].rows, - [rdf_osquery.OsqueryRow(values=["🐥", "🦆"])]) + self.assertEqual( + chunks[0].rows, [rdf_osquery.OsqueryRow(values=["🐔", "🐓"])] + ) + self.assertEqual( + chunks[1].rows, [rdf_osquery.OsqueryRow(values=["🐣", "🐤"])] + ) + self.assertEqual( + chunks[2].rows, [rdf_osquery.OsqueryRow(values=["🐥", "🦆"])] + ) class ParseTableTest(absltest.TestCase): @@ -505,11 +532,14 @@ def testSimple(self): header.columns.append(rdf_osquery.OsqueryColumn(name="bar")) header.columns.append(rdf_osquery.OsqueryColumn(name="baz")) - row = osquery.ParseRow(header, { - "foo": "quux", - "bar": "norf", - "baz": "thud", - }) + row = osquery.ParseRow( + header, + { + "foo": "quux", + "bar": "norf", + "baz": "thud", + }, + ) self.assertEqual(row.values, ["quux", "norf", "thud"]) diff --git a/grr/client/grr_response_client/client_actions/osx/osx.py b/grr/client/grr_response_client/client_actions/osx/osx.py index 15e957bea1..32278334d1 100644 --- a/grr/client/grr_response_client/client_actions/osx/osx.py +++ b/grr/client/grr_response_client/client_actions/osx/osx.py @@ -15,14 +15,13 @@ import pytsk3 from grr_response_client import actions -from grr_response_client import client_utils_common from grr_response_client import client_utils_osx +from grr_response_client.client_actions import osx_linux from grr_response_client.client_actions import standard from grr_response_client.osx import objc from grr_response_core.lib import rdfvalue from grr_response_core.lib.parsers import osx_launchd from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import client_action as rdf_client_action from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import client_network as rdf_client_network from grr_response_core.lib.rdfvalues import protodict as rdf_protodict @@ -106,6 +105,7 @@ class Sockaddr(ctypes.Structure): class Sockaddrdl(ctypes.Structure): """The sockaddr_dl struct.""" + _fields_ = [ ("sdl_len", ctypes.c_ubyte), ("sdl_family", ctypes.c_ubyte), @@ -129,6 +129,7 @@ class Sockaddrdl(ctypes.Structure): class Sockaddrin(ctypes.Structure): """The sockaddr_in struct.""" + _fields_ = [ ("sin_len", ctypes.c_ubyte), ("sin_family", sa_family_t), @@ -150,6 +151,7 @@ class Sockaddrin(ctypes.Structure): class Sockaddrin6(ctypes.Structure): """The sockaddr_in6 struct.""" + _fields_ = [ ("sin6_len", ctypes.c_ubyte), ("sin6_family", sa_family_t), @@ -159,6 +161,7 @@ class Sockaddrin6(ctypes.Structure): ("sin6_scope_id", ctypes.c_uint32), ] + # struct ifaddrs *ifa_next; /* Pointer to next struct */ # char *ifa_name; /* Interface name */ # u_int ifa_flags; /* Interface flags */ @@ -247,7 +250,9 @@ def ParseIfaddrs(ifaddrs): nlen = sockaddrdl.contents.sdl_nlen alen = sockaddrdl.contents.sdl_alen - iface.mac_address = bytes(sockaddrdl.contents.sdl_data[nlen:nlen + alen]) + iface.mac_address = bytes( + sockaddrdl.contents.sdl_data[nlen : nlen + alen] + ) else: raise ValueError("Unexpected socket address family: %s" % iffamily) @@ -271,6 +276,7 @@ def EnumerateInterfacesFromClient(args): class EnumerateInterfaces(actions.ActionPlugin): """Enumerate all MAC addresses of all NICs.""" + out_rdfvalues = [rdf_client_network.Interface] def Run(self, args): @@ -280,6 +286,7 @@ def Run(self, args): class GetInstallDate(actions.ActionPlugin): """Estimate the install date of this system.""" + out_rdfvalues = [rdf_protodict.DataBlob] def Run(self, unused_args): @@ -300,7 +307,8 @@ def EnumerateFilesystemsFromClient(args): yield rdf_client_fs.Filesystem( device=fs_struct.f_mntfromname, mount_point=fs_struct.f_mntonname, - type=fs_struct.f_fstypename) + type=fs_struct.f_fstypename, + ) drive_re = re.compile("r?disk[0-9].*") for drive in os.listdir("/dev"): @@ -320,7 +328,8 @@ def EnumerateFilesystemsFromClient(args): offset = volume.start * vol_inf.info.block_size yield rdf_client_fs.Filesystem( device="{path}:{offset}".format(path=path, offset=offset), - type="partition") + type="partition", + ) except (IOError, RuntimeError): continue @@ -328,6 +337,7 @@ def EnumerateFilesystemsFromClient(args): class EnumerateFilesystems(actions.ActionPlugin): """Enumerate all unique filesystems local to the system.""" + out_rdfvalues = [rdf_client_fs.Filesystem] def Run(self, args): @@ -348,7 +358,8 @@ def CreateServiceProto(job): label=job.get("Label"), program=job.get("Program"), sessiontype=job.get("LimitLoadToSessionType"), - ondemand=bool(job["OnDemand"])) + ondemand=bool(job["OnDemand"]), + ) if job["LastExitStatus"] is not None: service.lastexitstatus = int(job["LastExitStatus"]) @@ -399,8 +410,9 @@ def OSXEnumerateRunningServicesFromClient(args): if version_array[:2] < [10, 6]: raise UnsupportedOSVersionError( - "ServiceManagement API unsupported on < 10.6. This client is %s" % - osx_version.VersionString()) + "ServiceManagement API unsupported on < 10.6. This client is %s" + % osx_version.VersionString() + ) launchd_list = GetRunningLaunchDaemons() @@ -412,6 +424,7 @@ def OSXEnumerateRunningServicesFromClient(args): class OSXEnumerateRunningServices(actions.ActionPlugin): """Enumerate all running launchd jobs.""" + in_rdfvalue = None out_rdfvalues = [rdf_client.OSXServiceInformation] @@ -424,23 +437,5 @@ class UpdateAgent(standard.ExecuteBinaryCommand): """Updates the GRR agent to a new version.""" def ProcessFile(self, path, args): - - cmd = "/usr/sbin/installer" - cmd_args = ["-pkg", path, "-target", "/"] - time_limit = args.time_limit - - res = client_utils_common.Execute( - cmd, cmd_args, time_limit=time_limit, bypass_allowlist=True) - (stdout, stderr, status, time_used) = res - - # Limit output to 10MB so our response doesn't get too big. - stdout = stdout[:10 * 1024 * 1024] - stderr = stderr[:10 * 1024 * 1024] - - self.SendReply( - rdf_client_action.ExecuteBinaryResponse( - stdout=stdout, - stderr=stderr, - exit_status=status, - # We have to return microseconds. - time_used=int(1e6 * time_used))) + cmd = ["/usr/sbin/installer", "-pkg", path, "-target", "/"] + osx_linux.RunInstallerCmd(cmd) diff --git a/grr/client/grr_response_client/client_actions/osx/osx_test.py b/grr/client/grr_response_client/client_actions/osx/osx_test.py index 51058c93bb..9b0576d511 100644 --- a/grr/client/grr_response_client/client_actions/osx/osx_test.py +++ b/grr/client/grr_response_client/client_actions/osx/osx_test.py @@ -37,8 +37,8 @@ def testFileSystemEnumeration64Bit(self): """Ensure we can enumerate file systems successfully.""" path = os.path.join(self.base_path, "osx_fsdata") results = self.osx.client_utils_osx.ParseFileSystemsStruct( - self.osx.client_utils_osx.StatFS64Struct, 7, - open(path, "rb").read()) + self.osx.client_utils_osx.StatFS64Struct, 7, open(path, "rb").read() + ) self.assertLen(results, 7) self.assertEqual(results[0].f_fstypename, b"hfs") self.assertEqual(results[0].f_mntonname, b"/") @@ -59,22 +59,26 @@ def ValidResponseProtoSingle(self, proto): @mock.patch( "grr_response_client.client_utils_osx." - "OSXVersion") + "OSXVersion" + ) def testOSXEnumerateRunningServicesAll(self, osx_version_mock): version_value_mock = mock.Mock() version_value_mock.VersionAsMajorMinor.return_value = [10, 7] osx_version_mock.return_value = version_value_mock with mock.patch.object( - self.osx, "GetRunningLaunchDaemons") as get_running_launch_daemons_mock: - with mock.patch.object(self.osx.OSXEnumerateRunningServices, - "SendReply") as send_reply_mock: + self.osx, "GetRunningLaunchDaemons" + ) as get_running_launch_daemons_mock: + with mock.patch.object( + self.osx.OSXEnumerateRunningServices, "SendReply" + ) as send_reply_mock: get_running_launch_daemons_mock.return_value = osx_launchd_testdata.JOBS action = self.osx.OSXEnumerateRunningServices(None) - num_results = len( - osx_launchd_testdata.JOBS) - osx_launchd_testdata.FILTERED_COUNT + num_results = ( + len(osx_launchd_testdata.JOBS) - osx_launchd_testdata.FILTERED_COUNT + ) action.Run(None) @@ -86,16 +90,19 @@ def testOSXEnumerateRunningServicesAll(self, osx_version_mock): @mock.patch( "grr_response_client.client_utils_osx." - "OSXVersion") + "OSXVersion" + ) def testOSXEnumerateRunningServicesSingle(self, osx_version_mock): version_value_mock = mock.Mock() version_value_mock.VersionAsMajorMinor.return_value = [10, 7, 1] osx_version_mock.return_value = version_value_mock with mock.patch.object( - self.osx, "GetRunningLaunchDaemons") as get_running_launch_daemons_mock: - with mock.patch.object(self.osx.OSXEnumerateRunningServices, - "SendReply") as send_reply_mock: + self.osx, "GetRunningLaunchDaemons" + ) as get_running_launch_daemons_mock: + with mock.patch.object( + self.osx.OSXEnumerateRunningServices, "SendReply" + ) as send_reply_mock: get_running_launch_daemons_mock.return_value = osx_launchd_testdata.JOB @@ -116,7 +123,8 @@ def testOSXEnumerateRunningServicesSingle(self, osx_version_mock): @mock.patch( "grr_response_client.client_utils_osx." - "OSXVersion") + "OSXVersion" + ) def testOSXEnumerateRunningServicesVersionError(self, osx_version_mock): version_value_mock = mock.Mock() version_value_mock.VersionAsMajorMinor.return_value = [10, 5, 1] @@ -167,7 +175,8 @@ def testSingleIpv4(self): ifaddr = self.osx.Ifaddrs() ifaddr.ifa_name = ctypes.create_string_buffer("foo".encode("utf-8")) ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(sockaddrin), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(sockaddrin), ctypes.POINTER(self.osx.Sockaddr) + ) results = list(self.osx.ParseIfaddrs(ctypes.pointer(ifaddr))) self.assertLen(results, 1) @@ -188,7 +197,8 @@ def testSingleIpv6(self): ifaddr = self.osx.Ifaddrs() ifaddr.ifa_name = ctypes.create_string_buffer("bar".encode("utf-8")) ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(sockaddrin), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(sockaddrin), ctypes.POINTER(self.osx.Sockaddr) + ) results = list(self.osx.ParseIfaddrs(ctypes.pointer(ifaddr))) self.assertLen(results, 1) @@ -205,14 +215,15 @@ def testSingleMac(self): sockaddrdl = self.osx.Sockaddrdl() sockaddrdl.sdl_family = self.osx.AF_LINK - sockaddrdl.sdl_data[0:len(name + mac)] = list(bytes(name + mac)) + sockaddrdl.sdl_data[0 : len(name + mac)] = list(bytes(name + mac)) sockaddrdl.sdl_nlen = len(name) sockaddrdl.sdl_alen = len(mac) ifaddr = self.osx.Ifaddrs() ifaddr.ifa_name = ctypes.create_string_buffer(name) ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(sockaddrdl), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(sockaddrdl), ctypes.POINTER(self.osx.Sockaddr) + ) results = list(self.osx.ParseIfaddrs(ctypes.pointer(ifaddr))) self.assertLen(results, 1) @@ -229,7 +240,7 @@ def testMultiple(self): foo_sockaddrdl = self.osx.Sockaddrdl() foo_sockaddrdl.sdl_family = self.osx.AF_LINK - foo_sockaddrdl.sdl_data[0:len(foo_mac)] = list(bytes(foo_mac)) + foo_sockaddrdl.sdl_data[0 : len(foo_mac)] = list(bytes(foo_mac)) foo_sockaddrdl.sdl_nlen = 0 foo_sockaddrdl.sdl_alen = len(foo_mac) @@ -242,7 +253,7 @@ def testMultiple(self): bar_sockaddrdl = self.osx.Sockaddrdl() bar_sockaddrdl.sdl_family = self.osx.AF_LINK - bar_sockaddrdl.sdl_data[0:len(foo_mac)] = list(bytes(bar_mac)) + bar_sockaddrdl.sdl_data[0 : len(foo_mac)] = list(bytes(bar_mac)) bar_sockaddrdl.sdl_nlen = 0 bar_sockaddrdl.sdl_alen = len(bar_mac) @@ -250,28 +261,32 @@ def testMultiple(self): ifaddr.ifa_next = None ifaddr.ifa_name = ctypes.create_string_buffer(b"foo") ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(foo_sockaddrin), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(foo_sockaddrin), ctypes.POINTER(self.osx.Sockaddr) + ) ifnext = ifaddr ifaddr = self.osx.Ifaddrs() ifaddr.ifa_next = ctypes.pointer(ifnext) ifaddr.ifa_name = ctypes.create_string_buffer(b"foo") ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(foo_sockaddrdl), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(foo_sockaddrdl), ctypes.POINTER(self.osx.Sockaddr) + ) ifnext = ifaddr ifaddr = self.osx.Ifaddrs() ifaddr.ifa_next = ctypes.pointer(ifnext) ifaddr.ifa_name = ctypes.create_string_buffer(b"bar") ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(bar_sockaddrdl), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(bar_sockaddrdl), ctypes.POINTER(self.osx.Sockaddr) + ) ifnext = ifaddr ifaddr = self.osx.Ifaddrs() ifaddr.ifa_next = ctypes.pointer(ifnext) ifaddr.ifa_name = ctypes.create_string_buffer(b"bar") ifaddr.ifa_addr = ctypes.cast( - ctypes.pointer(bar_sockaddrin), ctypes.POINTER(self.osx.Sockaddr)) + ctypes.pointer(bar_sockaddrin), ctypes.POINTER(self.osx.Sockaddr) + ) expected_foo_iface = rdf_client_network.Interface( ifname="foo", @@ -279,8 +294,10 @@ def testMultiple(self): addresses=[ rdf_client_network.NetworkAddress( address_type=rdf_client_network.NetworkAddress.Family.INET, - packed_bytes=foo_ipv4), - ]) + packed_bytes=foo_ipv4, + ), + ], + ) expected_bar_iface = rdf_client_network.Interface( ifname="bar", @@ -288,8 +305,10 @@ def testMultiple(self): addresses=[ rdf_client_network.NetworkAddress( address_type=rdf_client_network.NetworkAddress.Family.INET6, - packed_bytes=bar_ipv6), - ]) + packed_bytes=bar_ipv6, + ), + ], + ) results = list(self.osx.ParseIfaddrs(ctypes.pointer(ifaddr))) self.assertCountEqual(results, [expected_foo_iface, expected_bar_iface]) diff --git a/grr/client/grr_response_client/client_actions/osx_linux.py b/grr/client/grr_response_client/client_actions/osx_linux.py new file mode 100644 index 0000000000..a53f325d24 --- /dev/null +++ b/grr/client/grr_response_client/client_actions/osx_linux.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +"""Client action utils common to macOS and Linux.""" + +import logging +import os +import subprocess +import time + +from grr_response_client.client_actions import tempfiles +from grr_response_core.lib.rdfvalues import client_action as rdf_client_action + + +def RunInstallerCmd(cmd: list[str]) -> rdf_client_action.ExecuteBinaryResponse: + """Run an installer process that is expected to kill the grr client.""" + # Remove env vars pointing into the bundled pyinstaller directory to prevent + # system installers from loading grr libs. + env = os.environ.copy() + env.pop("LD_LIBRARY_PATH", None) + env.pop("PYTHON_PATH", None) + logging.info("Executing %s", " ".join(cmd)) + start = time.monotonic() + stdout_filename = None + stderr_filename = None + try: + with tempfiles.CreateGRRTempFile( + filename="GRRInstallStdout.txt" + ) as stdout_file: + stdout_filename = stdout_file.name + with tempfiles.CreateGRRTempFile( + filename="GRRInstallStderr.txt" + ) as stderr_file: + stderr_filename = stderr_file.name + p = subprocess.run( + cmd, + env=env, + start_new_session=True, + stdin=subprocess.DEVNULL, + stdout=stdout_file, + stderr=stderr_file, + check=False, + ) + logging.error("Installer ran, but the old GRR client is still running") + return rdf_client_action.ExecuteBinaryResponse( + # Limit output to fit within 2MiB fleetspeak message limit. + stdout=stdout_file.read(512 * 1024), + stderr=stderr_file.read(512 * 1024), + exit_status=p.returncode, + # We have to return microseconds. + time_used=int(1e6 * (time.monotonic() - start)), + ) + finally: + # Clean up log files. It's unlikely that apt/dpkg/installer will produce + # more output than the cap above. + for filename in (stdout_filename, stderr_filename): + if filename is not None: + try: + os.remove(filename) + except OSError: + pass diff --git a/grr/client/grr_response_client/client_actions/read_low_level.py b/grr/client/grr_response_client/client_actions/read_low_level.py index b65428b18b..65785ad49c 100644 --- a/grr/client/grr_response_client/client_actions/read_low_level.py +++ b/grr/client/grr_response_client/client_actions/read_low_level.py @@ -52,8 +52,10 @@ def Run(self, args: rdf_read_low_level.ReadLowLevelRequest) -> None: # Make sure we limit the size of our output. if args.length > _READ_BYTES_LIMIT: - raise RuntimeError(f"Can not read buffers this large " - f"({args.length} > {_READ_BYTES_LIMIT} bytes).") + raise RuntimeError( + "Can not read buffers this large " + f"({args.length} > {_READ_BYTES_LIMIT} bytes)." + ) # TODO: Update `blob_size` when `sector_block_size` is set. # `blob_size` must be a multiple of `sector_block_size` so that reads start @@ -78,7 +80,7 @@ def Run(self, args: rdf_read_low_level.ReadLowLevelRequest) -> None: # Discard data that we read unnecessarily due to alignment. # Refer to `_AlignArgs` documentation for more details. if is_first_chunk: - data = data[self._pre_padding:] + data = data[self._pre_padding :] is_first_chunk = False # Upload the blobs to blobstore using `TransferStore`. Save the buffer @@ -88,7 +90,8 @@ def Run(self, args: rdf_read_low_level.ReadLowLevelRequest) -> None: # in order to avoid `InvalidBlobOffsetError` when storing the blobs as # a file in `file_store`. reference_offset = ( - current_offset - self._pre_padding if current_offset else 0) + current_offset - self._pre_padding if current_offset else 0 + ) self._StoreDataAndHash(data, reference_offset) current_offset = current_offset + read_size bytes_left_to_read -= read_size @@ -105,7 +108,8 @@ def _StoreDataAndHash(self, data: AnyStr, offset: int) -> None: data_blob = rdf_protodict.DataBlob( data=zlib.compress(data), - compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION) + compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION, + ) # Ensure that the buffer is counted against this response. Check network # send limit. @@ -114,20 +118,24 @@ def _StoreDataAndHash(self, data: AnyStr, offset: int) -> None: # Now return the data to the server into the special TransferStore well # known flow. self.grr_worker.SendReply( - data_blob, session_id=rdfvalue.SessionID(flow_name="TransferStore")) + data_blob, session_id=rdfvalue.SessionID(flow_name="TransferStore") + ) # Now report the hash of this blob to our flow as well as the offset and # length. digest = hashlib.sha256(data).digest() buffer_reference = rdf_client.BufferReference( - offset=offset, length=len(data), data=digest) + offset=offset, length=len(data), data=digest + ) self._partial_file_hash.update(data) partial_file_hash = self._partial_file_hash.digest() self.SendReply( rdf_read_low_level.ReadLowLevelResult( - blob=buffer_reference, accumulated_hash=partial_file_hash)) + blob=buffer_reference, accumulated_hash=partial_file_hash + ) + ) def GetPrePadding(args: rdf_read_low_level.ReadLowLevelRequest) -> int: @@ -144,8 +152,10 @@ def GetPrePadding(args: rdf_read_low_level.ReadLowLevelRequest) -> int: return args.offset % block_size -def AlignArgs(args: rdf_read_low_level.ReadLowLevelRequest, - pre_padding: int) -> rdf_read_low_level.ReadLowLevelRequest: +def AlignArgs( + args: rdf_read_low_level.ReadLowLevelRequest, + pre_padding: int, +) -> rdf_read_low_level.ReadLowLevelRequest: """Aligns the offset and updates the length according to the pre_padding. It returns a copy of the flow arguments with the aligned offset value, diff --git a/grr/client/grr_response_client/client_actions/read_low_level_test.py b/grr/client/grr_response_client/client_actions/read_low_level_test.py index a56bd67761..23dcab6214 100644 --- a/grr/client/grr_response_client/client_actions/read_low_level_test.py +++ b/grr/client/grr_response_client/client_actions/read_low_level_test.py @@ -20,7 +20,8 @@ def testReadsOneAlignedChunk(self): temp_file.write_bytes(b"123456") request1 = rdf_read_low_level.ReadLowLevelRequest( - path=temp_file.full_path, length=1) # offset should default to 0 + path=temp_file.full_path, length=1 + ) # offset should default to 0 # We call ExecuteAction rather than RunAction because we need the # the ClientAction `message` set in order to call `ChargeBytesToSession`. @@ -44,7 +45,8 @@ def testReadsOneMisalignedChunk(self): temp_file.write_bytes(b"123456") request23 = rdf_read_low_level.ReadLowLevelRequest( - path=temp_file.full_path, length=2, offset=1) + path=temp_file.full_path, length=2, offset=1 + ) # We call ExecuteAction rather than RunAction because we need the # the ClientAction `message` set in order to call `ChargeBytesToSession`. @@ -58,7 +60,8 @@ def testReadsOneMisalignedChunk(self): self.assertEqual(0, results[0].blob.offset) self.assertEqual(hashlib.sha256(b"23").digest(), results[0].blob.data) self.assertEqual( - hashlib.sha256(b"23").digest(), results[0].accumulated_hash) + hashlib.sha256(b"23").digest(), results[0].accumulated_hash + ) self.assertIsInstance(results[1], rdf_flows.GrrStatus) self.assertEqual(rdf_flows.GrrStatus.ReturnedStatus.OK, results[1].status) @@ -72,7 +75,8 @@ def testReadsMultipleMisalignedChunks(self): # 2 and 3. # TODO: Update test when blob size is also "aligned" request23_2blobs = rdf_read_low_level.ReadLowLevelRequest( - path=temp_file.full_path, length=2, offset=1, blob_size=1) + path=temp_file.full_path, length=2, offset=1, blob_size=1 + ) # We call ExecuteAction rather than RunAction because we need the # the ClientAction `message` set in order to call `ChargeBytesToSession`. @@ -92,7 +96,8 @@ def testReadsMultipleMisalignedChunks(self): self.assertEqual(1, results[1].blob.offset) # 'corrected' offset self.assertEqual(hashlib.sha256(b"3").digest(), results[1].blob.data) self.assertEqual( - hashlib.sha256(b"23").digest(), results[1].accumulated_hash) + hashlib.sha256(b"23").digest(), results[1].accumulated_hash + ) self.assertIsInstance(results[2], rdf_flows.GrrStatus) self.assertEqual(rdf_flows.GrrStatus.ReturnedStatus.OK, results[2].status) diff --git a/grr/client/grr_response_client/client_actions/registry_init.py b/grr/client/grr_response_client/client_actions/registry_init.py index e15e255b53..bbcd0ef4d7 100644 --- a/grr/client/grr_response_client/client_actions/registry_init.py +++ b/grr/client/grr_response_client/client_actions/registry_init.py @@ -26,8 +26,9 @@ def RegisterClientActions(): """Registers all client actions.""" - client_actions.Register("CheckFreeGRRTempSpace", - tempfiles.CheckFreeGRRTempSpace) + client_actions.Register( + "CheckFreeGRRTempSpace", tempfiles.CheckFreeGRRTempSpace + ) client_actions.Register("CollectLargeFile", large_file.CollectLargeFileAction) client_actions.Register("DeleteGRRTempFiles", tempfiles.DeleteGRRTempFiles) client_actions.Register("Echo", admin.Echo) @@ -50,8 +51,9 @@ def RegisterClientActions(): client_actions.Register("HashFile", standard.HashFile) client_actions.Register("Kill", admin.Kill) client_actions.Register("ListDirectory", standard.ListDirectory) - client_actions.Register("ListNetworkConnections", - network.ListNetworkConnections) + client_actions.Register( + "ListNetworkConnections", network.ListNetworkConnections + ) client_actions.Register("ListProcesses", standard.ListProcesses) client_actions.Register("Osquery", osquery.Osquery) client_actions.Register("ReadBuffer", standard.ReadBuffer) @@ -72,8 +74,9 @@ def RegisterClientActions(): client_actions.Register("Dummy", dummy.Dummy) client_actions.Register("EnumerateFilesystems", linux.EnumerateFilesystems) client_actions.Register("EnumerateInterfaces", linux.EnumerateInterfaces) - client_actions.Register("EnumerateRunningServices", - linux.EnumerateRunningServices) + client_actions.Register( + "EnumerateRunningServices", linux.EnumerateRunningServices + ) client_actions.Register("EnumerateUsers", linux.EnumerateUsers) client_actions.Register("GetInstallDate", linux.GetInstallDate) client_actions.Register("UpdateAgent", linux.UpdateAgent) @@ -82,8 +85,9 @@ def RegisterClientActions(): from grr_response_client.client_actions.windows import windows # pylint: disable=g-import-not-at-top client_actions.Register("Dummy", win_dummy.Dummy) - client_actions.Register("EnumerateFilesystems", - windows.EnumerateFilesystems) + client_actions.Register( + "EnumerateFilesystems", windows.EnumerateFilesystems + ) client_actions.Register("EnumerateInterfaces", windows.EnumerateInterfaces) client_actions.Register("GetInstallDate", windows.GetInstallDate) client_actions.Register("WmiQuery", windows.WmiQuery) @@ -92,9 +96,11 @@ def RegisterClientActions(): elif platform.system() == "Darwin": from grr_response_client.client_actions.osx import osx # pylint: disable=g-import-not-at-top + client_actions.Register("EnumerateFilesystems", osx.EnumerateFilesystems) client_actions.Register("EnumerateInterfaces", osx.EnumerateInterfaces) client_actions.Register("GetInstallDate", osx.GetInstallDate) - client_actions.Register("OSXEnumerateRunningServices", - osx.OSXEnumerateRunningServices) + client_actions.Register( + "OSXEnumerateRunningServices", osx.OSXEnumerateRunningServices + ) client_actions.Register("UpdateAgent", osx.UpdateAgent) diff --git a/grr/client/grr_response_client/client_actions/searching.py b/grr/client/grr_response_client/client_actions/searching.py index 3c4119470e..984301c6b6 100644 --- a/grr/client/grr_response_client/client_actions/searching.py +++ b/grr/client/grr_response_client/client_actions/searching.py @@ -16,6 +16,7 @@ class Find(actions.ActionPlugin): """Recurses through a directory returning files which match conditions.""" + in_rdfvalue = rdf_client_fs.FindSpec out_rdfvalues = [rdf_client_fs.FindSpec, rdf_client_fs.StatEntry] @@ -39,8 +40,9 @@ def ListDirectory(self, pathspec, depth=0): self.SetStatus(rdf_flows.GrrStatus.ReturnedStatus.IOERROR, e) else: # Can't open the directory we're searching, ignore the directory. - logging.info("Find failed to ListDirectory for %s. Err: %s", pathspec, - e) + logging.info( + "Find failed to ListDirectory for %s. Err: %s", pathspec, e + ) return # If we are not supposed to cross devices, and don't know yet @@ -67,7 +69,8 @@ def TestFileContent(self, file_stat): data = b"" with vfs.VFSOpen( - file_stat.pathspec, progress_callback=self.Progress) as fd: + file_stat.pathspec, progress_callback=self.Progress + ) as fd: # Only read this much data from the file. while fd.Tell() < self.request.max_data: data_read = fd.read(1024000) @@ -107,8 +110,9 @@ def BuildChecks(self, request): def FilterTimestamp(file_stat, request=request): return file_stat.HasField("st_mtime") and ( - file_stat.st_mtime < request.start_time or - file_stat.st_mtime > request.end_time) + file_stat.st_mtime < request.start_time + or file_stat.st_mtime > request.end_time + ) result.append(FilterTimestamp) @@ -116,8 +120,9 @@ def FilterTimestamp(file_stat, request=request): def FilterSize(file_stat, request=request): return file_stat.HasField("st_size") and ( - file_stat.st_size < request.min_file_size or - file_stat.st_size > request.max_file_size) + file_stat.st_size < request.min_file_size + or file_stat.st_size > request.max_file_size + ) result.append(FilterSize) @@ -186,6 +191,7 @@ def Run(self, request): class Grep(actions.ActionPlugin): """Search a file for a pattern.""" + in_rdfvalue = rdf_client_fs.GrepSpec out_rdfvalues = [rdf_client.BufferReference] @@ -268,7 +274,6 @@ def Run(self, args): Raises: RuntimeError: No search pattern has been given in the request. - """ fd = vfs.VFSOpen(args.target, progress_callback=self.Progress) fd.Seek(args.start_offset) @@ -291,13 +296,16 @@ def Run(self, args): while fd.Tell() < args.start_offset + args.length: # Base size to read is at most the buffer size. - to_read = min(args.length, self.BUFF_SIZE, - args.start_offset + args.length - fd.Tell()) + to_read = min( + args.length, + self.BUFF_SIZE, + args.start_offset + args.length - fd.Tell(), + ) # Read some more data for the snippet. to_read += self.ENVELOPE_SIZE - postscript_size read_data = fd.Read(to_read) - data = data[-postscript_size - self.ENVELOPE_SIZE:] + read_data + data = data[-postscript_size - self.ENVELOPE_SIZE :] + read_data postscript_size = max(0, self.ENVELOPE_SIZE - (to_read - len(read_data))) data_size = len(data) - preamble_size - postscript_size @@ -305,7 +313,7 @@ def Run(self, args): if data_size == 0 and postscript_size == 0: break - for (start, end) in find_func(data): + for start, end in find_func(data): # Ignore hits in the preamble. if end <= preamble_size: continue @@ -328,17 +336,22 @@ def Run(self, args): offset=base_offset + start - preamble_size, data=out_data, length=len(out_data), - pathspec=fd.pathspec)) + pathspec=fd.pathspec, + ) + ) if args.mode == rdf_client_fs.GrepSpec.Mode.FIRST_HIT: return if hits >= self.HIT_LIMIT: msg = utils.Xor( - b"This Grep has reached the maximum number of hits" - b" (%d)." % self.HIT_LIMIT, self.xor_out_key) + b"This Grep has reached the maximum number of hits (%d)." + % self.HIT_LIMIT, + self.xor_out_key, + ) self.SendReply( - rdf_client.BufferReference(offset=0, data=msg, length=len(msg))) + rdf_client.BufferReference(offset=0, data=msg, length=len(msg)) + ) return self.Progress() diff --git a/grr/client/grr_response_client/client_actions/searching_test.py b/grr/client/grr_response_client/client_actions/searching_test.py index 2ed2e68177..fd0d4de191 100644 --- a/grr/client/grr_response_client/client_actions/searching_test.py +++ b/grr/client/grr_response_client/client_actions/searching_test.py @@ -24,6 +24,7 @@ class MockVFSHandlerFind(vfs.VFSHandler): This is used to create the /mock2/ client vfs branch which is utilized in the below tests. """ + supported_pathtype = rdf_paths.PathSpec.PathType.OS filesystem = { @@ -37,7 +38,7 @@ class MockVFSHandlerFind(vfs.VFSHandler): "/mock2/directory1/directory2/file.mp3": b"MP3 movie", "/mock2/directory3": ["file1.txt", "long_file.text"], "/mock2/directory3/file1.txt": b"A text file", - "/mock2/directory3/long_file.text": (b"space " * 100000 + b"A Secret") + "/mock2/directory3/long_file.text": b"space " * 100000 + b"A Secret", } def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): @@ -45,7 +46,8 @@ def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): base_fd, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) self.pathspec.Append(pathspec) self.path = self.pathspec.CollapsePath() @@ -62,7 +64,7 @@ def Read(self, length): if isinstance(self.content, list): raise IOError() - result = self.content[self.offset:self.offset + length] + result = self.content[self.offset : self.offset + length] self.offset = min(self.size, self.offset + len(result)) return result @@ -154,8 +156,9 @@ def setUp(self): super().setUp() # Install the mock - vfs_overrider = vfs_test_lib.VFSOverrider(rdf_paths.PathSpec.PathType.OS, - MockVFSHandlerFind) + vfs_overrider = vfs_test_lib.VFSOverrider( + rdf_paths.PathSpec.PathType.OS, MockVFSHandlerFind + ) vfs_overrider.Start() self.addCleanup(vfs_overrider.Stop) @@ -163,18 +166,28 @@ def testFindAction(self): """Test the find action.""" # First get all the files at once pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_fs.FindSpec(pathspec=pathspec, path_regex=".") result = self.RunAction(searching.Find, request) - self.assertCountEqual([r.pathspec.Basename() for r in result], [ - "file1.txt", "file2.txt", "file.jpg", "file.mp3", "directory2", - "directory1", "directory3" - ]) + self.assertCountEqual( + [r.pathspec.Basename() for r in result], + [ + "file1.txt", + "file2.txt", + "file.jpg", + "file.mp3", + "directory2", + "directory1", + "directory3", + ], + ) def testFindAction2(self): """Test the find action path regex.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_fs.FindSpec(pathspec=pathspec, path_regex=".*mp3") all_files = self.RunAction(searching.Find, request) @@ -185,9 +198,11 @@ def testFindAction3(self): """Test the find action data regex.""" # First get all the files at once pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_fs.FindSpec( - pathspec=pathspec, data_regex=b"Secret", cross_devs=True) + pathspec=pathspec, data_regex=b"Secret", cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -198,9 +213,11 @@ def testFindSizeLimits(self): """Test the find action size limits.""" # First get all the files at once request = rdf_client_fs.FindSpec( - min_file_size=4, max_file_size=15, cross_devs=True) + min_file_size=4, max_file_size=15, cross_devs=True + ) request.pathspec.Append( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) results = self.RunAction(searching.Find, request) all_files = [r.pathspec.Basename() for r in results] @@ -215,7 +232,8 @@ def testNoFilters(self): """Test the we get all files with no filters in place.""" # First get all the files at once pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_fs.FindSpec(pathspec=pathspec, cross_devs=True) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 9) @@ -223,14 +241,17 @@ def testNoFilters(self): def testFindActionCrossDev(self): """Test that devices boundaries don't get crossed, also by default.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_fs.FindSpec( - pathspec=pathspec, cross_devs=True, path_regex=".") + pathspec=pathspec, cross_devs=True, path_regex="." + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 9) request = rdf_client_fs.FindSpec( - pathspec=pathspec, cross_devs=False, path_regex=".") + pathspec=pathspec, cross_devs=False, path_regex="." + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 7) @@ -242,12 +263,14 @@ def testPermissionFilter(self): """Test filtering based on file/folder permission happens correctly.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) # Look for files that match exact permissions request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", perm_mode=0o644, cross_devs=True) + pathspec=pathspec, path_regex=".", perm_mode=0o644, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -265,7 +288,8 @@ def testPermissionFilter(self): path_regex=".", perm_mode=0o4002, perm_mask=0o7002, - cross_devs=True) + cross_devs=True, + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -282,7 +306,8 @@ def testPermissionFilter(self): path_regex=".", perm_mode=0o0100001, perm_mask=0o0100001, - cross_devs=True) + cross_devs=True, + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -299,7 +324,8 @@ def testPermissionFilter(self): path_regex=".", perm_mode=0o0040010, perm_mask=0o0040010, - cross_devs=True) + cross_devs=True, + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 3) @@ -311,12 +337,14 @@ def testUIDFilter(self): """Test filtering based on uid happens correctly.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) # Look for files that have uid of 60 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", uid=60, cross_devs=True) + pathspec=pathspec, path_regex=".", uid=60, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -328,7 +356,8 @@ def testUIDFilter(self): # Look for files that have uid of 0 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", uid=0, cross_devs=True) + pathspec=pathspec, path_regex=".", uid=0, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 3) @@ -340,12 +369,14 @@ def testGIDFilter(self): """Test filtering based on gid happens correctly.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) # Look for files that have gid of 500 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", gid=500, cross_devs=True) + pathspec=pathspec, path_regex=".", gid=500, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -357,7 +388,8 @@ def testGIDFilter(self): # Look for files that have uid of 900 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", gid=900, cross_devs=True) + pathspec=pathspec, path_regex=".", gid=900, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -370,12 +402,14 @@ def testUIDAndGIDFilter(self): """Test filtering based on combination of uid and gid happens correctly.""" pathspec = rdf_paths.PathSpec( - path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/mock2/", pathtype=rdf_paths.PathSpec.PathType.OS + ) # Look for files that have uid of 90 and gid of 500 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", uid=90, gid=500, cross_devs=True) + pathspec=pathspec, path_regex=".", uid=90, gid=500, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertEmpty(all_files) @@ -383,7 +417,8 @@ def testUIDAndGIDFilter(self): # Look for files that have uid of 50 and gid of 500 request = rdf_client_fs.FindSpec( - pathspec=pathspec, path_regex=".", uid=50, gid=500, cross_devs=True) + pathspec=pathspec, path_regex=".", uid=50, gid=500, cross_devs=True + ) all_files = self.RunAction(searching.Find, request) self.assertLen(all_files, 2) @@ -399,21 +434,26 @@ def testExtAttrsCollection(self): with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: foo_filepath = temp.TempFilePath(dir=temp_dirpath) filesystem_test_lib.SetExtAttr( - foo_filepath, name="user.quux", value="foo") + foo_filepath, name="user.quux", value="foo" + ) bar_filepath = temp.TempFilePath(dir=temp_dirpath) filesystem_test_lib.SetExtAttr( - bar_filepath, name="user.quux", value="bar") + bar_filepath, name="user.quux", value="bar" + ) baz_filepath = temp.TempFilePath(dir=temp_dirpath) filesystem_test_lib.SetExtAttr( - baz_filepath, name="user.quux", value="baz") + baz_filepath, name="user.quux", value="baz" + ) request = rdf_client_fs.FindSpec( pathspec=rdf_paths.PathSpec( - path=temp_dirpath, pathtype=rdf_paths.PathSpec.PathType.OS), + path=temp_dirpath, pathtype=rdf_paths.PathSpec.PathType.OS + ), path_glob="*", - collect_ext_attrs=True) + collect_ext_attrs=True, + ) hits = self.RunAction(searching.Find, request) @@ -437,8 +477,9 @@ def setUp(self): super().setUp() # Install the mock - vfs_overrider = vfs_test_lib.VFSOverrider(rdf_paths.PathSpec.PathType.OS, - MockVFSHandlerFind) + vfs_overrider = vfs_test_lib.VFSOverrider( + rdf_paths.PathSpec.PathType.OS, MockVFSHandlerFind + ) vfs_overrider.Start() self.addCleanup(vfs_overrider.Stop) self.filename = "/mock2/directory1/grepfile.txt" @@ -450,17 +491,40 @@ def testGrep(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"10", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = os.path.join(self.base_path, "numbers.txt") request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 result = self.RunAction(searching.Grep, request) hits = [x.offset for x in result] - self.assertEqual(hits, [ - 18, 288, 292, 296, 300, 304, 308, 312, 316, 320, 324, 329, 729, 1129, - 1529, 1929, 2329, 2729, 3129, 3529, 3888 - ]) + self.assertEqual( + hits, + [ + 18, + 288, + 292, + 296, + 300, + 304, + 308, + 312, + 316, + 320, + 324, + 329, + 729, + 1129, + 1529, + 1929, + 2329, + 2729, + 3129, + 3529, + 3888, + ], + ) for x in result: self.assertIn(b"10", utils.Xor(x.data, self.XOR_OUT_KEY)) self.assertEqual(request.target.path, x.pathspec.path) @@ -475,14 +539,38 @@ def testGrepRegex(self): start_offset=0, target=rdf_paths.PathSpec( path=os.path.join(self.base_path, "numbers.txt"), - pathtype=rdf_paths.PathSpec.PathType.OS)) + pathtype=rdf_paths.PathSpec.PathType.OS, + ), + ) result = self.RunAction(searching.Grep, request) hits = [x.offset for x in result] - self.assertEqual(hits, [ - 18, 288, 292, 296, 300, 304, 308, 312, 316, 320, 324, 329, 729, 1129, - 1529, 1929, 2329, 2729, 3129, 3529, 3888 - ]) + self.assertEqual( + hits, + [ + 18, + 288, + 292, + 296, + 300, + 304, + 308, + 312, + 316, + 320, + 324, + 329, + 729, + 1129, + 1529, + 1929, + 2329, + 2729, + 3129, + 3529, + 3888, + ], + ) for x in result: self.assertIn(b"10", utils.Xor(x.data, self.XOR_OUT_KEY)) @@ -494,7 +582,8 @@ def testGrepLength(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -506,7 +595,8 @@ def testGrepLength(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -523,7 +613,8 @@ def testGrepOffset(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -535,7 +626,8 @@ def testGrepOffset(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 5 @@ -548,7 +640,8 @@ def testGrepOffset(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 11 @@ -564,7 +657,8 @@ def testOffsetAndLength(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 11 @@ -582,7 +676,8 @@ def testSecondBuffer(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -602,7 +697,8 @@ def testBufferBoundaries(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -624,7 +720,8 @@ def testSnippetSize(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -648,7 +745,8 @@ def testGrepEverywhere(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 @@ -658,7 +756,7 @@ def testGrepEverywhere(self): result = self.RunAction(searching.Grep, request) self.assertLen(result, 1) self.assertEqual(result[0].offset, offset) - expected = data[max(0, offset - 10):offset + 3 + 10] + expected = data[max(0, offset - 10) : offset + 3 + 10] self.assertLen(expected, result[0].length) self.assertEqual(utils.Xor(result[0].data, self.XOR_OUT_KEY), expected) @@ -672,7 +770,8 @@ def testHitLimit(self): request = rdf_client_fs.GrepSpec( literal=utils.Xor(b"HIT", self.XOR_IN_KEY), xor_in_key=self.XOR_IN_KEY, - xor_out_key=self.XOR_OUT_KEY) + xor_out_key=self.XOR_OUT_KEY, + ) request.target.path = self.filename request.target.pathtype = rdf_paths.PathSpec.PathType.OS request.start_offset = 0 diff --git a/grr/client/grr_response_client/client_actions/standard.py b/grr/client/grr_response_client/client_actions/standard.py index de9d89c72c..0c6238e390 100644 --- a/grr/client/grr_response_client/client_actions/standard.py +++ b/grr/client/grr_response_client/client_actions/standard.py @@ -33,6 +33,7 @@ class ReadBuffer(actions.ActionPlugin): """Reads a buffer from a file and returns it to a server callback.""" + in_rdfvalue = rdf_client.BufferReference out_rdfvalues = [rdf_client.BufferReference] @@ -57,11 +58,14 @@ def Run(self, args): # Now return the data to the server self.SendReply( rdf_client.BufferReference( - offset=offset, data=data, length=len(data), pathspec=fd.pathspec)) + offset=offset, data=data, length=len(data), pathspec=fd.pathspec + ) + ) class TransferBuffer(actions.ActionPlugin): """Reads a buffer from a file and returns it to the server efficiently.""" + in_rdfvalue = rdf_client.BufferReference out_rdfvalues = [rdf_client.BufferReference] @@ -72,13 +76,12 @@ def Run(self, args): raise RuntimeError("Can not read buffers this large.") data = vfs.ReadVFS( - args.pathspec, - args.offset, - args.length, - progress_callback=self.Progress) + args.pathspec, args.offset, args.length, progress_callback=self.Progress + ) result = rdf_protodict.DataBlob( data=zlib.compress(data), - compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION) + compression=rdf_protodict.DataBlob.CompressionType.ZCOMPRESSION, + ) digest = hashlib.sha256(data).digest() @@ -89,17 +92,21 @@ def Run(self, args): # Now return the data to the server into the special TransferStore well # known flow. self.grr_worker.SendReply( - result, session_id=rdfvalue.SessionID(flow_name="TransferStore")) + result, session_id=rdfvalue.SessionID(flow_name="TransferStore") + ) # Now report the hash of this blob to our flow as well as the offset and # length. self.SendReply( rdf_client.BufferReference( - offset=args.offset, length=len(data), data=digest)) + offset=args.offset, length=len(data), data=digest + ) + ) class HashBuffer(actions.ActionPlugin): """Hash a buffer from a file and returns it to the server efficiently.""" + in_rdfvalue = rdf_client.BufferReference out_rdfvalues = [rdf_client.BufferReference] @@ -117,11 +124,14 @@ def Run(self, args): # length. self.SendReply( rdf_client.BufferReference( - offset=args.offset, length=len(data), data=digest)) + offset=args.offset, length=len(data), data=digest + ) + ) class HashFile(actions.ActionPlugin): """Hash an entire file using multiple algorithms.""" + in_rdfvalue = rdf_client_action.FingerprintRequest out_rdfvalues = [rdf_client_action.FingerprintResponse] @@ -143,14 +153,14 @@ def Run(self, args): hash_object = hasher.GetHashObject() response = rdf_client_action.FingerprintResponse( - pathspec=fd.pathspec, - bytes_read=hash_object.num_bytes, - hash=hash_object) + pathspec=fd.pathspec, bytes_read=hash_object.num_bytes, hash=hash_object + ) self.SendReply(response) class ListDirectory(ReadBuffer): """Lists all the files in a directory.""" + in_rdfvalue = rdf_client_action.ListDirRequest out_rdfvalues = [rdf_client_fs.StatEntry] @@ -186,7 +196,8 @@ def Run(self, args): fd = vfs.VFSOpen(args.pathspec, progress_callback=self.Progress) stat_entry = fd.Stat( - ext_attrs=args.collect_ext_attrs, follow_symlink=args.follow_symlink) + ext_attrs=args.collect_ext_attrs, follow_symlink=args.follow_symlink + ) self.SendReply(stat_entry) except (IOError, OSError) as error: @@ -210,8 +221,8 @@ def ExecuteCommandFromClient(command): (stdout, stderr, status, time_used) = res # Limit output to 10MB so our response doesn't get too big. - stdout = stdout[:10 * 1024 * 1024] - stderr = stderr[:10 * 1024 * 1024] + stdout = stdout[: 10 * 1024 * 1024] + stderr = stderr[: 10 * 1024 * 1024] yield rdf_client_action.ExecuteResponse( request=command, @@ -219,7 +230,8 @@ def ExecuteCommandFromClient(command): stderr=stderr, exit_status=status, # We have to return microseconds. - time_used=int(1e6 * time_used)) + time_used=int(1e6 * time_used), + ) class ExecuteCommand(actions.ActionPlugin): @@ -247,6 +259,7 @@ class ExecuteBinaryCommand(actions.ActionPlugin): NOTE: If the binary is too large to fit inside a single request, the request will have the more_data flag enabled, indicating more data is coming. """ + in_rdfvalue = rdf_client_action.ExecuteBinaryRequest out_rdfvalues = [rdf_client_action.ExecuteBinaryResponse] @@ -259,7 +272,8 @@ def WriteBlobToFile(self, request): mode = "r+b" temp_file = tempfiles.CreateGRRTempFile( - filename=request.write_path, mode=mode) + filename=request.write_path, mode=mode + ) with temp_file: path = temp_file.name temp_file.seek(0, 2) @@ -283,7 +297,8 @@ def Run(self, args): """Run.""" # Verify the executable blob. args.executable.Verify( - config.CONFIG["Client.executable_signing_public_key"]) + config.CONFIG["Client.executable_signing_public_key"] + ) path = self.WriteBlobToFile(args) @@ -294,12 +309,13 @@ def Run(self, args): def ProcessFile(self, path, args): res = client_utils_common.Execute( - path, args.args, args.time_limit, bypass_allowlist=True) + path, args.args, args.time_limit, bypass_allowlist=True + ) (stdout, stderr, status, time_used) = res # Limit output to 10MB so our response doesn't get too big. - stdout = stdout[:10 * 1024 * 1024] - stderr = stderr[:10 * 1024 * 1024] + stdout = stdout[: 10 * 1024 * 1024] + stderr = stderr[: 10 * 1024 * 1024] self.SendReply( rdf_client_action.ExecuteBinaryResponse( @@ -307,7 +323,9 @@ def ProcessFile(self, path, args): stderr=stderr, exit_status=status, # We have to return microseconds. - time_used=int(1e6 * time_used))) + time_used=int(1e6 * time_used), + ) + ) class ExecutePython(actions.ActionPlugin): @@ -319,6 +337,7 @@ class ExecutePython(actions.ActionPlugin): This is protected by CONFIG[PrivateKeys.executable_signing_private_key], which should be stored offline and well protected. """ + in_rdfvalue = rdf_client_action.ExecutePythonRequest out_rdfvalues = [rdf_client_action.ExecutePythonResponse] @@ -327,7 +346,8 @@ def Run(self, args): time_start = rdfvalue.RDFDatetime.Now() args.python_code.Verify( - config.CONFIG["Client.executable_signing_public_key"]) + config.CONFIG["Client.executable_signing_public_key"] + ) # The execed code can assign to this variable if it wants to return data. logging.debug("exec for python code %s", args.python_code.data[0:100]) @@ -353,8 +373,9 @@ def Run(self, args): time_used = rdfvalue.RDFDatetime.Now() - time_start self.SendReply( rdf_client_action.ExecutePythonResponse( - time_used=time_used.ToInt(rdfvalue.MICROSECONDS), - return_val=output)) + time_used=time_used.ToInt(rdfvalue.MICROSECONDS), return_val=output + ) + ) # TODO(hanuszczak): This class has been moved out of `ExecutePython::Run`. The @@ -377,6 +398,7 @@ def write(self, text: Text): # pylint: disable=invalid-name class Segfault(actions.ActionPlugin): """This action is just for debugging. It induces a segfault.""" + in_rdfvalue = None out_rdfvalues = [None] @@ -402,6 +424,7 @@ def ListProcessesFromClient(args): class ListProcesses(actions.ActionPlugin): """This action lists all the processes running on a machine.""" + in_rdfvalue = None out_rdfvalues = [rdf_client.Process] @@ -448,7 +471,8 @@ def StatFSFromClient(args): sectors_per_allocation_unit=1, total_allocation_units=st.f_blocks, actual_available_allocation_units=st.f_bavail, - unixvolume=unix) + unixvolume=unix, + ) class StatFS(actions.ActionPlugin): @@ -459,6 +483,7 @@ class StatFS(actions.ActionPlugin): Note that a statvfs call for a network filesystem (e.g. NFS) that is unavailable, e.g. due to no network, will result in the call blocking. """ + in_rdfvalue = rdf_client_action.StatFSRequest out_rdfvalues = [rdf_client_fs.Volume] diff --git a/grr/client/grr_response_client/client_actions/standard_test.py b/grr/client/grr_response_client/client_actions/standard_test.py index f16d7af27d..0915447473 100644 --- a/grr/client/grr_response_client/client_actions/standard_test.py +++ b/grr/client/grr_response_client/client_actions/standard_test.py @@ -34,7 +34,8 @@ class TestExecutePython(client_test_lib.EmptyActionTest): def setUp(self): super().setUp() self.signing_key = config.CONFIG[ - "PrivateKeys.executable_signing_private_key"] + "PrivateKeys.executable_signing_private_key" + ] def testExecutePython(self): """Test the basic ExecutePython action.""" @@ -132,15 +133,23 @@ def testExecuteModifiedPython(self): request = rdf_client_action.ExecutePythonRequest(python_code=signed_blob) # Should raise since the code has been modified. - self.assertRaises(rdf_crypto.VerificationError, self.RunAction, - standard.ExecutePython, request) + self.assertRaises( + rdf_crypto.VerificationError, + self.RunAction, + standard.ExecutePython, + request, + ) # Lets also adjust the hash. signed_blob.digest = hashlib.sha256(signed_blob.data).digest() request = rdf_client_action.ExecutePythonRequest(python_code=signed_blob) - self.assertRaises(rdf_crypto.VerificationError, self.RunAction, - standard.ExecutePython, request) + self.assertRaises( + rdf_crypto.VerificationError, + self.RunAction, + standard.ExecutePython, + request, + ) # Make sure the code never ran. self.assertEqual(sys.TEST_VAL, "original") @@ -152,8 +161,9 @@ def testExecuteBrokenPython(self): signed_blob.Sign(python_code.encode("utf-8"), self.signing_key) request = rdf_client_action.ExecutePythonRequest(python_code=signed_blob) - self.assertRaises(ValueError, self.RunAction, standard.ExecutePython, - request) + self.assertRaises( + ValueError, self.RunAction, standard.ExecutePython, request + ) def testExecuteBinary(self): """Test the basic ExecuteBinaryCommand action.""" @@ -166,14 +176,16 @@ def testExecuteBinary(self): signed_blob.Sign(open(cmd, "rb").read(), self.signing_key) request = rdf_client_action.ExecuteBinaryRequest( - executable=signed_blob, args=args, write_path="ablob") + executable=signed_blob, args=args, write_path="ablob" + ) result = self.RunAction(standard.ExecuteBinaryCommand, request)[0] if platform.system() != "Windows": # Windows time resolution is too coarse. self.assertGreater(result.time_used, 0) - self.assertEqual("foobar{}".format(os.linesep).encode("utf-8"), - result.stdout) + self.assertEqual( + "foobar{}".format(os.linesep).encode("utf-8"), result.stdout + ) def testReturnVals(self): """Test return values.""" @@ -194,8 +206,12 @@ def testWrongKey(self): signed_blob = rdf_crypto.SignedBlob() signed_blob.Sign(python_code.encode("utf-8"), signing_key) request = rdf_client_action.ExecutePythonRequest(python_code=signed_blob) - self.assertRaises(rdf_crypto.VerificationError, self.RunAction, - standard.ExecutePython, request) + self.assertRaises( + rdf_crypto.VerificationError, + self.RunAction, + standard.ExecutePython, + request, + ) def testArgs(self): """Test passing arguments.""" @@ -210,7 +226,8 @@ def testArgs(self): signed_blob.Sign(python_code.encode("utf-8"), self.signing_key) pdict = rdf_protodict.Dict({"test": "dict_arg", 43: "dict_arg2"}) request = rdf_client_action.ExecutePythonRequest( - python_code=signed_blob, py_args=pdict) + python_code=signed_blob, py_args=pdict + ) result = self.RunAction(standard.ExecutePython, request)[0] self.assertEqual(result.return_val, "dict_arg") self.assertEqual(sys.TEST_VAL, "dict_arg2") @@ -219,15 +236,18 @@ def testArgs(self): class GetFileStatTest(client_test_lib.EmptyActionTest): # TODO: - @unittest.skipIf(platform.system() == "Windows", - "Skipping due to temp file locking issues.") + @unittest.skipIf( + platform.system() == "Windows", + "Skipping due to temp file locking issues.", + ) def testStatSize(self): with temp.AutoTempFilePath() as temp_filepath: with io.open(temp_filepath, "wb") as temp_file: temp_file.write(b"123456") pathspec = rdf_paths.PathSpec( - path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS) + path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_action.GetFileStatRequest(pathspec=pathspec) results = self.RunAction(standard.GetFileStat, request) @@ -238,13 +258,16 @@ def testStatSize(self): def testStatExtAttrsEnabled(self): with temp.AutoTempFilePath() as temp_filepath: filesystem_test_lib.SetExtAttr( - temp_filepath, name="user.foo", value="bar") + temp_filepath, name="user.foo", value="bar" + ) pathspec = rdf_paths.PathSpec( - path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS) + path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_action.GetFileStatRequest( - pathspec=pathspec, collect_ext_attrs=True) + pathspec=pathspec, collect_ext_attrs=True + ) results = self.RunAction(standard.GetFileStat, request) self.assertLen(results, 1) @@ -255,13 +278,16 @@ def testStatExtAttrsEnabled(self): def testStatExtAttrsDisabled(self): with temp.AutoTempFilePath() as temp_filepath: filesystem_test_lib.SetExtAttr( - temp_filepath, name="user.foo", value="bar") + temp_filepath, name="user.foo", value="bar" + ) pathspec = rdf_paths.PathSpec( - path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS) + path=temp_filepath, pathtype=rdf_paths.PathSpec.PathType.OS + ) request = rdf_client_action.GetFileStatRequest( - pathspec=pathspec, collect_ext_attrs=False) + pathspec=pathspec, collect_ext_attrs=False + ) results = self.RunAction(standard.GetFileStat, request) self.assertLen(results, 1) @@ -321,7 +347,8 @@ class TestNetworkByteLimits(client_test_lib.EmptyActionTest): def setUp(self): super().setUp() pathspec = rdf_paths.PathSpec( - path="/nothing", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/nothing", pathtype=rdf_paths.PathSpec.PathType.OS + ) self.buffer_ref = rdf_client.BufferReference(pathspec=pathspec, length=5000) self.data = b"X" * 500 @@ -336,7 +363,8 @@ def testTransferNetworkByteLimitError(self): name="TransferBuffer", payload=self.buffer_ref, network_bytes_limit=300, - generate_task_id=True) + generate_task_id=True, + ) # We just get a client alert and a status message back. responses = self.transfer_buf.HandleMessage(message) @@ -346,22 +374,25 @@ def testTransferNetworkByteLimitError(self): status = responses[1].payload self.assertIn("Action exceeded network send limit", str(status.backtrace)) - self.assertEqual(status.status, - rdf_flows.GrrStatus.ReturnedStatus.NETWORK_LIMIT_EXCEEDED) + self.assertEqual( + status.status, rdf_flows.GrrStatus.ReturnedStatus.NETWORK_LIMIT_EXCEEDED + ) def testTransferNetworkByteLimit(self): message = rdf_flows.GrrMessage( name="TransferBuffer", payload=self.buffer_ref, network_bytes_limit=900, - generate_task_id=True) + generate_task_id=True, + ) responses = self.transfer_buf.HandleMessage(message) for response in responses: if isinstance(response, rdf_flows.GrrStatus): - self.assertEqual(response.payload.status, - rdf_flows.GrrStatus.ReturnedStatus.OK) + self.assertEqual( + response.payload.status, rdf_flows.GrrStatus.ReturnedStatus.OK + ) def main(argv): diff --git a/grr/client/grr_response_client/client_actions/tempfiles.py b/grr/client/grr_response_client/client_actions/tempfiles.py index 36dc114bdc..802be0ccc7 100644 --- a/grr/client/grr_response_client/client_actions/tempfiles.py +++ b/grr/client/grr_response_client/client_actions/tempfiles.py @@ -9,7 +9,6 @@ import stat import sys import tempfile -import threading from typing import Text import psutil @@ -75,8 +74,10 @@ def EnsureTempDirIsSane(directory): # Make directory 700 before we write the file if sys.platform == "win32": from grr_response_client import client_utils_windows # pylint: disable=g-import-not-at-top - client_utils_windows.WinChmod(directory, - ["FILE_GENERIC_READ", "FILE_GENERIC_WRITE"]) + + client_utils_windows.WinChmod( + directory, ["FILE_GENERIC_READ", "FILE_GENERIC_WRITE"] + ) else: os.chmod(directory, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR) @@ -92,8 +93,8 @@ class TemporaryDirectory(object): Attributes: path: The path to the temporary directory. cleanup: A boolean to delete the directory when exiting. Note that if an - exception is raised, the directory will be deleted regardless of the - value of cleanup. + exception is raised, the directory will be deleted regardless of the value + of cleanup. """ def __init__(self, cleanup=True): @@ -112,7 +113,7 @@ def __exit__(self, exc_type, unused_exc_value, unused_traceback): shutil.rmtree(self.path, ignore_errors=True) -def CreateGRRTempFile(filename=None, lifetime=0, mode="w+b", suffix=""): +def CreateGRRTempFile(filename=None, mode="w+b"): """Open file with GRR prefix in directory to allow easy deletion. Missing parent dirs will be created. If an existing directory is specified @@ -124,19 +125,11 @@ def CreateGRRTempFile(filename=None, lifetime=0, mode="w+b", suffix=""): the caller doesn't specify a directory on windows we use the directory we are executing from as a safe default. - If lifetime is specified a housekeeping thread is created to delete the file - after lifetime seconds. Files won't be deleted by default. - Args: filename: The name of the file to use. Note that setting both filename and - directory name is not allowed. - - lifetime: time in seconds before we should delete this tempfile. - + directory name is not allowed. mode: The mode to open the file. - suffix: optional suffix to use for the temp file - Returns: Python file object @@ -152,24 +145,19 @@ def CreateGRRTempFile(filename=None, lifetime=0, mode="w+b", suffix=""): prefix = config.CONFIG.Get("Client.tempfile_prefix") if filename is None: outfile = tempfile.NamedTemporaryFile( - prefix=prefix, suffix=suffix, dir=directory, delete=False) + prefix=prefix, dir=directory, delete=False + ) else: if filename.startswith("/") or filename.startswith("\\"): raise ValueError("Filename must be relative") - if suffix: - filename = "%s.%s" % (filename, suffix) - outfile = open(os.path.join(directory, filename), mode) - if lifetime > 0: - cleanup = threading.Timer(lifetime, DeleteGRRTempFile, (outfile.name,)) - cleanup.start() - # Fix perms on the file, since this code is used for writing executable blobs # we apply RWX. if sys.platform == "win32": from grr_response_client import client_utils_windows # pylint: disable=g-import-not-at-top + client_utils_windows.WinChmod(outfile.name, ["FILE_ALL_ACCESS"]) else: os.chmod(outfile.name, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR) @@ -177,7 +165,7 @@ def CreateGRRTempFile(filename=None, lifetime=0, mode="w+b", suffix=""): return outfile -def CreateGRRTempFileVFS(filename=None, lifetime=0, mode="w+b", suffix=""): +def CreateGRRTempFileVFS(filename=None, mode="w+b"): """Creates a GRR VFS temp file. This function is analogous to CreateGRRTempFile but returns an open VFS handle @@ -185,22 +173,17 @@ def CreateGRRTempFileVFS(filename=None, lifetime=0, mode="w+b", suffix=""): Args: filename: The name of the file to use. Note that setting both filename and - directory name is not allowed. - - lifetime: time in seconds before we should delete this tempfile. - + directory name is not allowed. mode: The mode to open the file. - suffix: optional suffix to use for the temp file - Returns: An open file handle to the new file and the corresponding pathspec. """ - fd = CreateGRRTempFile( - filename=filename, lifetime=lifetime, mode=mode, suffix=suffix) + fd = CreateGRRTempFile(filename=filename, mode=mode) pathspec = rdf_paths.PathSpec( - path=fd.name, pathtype=rdf_paths.PathSpec.PathType.TMPFILE) + path=fd.name, pathtype=rdf_paths.PathSpec.PathType.TMPFILE + ) return fd, pathspec @@ -246,9 +229,12 @@ def DeleteGRRTempFile(path): GetTempDirForRoot(root) for root in config.CONFIG["Client.tempdir_roots"] ] if not _CheckIfPathIsValidForDeletion( - path, prefix=prefix, directories=directories): - msg = ("Can't delete temp file %s. Filename must start with %s " - "or lie within any of %s.") + path, prefix=prefix, directories=directories + ): + msg = ( + "Can't delete temp file %s. Filename must start with %s " + "or lie within any of %s." + ) raise ErrorNotTempFile(msg % (path, prefix, ";".join(directories))) if os.path.exists(path): @@ -261,6 +247,7 @@ def DeleteGRRTempFile(path): class DeleteGRRTempFiles(actions.ActionPlugin): """Delete all the GRR temp files in a directory.""" + in_rdfvalue = rdf_paths.PathSpec out_rdfvalues = [rdf_client.LogMessage] @@ -273,8 +260,9 @@ def Run(self, args): If path is a regular file and starts with Client.tempfile_prefix delete it. Args: - args: pathspec pointing to directory containing temp files to be - deleted, or a single file to be deleted. + args: pathspec pointing to directory containing temp files to be deleted, + or a single file to be deleted. + Returns: deleted: array of filename strings that were deleted Raises: @@ -346,4 +334,5 @@ def Run(self, args): path = GetDefaultGRRTempDirectory() total, used, free, _ = psutil.disk_usage(path) self.SendReply( - rdf_client_fs.DiskUsage(path=path, total=total, used=used, free=free)) + rdf_client_fs.DiskUsage(path=path, total=total, used=used, free=free) + ) diff --git a/grr/client/grr_response_client/client_actions/tempfiles_test.py b/grr/client/grr_response_client/client_actions/tempfiles_test.py index 87f8aee1d4..b9fb4d6062 100644 --- a/grr/client/grr_response_client/client_actions/tempfiles_test.py +++ b/grr/client/grr_response_client/client_actions/tempfiles_test.py @@ -27,11 +27,12 @@ def setUp(self): # For this test it has to be different from the temp directory # so we create a new one. self.client_tempdir = tempfile.mkdtemp( - dir=config.CONFIG.Get("Client.tempdir_roots")[0]) + dir=config.CONFIG.Get("Client.tempdir_roots")[0] + ) tempdir_overrider = test_lib.ConfigOverrider({ "Client.tempdir_roots": [os.path.dirname(self.client_tempdir)], - "Client.grr_tempdir": os.path.basename(self.client_tempdir) + "Client.grr_tempdir": os.path.basename(self.client_tempdir), }) tempdir_overrider.Start() self.addCleanup(tempdir_overrider.Stop) @@ -53,8 +54,9 @@ def testCreateAndDelete(self): with io.open(os.path.join(self.temp_dir, "notatmpfile"), "wb") as fd: fd.write(b"something") self.assertTrue(os.path.exists(fd.name)) - self.assertRaises(tempfiles.ErrorNotTempFile, tempfiles.DeleteGRRTempFile, - fd.name) + self.assertRaises( + tempfiles.ErrorNotTempFile, tempfiles.DeleteGRRTempFile, fd.name + ) self.assertTrue(os.path.exists(fd.name)) def testWrongOwnerGetsFixed(self): @@ -98,7 +100,7 @@ def setUp(self): os.makedirs(self.dirname) tempdir_overrider = test_lib.ConfigOverrider({ "Client.tempdir_roots": [os.path.dirname(self.dirname)], - "Client.grr_tempdir": os.path.basename(self.dirname) + "Client.grr_tempdir": os.path.basename(self.dirname), }) tempdir_overrider.Start() self.addCleanup(tempdir_overrider.Stop) @@ -114,7 +116,8 @@ def setUp(self): self.assertTrue(os.path.exists(self.temp_fd2.name)) self.pathspec = rdf_paths.PathSpec( - path=self.dirname, pathtype=rdf_paths.PathSpec.PathType.OS) + path=self.dirname, pathtype=rdf_paths.PathSpec.PathType.OS + ) def _SetUpTempDirStructure(self, grr_tempdir="grr_temp"): temproot1 = utils.JoinPath(self.temp_dir, "del_test1") @@ -146,21 +149,26 @@ def _SetUpTempDirStructure(self, grr_tempdir="grr_temp"): self.assertTrue(os.path.exists(file2)) self.assertTrue(os.path.exists(not_a_grr_file1)) self.assertTrue(os.path.exists(not_a_grr_file2)) - return ([temproot1, temproot2, temproot3], [tempdir1, tempdir2], [tempdir3], - [file1, file2], [not_a_grr_file1, not_a_grr_file1]) + return ( + [temproot1, temproot2, temproot3], + [tempdir1, tempdir2], + [tempdir3], + [file1, file2], + [not_a_grr_file1, not_a_grr_file1], + ) def testDeleteMultipleRoots(self): temp_dir = "grr_temp" test_data = self._SetUpTempDirStructure(temp_dir) roots, _, invalid_temp_dirs, temp_files, other_files = test_data - with test_lib.ConfigOverrider({ - "Client.tempdir_roots": roots, - "Client.grr_tempdir": temp_dir - }): + with test_lib.ConfigOverrider( + {"Client.tempdir_roots": roots, "Client.grr_tempdir": temp_dir} + ): - result = self.RunAction(tempfiles.DeleteGRRTempFiles, - rdf_paths.PathSpec()) + result = self.RunAction( + tempfiles.DeleteGRRTempFiles, rdf_paths.PathSpec() + ) self.assertLen(result, 1) log = result[0].data for f in temp_files: @@ -178,21 +186,24 @@ def testDeleteFilesInRoot(self): test_data = self._SetUpTempDirStructure(temp_dir) roots, _, _, temp_files, other_files = test_data - with test_lib.ConfigOverrider({ - "Client.tempdir_roots": roots, - "Client.grr_tempdir": temp_dir - }): + with test_lib.ConfigOverrider( + {"Client.tempdir_roots": roots, "Client.grr_tempdir": temp_dir} + ): for f in temp_files: - result = self.RunAction(tempfiles.DeleteGRRTempFiles, - rdf_paths.PathSpec(path=f)) + result = self.RunAction( + tempfiles.DeleteGRRTempFiles, rdf_paths.PathSpec(path=f) + ) self.assertLen(result, 1) self.assertIn(f, result[0].data) for f in other_files: - self.assertRaises(tempfiles.ErrorNotTempFile, self.RunAction, - tempfiles.DeleteGRRTempFiles, - rdf_paths.PathSpec(path=f)) + self.assertRaises( + tempfiles.ErrorNotTempFile, + self.RunAction, + tempfiles.DeleteGRRTempFiles, + rdf_paths.PathSpec(path=f), + ) def testDeleteGRRTempFilesInDirectory(self): result = self.RunAction(tempfiles.DeleteGRRTempFiles, self.pathspec)[0] @@ -204,7 +215,8 @@ def testDeleteGRRTempFilesInDirectory(self): def testDeleteGRRTempFilesSpecificPath(self): self.pathspec = rdf_paths.PathSpec( - path=self.temp_fd.name, pathtype=rdf_paths.PathSpec.PathType.OS) + path=self.temp_fd.name, pathtype=rdf_paths.PathSpec.PathType.OS + ) result = self.RunAction(tempfiles.DeleteGRRTempFiles, self.pathspec)[0] self.assertTrue(os.path.exists(self.not_tempfile)) self.assertFalse(os.path.exists(self.temp_fd.name)) @@ -214,9 +226,14 @@ def testDeleteGRRTempFilesSpecificPath(self): def testDeleteGRRTempFilesPathDoesNotExist(self): self.pathspec = rdf_paths.PathSpec( - path="/does/not/exist", pathtype=rdf_paths.PathSpec.PathType.OS) - self.assertRaises(tempfiles.ErrorBadPath, self.RunAction, - tempfiles.DeleteGRRTempFiles, self.pathspec) + path="/does/not/exist", pathtype=rdf_paths.PathSpec.PathType.OS + ) + self.assertRaises( + tempfiles.ErrorBadPath, + self.RunAction, + tempfiles.DeleteGRRTempFiles, + self.pathspec, + ) def testOneFileFails(self): # Sneak in a non existing file. diff --git a/grr/client/grr_response_client/client_actions/timeline.py b/grr/client/grr_response_client/client_actions/timeline.py index 73aa163d4c..aaa712f78f 100644 --- a/grr/client/grr_response_client/client_actions/timeline.py +++ b/grr/client/grr_response_client/client_actions/timeline.py @@ -4,7 +4,6 @@ import hashlib import os import stat as stat_mode - from typing import Iterator from typing import Optional diff --git a/grr/client/grr_response_client/client_actions/timeline_test.py b/grr/client/grr_response_client/client_actions/timeline_test.py index c915f90533..46bc768cb5 100644 --- a/grr/client/grr_response_client/client_actions/timeline_test.py +++ b/grr/client/grr_response_client/client_actions/timeline_test.py @@ -126,22 +126,26 @@ def testNestedDirectories(self): self.assertLen(entries, 7) paths = [_.path.decode("utf-8") for _ in entries] - self.assertCountEqual(paths, [ - os.path.join(root_dirpath), - os.path.join(root_dirpath, "foo"), - os.path.join(root_dirpath, "foo", "bar"), - os.path.join(root_dirpath, "foo", "baz"), - os.path.join(root_dirpath, "quux"), - os.path.join(root_dirpath, "quux", "norf"), - os.path.join(root_dirpath, "quux", "norf", "thud"), - ]) + self.assertCountEqual( + paths, + [ + os.path.join(root_dirpath), + os.path.join(root_dirpath, "foo"), + os.path.join(root_dirpath, "foo", "bar"), + os.path.join(root_dirpath, "foo", "baz"), + os.path.join(root_dirpath, "quux"), + os.path.join(root_dirpath, "quux", "norf"), + os.path.join(root_dirpath, "quux", "norf", "thud"), + ], + ) for entry in entries: self.assertTrue(stat_mode.S_ISDIR(entry.mode)) @skip.If( platform.system() == "Windows", - reason="Symlinks are not supported on Windows.") + reason="Symlinks are not supported on Windows.", + ) def testSymlinks(self): with temp.AutoTempDirPath(remove_non_empty=True) as root_dirpath: sub_dirpath = os.path.join(root_dirpath, "foo", "bar", "baz") @@ -155,13 +159,16 @@ def testSymlinks(self): self.assertLen(entries, 5) paths = [_.path.decode("utf-8") for _ in entries] - self.assertEqual(paths, [ - os.path.join(root_dirpath), - os.path.join(root_dirpath, "foo"), - os.path.join(root_dirpath, "foo", "bar"), - os.path.join(root_dirpath, "foo", "bar", "baz"), - os.path.join(root_dirpath, "foo", "bar", "baz", "quux") - ]) + self.assertEqual( + paths, + [ + os.path.join(root_dirpath), + os.path.join(root_dirpath, "foo"), + os.path.join(root_dirpath, "foo", "bar"), + os.path.join(root_dirpath, "foo", "bar", "baz"), + os.path.join(root_dirpath, "foo", "bar", "baz", "quux"), + ], + ) for entry in entries[:-1]: self.assertTrue(stat_mode.S_ISDIR(entry.mode)) diff --git a/grr/client/grr_response_client/client_actions/vfs_file_finder.py b/grr/client/grr_response_client/client_actions/vfs_file_finder.py index 1d6318618d..ed3d0c3699 100644 --- a/grr/client/grr_response_client/client_actions/vfs_file_finder.py +++ b/grr/client/grr_response_client/client_actions/vfs_file_finder.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """The file finder client action.""" -from typing import Callable, Text, Iterator +from typing import Callable, Iterator, Text from grr_response_client import actions from grr_response_client import client_utils @@ -27,9 +27,11 @@ class VfsFileFinder(actions.ActionPlugin): def Run(self, args: rdf_file_finder.FileFinderArgs): action = self._ParseAction(args) content_conditions = list( - conditions.ContentCondition.Parse(args.conditions)) + conditions.ContentCondition.Parse(args.conditions) + ) metadata_conditions = list( - conditions.MetadataCondition.Parse(args.conditions)) + conditions.MetadataCondition.Parse(args.conditions) + ) for path in _GetExpandedPaths(args, heartbeat_cb=self.Progress): self.Progress() @@ -56,7 +58,9 @@ def Run(self, args: rdf_file_finder.FileFinderArgs): self.SendReply(result) def _ParseAction( - self, args: rdf_file_finder.FileFinderArgs) -> vfs_subactions.Action: + self, + args: rdf_file_finder.FileFinderArgs, + ) -> vfs_subactions.Action: action_type = args.action.action_type if action_type == rdf_file_finder.FileFinderAction.Action.HASH: return vfs_subactions.HashAction(self, args.action.hash) @@ -99,7 +103,8 @@ def _GetExpandedPaths( opts = globbing.PathOpts( follow_links=args.follow_links, pathtype=args.pathtype, - implementation_type=implementation_type) + implementation_type=implementation_type, + ) for path in args.paths: for expanded_path in globbing.ExpandPath(str(path), opts, heartbeat_cb): @@ -119,6 +124,7 @@ def RegistryKeyFromClient(args: rdf_file_finder.FileFinderArgs): """ for path in _GetExpandedPaths(args): pathspec = rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY) + path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY + ) with vfs.VFSOpen(pathspec) as file_obj: yield file_obj.Stat() diff --git a/grr/client/grr_response_client/client_actions/vfs_file_finder_test.py b/grr/client/grr_response_client/client_actions/vfs_file_finder_test.py index d1893873bb..0be63913b3 100644 --- a/grr/client/grr_response_client/client_actions/vfs_file_finder_test.py +++ b/grr/client/grr_response_client/client_actions/vfs_file_finder_test.py @@ -5,7 +5,7 @@ import hashlib import os import platform -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import unittest from unittest import mock import zlib @@ -18,7 +18,6 @@ from grr_response_client.client_actions import vfs_file_finder from grr_response_client.client_actions.file_finder_utils import globbing from grr_response_client.vfs_handlers import files - from grr_response_core import config from grr_response_core.lib import config_lib from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder @@ -82,7 +81,7 @@ def _GroupItemsByType(iterable): def _RunFileFinder( - args: rdf_file_finder.FileFinderArgs + args: rdf_file_finder.FileFinderArgs, ) -> List[rdf_file_finder.FileFinderResult]: results = [] @@ -118,7 +117,9 @@ def testStatDoesNotFailForInaccessiblePath(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SAM/SAM/FOOBAR"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) @@ -127,11 +128,15 @@ def testCaseInsensitivitiy(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/AaA"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertLen(results, 1) - self.assertEqual(results[0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertEqual( + results[0].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") def testStatExactPath(self): @@ -139,11 +144,15 @@ def testStatExactPath(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertLen(results, 1) - self.assertEqual(results[0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertEqual( + results[0].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") self.assertEqual(results[0].stat_entry.st_size, 6) @@ -152,11 +161,15 @@ def testStatExactPathInWindowsNativeFormat(self): rdf_file_finder.FileFinderArgs( paths=[r"HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertLen(results, 1) - self.assertEqual(results[0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertEqual( + results[0].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") self.assertEqual(results[0].stat_entry.st_size, 6) @@ -167,12 +180,15 @@ def testStatLongUnicodeName(self): "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY) ], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertLen(results, 1) self.assertEqual( results[0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY)) + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY), + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") def testStatKeyWithDefaultValue(self): @@ -180,11 +196,15 @@ def testStatKeyWithDefaultValue(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertLen(results, 1) - self.assertEqual(results[0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1") + self.assertEqual( + results[0].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1", + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") self.assertEqual(results[0].stat_entry.st_size, 13) @@ -193,12 +213,16 @@ def testDownloadExactPath(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"))) + action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"), + ) + ) self.assertLen(results, 2) self.assertEqual(_DecodeDataBlob(results[0]), "lolcat") - self.assertEqual(results[1].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertEqual( + results[1].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + ) def testDownloadUnicode(self): results = _RunFileFinder( @@ -207,23 +231,29 @@ def testDownloadUnicode(self): "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY) ], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"))) + action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"), + ) + ) self.assertLen(results, 2) res_by_type = _GroupItemsByType(results) self.assertEqual( - _DecodeDataBlob(res_by_type["DataBlob"][0]), _LONG_STRING_VALUE) + _DecodeDataBlob(res_by_type["DataBlob"][0]), _LONG_STRING_VALUE + ) self.assertEqual( res_by_type["FileFinderResult"][0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY)) + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY), + ) def testDownloadDword(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"))) + action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"), + ) + ) self.assertLen(results, 2) res_by_type = _GroupItemsByType(results) @@ -231,21 +261,28 @@ def testDownloadDword(self): self.assertEqual(_DecodeDataBlob(res_by_type["DataBlob"][0]), "4294967295") self.assertEqual( res_by_type["FileFinderResult"][0].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba") + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", + ) def testDownloadGlob(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/a*"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"))) + action=rdf_file_finder.FileFinderAction(action_type="DOWNLOAD"), + ) + ) self.assertLen(results, 4) - self.assertEqual(results[1].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertEqual( + results[1].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + ) self.assertEqual(_DecodeDataBlob(results[0]), "lolcat") - self.assertEqual(results[3].stat_entry.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba") + self.assertEqual( + results[3].stat_entry.pathspec.path, + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", + ) self.assertEqual(_DecodeDataBlob(results[2]), "4294967295") def testHashExactPath(self): @@ -253,16 +290,24 @@ def testHashExactPath(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction.Hash())) + action=rdf_file_finder.FileFinderAction.Hash(), + ) + ) self.assertLen(results, 1) self.assertEqual(results[0].hash_entry.num_bytes, 6) - self.assertEqual(results[0].hash_entry.md5.HexDigest(), - hashlib.md5(b"lolcat").hexdigest()) - self.assertEqual(results[0].hash_entry.sha1.HexDigest(), - hashlib.sha1(b"lolcat").hexdigest()) - self.assertEqual(results[0].hash_entry.sha256.HexDigest(), - hashlib.sha256(b"lolcat").hexdigest()) + self.assertEqual( + results[0].hash_entry.md5.HexDigest(), + hashlib.md5(b"lolcat").hexdigest(), + ) + self.assertEqual( + results[0].hash_entry.sha1.HexDigest(), + hashlib.sha1(b"lolcat").hexdigest(), + ) + self.assertEqual( + results[0].hash_entry.sha256.HexDigest(), + hashlib.sha256(b"lolcat").hexdigest(), + ) def testHashSkipExactPath(self): results = _RunFileFinder( @@ -270,7 +315,10 @@ def testHashSkipExactPath(self): paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], pathtype="REGISTRY", action=rdf_file_finder.FileFinderAction.Hash( - max_size=5, oversized_file_policy="SKIP"))) + max_size=5, oversized_file_policy="SKIP" + ), + ) + ) self.assertLen(results, 1) self.assertFalse(results[0].HasField("hash")) @@ -280,32 +328,48 @@ def testHashTruncateExactPath(self): paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], pathtype="REGISTRY", action=rdf_file_finder.FileFinderAction.Hash( - max_size=5, oversized_file_policy="HASH_TRUNCATED"))) + max_size=5, oversized_file_policy="HASH_TRUNCATED" + ), + ) + ) self.assertLen(results, 1) - self.assertEqual(results[0].hash_entry.md5.HexDigest(), - hashlib.md5(b"lolca").hexdigest()) - self.assertEqual(results[0].hash_entry.sha1.HexDigest(), - hashlib.sha1(b"lolca").hexdigest()) - self.assertEqual(results[0].hash_entry.sha256.HexDigest(), - hashlib.sha256(b"lolca").hexdigest()) + self.assertEqual( + results[0].hash_entry.md5.HexDigest(), hashlib.md5(b"lolca").hexdigest() + ) + self.assertEqual( + results[0].hash_entry.sha1.HexDigest(), + hashlib.sha1(b"lolca").hexdigest(), + ) + self.assertEqual( + results[0].hash_entry.sha256.HexDigest(), + hashlib.sha256(b"lolca").hexdigest(), + ) def testStatSingleGlob(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/a*"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", - ]) - - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", - ]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", + ], + ) + + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", + ], + ) self.assertEqual(results[0].stat_entry.pathspec.pathtype, "REGISTRY") self.assertEqual(results[1].stat_entry.pathspec.pathtype, "REGISTRY") @@ -314,7 +378,9 @@ def testQuestionMarkMatchesOneCharacterOnly(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/a?"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) def testQuestionMarkIsWildcard(self): @@ -322,19 +388,26 @@ def testQuestionMarkIsWildcard(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/a?a"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", - ]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aba", + ], + ) def testStatEmptyGlob(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/nonexistent*"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) def testStatNonExistentPath(self): @@ -342,7 +415,9 @@ def testStatNonExistentPath(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/nonexistent"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) def testStatRecursiveGlobDefaultLevel(self): @@ -350,42 +425,57 @@ def testStatRecursiveGlobDefaultLevel(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/**/aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", - ]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", + ], + ) def testStatRecursiveGlobCustomLevel(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/**4/aaa"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aaa", - ]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aaa", + ], + ) def testStatRecursiveGlobAndRegularGlob(self): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/**4/a*"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], [ - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aba", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aba", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aba", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aaa", - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aba", - ]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + [ + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/aba", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/aba", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/aba", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aaa", + "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/1/2/3/4/aba", + ], + ) def testRecursiveGlobCallsProgressWithoutMatches(self): progress = mock.MagicMock() @@ -395,7 +485,9 @@ def testRecursiveGlobCallsProgressWithoutMatches(self): rdf_file_finder.FileFinderArgs( paths=["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/**4/nonexistent"], pathtype="REGISTRY", - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) # progress.call_count should rise linearly to the number of keys and @@ -409,11 +501,16 @@ def testMetadataConditionMatch(self): pathtype="REGISTRY", conditions=[ rdf_file_finder.FileFinderCondition.Size( - min_file_size=6, max_file_size=6) + min_file_size=6, max_file_size=6 + ) ], - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], - ["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + ["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], + ) def testSkipsIfMetadataConditionDoesNotMatch(self): results = _RunFileFinder( @@ -422,11 +519,15 @@ def testSkipsIfMetadataConditionDoesNotMatch(self): pathtype="REGISTRY", conditions=[ rdf_file_finder.FileFinderCondition.Size( - min_file_size=6, max_file_size=6), + min_file_size=6, max_file_size=6 + ), rdf_file_finder.FileFinderCondition.Size( - min_file_size=0, max_file_size=0), + min_file_size=0, max_file_size=0 + ), ], - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) def testContentConditionMatch(self): @@ -436,13 +537,19 @@ def testContentConditionMatch(self): pathtype="REGISTRY", conditions=[ rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=b"lol") + literal=b"lol" + ) ], - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) - self.assertCountEqual([res.stat_entry.pathspec.path for res in results], - ["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"]) - self.assertCountEqual([match.data for match in results[0].matches], - [b"lol"]) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) + self.assertCountEqual( + [res.stat_entry.pathspec.path for res in results], + ["/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa"], + ) + self.assertCountEqual( + [match.data for match in results[0].matches], [b"lol"] + ) def testSkipsIfContentConditionDoesNotMatch(self): results = _RunFileFinder( @@ -451,20 +558,28 @@ def testSkipsIfContentConditionDoesNotMatch(self): pathtype="REGISTRY", conditions=[ rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=b"lol"), + literal=b"lol" + ), rdf_file_finder.FileFinderCondition.ContentsLiteralMatch( - literal=b"foo") + literal=b"foo" + ), ], - action=rdf_file_finder.FileFinderAction(action_type="STAT"))) + action=rdf_file_finder.FileFinderAction(action_type="STAT"), + ) + ) self.assertEmpty(results) def testGlobbingKeyDoesNotYieldDuplicates(self): opts = globbing.PathOpts(pathtype=rdf_paths.PathSpec.PathType.REGISTRY) results = globbing.ExpandGlobs( - r"HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\*\aaa", opts) - self.assertCountEqual(results, [ - r"HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\1\aaa", - ]) + r"HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\*\aaa", opts + ) + self.assertCountEqual( + results, + [ + r"HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\1\aaa", + ], + ) class OsTest(absltest.TestCase): @@ -481,9 +596,12 @@ def testRecursiveRegexMatch(self) -> None: pathtype=rdf_paths.PathSpec.PathType.OS, conditions=[ rdf_file_finder.FileFinderCondition.ContentsRegexMatch( - regex=b"bar[0-9]+"), + regex=b"bar[0-9]+" + ), ], - action=rdf_file_finder.FileFinderAction.Stat())) + action=rdf_file_finder.FileFinderAction.Stat(), + ) + ) self.assertLen(results, 1) self.assertEqual(results[0].matches[0].data, b"bar123") files.FlushHandleCache() @@ -508,21 +626,25 @@ def _MockGetRawDevice(self, path: str) -> Tuple[rdf_paths.PathSpec, str]: pathspec = rdf_paths.PathSpec( path=ntfs_img_path, pathtype=rdf_paths.PathSpec.PathType.OS, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL) + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ) if platform.system() == "Windows": - return (pathspec, path[len(self._root):]) + return (pathspec, path[len(self._root) :]) else: return (pathspec, path) def testListRootDirectory(self): with mock.patch.object( - client_utils, "GetRawDevice", new=self._MockGetRawDevice): + client_utils, "GetRawDevice", new=self._MockGetRawDevice + ): results = _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=[self._paths_expr], pathtype=self.pathtype, - action=rdf_file_finder.FileFinderAction.Stat())) + action=rdf_file_finder.FileFinderAction.Stat(), + ) + ) names = [ result.stat_entry.pathspec.nested_path.path for result in results ] @@ -533,21 +655,27 @@ def testImplementationType(self) -> None: orig_vfs_open = vfs.VFSOpen def MockVfsOpen(pathspec, *args, **kwargs): - self.assertEqual(pathspec.implementation_type, - rdf_paths.PathSpec.ImplementationType.DIRECT) + self.assertEqual( + pathspec.implementation_type, + rdf_paths.PathSpec.ImplementationType.DIRECT, + ) return orig_vfs_open(pathspec, *args, **kwargs) with contextlib.ExitStack() as stack: stack.enter_context(mock.patch.object(vfs, "VFSOpen", new=MockVfsOpen)) stack.enter_context( mock.patch.object( - client_utils, "GetRawDevice", new=self._MockGetRawDevice)) + client_utils, "GetRawDevice", new=self._MockGetRawDevice + ) + ) _RunFileFinder( rdf_file_finder.FileFinderArgs( paths=[self._paths_expr], pathtype=self.pathtype, implementation_type=rdf_paths.PathSpec.ImplementationType.DIRECT, - action=rdf_file_finder.FileFinderAction.Stat())) + action=rdf_file_finder.FileFinderAction.Stat(), + ) + ) files.FlushHandleCache() diff --git a/grr/client/grr_response_client/client_actions/windows/pipes.py b/grr/client/grr_response_client/client_actions/windows/pipes.py index c34a7aedaf..a1b6b45392 100644 --- a/grr/client/grr_response_client/client_actions/windows/pipes.py +++ b/grr/client/grr_response_client/client_actions/windows/pipes.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with an action for collecting named pipes.""" + import contextlib import logging import os @@ -51,7 +52,9 @@ def ListNamedPipes() -> Iterator[rdf_client.NamedPipe]: # # https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-getnamedpipeclientcomputernamew # pytype: disable=module-attr - GetNamedPipeClientComputerNameW = ctypes.windll.kernel32.GetNamedPipeClientComputerNameW # pylint: disable=invalid-name + GetNamedPipeClientComputerNameW = ( # pylint: disable=invalid-name + ctypes.windll.kernel32.GetNamedPipeClientComputerNameW + ) # pytype: enable=module-attr GetNamedPipeClientComputerNameW.argtypes = [ ctypes.wintypes.HANDLE, @@ -66,8 +69,9 @@ def ListNamedPipes() -> Iterator[rdf_client.NamedPipe]: pipe.name = name try: - handle = win32file.CreateFile(f"\\\\.\\pipe\\{name}", 0, 0, None, - win32file.OPEN_EXISTING, 0, None) + handle = win32file.CreateFile( + f"\\\\.\\pipe\\{name}", 0, 0, None, win32file.OPEN_EXISTING, 0, None + ) except win32file.error as error: # There might be some permission issues. We log the error and skip getting # pipe details, but still yield a result with at least the name filled-in. diff --git a/grr/client/grr_response_client/client_actions/windows/pipes_test.py b/grr/client/grr_response_client/client_actions/windows/pipes_test.py index dd7c3a4020..896b2527f9 100644 --- a/grr/client/grr_response_client/client_actions/windows/pipes_test.py +++ b/grr/client/grr_response_client/client_actions/windows/pipes_test.py @@ -136,6 +136,7 @@ def testPid(self) -> None: class NamedPipeSpec: """A class with named pipe specification.""" + name: str open_mode: Optional[int] = None pipe_mode: Optional[int] = None diff --git a/grr/client/grr_response_client/client_actions/windows/windows.py b/grr/client/grr_response_client/client_actions/windows/windows.py index 34ac6eb173..0b7f732aeb 100644 --- a/grr/client/grr_response_client/client_actions/windows/windows.py +++ b/grr/client/grr_response_client/client_actions/windows/windows.py @@ -7,6 +7,11 @@ """ import binascii +import logging +import os +import subprocess +import time +import winreg import pythoncom import win32api @@ -14,11 +19,11 @@ import win32file import win32service import win32serviceutil -import winreg import wmi from grr_response_client import actions from grr_response_client.client_actions import standard +from grr_response_client.client_actions import tempfiles from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client_action as rdf_client_action from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs @@ -29,21 +34,32 @@ # Properties to remove from results sent to the server. # These properties are included with nearly every WMI object and use space. IGNORE_PROPS = [ - "CSCreationClassName", "CreationClassName", "OSName", "OSCreationClassName", - "WindowsVersion", "CSName", "__NAMESPACE", "__SERVER", "__PATH" + "CSCreationClassName", + "CreationClassName", + "OSName", + "OSCreationClassName", + "WindowsVersion", + "CSName", + "__NAMESPACE", + "__SERVER", + "__PATH", ] class GetInstallDate(actions.ActionPlugin): """Estimate the install date of this system.""" + out_rdfvalues = [rdf_protodict.DataBlob] def Run(self, unused_args): """Estimate the install date of this system.""" # Don't use winreg.KEY_WOW64_64KEY since it breaks on Windows 2000 - subkey = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, - "Software\\Microsoft\\Windows NT\\CurrentVersion", - 0, winreg.KEY_READ) + subkey = winreg.OpenKey( + winreg.HKEY_LOCAL_MACHINE, + "Software\\Microsoft\\Windows NT\\CurrentVersion", + 0, + winreg.KEY_READ, + ) install_date = winreg.QueryValueEx(subkey, "InstallDate") self.SendReply(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(install_date[0])) @@ -60,16 +76,18 @@ def EnumerateInterfacesFromClient(args): del args # Unused. pythoncom.CoInitialize() - for interface in (wmi.WMI().Win32_NetworkAdapterConfiguration() or []): + for interface in wmi.WMI().Win32_NetworkAdapterConfiguration() or []: addresses = [] for ip_address in interface.IPAddress or []: addresses.append( - rdf_client_network.NetworkAddress(human_readable_address=ip_address)) + rdf_client_network.NetworkAddress(human_readable_address=ip_address) + ) response = rdf_client_network.Interface(ifname=interface.Description) if interface.MACAddress: response.mac_address = binascii.unhexlify( - interface.MACAddress.replace(":", "")) + interface.MACAddress.replace(":", "") + ) if addresses: response.addresses = addresses @@ -82,6 +100,7 @@ class EnumerateInterfaces(actions.ActionPlugin): Win32_NetworkAdapterConfiguration definition: http://msdn.microsoft.com/en-us/library/aa394217(v=vs.85).aspx """ + out_rdfvalues = [rdf_client_network.Interface] def Run(self, args): @@ -103,14 +122,13 @@ def EnumerateFilesystemsFromClient(args): continue yield rdf_client_fs.Filesystem( - device=volume, - mount_point="/%s:/" % drive[0], - type=fs_type, - label=label) + device=volume, mount_point="/%s:/" % drive[0], type=fs_type, label=label + ) class EnumerateFilesystems(actions.ActionPlugin): """Enumerate all unique filesystems local to the system.""" + out_rdfvalues = [rdf_client_fs.Filesystem] def Run(self, args): @@ -120,12 +138,14 @@ def Run(self, args): def QueryService(svc_name): """Query service and get its config.""" - hscm = win32service.OpenSCManager(None, None, - win32service.SC_MANAGER_ALL_ACCESS) + hscm = win32service.OpenSCManager( + None, None, win32service.SC_MANAGER_ALL_ACCESS + ) result = None try: - hs = win32serviceutil.SmartOpenService(hscm, svc_name, - win32service.SERVICE_ALL_ACCESS) + hs = win32serviceutil.SmartOpenService( + hscm, svc_name, win32service.SERVICE_ALL_ACCESS + ) result = win32service.QueryServiceConfig(hs) win32service.CloseServiceHandle(hs) finally: @@ -148,6 +168,7 @@ def WmiQueryFromClient(args): class WmiQuery(actions.ActionPlugin): """Runs a WMI query and returns the results to a server callback.""" + in_rdfvalue = rdf_client_action.WMIRequest out_rdfvalues = [rdf_protodict.Dict] @@ -177,15 +198,17 @@ def RunWMIQuery(query, baseobj=r"winmgmts:\root\cimv2"): try: query_results = wmi_obj.ExecQuery(query) except pythoncom.com_error as e: - raise RuntimeError("Failed to run WMI query \'%s\' err was %s" % (query, e)) + raise RuntimeError( + "Failed to run WMI query '%s' err was %s" % (query, e) + ) from e # Extract results from the returned COMObject and return dicts. try: for result in query_results: response = rdf_protodict.Dict() - properties = ( - list(result.Properties_) + - list(getattr(result, "SystemProperties_", []))) + properties = list(result.Properties_) + list( + getattr(result, "SystemProperties_", []) + ) for prop in properties: if prop.Name not in IGNORE_PROPS: @@ -196,11 +219,48 @@ def RunWMIQuery(query, baseobj=r"winmgmts:\root\cimv2"): yield response except pythoncom.com_error as e: - raise RuntimeError("WMI query data error on query \'%s\' err was %s" % - (e, query)) + raise RuntimeError( + "WMI query data error on query '%s' err was %s" % (e, query) + ) from e class UpdateAgent(standard.ExecuteBinaryCommand): """Updates the GRR agent to a new version.""" - # For Windows this is just an alias to ExecuteBinaryCommand. + def ProcessFile(self, path, args): + if path.endswith(".msi"): + self._InstallMsi(path) + else: + raise ValueError(f"Unknown suffix for file {path}.") + + def _InstallMsi(self, path: bytes): + # misexec won't log to stdout/stderr. Write to a log file insetad. + with tempfiles.CreateGRRTempFile(filename="GRRInstallLog.txt") as f: + log_path = f.name + + try: + start = time.monotonic() + cmd = ["msiexec", "/i", path, "/qn", "/l*", log_path] + # Detach from process group and console session to help ensure the child + # process won't die when the parent process dies. + creationflags = ( + subprocess.DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP + ) + p = subprocess.run(cmd, check=False, creationflags=creationflags) + + with open(log_path, "rb") as f: + # Limit output to fit within 2MiB fleetspeak message limit. + msiexec_log_output = f.read(512 * 1024) + finally: + os.remove(log_path) + logging.error("Installer ran, but the old GRR client is still running") + + self.SendReply( + rdf_client_action.ExecuteBinaryResponse( + stdout=b"", + stderr=msiexec_log_output, + exit_status=p.returncode, + # We have to return microseconds. + time_used=int(1e6 * time.monotonic() - start), + ) + ) diff --git a/grr/client/grr_response_client/client_actions/windows/windows_test.py b/grr/client/grr_response_client/client_actions/windows/windows_test.py index 4c39c2ee18..cc28733995 100644 --- a/grr/client/grr_response_client/client_actions/windows/windows_test.py +++ b/grr/client/grr_response_client/client_actions/windows/windows_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from unittest import mock from absl.testing import absltest @@ -24,8 +23,9 @@ def testEnumerateInterfaces(self): self.assertNotEmpty(address.human_readable_address) found_address = True if not found_address: - self.fail("Not a single address found in EnumerateInterfaces {}".format( - replies)) + self.fail( + "Not a single address found in EnumerateInterfaces {}".format(replies) + ) def testEnumerateInterfacesMock(self): # Stub out wmi.WMI().Win32_NetworkAdapterConfiguration() @@ -44,11 +44,15 @@ def testEnumerateInterfacesMock(self): interface = replies[0] self.assertLen(interface.addresses, 4) addresses = [x.human_readable_address for x in interface.addresses] - self.assertCountEqual(addresses, [ - "192.168.1.20", "ffff::ffff:aaaa:1111:aaaa", - "dddd:0:8888:6666:bbbb:aaaa:eeee:bbbb", - "dddd:0:8888:6666:bbbb:aaaa:ffff:bbbb" - ]) + self.assertCountEqual( + addresses, + [ + "192.168.1.20", + "ffff::ffff:aaaa:1111:aaaa", + "dddd:0:8888:6666:bbbb:aaaa:eeee:bbbb", + "dddd:0:8888:6666:bbbb:aaaa:ffff:bbbb", + ], + ) def testRunWMI(self): result_list = list(windows.RunWMIQuery("SELECT * FROM Win32_logicalDisk")) diff --git a/grr/client/grr_response_client/client_logging.py b/grr/client/grr_response_client/client_logging.py index c5fc1b5228..268d26d4db 100644 --- a/grr/client/grr_response_client/client_logging.py +++ b/grr/client/grr_response_client/client_logging.py @@ -15,7 +15,8 @@ "verbose", default=False, help="Turn on verbose logging.", - allow_override=True) + allow_override=True, +) class PreLoggingMemoryHandler(handlers.BufferingHandler): @@ -30,7 +31,7 @@ def flush(self): This is called when the buffer is really full, we just just drop one oldest message. """ - self.buffer = self.buffer[-self.capacity:] + self.buffer = self.buffer[-self.capacity :] class RobustSysLogHandler(handlers.SysLogHandler): @@ -133,7 +134,8 @@ def LogInit(): # The root logger. logger = logging.getLogger() memory_handlers = [ - m for m in logger.handlers + m + for m in logger.handlers if m.__class__.__name__ == "PreLoggingMemoryHandler" ] diff --git a/grr/client/grr_response_client/client_main.py b/grr/client/grr_response_client/client_main.py index 8581bfd9f8..87a0775951 100644 --- a/grr/client/grr_response_client/client_main.py +++ b/grr/client/grr_response_client/client_main.py @@ -17,23 +17,29 @@ from grr_response_core.config import contexts -_INSTALL = flags.DEFINE_bool("install", False, - "Specify this to install the client.") +_INSTALL = flags.DEFINE_bool( + "install", False, "Specify this to install the client." +) _BREAK_ON_START = flags.DEFINE_bool( - "break_on_start", False, + "break_on_start", + False, "If True break into a pdb shell immediately on startup. This" - " can be used for debugging the client manually.") + " can be used for debugging the client manually.", +) _DEBUG_CLIENT_ACTIONS = flags.DEFINE_bool( - "debug_client_actions", False, - "If True break into a pdb shell before executing any client" - " action.") + "debug_client_actions", + False, + "If True break into a pdb shell before executing any client action.", +) _REMOTE_DEBUGGING_PORT = flags.DEFINE_integer( - "remote_debugging_port", 0, + "remote_debugging_port", + 0, "If set to a non-zero port, pydevd is started to allow remote debugging " - "(e.g. using PyCharm).") + "(e.g. using PyCharm).", +) def _start_remote_debugging(port): @@ -48,12 +54,14 @@ def _start_remote_debugging(port): port=port, stdoutToServer=True, stderrToServer=True, - suspend=_BREAK_ON_START.value) + suspend=_BREAK_ON_START.value, + ) except ImportError: print( "pydevd is required for remote debugging. Please follow the PyCharm" "manual or run `pip install pydevd-pycharm` to install.", - file=sys.stderr) + file=sys.stderr, + ) def main(unused_args): @@ -65,8 +73,9 @@ def main(unused_args): pdb.set_trace() # Allow per platform configuration. - config.CONFIG.AddContext(contexts.CLIENT_CONTEXT, - "Context applied when we run the client process.") + config.CONFIG.AddContext( + contexts.CLIENT_CONTEXT, "Context applied when we run the client process." + ) client_startup.ClientInit() @@ -80,11 +89,14 @@ def main(unused_args): # initialization makes only sense if we run from a proper installation. # This is the case if this is a PyInstaller binary. sandbox.InitSandbox( - "{}_{}".format(config.CONFIG["Client.name"], - config.CONFIG["Source.version_string"]), - [config.CONFIG["Client.install_path"]]) + "{}_{}".format( + config.CONFIG["Client.name"], config.CONFIG["Source.version_string"] + ), + [config.CONFIG["Client.install_path"]], + ) fleetspeak_client.GRRFleetspeakClient().Run() + if __name__ == "__main__": app.run(main) diff --git a/grr/client/grr_response_client/client_test.py b/grr/client/grr_response_client/client_test.py index 07d19d6835..d172da1502 100644 --- a/grr/client/grr_response_client/client_test.py +++ b/grr/client/grr_response_client/client_test.py @@ -23,11 +23,14 @@ class MockAction(actions.ActionPlugin): def Run(self, message): self.SendReply( rdf_client_action.EchoRequest( - data="Received Message: %s. Data %s" % (message.data, "x" * 100))) + data="Received Message: %s. Data %s" % (message.data, "x" * 100) + ) + ) class RaiseAction(actions.ActionPlugin): """A mock action which raises an error.""" + in_rdfvalue = rdf_client.LogMessage out_rdfvalues = [rdf_client.LogMessage] @@ -44,6 +47,7 @@ def LoadCertificates(self): class BasicContextTests(test_lib.GRRBaseTest): """Test the GRR contexts.""" + to_test_context = TestedContext def setUp(self): @@ -62,10 +66,12 @@ def testHandleMessage(self): auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED, payload=args, request_id=1, - generate_task_id=True) + generate_task_id=True, + ) - with mock.patch.object(client_actions, "REGISTRY", - {"MockAction": MockAction}): + with mock.patch.object( + client_actions, "REGISTRY", {"MockAction": MockAction} + ): self.context.HandleMessage(message) # Check the response - one data and one status @@ -86,10 +92,12 @@ def testHandleError(self): session_id=self.session_id, auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED, request_id=1, - generate_task_id=True) + generate_task_id=True, + ) - with mock.patch.object(client_actions, "REGISTRY", - {"RaiseAction": RaiseAction}): + with mock.patch.object( + client_actions, "REGISTRY", {"RaiseAction": RaiseAction} + ): self.context.HandleMessage(message) # Check the response - one data and one status @@ -112,10 +120,12 @@ def testUnauthenticated(self): session_id=self.session_id, auth_state=rdf_flows.GrrMessage.AuthorizationState.UNAUTHENTICATED, request_id=1, - generate_task_id=True) + generate_task_id=True, + ) - with mock.patch.object(client_actions, "REGISTRY", - {"MockAction": MockAction}): + with mock.patch.object( + client_actions, "REGISTRY", {"MockAction": MockAction} + ): self.context.HandleMessage(message) # We expect to receive an GrrStatus to indicate an exception was diff --git a/grr/client/grr_response_client/client_utils.py b/grr/client/grr_response_client/client_utils.py index 06b08b4cda..365a8f8cfe 100644 --- a/grr/client/grr_response_client/client_utils.py +++ b/grr/client/grr_response_client/client_utils.py @@ -30,7 +30,9 @@ OpenProcessForMemoryAccess = _client_utils.OpenProcessForMemoryAccess TransactionLog = _client_utils.TransactionLog VerifyFileOwner = _client_utils.VerifyFileOwner -CreateProcessFromSerializedFileDescriptor = _client_utils.CreateProcessFromSerializedFileDescriptor +CreateProcessFromSerializedFileDescriptor = ( + _client_utils.CreateProcessFromSerializedFileDescriptor +) # pylint: enable=g-bad-name @@ -61,9 +63,9 @@ def StatEntryFromPath( return StatEntryFromStat(stat, pathspec, ext_attrs=ext_attrs) -def StatEntryFromStat(stat: filesystem.Stat, - pathspec: rdf_paths.PathSpec, - ext_attrs: bool = True) -> rdf_client_fs.StatEntry: +def StatEntryFromStat( + stat: filesystem.Stat, pathspec: rdf_paths.PathSpec, ext_attrs: bool = True +) -> rdf_client_fs.StatEntry: """Build a stat entry object from a given stat object. Args: @@ -102,17 +104,20 @@ def StatEntryFromStat(stat: filesystem.Stat, return result -def StatEntryFromStatPathSpec(stat: filesystem.Stat, - ext_attrs: bool) -> rdf_client_fs.StatEntry: +def StatEntryFromStatPathSpec( + stat: filesystem.Stat, ext_attrs: bool +) -> rdf_client_fs.StatEntry: pathspec = rdf_paths.PathSpec( pathtype=rdf_paths.PathSpec.PathType.OS, path=LocalPathToCanonicalPath(stat.GetPath()), - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL) + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ) return StatEntryFromStat(stat, pathspec, ext_attrs=ext_attrs) def StatResultFromStatEntry( - stat_entry: rdf_client_fs.StatEntry) -> os.stat_result: + stat_entry: rdf_client_fs.StatEntry, +) -> os.stat_result: """Returns a `os.stat_result` with most information from `StatEntry`. This is a lossy conversion, only the 10 first stat_result fields are diff --git a/grr/client/grr_response_client/client_utils_common.py b/grr/client/grr_response_client/client_utils_common.py index 8fe78ec9cb..8b91d80afe 100644 --- a/grr/client/grr_response_client/client_utils_common.py +++ b/grr/client/grr_response_client/client_utils_common.py @@ -10,7 +10,6 @@ import threading import time - from grr_response_client.local import binary_whitelist from grr_response_core import config from grr_response_core.lib import constants @@ -28,13 +27,15 @@ def HandleAlarm(process): pass -def Execute(cmd, - args, - time_limit=-1, - bypass_allowlist=False, - daemon=False, - use_client_context=False, - cwd=None): +def Execute( + cmd, + args, + time_limit=-1, + bypass_allowlist=False, + daemon=False, + use_client_context=False, + cwd=None, +): """Executes commands on the client. This function is the only place where commands will be executed @@ -58,8 +59,9 @@ def Execute(cmd, """ if not bypass_allowlist and not IsExecutionAllowed(cmd, args): # Allowlist doesn't contain this cmd/arg pair - logging.info("Execution disallowed by allowlist: %s %s.", cmd, - " ".join(args)) + logging.info( + "Execution disallowed by allowlist: %s %s.", cmd, " ".join(args) + ) return (b"", b"Execution disallowed by allowlist.", -1, -1) if daemon: @@ -75,11 +77,13 @@ def Execute(cmd, # This only works if the process is running as root. pass _Execute( - cmd, args, time_limit, use_client_context=use_client_context, cwd=cwd) + cmd, args, time_limit, use_client_context=use_client_context, cwd=cwd + ) os._exit(0) # pylint: disable=protected-access else: return _Execute( - cmd, args, time_limit, use_client_context=use_client_context, cwd=cwd) + cmd, args, time_limit, use_client_context=use_client_context, cwd=cwd + ) def _Execute(cmd, args, time_limit=-1, use_client_context=False, cwd=None): @@ -100,7 +104,8 @@ def _Execute(cmd, args, time_limit=-1, use_client_context=False, cwd=None): stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, - cwd=cwd) + cwd=cwd, + ) alarm = None if time_limit > 0: @@ -155,8 +160,10 @@ def IsExecutionAllowed(cmd, args): ("driverquery.exe", ["/v"]), ("ipconfig.exe", ["/all"]), ("netsh.exe", ["advfirewall", "firewall", "show", "rule", "name=all"]), - ("netsh.exe", - ["advfirewall", "monitor", "show", "firewall", "rule", "name=all"]), + ( + "netsh.exe", + ["advfirewall", "monitor", "show", "firewall", "rule", "name=all"], + ), ("tasklist.exe", ["/SVC"]), ("tasklist.exe", ["/v"]), ] @@ -166,6 +173,7 @@ def IsExecutionAllowed(cmd, args): ("/bin/echo", ["1"]), ("/bin/mount", []), ("/bin/rpm", ["-qa"]), + ("/bin/rpm", ["--query", "--all"]), ("/bin/sleep", ["10"]), ("/sbin/auditctl", ["-l"]), ("/sbin/ifconfig", ["-a"]), @@ -191,17 +199,11 @@ def IsExecutionAllowed(cmd, args): ("/usr/sbin/arp", ["-a"]), ("/usr/sbin/kextstat", []), ("/usr/sbin/system_profiler", ["-xml", "SPHardwareDataType"]), - ("/usr/libexec/firmwarecheckers/eficheck/eficheck", ["--version"]), - ("/usr/libexec/firmwarecheckers/eficheck/eficheck", - ["--generate-hashes"]), - ("/usr/libexec/firmwarecheckers/eficheck/eficheck", - ["--save", "-b", "firmware.bin"]), - ("/usr/libexec/firmwarecheckers/ethcheck/ethcheck", ["--show-hashes"]), ] else: allowlist = [] - for (allowed_cmd, allowed_args) in allowlist: + for allowed_cmd, allowed_args in allowlist: if cmd == allowed_cmd and args == allowed_args: return True diff --git a/grr/client/grr_response_client/client_utils_linux.py b/grr/client/grr_response_client/client_utils_linux.py index 4374a41703..d40be7fbe6 100644 --- a/grr/client/grr_response_client/client_utils_linux.py +++ b/grr/client/grr_response_client/client_utils_linux.py @@ -8,7 +8,6 @@ from grr_response_client import client_utils_osx_linux from grr_response_client.linux import process - from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import paths as rdf_paths @@ -20,7 +19,9 @@ VerifyFileOwner = client_utils_osx_linux.VerifyFileOwner TransactionLog = client_utils_osx_linux.TransactionLog -CreateProcessFromSerializedFileDescriptor = process.Process.CreateFromSerializedFileDescriptor +CreateProcessFromSerializedFileDescriptor = ( + process.Process.CreateFromSerializedFileDescriptor +) # pylint: enable=invalid-name @@ -88,12 +89,14 @@ def GetRawDevice(path): result.pathtype = rdf_paths.PathSpec.PathType.OS else: logging.error( - "Filesystem %s is not supported. Supported filesystems " - "are %s", fs_type, SUPPORTED_FILESYSTEMS) + "Filesystem %s is not supported. Supported filesystems are %s", + fs_type, + SUPPORTED_FILESYSTEMS, + ) result.pathtype = rdf_paths.PathSpec.PathType.UNSET # Drop the mount point - path = utils.NormalizePath(path[len(mount_point):]) + path = utils.NormalizePath(path[len(mount_point) :]) result.mount_point = mount_point return result, path @@ -110,4 +113,5 @@ def MemoryRegions(proc, options): skip_executable_regions=options.skip_executable_regions, skip_mapped_files=options.skip_mapped_files, skip_readonly_regions=options.skip_readonly_regions, - skip_shared_regions=options.skip_shared_regions) + skip_shared_regions=options.skip_shared_regions, + ) diff --git a/grr/client/grr_response_client/client_utils_osx.py b/grr/client/grr_response_client/client_utils_osx.py index 1e177e6a94..f051fcf19d 100644 --- a/grr/client/grr_response_client/client_utils_osx.py +++ b/grr/client/grr_response_client/client_utils_osx.py @@ -7,7 +7,6 @@ import os import platform - from grr_response_client import client_utils_osx_linux from grr_response_client.osx import objc from grr_response_client.osx import process @@ -22,7 +21,9 @@ VerifyFileOwner = client_utils_osx_linux.VerifyFileOwner TransactionLog = client_utils_osx_linux.TransactionLog -CreateProcessFromSerializedFileDescriptor = process.Process.CreateFromSerializedFileDescriptor +CreateProcessFromSerializedFileDescriptor = ( + process.Process.CreateFromSerializedFileDescriptor +) # pylint: enable=invalid-name @@ -49,11 +50,13 @@ def FindProxies(): return ["http://%s:%d/" % (proxy, port)] cf_auto_enabled = sc.CFDictRetrieve( - settings, "kSCPropNetProxiesProxyAutoConfigEnable") + settings, "kSCPropNetProxiesProxyAutoConfigEnable" + ) if cf_auto_enabled and bool(sc.CFNumToInt32(cf_auto_enabled)): - cfurl = sc.CFDictRetrieve(settings, - "kSCPropNetProxiesProxyAutoConfigURLString") + cfurl = sc.CFDictRetrieve( + settings, "kSCPropNetProxiesProxyAutoConfigURLString" + ) if cfurl: unused_url = sc.CFStringToPystring(cfurl) # TODO(amoser): Auto config is enabled, what is the plan here? @@ -75,13 +78,15 @@ def GetMountpoints(): for filesys in GetFileSystems(): devices[filesys.f_mntonname.decode("utf-8")] = ( filesys.f_mntfromname.decode("utf-8"), - filesys.f_fstypename.decode("utf-8")) + filesys.f_fstypename.decode("utf-8"), + ) return devices class StatFSStruct(utils.Struct): """Parse filesystems getfsstat.""" + _fields = [ ("h", "f_otype;"), ("h", "f_oflags;"), @@ -108,6 +113,7 @@ class StatFSStruct(utils.Struct): class StatFS64Struct(utils.Struct): """Parse filesystems getfsstat for 64 bit.""" + _fields = [ (" limit: raise actions.NetworkBytesExceededError( - "Action exceeded network send limit.") + "Action exceeded network send limit." + ) def HandleMessage(self, message): """Entry point for processing jobs. @@ -239,7 +240,7 @@ def Sleep(self, timeout): """Sleeps the calling thread with heartbeat.""" # Split a long sleep interval into 1 second intervals so we can heartbeat. while timeout > 0: - time.sleep(min(1., timeout)) + time.sleep(min(1.0, timeout)) timeout -= 1 # If the output queue is full, we are ready to do a post - no # point in waiting. @@ -256,14 +257,16 @@ def OnStartup(self): if last_request: status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED, - error_message="Client killed during transaction") + error_message="Client killed during transaction", + ) self.SendReply( status, request_id=last_request.request_id, response_id=1, session_id=last_request.session_id, - message_type=rdf_flows.GrrMessage.Type.STATUS) + message_type=rdf_flows.GrrMessage.Type.STATUS, + ) self.transaction_log.Clear() @@ -292,12 +295,14 @@ def run(self): self.SendReply( rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.GENERIC_ERROR, - error_message=utils.SmartUnicode(e)), + error_message=utils.SmartUnicode(e), + ), request_id=message.request_id, response_id=1, session_id=message.session_id, task_id=message.task_id, - message_type=rdf_flows.GrrMessage.Type.STATUS) + message_type=rdf_flows.GrrMessage.Type.STATUS, + ) if flags.FLAGS.pdb_post_mortem: pdb.post_mortem() diff --git a/grr/client/grr_response_client/comms_test.py b/grr/client/grr_response_client/comms_test.py index 9d349db902..78939b0cd2 100644 --- a/grr/client/grr_response_client/comms_test.py +++ b/grr/client/grr_response_client/comms_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Test for client comms.""" + from absl import app from grr_response_client import comms diff --git a/grr/client/grr_response_client/distro_entry.py b/grr/client/grr_response_client/distro_entry.py index dbac673b03..2455dc7b21 100644 --- a/grr/client/grr_response_client/distro_entry.py +++ b/grr/client/grr_response_client/distro_entry.py @@ -14,4 +14,5 @@ def FleetspeakClient(): def FleetspeakClientWrapper(): from grr_response_client import fleetspeak_client_wrapper + app.run(fleetspeak_client_wrapper.main) diff --git a/grr/client/grr_response_client/fleetspeak_client.py b/grr/client/grr_response_client/fleetspeak_client.py index de5304eb4b..9b9ddd8f55 100644 --- a/grr/client/grr_response_client/fleetspeak_client.py +++ b/grr/client/grr_response_client/fleetspeak_client.py @@ -82,10 +82,12 @@ class GRRFleetspeakClient(object): def __init__(self): self._fs = fs_client.FleetspeakConnection( - version=config.CONFIG["Source.version_string"]) + version=config.CONFIG["Source.version_string"] + ) self._sender_queue = queue.Queue( - maxsize=GRRFleetspeakClient._SENDER_QUEUE_MAXSIZE) + maxsize=GRRFleetspeakClient._SENDER_QUEUE_MAXSIZE + ) self._threads = {} @@ -94,7 +96,8 @@ def __init__(self): # threading.Thread here. out_queue = _FleetspeakQueueForwarder(self._sender_queue) worker = self._threads["Worker"] = comms.GRRClientWorker( - out_queue=out_queue, heart_beat_cb=self._fs.Heartbeat, client=self) + out_queue=out_queue, heart_beat_cb=self._fs.Heartbeat, client=self + ) # TODO(user): this is an ugly way of passing the heartbeat callback to # the queue. Refactor the heartbeat callback initialization logic so that # this won't be needed. @@ -137,7 +140,8 @@ def Run(self): ] if dead_threads: raise FatalError( - "These threads are dead: %r. Shutting down..." % dead_threads) + "These threads are dead: %r. Shutting down..." % dead_threads + ) time.sleep(10) def _ForemanOp(self): @@ -156,20 +160,27 @@ def _SendMessages(self, grr_msgs, background=False): fs_msg = fs_common_pb2.Message( message_type="MessageList", destination=fs_common_pb2.Address(service_name="GRR"), - background=background) + background=background, + ) fs_msg.data.Pack(message_list.AsPrimitiveProto()) for grr_msg in grr_msgs: - if (grr_msg.session_id is None or grr_msg.request_id is None or - grr_msg.response_id is None): + if ( + grr_msg.session_id is None + or grr_msg.request_id is None + or grr_msg.response_id is None + ): continue # Place all ids in a single annotation, instead of having separate # annotations for the flow-id, request-id and response-id. This reduces # overall size of the annotations by half (~60 bytes to ~30 bytes). annotation = fs_msg.annotations.entries.add() annotation.key = _DATA_IDS_ANNOTATION_KEY - annotation.value = "%s:%d:%d" % (grr_msg.session_id.Basename(), - grr_msg.request_id, grr_msg.response_id) + annotation.value = "%s:%d:%d" % ( + grr_msg.session_id.Basename(), + grr_msg.request_id, + grr_msg.response_id, + ) if fs_msg.annotations.ByteSize() >= _MAX_ANNOTATIONS_BYTES: break @@ -212,7 +223,8 @@ def _ReceiveOp(self): if not received_type.endswith("GrrMessage"): raise ValueError( "Unexpected proto type received through Fleetspeak: %r; expected " - "grr.GrrMessage." % received_type) + "grr.GrrMessage." % received_type + ) grr_msg = rdf_flows.GrrMessage.FromSerializedBytes(fs_msg.data.value) # Authentication is ensured by Fleetspeak. diff --git a/grr/client/grr_response_client/fleetspeak_client_test.py b/grr/client/grr_response_client/fleetspeak_client_test.py index d61e2a3c80..08677516dd 100644 --- a/grr/client/grr_response_client/fleetspeak_client_test.py +++ b/grr/client/grr_response_client/fleetspeak_client_test.py @@ -64,7 +64,8 @@ def testSendMessagesWithAnnotations(self, mock_worker_class, mock_conn_class): session_id="%s/%s" % (client_id, flow_id), name="TestClientAction", request_id=2, - response_id=len(grr_messages) + 1) + response_id=len(grr_messages) + 1, + ) annotation = expected_annotations.entries.add() annotation.key = fleetspeak_client._DATA_IDS_ANNOTATION_KEY annotation.value = "%s:2:%d" % (flow_id, len(grr_messages) + 1) @@ -76,15 +77,18 @@ def testSendMessagesWithAnnotations(self, mock_worker_class, mock_conn_class): session_id="%s/%s" % (client_id, flow_id), name="TestClientAction", request_id=3, - response_id=1) + response_id=1, + ) grr_messages.append(extra_message) client._sender_queue.put(extra_message) self.assertLess( - len(grr_messages), fleetspeak_client._MAX_MSG_LIST_MSG_COUNT) + len(grr_messages), fleetspeak_client._MAX_MSG_LIST_MSG_COUNT + ) self.assertLess( sum(len(x.SerializeToBytes()) for x in grr_messages), - fleetspeak_client._MAX_MSG_LIST_BYTES) + fleetspeak_client._MAX_MSG_LIST_BYTES, + ) client._SendOp() diff --git a/grr/client/grr_response_client/fleetspeak_client_wrapper.py b/grr/client/grr_response_client/fleetspeak_client_wrapper.py index f33dc66331..629ed8db99 100644 --- a/grr/client/grr_response_client/fleetspeak_client_wrapper.py +++ b/grr/client/grr_response_client/fleetspeak_client_wrapper.py @@ -39,18 +39,21 @@ def TmpPath(*args): return os.path.join(tmp_dir, *args) server_config_dir = package.ResourcePath( - "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server") + "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server" + ) if not os.path.exists(server_config_dir): raise Error( f"Fleetspeak server config dir not found: {server_config_dir}. " - "Please make sure `grr_config_updater initialize` has been run.") + "Please make sure `grr_config_updater initialize` has been run." + ) client_config_name = { "Linux": "linux_client.config", "Windows": "windows_client.config", "Darwin": "darwin_client.config", } - client_config_path = os.path.join(server_config_dir, - client_config_name[platform.system()]) + client_config_path = os.path.join( + server_config_dir, client_config_name[platform.system()] + ) with open(client_config_path, "r") as f: client_config = text_format.Parse(f.read(), fs_cli_config_pb2.Config()) if client_config.HasField("filesystem_handler"): @@ -62,7 +65,8 @@ def TmpPath(*args): # re-runs of this command. Otherwise the client ID of the client would # change at each re-run. client_config.filesystem_handler.state_file = os.path.join( - config.CONFIG["Logging.path"], "fleetspeak-client.state") + config.CONFIG["Logging.path"], "fleetspeak-client.state" + ) with open(TmpPath("config"), "w") as f: f.write(text_format.MessageToString(client_config)) return TmpPath("config") @@ -73,8 +77,9 @@ def _CreateServiceConfig(config_dir: str) -> None: service_config_path = config.CONFIG["ClientBuilder.fleetspeak_config_path"] with open(service_config_path, "r") as f: data = config.CONFIG.InterpolateValue(f.read()) - service_config = text_format.Parse(data, - fs_system_pb2.ClientServiceConfig()) + service_config = text_format.Parse( + data, fs_system_pb2.ClientServiceConfig() + ) daemon_config = fs_daemon_config_pb2.Config() service_config.config.Unpack(daemon_config) del daemon_config.argv[:] @@ -83,8 +88,9 @@ def _CreateServiceConfig(config_dir: str) -> None: ]) service_config.config.Pack(daemon_config) utils.EnsureDirExists(os.path.join(config_dir, "textservices")) - with open(os.path.join(config_dir, "textservices", "GRR.textproto"), - "w") as f: + with open( + os.path.join(config_dir, "textservices", "GRR.textproto"), "w" + ) as f: f.write(text_format.MessageToString(service_config)) @@ -93,13 +99,14 @@ def _RunClient(tmp_dir: str) -> None: config_path = _CreateClientConfig(tmp_dir) _CreateServiceConfig(tmp_dir) fleetspeak_client = package.ResourcePath( - "fleetspeak-client-bin", - "fleetspeak-client-bin/usr/bin/fleetspeak-client") + "fleetspeak-client-bin", "fleetspeak-client-bin/usr/bin/fleetspeak-client" + ) if not fleetspeak_client or not os.path.exists(fleetspeak_client): raise Error( f"Fleetspeak client binary not found: {fleetspeak_client}." "Please make sure that the package `fleetspeak-client-bin` has been " - "installed.") + "installed." + ) command = [ fleetspeak_client, "--logtostderr", diff --git a/grr/client/grr_response_client/gcs.py b/grr/client/grr_response_client/gcs.py index 76949df9c3..6ad001d901 100644 --- a/grr/client/grr_response_client/gcs.py +++ b/grr/client/grr_response_client/gcs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Utilities for working with Google Cloud Storage.""" + import datetime from typing import Callable from typing import IO @@ -72,6 +73,7 @@ class Opts: progress_interval: An upper bound on the time between the calls of the progress function (in seconds). """ + chunk_size: int = 8 * 1024 * 1024 # 8 MiB. retry_chunk_attempts: int = 30 @@ -229,7 +231,8 @@ def PutChunk(): self.uri, data=chunk, headers=headers, - timeout=opts.progress_interval) + timeout=opts.progress_interval, + ) except exceptions.RequestException as error: raise RequestError("Chunk transmission failure") from error diff --git a/grr/client/grr_response_client/grr_fs_client.py b/grr/client/grr_response_client/grr_fs_client.py index e59edb5159..3698a0ab47 100644 --- a/grr/client/grr_response_client/grr_fs_client.py +++ b/grr/client/grr_response_client/grr_fs_client.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -r"""This is the GRR client for Fleetspeak enabled installations. -""" +r"""This is the GRR client for Fleetspeak enabled installations.""" from absl import app diff --git a/grr/client/grr_response_client/installer.py b/grr/client/grr_response_client/installer.py index 0a56472212..66e0f73990 100644 --- a/grr/client/grr_response_client/installer.py +++ b/grr/client/grr_response_client/installer.py @@ -50,8 +50,10 @@ def RunInstaller(): # the configuration from the flag and ignore the Config.writeback # location. config.CONFIG.Initialize(filename=flags.FLAGS.config, reset=True) - config.CONFIG.AddContext(contexts.INSTALLER_CONTEXT, - "Context applied when we run the client installer.") + config.CONFIG.AddContext( + contexts.INSTALLER_CONTEXT, + "Context applied when we run the client installer.", + ) if installers is None: logging.info("No installers found for %s.", sys.platform) diff --git a/grr/client/grr_response_client/linux/__init__.py b/grr/client/grr_response_client/linux/__init__.py index 03330d66c2..1ec8e6481c 100644 --- a/grr/client/grr_response_client/linux/__init__.py +++ b/grr/client/grr_response_client/linux/__init__.py @@ -1,3 +1,2 @@ #!/usr/bin/env python """Client linux-specific module root.""" - diff --git a/grr/client/grr_response_client/linux/client_utils_linux_test.py b/grr/client/grr_response_client/linux/client_utils_linux_test.py index 2878726b83..c0f4eb8577 100644 --- a/grr/client/grr_response_client/linux/client_utils_linux_test.py +++ b/grr/client/grr_response_client/linux/client_utils_linux_test.py @@ -44,22 +44,42 @@ def testLinuxGetRawDevice(self): with contextlib.ExitStack() as stack: stack.enter_context( mock.patch.object( - client_utils_linux, "MOUNTPOINT_CACHE", new=[0, None])) + client_utils_linux, "MOUNTPOINT_CACHE", new=[0, None] + ) + ) mountpoints = client_utils_linux.GetMountpoints(proc_mounts) stack.enter_context( mock.patch.object( - client_utils_linux, "GetMountpoints", return_value=mountpoints)) + client_utils_linux, "GetMountpoints", return_value=mountpoints + ) + ) for filename, expected_device, expected_path, device_type in [ - ("/etc/passwd", "/dev/mapper/root", "/etc/passwd", - rdf_paths.PathSpec.PathType.OS), - ("/usr/local/bin/ls", "/dev/mapper/usr", "/bin/ls", - rdf_paths.PathSpec.PathType.OS), - ("/proc/net/sys", "none", "/net/sys", - rdf_paths.PathSpec.PathType.UNSET), - ("/home/user/test.txt", "server.nfs:/vol/home", "/test.txt", - rdf_paths.PathSpec.PathType.UNSET) + ( + "/etc/passwd", + "/dev/mapper/root", + "/etc/passwd", + rdf_paths.PathSpec.PathType.OS, + ), + ( + "/usr/local/bin/ls", + "/dev/mapper/usr", + "/bin/ls", + rdf_paths.PathSpec.PathType.OS, + ), + ( + "/proc/net/sys", + "none", + "/net/sys", + rdf_paths.PathSpec.PathType.UNSET, + ), + ( + "/home/user/test.txt", + "server.nfs:/vol/home", + "/test.txt", + rdf_paths.PathSpec.PathType.UNSET, + ), ]: raw_pathspec, path = client_utils_linux.GetRawDevice(filename) @@ -92,9 +112,11 @@ def testEmpty(self): def testMany(self): with temp.AutoTempFilePath() as temp_filepath: filesystem_test_lib.SetExtAttr( - temp_filepath, name=b"user.foo", value=b"bar") + temp_filepath, name=b"user.foo", value=b"bar" + ) filesystem_test_lib.SetExtAttr( - temp_filepath, name=b"user.quux", value=b"norf") + temp_filepath, name=b"user.quux", value=b"norf" + ) attrs = list(client_utils_linux.GetExtAttrs(temp_filepath)) @@ -113,7 +135,8 @@ def testIncorrectFilePath(self): def testAttrChangeAfterListing(self, listxattr): with temp.AutoTempFilePath() as temp_filepath: filesystem_test_lib.SetExtAttr( - temp_filepath, name=b"user.bar", value=b"baz") + temp_filepath, name=b"user.bar", value=b"baz" + ) attrs = list(client_utils_linux.GetExtAttrs(temp_filepath)) diff --git a/grr/client/grr_response_client/linux/process.py b/grr/client/grr_response_client/linux/process.py index c7e6e03648..428055d2aa 100644 --- a/grr/client/grr_response_client/linux/process.py +++ b/grr/client/grr_response_client/linux/process.py @@ -3,7 +3,6 @@ This code is based on the memorpy project: https://github.com/n1nj4sec/memorpy - """ import ctypes @@ -21,8 +20,9 @@ def Errcheck(ret, func, args): del args if ret == -1: raise OSError( - "Error in %s: %s" % (func.__name__, - os.strerror(ctypes.get_errno() or errno.EPERM))) + "Error in %s: %s" + % (func.__name__, os.strerror(ctypes.get_errno() or errno.EPERM)) + ) return ret @@ -54,7 +54,8 @@ class Process(object): maps_re = re.compile( r"([0-9A-Fa-f]+)-([0-9A-Fa-f]+)\s+([-rwpsx]+)\s+" - r"([0-9A-Fa-f]+)\s+([0-9A-Fa-f]+:[0-9A-Fa-f]+)\s+([0-9]+)\s*(.*)") + r"([0-9A-Fa-f]+)\s+([0-9A-Fa-f]+:[0-9A-Fa-f]+)\s+([0-9]+)\s*(.*)" + ) def __init__(self, pid=None, mem_fd=None): """Creates a process for reading memory.""" @@ -95,11 +96,13 @@ def __enter__(self): def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): self.Close() - def Regions(self, - skip_mapped_files=False, - skip_shared_regions=False, - skip_executable_regions=False, - skip_readonly_regions=False): + def Regions( + self, + skip_mapped_files=False, + skip_shared_regions=False, + skip_executable_regions=False, + skip_readonly_regions=False, + ): """Returns an iterator over the readable regions for this process.""" try: maps_file = open("/proc/" + str(self.pid) + "/maps", "r") @@ -134,7 +137,8 @@ def Regions(self, size=end - start, is_readable=True, is_writable=is_writable, - is_executable=is_executable) + is_executable=is_executable, + ) def ReadBytes(self, address, num_bytes): lseek64(self.mem_file, address, os.SEEK_SET) @@ -149,5 +153,7 @@ def serialized_file_descriptor(self) -> int: @classmethod def CreateFromSerializedFileDescriptor( - cls, serialized_file_descriptor: int) -> "Process": + cls, + serialized_file_descriptor: int, + ) -> "Process": return Process(mem_fd=serialized_file_descriptor) diff --git a/grr/client/grr_response_client/linux/process_test.py b/grr/client/grr_response_client/linux/process_test.py index a1183e273b..300e7b4b59 100644 --- a/grr/client/grr_response_client/linux/process_test.py +++ b/grr/client/grr_response_client/linux/process_test.py @@ -41,22 +41,29 @@ def MockedOpen(requested_path, mode="rb"): raise OSError("Error in open.") - with utils.MultiStubber((builtins, "open", MockedOpen), - (process, "open64", MockedOpen64)): + with utils.MultiStubber( + (builtins, "open", MockedOpen), (process, "open64", MockedOpen64) + ): with process.Process(pid=100) as proc: self.assertLen(list(proc.Regions()), 32) self.assertLen(list(proc.Regions(skip_mapped_files=True)), 10) self.assertLen(list(proc.Regions(skip_shared_regions=True)), 31) self.assertEqual( - len(list(proc.Regions(skip_executable_regions=True))), 27) + len(list(proc.Regions(skip_executable_regions=True))), 27 + ) self.assertEqual( - len(list(proc.Regions(skip_readonly_regions=True))), 10) + len(list(proc.Regions(skip_readonly_regions=True))), 10 + ) self.assertEqual( len( list( proc.Regions( - skip_executable_regions=True, - skip_shared_regions=True))), 26) + skip_executable_regions=True, skip_shared_regions=True + ) + ) + ), + 26, + ) def main(argv): diff --git a/grr/client/grr_response_client/linux/registry_init.py b/grr/client/grr_response_client/linux/registry_init.py index 198ab3c24a..14832b5f34 100644 --- a/grr/client/grr_response_client/linux/registry_init.py +++ b/grr/client/grr_response_client/linux/registry_init.py @@ -1,5 +1,4 @@ #!/usr/bin/env python """This module contains linux specific client code.""" - # These need to register plugins so, pylint: disable=unused-import diff --git a/grr/client/grr_response_client/osx/__init__.py b/grr/client/grr_response_client/osx/__init__.py index 06227e947a..a7aa955fda 100644 --- a/grr/client/grr_response_client/osx/__init__.py +++ b/grr/client/grr_response_client/osx/__init__.py @@ -1,3 +1,2 @@ #!/usr/bin/env python """Client osx-specific module root.""" - diff --git a/grr/client/grr_response_client/osx/installers.py b/grr/client/grr_response_client/osx/installers.py index e72501e027..0a258a82e7 100644 --- a/grr/client/grr_response_client/osx/installers.py +++ b/grr/client/grr_response_client/osx/installers.py @@ -31,7 +31,8 @@ def Run(): packaged_config = config.CONFIG.MakeNewConfig() packaged_config.Initialize( - filename=installer_config, parser=config_parser.YamlConfigFileParser) + filename=installer_config, parser=config_parser.YamlConfigFileParser + ) new_config = config.CONFIG.MakeNewConfig() new_config.SetWriteBack(config.CONFIG["Config.writeback"]) diff --git a/grr/client/grr_response_client/osx/objc.py b/grr/client/grr_response_client/osx/objc.py index ea8207d03c..93ded025f0 100644 --- a/grr/client/grr_response_client/osx/objc.py +++ b/grr/client/grr_response_client/osx/objc.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Interface to Objective C libraries on OS X.""" + import ctypes import ctypes.util from typing import Text @@ -95,6 +96,7 @@ def _SetCTypesForLibrary(libname, fn_table): Args: libname: Library name string fn_table: List of (function, [arg types], return types) tuples + Returns: ctypes.CDLL with types set according to fn_table Raises: @@ -103,7 +105,7 @@ def _SetCTypesForLibrary(libname, fn_table): lib = LoadLibrary(libname) # We need to define input / output parameters for all functions we use - for (function, args, result) in fn_table: + for function, args, result in fn_table: f = getattr(lib, function) f.argtypes = args f.restype = result @@ -170,8 +172,9 @@ def IntToCFNumber(self, num): if not isinstance(num, int): raise TypeError('CFNumber can only be created from int') c_num = ctypes.c_int64(num) - cf_number = self.dll.CFNumberCreate(CF_DEFAULT_ALLOCATOR, INT64, - ctypes.byref(c_num)) + cf_number = self.dll.CFNumberCreate( + CF_DEFAULT_ALLOCATOR, INT64, ctypes.byref(c_num) + ) return cf_number def CFNumToInt32(self, num): @@ -191,14 +194,16 @@ def CFDictRetrieve(self, dictionary, key): return self.dll.CFDictionaryGetValue(dictionary, ptr) def PyStringToCFString(self, pystring): - return self.dll.CFStringCreateWithCString(CF_DEFAULT_ALLOCATOR, - pystring.encode('utf8'), UTF8) + return self.dll.CFStringCreateWithCString( + CF_DEFAULT_ALLOCATOR, pystring.encode('utf8'), UTF8 + ) def WrapCFTypeInPython(self, obj): """Package a CoreFoundation object in a Python wrapper. Args: obj: The CoreFoundation object. + Returns: One of CFBoolean, CFNumber, CFString, CFDictionary, CFArray. Raises: @@ -228,8 +233,9 @@ class SystemConfiguration(Foundation): def __init__(self): super().__init__() - self.cftable.append(('SCDynamicStoreCopyProxies', [ctypes.c_void_p], - ctypes.c_void_p)) + self.cftable.append( + ('SCDynamicStoreCopyProxies', [ctypes.c_void_p], ctypes.c_void_p) + ) self.dll = _SetCTypesForLibrary('SystemConfiguration', self.cftable) @@ -245,7 +251,8 @@ def __init__(self): super().__init__() self.cftable.append( # Only available 10.6 and later - ('SMCopyAllJobDictionaries', [ctypes.c_void_p], ctypes.c_void_p),) + ('SMCopyAllJobDictionaries', [ctypes.c_void_p], ctypes.c_void_p), + ) self.dll = _SetCTypesForLibrary('ServiceManagement', self.cftable) @@ -253,8 +260,8 @@ def SMGetJobDictionaries(self, domain='kSMDomainSystemLaunchd'): """Copy all Job Dictionaries from the ServiceManagement. Args: - domain: The name of a constant in Foundation referencing the domain. - Will copy all launchd services by default. + domain: The name of a constant in Foundation referencing the domain. Will + copy all launchd services by default. Returns: A marshalled python list of dicts containing the job dictionaries. @@ -320,7 +327,8 @@ def __init__(self, obj=0): self.ref = ctypes.c_void_p(self.IntToCFNumber(obj)) else: raise TypeError( - 'CFNumber initializer must be python int or objc CFNumber.') + 'CFNumber initializer must be python int or objc CFNumber.' + ) def __int__(self): return self.value @@ -375,9 +383,9 @@ def __getitem__(self, index): if not isinstance(index, int): raise TypeError('index must be an integer') if (index < 0) or (index >= len(self)): - raise IndexError('index must be between {0} and {1}'.format( - 0, - len(self) - 1)) + raise IndexError( + 'index must be between {0} and {1}'.format(0, len(self) - 1) + ) obj = self.dll.CFArrayGetValueAtIndex(self.ref, index) return self.WrapCFTypeInPython(obj) @@ -410,7 +418,8 @@ def __getitem__(self, key): cftype_key = key else: raise TypeError( - 'CFDictionary wrapper only supports string, int and objc values') + 'CFDictionary wrapper only supports string, int and objc values' + ) obj = ctypes.c_void_p(self.dll.CFDictionaryGetValue(self, cftype_key)) # Detect null pointers and avoid crashing WrapCFTypeInPython @@ -431,7 +440,8 @@ def get(self, key, default='', stringify=True): key: string. Dictionary key to look up. default: string. Return this value if key not found. stringify: bool. Force all return values to string for compatibility - reasons. + reasons. + Returns: python-wrapped CF object or default if not found. """ diff --git a/grr/client/grr_response_client/osx/objc_test.py b/grr/client/grr_response_client/osx/objc_test.py index 24f6cf07c7..6e8bdc7514 100644 --- a/grr/client/grr_response_client/osx/objc_test.py +++ b/grr/client/grr_response_client/osx/objc_test.py @@ -35,8 +35,9 @@ def testSetCTypesForLibraryLibNotFound(self, find_library_mock): @mock.patch("ctypes.util.find_library") @mock.patch("ctypes.cdll.LoadLibrary") - def testLoadLibraryUsesWellKnownPathAsFallback(self, load_library_mock, - find_library_mock): + def testLoadLibraryUsesWellKnownPathAsFallback( + self, load_library_mock, find_library_mock + ): mock_cdll = mock.Mock() find_library_mock.return_value = None load_library_mock.side_effect = [OSError("not found"), mock_cdll] @@ -45,13 +46,15 @@ def testLoadLibraryUsesWellKnownPathAsFallback(self, load_library_mock, self.assertGreaterEqual(load_library_mock.call_count, 1) load_library_mock.assert_called_with( - "/System/Library/Frameworks/Foobazzle.framework/Foobazzle") + "/System/Library/Frameworks/Foobazzle.framework/Foobazzle" + ) self.assertIs(result, mock_cdll) @mock.patch("ctypes.util.find_library") @mock.patch("ctypes.cdll.LoadLibrary") - def testLoadLibraryTriesLoadingSharedLoadedLibrary(self, load_library_mock, - find_library_mock): + def testLoadLibraryTriesLoadingSharedLoadedLibrary( + self, load_library_mock, find_library_mock + ): mock_cdll = mock.Mock() def _LoadLibrary(libpath): diff --git a/grr/client/grr_response_client/osx/process.py b/grr/client/grr_response_client/osx/process.py index cb969b40b9..7269fd7842 100644 --- a/grr/client/grr_response_client/osx/process.py +++ b/grr/client/grr_response_client/osx/process.py @@ -3,7 +3,6 @@ This code is based on the memorpy project: https://github.com/n1nj4sec/memorpy - """ import ctypes @@ -79,22 +78,27 @@ def Close(self): def Open(self): self.task = ctypes.c_uint32() self.mytask = libc.mach_task_self() - ret = libc.task_for_pid(self.mytask, ctypes.c_int(self.pid), - ctypes.pointer(self.task)) + ret = libc.task_for_pid( + self.mytask, ctypes.c_int(self.pid), ctypes.pointer(self.task) + ) if ret: if ret == 5: # Most likely this means access denied. This is not perfect # but there is no way to find out. raise process_error.ProcessError( - "Access denied (task_for_pid returned 5).") + "Access denied (task_for_pid returned 5)." + ) raise process_error.ProcessError( - "task_for_pid failed with error code : %s" % ret) - - def Regions(self, - skip_executable_regions=False, - skip_shared_regions=False, - skip_readonly_regions=False): + "task_for_pid failed with error code : %s" % ret + ) + + def Regions( + self, + skip_executable_regions=False, + skip_shared_regions=False, + skip_readonly_regions=False, + ): """Iterates over the readable regions for this process. We use mach_vm_region_recurse here to get a fine grained view of @@ -135,11 +139,14 @@ def Regions(self, while True: c_depth = ctypes.c_uint32(depth) - r = libc.mach_vm_region_recurse(self.task, ctypes.pointer(address), - ctypes.pointer(mapsize), - ctypes.pointer(c_depth), - ctypes.pointer(sub_info), - ctypes.pointer(count)) + r = libc.mach_vm_region_recurse( + self.task, + ctypes.pointer(address), + ctypes.pointer(mapsize), + ctypes.pointer(c_depth), + ctypes.pointer(sub_info), + ctypes.pointer(count), + ) # If we get told "invalid address", we have crossed into kernel land... if r == 1: @@ -160,7 +167,9 @@ def Regions(self, continue if skip_shared_regions and sub_info.share_mode in [ - SM_COW, SM_SHARED, SM_TRUESHARED + SM_COW, + SM_SHARED, + SM_TRUESHARED, ]: address.value += mapsize.value continue @@ -183,7 +192,8 @@ def Regions(self, size=mapsize.value, is_readable=True, is_executable=is_executable, - is_writable=is_writable) + is_writable=is_writable, + ) address.value += mapsize.value def ReadBytes(self, address, num_bytes): @@ -191,9 +201,13 @@ def ReadBytes(self, address, num_bytes): pdata = ctypes.c_void_p(0) data_cnt = ctypes.c_uint32(0) - ret = libc.mach_vm_read(self.task, ctypes.c_ulonglong(address), - ctypes.c_longlong(num_bytes), ctypes.pointer(pdata), - ctypes.pointer(data_cnt)) + ret = libc.mach_vm_read( + self.task, + ctypes.c_ulonglong(address), + ctypes.c_longlong(num_bytes), + ctypes.pointer(pdata), + ctypes.pointer(data_cnt), + ) if ret: raise process_error.ProcessError("Error in mach_vm_read, ret=%s" % ret) buf = ctypes.string_at(pdata.value, data_cnt.value) @@ -206,6 +220,7 @@ def serialized_file_descriptor(self) -> int: @classmethod def CreateFromSerializedFileDescriptor( - cls, serialized_file_descriptor: int) -> "Process": + cls, serialized_file_descriptor: int + ) -> "Process": del serialized_file_descriptor # Unused return NotImplementedError() diff --git a/grr/client/grr_response_client/streaming.py b/grr/client/grr_response_client/streaming.py index 4594d677cd..1655bbc0e3 100644 --- a/grr/client/grr_response_client/streaming.py +++ b/grr/client/grr_response_client/streaming.py @@ -95,7 +95,7 @@ def Stream(self, reader: "Reader", amount=None): while amount > 0: # We need `len(data)` here because overlap size can be 0. - overlap = data[len(data) - self.overlap_size:] + overlap = data[len(data) - self.overlap_size :] new = reader.Read(min(self.chunk_size - self.overlap_size, amount)) if not new: @@ -139,7 +139,8 @@ def StreamRanges(self, offset: int, amount: int) -> "Iterator[Chunk]": offset=chunk_start, amount=(chunk_end - chunk_start), overlap=(pos - chunk_start), - data=None) + data=None, + ) pos = chunk_end @@ -152,11 +153,13 @@ class Chunk(object): overlap: A number of bytes this chunk shares with the previous one. """ - def __init__(self, - offset: Optional[int] = None, - data: Optional[bytes] = None, - overlap: int = 0, - amount: Optional[int] = None): + def __init__( + self, + offset: Optional[int] = None, + data: Optional[bytes] = None, + overlap: int = 0, + amount: Optional[int] = None, + ): if offset is None: raise ValueError("chunk offset must be specified") if data is None and amount is None: @@ -171,8 +174,10 @@ def __init__(self, self.amount = amount def __repr__(self): - return (f"Chunk") + return ( + f"Chunk" + ) # TODO(hanuszczak): This function is beyond the scope of this module. It is # used in only one place [1] and should probably be moved there as well as diff --git a/grr/client/grr_response_client/streaming_test.py b/grr/client/grr_response_client/streaming_test.py index 5acca78889..5c86bf8010 100644 --- a/grr/client/grr_response_client/streaming_test.py +++ b/grr/client/grr_response_client/streaming_test.py @@ -171,7 +171,7 @@ def Stream(self, streamer, data): def Result(amount=available_data, offset=0): amount = min(amount, available_data - offset) for chunk in streamer.StreamRanges(offset, amount): - chunk.data = data[chunk.offset:chunk.offset + chunk.amount] + chunk.data = data[chunk.offset : chunk.offset + chunk.amount] yield chunk return Result @@ -252,7 +252,7 @@ def __init__(self, memory): self.memory = memory def ReadBytes(self, address, num_bytes): - return self.memory[address:address + num_bytes] + return self.memory[address : address + num_bytes] class ChunkTest(absltest.TestCase): diff --git a/grr/client/grr_response_client/time.py b/grr/client/grr_response_client/time.py index f54b820714..bdf20999a8 100644 --- a/grr/client/grr_response_client/time.py +++ b/grr/client/grr_response_client/time.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with agent-specific time utilities.""" + import math import time from typing import Callable diff --git a/grr/client/grr_response_client/unprivileged/communication.py b/grr/client/grr_response_client/unprivileged/communication.py index 230cd5df16..9ae2110915 100644 --- a/grr/client/grr_response_client/unprivileged/communication.py +++ b/grr/client/grr_response_client/unprivileged/communication.py @@ -8,7 +8,8 @@ import platform import struct import subprocess -from typing import NamedTuple, Callable, Optional, List, BinaryIO, Set +from typing import BinaryIO, Callable, List, NamedTuple, Optional, Set + import psutil from grr_response_client.unprivileged import sandbox @@ -156,10 +157,12 @@ class FileDescriptor: _handle: Optional[int] = None _mode: Optional[Mode] = None - def __init__(self, - file_descriptor: Optional[int] = None, - handle: Optional[int] = None, - mode: Optional[Mode] = None): + def __init__( + self, + file_descriptor: Optional[int] = None, + handle: Optional[int] = None, + mode: Optional[Mode] = None, + ): self._file_descriptor = file_descriptor self._handle = handle self._mode = mode @@ -178,6 +181,7 @@ def ToFileDescriptor(self) -> int: if platform.system() == "Windows": import msvcrt # pylint: disable=g-import-not-at-top + if self._mode == Mode.READ: mode = os.O_RDONLY elif self._mode == Mode.WRITE: @@ -247,7 +251,8 @@ def FromSerialized(cls, pipe_input: int, pipe_output: int) -> "Channel": """Creates a channel from serialized pipe file descriptors.""" return Channel( FileDescriptor.FromSerialized(pipe_input, Mode.READ), - FileDescriptor.FromSerialized(pipe_output, Mode.WRITE)) + FileDescriptor.FromSerialized(pipe_output, Mode.WRITE), + ) ArgsFactory = Callable[[Channel], List[str]] @@ -265,9 +270,11 @@ class SubprocessServer(Server): _started_instances: Set["SubprocessServer"] = set() - def __init__(self, - args_factory: ArgsFactory, - extra_file_descriptors: Optional[List[FileDescriptor]] = None): + def __init__( + self, + args_factory: ArgsFactory, + extra_file_descriptors: Optional[List[FileDescriptor]] = None, + ): """Constructor. Args: @@ -303,14 +310,18 @@ def Start(self) -> None: from grr_response_client.unprivileged.windows import process # pytype: disable=import-error # pylint: enable=g-import-not-at-top args = self._args_factory( - Channel(pipe_input=input_r_fd_obj, pipe_output=output_w_fd_obj)) + Channel(pipe_input=input_r_fd_obj, pipe_output=output_w_fd_obj) + ) extra_handles = [fd.ToHandle() for fd in self._extra_file_descriptors] self._process_win = process.Process( - args, [input_r_fd_obj.ToHandle(), - output_w_fd_obj.ToHandle()] + extra_handles) + args, + [input_r_fd_obj.ToHandle(), output_w_fd_obj.ToHandle()] + + extra_handles, + ) else: args = self._args_factory( - Channel(pipe_input=input_r_fd_obj, pipe_output=output_w_fd_obj)) + Channel(pipe_input=input_r_fd_obj, pipe_output=output_w_fd_obj) + ) extra_fds = [ fd.ToFileDescriptor() for fd in self._extra_file_descriptors ] @@ -346,12 +357,14 @@ def Connect(self) -> Connection: @classmethod def TotalCpuTime(cls) -> float: return SubprocessServer._past_instances_total_cpu_time + sum( - [instance.cpu_time for instance in cls._started_instances]) + [instance.cpu_time for instance in cls._started_instances] + ) @classmethod def TotalSysTime(cls) -> float: return SubprocessServer._past_instances_total_sys_time + sum( - [instance.sys_time for instance in cls._started_instances]) + [instance.sys_time for instance in cls._started_instances] + ) @property def cpu_time(self) -> float: @@ -375,8 +388,12 @@ def _psutil_process(self) -> psutil.Process: raise ValueError("Can't determine process.") -def Main(channel: Channel, connection_handler: ConnectionHandler, user: str, - group: str) -> None: +def Main( + channel: Channel, + connection_handler: ConnectionHandler, + user: str, + group: str, +) -> None: """The entry point of the server process. Args: @@ -388,11 +405,11 @@ def Main(channel: Channel, connection_handler: ConnectionHandler, user: str, sandbox.EnterSandbox(user, group) assert channel.pipe_input is not None and channel.pipe_output is not None with os.fdopen( - channel.pipe_input.ToFileDescriptor(), "rb", - buffering=False) as pipe_input: + channel.pipe_input.ToFileDescriptor(), "rb", buffering=False + ) as pipe_input: with os.fdopen( - channel.pipe_output.ToFileDescriptor(), "wb", - buffering=False) as pipe_output: + channel.pipe_output.ToFileDescriptor(), "wb", buffering=False + ) as pipe_output: transport = PipeTransport(pipe_input, pipe_output) connection = Connection(transport) connection_handler(connection) diff --git a/grr/client/grr_response_client/unprivileged/communication_test.py b/grr/client/grr_response_client/unprivileged/communication_test.py index b2242ec0d2..011c27f74e 100644 --- a/grr/client/grr_response_client/unprivileged/communication_test.py +++ b/grr/client/grr_response_client/unprivileged/communication_test.py @@ -6,6 +6,7 @@ from typing import List import unittest from unittest import mock + import sys from absl.testing import absltest import psutil @@ -49,8 +50,9 @@ def testCommunication(self): server.Stop() - @unittest.skipIf(platform.system() == "Windows", - "psutil is not used on Windows.") + @unittest.skipIf( + platform.system() == "Windows", "psutil is not used on Windows." + ) def testTotalServerCpuSysTime_usesPsutilProcess(self): _FakeCpuTimes = collections.namedtuple("FakeCpuTimes", ["user", "system"]) @@ -69,10 +71,12 @@ def cpu_times(self): # pylint: disable=invalid-name with communication.SubprocessServer(_MakeArgs): pass - self.assertAlmostEqual(communication.TotalServerCpuTime() - init_cpu_time, - 42.0) - self.assertAlmostEqual(communication.TotalServerSysTime() - init_sys_time, - 43.0) + self.assertAlmostEqual( + communication.TotalServerCpuTime() - init_cpu_time, 42.0 + ) + self.assertAlmostEqual( + communication.TotalServerSysTime() - init_sys_time, 43.0 + ) @unittest.skipIf(platform.system() != "Windows", "Windows only test.") def testTotalServerCpuSysTime_usesWin32Process(self): @@ -81,24 +85,27 @@ def _MockGetProcessTimes(handle): del handle # Unused. return { "UserTime": 42 * 10 * 1000 * 1000, - "KernelTime": 43 * 10 * 1000 * 1000 + "KernelTime": 43 * 10 * 1000 * 1000, } # pytype: disable=import-error import win32process # pylint: disable=g-import-not-at-top # pytype: enable=import-error - with mock.patch.object(win32process, "GetProcessTimes", - _MockGetProcessTimes): + with mock.patch.object( + win32process, "GetProcessTimes", _MockGetProcessTimes + ): init_cpu_time = communication.TotalServerCpuTime() init_sys_time = communication.TotalServerSysTime() with communication.SubprocessServer(_MakeArgs): pass - self.assertAlmostEqual(communication.TotalServerCpuTime() - init_cpu_time, - 42.0) - self.assertAlmostEqual(communication.TotalServerSysTime() - init_sys_time, - 43.0) + self.assertAlmostEqual( + communication.TotalServerCpuTime() - init_cpu_time, 42.0 + ) + self.assertAlmostEqual( + communication.TotalServerSysTime() - init_sys_time, 43.0 + ) def testCpuSysTime_addsUpMultipleProcesses(self): @@ -132,10 +139,12 @@ def sys_time(self): server1.fake_sys_time = 3.0 server2.fake_sys_time = 4.0 - self.assertAlmostEqual(communication.TotalServerCpuTime() - init_cpu_time, - 1.0 + 2.0) - self.assertAlmostEqual(communication.TotalServerSysTime() - init_sys_time, - 3.0 + 4.0) + self.assertAlmostEqual( + communication.TotalServerCpuTime() - init_cpu_time, 1.0 + 2.0 + ) + self.assertAlmostEqual( + communication.TotalServerSysTime() - init_sys_time, 3.0 + 4.0 + ) server1.fake_cpu_time = 5.0 server2.fake_cpu_time = 6.0 @@ -143,10 +152,12 @@ def sys_time(self): server1.fake_sys_time = 7.0 server2.fake_sys_time = 8.0 - self.assertAlmostEqual(communication.TotalServerCpuTime() - init_cpu_time, - 5.0 + 6.0) - self.assertAlmostEqual(communication.TotalServerSysTime() - init_sys_time, - 7.0 + 8.0) + self.assertAlmostEqual( + communication.TotalServerCpuTime() - init_cpu_time, 5.0 + 6.0 + ) + self.assertAlmostEqual( + communication.TotalServerSysTime() - init_sys_time, 7.0 + 8.0 + ) server1.Stop() server2.Stop() @@ -159,22 +170,28 @@ def sys_time(self): server1.fake_sys_time = 9.0 server2.fake_sys_time = 9.0 - self.assertAlmostEqual(communication.TotalServerCpuTime() - init_cpu_time, - 5.0 + 6.0) - self.assertAlmostEqual(communication.TotalServerSysTime() - init_sys_time, - 7.0 + 8.0) - - @unittest.skipIf(platform.system() != "Linux" and - platform.system() != "Darwin", "Unix only test.") + self.assertAlmostEqual( + communication.TotalServerCpuTime() - init_cpu_time, 5.0 + 6.0 + ) + self.assertAlmostEqual( + communication.TotalServerSysTime() - init_sys_time, 7.0 + 8.0 + ) + + @unittest.skipIf( + platform.system() != "Linux" and platform.system() != "Darwin", + "Unix only test.", + ) def testMain_entersSandbox(self): with mock.patch.object(sandbox, "EnterSandbox") as mock_enter_sandbox: input_fd = os.open("/dev/null", os.O_RDONLY) output_file = os.open("/dev/null", os.O_WRONLY) channel = communication.Channel( communication.FileDescriptor.FromFileDescriptor(input_fd), - communication.FileDescriptor.FromFileDescriptor(output_file)) - communication.Main(channel, lambda connection: None, "fooUser", - "barGroup") + communication.FileDescriptor.FromFileDescriptor(output_file), + ) + communication.Main( + channel, lambda connection: None, "fooUser", "barGroup" + ) mock_enter_sandbox.assert_called_with("fooUser", "barGroup") diff --git a/grr/client/grr_response_client/unprivileged/echo_server.py b/grr/client/grr_response_client/unprivileged/echo_server.py index c671366d4d..b007a1ad69 100644 --- a/grr/client/grr_response_client/unprivileged/echo_server.py +++ b/grr/client/grr_response_client/unprivileged/echo_server.py @@ -16,17 +16,21 @@ def Handler(connection: communication.Connection): while True: recv_result = connection.Recv() connection.Send( - communication.Message(recv_result.data + b"x", - recv_result.attachment + b"x")) + communication.Message( + recv_result.data + b"x", recv_result.attachment + b"x" + ) + ) def main(argv): communication.Main( communication.Channel.FromSerialized( - pipe_input=int(argv[1]), pipe_output=int(argv[2])), + pipe_input=int(argv[1]), pipe_output=int(argv[2]) + ), Handler, user="", - group="") + group="", + ) if __name__ == "__main__": diff --git a/grr/client/grr_response_client/unprivileged/filesystem/client.py b/grr/client/grr_response_client/unprivileged/filesystem/client.py index e582f170f1..583ee19fe0 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/client.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/client.py @@ -2,7 +2,7 @@ """Unprivileged filesystem RPC client code.""" import abc -from typing import TypeVar, Generic, Optional, Sequence, BinaryIO, Tuple +from typing import BinaryIO, Generic, Optional, Sequence, Tuple, TypeVar from grr_response_client.unprivileged import communication from grr_response_client.unprivileged.proto import filesystem_pb2 @@ -13,6 +13,7 @@ class Error(Exception): """Base class for exceptions in this module.""" + pass @@ -39,7 +40,8 @@ def __init__(self, connection: communication.Connection): def Send(self, request: filesystem_pb2.Request, attachment: bytes) -> None: self._connection.Send( - communication.Message(request.SerializeToString(), attachment)) + communication.Message(request.SerializeToString(), attachment) + ) def Recv(self) -> Tuple[filesystem_pb2.Response, bytes]: raw_response, attachment = self._connection.Recv() @@ -100,21 +102,25 @@ def Run(self, request: RequestType) -> ResponseType: packed_response, attachment = self._connection.Recv() if packed_response.HasField('device_data_request'): device_data_request = packed_response.device_data_request - data = self._device.Read(device_data_request.offset, - device_data_request.size) + data = self._device.Read( + device_data_request.offset, device_data_request.size + ) device_data = filesystem_pb2.DeviceData() request = filesystem_pb2.Request(device_data=device_data) self._connection.Send(request, data) elif packed_response.HasField('exception'): - raise OperationError(packed_response.exception.message, - packed_response.exception.formatted_exception) + raise OperationError( + packed_response.exception.message, + packed_response.exception.formatted_exception, + ) else: response = self.UnpackResponse(packed_response) self.MergeResponseAttachment(response, attachment) return response - def MergeResponseAttachment(self, response: ResponseType, - attachment: bytes) -> None: + def MergeResponseAttachment( + self, response: ResponseType, attachment: bytes + ) -> None: """Merges an attachment back into the response.""" pass @@ -129,106 +135,133 @@ def PackRequest(self, request: RequestType) -> filesystem_pb2.Request: pass -class InitHandler(OperationHandler[filesystem_pb2.InitRequest, - filesystem_pb2.InitResponse]): +class InitHandler( + OperationHandler[filesystem_pb2.InitRequest, filesystem_pb2.InitResponse] +): """Implements the Init RPC.""" def UnpackResponse( - self, response: filesystem_pb2.Response) -> filesystem_pb2.InitResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.InitResponse: return response.init_response def PackRequest( - self, request: filesystem_pb2.InitRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.InitRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(init_request=request) -class OpenHandler(OperationHandler[filesystem_pb2.OpenRequest, - filesystem_pb2.OpenResponse]): +class OpenHandler( + OperationHandler[filesystem_pb2.OpenRequest, filesystem_pb2.OpenResponse] +): """Implements the Open RPC.""" def UnpackResponse( - self, response: filesystem_pb2.Response) -> filesystem_pb2.OpenResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.OpenResponse: return response.open_response def PackRequest( - self, request: filesystem_pb2.OpenRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.OpenRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(open_request=request) -class ReadHandler(OperationHandler[filesystem_pb2.ReadRequest, - filesystem_pb2.ReadResponse]): +class ReadHandler( + OperationHandler[filesystem_pb2.ReadRequest, filesystem_pb2.ReadResponse] +): """Implements the Read RPC.""" def UnpackResponse( - self, response: filesystem_pb2.Response) -> filesystem_pb2.ReadResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.ReadResponse: return response.read_response def PackRequest( - self, request: filesystem_pb2.ReadRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.ReadRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(read_request=request) - def MergeResponseAttachment(self, response: filesystem_pb2.ReadResponse, - attachment: bytes) -> None: + def MergeResponseAttachment( + self, response: filesystem_pb2.ReadResponse, attachment: bytes + ) -> None: response.data = attachment -class StatHandler(OperationHandler[filesystem_pb2.StatRequest, - filesystem_pb2.StatResponse]): +class StatHandler( + OperationHandler[filesystem_pb2.StatRequest, filesystem_pb2.StatResponse] +): """Implements the Stat RPC.""" def UnpackResponse( - self, response: filesystem_pb2.Response) -> filesystem_pb2.StatResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.StatResponse: return response.stat_response def PackRequest( - self, request: filesystem_pb2.StatRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.StatRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(stat_request=request) -class ListFilesHandler(OperationHandler[filesystem_pb2.ListFilesRequest, - filesystem_pb2.ListFilesResponse]): +class ListFilesHandler( + OperationHandler[ + filesystem_pb2.ListFilesRequest, filesystem_pb2.ListFilesResponse + ] +): """Implements the ListFiles RPC.""" def UnpackResponse( - self, - response: filesystem_pb2.Response) -> filesystem_pb2.ListFilesResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.ListFilesResponse: return response.list_files_response def PackRequest( - self, request: filesystem_pb2.ListFilesRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.ListFilesRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(list_files_request=request) -class ListNamesHandler(OperationHandler[filesystem_pb2.ListNamesRequest, - filesystem_pb2.ListNamesResponse]): +class ListNamesHandler( + OperationHandler[ + filesystem_pb2.ListNamesRequest, filesystem_pb2.ListNamesResponse + ] +): """Implements the ListNames RPC.""" def UnpackResponse( - self, - response: filesystem_pb2.Response) -> filesystem_pb2.ListNamesResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.ListNamesResponse: return response.list_names_response def PackRequest( - self, request: filesystem_pb2.ListNamesRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.ListNamesRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(list_names_request=request) -class CloseHandler(OperationHandler[filesystem_pb2.CloseRequest, - filesystem_pb2.CloseResponse]): +class CloseHandler( + OperationHandler[filesystem_pb2.CloseRequest, filesystem_pb2.CloseResponse] +): """Implements the Close RPC.""" def UnpackResponse( - self, response: filesystem_pb2.Response) -> filesystem_pb2.CloseResponse: + self, response: filesystem_pb2.Response + ) -> filesystem_pb2.CloseResponse: return response.close_response def PackRequest( - self, request: filesystem_pb2.CloseRequest) -> filesystem_pb2.Request: + self, request: filesystem_pb2.CloseRequest + ) -> filesystem_pb2.Request: return filesystem_pb2.Request(close_request=request) class LookupCaseInsensitiveHandler( - OperationHandler[filesystem_pb2.LookupCaseInsensitiveRequest, - filesystem_pb2.LookupCaseInsensitiveResponse]): + OperationHandler[ + filesystem_pb2.LookupCaseInsensitiveRequest, + filesystem_pb2.LookupCaseInsensitiveResponse, + ] +): """Implements the LookupCaseInsensitive RPC.""" def UnpackResponse( @@ -245,8 +278,13 @@ def PackRequest( class File: """Wraps a remote file_id.""" - def __init__(self, connection: ConnectionWrapper, device: Device, - file_id: int, inode: int): + def __init__( + self, + connection: ConnectionWrapper, + device: Device, + file_id: int, + inode: int, + ): self._connection = connection self._device = device self._file_id = file_id @@ -254,7 +292,8 @@ def __init__(self, connection: ConnectionWrapper, device: Device, def Read(self, offset: int, size: int) -> bytes: request = filesystem_pb2.ReadRequest( - file_id=self._file_id, offset=offset, size=size) + file_id=self._file_id, offset=offset, size=size + ) response = ReadHandler(self._connection, self._device).Run(request) return response.data @@ -283,9 +322,11 @@ def inode(self) -> int: def LookupCaseInsensitive(self, name: str) -> Optional[str]: request = filesystem_pb2.LookupCaseInsensitiveRequest( - file_id=self._file_id, name=name) - response = LookupCaseInsensitiveHandler(self._connection, - self._device).Run(request) + file_id=self._file_id, name=name + ) + response = LookupCaseInsensitiveHandler(self._connection, self._device).Run( + request + ) if response.HasField('name'): return response.name return None @@ -299,26 +340,34 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: class StaleInodeError(Error): """The inode provided to open a file is stale / outdated.""" + pass class Client: """Client for the RPC filesystem service.""" - def __init__(self, connection: communication.Connection, - implementation_type: filesystem_pb2.ImplementationType, - device: Device): + def __init__( + self, + connection: communication.Connection, + implementation_type: filesystem_pb2.ImplementationType, + device: Device, + ): self._connection = ConnectionWrapper(connection) self._device = device device_file_descriptor = device.file_descriptor if device_file_descriptor is None: serialized_device_file_descriptor = None else: - serialized_device_file_descriptor = communication.FileDescriptor.FromFileDescriptor( - device_file_descriptor).Serialize() + serialized_device_file_descriptor = ( + communication.FileDescriptor.FromFileDescriptor( + device_file_descriptor + ).Serialize() + ) request = filesystem_pb2.InitRequest( implementation_type=implementation_type, - serialized_device_file_descriptor=serialized_device_file_descriptor) + serialized_device_file_descriptor=serialized_device_file_descriptor, + ) InitHandler(self._connection, self._device).Run(request) def __enter__(self) -> 'Client': @@ -330,16 +379,12 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: def Close(self): pass - def Open(self, - path: str, - stream_name: Optional[str] = None) -> File: + def Open(self, path: str, stream_name: Optional[str] = None) -> File: """Opens a file.""" request = filesystem_pb2.OpenRequest(path=path, stream_name=stream_name) return self._Open(request) - def OpenByInode(self, - inode: int, - stream_name: Optional[str] = None) -> File: + def OpenByInode(self, inode: int, stream_name: Optional[str] = None) -> File: """Opens a file by inode.""" request = filesystem_pb2.OpenRequest(inode=inode, stream_name=stream_name) return self._Open(request) @@ -350,13 +395,15 @@ def _Open(self, request: filesystem_pb2.OpenRequest) -> File: raise StaleInodeError() elif response.status != filesystem_pb2.OpenResponse.Status.NO_ERROR: raise IOError(f'Open RPC returned status {response.status}.') - return File(self._connection, self._device, response.file_id, - response.inode) + return File( + self._connection, self._device, response.file_id, response.inode + ) def CreateFilesystemClient( connection: communication.Connection, implementation_type: filesystem_pb2.ImplementationType, - device: Device) -> Client: + device: Device, +) -> Client: """Creates a filesystem client.""" return Client(connection, implementation_type, device) diff --git a/grr/client/grr_response_client/unprivileged/filesystem/filesystem.py b/grr/client/grr_response_client/unprivileged/filesystem/filesystem.py index eab53b7ddf..11c0755d6c 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/filesystem.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/filesystem.py @@ -2,19 +2,20 @@ """Common code and abstractions for filesystem implementations.""" import abc - -from typing import Dict, Optional, Iterable +from typing import Dict, Iterable, Optional from grr_response_client.unprivileged.proto import filesystem_pb2 class Error(Exception): """Base class for filesystem error.""" + pass class StaleInodeError(Error): """The inode provided to open a file is stale / outdated.""" + pass @@ -75,6 +76,7 @@ def LookupCaseInsensitive(self, name: str) -> Optional[str]: Args: name: Case-insensitive name to match. + Returns: the case-literal name or None if the case-insensitive name couldn't be found. """ diff --git a/grr/client/grr_response_client/unprivileged/filesystem/ntfs.py b/grr/client/grr_response_client/unprivileged/filesystem/ntfs.py index bfed28a90f..abd3441246 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/ntfs.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/ntfs.py @@ -2,7 +2,7 @@ """pyfsntfs implementation of a filesystem.""" import os -from typing import Optional, Iterable +from typing import Iterable, Optional import pyfsntfs from grr_response_client.unprivileged.filesystem import filesystem from grr_response_client.unprivileged.proto import filesystem_pb2 @@ -35,9 +35,12 @@ def tell(self) -> int: class NtfsFile(filesystem.File): """pyfsntfs implementation of File.""" - def __init__(self, filesystem_obj: filesystem.Filesystem, - fd: pyfsntfs.file_entry, - data_stream: Optional[pyfsntfs.data_stream]): + def __init__( + self, + filesystem_obj: filesystem.Filesystem, + fd: pyfsntfs.file_entry, + data_stream: Optional[pyfsntfs.data_stream], + ): super().__init__(filesystem_obj) self.fd = fd self.data_stream = data_stream @@ -117,7 +120,8 @@ def Inode(self) -> int: def _get_data_stream( entry: pyfsntfs.file_entry, - stream_name: Optional[str]) -> Optional[pyfsntfs.data_stream]: + stream_name: Optional[str], +) -> Optional[pyfsntfs.data_stream]: """Returns a data stream by name, or the default data stream.""" if stream_name is None: if entry.has_default_data_stream(): diff --git a/grr/client/grr_response_client/unprivileged/filesystem/ntfs_image_test_lib.py b/grr/client/grr_response_client/unprivileged/filesystem/ntfs_image_test_lib.py index df0bb1c8e5..1785352aea 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/ntfs_image_test_lib.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/ntfs_image_test_lib.py @@ -27,22 +27,35 @@ HIDDEN_FILE_TXT_FILE_REF = 562949953421388 CHINESE_FILE_FILE_REF = 844424930132045 - # Default StatEntry.ntfs values for files and directories S_DEFAULT_FILE = filesystem_pb2.StatEntry.Ntfs( - is_directory=False, flags=stat.FILE_ATTRIBUTE_ARCHIVE) # pytype: disable=module-attr + is_directory=False, flags=stat.FILE_ATTRIBUTE_ARCHIVE # pytype: disable=module-attr +) S_DEFAULT_DIR = filesystem_pb2.StatEntry.Ntfs( - is_directory=True, flags=stat.FILE_ATTRIBUTE_ARCHIVE) # pytype: disable=module-attr + is_directory=True, flags=stat.FILE_ATTRIBUTE_ARCHIVE # pytype: disable=module-attr +) S_MODE_ALL = stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO S_MODE_DIR = stat.S_IFDIR | S_MODE_ALL S_MODE_DEFAULT = stat.S_IFREG | S_MODE_ALL S_MODE_READ_ONLY = ( - stat.S_IFREG | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH | stat.S_IXUSR - | stat.S_IXGRP | stat.S_IXOTH) + stat.S_IFREG + | stat.S_IRUSR + | stat.S_IRGRP + | stat.S_IROTH + | stat.S_IXUSR + | stat.S_IXGRP + | stat.S_IXOTH +) S_MODE_HIDDEN = ( - stat.S_IFREG | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH | stat.S_IXUSR - | stat.S_IXGRP | stat.S_IXOTH) + stat.S_IFREG + | stat.S_IWUSR + | stat.S_IWGRP + | stat.S_IWOTH + | stat.S_IXUSR + | stat.S_IXGRP + | stat.S_IXOTH +) def _FormatTimestamp(timestamp: timestamp_pb2.Timestamp) -> str: @@ -52,12 +65,8 @@ def _FormatTimestamp(timestamp: timestamp_pb2.Timestamp) -> str: def _ParseTimestamp(s: str) -> timestamp_pb2.Timestamp: default = datetime.datetime( # pylint: disable=g-tzinfo-datetime - time.gmtime().tm_year, - 1, - 1, - 0, - 0, - tzinfo=dateutil.tz.tzutc()) + time.gmtime().tm_year, 1, 1, 0, 0, tzinfo=dateutil.tz.tzutc() + ) dt = dateutil.parser.parse(s, default=default) result = timestamp_pb2.Timestamp() result.FromDatetime(dt) @@ -74,7 +83,8 @@ class NtfsImageTest(absltest.TestCase, abc.ABC): @abc.abstractmethod def _ExpectedStatEntry( - self, st: filesystem_pb2.StatEntry) -> filesystem_pb2.StatEntry: + self, st: filesystem_pb2.StatEntry + ) -> filesystem_pb2.StatEntry: """Fixes an expected StatEntry for the respective implementation.""" pass @@ -94,14 +104,18 @@ def setUpClass(cls): cls._exit_stack = contextlib.ExitStack() ntfs_image = cls._exit_stack.enter_context( - open(os.path.join(config.CONFIG["Test.data_dir"], "ntfs.img"), "rb")) + open(os.path.join(config.CONFIG["Test.data_dir"], "ntfs.img"), "rb") + ) cls._server = server.CreateFilesystemServer(ntfs_image.fileno()) cls._server.Start() cls._client = cls._exit_stack.enter_context( - client.CreateFilesystemClient(cls._server.Connect(), - cls._IMPLEMENTATION_TYPE, - client.FileDevice(ntfs_image))) + client.CreateFilesystemClient( + cls._server.Connect(), + cls._IMPLEMENTATION_TYPE, + client.FileDevice(ntfs_image), + ) + ) @classmethod def tearDownClass(cls): @@ -114,7 +128,7 @@ def testRead(self): data = file_obj.Read(offset=0, size=50) self.assertEqual( data, - b"1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19\n20" + b"1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19\n20", ) def testNonExistent(self): @@ -135,7 +149,8 @@ def testRead_PastTheEnd(self): def testOpenByInode(self): with self._client.OpenByInode( - inode=self._FileRefToInode(A_B1_C1_D_FILE_REF)) as file_obj: + inode=self._FileRefToInode(A_B1_C1_D_FILE_REF) + ) as file_obj: self.assertEqual(file_obj.Read(0, 100), b"foo\n") def testOpenByInode_stale(self): @@ -175,7 +190,8 @@ def testListFiles(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DIR, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="ads", @@ -189,7 +205,8 @@ def testListFiles(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DIR, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="hidden_file.txt", @@ -201,13 +218,15 @@ def testListFiles(self): ntfs=filesystem_pb2.StatEntry.Ntfs( is_directory=False, flags=stat.FILE_ATTRIBUTE_ARCHIVE # pytype: disable=module-attr - | stat.FILE_ATTRIBUTE_HIDDEN), # pytype: disable=module-attr + | stat.FILE_ATTRIBUTE_HIDDEN, # pytype: disable=module-attr + ), st_size=0, st_gid=0, st_uid=48, st_nlink=1, st_mode=S_MODE_HIDDEN, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="numbers.txt", @@ -222,7 +241,8 @@ def testListFiles(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="read_only_file.txt", @@ -234,13 +254,15 @@ def testListFiles(self): ntfs=filesystem_pb2.StatEntry.Ntfs( is_directory=False, flags=stat.FILE_ATTRIBUTE_ARCHIVE # pytype: disable=module-attr - | stat.FILE_ATTRIBUTE_READONLY), # pytype: disable=module-attr + | stat.FILE_ATTRIBUTE_READONLY, # pytype: disable=module-attr + ), st_size=0, st_gid=0, st_uid=48, st_nlink=1, st_mode=S_MODE_READ_ONLY, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="入乡随俗 海外春节别样过法.txt", @@ -255,7 +277,8 @@ def testListFiles(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), ] self.assertEqual(files, expected_files) @@ -286,7 +309,8 @@ def testListFiles_alternateDataStreams(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="ads.txt", @@ -302,7 +326,8 @@ def testListFiles_alternateDataStreams(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="ads.txt", @@ -318,7 +343,8 @@ def testListFiles_alternateDataStreams(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), ] self.assertEqual(files, expected_files) @@ -334,22 +360,26 @@ def testOpen_alternateDataStreams(self): self.assertEqual(file_obj.Read(0, 100), b"Foo.\n") with self._client.Open( - path=self._Path("\\ads\\ads.txt"), stream_name="one") as file_obj: + path=self._Path("\\ads\\ads.txt"), stream_name="one" + ) as file_obj: self.assertEqual(file_obj.Read(0, 100), b"Bar..\n") with self._client.Open( - path=self._Path("\\ads\\ads.txt"), stream_name="two") as file_obj: + path=self._Path("\\ads\\ads.txt"), stream_name="two" + ) as file_obj: self.assertEqual(file_obj.Read(0, 100), b"Baz...\n") def testOpen_alternateDataStreams_invalid(self): with self.assertRaises(client.OperationError): self._client.Open( - path=self._Path("\\ads\\ads.txt"), stream_name="invalid") + path=self._Path("\\ads\\ads.txt"), stream_name="invalid" + ) def testStat_alternateDataStreams(self): with self._client.Open( - path=self._Path("\\ads\\ads.txt"), stream_name="one") as file_obj: + path=self._Path("\\ads\\ads.txt"), stream_name="one" + ) as file_obj: s = file_obj.Stat() self.assertEqual(s.name, "ads.txt") self.assertEqual(s.stream_name, "one") @@ -360,8 +390,8 @@ def testStat_alternateDataStreams(self): def testOpenByInode_alternateDataStreams(self): with self._client.OpenByInode( - inode=self._FileRefToInode(ADS_ADS_TXT_FILE_REF), - stream_name="one") as file_obj: + inode=self._FileRefToInode(ADS_ADS_TXT_FILE_REF), stream_name="one" + ) as file_obj: self.assertEqual(file_obj.Read(0, 100), b"Bar..\n") def testListFiles_alternateDataStreams_fileOnly(self): @@ -385,7 +415,8 @@ def testListFiles_alternateDataStreams_fileOnly(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), self._ExpectedStatEntry( filesystem_pb2.StatEntry( name="ads.txt", @@ -401,17 +432,21 @@ def testListFiles_alternateDataStreams_fileOnly(self): st_uid=48, st_nlink=1, st_mode=S_MODE_DEFAULT, - )), + ) + ), ] self.assertEqual(files, expected_files) def testReadUnicode(self): - with self._client.Open(path=self._Path("\\入乡随俗 海外春节别样过法.txt")) as file_obj: + with self._client.Open( + path=self._Path("\\入乡随俗 海外春节别样过法.txt") + ) as file_obj: expected = "Chinese news\n中国新闻\n".encode("utf-8") self.assertEqual(file_obj.Read(0, 100), expected) def testRead_fromDirectoryRaises(self): - with self.assertRaisesRegex(client.OperationError, - "Attempting to read from a directory"): + with self.assertRaisesRegex( + client.OperationError, "Attempting to read from a directory" + ): with self._client.Open(path=self._Path("\\a")) as file_obj: file_obj.Read(offset=0, size=1) diff --git a/grr/client/grr_response_client/unprivileged/filesystem/ntfs_test.py b/grr/client/grr_response_client/unprivileged/filesystem/ntfs_test.py index b5336ff6e4..2e06de0608 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/ntfs_test.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/ntfs_test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import contextlib from unittest import mock + from absl.testing import absltest from grr_response_client.unprivileged import test_lib @@ -14,7 +15,8 @@ class NtfsTestBase(ntfs_image_test_lib.NtfsImageTest): _IMPLEMENTATION_TYPE = filesystem_pb2.NTFS def _ExpectedStatEntry( - self, st: filesystem_pb2.StatEntry) -> filesystem_pb2.StatEntry: + self, st: filesystem_pb2.StatEntry + ) -> filesystem_pb2.StatEntry: st.ClearField("st_mode") st.ClearField("st_nlink") st.ClearField("st_uid") @@ -38,7 +40,8 @@ def setUp(self): # The FileDevice won't return a file descriptor. stack.enter_context( - mock.patch.object(client.FileDevice, "file_descriptor", None)) + mock.patch.object(client.FileDevice, "file_descriptor", None) + ) class NtfsWithFileDescriptorSharingTest(NtfsTestBase): diff --git a/grr/client/grr_response_client/unprivileged/filesystem/server.py b/grr/client/grr_response_client/unprivileged/filesystem/server.py index 99b71ab6c4..41df9dd472 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/server.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/server.py @@ -9,10 +9,13 @@ def CreateFilesystemServer( - device_file_descriptor: Optional[int] = None) -> communication.Server: + device_file_descriptor: Optional[int] = None, +) -> communication.Server: extra_file_descriptors = [] if device_file_descriptor is not None: extra_file_descriptors.append( - communication.FileDescriptor.FromFileDescriptor(device_file_descriptor)) - return server.CreateServer(extra_file_descriptors, - interface_registry.Interface.FILESYSTEM) + communication.FileDescriptor.FromFileDescriptor(device_file_descriptor) + ) + return server.CreateServer( + extra_file_descriptors, interface_registry.Interface.FILESYSTEM + ) diff --git a/grr/client/grr_response_client/unprivileged/filesystem/server_lib.py b/grr/client/grr_response_client/unprivileged/filesystem/server_lib.py index 9795594e5a..2ad8b18543 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/server_lib.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/server_lib.py @@ -5,7 +5,7 @@ import os import sys import traceback -from typing import TypeVar, Generic, Optional, Tuple +from typing import Generic, Optional, Tuple, TypeVar from grr_response_client.unprivileged import communication from grr_response_client.unprivileged.filesystem import filesystem from grr_response_client.unprivileged.filesystem import ntfs @@ -15,11 +15,13 @@ class Error(Exception): """Base class for exceptions in this module.""" + pass class DispatchError(Error): """Error while dispatching a request.""" + pass @@ -42,7 +44,8 @@ def __init__(self, connection: communication.Connection): def Send(self, response: filesystem_pb2.Response, attachment: bytes) -> None: self._connection.Send( - communication.Message(response.SerializeToString(), attachment)) + communication.Message(response.SerializeToString(), attachment) + ) def Recv(self) -> Tuple[filesystem_pb2.Request, bytes]: raw_request, attachment = self._connection.Recv() @@ -59,9 +62,11 @@ def __init__(self, connection: ConnectionWrapper): def Read(self, offset: int, size: int) -> bytes: device_data_request = filesystem_pb2.DeviceDataRequest( - offset=offset, size=size) + offset=offset, size=size + ) self._connection.Send( - filesystem_pb2.Response(device_data_request=device_data_request), b'') + filesystem_pb2.Response(device_data_request=device_data_request), b'' + ) _, attachment = self._connection.Recv() return attachment @@ -88,8 +93,12 @@ class OperationHandler(abc.ABC, Generic[RequestType, ResponseType]): common to most RPCs. """ - def __init__(self, state: State, request: filesystem_pb2.Request, - connection: ConnectionWrapper): + def __init__( + self, + state: State, + request: filesystem_pb2.Request, + connection: ConnectionWrapper, + ): self._state = state self._request = request self._connection = connection @@ -123,18 +132,20 @@ def ExtractResponseAttachment(self, response: ResponseType) -> bytes: return b'' -class InitHandler(OperationHandler[filesystem_pb2.InitRequest, - filesystem_pb2.InitResponse]): +class InitHandler( + OperationHandler[filesystem_pb2.InitRequest, filesystem_pb2.InitResponse] +): """Implements the Init operation.""" def HandleOperation( - self, state: State, - request: filesystem_pb2.InitRequest) -> filesystem_pb2.InitResponse: + self, state: State, request: filesystem_pb2.InitRequest + ) -> filesystem_pb2.InitResponse: if request.HasField('serialized_device_file_descriptor'): device = FileDevice( communication.FileDescriptor.FromSerialized( - request.serialized_device_file_descriptor, - communication.Mode.READ).ToFileDescriptor()) + request.serialized_device_file_descriptor, communication.Mode.READ + ).ToFileDescriptor() + ) else: device = self.CreateDevice() @@ -144,30 +155,35 @@ def HandleOperation( state.filesystem = tsk.TskFilesystem(device) else: raise DispatchError( - f'Bad implementation type: {request.implementation_type}') + f'Bad implementation type: {request.implementation_type}' + ) return filesystem_pb2.InitResponse() def PackResponse( - self, response: filesystem_pb2.InitResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.InitResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(init_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.InitRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.InitRequest: return request.init_request -class OpenHandler(OperationHandler[filesystem_pb2.OpenRequest, - filesystem_pb2.OpenResponse]): +class OpenHandler( + OperationHandler[filesystem_pb2.OpenRequest, filesystem_pb2.OpenResponse] +): """Implements the Open operation.""" def HandleOperation( - self, state: State, - request: filesystem_pb2.OpenRequest) -> filesystem_pb2.OpenResponse: + self, state: State, request: filesystem_pb2.OpenRequest + ) -> filesystem_pb2.OpenResponse: path = request.path if request.HasField('path') else None inode = request.inode if request.HasField('inode') else None - stream_name = request.stream_name if request.HasField( - 'stream_name') else None + stream_name = ( + request.stream_name if request.HasField('stream_name') else None + ) assert state.filesystem is not None if inode is None: file_obj = state.filesystem.Open(path, stream_name) @@ -176,69 +192,83 @@ def HandleOperation( file_obj = state.filesystem.OpenByInode(inode, stream_name) except filesystem.StaleInodeError: return filesystem_pb2.OpenResponse( - status=filesystem_pb2.OpenResponse.Status.STALE_INODE) + status=filesystem_pb2.OpenResponse.Status.STALE_INODE + ) file_id = state.files.Add(file_obj) return filesystem_pb2.OpenResponse( status=filesystem_pb2.OpenResponse.Status.NO_ERROR, file_id=file_id, - inode=file_obj.Inode()) + inode=file_obj.Inode(), + ) def PackResponse( - self, response: filesystem_pb2.OpenResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.OpenResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(open_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.OpenRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.OpenRequest: return request.open_request -class ReadHandler(OperationHandler[filesystem_pb2.ReadRequest, - filesystem_pb2.ReadResponse]): +class ReadHandler( + OperationHandler[filesystem_pb2.ReadRequest, filesystem_pb2.ReadResponse] +): """Implements the Read operation.""" def HandleOperation( - self, state: State, - request: filesystem_pb2.ReadRequest) -> filesystem_pb2.ReadResponse: + self, state: State, request: filesystem_pb2.ReadRequest + ) -> filesystem_pb2.ReadResponse: file = state.files.Get(request.file_id) data = file.Read(offset=request.offset, size=request.size) return filesystem_pb2.ReadResponse(data=data) def PackResponse( - self, response: filesystem_pb2.ReadResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.ReadResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(read_response=response) - def ExtractResponseAttachment(self, - response: filesystem_pb2.ReadResponse) -> bytes: + def ExtractResponseAttachment( + self, response: filesystem_pb2.ReadResponse + ) -> bytes: attachment = response.data response.ClearField('data') return attachment def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.ReadRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.ReadRequest: return request.read_request -class StatHandler(OperationHandler[filesystem_pb2.StatRequest, - filesystem_pb2.StatResponse]): +class StatHandler( + OperationHandler[filesystem_pb2.StatRequest, filesystem_pb2.StatResponse] +): """Implements the Stat operation.""" def HandleOperation( - self, state: State, - request: filesystem_pb2.StatRequest) -> filesystem_pb2.StatResponse: + self, state: State, request: filesystem_pb2.StatRequest + ) -> filesystem_pb2.StatResponse: file_obj = state.files.Get(request.file_id) return filesystem_pb2.StatResponse(entry=file_obj.Stat()) def PackResponse( - self, response: filesystem_pb2.StatResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.StatResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(stat_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.StatRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.StatRequest: return request.stat_request -class ListFilesHandler(OperationHandler[filesystem_pb2.ListFilesRequest, - filesystem_pb2.ListFilesResponse]): +class ListFilesHandler( + OperationHandler[ + filesystem_pb2.ListFilesRequest, filesystem_pb2.ListFilesResponse + ] +): """Implements the ListFiles operation.""" def HandleOperation( @@ -248,17 +278,21 @@ def HandleOperation( return filesystem_pb2.ListFilesResponse(entries=file_obj.ListFiles()) def PackResponse( - self, - response: filesystem_pb2.ListFilesResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.ListFilesResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(list_files_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.ListFilesRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.ListFilesRequest: return request.list_files_request -class ListNamesHandler(OperationHandler[filesystem_pb2.ListNamesRequest, - filesystem_pb2.ListNamesResponse]): +class ListNamesHandler( + OperationHandler[ + filesystem_pb2.ListNamesRequest, filesystem_pb2.ListNamesResponse + ] +): """Implements the ListNames operation.""" def HandleOperation( @@ -268,39 +302,46 @@ def HandleOperation( return filesystem_pb2.ListNamesResponse(names=file_obj.ListNames()) def PackResponse( - self, - response: filesystem_pb2.ListNamesResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.ListNamesResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(list_names_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.ListNamesRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.ListNamesRequest: return request.list_names_request -class CloseHandler(OperationHandler[filesystem_pb2.CloseRequest, - filesystem_pb2.CloseResponse]): +class CloseHandler( + OperationHandler[filesystem_pb2.CloseRequest, filesystem_pb2.CloseResponse] +): """Implements the Close operation.""" def HandleOperation( - self, state: State, - request: filesystem_pb2.CloseRequest) -> filesystem_pb2.CloseResponse: + self, state: State, request: filesystem_pb2.CloseRequest + ) -> filesystem_pb2.CloseResponse: file_obj = state.files.Get(request.file_id) file_obj.Close() state.files.Remove(request.file_id) return filesystem_pb2.CloseResponse() def PackResponse( - self, response: filesystem_pb2.CloseResponse) -> filesystem_pb2.Response: + self, response: filesystem_pb2.CloseResponse + ) -> filesystem_pb2.Response: return filesystem_pb2.Response(close_response=response) def UnpackRequest( - self, request: filesystem_pb2.Request) -> filesystem_pb2.CloseRequest: + self, request: filesystem_pb2.Request + ) -> filesystem_pb2.CloseRequest: return request.close_request class LookupCaseInsensitiveHandler( - OperationHandler[filesystem_pb2.LookupCaseInsensitiveRequest, - filesystem_pb2.LookupCaseInsensitiveResponse]): + OperationHandler[ + filesystem_pb2.LookupCaseInsensitiveRequest, + filesystem_pb2.LookupCaseInsensitiveResponse, + ] +): """Implements the LookupCaseInsensitive operation.""" def HandleOperation( @@ -358,7 +399,8 @@ def DispatchWrapped(connection: ConnectionWrapper) -> None: except: # pylint: disable=bare-except exception = filesystem_pb2.Exception( message=str(sys.exc_info()[1]), - formatted_exception=traceback.format_exc()) + formatted_exception=traceback.format_exc(), + ) connection.Send(filesystem_pb2.Response(exception=exception), b'') diff --git a/grr/client/grr_response_client/unprivileged/filesystem/tsk.py b/grr/client/grr_response_client/unprivileged/filesystem/tsk.py index 0a82ace163..defb974927 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/tsk.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/tsk.py @@ -3,7 +3,7 @@ import logging import stat -from typing import Optional, Iterator +from typing import Iterator, Optional import pytsk3 from grr_response_client.unprivileged.filesystem import filesystem from grr_response_client.unprivileged.proto import filesystem_pb2 @@ -104,9 +104,13 @@ def get_size(self) -> int: # pylint: disable=g-bad-name class TskFile(filesystem.File): """TSK implementation of File.""" - def __init__(self, filesystem_obj: filesystem.Filesystem, - fs_info: pytsk3.FS_Info, fd: pytsk3.File, - data_stream: Optional[pytsk3.Attribute]): + def __init__( + self, + filesystem_obj: filesystem.Filesystem, + fs_info: pytsk3.FS_Info, + fd: pytsk3.File, + data_stream: Optional[pytsk3.Attribute], + ): super().__init__(filesystem_obj) self._fs_info = fs_info self._fd = fd @@ -135,9 +139,12 @@ def Read(self, offset: int, size: int) -> bytes: if self._data_stream is None: return self._fd.read_random(offset, available) else: - return self._fd.read_random(offset, available, - self._data_stream.info.type, - self._data_stream.info.id) + return self._fd.read_random( + offset, + available, + self._data_stream.info.type, + self._data_stream.info.id, + ) def Close(self) -> None: pass @@ -262,14 +269,19 @@ def _IsDirectory(self) -> bool: return self._fd.info.meta.type == pytsk3.TSK_FS_META_TYPE_DIR -def _GetDataStream(fd: pytsk3.File, - stream_name: Optional[str]) -> Optional[pytsk3.Attribute]: +def _GetDataStream( + fd: pytsk3.File, stream_name: Optional[str] +) -> Optional[pytsk3.Attribute]: + """Get data stream from a file.""" + if stream_name is None: return None for attribute in fd: - if (attribute.info.name is not None and - _DecodeName(attribute.info.name) == stream_name and - attribute.info.type == pytsk3.TSK_FS_ATTR_TYPE_NTFS_DATA): + if ( + attribute.info.name is not None + and _DecodeName(attribute.info.name) == stream_name + and attribute.info.type == pytsk3.TSK_FS_ATTR_TYPE_NTFS_DATA + ): return attribute raise IOError(f"Failed to open data stream {stream_name}.") diff --git a/grr/client/grr_response_client/unprivileged/filesystem/tsk_test.py b/grr/client/grr_response_client/unprivileged/filesystem/tsk_test.py index a1b3b8154d..9135f2bbfd 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/tsk_test.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/tsk_test.py @@ -11,7 +11,8 @@ class TskTest(ntfs_image_test_lib.NtfsImageTest): _IMPLEMENTATION_TYPE = filesystem_pb2.TSK def _ExpectedStatEntry( - self, st: filesystem_pb2.StatEntry) -> filesystem_pb2.StatEntry: + self, st: filesystem_pb2.StatEntry + ) -> filesystem_pb2.StatEntry: """Clears the fields which are not returned by TSK.""" if st.HasField("ntfs"): st.ClearField("ntfs") diff --git a/grr/client/grr_response_client/unprivileged/filesystem/vfs.py b/grr/client/grr_response_client/unprivileged/filesystem/vfs.py index 2cf759ea5d..5766b4f53f 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/vfs.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/vfs.py @@ -3,7 +3,7 @@ import contextlib import stat -from typing import Any, Callable, Dict, Iterator, Optional, Text, Type, Tuple, NamedTuple +from typing import Any, Callable, Dict, Iterator, NamedTuple, Optional, Text, Tuple, Type from grr_response_client import client_utils from grr_response_client.unprivileged import communication @@ -34,8 +34,9 @@ def KillObject(self, obj: utils.TimeBasedCacheEntry) -> None: MOUNT_CACHE = MountCache() -def _ConvertStatEntry(entry: filesystem_pb2.StatEntry, - pathspec: rdf_paths.PathSpec) -> rdf_client_fs.StatEntry: +def _ConvertStatEntry( + entry: filesystem_pb2.StatEntry, pathspec: rdf_paths.PathSpec +) -> rdf_client_fs.StatEntry: """Converts a stat entry from a filesystem_pb2 protobuf to RDF.""" st = rdf_client_fs.StatEntry() st.pathspec = pathspec.Copy() @@ -80,9 +81,11 @@ def _ConvertStatEntry(entry: filesystem_pb2.StatEntry, class VFSHandlerDevice(client.Device): """A device implementation backed by a VFSHandler.""" - def __init__(self, - vfs_handler: vfs_base.VFSHandler, - device_file_descriptor: Optional[int] = None): + def __init__( + self, + vfs_handler: vfs_base.VFSHandler, + device_file_descriptor: Optional[int] = None, + ): super().__init__() self._vfs_handler = vfs_handler self._device_file_descriptor = device_file_descriptor @@ -101,16 +104,19 @@ class UnprivilegedFileBase(vfs_base.VFSHandler): implementation_type = filesystem_pb2.UNDEFINED - def __init__(self, - base_fd: Optional[vfs_base.VFSHandler], - handlers: Dict[Any, Type[vfs_base.VFSHandler]], - pathspec: Optional[rdf_paths.PathSpec] = None, - progress_callback: Optional[Callable[[], None]] = None): + def __init__( + self, + base_fd: Optional[vfs_base.VFSHandler], + handlers: Dict[Any, Type[vfs_base.VFSHandler]], + pathspec: Optional[rdf_paths.PathSpec] = None, + progress_callback: Optional[Callable[[], None]] = None, + ): super().__init__( base_fd, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) # self.pathspec is initialized to a copy of base_fd @@ -125,7 +131,8 @@ def __init__(self, self.pathspec.last.path = last_path elif not base_fd.IsDirectory(): cache_key = base_fd.pathspec.SerializeToBytes() + str( - self.implementation_type).encode("utf-8") + self.implementation_type + ).encode("utf-8") try: self.client = MOUNT_CACHE.Get(cache_key).client except KeyError: @@ -135,19 +142,26 @@ def __init__(self, server_obj = server.CreateFilesystemServer() stack.enter_context(server_obj) self.client = stack.enter_context( - client.CreateFilesystemClient(server_obj.Connect(), - self.implementation_type, - VFSHandlerDevice(base_fd))) + client.CreateFilesystemClient( + server_obj.Connect(), + self.implementation_type, + VFSHandlerDevice(base_fd), + ) + ) else: with open(device_path, "rb") as device_file: server_obj = server.CreateFilesystemServer(device_file.fileno()) stack.enter_context(server_obj) self.client = stack.enter_context( client.CreateFilesystemClient( - server_obj.Connect(), self.implementation_type, - VFSHandlerDevice(base_fd, device_file.fileno()))) - MOUNT_CACHE.Put(cache_key, - MountCacheItem(server=server_obj, client=self.client)) + server_obj.Connect(), + self.implementation_type, + VFSHandlerDevice(base_fd, device_file.fileno()), + ) + ) + MOUNT_CACHE.Put( + cache_key, MountCacheItem(server=server_obj, client=self.client) + ) # Transfer ownership of resources to MOUNT_CACHE. stack.pop_all() self.pathspec.Append(pathspec) @@ -170,7 +184,8 @@ def __init__(self, # We have to find the corresponding case literal stream name in this # case ourselves. self.fd, self.pathspec.last.stream_name = ( - self._OpenStreamCaseInsensitive(pathspec)) + self._OpenStreamCaseInsensitive(pathspec) + ) else: self.fd = self._OpenPathSpec(pathspec) except client.OperationError as e: @@ -213,7 +228,8 @@ def _ToClientPath(self, path: str) -> str: return path def _OpenStreamCaseInsensitive( - self, pathspec: rdf_paths.PathSpec) -> Tuple[client.File, str]: + self, pathspec: rdf_paths.PathSpec + ) -> Tuple[client.File, str]: """Opens a stream by pathspec with a case-insensitvie stream name. Args: @@ -226,12 +242,14 @@ def _OpenStreamCaseInsensitive( file_pathspec = pathspec.Copy() file_pathspec.stream_name = None result = pathspec.Copy() - result.stream_name = self._GetStreamNameCaseLiteral(file_pathspec, - stream_name) + result.stream_name = self._GetStreamNameCaseLiteral( + file_pathspec, stream_name + ) return self._OpenPathSpec(result), result.stream_name - def _GetStreamNameCaseLiteral(self, file_pathspec: rdf_paths.PathSpec, - stream_name_case_insensitive: str) -> str: + def _GetStreamNameCaseLiteral( + self, file_pathspec: rdf_paths.PathSpec, stream_name_case_insensitive: str + ) -> str: """Returns the case literal stream name. Args: @@ -245,16 +263,18 @@ def _GetStreamNameCaseLiteral(self, file_pathspec: rdf_paths.PathSpec, result = file_obj.LookupCaseInsensitive(stream_name_case_insensitive) if result is not None: return result - raise IOError(f"Failed to open stream {stream_name_case_insensitive} in " - f"{file_pathspec}.") + raise IOError( + f"Failed to open stream {stream_name_case_insensitive} in " + f"{file_pathspec}." + ) @property def size(self) -> int: return self._stat_result.st_size - def Stat(self, - ext_attrs: bool = False, - follow_symlink: bool = True) -> rdf_client_fs.StatEntry: + def Stat( + self, ext_attrs: bool = False, follow_symlink: bool = True + ) -> rdf_client_fs.StatEntry: return self._stat_result def Read(self, length: int) -> bytes: @@ -267,8 +287,10 @@ def Read(self, length: int) -> bytes: def IsDirectory(self) -> bool: return (self._stat_result.st_mode & stat.S_IFDIR) != 0 - def ListFiles(self, # pytype: disable=signature-mismatch # overriding-return-type-checks - ext_attrs: bool = False) -> Iterator[rdf_client_fs.StatEntry]: + def ListFiles( + self, # pytype: disable=signature-mismatch # overriding-return-type-checks + ext_attrs: bool = False, + ) -> Iterator[rdf_client_fs.StatEntry]: del ext_attrs # Unused. self._CheckIsDirectory() @@ -290,8 +312,9 @@ def ListNames(self) -> Iterator[Text]: # pytype: disable=signature-mismatch # def _CheckIsDirectory(self) -> None: if not self.IsDirectory(): - raise IOError("{} is not a directory".format( - self.pathspec.CollapsePath())) + raise IOError( + "{} is not a directory".format(self.pathspec.CollapsePath()) + ) def _CheckIsFile(self) -> None: if self.IsDirectory(): @@ -302,7 +325,8 @@ def Close(self) -> None: self.fd.Close() def MatchBestComponentName( - self, component: str, pathtype: rdf_paths.PathSpec) -> rdf_paths.PathSpec: + self, component: str, pathtype: rdf_paths.PathSpec + ) -> rdf_paths.PathSpec: fd = self.OpenAsContainer(pathtype) assert self.fd is not None @@ -313,7 +337,8 @@ def MatchBestComponentName( if fd.supported_pathtype != self.pathspec.pathtype: new_pathspec = rdf_paths.PathSpec( - path=component, pathtype=fd.supported_pathtype) + path=component, pathtype=fd.supported_pathtype + ) else: new_pathspec = self.pathspec.last.Copy() new_pathspec.path = component @@ -327,12 +352,15 @@ def Open( component: rdf_paths.PathSpec, handlers: Dict[Any, Type[vfs_base.VFSHandler]], pathspec: Optional[rdf_paths.PathSpec] = None, - progress_callback: Optional[Callable[[], None]] = None + progress_callback: Optional[Callable[[], None]] = None, ) -> Optional[vfs_base.VFSHandler]: # A Pathspec which starts with NTFS means we need to resolve the mount # point at runtime. - if (fd is None and component.pathtype == cls.supported_pathtype and - pathspec is not None): + if ( + fd is None + and component.pathtype == cls.supported_pathtype + and pathspec is not None + ): # We are the top level handler. This means we need to check the system # mounts to work out the exact mount point and device we need to # open. We then modify the pathspec so we get nested in the raw @@ -365,7 +393,8 @@ def Open( component=component, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) class UnprivilegedNtfsFile(UnprivilegedFileBase): diff --git a/grr/client/grr_response_client/unprivileged/filesystem/vfs_ntfs_test.py b/grr/client/grr_response_client/unprivileged/filesystem/vfs_ntfs_test.py index d98d90fd80..9e0ba672e5 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/vfs_ntfs_test.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/vfs_ntfs_test.py @@ -20,14 +20,20 @@ def setUp(self): self.addCleanup(stack.close) stack.enter_context( - mock.patch.dict(client_vfs.VFS_HANDLERS, { - vfs.UnprivilegedNtfsFile.supported_pathtype: - vfs.UnprivilegedNtfsFile, - })) + mock.patch.dict( + client_vfs.VFS_HANDLERS, + { + vfs.UnprivilegedNtfsFile.supported_pathtype: ( + vfs.UnprivilegedNtfsFile + ), + }, + ) + ) class VfsNtfsWithFileDescriptorSharingTest(VfsNtfsTestBase): """Test variant sharing the device file descriptor with the server.""" + pass diff --git a/grr/client/grr_response_client/unprivileged/filesystem/vfs_tsk_test.py b/grr/client/grr_response_client/unprivileged/filesystem/vfs_tsk_test.py index 1a4a51dcfa..194dcc8bd5 100644 --- a/grr/client/grr_response_client/unprivileged/filesystem/vfs_tsk_test.py +++ b/grr/client/grr_response_client/unprivileged/filesystem/vfs_tsk_test.py @@ -19,9 +19,15 @@ def setUp(self): self.addCleanup(stack.close) stack.enter_context( - mock.patch.dict(client_vfs.VFS_HANDLERS, { - vfs.UnprivilegedTskFile.supported_pathtype: vfs.UnprivilegedTskFile, - })) + mock.patch.dict( + client_vfs.VFS_HANDLERS, + { + vfs.UnprivilegedTskFile.supported_pathtype: ( + vfs.UnprivilegedTskFile + ), + }, + ) + ) def setUpModule(): diff --git a/grr/client/grr_response_client/unprivileged/interface_registry.py b/grr/client/grr_response_client/unprivileged/interface_registry.py index 01ecd369a4..841f9373c1 100644 --- a/grr/client/grr_response_client/unprivileged/interface_registry.py +++ b/grr/client/grr_response_client/unprivileged/interface_registry.py @@ -30,7 +30,8 @@ class Error(Exception): def GetConnectionHandlerForInterfaceString( - interface_str: str) -> communication.ConnectionHandler: + interface_str: str, +) -> communication.ConnectionHandler: """Returns the connection handler for the respective interface.""" try: interface = Interface(interface_str) diff --git a/grr/client/grr_response_client/unprivileged/linux/sandbox_test.py b/grr/client/grr_response_client/unprivileged/linux/sandbox_test.py index 190a1de3a0..b31d2cdcae 100644 --- a/grr/client/grr_response_client/unprivileged/linux/sandbox_test.py +++ b/grr/client/grr_response_client/unprivileged/linux/sandbox_test.py @@ -14,8 +14,10 @@ from grr_response_client.unprivileged import sandbox -@unittest.skipIf(platform.system() != "Linux" or os.getuid() != 0, - "Skipping Linux-only root test.") +@unittest.skipIf( + platform.system() != "Linux" or os.getuid() != 0, + "Skipping Linux-only root test.", +) class LinuxSandboxTest(absltest.TestCase): @classmethod diff --git a/grr/client/grr_response_client/unprivileged/memory/client.py b/grr/client/grr_response_client/unprivileged/memory/client.py index 5c2c65a050..321bd33c4d 100644 --- a/grr/client/grr_response_client/unprivileged/memory/client.py +++ b/grr/client/grr_response_client/unprivileged/memory/client.py @@ -2,7 +2,7 @@ """Unprivileged memory RPC client code.""" import abc -from typing import TypeVar, Generic, Iterable +from typing import Generic, Iterable, TypeVar from grr_response_client.unprivileged import communication from grr_response_client.unprivileged.proto import memory_pb2 @@ -16,7 +16,8 @@ def __init__(self, connection: communication.Connection): def Send(self, request: memory_pb2.Request) -> None: self._connection.Send( - communication.Message(request.SerializeToString(), b"")) + communication.Message(request.SerializeToString(), b"") + ) def Recv(self) -> memory_pb2.Response: raw_response, _ = self._connection.Recv() @@ -27,6 +28,7 @@ def Recv(self) -> memory_pb2.Response: class Error(Exception): """Base class for exceptions in this module.""" + pass @@ -60,8 +62,10 @@ def Run(self, request: RequestType) -> ResponseType: packed_response = self._connection.Recv() if packed_response.HasField("exception"): - raise OperationError(packed_response.exception.message, - packed_response.exception.formatted_exception) + raise OperationError( + packed_response.exception.message, + packed_response.exception.formatted_exception, + ) else: response = self.UnpackResponse(packed_response) return response @@ -78,30 +82,38 @@ def PackRequest(self, request: RequestType) -> memory_pb2.Request: class UploadSignatureHandler( - OperationHandler[memory_pb2.UploadSignatureRequest, - memory_pb2.UploadSignatureResponse]): + OperationHandler[ + memory_pb2.UploadSignatureRequest, memory_pb2.UploadSignatureResponse + ] +): """Implements the UploadSignature RPC.""" def UnpackResponse( - self, - response: memory_pb2.Response) -> memory_pb2.UploadSignatureResponse: + self, response: memory_pb2.Response + ) -> memory_pb2.UploadSignatureResponse: return response.upload_signature_response def PackRequest( - self, request: memory_pb2.UploadSignatureRequest) -> memory_pb2.Request: + self, request: memory_pb2.UploadSignatureRequest + ) -> memory_pb2.Request: return memory_pb2.Request(upload_signature_request=request) -class ProcessScanHandler(OperationHandler[memory_pb2.ProcessScanRequest, - memory_pb2.ProcessScanResponse]): +class ProcessScanHandler( + OperationHandler[ + memory_pb2.ProcessScanRequest, memory_pb2.ProcessScanResponse + ] +): """Implements the ProcessScan RPC.""" def UnpackResponse( - self, response: memory_pb2.Response) -> memory_pb2.ProcessScanResponse: + self, response: memory_pb2.Response + ) -> memory_pb2.ProcessScanResponse: return response.process_scan_response - def PackRequest(self, - request: memory_pb2.ProcessScanRequest) -> memory_pb2.Request: + def PackRequest( + self, request: memory_pb2.ProcessScanRequest + ) -> memory_pb2.Request: return memory_pb2.Request(process_scan_request=request) @@ -116,9 +128,13 @@ def UploadSignature(self, yara_signature: str): request = memory_pb2.UploadSignatureRequest(yara_signature=yara_signature) UploadSignatureHandler(self._connection).Run(request) - def ProcessScan(self, serialized_file_descriptor: int, - chunks: Iterable[memory_pb2.Chunk], - timeout_seconds: int) -> memory_pb2.ProcessScanResponse: + def ProcessScan( + self, + serialized_file_descriptor: int, + chunks: Iterable[memory_pb2.Chunk], + timeout_seconds: int, + context_window: int, + ) -> memory_pb2.ProcessScanResponse: """Scans process memory. Args: @@ -126,6 +142,7 @@ def ProcessScan(self, serialized_file_descriptor: int, memory. The file descriptor must be accessible by the server process. chunks: Chunks (offset, size) to scan. timeout_seconds: Timeout in seconds. + context_window: Amount of bytes surrounding the match to return. Returns: A `ScanResult` proto. @@ -133,9 +150,10 @@ def ProcessScan(self, serialized_file_descriptor: int, request = memory_pb2.ProcessScanRequest( serialized_file_descriptor=serialized_file_descriptor, chunks=chunks, - timeout_seconds=timeout_seconds) - response = ProcessScanHandler(self._connection).Run(request) - return response + timeout_seconds=timeout_seconds, + context_window=context_window, + ) + return ProcessScanHandler(self._connection).Run(request) def CreateMemoryClient(connection: communication.Connection) -> Client: diff --git a/grr/client/grr_response_client/unprivileged/memory/memory_test.py b/grr/client/grr_response_client/unprivileged/memory/memory_test.py index e5291f0ea5..110b159db0 100644 --- a/grr/client/grr_response_client/unprivileged/memory/memory_test.py +++ b/grr/client/grr_response_client/unprivileged/memory/memory_test.py @@ -13,6 +13,10 @@ from grr_response_client.unprivileged.proto import memory_pb2 _SEARCH_STRING = b"I am a test string, just for testing!!!!" +_EXPECTED_CONTEXT = ( + b'"Just for testing."\n strings:\n $s1 = "I am a test string,' + b' just for testing!!!!"\n co' +) _SIGNATURE = """ rule test_rule { @@ -26,8 +30,10 @@ """ -@unittest.skipIf(platform.system() == "Darwin", - "Sandboxed memory scanning is not yet supported on OSX.") +@unittest.skipIf( + platform.system() == "Darwin", + "Sandboxed memory scanning is not yet supported on OSX.", +) class MemoryTest(absltest.TestCase): def setUp(self): @@ -36,15 +42,17 @@ def setUp(self): self.addCleanup(stack.close) self._process = stack.enter_context( - client_utils.OpenProcessForMemoryAccess(os.getpid())) + client_utils.OpenProcessForMemoryAccess(os.getpid()) + ) self._process.Open() - self._process_file_descriptor = ( - communication.FileDescriptor.FromSerialized( - self._process.serialized_file_descriptor, communication.Mode.READ)) + self._process_file_descriptor = communication.FileDescriptor.FromSerialized( + self._process.serialized_file_descriptor, communication.Mode.READ + ) self._server = stack.enter_context( - server.CreateMemoryServer([self._process_file_descriptor])) + server.CreateMemoryServer([self._process_file_descriptor]) + ) self._client = client.Client(self._server.Connect()) def testProcessScan(self): @@ -55,19 +63,25 @@ def testProcessScan(self): for region in self._process.Regions(): streamer = streaming.Streamer( - chunk_size=1024 * 1024, overlap_size=32 * 1024) + chunk_size=1024 * 1024, overlap_size=32 * 1024 + ) for chunk in streamer.StreamRanges(region.start, region.size): response = self._client.ProcessScan( self._process_file_descriptor.Serialize(), - [memory_pb2.Chunk(offset=chunk.offset, size=chunk.amount)], 60) - self.assertEqual(response.status, - memory_pb2.ProcessScanResponse.Status.NO_ERROR) + [memory_pb2.Chunk(offset=chunk.offset, size=chunk.amount)], + 60, + 50, + ) + self.assertEqual( + response.status, memory_pb2.ProcessScanResponse.Status.NO_ERROR + ) all_scan_matches.extend(response.scan_result.scan_match) self.assertTrue(all_scan_matches) found_in_actual_memory_count = 0 + expected_context_found = False for scan_match in all_scan_matches: self.assertEqual(scan_match.rule_name, "test_rule") for string_match in scan_match.string_matches: @@ -75,15 +89,19 @@ def testProcessScan(self): self.assertEqual(string_match.data, _SEARCH_STRING) # Check that the reported result resides in memory of the # scanned process. - actual_memory = self._process.ReadBytes(string_match.offset, - len(string_match.data)) + actual_memory = self._process.ReadBytes( + string_match.offset, len(string_match.data) + ) # Since copies of the string might be in dynamic memory, we won't be # able to read back every match. We'll check that at least one of the # reads succeeds later. if actual_memory == _SEARCH_STRING: found_in_actual_memory_count += 1 + if string_match.context == _EXPECTED_CONTEXT: + expected_context_found = True self.assertTrue(found_in_actual_memory_count) + self.assertTrue(expected_context_found) def setUpModule() -> None: diff --git a/grr/client/grr_response_client/unprivileged/memory/server.py b/grr/client/grr_response_client/unprivileged/memory/server.py index 024786433d..13a02a1e2d 100644 --- a/grr/client/grr_response_client/unprivileged/memory/server.py +++ b/grr/client/grr_response_client/unprivileged/memory/server.py @@ -9,7 +9,8 @@ def CreateMemoryServer( - process_file_descriptors: List[communication.FileDescriptor] + process_file_descriptors: List[communication.FileDescriptor], ) -> communication.Server: - return server.CreateServer(process_file_descriptors, - interface_registry.Interface.MEMORY) + return server.CreateServer( + process_file_descriptors, interface_registry.Interface.MEMORY + ) diff --git a/grr/client/grr_response_client/unprivileged/memory/server_lib.py b/grr/client/grr_response_client/unprivileged/memory/server_lib.py index 5a622ca99b..b62b7b4938 100644 --- a/grr/client/grr_response_client/unprivileged/memory/server_lib.py +++ b/grr/client/grr_response_client/unprivileged/memory/server_lib.py @@ -5,7 +5,7 @@ import sys import time import traceback -from typing import TypeVar, Generic, Optional, Tuple +from typing import Generic, Optional, Tuple, TypeVar import yara from grr_response_client import client_utils from grr_response_client.unprivileged import communication @@ -14,11 +14,13 @@ class Error(Exception): """Base class for exceptions in this module.""" + pass class DispatchError(Error): """Error while dispatching a request.""" + pass @@ -37,7 +39,8 @@ def __init__(self, connection: communication.Connection): def Send(self, response: memory_pb2.Response) -> None: self._connection.Send( - communication.Message(response.SerializeToString(), b"")) + communication.Message(response.SerializeToString(), b"") + ) def Recv(self) -> memory_pb2.Request: raw_request, _ = self._connection.Recv() @@ -53,8 +56,12 @@ def Recv(self) -> memory_pb2.Request: class OperationHandler(abc.ABC, Generic[RequestType, ResponseType]): """Base class for RPC handlers.""" - def __init__(self, state: State, request: memory_pb2.Request, - connection: ConnectionWrapper): + def __init__( + self, + state: State, + request: memory_pb2.Request, + connection: ConnectionWrapper, + ): self._state = state self._request = request self._connection = connection @@ -81,8 +88,10 @@ def UnpackRequest(self, request: memory_pb2.Request) -> RequestType: class UploadSignatureHandler( - OperationHandler[memory_pb2.UploadSignatureRequest, - memory_pb2.UploadSignatureResponse]): + OperationHandler[ + memory_pb2.UploadSignatureRequest, memory_pb2.UploadSignatureResponse + ] +): """Implements the UploadSignature operation.""" def HandleOperation( @@ -92,74 +101,110 @@ def HandleOperation( return memory_pb2.UploadSignatureResponse() def PackResponse( - self, - response: memory_pb2.UploadSignatureResponse) -> memory_pb2.Response: + self, response: memory_pb2.UploadSignatureResponse + ) -> memory_pb2.Response: return memory_pb2.Response(upload_signature_response=response) def UnpackRequest( - self, request: memory_pb2.Request) -> memory_pb2.UploadSignatureRequest: + self, request: memory_pb2.Request + ) -> memory_pb2.UploadSignatureRequest: return request.upload_signature_request def _YaraStringMatchToProto( - offset: int, value: Tuple[int, str, bytes]) -> memory_pb2.StringMatch: + offset: int, value: Tuple[int, str, bytes], context: bytes +) -> memory_pb2.StringMatch: return memory_pb2.StringMatch( chunk_offset=offset, offset=offset + value[0], string_id=value[1], - data=value[2]) + data=value[2], + context=context, + ) + + +def _YaraMatchToProto( + offset: int, value: "yara.Match", data: bytes, context_window: int +) -> memory_pb2.RuleMatch: + """Converts a yara.Match to a memory_pb2.RuleMatch. + Args: + offset: The offset (within data) where the match is located. + value: The libyara Match object. + data: The data segment where the match occurred. + context_window: The amount of bytes around the match to return. -def _YaraMatchToProto(offset: int, value: "yara.Match") -> memory_pb2.RuleMatch: + Returns: + The resulting memory_pb2.RuleMatch object. + """ result = memory_pb2.RuleMatch(rule_name=value.rule) for yara_string_match in value.strings: - result.string_matches.append( - _YaraStringMatchToProto(offset, yara_string_match)) + context = b"" + if context_window: + match_offset = yara_string_match[0] + context = data[ + match_offset - context_window : match_offset + context_window + ] + match = _YaraStringMatchToProto(offset, yara_string_match, context) + result.string_matches.append(match) return result -class ProcessScanHandler(OperationHandler[memory_pb2.ProcessScanRequest, - memory_pb2.ProcessScanResponse]): +class ProcessScanHandler( + OperationHandler[ + memory_pb2.ProcessScanRequest, memory_pb2.ProcessScanResponse + ] +): """Implements the ProcessScan operation.""" def HandleOperation( - self, state: State, - request: memory_pb2.ProcessScanRequest) -> memory_pb2.ProcessScanResponse: + self, state: State, request: memory_pb2.ProcessScanRequest + ) -> memory_pb2.ProcessScanResponse: if state.yara_rules is None: raise Error("Rules have not been set.") deadline = time.time() + request.timeout_seconds with client_utils.CreateProcessFromSerializedFileDescriptor( - request.serialized_file_descriptor) as process: # pytype: disable=wrong-arg-count # attribute-variable-annotations + request.serialized_file_descriptor + ) as process: # pytype: disable=wrong-arg-count # attribute-variable-annotations result = memory_pb2.ScanResult() for chunk in request.chunks: data = process.ReadBytes(chunk.offset, chunk.size) try: timeout_secs = int(max(deadline - time.time(), 0)) for yara_match in state.yara_rules.match( - data=data, timeout=timeout_secs): - result.scan_match.append( - _YaraMatchToProto(chunk.offset, yara_match)) + data=data, timeout=timeout_secs + ): + match_proto = _YaraMatchToProto( + chunk.offset, yara_match, data, request.context_window + ) + result.scan_match.append(match_proto) except yara.TimeoutError as e: return memory_pb2.ProcessScanResponse( - status=memory_pb2.ProcessScanResponse.Status.TIMEOUT_ERROR) + status=memory_pb2.ProcessScanResponse.Status.TIMEOUT_ERROR + ) except yara.Error as e: # Yara internal error 30 is too many hits. if "internal error: 30" in str(e): return memory_pb2.ProcessScanResponse( - status=memory_pb2.ProcessScanResponse.Status.TOO_MANY_MATCHES) + status=memory_pb2.ProcessScanResponse.Status.TOO_MANY_MATCHES + ) else: return memory_pb2.ProcessScanResponse( - status=memory_pb2.ProcessScanResponse.Status.GENERIC_ERROR) + status=memory_pb2.ProcessScanResponse.Status.GENERIC_ERROR + ) return memory_pb2.ProcessScanResponse( scan_result=result, - status=memory_pb2.ProcessScanResponse.Status.NO_ERROR) + status=memory_pb2.ProcessScanResponse.Status.NO_ERROR, + ) def PackResponse( - self, response: memory_pb2.ProcessScanResponse) -> memory_pb2.Response: + self, response: memory_pb2.ProcessScanResponse + ) -> memory_pb2.Response: return memory_pb2.Response(process_scan_response=response) def UnpackRequest( - self, request: memory_pb2.Request) -> memory_pb2.ProcessScanRequest: + self, request: memory_pb2.Request + ) -> memory_pb2.ProcessScanRequest: return request.process_scan_request @@ -182,7 +227,8 @@ def DispatchWrapped(connection: ConnectionWrapper) -> None: except: # pylint: disable=bare-except exception = memory_pb2.Exception( message=str(sys.exc_info()[1]), - formatted_exception=traceback.format_exc()) + formatted_exception=traceback.format_exc(), + ) connection.Send(memory_pb2.Response(exception=exception)) diff --git a/grr/client/grr_response_client/unprivileged/osx/sandbox_test.py b/grr/client/grr_response_client/unprivileged/osx/sandbox_test.py index 23100c9f92..013ccc11bb 100644 --- a/grr/client/grr_response_client/unprivileged/osx/sandbox_test.py +++ b/grr/client/grr_response_client/unprivileged/osx/sandbox_test.py @@ -12,8 +12,10 @@ from grr_response_client.unprivileged import sandbox -@unittest.skipIf(platform.system() != "Darwin" or os.getuid() != 0, - "Skipping OSX-only root test.") +@unittest.skipIf( + platform.system() != "Darwin" or os.getuid() != 0, + "Skipping OSX-only root test.", +) class OSXSandboxTest(absltest.TestCase): @classmethod diff --git a/grr/client/grr_response_client/unprivileged/proto/memory.proto b/grr/client/grr_response_client/unprivileged/proto/memory.proto index 81ca8d6585..7b1a9f4bbb 100644 --- a/grr/client/grr_response_client/unprivileged/proto/memory.proto +++ b/grr/client/grr_response_client/unprivileged/proto/memory.proto @@ -14,6 +14,9 @@ message StringMatch { // Original offset of the chunk that has been matched. optional uint64 chunk_offset = 4; + + // Context bytes around the match. + optional bytes context = 5; } message RuleMatch { @@ -54,6 +57,9 @@ message ProcessScanRequest { // Timeout in seconds for the scan. optional uint64 timeout_seconds = 4; + // Context window + optional uint64 context_window = 6; + reserved 2, 3; } diff --git a/grr/client/grr_response_client/unprivileged/server.py b/grr/client/grr_response_client/unprivileged/server.py index da92fa989b..cd545c099a 100644 --- a/grr/client/grr_response_client/unprivileged/server.py +++ b/grr/client/grr_response_client/unprivileged/server.py @@ -9,8 +9,9 @@ from grr_response_core import config -def _MakeServerArgs(channel: communication.Channel, - interface: interface_registry.Interface) -> List[str]: +def _MakeServerArgs( + channel: communication.Channel, interface: interface_registry.Interface +) -> List[str]: """Returns the args to run the unprivileged server command.""" assert channel.pipe_input is not None and channel.pipe_output is not None named_flags = [ @@ -40,8 +41,10 @@ def _MakeServerArgs(channel: communication.Channel, def CreateServer( extra_file_descriptors: List[communication.FileDescriptor], - interface: interface_registry.Interface) -> communication.Server: + interface: interface_registry.Interface, +) -> communication.Server: server = communication.SubprocessServer( lambda channel: _MakeServerArgs(channel, interface), - extra_file_descriptors) + extra_file_descriptors, + ) return server diff --git a/grr/client/grr_response_client/unprivileged/server_main_lib.py b/grr/client/grr_response_client/unprivileged/server_main_lib.py index f34dd8205c..c1e6fa98f5 100644 --- a/grr/client/grr_response_client/unprivileged/server_main_lib.py +++ b/grr/client/grr_response_client/unprivileged/server_main_lib.py @@ -7,21 +7,28 @@ from grr_response_client.unprivileged import interface_registry flags.DEFINE_integer( - "unprivileged_server_pipe_input", -1, - "The file descriptor of the input pipe used for communication.") + "unprivileged_server_pipe_input", + -1, + "The file descriptor of the input pipe used for communication.", +) flags.DEFINE_integer( - "unprivileged_server_pipe_output", -1, - "The file descriptor of the output pipe used for communication.") + "unprivileged_server_pipe_output", + -1, + "The file descriptor of the output pipe used for communication.", +) -flags.DEFINE_string("unprivileged_server_interface", "", - "The name of the RPC interface used.") +flags.DEFINE_string( + "unprivileged_server_interface", "", "The name of the RPC interface used." +) -flags.DEFINE_string("unprivileged_user", "", - "Name of user to run unprivileged server as.") +flags.DEFINE_string( + "unprivileged_user", "", "Name of user to run unprivileged server as." +) -flags.DEFINE_string("unprivileged_group", "", - "Name of group to run unprivileged server as.") +flags.DEFINE_string( + "unprivileged_group", "", "Name of group to run unprivileged server as." +) def main(argv): @@ -29,7 +36,11 @@ def main(argv): communication.Main( communication.Channel.FromSerialized( pipe_input=flags.FLAGS.unprivileged_server_pipe_input, - pipe_output=flags.FLAGS.unprivileged_server_pipe_output), + pipe_output=flags.FLAGS.unprivileged_server_pipe_output, + ), interface_registry.GetConnectionHandlerForInterfaceString( - flags.FLAGS.unprivileged_server_interface), - flags.FLAGS.unprivileged_user, flags.FLAGS.unprivileged_group) + flags.FLAGS.unprivileged_server_interface + ), + flags.FLAGS.unprivileged_user, + flags.FLAGS.unprivileged_group, + ) diff --git a/grr/client/grr_response_client/vfs.py b/grr/client/grr_response_client/vfs.py index 11dac7b1c0..450cec2939 100644 --- a/grr/client/grr_response_client/vfs.py +++ b/grr/client/grr_response_client/vfs.py @@ -47,16 +47,19 @@ def Init(): VFS_HANDLERS[files.File.supported_pathtype] = files.File VFS_HANDLERS[files.TempFile.supported_pathtype] = files.TempFile if config.CONFIG["Client.use_filesystem_sandboxing"]: - VFS_HANDLERS[unprivileged_vfs.UnprivilegedNtfsFile - .supported_pathtype] = unprivileged_vfs.UnprivilegedNtfsFile - VFS_HANDLERS[unprivileged_vfs.UnprivilegedTskFile - .supported_pathtype] = unprivileged_vfs.UnprivilegedTskFile + VFS_HANDLERS[unprivileged_vfs.UnprivilegedNtfsFile.supported_pathtype] = ( + unprivileged_vfs.UnprivilegedNtfsFile + ) + VFS_HANDLERS[unprivileged_vfs.UnprivilegedTskFile.supported_pathtype] = ( + unprivileged_vfs.UnprivilegedTskFile + ) else: VFS_HANDLERS[sleuthkit.TSKFile.supported_pathtype] = sleuthkit.TSKFile VFS_HANDLERS[ntfs.NTFSFile.supported_pathtype] = ntfs.NTFSFile if vfs_registry is not None: - VFS_HANDLERS[vfs_registry.RegistryFile - .supported_pathtype] = vfs_registry.RegistryFile + VFS_HANDLERS[vfs_registry.RegistryFile.supported_pathtype] = ( + vfs_registry.RegistryFile + ) VFS_HANDLERS_DIRECT.update(VFS_HANDLERS) VFS_HANDLERS_DIRECT[sleuthkit.TSKFile.supported_pathtype] = sleuthkit.TSKFile @@ -76,48 +79,61 @@ def Init(): except ValueError: raise ValueError( "Badly formatted vfs virtual root: %s. Correct format is " - "os:/path/to/virtual_root" % vfs_virtualroot) + "os:/path/to/virtual_root" % vfs_virtualroot + ) handler_string = handler_string.upper() handler = rdf_paths.PathSpec.PathType.enum_dict.get(handler_string) if handler is None: raise ValueError( "VFSHandler {} could not be registered, because it was not found in" - " PathSpec.PathType {}".format(handler_string, - rdf_paths.PathSpec.PathType.enum_dict)) + " PathSpec.PathType {}".format( + handler_string, rdf_paths.PathSpec.PathType.enum_dict + ) + ) # We need some translation here, TSK needs an OS virtual root base. For # every other handler we can just keep the type the same. - if handler in (rdf_paths.PathSpec.PathType.TSK, - rdf_paths.PathSpec.PathType.NTFS): + if handler in ( + rdf_paths.PathSpec.PathType.TSK, + rdf_paths.PathSpec.PathType.NTFS, + ): base_type = rdf_paths.PathSpec.PathType.OS else: base_type = handler _VFS_VIRTUALROOTS[handler] = rdf_paths.PathSpec( - path=root, pathtype=base_type, is_virtualroot=True) + path=root, pathtype=base_type, is_virtualroot=True + ) def _GetVfsHandlers( - pathspec: rdf_paths.PathSpec) -> Dict[Any, Type[vfs_base.VFSHandler]]: + pathspec: rdf_paths.PathSpec, +) -> Dict[Any, Type[vfs_base.VFSHandler]]: """Returns the table of VFS handlers for the given pathspec.""" for i, element in enumerate(pathspec): if element.HasField("implementation_type") and i != 0: raise ValueError( "implementation_type must be set on the top-level component of " - "a pathspec.") - if (pathspec.implementation_type == - rdf_paths.PathSpec.ImplementationType.DIRECT): + "a pathspec." + ) + if ( + pathspec.implementation_type + == rdf_paths.PathSpec.ImplementationType.DIRECT + ): return VFS_HANDLERS_DIRECT - elif (pathspec.implementation_type == - rdf_paths.PathSpec.ImplementationType.SANDBOX): + elif ( + pathspec.implementation_type + == rdf_paths.PathSpec.ImplementationType.SANDBOX + ): return VFS_HANDLERS_SANDBOX else: return VFS_HANDLERS -def VFSOpen(pathspec: rdf_paths.PathSpec, - progress_callback: Optional[Callable[[], None]] = None - ) -> VFSHandler: +def VFSOpen( + pathspec: rdf_paths.PathSpec, + progress_callback: Optional[Callable[[], None]] = None, +) -> VFSHandler: """Expands pathspec to return an expanded Path. A pathspec is a specification of how to access the file by recursively opening @@ -196,7 +212,6 @@ def VFSOpen(pathspec: rdf_paths.PathSpec, Raises: IOError: if one of the path components can not be opened. - """ # Initialize the dictionary of VFS handlers lazily, if not yet done. if not VFS_HANDLERS: @@ -213,8 +228,11 @@ def VFSOpen(pathspec: rdf_paths.PathSpec, # it to the incoming pathspec except if the pathspec is explicitly # marked as containing a virtual root already or if it isn't marked but # the path already contains the virtual root. - if (not vroot or pathspec.is_virtualroot or - pathspec.CollapsePath().startswith(vroot.CollapsePath())): + if ( + not vroot + or pathspec.is_virtualroot + or pathspec.CollapsePath().startswith(vroot.CollapsePath()) + ): # No virtual root but opening changes the pathspec so we always work on a # copy. working_pathspec = pathspec.Copy() @@ -242,7 +260,8 @@ def VFSOpen(pathspec: rdf_paths.PathSpec, component=component, handlers=dict(handlers), pathspec=working_pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) # If the handler uses `client_utils.GetRawDevice`, it will rewrite # `working_pathspec`, adding 3 new entries (only the first 2 matter). @@ -250,8 +269,10 @@ def VFSOpen(pathspec: rdf_paths.PathSpec, # the new top-level entry and remove it from the original entry (which is # now modified at index 1). - if (orig_component.HasField("implementation_type") and - len(working_pathspec) >= len(orig_working_pathspec) + 2): + if ( + orig_component.HasField("implementation_type") + and len(working_pathspec) >= len(orig_working_pathspec) + 2 + ): working_pathspec.implementation_type = orig_component.implementation_type working_pathspec[1].implementation_type = None diff --git a/grr/client/grr_response_client/vfs_handlers/base.py b/grr/client/grr_response_client/vfs_handlers/base.py index 5a5959f78b..bacc782041 100644 --- a/grr/client/grr_response_client/vfs_handlers/base.py +++ b/grr/client/grr_response_client/vfs_handlers/base.py @@ -28,6 +28,7 @@ def __init__(self, pathtype): class VFSHandler(IO[bytes], metaclass=abc.ABCMeta): """Base class for handling objects in the VFS.""" + supported_pathtype = rdf_paths.PathSpec.PathType.UNSET # Should this handler be auto-registered? @@ -146,7 +147,8 @@ def OpenAsContainer(self, pathtype): base_fd=self, handlers=self._handlers, pathspec=pathspec, - progress_callback=self.progress_callback) + progress_callback=self.progress_callback, + ) def MatchBestComponentName(self, component, pathtype): """Returns the name of the component which matches best our base listing. @@ -177,7 +179,8 @@ def MatchBestComponentName(self, component, pathtype): if fd.supported_pathtype != self.pathspec.pathtype: new_pathspec = rdf_paths.PathSpec( - path=component, pathtype=fd.supported_pathtype) + path=component, pathtype=fd.supported_pathtype + ) else: new_pathspec = self.pathspec.last.Copy() new_pathspec.path = component @@ -299,8 +302,9 @@ def Open(cls, fd, component, handlers, pathspec=None, progress_callback=None): for i, path_component in enumerate(path_components): try: if fd: - new_pathspec = fd.MatchBestComponentName(path_component, - component.pathtype) + new_pathspec = fd.MatchBestComponentName( + path_component, component.pathtype + ) else: new_pathspec = component.Copy() new_pathspec.path = path_component @@ -320,7 +324,8 @@ def Open(cls, fd, component, handlers, pathspec=None, progress_callback=None): base_fd=fd, handlers=handlers, pathspec=new_pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) except IOError as e: # Can not open the first component, we must raise here. if i <= 1: @@ -336,7 +341,8 @@ def Open(cls, fd, component, handlers, pathspec=None, progress_callback=None): pathspec.Insert( 0, path=utils.JoinPath(*path_components[i:]), - pathtype=rdf_paths.PathSpec.PathType.TSK) + pathtype=rdf_paths.PathSpec.PathType.TSK, + ) break return fd diff --git a/grr/client/grr_response_client/vfs_handlers/files.py b/grr/client/grr_response_client/vfs_handlers/files.py index d4015a8635..1b8b1a2345 100644 --- a/grr/client/grr_response_client/vfs_handlers/files.py +++ b/grr/client/grr_response_client/vfs_handlers/files.py @@ -7,8 +7,7 @@ import re import sys import threading - -from typing import Text, Optional +from typing import Optional, Text from grr_response_client import client_utils from grr_response_client.vfs_handlers import base as vfs_base @@ -111,15 +110,17 @@ def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): base_fd, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) if base_fd is None: self.pathspec.Append(pathspec) # We can stack on another directory, which means we concatenate their # directory with ours. elif base_fd.IsDirectory(): - self.pathspec.last.path = utils.JoinPath(self.pathspec.last.path, - pathspec.path) + self.pathspec.last.path = utils.JoinPath( + self.pathspec.last.path, pathspec.path + ) else: raise IOError("File handler can not be stacked on another handler.") @@ -199,7 +200,7 @@ def FileHacks(self): elif re.match(r"/*\\\\.\\[^\\]+\\?$", self.path) is not None: # Special case windows devices cant seek to the end so just lie about # the size - self.size = 0x7fffffffffffffff + self.size = 0x7FFFFFFFFFFFFFFF # Windows raw devices can be opened in two incompatible modes. With a # trailing \ they look like a directory, but without they are the raw @@ -213,7 +214,7 @@ def FileHacks(self): # On Mac, raw disk devices are also not seekable to the end and have no # size so we use the same approach as on Windows. if re.match("/dev/r?disk.*", self.path): - self.size = 0x7fffffffffffffff + self.size = 0x7FFFFFFFFFFFFFFF self.alignment = 512 def _GetDepth(self, path): @@ -263,7 +264,8 @@ def Stat( follow_symlink: bool = True, ) -> rdf_client_fs.StatEntry: return self._Stat( - self.path, ext_attrs=ext_attrs, follow_symlink=follow_symlink) + self.path, ext_attrs=ext_attrs, follow_symlink=follow_symlink + ) def _Stat( self, @@ -290,7 +292,8 @@ def _Stat( local_path, self.pathspec, ext_attrs=ext_attrs, - follow_symlink=follow_symlink) + follow_symlink=follow_symlink, + ) # Is this a symlink? If so we need to note the real location of the file. try: @@ -357,7 +360,8 @@ def GetMountPoint(self, path=None): path string of the mount point """ path = os.path.abspath( - client_utils.CanonicalPathToLocalPath(path or self.path)) + client_utils.CanonicalPathToLocalPath(path or self.path) + ) while not os.path.ismount(path): path = os.path.dirname(path) @@ -371,4 +375,5 @@ def native_path(self) -> Optional[str]: class TempFile(File): """GRR temporary files on the client.""" + supported_pathtype = rdf_paths.PathSpec.PathType.TMPFILE diff --git a/grr/client/grr_response_client/vfs_handlers/files_test.py b/grr/client/grr_response_client/vfs_handlers/files_test.py index cbd751f3e0..604b502f84 100644 --- a/grr/client/grr_response_client/vfs_handlers/files_test.py +++ b/grr/client/grr_response_client/vfs_handlers/files_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io import os import platform diff --git a/grr/client/grr_response_client/vfs_handlers/ntfs.py b/grr/client/grr_response_client/vfs_handlers/ntfs.py index 41bb66bf97..ffab33374d 100644 --- a/grr/client/grr_response_client/vfs_handlers/ntfs.py +++ b/grr/client/grr_response_client/vfs_handlers/ntfs.py @@ -25,7 +25,8 @@ def _GetAlternateDataStreamCaseInsensitive( - fd: pyfsntfs.file_entry, name: Text) -> Optional[pyfsntfs.data_stream]: + fd: pyfsntfs.file_entry, name: Text +) -> Optional[pyfsntfs.data_stream]: name = name.lower() for data_stream in fd.alternate_data_streams: if data_stream.name.lower() == name: @@ -37,16 +38,19 @@ class NTFSFile(vfs_base.VFSHandler): supported_pathtype = rdf_paths.PathSpec.PathType.NTFS - def __init__(self, - base_fd: Optional[vfs_base.VFSHandler], - handlers: Dict[Any, Type[vfs_base.VFSHandler]], - pathspec: Optional[rdf_paths.PathSpec] = None, - progress_callback: Optional[Callable[[], None]] = None): + def __init__( + self, + base_fd: Optional[vfs_base.VFSHandler], + handlers: Dict[Any, Type[vfs_base.VFSHandler]], + pathspec: Optional[rdf_paths.PathSpec] = None, + progress_callback: Optional[Callable[[], None]] = None, + ): super().__init__( base_fd, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) # self.pathspec is initialized to a copy of base_fd @@ -97,13 +101,18 @@ def __init__(self, if pathspec is not None and pathspec.HasField("stream_name"): if pathspec.path_options == rdf_paths.PathSpec.Options.CASE_LITERAL: self.data_stream = self.fd.get_alternate_data_stream_by_name( - pathspec.stream_name) + pathspec.stream_name + ) else: self.data_stream = _GetAlternateDataStreamCaseInsensitive( - self.fd, pathspec.stream_name) + self.fd, pathspec.stream_name + ) if self.data_stream is None: - raise IOError("Failed to open data stream {} in {}.".format( - pathspec.stream_name, path)) + raise IOError( + "Failed to open data stream {} in {}.".format( + pathspec.stream_name, path + ) + ) self.pathspec.last.stream_name = self.data_stream.name else: if self.fd.has_default_data_stream(): @@ -123,9 +132,9 @@ def __init__(self, else: self.size = 0 - def Stat(self, - ext_attrs: bool = False, - follow_symlink: bool = True) -> rdf_client_fs.StatEntry: + def Stat( + self, ext_attrs: bool = False, follow_symlink: bool = True + ) -> rdf_client_fs.StatEntry: return self._Stat(self.fd, self.data_stream, self.pathspec.Copy()) def Read(self, length: int) -> bytes: @@ -139,8 +148,10 @@ def Read(self, length: int) -> bytes: def IsDirectory(self) -> bool: return self.fd.has_directory_entries_index() - def ListFiles(self, # pytype: disable=signature-mismatch # overriding-return-type-checks - ext_attrs: bool = False) -> Iterable[rdf_client_fs.StatEntry]: + def ListFiles( # pytype: disable=signature-mismatch # overriding-return-type-checks + self, + ext_attrs: bool = False, + ) -> Iterable[rdf_client_fs.StatEntry]: del ext_attrs # Unused. self._CheckIsDirectory() @@ -165,8 +176,9 @@ def ListNames(self) -> Iterable[Text]: # pytype: disable=signature-mismatch # def _CheckIsDirectory(self) -> None: if not self.IsDirectory(): - raise IOError("{} is not a directory".format( - self.pathspec.CollapsePath())) + raise IOError( + "{} is not a directory".format(self.pathspec.CollapsePath()) + ) def _CheckIsFile(self) -> None: if self.IsDirectory(): @@ -182,13 +194,17 @@ def _Stat( st.pathspec = pathspec st.st_atime = rdfvalue.RDFDatetimeSeconds.FromDatetime( - entry.get_access_time()) + entry.get_access_time() + ) st.st_mtime = rdfvalue.RDFDatetimeSeconds.FromDatetime( - entry.get_modification_time()) + entry.get_modification_time() + ) st.st_btime = rdfvalue.RDFDatetimeSeconds.FromDatetime( - entry.get_creation_time()) + entry.get_creation_time() + ) st.st_ctime = rdfvalue.RDFDatetimeSeconds.FromDatetime( - entry.get_entry_modification_time()) + entry.get_entry_modification_time() + ) if entry.has_directory_entries_index(): st.st_mode = stat.S_IFDIR else: @@ -210,13 +226,15 @@ def Open( component: rdf_paths.PathSpec, handlers: Dict[Any, Type[vfs_base.VFSHandler]], pathspec: Optional[rdf_paths.PathSpec] = None, - progress_callback: Optional[Callable[[], None]] = None + progress_callback: Optional[Callable[[], None]] = None, ) -> Optional[vfs_base.VFSHandler]: # A Pathspec which starts with NTFS means we need to resolve the mount # point at runtime. - if (fd is None and - component.pathtype == rdf_paths.PathSpec.PathType.NTFS and - pathspec is not None): + if ( + fd is None + and component.pathtype == rdf_paths.PathSpec.PathType.NTFS + and pathspec is not None + ): # We are the top level handler. This means we need to check the system # mounts to work out the exact mount point and device we need to # open. We then modify the pathspec so we get nested in the raw @@ -243,11 +261,13 @@ def Open( # This is necessary so that component.path is ignored. elif component.HasField("inode"): return NTFSFile( - fd, handlers, component, progress_callback=progress_callback) + fd, handlers, component, progress_callback=progress_callback + ) else: return super(NTFSFile, cls).Open( fd=fd, component=component, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) diff --git a/grr/client/grr_response_client/vfs_handlers/ntfs_image_test_lib.py b/grr/client/grr_response_client/vfs_handlers/ntfs_image_test_lib.py index 0969d42a85..476adb1eda 100644 --- a/grr/client/grr_response_client/vfs_handlers/ntfs_image_test_lib.py +++ b/grr/client/grr_response_client/vfs_handlers/ntfs_image_test_lib.py @@ -26,8 +26,8 @@ CHINESE_FILE_FILE_REF = 844424930132045 # Default st_mode flags for files and directories -S_DEFAULT_FILE = (stat.S_IFREG | stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) -S_DEFAULT_DIR = (stat.S_IFDIR | stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) +S_DEFAULT_FILE = stat.S_IFREG | stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO +S_DEFAULT_DIR = stat.S_IFDIR | stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO class NTFSImageTest(absltest.TestCase, abc.ABC): @@ -41,15 +41,18 @@ def _FileRefToInode(self, file_ref: int) -> int: @abc.abstractmethod def _ExpectedStatEntry( - self, st: rdf_client_fs.StatEntry) -> rdf_client_fs.StatEntry: + self, st: rdf_client_fs.StatEntry + ) -> rdf_client_fs.StatEntry: """Fixes an expected StatEntry for the respective implementation.""" pass - def _GetNTFSPathSpec(self, - path, - file_ref=None, - path_options=None, - stream_name=None): + def _GetNTFSPathSpec( + self, + path, + file_ref=None, + path_options=None, + stream_name=None, + ): # ntfs.img is an NTFS formatted filesystem containing: # -rwxrwxrwx 1 root root 4 Mar 4 15:00 ./a/b1/c1/d # -rwxrwxrwx 1 root root 3893 Mar 3 21:10 ./numbers.txt @@ -77,7 +80,9 @@ def _GetNTFSPathSpec(self, pathtype=self.PATH_TYPE, inode=inode, path_options=path_options, - stream_name=stream_name)) + stream_name=stream_name, + ), + ) def testNTFSNestedFile(self): pathspec = self._GetNTFSPathSpec("/a/b1/c1/d") @@ -86,8 +91,12 @@ def testNTFSNestedFile(self): result = fd.Stat() self.assertEqual( result.pathspec, - self._GetNTFSPathSpec("/a/b1/c1/d", A_B1_C1_D_FILE_REF, - rdf_paths.PathSpec.Options.CASE_LITERAL)) + self._GetNTFSPathSpec( + "/a/b1/c1/d", + A_B1_C1_D_FILE_REF, + rdf_paths.PathSpec.Options.CASE_LITERAL, + ), + ) def testNTFSOpenByInode(self): pathspec = self._GetNTFSPathSpec("/a/b1/c1/d") @@ -98,8 +107,11 @@ def testNTFSOpenByInode(self): fd2 = vfs.VFSOpen(fd.pathspec) self.assertEqual(fd2.Read(100), b"foo\n") - pathspec = self._GetNTFSPathSpec("/ignored", fd.pathspec.last.inode, - rdf_paths.PathSpec.Options.CASE_LITERAL) + pathspec = self._GetNTFSPathSpec( + "/ignored", + fd.pathspec.last.inode, + rdf_paths.PathSpec.Options.CASE_LITERAL, + ) fd3 = vfs.VFSOpen(pathspec) self.assertEqual(fd3.Read(100), b"foo\n") @@ -116,8 +128,12 @@ def testNTFSStat(self): s = fd.Stat() self.assertEqual( s.pathspec, - self._GetNTFSPathSpec("/numbers.txt", NUMBERS_TXT_FILE_REF, - rdf_paths.PathSpec.Options.CASE_LITERAL)) + self._GetNTFSPathSpec( + "/numbers.txt", + NUMBERS_TXT_FILE_REF, + rdf_paths.PathSpec.Options.CASE_LITERAL, + ), + ) self.assertEqual(str(s.st_atime), "2020-03-03 20:10:46") self.assertEqual(str(s.st_mtime), "2020-03-03 20:10:46") self.assertEqual(str(s.st_btime), "2020-03-03 16:46:00") @@ -161,123 +177,171 @@ def testNTFSListFiles(self): pathspec=self._GetNTFSPathSpec( "/a", file_ref=A_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 16:48:16"), + "2020-03-03 16:48:16" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 16:47:43"), + "2020-03-03 16:47:43" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 16:47:50"), + "2020-03-03 16:47:50" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 16:47:50"), + "2020-03-03 16:47:50" + ), st_mode=S_DEFAULT_DIR, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/ads", file_ref=ADS_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 14:57:02"), + "2020-04-07 14:57:02" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:23:07"), + "2020-04-07 13:23:07" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 14:56:47"), + "2020-04-07 14:56:47" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 14:56:47"), + "2020-04-07 14:56:47" + ), st_mode=S_DEFAULT_DIR, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/hidden_file.txt", file_ref=HIDDEN_FILE_TXT_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:38"), + "2020-04-08 20:14:38" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:38"), + "2020-04-08 20:14:38" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:38"), + "2020-04-08 20:14:38" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:15:07"), - st_mode=(stat.S_IFREG | stat.S_IWUSR | stat.S_IWGRP - | stat.S_IWOTH - | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH), + "2020-04-08 20:15:07" + ), + st_mode=( + stat.S_IFREG + | stat.S_IWUSR + | stat.S_IWGRP + | stat.S_IWOTH + | stat.S_IXUSR + | stat.S_IXGRP + | stat.S_IXOTH + ), st_size=0, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/numbers.txt", file_ref=NUMBERS_TXT_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 20:10:46"), + "2020-03-03 20:10:46" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 16:46:00"), + "2020-03-03 16:46:00" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 20:10:46"), + "2020-03-03 20:10:46" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-03-03 20:10:46"), + "2020-03-03 20:10:46" + ), st_mode=S_DEFAULT_FILE, st_size=3893, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/read_only_file.txt", file_ref=READ_ONLY_FILE_TXT_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:33"), + "2020-04-08 20:14:33" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:33"), + "2020-04-08 20:14:33" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:33"), + "2020-04-08 20:14:33" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-08 20:14:55"), - st_mode=(stat.S_IFREG | stat.S_IRUSR | stat.S_IRGRP - | stat.S_IROTH - | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH), + "2020-04-08 20:14:55" + ), + st_mode=( + stat.S_IFREG + | stat.S_IRUSR + | stat.S_IRGRP + | stat.S_IROTH + | stat.S_IXUSR + | stat.S_IXGRP + | stat.S_IXOTH + ), st_size=0, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/入乡随俗 海外春节别样过法.txt", file_ref=CHINESE_FILE_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-06-10 13:34:36"), + "2020-06-10 13:34:36" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-06-10 13:34:36"), + "2020-06-10 13:34:36" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-06-10 13:34:36"), + "2020-06-10 13:34:36" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-06-10 13:34:36"), + "2020-06-10 13:34:36" + ), st_mode=S_DEFAULT_FILE, st_size=26, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), ] self.assertEqual(files, expected_files) @@ -293,63 +357,81 @@ def testNTFSListFiles_alternateDataStreams(self): pathspec=self._GetNTFSPathSpec( "/ads/ads.txt", file_ref=ADS_ADS_TXT_FILE_REF, - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL), + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:51"), + "2020-04-07 13:48:51" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:18:53"), + "2020-04-07 13:18:53" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_mode=S_DEFAULT_FILE, st_size=5, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/ads/ads.txt", file_ref=ADS_ADS_TXT_FILE_REF, path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, - stream_name="one"), + stream_name="one", + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:51"), + "2020-04-07 13:48:51" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:18:53"), + "2020-04-07 13:18:53" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_mode=S_DEFAULT_FILE, st_size=6, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), self._ExpectedStatEntry( rdf_client_fs.StatEntry( pathspec=self._GetNTFSPathSpec( "/ads/ads.txt", file_ref=ADS_ADS_TXT_FILE_REF, path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, - stream_name="two"), + stream_name="two", + ), st_atime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:51"), + "2020-04-07 13:48:51" + ), st_btime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:18:53"), + "2020-04-07 13:18:53" + ), st_mtime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_ctime=rdfvalue.RDFDatetimeSeconds.FromHumanReadable( - "2020-04-07 13:48:56"), + "2020-04-07 13:48:56" + ), st_mode=S_DEFAULT_FILE, st_size=7, st_gid=0, st_uid=48, st_nlink=1, - )), + ) + ), ] self.assertEqual(files, expected_files) @@ -370,7 +452,8 @@ def testNTFSOpen_alternateDataStreams(self): pathspec = self._GetNTFSPathSpec( "/ads/ads.txt", stream_name="ONE", - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL) + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ) vfs.VFSOpen(pathspec) pathspec = self._GetNTFSPathSpec("/ads/ads.txt", stream_name="two") @@ -385,7 +468,8 @@ def testNTFSOpen_alternateDataStreams(self): pathspec = self._GetNTFSPathSpec( "/ads/ads.txt", stream_name="TWO", - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL) + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ) vfs.VFSOpen(pathspec) def testNTFSStat_alternateDataStreams(self): @@ -399,7 +483,9 @@ def testNTFSStat_alternateDataStreams(self): "/ads/ads.txt", ADS_ADS_TXT_FILE_REF, stream_name="one", - path_options=rdf_paths.PathSpec.Options.CASE_LITERAL)) + path_options=rdf_paths.PathSpec.Options.CASE_LITERAL, + ), + ) self.assertEqual(str(s.st_atime), "2020-04-07 13:48:51") self.assertEqual(str(s.st_mtime), "2020-04-07 13:48:56") self.assertEqual(str(s.st_btime), "2020-04-07 13:18:53") @@ -407,7 +493,8 @@ def testNTFSStat_alternateDataStreams(self): def testNTFSOpenByInode_alternateDataStreams(self): pathspec = self._GetNTFSPathSpec( - "/ignore", file_ref=ADS_ADS_TXT_FILE_REF, stream_name="ONE") + "/ignore", file_ref=ADS_ADS_TXT_FILE_REF, stream_name="ONE" + ) fd = vfs.VFSOpen(pathspec) self.assertEqual(fd.Read(100), b"Bar..\n") diff --git a/grr/client/grr_response_client/vfs_handlers/ntfs_test_lib.py b/grr/client/grr_response_client/vfs_handlers/ntfs_test_lib.py index a268ffae1b..85695abec9 100644 --- a/grr/client/grr_response_client/vfs_handlers/ntfs_test_lib.py +++ b/grr/client/grr_response_client/vfs_handlers/ntfs_test_lib.py @@ -20,7 +20,8 @@ def _FileRefToInode(self, file_ref: int) -> int: return file_ref def _ExpectedStatEntry( - self, st: rdf_client_fs.StatEntry) -> rdf_client_fs.StatEntry: + self, st: rdf_client_fs.StatEntry + ) -> rdf_client_fs.StatEntry: # libfsntfs doesn't report these fields. st.st_gid = None st.st_uid = None @@ -39,16 +40,17 @@ def testNTFSReadUnicode(self): with open(path, "w", encoding="utf-8") as f: f.write(file_data) pathspec = rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.NTFS) + path=path, pathtype=rdf_paths.PathSpec.PathType.NTFS + ) fd = vfs.VFSOpen(pathspec) self.assertEqual(fd.Read(100).decode("utf-8"), file_data) def testGlobComponentGenerate(self): opts = globbing.PathOpts(pathtype=rdf_paths.PathSpec.PathType.NTFS) - paths = globbing.GlobComponent(u"Windows", opts=opts).Generate("C:\\") - self.assertEqual(list(paths), [u"C:\\Windows"]) + paths = globbing.GlobComponent("Windows", opts=opts).Generate("C:\\") + self.assertEqual(list(paths), ["C:\\Windows"]) def testGlobbingExpandPath(self): opts = globbing.PathOpts(pathtype=rdf_paths.PathSpec.PathType.NTFS) paths = globbing.ExpandPath("C:/Windows/System32/notepad.exe", opts=opts) - self.assertEqual(list(paths), [u"C:\\Windows\\System32\\notepad.exe"]) + self.assertEqual(list(paths), ["C:\\Windows\\System32\\notepad.exe"]) diff --git a/grr/client/grr_response_client/vfs_handlers/registry_test.py b/grr/client/grr_response_client/vfs_handlers/registry_test.py index 0c2c8a1c87..e88f82f75e 100644 --- a/grr/client/grr_response_client/vfs_handlers/registry_test.py +++ b/grr/client/grr_response_client/vfs_handlers/registry_test.py @@ -3,6 +3,7 @@ import platform import unittest + from absl import app from absl.testing import absltest @@ -61,10 +62,13 @@ def testFileStat(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE\SOFTWARE\GRR_TEST\aaa", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) stat = fd.Stat() - self.assertIn(stat.pathspec.path, - "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa") + self.assertIn( + stat.pathspec.path, "/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa" + ) self.assertEqual(stat.pathspec.pathtype, "REGISTRY") self.assertEqual(stat.st_size, 6) @@ -72,7 +76,9 @@ def testFileRead(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/aaa", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.Read(-1), b"lolcat") self.assertEqual(fd.Stat().registry_data.GetValue(), "lolcat") @@ -80,28 +86,36 @@ def testFileReadLongUnicodeValue(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/{}".format(_LONG_KEY), - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.Read(-1).decode("utf-8"), _LONG_STRING_VALUE) def testReadMinDword(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/mindword", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.value, 0) def testReadMaxDword(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/maxdword", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.value, 0xFFFFFFFF) def testReadAnyDword(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/dword42", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.value, 42) self.assertEqual(fd.Stat().registry_data.GetValue(), 42) @@ -109,14 +123,18 @@ def testReadMaxDwordAsString(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/maxdword", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertEqual(fd.Read(-1), b"4294967295") def testListNamesDoesNotListKeyAndValueOfSameNameTwice(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( path=r"/HKEY_LOCAL_MACHINE/SOFTWARE/GRR_TEST/listnametest", - pathtype="REGISTRY")) + pathtype="REGISTRY", + ) + ) self.assertCountEqual(fd.ListNames(), ["bar", "baz"]) diff --git a/grr/client/grr_response_client/vfs_handlers/sleuthkit.py b/grr/client/grr_response_client/vfs_handlers/sleuthkit.py index 23c5a94e3d..f70d2429a1 100644 --- a/grr/client/grr_response_client/vfs_handlers/sleuthkit.py +++ b/grr/client/grr_response_client/vfs_handlers/sleuthkit.py @@ -56,7 +56,7 @@ def get_size(self): # pylint: disable=g-bad-name # Windows is unable to report the true size of the raw device and allows # arbitrary reading past the end - so we lie here to force tsk to read it # anyway - return 10 ** 12 + return 10**12 class TSKFile(vfs_base.VFSHandler): @@ -88,9 +88,7 @@ class TSKFile(vfs_base.VFSHandler): } # Files we won't return in directories. - _IGNORE_FILES = [ - "$OrphanFiles" # Special TSK dir that invokes processing. - ] + _IGNORE_FILES = ["$OrphanFiles"] # Special TSK dir that invokes processing. # The file like object we read our image from tsk_raw_device = None @@ -101,8 +99,14 @@ class TSKFile(vfs_base.VFSHandler): # This is all bits that define the type of the file in the stat mode. Equal to # 0b1111000000000000. stat_type_mask = ( - stat.S_IFREG | stat.S_IFDIR | stat.S_IFLNK | stat.S_IFBLK - | stat.S_IFCHR | stat.S_IFIFO | stat.S_IFSOCK) + stat.S_IFREG + | stat.S_IFDIR + | stat.S_IFLNK + | stat.S_IFBLK + | stat.S_IFCHR + | stat.S_IFIFO + | stat.S_IFSOCK + ) def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): """Use TSK to read the pathspec. @@ -122,7 +126,8 @@ def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): base_fd, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) if self.base_fd is None: raise IOError("TSK driver must have a file base.") @@ -159,7 +164,8 @@ def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): self.fs = self.filesystem.fs except KeyError: self.img = MyImgInfo( - fd=self.tsk_raw_device, progress_callback=progress_callback) + fd=self.tsk_raw_device, progress_callback=progress_callback + ) self.fs = pytsk3.FS_Info(self.img, 0) self.filesystem = CachedFilesystem(self.fs, self.img) @@ -173,8 +179,9 @@ def __init__(self, base_fd, handlers, pathspec=None, progress_callback=None): # NTFS_ID is only required when reading ADSs. If it's not provided, we # just get the first attribute with matching type. if pathspec.HasField("ntfs_id"): - self.tsk_attribute = self.GetAttribute(pathspec.ntfs_type, - pathspec.ntfs_id) + self.tsk_attribute = self.GetAttribute( + pathspec.ntfs_type, pathspec.ntfs_id + ) else: self.tsk_attribute = self.GetAttribute(pathspec.ntfs_type) @@ -247,7 +254,14 @@ def MakeStatResponse(self, tsk_file, tsk_attribute=None, append_name=None): if meta: response.st_ino = meta.addr for attribute in [ - "mode", "nlink", "uid", "gid", "size", "atime", "mtime", "ctime", + "mode", + "nlink", + "uid", + "gid", + "size", + "atime", + "mtime", + "ctime", ]: try: value = int(getattr(meta, attribute)) @@ -269,8 +283,9 @@ def MakeStatResponse(self, tsk_file, tsk_attribute=None, append_name=None): if append_name is not None: # Append the name to the most inner pathspec - child_pathspec.last.path = utils.JoinPath(child_pathspec.last.path, - append_name) + child_pathspec.last.path = utils.JoinPath( + child_pathspec.last.path, append_name + ) child_pathspec.last.inode = meta.addr if tsk_attribute is not None: @@ -318,12 +333,16 @@ def Read(self, length): # NTFS_ID is only required when reading ADSs. If it's is not provided, # we just let pytsk use the default. if self.pathspec.last.HasField("ntfs_id"): - data = self.fd.read_random(self.offset, available, - self.pathspec.last.ntfs_type, - self.pathspec.last.ntfs_id) + data = self.fd.read_random( + self.offset, + available, + self.pathspec.last.ntfs_type, + self.pathspec.last.ntfs_id, + ) else: - data = self.fd.read_random(self.offset, available, - self.pathspec.last.ntfs_type) + data = self.fd.read_random( + self.offset, available, self.pathspec.last.ntfs_type + ) except RuntimeError as e: raise IOError(e) @@ -361,11 +380,13 @@ def ListFiles(self, ext_attrs=None): # Now send back additional named attributes for the ADS. for attribute in f: if attribute.info.type in [ - pytsk3.TSK_FS_ATTR_TYPE_NTFS_DATA, pytsk3.TSK_FS_ATTR_TYPE_DEFAULT + pytsk3.TSK_FS_ATTR_TYPE_NTFS_DATA, + pytsk3.TSK_FS_ATTR_TYPE_DEFAULT, ]: if attribute.info.name: yield self.MakeStatResponse( - f, append_name=name, tsk_attribute=attribute) + f, append_name=name, tsk_attribute=attribute + ) except AttributeError: pass @@ -417,7 +438,8 @@ def Open(cls, fd, component, handlers, pathspec=None, progress_callback=None): # If an inode is specified, just use it directly. elif component.HasField("inode"): return TSKFile( - fd, handlers, component, progress_callback=progress_callback) + fd, handlers, component, progress_callback=progress_callback + ) # Otherwise do the usual case folding. else: @@ -426,4 +448,5 @@ def Open(cls, fd, component, handlers, pathspec=None, progress_callback=None): component=component, handlers=handlers, pathspec=pathspec, - progress_callback=progress_callback) + progress_callback=progress_callback, + ) diff --git a/grr/client/grr_response_client/vfs_handlers/sleuthkit_test.py b/grr/client/grr_response_client/vfs_handlers/sleuthkit_test.py index 3715d7ed1b..e67cc0fe1d 100644 --- a/grr/client/grr_response_client/vfs_handlers/sleuthkit_test.py +++ b/grr/client/grr_response_client/vfs_handlers/sleuthkit_test.py @@ -4,6 +4,7 @@ import collections import platform import unittest + from absl import app from absl.testing import absltest @@ -19,19 +20,20 @@ class TSKWindowsGlobbingTest(vfs_test_lib.VfsTestCase, absltest.TestCase): def testGlobComponentGenerateWorksWithTSK(self): opts = globbing.PathOpts(pathtype=rdf_paths.PathSpec.PathType.TSK) - paths = globbing.GlobComponent(u"Windows", opts=opts).Generate("C:\\") - self.assertEqual(list(paths), [u"C:\\Windows"]) + paths = globbing.GlobComponent("Windows", opts=opts).Generate("C:\\") + self.assertEqual(list(paths), ["C:\\Windows"]) def testGlobbingExpandPathWorksWithTSK(self): opts = globbing.PathOpts(pathtype=rdf_paths.PathSpec.PathType.TSK) paths = globbing.ExpandPath("C:/Windows/System32/notepad.exe", opts=opts) - self.assertEqual(list(paths), [u"C:\\Windows\\System32\\notepad.exe"]) + self.assertEqual(list(paths), ["C:\\Windows\\System32\\notepad.exe"]) def testListNamesNoDuplicates(self): fd = vfs.VFSOpen( rdf_paths.PathSpec( - path="C:/Windows/System32", - pathtype=rdf_paths.PathSpec.PathType.TSK)) + path="C:/Windows/System32", pathtype=rdf_paths.PathSpec.PathType.TSK + ) + ) names = fd.ListNames() counts = collections.Counter(names) duplicates = [(name, count) for name, count in counts.items() if count > 1] diff --git a/grr/client/grr_response_client/vfs_handlers/tsk_test_lib.py b/grr/client/grr_response_client/vfs_handlers/tsk_test_lib.py index 334c4d685a..9c799050e1 100644 --- a/grr/client/grr_response_client/vfs_handlers/tsk_test_lib.py +++ b/grr/client/grr_response_client/vfs_handlers/tsk_test_lib.py @@ -14,5 +14,6 @@ def _FileRefToInode(self, file_ref: int) -> int: return file_ref & ~(0xFFFF << 48) def _ExpectedStatEntry( - self, st: rdf_client_fs.StatEntry) -> rdf_client_fs.StatEntry: + self, st: rdf_client_fs.StatEntry + ) -> rdf_client_fs.StatEntry: return st diff --git a/grr/client/grr_response_client/vfs_test.py b/grr/client/grr_response_client/vfs_test.py index 435323ef2f..85cd1671e0 100644 --- a/grr/client/grr_response_client/vfs_test.py +++ b/grr/client/grr_response_client/vfs_test.py @@ -22,7 +22,7 @@ def _CreateNestedPathSpec( self, path: str, implementation_type: Optional[rdf_structs.EnumNamedValue], - path_options: Optional[rdf_structs.EnumNamedValue] = None + path_options: Optional[rdf_structs.EnumNamedValue] = None, ) -> rdf_paths.PathSpec: ntfs_img_path = os.path.join(config.CONFIG["Test.data_dir"], "ntfs.img") @@ -35,11 +35,14 @@ def _CreateNestedPathSpec( path=path, pathtype=rdf_paths.PathSpec.PathType.NTFS, path_options=path_options, - )) + ), + ) def _CheckHasImplementationType( - self, pathspec: rdf_paths.PathSpec, - implementation_type: rdf_paths.PathSpec.ImplementationType) -> None: + self, + pathspec: rdf_paths.PathSpec, + implementation_type: rdf_paths.PathSpec.ImplementationType, + ) -> None: if implementation_type is None: self.assertFalse(pathspec.HasField("implementation_type")) else: @@ -49,8 +52,10 @@ def _CheckHasImplementationType( self.assertFalse(component.HasField("implementation_type")) def _OpenAndCheckImplementationType( - self, pathspec: rdf_paths.PathSpec, - implementation_type: rdf_paths.PathSpec.ImplementationType) -> None: + self, + pathspec: rdf_paths.PathSpec, + implementation_type: rdf_paths.PathSpec.ImplementationType, + ) -> None: with vfs.VFSOpen(pathspec) as f: self._CheckHasImplementationType(f.pathspec, implementation_type) self._CheckHasImplementationType(f.Stat().pathspec, implementation_type) @@ -67,49 +72,64 @@ def testVfsOpen_default_nestedPath(self): def testVfsOpen_direct_nestedPath(self): pathspec = self._CreateNestedPathSpec( - "/", rdf_paths.PathSpec.ImplementationType.DIRECT) + "/", rdf_paths.PathSpec.ImplementationType.DIRECT + ) self._OpenAndCheckImplementationType( - pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT) + pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT + ) def testVfsOpen_direct_caseLiteral_nestedPath(self): pathspec = self._CreateNestedPathSpec( - "/", rdf_paths.PathSpec.ImplementationType.DIRECT, - rdf_paths.PathSpec.Options.CASE_LITERAL) + "/", + rdf_paths.PathSpec.ImplementationType.DIRECT, + rdf_paths.PathSpec.Options.CASE_LITERAL, + ) self._OpenAndCheckImplementationType( - pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT) + pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT + ) def testVfsOpen_sandbox_nestedPath(self): pathspec = self._CreateNestedPathSpec( - "/", rdf_paths.PathSpec.ImplementationType.SANDBOX) + "/", rdf_paths.PathSpec.ImplementationType.SANDBOX + ) self._OpenAndCheckImplementationType( - pathspec, rdf_paths.PathSpec.ImplementationType.SANDBOX) + pathspec, rdf_paths.PathSpec.ImplementationType.SANDBOX + ) def testVfsOpen_default_rawPath(self): with mock.patch.object( - client_utils, "GetRawDevice", new=self._MockGetRawDevice): + client_utils, "GetRawDevice", new=self._MockGetRawDevice + ): pathspec = rdf_paths.PathSpec( - path="/", pathtype=rdf_paths.PathSpec.PathType.NTFS) + path="/", pathtype=rdf_paths.PathSpec.PathType.NTFS + ) self._OpenAndCheckImplementationType(pathspec, None) def testVfsOpen_direct_rawPath(self): with mock.patch.object( - client_utils, "GetRawDevice", new=self._MockGetRawDevice): + client_utils, "GetRawDevice", new=self._MockGetRawDevice + ): pathspec = rdf_paths.PathSpec( path="/", pathtype=rdf_paths.PathSpec.PathType.NTFS, - implementation_type=rdf_paths.PathSpec.ImplementationType.DIRECT) + implementation_type=rdf_paths.PathSpec.ImplementationType.DIRECT, + ) self._OpenAndCheckImplementationType( - pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT) + pathspec, rdf_paths.PathSpec.ImplementationType.DIRECT + ) def testVfsOpen_sandbox_rawPath(self): with mock.patch.object( - client_utils, "GetRawDevice", new=self._MockGetRawDevice): + client_utils, "GetRawDevice", new=self._MockGetRawDevice + ): pathspec = rdf_paths.PathSpec( path="/", pathtype=rdf_paths.PathSpec.PathType.NTFS, - implementation_type=rdf_paths.PathSpec.ImplementationType.SANDBOX) + implementation_type=rdf_paths.PathSpec.ImplementationType.SANDBOX, + ) self._OpenAndCheckImplementationType( - pathspec, rdf_paths.PathSpec.ImplementationType.SANDBOX) + pathspec, rdf_paths.PathSpec.ImplementationType.SANDBOX + ) def setUpModule() -> None: diff --git a/grr/client/grr_response_client/windows/__init__.py b/grr/client/grr_response_client/windows/__init__.py index 42da1790a5..482c39f5d5 100644 --- a/grr/client/grr_response_client/windows/__init__.py +++ b/grr/client/grr_response_client/windows/__init__.py @@ -1,3 +1,2 @@ #!/usr/bin/env python """Client windows-specific module root.""" - diff --git a/grr/client/grr_response_client/windows/installers.py b/grr/client/grr_response_client/windows/installers.py index 8b14f3bf70..af87670ac0 100644 --- a/grr/client/grr_response_client/windows/installers.py +++ b/grr/client/grr_response_client/windows/installers.py @@ -43,9 +43,11 @@ flags.DEFINE_string( - "interpolate_fleetspeak_service_config", "", + "interpolate_fleetspeak_service_config", + "", "If set, only interpolate a fleetspeak service config. " - "The value is a path to a file to interpolate (rewrite).") + "The value is a path to a file to interpolate (rewrite).", +) def _StartService(service_name): @@ -60,8 +62,9 @@ def _StartService(service_name): logging.info("Service '%s' started.", service_name) except pywintypes.error as e: if getattr(e, "winerror", None) == winerror.ERROR_SERVICE_DOES_NOT_EXIST: - logging.debug("Tried to start '%s', but the service is not installed.", - service_name) + logging.debug( + "Tried to start '%s', but the service is not installed.", service_name + ) else: logging.exception("Encountered error trying to start '%s':", service_name) @@ -89,8 +92,9 @@ def _StopService(service_name, service_binary_name=None): status = win32serviceutil.QueryServiceStatus(service_name)[1] except pywintypes.error as e: if getattr(e, "winerror", None) == winerror.ERROR_SERVICE_DOES_NOT_EXIST: - logging.debug("Tried to stop '%s', but the service is not installed.", - service_name) + logging.debug( + "Tried to stop '%s', but the service is not installed.", service_name + ) else: logging.exception("Unable to query status of service '%s':", service_name) return @@ -119,7 +123,8 @@ def _StopService(service_name, service_binary_name=None): ["taskkill", "/im", "%s*" % service_binary_name, "/f"], shell=True, stdin=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + ) logging.debug("%s", output) @@ -135,47 +140,59 @@ def _RemoveService(service_name): logging.info("Service '%s' removed.", service_name) except pywintypes.error as e: if getattr(e, "winerror", None) == winerror.ERROR_SERVICE_DOES_NOT_EXIST: - logging.debug("Tried to remove '%s', but the service is not installed.", - service_name) + logging.debug( + "Tried to remove '%s', but the service is not installed.", + service_name, + ) else: logging.exception("Unable to remove service '%s':", service_name) -def _CreateService(service_name: str, description: str, - command_line: str) -> None: +def _CreateService( + service_name: str, + description: str, + command_line: str, +) -> None: """Creates a Windows service.""" logging.info("Creating service '%s'.", service_name) with contextlib.ExitStack() as stack: - hscm = win32service.OpenSCManager(None, None, - win32service.SC_MANAGER_ALL_ACCESS) + hscm = win32service.OpenSCManager( + None, None, win32service.SC_MANAGER_ALL_ACCESS + ) stack.callback(win32service.CloseServiceHandle, hscm) - hs = win32service.CreateService(hscm, service_name, service_name, - win32service.SERVICE_ALL_ACCESS, - win32service.SERVICE_WIN32_OWN_PROCESS, - win32service.SERVICE_AUTO_START, - win32service.SERVICE_ERROR_NORMAL, - command_line, None, 0, None, None, None) + hs = win32service.CreateService( + hscm, + service_name, + service_name, + win32service.SERVICE_ALL_ACCESS, + win32service.SERVICE_WIN32_OWN_PROCESS, + win32service.SERVICE_AUTO_START, + win32service.SERVICE_ERROR_NORMAL, + command_line, + None, + 0, + None, + None, + None, + ) stack.callback(win32service.CloseServiceHandle, hs) service_failure_actions = { - "ResetPeriod": - SERVICE_RESET_FAIL_COUNT_DELAY_SEC, - "RebootMsg": - u"", - "Command": - u"", + "ResetPeriod": SERVICE_RESET_FAIL_COUNT_DELAY_SEC, + "RebootMsg": "", + "Command": "", "Actions": [ (win32service.SC_ACTION_RESTART, SERVICE_RESTART_DELAY_MSEC), (win32service.SC_ACTION_RESTART, SERVICE_RESTART_DELAY_MSEC), (win32service.SC_ACTION_RESTART, SERVICE_RESTART_DELAY_MSEC), - ] + ], } win32service.ChangeServiceConfig2( - hs, win32service.SERVICE_CONFIG_FAILURE_ACTIONS, - service_failure_actions) - win32service.ChangeServiceConfig2(hs, - win32service.SERVICE_CONFIG_DESCRIPTION, - description) + hs, win32service.SERVICE_CONFIG_FAILURE_ACTIONS, service_failure_actions + ) + win32service.ChangeServiceConfig2( + hs, win32service.SERVICE_CONFIG_DESCRIPTION, description + ) logging.info("Successfully created service '%s'.", service_name) @@ -190,15 +207,18 @@ def _OpenRegkey(key_path): def _CheckForWow64(): """Checks to ensure we are not running on a Wow64 system.""" if win32process.IsWow64Process(): - raise RuntimeError("Will not install a 32 bit client on a 64 bit system. " - "Please use the correct client.") + raise RuntimeError( + "Will not install a 32 bit client on a 64 bit system. " + "Please use the correct client." + ) def _StopPreviousService(): """Stops the Windows service hosting the GRR process.""" _StopService( service_name=config.CONFIG["Nanny.service_name"], - service_binary_name=config.CONFIG["Nanny.service_binary_name"]) + service_binary_name=config.CONFIG["Nanny.service_binary_name"], + ) _StopService(service_name=config.CONFIG["Client.fleetspeak_service_name"]) @@ -212,8 +232,11 @@ def _DeleteGrrFleetspeakService(): regkey = _OpenRegkey(key_path) try: winreg.DeleteValue(regkey, config.CONFIG["Client.name"]) - logging.info("Deleted value '%s' of key '%s'.", - config.CONFIG["Client.name"], key_path) + logging.info( + "Deleted value '%s' of key '%s'.", + config.CONFIG["Client.name"], + key_path, + ) except OSError as e: # Windows will raise a no-such-file-or-directory error if # GRR's config hasn't been written to the registry yet. @@ -295,8 +318,11 @@ def _CopyToSystemDir(): """ executable_directory = os.path.dirname(sys.executable) install_path = config.CONFIG["Client.install_path"] - logging.info("Installing binaries %s -> %s", executable_directory, - config.CONFIG["Client.install_path"]) + logging.info( + "Installing binaries %s -> %s", + executable_directory, + config.CONFIG["Client.install_path"], + ) # Recursively copy the temp directory to the installation directory. for root, dirs, files in os.walk(executable_directory): @@ -331,20 +357,25 @@ def _CopyToSystemDir(): # Options for the legacy (non-Fleetspeak) GRR installation that should get # deleted when installing Fleetspeak-enabled GRR clients. _LEGACY_OPTIONS = frozenset( - itertools.chain(_NANNY_OPTIONS, - ["Nanny.status", "Nanny.heartbeat", "Client.labels"])) + itertools.chain( + _NANNY_OPTIONS, ["Nanny.status", "Nanny.heartbeat", "Client.labels"] + ) +) def _DeleteLegacyConfigOptions(registry_key_uri): """Deletes config values in the registry for legacy GRR installations.""" key_spec = regconfig.ParseRegistryURI(registry_key_uri) try: - regkey = winreg.OpenKeyEx(key_spec.winreg_hive, key_spec.path, 0, - winreg.KEY_ALL_ACCESS) + regkey = winreg.OpenKeyEx( + key_spec.winreg_hive, key_spec.path, 0, winreg.KEY_ALL_ACCESS + ) except OSError as e: if e.errno == errno.ENOENT: - logging.info("Skipping legacy config purge for non-existent key: %s.", - registry_key_uri) + logging.info( + "Skipping legacy config purge for non-existent key: %s.", + registry_key_uri, + ) return else: raise @@ -362,32 +393,39 @@ def _DeleteLegacyConfigOptions(registry_key_uri): def _IsFleetspeakBundled(): return os.path.exists( - os.path.join(config.CONFIG["Client.install_path"], - "fleetspeak-client.exe")) + os.path.join( + config.CONFIG["Client.install_path"], "fleetspeak-client.exe" + ) + ) def _InstallBundledFleetspeak(): - fleetspeak_client = os.path.join(config.CONFIG["Client.install_path"], - "fleetspeak-client.exe") - fleetspeak_config = os.path.join(config.CONFIG["Client.install_path"], - "fleetspeak-client.config") + fleetspeak_client = os.path.join( + config.CONFIG["Client.install_path"], "fleetspeak-client.exe" + ) + fleetspeak_config = os.path.join( + config.CONFIG["Client.install_path"], "fleetspeak-client.config" + ) _RemoveService(config.CONFIG["Client.fleetspeak_service_name"]) _CreateService( service_name=config.CONFIG["Client.fleetspeak_service_name"], description="Fleetspeak communication agent.", - command_line=f"\"{fleetspeak_client}\" -config \"{fleetspeak_config}\"") + command_line=f'"{fleetspeak_client}" -config "{fleetspeak_config}"', + ) def _MaybeInterpolateFleetspeakServiceConfig(): """Interpolates the fleetspeak service config if present.""" fleetspeak_unsigned_config_path = os.path.join( config.CONFIG["Client.install_path"], - config.CONFIG["Client.fleetspeak_unsigned_config_fname"]) + config.CONFIG["Client.fleetspeak_unsigned_config_fname"], + ) template_path = f"{fleetspeak_unsigned_config_path}.in" if not os.path.exists(template_path): return - _InterpolateFleetspeakServiceConfig(template_path, - fleetspeak_unsigned_config_path) + _InterpolateFleetspeakServiceConfig( + template_path, fleetspeak_unsigned_config_path + ) def _InterpolateFleetspeakServiceConfig(src_path: str, dst_path: str) -> None: @@ -407,9 +445,15 @@ def _WriteGrrFleetspeakService(): regkey = _OpenRegkey(key_path) fleetspeak_unsigned_config_path = os.path.join( config.CONFIG["Client.install_path"], - config.CONFIG["Client.fleetspeak_unsigned_config_fname"]) - winreg.SetValueEx(regkey, config.CONFIG["Client.name"], 0, winreg.REG_SZ, - fleetspeak_unsigned_config_path) + config.CONFIG["Client.fleetspeak_unsigned_config_fname"], + ) + winreg.SetValueEx( + regkey, + config.CONFIG["Client.name"], + 0, + winreg.REG_SZ, + fleetspeak_unsigned_config_path, + ) def _Run(): @@ -418,7 +462,8 @@ def _Run(): if flags.FLAGS.interpolate_fleetspeak_service_config: _InterpolateFleetspeakServiceConfig( flags.FLAGS.interpolate_fleetspeak_service_config, - flags.FLAGS.interpolate_fleetspeak_service_config) + flags.FLAGS.interpolate_fleetspeak_service_config, + ) fs_service = config.CONFIG["Client.fleetspeak_service_name"] _StopService(service_name=fs_service) _StartService(service_name=fs_service) diff --git a/grr/client/grr_response_client/windows/installers_test.py b/grr/client/grr_response_client/windows/installers_test.py index 92680e9603..16f8818a4a 100644 --- a/grr/client/grr_response_client/windows/installers_test.py +++ b/grr/client/grr_response_client/windows/installers_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import platform import unittest from unittest import mock @@ -33,8 +32,9 @@ def _GetAllRegistryKeyValues(key): return values -@unittest.skipIf(platform.system() != "Windows", - "Windows-only functionality being tested.") +@unittest.skipIf( + platform.system() != "Windows", "Windows-only functionality being tested." +) class InstallerTest(absltest.TestCase): @classmethod @@ -47,17 +47,20 @@ def setUpClass(cls): def tearDownClass(cls): super(InstallerTest, cls).tearDownClass() - winreg.DeleteKeyEx(winreg.HKEY_LOCAL_MACHINE, _TEST_KEY_PATH, - winreg.KEY_ALL_ACCESS, 0) + winreg.DeleteKeyEx( + winreg.HKEY_LOCAL_MACHINE, _TEST_KEY_PATH, winreg.KEY_ALL_ACCESS, 0 + ) @mock.patch.object(installers, "_LEGACY_OPTIONS", frozenset(["bar"])) def testDeleteLegacyConfigOptions(self): - key = winreg.OpenKeyEx(winreg.HKEY_LOCAL_MACHINE, _TEST_KEY_PATH, 0, - winreg.KEY_ALL_ACCESS) + key = winreg.OpenKeyEx( + winreg.HKEY_LOCAL_MACHINE, _TEST_KEY_PATH, 0, winreg.KEY_ALL_ACCESS + ) winreg.SetValueEx(key, "foo", 0, winreg.REG_SZ, "foo-value") winreg.SetValueEx(key, "bar", 0, winreg.REG_SZ, "bar-value") installers._DeleteLegacyConfigOptions( - "reg://HKEY_LOCAL_MACHINE/{}".format(_TEST_KEY_PATH)) + "reg://HKEY_LOCAL_MACHINE/{}".format(_TEST_KEY_PATH) + ) remaining_values = _GetAllRegistryKeyValues(key) self.assertDictEqual(remaining_values, {"foo": "foo-value"}) diff --git a/grr/client/grr_response_client/windows/process.py b/grr/client/grr_response_client/windows/process.py index 8bd6bab700..9e7e61a160 100644 --- a/grr/client/grr_response_client/windows/process.py +++ b/grr/client/grr_response_client/windows/process.py @@ -3,7 +3,6 @@ This code is based on the memorpy project: https://github.com/n1nj4sec/memorpy - """ import ctypes @@ -46,37 +45,44 @@ class SECURITY_DESCRIPTOR(ctypes.Structure): - _fields_ = [("SID", wintypes.DWORD), ("group", wintypes.DWORD), - ("dacl", wintypes.DWORD), ("sacl", wintypes.DWORD), - ("test", wintypes.DWORD)] + _fields_ = [ + ("SID", wintypes.DWORD), + ("group", wintypes.DWORD), + ("dacl", wintypes.DWORD), + ("sacl", wintypes.DWORD), + ("test", wintypes.DWORD), + ] PSECURITY_DESCRIPTOR = ctypes.POINTER(SECURITY_DESCRIPTOR) class SYSTEM_INFO(ctypes.Structure): - _fields_ = [("wProcessorArchitecture", - wintypes.WORD), ("wReserved", wintypes.WORD), ("dwPageSize", - wintypes.DWORD), - ("lpMinimumApplicationAddress", - wintypes.LPVOID), ("lpMaximumApplicationAddress", - wintypes.LPVOID), ("dwActiveProcessorMask", - wintypes.WPARAM), - ("dwNumberOfProcessors", wintypes.DWORD), ("dwProcessorType", - wintypes.DWORD), - ("dwAllocationGranularity", - wintypes.DWORD), ("wProcessorLevel", - wintypes.WORD), ("wProcessorRevision", - wintypes.WORD)] + _fields_ = [ + ("wProcessorArchitecture", wintypes.WORD), + ("wReserved", wintypes.WORD), + ("dwPageSize", wintypes.DWORD), + ("lpMinimumApplicationAddress", wintypes.LPVOID), + ("lpMaximumApplicationAddress", wintypes.LPVOID), + ("dwActiveProcessorMask", wintypes.WPARAM), + ("dwNumberOfProcessors", wintypes.DWORD), + ("dwProcessorType", wintypes.DWORD), + ("dwAllocationGranularity", wintypes.DWORD), + ("wProcessorLevel", wintypes.WORD), + ("wProcessorRevision", wintypes.WORD), + ] class MEMORY_BASIC_INFORMATION(ctypes.Structure): - _fields_ = [("BaseAddress", ctypes.c_void_p), ("AllocationBase", - ctypes.c_void_p), - ("AllocationProtect", - wintypes.DWORD), ("RegionSize", - ctypes.c_size_t), ("State", wintypes.DWORD), - ("Protect", wintypes.DWORD), ("Type", wintypes.DWORD)] + _fields_ = [ + ("BaseAddress", ctypes.c_void_p), + ("AllocationBase", ctypes.c_void_p), + ("AllocationProtect", wintypes.DWORD), + ("RegionSize", ctypes.c_size_t), + ("State", wintypes.DWORD), + ("Protect", wintypes.DWORD), + ("Type", wintypes.DWORD), + ] CloseHandle = kernel32.CloseHandle @@ -84,14 +90,19 @@ class MEMORY_BASIC_INFORMATION(ctypes.Structure): ReadProcessMemory = kernel32.ReadProcessMemory ReadProcessMemory.argtypes = [ - wintypes.HANDLE, wintypes.LPCVOID, wintypes.LPVOID, ctypes.c_size_t, - ctypes.POINTER(ctypes.c_size_t) + wintypes.HANDLE, + wintypes.LPCVOID, + wintypes.LPVOID, + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_size_t), ] VirtualQueryEx = kernel32.VirtualQueryEx VirtualQueryEx.argtypes = [ - wintypes.HANDLE, wintypes.LPCVOID, - ctypes.POINTER(MEMORY_BASIC_INFORMATION), ctypes.c_size_t + wintypes.HANDLE, + wintypes.LPCVOID, + ctypes.POINTER(MEMORY_BASIC_INFORMATION), + ctypes.c_size_t, ] VirtualQueryEx.restype = ctypes.c_size_t @@ -138,10 +149,12 @@ def Open(self): self.h_process = self._existing_handle else: self.h_process = kernel32.OpenProcess( - PROCESS_VM_READ | PROCESS_QUERY_INFORMATION, 0, self.pid) + PROCESS_VM_READ | PROCESS_QUERY_INFORMATION, 0, self.pid + ) if not self.h_process: raise process_error.ProcessError( - "Failed to open process (pid %d)." % self.pid) + "Failed to open process (pid %d)." % self.pid + ) if self.Is64bit(): si = self.GetNativeSystemInfo() @@ -176,16 +189,19 @@ def GetNativeSystemInfo(self): def VirtualQueryEx(self, address): mbi = MEMORY_BASIC_INFORMATION() - res = VirtualQueryEx(self.h_process, address, ctypes.byref(mbi), - ctypes.sizeof(mbi)) + res = VirtualQueryEx( + self.h_process, address, ctypes.byref(mbi), ctypes.sizeof(mbi) + ) if not res: raise process_error.ProcessError("Error VirtualQueryEx: 0x%08X" % address) return mbi - def Regions(self, - skip_special_regions=False, - skip_executable_regions=False, - skip_readonly_regions=False): + def Regions( + self, + skip_special_regions=False, + skip_executable_regions=False, + skip_readonly_regions=False, + ): """Returns an iterator over the readable regions for this process.""" offset = self.min_addr @@ -202,22 +218,40 @@ def Regions(self, start=offset, size=mbi.RegionSize, is_readable=True, - is_writable=bool(protect - & (PAGE_EXECUTE_READWRITE | PAGE_READWRITE - | PAGE_EXECUTE_WRITECOPY | PAGE_WRITECOPY)), - is_executable=bool(protect & (PAGE_EXECUTE | PAGE_EXECUTE_READ - | PAGE_EXECUTE_READWRITE - | PAGE_EXECUTE_WRITECOPY))) + is_writable=bool( + protect + & ( + PAGE_EXECUTE_READWRITE + | PAGE_READWRITE + | PAGE_EXECUTE_WRITECOPY + | PAGE_WRITECOPY + ) + ), + is_executable=bool( + protect + & ( + PAGE_EXECUTE + | PAGE_EXECUTE_READ + | PAGE_EXECUTE_READWRITE + | PAGE_EXECUTE_WRITECOPY + ) + ), + ) is_special = ( - protect & PAGE_NOCACHE or protect & PAGE_WRITECOMBINE or - protect & PAGE_GUARD) + protect & PAGE_NOCACHE + or protect & PAGE_WRITECOMBINE + or protect & PAGE_GUARD + ) offset += chunk - if (state & MEM_FREE or state & MEM_RESERVE or - (skip_special_regions and is_special) or - (skip_executable_regions and region.is_executable) or - (skip_readonly_regions and not region.is_writable)): + if ( + state & MEM_FREE + or state & MEM_RESERVE + or (skip_special_regions and is_special) + or (skip_executable_regions and region.is_executable) + or (skip_readonly_regions and not region.is_writable) + ): continue yield region @@ -227,16 +261,17 @@ def ReadBytes(self, address, num_bytes): address = int(address) buf = ctypes.create_string_buffer(num_bytes) bytesread = ctypes.c_size_t(0) - res = ReadProcessMemory(self.h_process, address, buf, num_bytes, - ctypes.byref(bytesread)) + res = ReadProcessMemory( + self.h_process, address, buf, num_bytes, ctypes.byref(bytesread) + ) if res == 0: err = ctypes.GetLastError() if err == 299: # Only part of ReadProcessMemory has been done, let's return it. - return buf.raw[:bytesread.value] + return buf.raw[: bytesread.value] raise process_error.ProcessError("Error in ReadProcessMemory: %d" % err) - return buf.raw[:bytesread.value] + return buf.raw[: bytesread.value] @property def serialized_file_descriptor(self) -> int: @@ -244,5 +279,6 @@ def serialized_file_descriptor(self) -> int: @classmethod def CreateFromSerializedFileDescriptor( - cls, serialized_file_descriptor: int) -> "Process": + cls, serialized_file_descriptor: int + ) -> "Process": return Process(handle=serialized_file_descriptor) diff --git a/grr/client/grr_response_client/windows/regconfig.py b/grr/client/grr_response_client/windows/regconfig.py index a32af07cc7..28e6059953 100644 --- a/grr/client/grr_response_client/windows/regconfig.py +++ b/grr/client/grr_response_client/windows/regconfig.py @@ -12,7 +12,6 @@ import logging from typing import Any, Dict, Text from urllib import parse as urlparse - import winreg from grr_response_core.lib import config_parser @@ -20,7 +19,8 @@ class RegistryKeySpec( - collections.namedtuple("RegistryKey", ["hive", "winreg_hive", "path"])): + collections.namedtuple("RegistryKey", ["hive", "winreg_hive", "path"]) +): __slots__ = () def __str__(self): @@ -32,7 +32,8 @@ def ParseRegistryURI(uri): return RegistryKeySpec( hive=url.netloc, winreg_hive=getattr(winreg, url.netloc), - path=url.path.replace("/", "\\").lstrip("\\")) + path=url.path.replace("/", "\\").lstrip("\\"), + ) class RegistryConfigParser(config_parser.GRRConfigParser): @@ -58,9 +59,12 @@ def __init__(self, config_path: str) -> None: def _AccessRootKey(self): if self._root_key is None: # Don't use winreg.KEY_WOW64_64KEY since it breaks on Windows 2000 - self._root_key = winreg.CreateKeyEx(self._key_spec.winreg_hive, - self._key_spec.path, 0, - winreg.KEY_ALL_ACCESS) + self._root_key = winreg.CreateKeyEx( + self._key_spec.winreg_hive, + self._key_spec.path, + 0, + winreg.KEY_ALL_ACCESS, + ) self.parsed = self._key_spec.path return self._root_key @@ -99,8 +103,9 @@ def SaveData(self, raw_data: Dict[str, Any]) -> None: str_value = value.decode("ascii") else: str_value = str(value) - winreg.SetValueEx(self._AccessRootKey(), key, 0, winreg.REG_SZ, - str_value) + winreg.SetValueEx( + self._AccessRootKey(), key, 0, winreg.REG_SZ, str_value + ) finally: # Make sure changes hit the disk. diff --git a/grr/client/makefile.py b/grr/client/makefile.py index 490649e3d6..1ae018b4ca 100644 --- a/grr/client/makefile.py +++ b/grr/client/makefile.py @@ -9,15 +9,14 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--clean", - action="store_true", - default=False, - help="Clean compiled protos.") + "--clean", action="store_true", default=False, help="Clean compiled protos." +) parser.add_argument( "--mypy-protobuf", default="", - help="A path to the mypy protobuf generator plugin.") + help="A path to the mypy protobuf generator plugin.", +) args = parser.parse_args() diff --git a/grr/client/setup.py b/grr/client/setup.py index ebc535a5df..00d824cc19 100644 --- a/grr/client/setup.py +++ b/grr/client/setup.py @@ -49,9 +49,11 @@ def compile_protos(): """Builds necessary assets from sources.""" # Using Popen to effectively suppress the output of the command below - no # need to fill in the logs with protoc's help. - p = subprocess.Popen([sys.executable, "-m", "grpc_tools.protoc", "--help"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + p = subprocess.Popen( + [sys.executable, "-m", "grpc_tools.protoc", "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) p.communicate() # If protoc is not installed, install it. This seems to be the only reliable # way to make sure that grpcio-tools gets intalled, no matter which Python @@ -62,9 +64,9 @@ def compile_protos(): # version. Otherwise latest protobuf library will be installed with # grpcio-tools and then uninstalled when grr-response-proto's setup.py runs # and reinstalled to the version required by grr-response-proto. - subprocess.check_call([ - sys.executable, "-m", "pip", "install", GRPCIO, GRPCIO_TOOLS, PROTOBUF - ]) + subprocess.check_call( + [sys.executable, "-m", "pip", "install", GRPCIO, GRPCIO_TOOLS, PROTOBUF] + ) # If there's no makefile, we're likely installing from an sdist, # so there's no need to compile the protos (they should be already @@ -73,8 +75,9 @@ def compile_protos(): return # Only compile protobufs if we're inside GRR source tree. - subprocess.check_call([sys.executable, "makefile.py", "--clean"], - cwd=THIS_DIRECTORY) + subprocess.check_call( + [sys.executable, "makefile.py", "--clean"], cwd=THIS_DIRECTORY + ) VERSION = get_config() @@ -89,7 +92,8 @@ def make_release_tree(self, base_dir, files): if os.path.exists(sdist_version_ini): os.unlink(sdist_version_ini) shutil.copy( - os.path.join(THIS_DIRECTORY, "../../version.ini"), sdist_version_ini) + os.path.join(THIS_DIRECTORY, "../../version.ini"), sdist_version_ini + ) def run(self): compile_protos() diff --git a/grr/client_builder/grr_response_client_builder/build_helpers.py b/grr/client_builder/grr_response_client_builder/build_helpers.py index 5c7d8d1b41..804e640e20 100644 --- a/grr/client_builder/grr_response_client_builder/build_helpers.py +++ b/grr/client_builder/grr_response_client_builder/build_helpers.py @@ -242,24 +242,10 @@ def ValidateEndConfig(config_obj, errors_fatal=True, context=None): if not url.startswith("http"): errors.append("Bad Client.server_urls specified %s" % url) - certificate = config_obj.GetRaw( - "CA.certificate", default=None, context=context) - if certificate is None or not certificate.startswith("-----BEGIN CERTIF"): - errors.append("CA certificate missing from config.") - - key_data = config_obj.GetRaw( - "Client.executable_signing_public_key", default=None, context=context) - if key_data is None: + if not config_obj.Get( + "Client.executable_signing_public_key", context=context + ): errors.append("Missing Client.executable_signing_public_key.") - elif not key_data.startswith("-----BEGIN PUBLIC"): - errors.append("Invalid Client.executable_signing_public_key: %s" % key_data) - else: - rdf_crypto.RSAPublicKey.FromHumanReadable(key_data) - - for bad_opt in ["Client.private_key"]: - if config_obj.Get(bad_opt, context=context, default=""): - errors.append("Client cert in conf, this should be empty at deployment" - " %s" % bad_opt) if errors_fatal and errors: for error in errors: @@ -271,13 +257,16 @@ def ValidateEndConfig(config_obj, errors_fatal=True, context=None): # Config options that have to make it to a deployable binary. _CONFIG_SECTIONS = [ - "CA", "Client", "ClientRepacker", "Logging", "Config", "Nanny", "Osquery", - "Installer", "Template" + "Client", + "ClientRepacker", + "Logging", + "Config", + "Nanny", + "Osquery", + "Installer", + "Template", ] -# Config options that should never make it to a deployable binary. -_SKIP_OPTION_LIST = ["Client.private_key"] - def GetClientConfig(context, validate=True, deploy_timestamp=True): """Generates the client config file for inclusion in deployable binaries.""" @@ -299,9 +288,6 @@ def GetClientConfig(context, validate=True, deploy_timestamp=True): while contexts.CLIENT_BUILD_CONTEXT in client_context: client_context.remove(contexts.CLIENT_BUILD_CONTEXT) for descriptor in sorted(config.CONFIG.type_infos, key=lambda x: x.name): - if descriptor.name in _SKIP_OPTION_LIST: - continue - if descriptor.section in _CONFIG_SECTIONS: value = config.CONFIG.GetRaw( descriptor.name, context=client_context, default=None) diff --git a/grr/client_builder/grr_response_client_builder/repacking.py b/grr/client_builder/grr_response_client_builder/repacking.py index 55261a0bd2..7be901f289 100644 --- a/grr/client_builder/grr_response_client_builder/repacking.py +++ b/grr/client_builder/grr_response_client_builder/repacking.py @@ -2,13 +2,13 @@ """Client repacking library.""" import getpass +import glob import logging import os import platform import sys import zipfile - from grr_response_client_builder import build from grr_response_client_builder import build_helpers from grr_response_client_builder import signing @@ -223,24 +223,29 @@ def RepackTemplate(self, return result_path - def RepackAllTemplates(self, upload=False): - """Repack all the templates in ClientBuilder.template_dir.""" - for template in os.listdir(config.CONFIG["ClientBuilder.template_dir"]): - template_path = os.path.join(config.CONFIG["ClientBuilder.template_dir"], - template) - + def RepackAllTemplates(self, upload: bool = False): + """Repack all the templates in ClientBuilder.template_dir including subfolders.""" + + template_dir = config.CONFIG.Get("ClientBuilder.template_dir", default="") + template_paths = glob.glob( + os.path.join(template_dir, "**/*.zip"), recursive=True + ) + executables_dir = config.CONFIG.Get( + "ClientBuilder.executables_dir", default="." + ) + for template_path in template_paths: self.RepackTemplate( template_path, - os.path.join(config.CONFIG["ClientBuilder.executables_dir"], - "installers"), - upload=upload) + executables_dir, + upload=upload, + ) # If it's windows also repack a debug version. if template_path.endswith(".exe.zip") or template_path.endswith( ".msi.zip"): print("Repacking as debug installer: %s." % template_path) self.RepackTemplate( template_path, - os.path.join(config.CONFIG["ClientBuilder.executables_dir"], - "installers"), + executables_dir, upload=upload, - context=["DebugClientBuild Context"]) + context=["DebugClientBuild Context"], + ) diff --git a/grr/client_builder/grr_response_client_builder/repacking_test.py b/grr/client_builder/grr_response_client_builder/repacking_test.py index 49dcfa74cc..289b3e4b00 100644 --- a/grr/client_builder/grr_response_client_builder/repacking_test.py +++ b/grr/client_builder/grr_response_client_builder/repacking_test.py @@ -43,22 +43,18 @@ def testRepackAll(self): with test_lib.ConfigOverrider({ "ClientBuilder.executables_dir": new_dir, - "ClientBuilder.unzipsfx_stub_dir": new_dir }): repacking.TemplateRepacker().RepackAllTemplates() - self.assertEqual( - len(glob.glob(os.path.join(new_dir, "installers/*.deb"))), 2) - self.assertEqual( - len(glob.glob(os.path.join(new_dir, "installers/*.rpm"))), 2) - self.assertEqual( - len(glob.glob(os.path.join(new_dir, "installers/*.exe"))), 4) - self.assertEqual( - len(glob.glob(os.path.join(new_dir, "installers/*.pkg"))), 1) + self.assertLen(glob.glob(os.path.join(new_dir, "*.deb")), 2) + self.assertLen(glob.glob(os.path.join(new_dir, "*.rpm")), 2) + self.assertLen(glob.glob(os.path.join(new_dir, "*.exe")), 4) + self.assertLen(glob.glob(os.path.join(new_dir, "*.pkg")), 1) # Validate the config appended to the OS X package. zf = zipfile.ZipFile( - glob.glob(os.path.join(new_dir, "installers/*.pkg")).pop(), mode="r") + glob.glob(os.path.join(new_dir, "*.pkg")).pop(), mode="r" + ) fd = zf.open("config.yaml") # We can't load the included build.yaml because the package hasn't been diff --git a/grr/client_builder/grr_response_client_builder/windows_msi_test.py b/grr/client_builder/grr_response_client_builder/windows_msi_test.py index a96bb8d861..486ca2ebb3 100644 --- a/grr/client_builder/grr_response_client_builder/windows_msi_test.py +++ b/grr/client_builder/grr_response_client_builder/windows_msi_test.py @@ -42,25 +42,6 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAMQpeVjrxmf6nPmsjHjULWhLmquSgTDK GpJgTFkTIAgX0Ih5lxoFB5TUjUfJFbBkSmKQPRA/IyuLBtCLQgwkTNkCAwEAAQ== -----END PUBLIC KEY----- -CA.certificate: | - -----BEGIN CERTIFICATE----- - MIIC2zCCAcOgAwIBAgIBATANBgkqhkiG9w0BAQsFADAgMREwDwYDVQQDDAhncnJf - dGVzdDELMAkGA1UEBhMCVVMwHhcNMjEwMTE5MjAwNjQ1WhcNMzEwMTE4MjAwNjQ1 - WjAOMQwwCgYDVQQDDANncnIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB - AQDCYKqomwMxTPsipJtYpzxIbb0okr+NdKZipqLVjqt+LTtHBFuMwRxtX2ZG7l+6 - EiadZjh4tk+PNk9Lq5ZFfGjpJ/mLLpPXkdcZnjToseCTYdM0dsnQ0q1hIA6chRwU - mvTU81rlexNsslthjGUHfNdeWwIPtfvEW9/GtV8f3eeIo7e5h4Nco97N2bj6alPZ - 5ASThtCUK0GAm9qfTwi+UZaLNZlUPbj7OSdbc/5ieosF9CAuAXNHAqQY5IfkLYun - w+Ma6oDYbfSB0EV450tJATwNprNLgg9fyABz3sDEFWJ7+H0eRQ0nOQLCHHvhduEP - hdX6LzsaUH0WBiqgyq2prCenAgMBAAGjMjAwMA8GA1UdEwEB/wQFMAMBAf8wHQYD - VR0OBBYEFH2xv8xuBK6Vxarntzu5WwqowKbxMA0GCSqGSIb3DQEBCwUAA4IBAQB1 - JKXAglrc4ZYY04ZRyodpKzVXext7grpbRpen1+NigObYQb1ZGuaYXvpr8HiB6yGm - wx8BUrO0n5wzJi7ZRktwrBWdTseYRX6ztHF0+2pBnzkCF06zM597wwv49aUaySVV - BfHLR7TqF7QrQNeUMMjprADM3yNuuUGhLtlDZszgUTMLowxK3WM0A4niKhLaeGRb - E+i02f9gjMQBdhkFxZ/r3LhgXvwtb7xy+1JuvJlTmWpPWDivLScODtTq+/US6lnw - d7yf65zi20ufC5fh4oxc2stFLYlI0+MvTfj9f0sJbfJLSYj+8/jvRub0nAJQEyl7 - 6H+n8+lmRu0iE0dFPB+z - -----END CERTIFICATE----- """ PFX_BASE64_FILE = """ diff --git a/grr/config/grr_response_templates/upload.sh b/grr/config/grr_response_templates/upload.sh index 1c953ab2e5..58c8a1a4e7 100755 --- a/grr/config/grr_response_templates/upload.sh +++ b/grr/config/grr_response_templates/upload.sh @@ -10,7 +10,7 @@ set -e VERSION=$1 -RELEASE_NAME="grr-response-templates-${VERSION}" +RELEASE_NAME="grr_response_templates-${VERSION}" RELEASE_TAR="${RELEASE_NAME}.tar.gz" RELEASE_FILE="dist/${RELEASE_TAR}" if [[ $# -eq 1 ]] ; then diff --git a/grr/core/grr_response_core/config/artifacts.py b/grr/core/grr_response_core/config/artifacts.py index 9b47029cdb..5ada59c8af 100644 --- a/grr/core/grr_response_core/config/artifacts.py +++ b/grr/core/grr_response_core/config/artifacts.py @@ -20,10 +20,10 @@ " dependencies.") config_lib.DEFINE_list( - "Artifacts.non_kb_interrogate_artifacts", [ - "WMILogicalDisks", "RootDiskVolumeUsage", "WMIComputerSystemProduct", - "LinuxHardwareInfo", "OSXSPHardwareDataType" - ], "Non-knowledge-base artifacts collected during Interrogate flows.") + "Artifacts.non_kb_interrogate_artifacts", + ["WMILogicalDisks", "RootDiskVolumeUsage"], + "Non-knowledge-base artifacts collected during Interrogate flows.", +) config_lib.DEFINE_list( "Artifacts.knowledge_base_additions", [], diff --git a/grr/core/grr_response_core/config/build.py b/grr/core/grr_response_core/config/build.py index b71f4aff8f..d1755a97ac 100644 --- a/grr/core/grr_response_core/config/build.py +++ b/grr/core/grr_response_core/config/build.py @@ -297,66 +297,12 @@ def FromString(self, string): default="%(executables@grr-response-core|resource)", description="The path to the grr executables directory.")) -config_lib.DEFINE_option( - PathTypeInfo( - name="ClientBuilder.unzipsfx_stub_dir", - must_exist=False, - default=("%(ClientBuilder.executables_dir)/%(Client.platform)" - "/templates/unzipsfx"), - description="The directory that contains the zip self extracting " - "stub.")) - -config_lib.DEFINE_option( - PathTypeInfo( - name="ClientBuilder.unzipsfx_stub", - must_exist=False, - default=( - "%(ClientBuilder.unzipsfx_stub_dir)/unzipsfx-%(Client.arch).exe"), - description="The full path to the zip self extracting stub.")) - config_lib.DEFINE_string( name="ClientBuilder.config_filename", default="%(Client.binary_name).yaml", help=("The name of the configuration file which will be embedded in the " "deployable binary.")) -config_lib.DEFINE_string( - name="ClientBuilder.autorun_command_line", - default=("%(Client.binary_name) --install " - "--config %(ClientBuilder.config_filename)"), - help=("The command that the installer will execute after " - "unpacking the package.")) - -config_lib.DEFINE_list( - name="ClientBuilder.plugins", - default=[], - help="Plugins that will copied to the client installation file and run when" - "the client is running.") - -config_lib.DEFINE_string( - name="ClientBuilder.client_logging_filename", - default="%(Logging.path)/%(Client.name)_log.txt", - help="Filename for logging, to be copied to Client section in the client " - "that gets built.") - -config_lib.DEFINE_string( - name="ClientBuilder.client_logging_path", - default="/tmp", - help="Filename for logging, to be copied to Client section in the client " - "that gets built.") - -config_lib.DEFINE_list( - name="ClientBuilder.client_logging_engines", - default=["stderr", "file"], - help="Enabled logging engines, to be copied to Logging.engines in client " - "configuration.") - -config_lib.DEFINE_string( - name="ClientBuilder.client_installer_logfile", - default="%(Logging.path)/%(Client.name)_installer.txt", - help="Logfile for logging the client installation process, to be copied to" - " Installer.logfile in client built.") - config_lib.DEFINE_string( name="ClientBuilder.maintainer", default="GRR ", @@ -465,12 +411,6 @@ def FromString(self, string): ], help="Platforms that will be built by client_build buildandrepack") -config_lib.DEFINE_list( - name="ClientBuilder.BuildTargets", - default=[], - help="List of context names that should be built by " - "buildandrepack") - config_lib.DEFINE_string( "ClientBuilder.rpm_signing_key_public_keyfile", default="/etc/alternatives/grr_rpm_signing_key", diff --git a/grr/core/grr_response_core/config/client.py b/grr/core/grr_response_core/config/client.py index 74cb11a8ba..9eb3276c81 100644 --- a/grr/core/grr_response_core/config/client.py +++ b/grr/core/grr_response_core/config/client.py @@ -165,20 +165,6 @@ help="The registry key where client configuration " "will be stored.") -# Client Cryptographic options. Here we define defaults for key values. -config_lib.DEFINE_semantic_value( - rdf_crypto.RSAPrivateKey, - "Client.private_key", - help="Client private key in pem format. If not provided this " - "will be generated by the enrollment process.", -) - -config_lib.DEFINE_semantic_value( - rdf_crypto.RDFX509Cert, - "CA.certificate", - help="Trusted CA certificate in X509 pem format", -) - config_lib.DEFINE_semantic_value( rdf_crypto.RSAPublicKey, "Client.executable_signing_public_key", diff --git a/grr/core/grr_response_core/config/server.py b/grr/core/grr_response_core/config/server.py index a69ab6af8a..50d726db64 100644 --- a/grr/core/grr_response_core/config/server.py +++ b/grr/core/grr_response_core/config/server.py @@ -5,7 +5,6 @@ from grr_response_core import version from grr_response_core.lib import config_lib from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import paths as rdf_paths VERSION = version.Version() @@ -56,25 +55,9 @@ config_lib.DEFINE_string("Worker.smtp_password", None, "Password for the smtp connection.") -# Server Cryptographic settings. -config_lib.DEFINE_semantic_value( - rdf_crypto.RSAPrivateKey, - "PrivateKeys.ca_key", - help="CA private key. Used to sign for client enrollment.") - -config_lib.DEFINE_semantic_value( - rdf_crypto.RSAPrivateKey, - "PrivateKeys.server_key", - help="Private key for the front end server.") - config_lib.DEFINE_integer("Server.rsa_key_length", 2048, "The length of the server rsa key in bits.") -config_lib.DEFINE_semantic_value( - rdf_crypto.RDFX509Cert, - "Frontend.certificate", - help="An X509 certificate for the frontend server.") - config_lib.DEFINE_bool("Cron.active", False, "Set to true to run a cron thread on this binary.") diff --git a/grr/core/grr_response_core/config/test.py b/grr/core/grr_response_core/config/test.py index 538dcf782f..415382a4f8 100644 --- a/grr/core/grr_response_core/config/test.py +++ b/grr/core/grr_response_core/config/test.py @@ -27,9 +27,6 @@ config_lib.DEFINE_integer("Test.remote_pdb_port", 2525, "Remote debugger port.") -config_lib.DEFINE_string("PrivateKeys.ca_key_raw_data", "", - "For testing purposes.") - config_lib.DEFINE_integer("SharedMemoryDB.port", 0, "Port used to connect to SharedMemoryDB server.") diff --git a/grr/core/grr_response_core/lib/artifact_utils.py b/grr/core/grr_response_core/lib/artifact_utils.py index 1639c8d43d..145e07054e 100644 --- a/grr/core/grr_response_core/lib/artifact_utils.py +++ b/grr/core/grr_response_core/lib/artifact_utils.py @@ -5,11 +5,8 @@ intended to end up as an independent library. """ - import re from typing import Iterable -from typing import Text - from grr_response_core.lib import interpolation from grr_response_core.lib.rdfvalues import structs as rdf_structs @@ -30,7 +27,7 @@ class ArtifactProcessingError(Error): class KbInterpolationMissingAttributesError(Error): """An exception class for missing knowledgebase attributes.""" - def __init__(self, attrs: Iterable[Text]) -> None: + def __init__(self, attrs: Iterable[str]) -> None: message = "Some attributes could not be located in the knowledgebase: {}" message = message.format(", ".join(attrs)) super().__init__(message) @@ -41,7 +38,7 @@ def __init__(self, attrs: Iterable[Text]) -> None: class KbInterpolationUnknownAttributesError(Error): """An exception class for non-existing knowledgebase attributes.""" - def __init__(self, attrs: Iterable[Text]) -> None: + def __init__(self, attrs: Iterable[str]) -> None: message = "Some attributes are not part of the knowledgebase: {}" message = message.format(", ".join(attrs)) super().__init__(message) @@ -108,8 +105,10 @@ def InterpolateKbAttributes(pattern, knowledge_base): for scope_id in interpolator.Scopes(): scope_name = str(scope_id).lower() - if not (scope_name in kb_cls.type_infos and - isinstance(kb_cls.type_infos[scope_name], rdf_structs.ProtoList)): + if not ( + scope_name in kb_cls.type_infos + and isinstance(kb_cls.type_infos[scope_name], rdf_structs.ProtoList) + ): unknown_attr_names.add(scope_name) continue @@ -250,11 +249,12 @@ def ExpandWindowsEnvironmentVariables(data_string, knowledge_base): components = [] offset = 0 for match in win_environ_regex.finditer(data_string): - components.append(data_string[offset:match.start()]) + components.append(data_string[offset : match.start()]) # KB environment variables are prefixed with environ_. - kb_value = getattr(knowledge_base, "environ_%s" % match.group(1).lower(), - None) + kb_value = getattr( + knowledge_base, "environ_%s" % match.group(1).lower(), None + ) if isinstance(kb_value, str) and kb_value: components.append(kb_value) else: @@ -265,10 +265,9 @@ def ExpandWindowsEnvironmentVariables(data_string, knowledge_base): return "".join(components) -def ExpandWindowsUserEnvironmentVariables(data_string, - knowledge_base, - sid=None, - username=None): +def ExpandWindowsUserEnvironmentVariables( + data_string, knowledge_base, sid=None, username=None +): r"""Take a string and expand windows user environment variables based. Args: @@ -284,7 +283,7 @@ def ExpandWindowsUserEnvironmentVariables(data_string, components = [] offset = 0 for match in win_environ_regex.finditer(data_string): - components.append(data_string[offset:match.start()]) + components.append(data_string[offset : match.start()]) kb_user = knowledge_base.GetUser(sid=sid, username=username) kb_value = None if kb_user: diff --git a/grr/core/grr_response_core/lib/casing.py b/grr/core/grr_response_core/lib/casing.py index c4b0a499a9..5aece1376a 100644 --- a/grr/core/grr_response_core/lib/casing.py +++ b/grr/core/grr_response_core/lib/casing.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Functions to convert strings between different case styles.""" + import re diff --git a/grr/core/grr_response_core/lib/casing_test.py b/grr/core/grr_response_core/lib/casing_test.py index fd00d341e0..0c2c4570d1 100644 --- a/grr/core/grr_response_core/lib/casing_test.py +++ b/grr/core/grr_response_core/lib/casing_test.py @@ -34,16 +34,18 @@ def testSnakeToCamelWorksOnStringsWithPrefixOrSuffixUnderscores(self): self.assertEqual("aBCD", casing.SnakeToCamel("_a_b_c_d_")) self.assertEqual("aBCD", casing.SnakeToCamel("____a_b_c_d____")) self.assertEqual("aaBbCcDd", casing.SnakeToCamel("____aa_bb_cc_dd____")) - self.assertEqual("aaaBbbCccDdd", - casing.SnakeToCamel("____aaa_bbb_ccc_ddd____")) + self.assertEqual( + "aaaBbbCccDdd", casing.SnakeToCamel("____aaa_bbb_ccc_ddd____") + ) def testSnakeToCamelWorksOnStringsWithMultipleUnderscoresBetweenWords(self): self.assertEqual("aBCD", casing.SnakeToCamel("a__b__c__d")) self.assertEqual("aBCD", casing.SnakeToCamel("a____b____c____d")) self.assertEqual("aBCD", casing.SnakeToCamel("___a___b___c___d___")) self.assertEqual("aaBbCcDd", casing.SnakeToCamel("___aa___bb___cc___dd___")) - self.assertEqual("aaaBbbCccDdd", - casing.SnakeToCamel("___aaa___bbb___ccc___ddd___")) + self.assertEqual( + "aaaBbbCccDdd", casing.SnakeToCamel("___aaa___bbb___ccc___ddd___") + ) def testSnakeToCamelWorksOnStringsWithUppercaseLetters(self): self.assertEqual("a", casing.SnakeToCamel("A")) @@ -71,18 +73,21 @@ def testCamelToSnakeWorksOnRegularStrings(self): self.assertEqual("a_b_c_d", casing.CamelToSnake("aBCD")) def testCamelToSnakeWorksOnStringsWithUppercaseLettersOnly(self): - self.assertEqual("t_h_i_s_i_s_a_s_n_a_k_e", - casing.CamelToSnake("THISISASNAKE")) - self.assertEqual("a_s_n_a_k_e_t_h_i_s_i_s", - casing.CamelToSnake("ASNAKETHISIS")) + self.assertEqual( + "t_h_i_s_i_s_a_s_n_a_k_e", casing.CamelToSnake("THISISASNAKE") + ) + self.assertEqual( + "a_s_n_a_k_e_t_h_i_s_i_s", casing.CamelToSnake("ASNAKETHISIS") + ) self.assertEqual("a_b_c_d", casing.CamelToSnake("ABCD")) self.assertEqual("a", casing.CamelToSnake("A")) def testCamelToSnakeWorksOnStringsWithUnicodeCharacters(self): self.assertEqual("ą_ć_ę", casing.CamelToSnake("ąĆĘ")) self.assertEqual("ąąąa_ććća_ęęęa", casing.CamelToSnake("ąąąaĆććaĘęęa")) - self.assertEqual("ą_ą_ąa_ć_ć_ća_ę_ę_ęa", - casing.CamelToSnake("ĄĄĄaĆĆĆaĘĘĘa")) + self.assertEqual( + "ą_ą_ąa_ć_ć_ća_ę_ę_ęa", casing.CamelToSnake("ĄĄĄaĆĆĆaĘĘĘa") + ) self.assertEqual("ą_ą_ą_ć_ć_ć_ę_ę_ę", casing.CamelToSnake("ĄĄĄĆĆĆĘĘĘ")) self.assertEqual("ą", casing.CamelToSnake("Ą")) diff --git a/grr/core/grr_response_core/lib/communicator.py b/grr/core/grr_response_core/lib/communicator.py index b0bfc0c504..559b89f619 100644 --- a/grr/core/grr_response_core/lib/communicator.py +++ b/grr/core/grr_response_core/lib/communicator.py @@ -14,7 +14,8 @@ GRR_DECODING_ERROR = metrics.Counter("grr_decoding_error") GRR_DECRYPTION_ERROR = metrics.Counter("grr_decryption_error") GRR_LEGACY_CLIENT_DECRYPTION_ERROR = metrics.Counter( - "grr_legacy_client_decryption_error") + "grr_legacy_client_decryption_error" +) GRR_RSA_OPERATIONS = metrics.Counter("grr_rsa_operations") @@ -46,8 +47,9 @@ def __init__(self, message): super().__init__(message) -class Cipher(object): +class Cipher: """Holds keying information.""" + cipher_name = "aes_128_cbc" key_size = 128 iv_size = 128 @@ -70,7 +72,8 @@ def __init__(self, source, private_key, remote_public_key): key=rdf_crypto.EncryptionKey.GenerateKey(length=self.key_size), metadata_iv=rdf_crypto.EncryptionKey.GenerateKey(length=self.key_size), hmac_key=rdf_crypto.EncryptionKey.GenerateKey(length=self.key_size), - hmac_type="FULL_HMAC") + hmac_type="FULL_HMAC", + ) serialized_cipher = self.cipher.SerializeToBytes() @@ -85,7 +88,8 @@ def __init__(self, source, private_key, remote_public_key): # Encrypt the metadata block symmetrically. _, self.encrypted_cipher_metadata = self.Encrypt( - self.cipher_metadata.SerializeToBytes(), self.cipher.metadata_iv) + self.cipher_metadata.SerializeToBytes(), self.cipher.metadata_iv + ) def Encrypt(self, data, iv=None): """Symmetrically encrypt the data using the optional iv.""" @@ -112,18 +116,23 @@ class ReceivedCipher(Cipher): """A cipher which we received from our peer.""" # pylint: disable=super-init-not-called - def __init__(self, response_comms: rdf_flows.ClientCommunication, - private_key): + def __init__( + self, response_comms: rdf_flows.ClientCommunication, private_key + ): self.private_key = private_key self.response_comms = response_comms if response_comms.api_version not in [3]: - raise DecryptionError("Unsupported api version: %s, expected 3." % - response_comms.api_version) + raise DecryptionError( + "Unsupported api version: %s, expected 3." + % response_comms.api_version + ) if not response_comms.encrypted_cipher: - logging.error("No encrypted_cipher. orig_request=%s", - str(response_comms.orig_request)[:1000]) + logging.error( + "No encrypted_cipher. orig_request=%s", + str(response_comms.orig_request)[:1000], + ) # The message is not encrypted. We do not allow unencrypted # messages: raise DecryptionError("Message is not encrypted.") @@ -131,16 +140,20 @@ def __init__(self, response_comms: rdf_flows.ClientCommunication, try: # The encrypted_cipher contains the session key, iv and hmac_key. self.serialized_cipher = private_key.Decrypt( - response_comms.encrypted_cipher) + response_comms.encrypted_cipher + ) # If we get here we have the session keys. self.cipher = rdf_flows.CipherProperties.FromSerializedBytes( - self.serialized_cipher) + self.serialized_cipher + ) # Check the key lengths. - if (len(self.cipher.key) * 8 != self.key_size or - len(self.cipher.metadata_iv) * 8 != self.iv_size or - len(self.cipher.hmac_key) * 8 != self.key_size): + if ( + len(self.cipher.key) * 8 != self.key_size + or len(self.cipher.metadata_iv) * 8 != self.iv_size + or len(self.cipher.hmac_key) * 8 != self.key_size + ): raise DecryptionError("Invalid cipher.") self.VerifyHMAC() @@ -150,9 +163,11 @@ def __init__(self, response_comms: rdf_flows.ClientCommunication, # digest of the serialized CipherProperties(). It is stored inside the # encrypted payload. serialized_metadata = self.Decrypt( - response_comms.encrypted_cipher_metadata, self.cipher.metadata_iv) + response_comms.encrypted_cipher_metadata, self.cipher.metadata_iv + ) self.cipher_metadata = rdf_flows.CipherMetadata.FromSerializedBytes( - serialized_metadata) + serialized_metadata + ) except (rdf_crypto.InvalidSignature, rdf_crypto.CipherError) as e: if "Ciphertext length must be equal to key size" in str(e): @@ -191,7 +206,6 @@ def VerifyHMAC(self): Returns: True - """ return self._VerifyHMAC(self.response_comms) @@ -209,7 +223,6 @@ def _VerifyHMAC(self, comms=None): Returns: True - """ # Check the encrypted message integrity using HMAC. if self.hmac_type == "SIMPLE_HMAC": @@ -217,10 +230,11 @@ def _VerifyHMAC(self, comms=None): digest = comms.hmac elif self.hmac_type == "FULL_HMAC": msg = b"".join([ - comms.encrypted, comms.encrypted_cipher, + comms.encrypted, + comms.encrypted_cipher, comms.encrypted_cipher_metadata, comms.packet_iv.SerializeToBytes(), - struct.pack(" Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) return data class Literal(ConfigFilter): """A filter which does not interpolate.""" + name = "literal" class Lower(ConfigFilter): name = "lower" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) return data.lower() class Upper(ConfigFilter): name = "upper" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) return data.upper() class Filename(ConfigFilter): name = "file" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) try: with io.open(data, "r") as fd: return fd.read() # pytype: disable=bad-return-type @@ -175,8 +186,8 @@ def Filter(self, data: Text) -> Text: class OptionalFile(ConfigFilter): name = "optionalfile" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) try: with io.open(data, "r") as fd: return fd.read() # pytype: disable=bad-return-type @@ -189,8 +200,8 @@ class FixPathSeparator(ConfigFilter): name = "fixpathsep" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) if platform.system() == "Windows": # This will fix "X:\", and might add extra slashes to other paths, but # this is OK. @@ -201,33 +212,36 @@ def Filter(self, data: Text) -> Text: class Env(ConfigFilter): """Interpolate environment variables.""" + name = "env" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) return os.environ.get(data.upper(), "") class Expand(ConfigFilter): """Expands the input as a configuration parameter.""" + name = "expand" - def Filter(self, data: Text) -> Text: - precondition.AssertType(data, Text) + def Filter(self, data: str) -> str: + precondition.AssertType(data, str) interpolated = _CONFIG.InterpolateValue(data) # TODO(hanuszczak): This assertion should not be necessary but since the # whole configuration system is one gigantic spaghetti, we can never be sure # what is being returned. - precondition.AssertType(data, Text) - return cast(Text, interpolated) + precondition.AssertType(data, str) + return cast(str, interpolated) class Flags(ConfigFilter): """Get the parameter from the flags.""" + name = "flags" - def Filter(self, data: Text): - precondition.AssertType(data, Text) + def Filter(self, data: str): + precondition.AssertType(data, str) try: logging.debug("Overriding config option with flags.FLAGS.%s", data) attribute = getattr(flags.FLAGS, data) @@ -235,11 +249,11 @@ def Filter(self, data: Text): # should not be needed. This is just a quick hack to fix prod. if isinstance(attribute, bytes): attribute = attribute.decode("utf-8") - elif not isinstance(attribute, Text): + elif not isinstance(attribute, str): attribute = str(attribute) # TODO(hanuszczak): See TODO comment in the `Expand` filter. - precondition.AssertType(attribute, Text) - return cast(Text, attribute) + precondition.AssertType(attribute, str) + return cast(str, attribute) except AttributeError as e: raise FilterError(e) @@ -250,23 +264,23 @@ class Resource(ConfigFilter): The format of the directive is "path/to/resource@package_name". If package_name is not provided we use grr-resource-core by default. """ + name = "resource" default_package = "grr-response-core" - def Filter(self, filename_spec: Text) -> Text: + def Filter(self, data: str) -> str: """Use pkg_resources to find the path to the required resource.""" - if "@" in filename_spec: - file_path, package_name = filename_spec.split("@") + if "@" in data: + file_path, package_name = data.split("@") else: - file_path, package_name = filename_spec, Resource.default_package + file_path, package_name = data, Resource.default_package resource_path = package.ResourcePath(package_name, file_path) if resource_path is not None: return resource_path # pylint: disable=unreachable - raise FilterError( - "Unable to find resource %s while interpolating: " % filename_spec) + raise FilterError("Unable to find resource %s while interpolating: " % data) # pylint: enable=unreachable @@ -280,21 +294,23 @@ class ModulePath(ConfigFilter): Caveat: This will raise if the module is not a physically present on disk (e.g. pyinstaller bundle). """ + name = "module_path" - def Filter(self, name: Text) -> Text: + def Filter(self, data: str) -> str: try: - return package.ModulePath(name) - except ImportError: + return package.ModulePath(data) + except ImportError as e: message = ( - "Config parameter module_path expansion %r can not be imported." % - name) + "Config parameter module_path expansion %r can not be imported." + % data + ) # This exception will typically be caught by the expansion engine and # be silently swallowed. traceback.print_exc() logging.error(message) - raise FilterError(message) + raise FilterError(message) from e class StringInterpolator(lexer.Lexer): @@ -326,22 +342,18 @@ class StringInterpolator(lexer.Lexer): tokens = [ # When in literal mode, only allow to escape } lexer.Token("Literal", r"\\[^{}]", "AppendArg", None), - # Allow escaping of special characters lexer.Token(None, r"\\(.)", "Escape", None), - # Literal sequence is %{....}. Literal states can not be nested further, # i.e. we include anything until the next }. It is still possible to # escape } if this character needs to be inserted literally. lexer.Token("Literal", r"\}", "EndLiteralExpression,PopState", None), lexer.Token("Literal", r"[^}\\]+", "AppendArg", None), lexer.Token(None, r"\%\{", "StartExpression,PushState", "Literal"), - # Expansion sequence is %(....) lexer.Token(None, r"\%\(", "StartExpression", None), lexer.Token(None, r"\|([a-zA-Z_-]+)\)", "Filter", None), lexer.Token(None, r"\)", "ExpandArg", None), - # Glob up as much data as possible to increase efficiency here. lexer.Token(None, r"[^()%{}|\\]+", "AppendArg", None), lexer.Token(None, r".", "AppendArg", None), @@ -353,15 +365,17 @@ class StringInterpolator(lexer.Lexer): "\\)": ")", "\\{": "{", "\\}": "}", - "\\%": "%" + "\\%": "%", } - def __init__(self, - data, - config, - default_section="", - parameter=None, - context=None): + def __init__( + self, + data, + config, + default_section="", + parameter=None, + context=None, + ): self.stack = [""] self.default_section = default_section self.parameter = parameter @@ -385,8 +399,10 @@ def StartExpression(self, **_): def EndLiteralExpression(self, **_): if len(self.stack) <= 1: - raise lexer.ParseError("Unbalanced literal sequence: Can not expand '%s'" - % self.processed_buffer) + raise lexer.ParseError( + "Unbalanced literal sequence: Can not expand '%s'" + % self.processed_buffer + ) arg = self.stack.pop(-1) self.stack[-1] += arg @@ -404,7 +420,7 @@ def Filter(self, match=None, **_): if not filter_object.sensitive_arg: logging.debug("Applying filter %s for %s.", filter_name, arg) arg = filter_object().Filter(arg) - precondition.AssertType(arg, Text) + precondition.AssertType(arg, str) self.stack[-1] += arg @@ -414,7 +430,8 @@ def ExpandArg(self, **_): # exactly match the number of (. if len(self.stack) <= 1: raise lexer.ParseError( - "Unbalanced parenthesis: Can not expand '%s'" % self.processed_buffer) + "Unbalanced parenthesis: Can not expand '%s'" % self.processed_buffer + ) # This is the full parameter name: e.g. Logging.path parameter_name = self.stack.pop(-1) @@ -426,7 +443,8 @@ def ExpandArg(self, **_): final_value = "" type_info_obj = ( - self.config.FindTypeInfo(parameter_name) or type_info.String()) + self.config.FindTypeInfo(parameter_name) or type_info.String() + ) # Encode the interpolated string according to its type. self.stack[-1] += type_info_obj.ToString(final_value) @@ -477,16 +495,18 @@ def __init__(self): def DeclareBuiltIns(self): """Declare built in options internal to the config system.""" self.DEFINE_list( - "Config.includes", [], + "Config.includes", + [], "List of additional config files to include. Files are " "processed recursively depth-first, later values " - "override earlier ones.") + "override earlier ones.", + ) def __str__(self): # List all the files we read from. message = "" for filename in self.files: - message += " file=\"%s\" " % filename + message += ' file="%s" ' % filename return "<%s %s>" % (self.__class__.__name__, message) @@ -553,8 +573,8 @@ def SetWriteBack(self, filename, rename_invalid_writeback=True): Args: filename: A filename which will receive updates. The file is parsed first and merged into the raw data from this object. - rename_invalid_writeback: Whether to rename the writeback file if - it cannot be parsed. + rename_invalid_writeback: Whether to rename the writeback file if it + cannot be parsed. """ try: self.writeback = self.LoadSecondaryConfig(filename) @@ -643,7 +663,8 @@ def AddContext(self, context_string, description=None): if context_string not in self.context: if context_string not in self.valid_contexts: raise InvalidContextError( - "Invalid context specified: %s" % context_string) + "Invalid context specified: %s" % context_string + ) self.context.append(context_string) self.context_descriptions[context_string] = description @@ -667,7 +688,8 @@ def SetRaw(self, name, value): logging.warning("Attempting to modify a read only config object.") if name in self.constants: raise ConstModificationError( - "Attempting to modify constant value %s" % name) + "Attempting to modify constant value %s" % name + ) self.writeback_data[name] = value self.FlushCache() @@ -689,22 +711,25 @@ def Set(self, name, value): # If the configuration system has a write back location we use it, # otherwise we use the primary configuration object. if self.writeback is None: - logging.warning("Attempting to modify a read only config object for %s.", - name) + logging.warning( + "Attempting to modify a read only config object for %s.", name + ) if name in self.constants: raise ConstModificationError( - "Attempting to modify constant value %s" % name) + "Attempting to modify constant value %s" % name + ) writeback_data = self.writeback_data # Check if the new value conforms with the type_info. if value is not None: - if isinstance(value, Text): + if isinstance(value, str): value = self.EscapeString(value) if isinstance(value, bytes): - raise ValueError("Setting config option %s to bytes is not allowed" % - name) + raise ValueError( + "Setting config option %s to bytes is not allowed" % name + ) writeback_data[name] = value self.FlushCache() @@ -718,14 +743,16 @@ def Write(self): if self.writeback: self.writeback.SaveData(self.writeback_data) else: - raise RuntimeError("Attempting to write a configuration without a " - "writeback location.") + raise RuntimeError( + "Attempting to write a configuration without a writeback location." + ) def Persist(self, config_option): """Stores in the writeback.""" if not self.writeback: - raise RuntimeError("Attempting to write a configuration without a " - "writeback location.") + raise RuntimeError( + "Attempting to write a configuration without a writeback location." + ) writeback_raw_value = dict(self.writeback.ReadData()).get(config_option) raw_value = None @@ -766,7 +793,8 @@ def AddOption(self, descriptor, constant=False): """ if self.initialized: raise AlreadyInitializedError( - "Config was already initialized when defining %s" % descriptor.name) + "Config was already initialized when defining %s" % descriptor.name + ) descriptor.section = descriptor.name.split(".")[0] if descriptor.name in self.type_infos: @@ -791,8 +819,10 @@ def FormatHelp(self): try: result += " Current Value: %s\n" % self.Get(descriptor.name) except Exception as e: # pylint:disable=broad-except - result += " Current Value: %s (Error: %s)\n" % (self.GetRaw( - descriptor.name), e) + result += " Current Value: %s (Error: %s)\n" % ( + self.GetRaw(descriptor.name), + e, + ) return result def PrintHelp(self): @@ -816,8 +846,10 @@ def MergeData(self, merge_data: Dict[Any, Any], raw_data=None): # Find the descriptor for this field. descriptor = self.type_infos.get(k) if descriptor is None: - msg = ("Missing config definition for %s. This option is likely " - "deprecated or renamed. Check the release notes." % k) + msg = ( + "Missing config definition for %s. This option is likely " + "deprecated or renamed. Check the release notes." % k + ) if flags.FLAGS.disallow_missing_config_definitions: raise MissingConfigDefinitionError(msg) @@ -828,15 +860,14 @@ def MergeData(self, merge_data: Dict[Any, Any], raw_data=None): # value (e.g. via Set()), break loudly. if self.initialized and k in self.constants: raise ConstModificationError( - "Attempting to modify constant value %s" % k) + "Attempting to modify constant value %s" % k + ) raw_data[k] = v def LoadSecondaryConfig( - self, - filename=None, - parser=None, - process_includes=True) -> config_parser.GRRConfigParser: + self, filename=None, parser=None, process_includes=True + ) -> config_parser.GRRConfigParser: """Loads an additional configuration file. The configuration system has the concept of a single Primary configuration @@ -861,7 +892,6 @@ def LoadSecondaryConfig( Raises: ValueError: if both filename and parser arguments are None. ConfigFileNotFound: If a specified included file was not found. - """ if filename: # Maintain a stack of config file locations in loaded order. @@ -885,7 +915,8 @@ def LoadSecondaryConfig( if not filename: raise ConfigFileNotFound( "While loading %s: Unable to include a relative path (%s) " - "from a config without a filename" % (filename, file_to_load)) + "from a config without a filename" % (filename, file_to_load) + ) # If the included path is relative, we take it as relative to the # current path of the config. @@ -896,8 +927,9 @@ def LoadSecondaryConfig( try: clone_parser.ReadData() except config_parser.ReadDataError as e: - raise ConfigFileNotFound("Unable to load include file %s" % - file_to_load) from e + raise ConfigFileNotFound( + "Unable to load include file %s" % file_to_load + ) from e self.MergeData(clone.raw_data) self.files.extend(clone.files) @@ -915,7 +947,9 @@ def Initialize( must_exist: bool = False, process_includes: bool = True, parser: Type[ - config_parser.GRRConfigParser] = config_parser.IniConfigFileParser): + config_parser.GRRConfigParser + ] = config_parser.IniConfigFileParser, + ): """Initializes the config manager. This method is used to add more config options to the manager. The config @@ -951,26 +985,31 @@ def Initialize( if issubclass(parser, config_parser.GRRConfigFileParser): self.parser = self.LoadSecondaryConfig( parser=config_parser.FileParserDataWrapper(fd.read(), parser("")), - process_includes=process_includes) + process_includes=process_includes, + ) else: raise TypeError("Trying to read from FD with a non-file parser.") elif filename is not None: self.parser = self.LoadSecondaryConfig( - filename, process_includes=process_includes) + filename, process_includes=process_includes + ) try: self.parser.ReadData() except config_parser.ReadDataError as e: if must_exist: - raise ConfigFormatError("Unable to parse config file %s" % - filename) from e + raise ConfigFormatError( + "Unable to parse config file %s" % filename + ) from e elif data is not None: if issubclass(parser, config_parser.GRRConfigFileParser): self.parser = self.LoadSecondaryConfig( parser=config_parser.FileParserDataWrapper( - data.encode("utf-8"), parser("")), - process_includes=process_includes) + data.encode("utf-8"), parser("") + ), + process_includes=process_includes, + ) else: raise TypeError("Trying to parse bytes with a non-file parser.") @@ -1020,11 +1059,13 @@ def Get(self, name, default=utils.NotAValue, context=None): """ if not self.initialized: if name not in self.constants: - raise RuntimeError("Error while retrieving %s: " - "Configuration hasn't been initialized yet." % name) + raise RuntimeError( + "Error while retrieving %s: " + "Configuration hasn't been initialized yet." % name + ) if context: # Make sure it's not just a string and is iterable. - if (isinstance(context, str) or not isinstance(context, abc.Iterable)): + if isinstance(context, str) or not isinstance(context, abc.Iterable): raise ValueError("context should be a list, got %r" % context) calc_context = context @@ -1040,7 +1081,8 @@ def Get(self, name, default=utils.NotAValue, context=None): type_info_obj = self.FindTypeInfo(name) _, return_value = self._GetValue( - name, context=calc_context, default=default) + name, context=calc_context, default=default + ) # If we returned the specified default, we just return it here. if return_value is default: @@ -1051,7 +1093,8 @@ def Get(self, name, default=utils.NotAValue, context=None): return_value, default_section=name.split(".")[0], type_info_obj=type_info_obj, - context=calc_context) + context=calc_context, + ) except (lexer.ParseError, ValueError) as e: # We failed to parse the value, but a default was specified, so we just # return that. @@ -1101,7 +1144,8 @@ def _ResolveContext(self, context, name, raw_data, path=None): # Recurse into the new context configuration. for context_raw_data, value, new_path in self._ResolveContext( - context, name, context_raw_data, path=path + [element]): + context, name, context_raw_data, path=path + [element] + ): yield context_raw_data, value, new_path def _GetValue(self, name, context, default=utils.NotAValue): @@ -1138,9 +1182,12 @@ def _GetValue(self, name, context, default=utils.NotAValue): value = matches[-1][1] container = matches[-1][0] - if (len(matches) >= 2 and len(matches[-1][2]) == len(matches[-2][2]) and - matches[-1][2] != matches[-2][2] and - matches[-1][1] != matches[-2][1]): + if ( + len(matches) >= 2 + and len(matches[-1][2]) == len(matches[-2][2]) + and matches[-1][2] != matches[-2][2] + and matches[-1][1] != matches[-2][1] + ): # This warning specifies that there is an ambiguous match, the config # attempts to find the most specific value e.g. if you have a value # for X.y in context A,B,C, and a value for X.y in D,B it should choose @@ -1149,8 +1196,13 @@ def _GetValue(self, name, context, default=utils.NotAValue): # one and displays this warning. logging.warning( "Ambiguous configuration for key %s: " - "Contexts of equal length: %s (%s) and %s (%s)", name, - matches[-1][2], matches[-1][1], matches[-2][2], matches[-2][1]) + "Contexts of equal length: %s (%s) and %s (%s)", + name, + matches[-1][2], + matches[-1][1], + matches[-2][2], + matches[-2][1], + ) # If there is a writeback location this overrides any previous # values. @@ -1175,21 +1227,24 @@ def FindTypeInfo(self, name): return result - def InterpolateValue(self, - value, - type_info_obj=type_info.String(), - default_section=None, - context=None): + def InterpolateValue( + self, + value, + type_info_obj=type_info.String(), + default_section=None, + context=None, + ): """Interpolate the value and parse it with the appropriate type.""" # It is only possible to interpolate strings... - if isinstance(value, Text): + if isinstance(value, str): try: value = StringInterpolator( value, self, default_section=default_section, parameter=type_info_obj.name, - context=context).Parse() + context=context, + ).Parse() except InterpolationError as e: # TODO(hanuszczak): This is a quick hack to not refactor too much while # working on Python 3 compatibility. But this is bad and exceptions @@ -1204,7 +1259,8 @@ def InterpolateValue(self, if isinstance(value, list): value = [ self.InterpolateValue( - v, default_section=default_section, context=context) + v, default_section=default_section, context=context + ) for v in value ] @@ -1217,11 +1273,13 @@ def GetSections(self): return result - def MatchBuildContext(self, - target_os, - target_arch, - target_package, - context=None): + def MatchBuildContext( + self, + target_os, + target_arch, + target_package, + context=None, + ): """Return true if target_platforms matches the supplied parameters. Used by buildanddeploy to determine what clients need to be built. @@ -1238,8 +1296,11 @@ def MatchBuildContext(self, """ for spec in self.Get("ClientBuilder.target_platforms", context=context): spec_os, arch, package_name = spec.split("_") - if (spec_os == target_os and arch == target_arch and - package_name == target_package): + if ( + spec_os == target_os + and arch == target_arch + and package_name == target_package + ): return True return False @@ -1248,39 +1309,47 @@ def DEFINE_bool(self, name, default, help, constant=False): """A helper for defining boolean options.""" self.AddOption( type_info.Bool(name=name, default=default, description=help), - constant=constant) + constant=constant, + ) def DEFINE_float(self, name, default, help, constant=False): """A helper for defining float options.""" self.AddOption( type_info.Float(name=name, default=default, description=help), - constant=constant) + constant=constant, + ) def DEFINE_integer(self, name, default, help, constant=False): """A helper for defining integer options.""" self.AddOption( type_info.Integer(name=name, default=default, description=help), - constant=constant) + constant=constant, + ) def DEFINE_string(self, name, default, help, constant=False): """A helper for defining string options.""" self.AddOption( type_info.String(name=name, default=default or "", description=help), - constant=constant) + constant=constant, + ) def DEFINE_choice(self, name, default, choices, help, constant=False): """A helper for defining choice string options.""" self.AddOption( type_info.Choice( - name=name, default=default, choices=choices, description=help), - constant=constant) + name=name, default=default, choices=choices, description=help + ), + constant=constant, + ) def DEFINE_multichoice(self, name, default, choices, help, constant=False): """Choose multiple options from a list.""" self.AddOption( type_info.MultiChoice( - name=name, default=default, choices=choices, description=help), - constant=constant) + name=name, default=default, choices=choices, description=help + ), + constant=constant, + ) def DEFINE_integer_list(self, name, default, help, constant=False): """A helper for defining lists of integer options.""" @@ -1289,8 +1358,10 @@ def DEFINE_integer_list(self, name, default, help, constant=False): name=name, default=default, description=help, - validator=type_info.Integer()), - constant=constant) + validator=type_info.Integer(), + ), + constant=constant, + ) def DEFINE_list(self, name, default, help, constant=False): """A helper for defining lists of strings options.""" @@ -1299,32 +1370,37 @@ def DEFINE_list(self, name, default, help, constant=False): name=name, default=default, description=help, - validator=type_info.String()), - constant=constant) + validator=type_info.String(), + ), + constant=constant, + ) def DEFINE_constant_string(self, name, default, help): """A helper for defining constant strings.""" self.AddOption( type_info.String(name=name, default=default or "", description=help), - constant=True) + constant=True, + ) def DEFINE_semantic_value(self, semantic_type, name, default=None, help=""): if issubclass(semantic_type, rdf_structs.RDFStruct): - raise ValueError("DEFINE_semantic_value should be used for types based " - "on primitives.") + raise ValueError( + "DEFINE_semantic_value should be used for types based on primitives." + ) self.AddOption( type_info.RDFValueType( - rdfclass=semantic_type, - name=name, - default=default, - description=help)) + rdfclass=semantic_type, name=name, default=default, description=help + ) + ) - def DEFINE_semantic_enum(self, - enum_container: rdf_structs.EnumContainer, - name: str, - default: Optional[rdf_structs.EnumNamedValue] = None, - help: str = "") -> None: + def DEFINE_semantic_enum( + self, + enum_container: rdf_structs.EnumContainer, + name: str, + default: Optional[rdf_structs.EnumNamedValue] = None, + help: str = "", + ) -> None: if not isinstance(enum_container, rdf_structs.EnumContainer): raise ValueError("enum_container must be an EnumContainer.") @@ -1333,19 +1409,21 @@ def DEFINE_semantic_enum(self, enum_container=enum_container, name=name, default=default, - description=help)) + description=help, + ) + ) def DEFINE_semantic_struct(self, semantic_type, name, default=None, help=""): if not issubclass(semantic_type, rdf_structs.RDFStruct): - raise ValueError("DEFINE_semantic_struct should be used for types based " - "on structs.") + raise ValueError( + "DEFINE_semantic_struct should be used for types based on structs." + ) self.AddOption( type_info.RDFStructDictType( - rdfclass=semantic_type, - name=name, - default=default, - description=help)) + rdfclass=semantic_type, name=name, default=default, description=help + ) + ) def DEFINE_context(self, name): return self.DefineContext(name) @@ -1413,16 +1491,19 @@ def DEFINE_semantic_value(semantic_type, name, default=None, help=""): _CONFIG.DEFINE_semantic_value(semantic_type, name, default=default, help=help) -def DEFINE_semantic_enum(semantic_enum: rdf_structs.EnumContainer, - name: str, - default: Optional[rdf_structs.EnumNamedValue] = None, - help: str = "") -> None: +def DEFINE_semantic_enum( + semantic_enum: rdf_structs.EnumContainer, + name: str, + default: Optional[rdf_structs.EnumNamedValue] = None, + help: str = "", +) -> None: _CONFIG.DEFINE_semantic_enum(semantic_enum, name, default=default, help=help) def DEFINE_semantic_struct(semantic_type, name, default=None, help=""): _CONFIG.DEFINE_semantic_struct( - semantic_type, name, default=default, help=help) + semantic_type, name, default=default, help=help + ) def DEFINE_option(type_descriptor): @@ -1441,13 +1522,15 @@ def DEFINE_context(name): # pylint: enable=g-bad-name -def LoadConfig(config_obj, - config_file=None, - config_fd=None, - secondary_configs=None, - contexts=None, - reset=False, - parser=config_parser.IniConfigFileParser): +def LoadConfig( + config_obj, + config_file=None, + config_fd=None, + secondary_configs=None, + contexts=None, + reset=False, + parser=config_parser.IniConfigFileParser, +): """Initialize a ConfigManager with the specified options. Args: @@ -1514,7 +1597,8 @@ def ParseConfigCommandLine(rename_invalid_writeback=True): if _CONFIG["Config.writeback"]: _CONFIG.SetWriteBack( _CONFIG["Config.writeback"], - rename_invalid_writeback=rename_invalid_writeback) + rename_invalid_writeback=rename_invalid_writeback, + ) # Does the user want to dump help? We do this after the config system is # initialized so the user can examine what we think the value of all the diff --git a/grr/core/grr_response_core/lib/config_lib_test.py b/grr/core/grr_response_core/lib/config_lib_test.py index afb5e24587..05c6bd1afd 100644 --- a/grr/core/grr_response_core/lib/config_lib_test.py +++ b/grr/core/grr_response_core/lib/config_lib_test.py @@ -42,7 +42,8 @@ def testParsing(self): parser=config_parser.YamlConfigFileParser, data=""" Section2.test: 2 - """) + """, + ) conf.DEFINE_string("Section2.test", "", "A string") conf.DEFINE_context("Client Context") @@ -67,7 +68,8 @@ def testParsing(self): Section1.test: 5 Section1.test2: 2 -""") +""", + ) self.assertEqual(conf["Section1.test"], 2) @@ -77,8 +79,10 @@ def testParsing(self): self.assertEqual( conf.Get( - "Section1.test_list", context=["Client Context", - "Windows Context"]), ["x", "y"]) + "Section1.test_list", context=["Client Context", "Windows Context"] + ), + ["x", "y"], + ) # Test that contexts affect option selection. self.assertEqual(conf.Get("Section1.test", context=["Client Context"]), 6) @@ -114,7 +118,8 @@ def testConflictingContexts(self): Extra Context: Section1.test: 15 -""") +""", + ) # Without contexts. self.assertEqual(conf.Get("Section1.test"), 2) @@ -128,12 +133,17 @@ def testConflictingContexts(self): # since they are added last. self.assertEqual( conf.Get( - "Section1.test", context=["Client Context", "Platform:Windows"]), - 10) + "Section1.test", context=["Client Context", "Platform:Windows"] + ), + 10, + ) self.assertEqual( conf.Get( - "Section1.test", context=["Platform:Windows", "Client Context"]), 6) + "Section1.test", context=["Platform:Windows", "Client Context"] + ), + 6, + ) def testRemoveContext(self): """Test that conflicting contexts are resolved by precedence.""" @@ -159,7 +169,8 @@ def testRemoveContext(self): Extra Context: Section1.test: 15 -""") +""", + ) # Should be defaults, no contexts added self.assertEqual(conf.Get("Section1.test"), 2) @@ -196,7 +207,8 @@ def testContextApplied(self): data=""" Client Context: Section1.test: 6 -""") +""", + ) # Should be defaults, no contexts added self.assertFalse(conf.ContextApplied("Client Context")) @@ -225,7 +237,8 @@ def testBackslashes(self): Section1.parameter3: | \%(a\\b\\c\\d\) -""") +""", + ) self.assertEqual(conf.Get("Section1.parameter"), "a\\b\\c\\d") self.assertEqual(conf.Get("Section1.parameter2"), "a\\b\\c\\d\\e") @@ -233,13 +246,15 @@ def testBackslashes(self): def testSemanticValueType(self): conf = config_lib.GrrConfigManager() - conf.DEFINE_semantic_value(rdfvalue.DurationSeconds, "Section1.foobar", - None, "Sample help.") + conf.DEFINE_semantic_value( + rdfvalue.DurationSeconds, "Section1.foobar", None, "Sample help." + ) conf.Initialize( parser=config_parser.YamlConfigFileParser, data=""" Section1.foobar: 6d -""") +""", + ) value = conf.Get("Section1.foobar") self.assertIsInstance(value, rdfvalue.DurationSeconds) @@ -248,8 +263,9 @@ def testSemanticValueType(self): def testSemanticStructType(self): conf = config_lib.GrrConfigManager() - conf.DEFINE_semantic_struct(rdf_file_finder.FileFinderArgs, - "Section1.foobar", [], "Sample help.") + conf.DEFINE_semantic_struct( + rdf_file_finder.FileFinderArgs, "Section1.foobar", [], "Sample help." + ) conf.Initialize( parser=config_parser.YamlConfigFileParser, data=""" @@ -258,7 +274,8 @@ def testSemanticStructType(self): - "a/b" - "b/c" pathtype: "TSK" -""") +""", + ) values = conf.Get("Section1.foobar") self.assertIsInstance(values, rdf_file_finder.FileFinderArgs) @@ -271,9 +288,11 @@ def testSemanticEnum(self): conf.DEFINE_semantic_enum( enum_container=rdf_paths.PathSpec.PathType, name="Foo.Bar", - default=rdf_paths.PathSpec.PathType.TSK) + default=rdf_paths.PathSpec.PathType.TSK, + ) conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Foo.Bar: NTFS") + parser=config_parser.YamlConfigFileParser, data="Foo.Bar: NTFS" + ) value = conf.Get("Foo.Bar") self.assertIsInstance(value, rdf_structs.EnumNamedValue) @@ -286,7 +305,8 @@ def testSemanticEnum_defaultValue(self): conf.DEFINE_semantic_enum( enum_container=rdf_paths.PathSpec.PathType, name="Foo.Bar", - default=rdf_paths.PathSpec.PathType.TSK) + default=rdf_paths.PathSpec.PathType.TSK, + ) conf.Initialize(parser=config_parser.YamlConfigFileParser, data="") value = conf.Get("Foo.Bar") @@ -299,9 +319,11 @@ def testSemanticEnum_invalidValue(self): conf.DEFINE_semantic_enum( enum_container=rdf_paths.PathSpec.PathType, name="Foo.Bar", - default=rdf_paths.PathSpec.PathType.TSK) + default=rdf_paths.PathSpec.PathType.TSK, + ) conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Foo.Bar: Invalid") + parser=config_parser.YamlConfigFileParser, data="Foo.Bar: Invalid" + ) with self.assertRaises(ValueError): conf.Get("Foo.Bar") @@ -333,12 +355,18 @@ def testInit(self): self.assertEqual( conf.Get( "MemoryDriver.device_path", - context=("Client Context", "Platform:Linux")), "/dev/pmem") + context=("Client Context", "Platform:Linux"), + ), + "/dev/pmem", + ) self.assertEqual( conf.Get( "MemoryDriver.device_path", - context=("Client Context", "Platform:Windows")), r"\\.\pmem") + context=("Client Context", "Platform:Windows"), + ), + r"\\.\pmem", + ) def testSet(self): """Test setting options.""" @@ -371,33 +399,33 @@ def testSave(self): def testQuotes(self): conf = config_lib.GrrConfigManager() - conf.DEFINE_string(name="foo.bar", default="\"baz\"", help="Bar.") - conf.DEFINE_string(name="foo.quux", default="\"%(foo.bar)\"", help="Quux.") + conf.DEFINE_string(name="foo.bar", default='"baz"', help="Bar.") + conf.DEFINE_string(name="foo.quux", default='"%(foo.bar)"', help="Quux.") conf.Initialize(data="") - self.assertEqual(conf["foo.bar"], "\"baz\"") - self.assertEqual(conf["foo.quux"], "\"\"baz\"\"") + self.assertEqual(conf["foo.bar"], '"baz"') + self.assertEqual(conf["foo.quux"], '""baz""') def testWritebackQuotes(self): def Config(): conf = config_lib.GrrConfigManager() conf.DEFINE_string(name="foo.bar", default="", help="Bar.") - conf.DEFINE_string(name="foo.baz", default="\"%(foo.bar)\"", help="Baz.") + conf.DEFINE_string(name="foo.baz", default='"%(foo.bar)"', help="Baz.") return conf with temp.AutoTempFilePath(suffix=".yaml") as confpath: writeback_conf = Config() writeback_conf.SetWriteBack(confpath) - writeback_conf.Set("foo.bar", "\"quux\"") + writeback_conf.Set("foo.bar", '"quux"') writeback_conf.Write() loaded_conf = Config() loaded_conf.Initialize(filename=confpath) - self.assertEqual(loaded_conf["foo.bar"], "\"quux\"") - self.assertEqual(loaded_conf["foo.baz"], "\"\"quux\"\"") + self.assertEqual(loaded_conf["foo.bar"], '"quux"') + self.assertEqual(loaded_conf["foo.baz"], '""quux""') def _SetupConfig(self, value): conf = config_lib.GrrConfigManager() @@ -480,11 +508,13 @@ def testFileFilters(self): conf = config_lib.GrrConfigManager() conf.DEFINE_string("Valid.file", "%%(%s|file)" % filename, "test") - conf.DEFINE_string("Valid.optionalfile", "%%(%s|optionalfile)" % filename, - "test") + conf.DEFINE_string( + "Valid.optionalfile", "%%(%s|optionalfile)" % filename, "test" + ) conf.DEFINE_string("Invalid.file", "%(notafile|file)", "test") - conf.DEFINE_string("Invalid.optionalfile", "%(notafile|optionalfile)", - "test") + conf.DEFINE_string( + "Invalid.optionalfile", "%(notafile|optionalfile)", "test" + ) conf.Initialize(data="") @@ -509,8 +539,9 @@ def testErrorDetection(self): # This should raise since the config file is incorrect. errors = conf.Validate("Section1") - self.assertIn("Invalid value val2 for Integer", - str(errors["Section1.test"])) + self.assertIn( + "Invalid value val2 for Integer", str(errors["Section1.test"]) + ) def testCopyConfig(self): """Check we can copy a config and use it without affecting the old one.""" @@ -569,9 +600,9 @@ def testKeyConfigOptions(self): """) errors = conf.Validate(["Client"]) self.assertEqual(errors, {}) - self.assertIsInstance(conf["Client.executable_signing_public_key"], - rdf_crypto.RSAPublicKey) - self.assertIsInstance(conf["Client.private_key"], rdf_crypto.RSAPrivateKey) + self.assertIsInstance( + conf["Client.executable_signing_public_key"], rdf_crypto.RSAPublicKey + ) def testGet(self): conf = config_lib.GrrConfigManager() @@ -630,8 +661,9 @@ def testAddOption(self): # The default value is invalid. errors = conf.Validate("Section1") - self.assertIn("Invalid value string for Integer", - str(errors["Section1.broken_int"])) + self.assertIn( + "Invalid value string for Integer", str(errors["Section1.broken_int"]) + ) # Section not specified: self.assertRaises(config_lib.UnknownOption, conf.__getitem__, "a") @@ -672,10 +704,18 @@ def testConstants(self): # Once the config file is loaded and initialized, modification of constant # values is an error. - self.assertRaises(config_lib.ConstModificationError, conf.Set, - "Section1.const", "New string") - self.assertRaises(config_lib.ConstModificationError, conf.SetRaw, - "Section1.const", "New string") + self.assertRaises( + config_lib.ConstModificationError, + conf.Set, + "Section1.const", + "New string", + ) + self.assertRaises( + config_lib.ConstModificationError, + conf.SetRaw, + "Section1.const", + "New string", + ) @flagsaver.flagsaver(disallow_missing_config_definitions=True) def testBadConfigRaises(self): @@ -691,8 +731,9 @@ def testBadConfigRaises(self): def testBadFilterRaises(self): """Checks that bad filter directive raise.""" conf = config_lib.GrrConfigManager() - conf.DEFINE_string("Section1.foo6", "%(somefile@somepackage|resource)", - "test") + conf.DEFINE_string( + "Section1.foo6", "%(somefile@somepackage|resource)", "test" + ) conf.DEFINE_string("Section1.foo1", "%(Section1.foo6)/bar", "test") conf.Initialize(data="") @@ -710,8 +751,9 @@ def testConfigOptionsDefined(self): conf = config.CONFIG.MakeNewConfig() # Check our actual config validates - configpath = package.ResourcePath("grr-response-core", - "install_data/etc/grr-server.yaml") + configpath = package.ResourcePath( + "grr-response-core", "install_data/etc/grr-server.yaml" + ) conf.Initialize(filename=configpath) def _DefineStringName(self, conf, name): @@ -720,9 +762,17 @@ def _DefineStringName(self, conf, name): def testUnbalancedParenthesis(self): conf = config_lib.GrrConfigManager() name_list = [ - "Section1.foobar", "Section1.foo", "Section1.foo1", "Section1.foo2", - "Section1.foo3", "Section1.foo4", "Section1.foo5", "Section1.foo6", - "Section1.interpolation1", "Section1.interpolation2", "Section1.literal" + "Section1.foobar", + "Section1.foo", + "Section1.foo1", + "Section1.foo2", + "Section1.foo3", + "Section1.foo4", + "Section1.foo5", + "Section1.foo6", + "Section1.interpolation1", + "Section1.interpolation2", + "Section1.literal", ] for name in name_list: self._DefineStringName(conf, name) @@ -760,19 +810,22 @@ def testUnbalancedParenthesis(self): # Test direct access. self.assertEqual(conf["Section1.foo"], "X") - self.assertRaises(config_lib.ConfigFormatError, conf.__getitem__, - "Section1.foo1") + self.assertRaises( + config_lib.ConfigFormatError, conf.__getitem__, "Section1.foo1" + ) - self.assertRaises(config_lib.ConfigFormatError, conf.__getitem__, - "Section1.foo2") + self.assertRaises( + config_lib.ConfigFormatError, conf.__getitem__, "Section1.foo2" + ) self.assertEqual(conf["Section1.foo3"], "foo)") # Test literal expansion. self.assertEqual(conf["Section1.foo4"], "%(hello)") - self.assertRaises(config_lib.ConfigFormatError, conf.__getitem__, - "Section1.foo5") + self.assertRaises( + config_lib.ConfigFormatError, conf.__getitem__, "Section1.foo5" + ) self.assertEqual(conf["Section1.foo6"], "foo)") @@ -794,20 +847,23 @@ def testUnbalancedParenthesis(self): self.assertEqual(conf["Section1.foo6"], "foo)") # A complex regex which gets literally expanded. - self.assertEqual(conf["Section1.literal"], - r"aff4:/C\.(?P.{1,16}?)($|/.*)") + self.assertEqual( + conf["Section1.literal"], r"aff4:/C\.(?P.{1,16}?)($|/.*)" + ) def testDataTypes(self): conf = config_lib.GrrConfigManager() conf.DEFINE_float("Section1.float", 0, "A float") conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Section1.float: abc") + parser=config_parser.YamlConfigFileParser, data="Section1.float: abc" + ) errors = conf.Validate("Section1") self.assertIn("Invalid value abc for Float", str(errors["Section1.float"])) self.assertRaises(config_lib.ConfigFormatError, conf.Get, "Section1.float") conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Section1.float: 2") + parser=config_parser.YamlConfigFileParser, data="Section1.float: 2" + ) # Should have no errors now. Validate should normalize the value to a float. self.assertEqual(conf.Validate("Section1"), {}) @@ -819,7 +875,8 @@ def testDataTypes(self): conf.DEFINE_list("Section1.list", default=[], help="A list") conf.DEFINE_list("Section1.list2", default=["a", "2"], help="A list") conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Section1.int: 2.0") + parser=config_parser.YamlConfigFileParser, data="Section1.int: 2.0" + ) errors = conf.Validate("Section1") @@ -828,7 +885,8 @@ def testDataTypes(self): # A string can be coerced to an int if it makes sense: conf.Initialize( - parser=config_parser.YamlConfigFileParser, data="Section1.int: '2'") + parser=config_parser.YamlConfigFileParser, data="Section1.int: '2'" + ) conf.Validate("Section1") self.assertEqual(type(conf.Get("Section1.int")), int) @@ -888,7 +946,8 @@ def testConfigFileInclusion(self): # Using filename conf = self._GetNewConf() conf.Initialize( - parser=config_parser.YamlConfigFileParser, filename=configone) + parser=config_parser.YamlConfigFileParser, filename=configone + ) self._CheckConf(conf) # Using fd with no fd.name should raise because there is no way to resolve @@ -899,7 +958,8 @@ def testConfigFileInclusion(self): config_lib.ConfigFileNotFound, conf.Initialize, parser=config_parser.YamlConfigFileParser, - fd=fd) + fd=fd, + ) # Using data conf = self._GetNewConf() @@ -907,7 +967,8 @@ def testConfigFileInclusion(self): config_lib.ConfigFileNotFound, conf.Initialize, parser=config_parser.YamlConfigFileParser, - data=one) + data=one, + ) def testConfigFileInclusionCanBeTurnedOff(self): one = r""" @@ -934,7 +995,8 @@ def testConfigFileInclusionCanBeTurnedOff(self): conf.Initialize( parser=config_parser.YamlConfigFileParser, filename=configone, - process_includes=False) + process_includes=False, + ) self.assertFalse(conf.Get("SecondaryFileIncluded")) self.assertEqual(conf.Get("Section1.int"), 1) @@ -948,16 +1010,17 @@ def testConfigFileIncludeAbsolutePaths(self): with io.open(configone, "w") as fd: fd.write(one) - absolute_include = (r""" + absolute_include = r""" Config.includes: - %s Section1.int: 2 -""" % configone) +""" % configone conf = self._GetNewConf() conf.Initialize( - parser=config_parser.YamlConfigFileParser, data=absolute_include) + parser=config_parser.YamlConfigFileParser, data=absolute_include + ) self.assertEqual(conf["Section1.int"], 1) relative_include = r""" @@ -972,7 +1035,8 @@ def testConfigFileIncludeAbsolutePaths(self): config_lib.ConfigFileNotFound, conf.Initialize, parser=config_parser.YamlConfigFileParser, - data=relative_include) + data=relative_include, + ) # If we write it to a file it should work though. configtwo = os.path.join(temp_dir, "2.yaml") @@ -980,7 +1044,8 @@ def testConfigFileIncludeAbsolutePaths(self): fd.write(relative_include) conf.Initialize( - parser=config_parser.YamlConfigFileParser, filename=configtwo) + parser=config_parser.YamlConfigFileParser, filename=configtwo + ) self.assertEqual(conf["Section1.int"], 1) def testConfigFileInclusionWindowsPaths(self): @@ -1017,8 +1082,9 @@ def MockedWindowsOpen(filename, _=None): # testing. # # We need to also use the nt path manipulation modules. - with utils.MultiStubber((io, "open", MockedWindowsOpen), - (os, "path", ntpath)): + with utils.MultiStubber( + (io, "open", MockedWindowsOpen), (os, "path", ntpath) + ): conf = self._GetNewConf() conf.Initialize(filename=ntpath.join(config_path, "1.yaml")) self.assertEqual(conf["Section1.int"], 2) @@ -1048,7 +1114,8 @@ def testConfigFileInclusionWithContext(self): # Without specifying the context the includes are not processed. conf = self._GetNewConf() conf.Initialize( - parser=config_parser.YamlConfigFileParser, filename=configone) + parser=config_parser.YamlConfigFileParser, filename=configone + ) self.assertEqual(conf["Section1.int"], 1) # Only one config is loaded. @@ -1058,7 +1125,8 @@ def testConfigFileInclusionWithContext(self): conf = self._GetNewConf() conf.AddContext("Client Context") conf.Initialize( - parser=config_parser.YamlConfigFileParser, filename=configone) + parser=config_parser.YamlConfigFileParser, filename=configone + ) # Both config files were loaded. Note that load order is important and # well defined. @@ -1089,10 +1157,12 @@ def testMatchBuildContext(self): conf.DEFINE_context("Test3 Context") conf.Initialize(parser=config_parser.YamlConfigFileParser, data=context) conf.AddContext("Test1 Context") - result_map = [(("linux", "amd64", "deb"), True), - (("linux", "i386", "deb"), True), - (("windows", "amd64", "exe"), True), - (("windows", "i386", "exe"), False)] + result_map = [ + (("linux", "amd64", "deb"), True), + (("linux", "i386", "deb"), True), + (("windows", "amd64", "exe"), True), + (("windows", "i386", "exe"), False), + ] for result in result_map: self.assertEqual(conf.MatchBuildContext(*result[0]), result[1]) @@ -1118,8 +1188,8 @@ def testNoUnicodeWriting(self): conf = config.CONFIG.MakeNewConfig() config_file = os.path.join(self.temp_dir, "writeback.yaml") conf.SetWriteBack(config_file) - conf.DEFINE_string("NewSection1.new_option1", u"Default Value", "Help") - conf.Set(str("NewSection1.new_option1"), u"New Value1") + conf.DEFINE_string("NewSection1.new_option1", "Default Value", "Help") + conf.Set(str("NewSection1.new_option1"), "New Value1") conf.Write() data = io.open(config_file).read() diff --git a/grr/core/grr_response_core/lib/config_parser.py b/grr/core/grr_response_core/lib/config_parser.py index 7f2266d7b9..cbe6966b58 100644 --- a/grr/core/grr_response_core/lib/config_parser.py +++ b/grr/core/grr_response_core/lib/config_parser.py @@ -43,8 +43,10 @@ def __init__(self, config_path: str) -> None: self._config_path = config_path def __str__(self) -> str: - return "<%s config_path=\"%s\">" % (self.__class__.__name__, - self._config_path) + return '<%s config_path="%s">' % ( + self.__class__.__name__, + self._config_path, + ) @property def config_path(self) -> str: @@ -103,8 +105,9 @@ def RawDataToBytes(self, raw_data: Dict[str, Any]) -> bytes: def RawDataFromBytes(self, b: bytes) -> Dict[str, Any]: raise NotImplementedError() - def SaveDataToFD(self, raw_data: Dict[str, Any], - fd: io.BufferedWriter) -> None: + def SaveDataToFD( + self, raw_data: Dict[str, Any], fd: io.BufferedWriter + ) -> None: fd.write(self.RawDataToBytes(raw_data)) def ReadDataFromFD(self, fd: BinaryIO) -> Dict[str, Any]: @@ -131,10 +134,12 @@ def SaveData(self, raw_data: Dict[str, Any]) -> None: config_file.write(self.RawDataToBytes(raw_data)) except OSError as e: - logging.exception("Unable to write config file %s: %s.", self.config_path, - e) + logging.exception( + "Unable to write config file %s: %s.", self.config_path, e + ) raise SaveDataError( - f"Unable to write config file {self.config_path}: {e}.") from e + f"Unable to write config file {self.config_path}: {e}." + ) from e def ReadData(self) -> Dict[str, Any]: if not self.config_path: diff --git a/grr/core/grr_response_core/lib/config_parser_test.py b/grr/core/grr_response_core/lib/config_parser_test.py index 0c5ef3bad0..a75aec633e 100644 --- a/grr/core/grr_response_core/lib/config_parser_test.py +++ b/grr/core/grr_response_core/lib/config_parser_test.py @@ -4,7 +4,6 @@ import os import platform import stat - from typing import Any, Dict from absl.testing import absltest @@ -145,9 +144,12 @@ def testRaisesOnSave(self): def testForwardsDataToNestedParserOnRead(self): p = config_parser.FileParserDataWrapper(b"foo", StubFileParser("")) - self.assertEqual(p.ReadData(), { - "from_bytes": b"foo", - }) + self.assertEqual( + p.ReadData(), + { + "from_bytes": b"foo", + }, + ) if __name__ == "__main__": diff --git a/grr/core/grr_response_core/lib/config_testing_lib.py b/grr/core/grr_response_core/lib/config_testing_lib.py index c3aeef5728..ea7e2184fa 100644 --- a/grr/core/grr_response_core/lib/config_testing_lib.py +++ b/grr/core/grr_response_core/lib/config_testing_lib.py @@ -1,12 +1,10 @@ #!/usr/bin/env python """Helper library for config testing.""" - import copy import logging from unittest import mock - from grr_response_core import config from grr_response_core.lib import config_lib from grr_response_core.lib import utils @@ -17,12 +15,8 @@ class BuildConfigTestsBase(test_lib.GRRBaseTest): """Base for config functionality tests.""" exceptions = [ - # Server configuration files do not normally have valid client keys. - "Client.private_key", # Those keys are maybe passphrase protected so we need to skip. - "PrivateKeys.ca_key", "PrivateKeys.executable_signing_private_key", - "PrivateKeys.server_key", ] # For all the resource filters to work you need the grr-response-templates @@ -39,8 +33,9 @@ def ValidateConfig(self, config_file=None): conf_obj = config.CONFIG.MakeNewConfig() conf_obj.Initialize(filename=config_file, reset=True) - with utils.MultiStubber((config, "CONFIG", conf_obj), - (config_lib, "_CONFIG", conf_obj)): + with utils.MultiStubber( + (config, "CONFIG", conf_obj), (config_lib, "_CONFIG", conf_obj) + ): all_sections = conf_obj.GetSections() errors = conf_obj.Validate(sections=all_sections) @@ -51,8 +46,9 @@ def ValidateConfigs(self, configs): for filter_name in self.disabled_filters: test_filter_map[filter_name] = config_lib.ConfigFilter - with mock.patch.object(config_lib.ConfigFilter, "classes_by_name", - test_filter_map): + with mock.patch.object( + config_lib.ConfigFilter, "classes_by_name", test_filter_map + ): for config_file in configs: errors = self.ValidateConfig(config_file) @@ -65,5 +61,6 @@ def ValidateConfigs(self, configs): logging.info("%s:", config_entry) logging.info("%s", error) - self.fail("Validation of %s returned errors: %s" % (config_file, - errors)) + self.fail( + "Validation of %s returned errors: %s" % (config_file, errors) + ) diff --git a/grr/core/grr_response_core/lib/config_validator_base.py b/grr/core/grr_response_core/lib/config_validator_base.py index 0d1691ab9d..90f57f6bfa 100644 --- a/grr/core/grr_response_core/lib/config_validator_base.py +++ b/grr/core/grr_response_core/lib/config_validator_base.py @@ -9,6 +9,7 @@ class PrivateConfigValidator(metaclass=MetaclassRegistry): """Use this class to sanity check private config options at repack time.""" + __abstract = True # pylint: disable=g-bad-name def ValidateEndConfig(self, conf, context, errors_fatal=True): diff --git a/grr/core/grr_response_core/lib/constants.py b/grr/core/grr_response_core/lib/constants.py index b77d31ceb7..693ca7b105 100644 --- a/grr/core/grr_response_core/lib/constants.py +++ b/grr/core/grr_response_core/lib/constants.py @@ -4,30 +4,41 @@ # Special folders we want to report back for each user. The format here is: # registry key, folder name (relative to ProfileImagePath), protobuf name. profile_folders = [ - ("AppData", "Application Data", - "app_data"), ("Cache", "AppData\\Local\\Microsoft\\Windows\\" - "Temporary Internet Files", - "cache"), ("Cookies", "Cookies", - "cookies"), ("Desktop", "Desktop", "desktop"), - ("Favorites", "Favorites", - "favorites"), ("History", "AppData\\Local\\Microsoft\\Windows\\History", - "history"), ("Local AppData", "AppData\\Roaming", - "local_app_data"), ("My Music", "Music", - "my_music"), - ("My Pictures", "Pictures", - "my_pictures"), ("My Video", "Videos", - "my_video"), ("NetHood", "NetHood", - "net_hood"), ("Personal", "Documents", - "personal"), - ("PrintHood", "PrintHood", - "print_hood"), ("Programs", "AppData\\Roaming\\Microsoft\\Windows\\" - "Start Menu\\Programs", - "programs"), ("Recent", "Recent", - "recent"), ("SendTo", "SendTo", "send_to"), - ("Start Menu", "AppData\\Roaming\\Microsoft\\Windows\\Start Menu", - "start_menu"), ("Startup", "AppData\\Roaming\\Microsoft\\Windows\\" - "Start Menu\\Programs\\Startup", - "startup"), ("Templates", "Templates", "templates") + ("AppData", "Application Data", "app_data"), + ( + "Cache", + "AppData\\Local\\Microsoft\\Windows\\Temporary Internet Files", + "cache", + ), + ("Cookies", "Cookies", "cookies"), + ("Desktop", "Desktop", "desktop"), + ("Favorites", "Favorites", "favorites"), + ("History", "AppData\\Local\\Microsoft\\Windows\\History", "history"), + ("Local AppData", "AppData\\Roaming", "local_app_data"), + ("My Music", "Music", "my_music"), + ("My Pictures", "Pictures", "my_pictures"), + ("My Video", "Videos", "my_video"), + ("NetHood", "NetHood", "net_hood"), + ("Personal", "Documents", "personal"), + ("PrintHood", "PrintHood", "print_hood"), + ( + "Programs", + "AppData\\Roaming\\Microsoft\\Windows\\Start Menu\\Programs", + "programs", + ), + ("Recent", "Recent", "recent"), + ("SendTo", "SendTo", "send_to"), + ( + "Start Menu", + "AppData\\Roaming\\Microsoft\\Windows\\Start Menu", + "start_menu", + ), + ( + "Startup", + "AppData\\Roaming\\Microsoft\\Windows\\Start Menu\\Programs\\Startup", + "startup", + ), + ("Templates", "Templates", "templates"), ] CLIENT_MAX_BUFFER_SIZE = 640 * 1024 diff --git a/grr/core/grr_response_core/lib/factory.py b/grr/core/grr_response_core/lib/factory.py index bcf3b459d8..938724fb75 100644 --- a/grr/core/grr_response_core/lib/factory.py +++ b/grr/core/grr_response_core/lib/factory.py @@ -45,10 +45,12 @@ def __init__(self, cls: Type[T]): self._cls: Type[T] = cls self._entries: Dict[str, _FactoryEntry[T]] = {} - def Register(self, - name: Text, - cls: Type[T], - constructor: Optional[Callable[[], T]] = None): + def Register( + self, + name: Text, + cls: Type[T], + constructor: Optional[Callable[[], T]] = None, + ): """Registers a new constructor in the factory. Args: diff --git a/grr/core/grr_response_core/lib/fingerprint.py b/grr/core/grr_response_core/lib/fingerprint.py index aca296d1b1..0995302d94 100644 --- a/grr/core/grr_response_core/lib/fingerprint.py +++ b/grr/core/grr_response_core/lib/fingerprint.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - # Copyright 2011 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,7 +37,7 @@ # pylint: enable=g-bad-name -class Finger(object): +class Finger: """A Finger defines how to hash a file to get specific fingerprints. The Finger contains one or more hash functions, a set of ranges in the @@ -101,7 +100,7 @@ def HashBlock(self, block): hasher.update(block) -class Fingerprinter(object): +class Fingerprinter: """Compute different types of cryptographic hashes over a file. Depending on type of file and mode of invocation, filetype-specific or @@ -120,8 +119,12 @@ class Fingerprinter(object): """ BLOCK_SIZE = 1000000 - GENERIC_HASH_CLASSES = (hashlib.md5, hashlib.sha1, hashlib.sha256, - hashlib.sha512) + GENERIC_HASH_CLASSES = ( + hashlib.md5, + hashlib.sha1, + hashlib.sha256, + hashlib.sha512, + ) AUTHENTICODE_HASH_CLASSES = (hashlib.md5, hashlib.sha1) def __init__(self, file_obj): @@ -179,9 +182,11 @@ def _HashBlock(self, block, start, end): expected_range = finger.CurrentRange() if expected_range is None: continue - if (start > expected_range.start or - (start == expected_range.start and end > expected_range.end) or - (start < expected_range.start and end > expected_range.start)): + if ( + start > expected_range.start + or (start == expected_range.start and end > expected_range.end) + or (start < expected_range.start and end > expected_range.start) + ): raise RuntimeError('Cutting across fingers.') if start == expected_range.start: finger.HashBlock(block) @@ -221,8 +226,11 @@ def HashIt(self): res = {} leftover = finger.CurrentRange() if leftover: - if (len(finger.ranges) > 1 or leftover.start != self.filelength or - leftover.end != self.filelength): + if ( + len(finger.ranges) > 1 + or leftover.start != self.filelength + or leftover.end != self.filelength + ): raise RuntimeError('Non-empty range remains.') res.update(finger.metadata) for hasher in finger.hashers: @@ -243,10 +251,10 @@ def EvalGeneric(self, hashers=None): is passed through a pre-defined (or user defined) set of hash functions. Args: - hashers: An iterable of hash classes (e.g. out of hashlib) which will - be instantiated for use. If hashers is not provided, or is - provided as 'None', the default hashers will get used. To - invoke this without hashers, provide an empty list. + hashers: An iterable of hash classes (e.g. out of hashlib) which will be + instantiated for use. If hashers is not provided, or is provided as + 'None', the default hashers will get used. To invoke this without + hashers, provide an empty list. Returns: Always True, as all files are 'generic' files. @@ -301,11 +309,11 @@ def _PecoffHeaderParser(self): self.file.seek(optional_header_offset, os.SEEK_SET) buf = self.file.read(2) image_magic = struct.unpack(' self.filelength): + if ( + length == 0 + or start < optional_header_offset + optional_header_size + or start + length > self.filelength + ): # The location of the SignedData blob is just wrong (or there is none). # Ignore it -- everything else we did still makes sense. return extents @@ -346,7 +359,7 @@ def _CollectSignedData(self, extent): # If the entire blob is smaller than its header, bail out. return signed_data b_cert = buf[8:dw_length] - buf = buf[(dw_length + 7) & 0x7ffffff8:] + buf = buf[(dw_length + 7) & 0x7FFFFFF8 :] signed_data.append((w_revision, w_cert_type, b_cert)) return signed_data @@ -361,10 +374,10 @@ def EvalPecoff(self, hashers=None): parts is added to results by HashIt() Args: - hashers: An iterable of hash classes (e.g. out of hashlib) which will - be instantiated for use. If 'None' is provided, a default set - of hashers is used. To select no hash function (e.g. to only - extract metadata), use an empty iterable. + hashers: An iterable of hash classes (e.g. out of hashlib) which will be + instantiated for use. If 'None' is provided, a default set of hashers is + used. To select no hash function (e.g. to only extract metadata), use an + empty iterable. Returns: True if the file is detected as a valid PE/COFF image file, diff --git a/grr/core/grr_response_core/lib/fingerprint_test.py b/grr/core/grr_response_core/lib/fingerprint_test.py index 5dd4c1346b..2f09daa598 100644 --- a/grr/core/grr_response_core/lib/fingerprint_test.py +++ b/grr/core/grr_response_core/lib/fingerprint_test.py @@ -1,11 +1,11 @@ #!/usr/bin/env python """Tests for config_lib classes.""" - import io import os from absl.testing import absltest + from grr_response_core.lib import fingerprint from grr_response_core.lib import package @@ -44,7 +44,7 @@ def testAdjustments(self): fp._AdjustIntervals(11, 20) self.assertEmpty(fp.fingers[0].ranges) - class MockHasher(object): + class MockHasher: def __init__(self): self.seen = b'' @@ -56,8 +56,9 @@ def testHashBlock(self): # Does it invoke a hash function? dummy = b'12345' fp = fingerprint.Fingerprinter(io.BytesIO(dummy)) - big_finger = fingerprint.Finger(None, [fingerprint.Range(0, len(dummy))], - None) + big_finger = fingerprint.Finger( + None, [fingerprint.Range(0, len(dummy))], None + ) hasher = self.MockHasher() big_finger.hashers = [hasher] fp.fingers.append(big_finger) @@ -69,15 +70,17 @@ def testSampleDataParsedCorrectly(self): for fname, expected in self.SAMPLE_LIST.items(): path = package.ResourcePath( 'grr-response-test', - os.path.join('grr_response_test', 'test_data', 'fingerprint', fname)) + os.path.join('grr_response_test', 'test_data', 'fingerprint', fname), + ) with io.open(path, 'rb') as f: fp = fingerprint.Fingerprinter(f) fp.EvalGeneric() fp.EvalPecoff() result = fp.HashIt() - self.assertCountEqual(result, expected, - 'Hashing results for %s do not match.' % fname) + self.assertCountEqual( + result, expected, 'Hashing results for %s do not match.' % fname + ) # pyformat: disable SAMPLE_DATA_1 = [{ diff --git a/grr/core/grr_response_core/lib/interpolation_test.py b/grr/core/grr_response_core/lib/interpolation_test.py index c3281c9162..35c7001230 100644 --- a/grr/core/grr_response_core/lib/interpolation_test.py +++ b/grr/core/grr_response_core/lib/interpolation_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_core.lib import interpolation @@ -41,7 +40,8 @@ def testMultipleVarByteString(self): subst = interpolation.Substitution(var_config=var_config, scope_config={}) self.assertEqual( - subst.Substitute(b"%%foo%% %%bar%% %%baz%%"), b"42 BAR 1337") + subst.Substitute(b"%%foo%% %%bar%% %%baz%%"), b"42 BAR 1337" + ) def testSimpleScopeUnicodeString(self): scope_config = {sid("foo"): {vid("bar"): "BAR", vid("baz"): "BAZ"}} @@ -57,14 +57,8 @@ def testSimpleScopeByteString(self): def testMultipleScopeUnicodeString(self): scope_config = { - sid("foo"): { - vid("bar"): "BAR", - vid("baz"): "BAZ" - }, - sid("quux"): { - vid("norf"): "NORF", - vid("thud"): "THUD" - }, + sid("foo"): {vid("bar"): "BAR", vid("baz"): "BAZ"}, + sid("quux"): {vid("norf"): "NORF", vid("thud"): "THUD"}, } subst = interpolation.Substitution(var_config={}, scope_config=scope_config) @@ -73,14 +67,8 @@ def testMultipleScopeUnicodeString(self): def testMultipleScopeByteString(self): scope_config = { - sid("foo"): { - vid("bar"): 2, - vid("baz"): 3 - }, - sid("quux"): { - vid("norf"): 5, - vid("thud"): 7 - }, + sid("foo"): {vid("bar"): 2, vid("baz"): 3}, + sid("quux"): {vid("norf"): 5, vid("thud"): 7}, } subst = interpolation.Substitution(var_config={}, scope_config=scope_config) @@ -91,7 +79,8 @@ def testVarAndScope(self): var_config = {vid("foo"): "FOO"} scope_config = {sid("quux"): {vid("bar"): "BAR", vid("baz"): "BAZ"}} subst = interpolation.Substitution( - var_config=var_config, scope_config=scope_config) + var_config=var_config, scope_config=scope_config + ) pattern = "%%foo%% %%quux.bar%% %%quux.baz%%" self.assertEqual(subst.Substitute(pattern), "FOO BAR BAZ") @@ -104,13 +93,11 @@ def testMultipleVariableOccurrences(self): def testInterpolationHappensSimultaneously(self): var_config = {vid("foo"): "%%bar%%", vid("bar"): "%%quux.norf%%"} scope_config = { - sid("quux"): { - vid("norf"): "%%foo%%", - vid("thud"): "%%quux.norf%%" - } + sid("quux"): {vid("norf"): "%%foo%%", vid("thud"): "%%quux.norf%%"} } subst = interpolation.Substitution( - var_config=var_config, scope_config=scope_config) + var_config=var_config, scope_config=scope_config + ) pattern = "%%foo%% %%bar%% %%quux.norf%% %%quux.thud%%" output = "%%bar%% %%quux.norf%% %%foo%% %%quux.norf%%" @@ -143,11 +130,13 @@ def testListScopeVarsUnicodeString(self): def testListScopeVarsByteString(self): interpolator = interpolation.Interpolator(b"%%foo.A%% %%foo.B%% %%foo.C%%") self.assertEqual( - interpolator.ScopeVars(sid("foo")), { + interpolator.ScopeVars(sid("foo")), + { vid("A"), vid("B"), vid("C"), - }) + }, + ) def testBindVarSimpleUnicodeString(self): interpolator = interpolation.Interpolator("foo %%bar%% baz") @@ -251,16 +240,19 @@ def testBindScopeAndVarUnicodeString(self): interpolator.BindScope(sid("norf"), {vid("thud"): 9, vid("blargh"): 0}) strings = list(interpolator.Interpolate()) - self.assertCountEqual(strings, [ - "3|1|4|7|8", - "3|1|4|9|0", - "3|2|4|7|8", - "3|2|4|9|0", - "5|1|6|7|8", - "5|1|6|9|0", - "5|2|6|7|8", - "5|2|6|9|0", - ]) + self.assertCountEqual( + strings, + [ + "3|1|4|7|8", + "3|1|4|9|0", + "3|2|4|7|8", + "3|2|4|9|0", + "5|1|6|7|8", + "5|1|6|9|0", + "5|2|6|7|8", + "5|2|6|9|0", + ], + ) def testBindScopeKeyErrorScope(self): interpolator = interpolation.Interpolator("%%foo.bar%%") diff --git a/grr/core/grr_response_core/lib/lexer.py b/grr/core/grr_response_core/lib/lexer.py index 48fc158310..92d0fb1848 100644 --- a/grr/core/grr_response_core/lib/lexer.py +++ b/grr/core/grr_response_core/lib/lexer.py @@ -3,15 +3,13 @@ import logging import re -from typing import Text - from grr_response_core.lib import utils from grr_response_core.lib.util import precondition from grr_response_core.lib.util import text -class Token(object): +class Token: """A token action.""" state_regex = None @@ -20,20 +18,20 @@ def __init__(self, state_regex, regex, actions, next_state, flags=re.I): """Constructor. Args: - state_regex: If this regular expression matches the current state this - rule is considered. + rule is considered. regex: A regular expression to try and match from the current point. actions: A command separated list of method names in the Lexer to call. next_state: The next state we transition to if this Token matches. flags: re flags. """ - precondition.AssertType(regex, Text) - precondition.AssertOptionalType(state_regex, Text) + precondition.AssertType(regex, str) + precondition.AssertOptionalType(state_regex, str) if state_regex: - self.state_regex = re.compile(state_regex, re.DOTALL | re.M | re.S | re.U - | flags) + self.state_regex = re.compile( + state_regex, re.DOTALL | re.M | re.S | re.U | flags + ) self.regex = re.compile(regex, re.DOTALL | re.M | re.S | re.U | flags) self.re_str = regex @@ -55,15 +53,16 @@ class ParseError(Error): """A parse error occurred.""" -class Lexer(object): +class Lexer: """A generic feed lexer.""" + # A list of Token() instances. tokens = [] # Regex flags flags = 0 def __init__(self, data=""): - precondition.AssertType(data, Text) + precondition.AssertType(data, str) # Set the lexer up to process a new data feed. self.Reset() # Populate internal token list with class tokens, if defined. @@ -99,8 +98,12 @@ def NextToken(self): continue if self.verbose: - logging.debug("%s: Trying to match %r with %r", self.state, - self.buffer[:10], token.re_str) + logging.debug( + "%s: Trying to match %r with %r", + self.state, + self.buffer[:10], + token.re_str, + ) # Try to match the rule m = token.regex.match(self.buffer) @@ -117,8 +120,8 @@ def NextToken(self): # The match consumes the data off the buffer (the handler can put it back # if it likes) - self.processed_buffer += self.buffer[:m.end()] - self.buffer = self.buffer[m.end():] + self.processed_buffer += self.buffer[: m.end()] + self.buffer = self.buffer[m.end() :] self.processed += m.end() next_state = token.next_state @@ -154,7 +157,7 @@ def NextToken(self): return "Error" def Feed(self, data): - precondition.AssertType(data, Text) + precondition.AssertType(data, str) self.buffer += data def Empty(self): @@ -187,9 +190,9 @@ def PopState(self, **_): def PushBack(self, string="", **_): """Push the match back on the stream.""" - precondition.AssertType(string, Text) + precondition.AssertType(string, str) self.buffer = string + self.buffer - self.processed_buffer = self.processed_buffer[:-len(string)] + self.processed_buffer = self.processed_buffer[: -len(string)] def Close(self): """A convenience function to force us to parse all the data.""" @@ -198,8 +201,9 @@ def Close(self): return -class Expression(object): +class Expression: """A class representing an expression.""" + attribute = None args = None operator = None @@ -238,8 +242,11 @@ def AddArg(self, arg): return False def __str__(self): - return "Expression: (%s) (%s) %s" % (self.attribute, self.operator, - self.args) + return "Expression: (%s) (%s) %s" % ( + self.attribute, + self.operator, + self.args, + ) def PrintTree(self, depth=""): return "%s %s" % (depth, self) @@ -247,7 +254,8 @@ def PrintTree(self, depth=""): def Compile(self, filter_implemention): """Given a filter implementation, compile this expression.""" raise NotImplementedError( - "%s does not implement Compile." % self.__class__.__name__) + "%s does not implement Compile." % self.__class__.__name__ + ) class BinaryExpression(Expression): @@ -260,9 +268,11 @@ def __init__(self, operator="", part=None): self.args.append(part) super().__init__() - def __str__(self) -> Text: - return "Binary Expression: %s %s" % (self.operator, - [str(x) for x in self.args]) + def __str__(self) -> str: + return "Binary Expression: %s %s" % ( + self.operator, + [str(x) for x in self.args], + ) def AddOperands(self, lhs, rhs): if isinstance(lhs, Expression) and isinstance(rhs, Expression): @@ -270,7 +280,8 @@ def AddOperands(self, lhs, rhs): self.args.append(rhs) else: raise ParseError( - "Expected expression, got %s %s %s" % (lhs, self.operator, rhs)) + "Expected expression, got %s %s %s" % (lhs, self.operator, rhs) + ) def PrintTree(self, depth=""): result = "%s%s\n" % (depth, self.operator) @@ -316,15 +327,13 @@ class SearchParser(Lexer): tokens = [ # Double quoted string - Token("STRING", "\"", "PopState,StringFinish", None), + Token("STRING", '"', "PopState,StringFinish", None), Token("STRING", r"\\(.)", "StringEscape", None), Token("STRING", r"[^\\\"]+", "StringInsert", None), - # Single quoted string Token("SQ_STRING", "'", "PopState,StringFinish", None), Token("SQ_STRING", r"\\(.)", "StringEscape", None), Token("SQ_STRING", r"[^\\']+", "StringInsert", None), - # TODO(user): Implement a unary not operator. # The first thing we see in the initial state takes up to the ATTRIBUTE Token("INITIAL", r"(and|or|\&\&|\|\|)", "BinaryOperator", None), @@ -332,15 +341,14 @@ class SearchParser(Lexer): Token("INITIAL", r"\(", "BracketOpen", None), Token("INITIAL", r"\)", "BracketClose", None), Token("ATTRIBUTE", r"[\w._0-9]+", "StoreAttribute", "OPERATOR"), - Token("OPERATOR", r"[a-z0-9<>=\-\+\!\^\&%]+", "StoreOperator", - "ARG_LIST"), + Token( + "OPERATOR", r"[a-z0-9<>=\-\+\!\^\&%]+", "StoreOperator", "ARG_LIST" + ), Token("OPERATOR", "(!=|[<>=])", "StoreSpecialOperator", "ARG_LIST"), Token("ARG_LIST", r"[^\s'\"]+", "InsertArg", None), - # Start a string. - Token(".", "\"", "PushState,StringStart", "STRING"), + Token(".", '"', "PushState,StringStart", "STRING"), Token(".", "'", "PushState,StringStart", "SQ_STRING"), - # Skip whitespace. Token(".", r"\s+", None, None), ] @@ -376,7 +384,7 @@ def StringEscape(self, string, match, **_): string: The string that matched. match: The match object (m.group(1) is the escaped code) """ - precondition.AssertType(string, Text) + precondition.AssertType(string, str) if match.group(1) in "'\"rnbt": self.string += text.Unescape(string) else: @@ -423,9 +431,12 @@ def InsertArg(self, string="", **_): def _CombineBinaryExpressions(self, operator): for i in range(1, len(self.stack) - 1): item = self.stack[i] - if (isinstance(item, BinaryExpression) and item.operator == operator and - isinstance(self.stack[i - 1], Expression) and - isinstance(self.stack[i + 1], Expression)): + if ( + isinstance(item, BinaryExpression) + and item.operator == operator + and isinstance(self.stack[i - 1], Expression) + and isinstance(self.stack[i + 1], Expression) + ): lhs = self.stack[i - 1] rhs = self.stack[i + 1] @@ -437,8 +448,11 @@ def _CombineBinaryExpressions(self, operator): def _CombineParenthesis(self): for i in range(len(self.stack) - 2): - if (self.stack[i] == "(" and self.stack[i + 2] == ")" and - isinstance(self.stack[i + 1], Expression)): + if ( + self.stack[i] == "(" + and self.stack[i + 2] == ")" + and isinstance(self.stack[i + 1], Expression) + ): self.stack[i] = None self.stack[i + 2] = None @@ -468,9 +482,15 @@ def Reduce(self): return self.stack[0] def Error(self, message=None, weight=1): - raise ParseError(u"%s in position %s: %s <----> %s )" % - (utils.SmartUnicode(message), len(self.processed_buffer), - self.processed_buffer, self.buffer)) + raise ParseError( + "%s in position %s: %s <----> %s )" + % ( + utils.SmartUnicode(message), + len(self.processed_buffer), + self.processed_buffer, + self.buffer, + ) + ) def Parse(self): if not self.filter_string: diff --git a/grr/core/grr_response_core/lib/lexer_test.py b/grr/core/grr_response_core/lib/lexer_test.py index 32b1217eac..05309ff8ff 100644 --- a/grr/core/grr_response_core/lib/lexer_test.py +++ b/grr/core/grr_response_core/lib/lexer_test.py @@ -66,7 +66,8 @@ def testFailedParser(self): for expression in ( """filename contains "foo""", # Unterminated string """(filename contains "foo" """, # Unbalanced parenthesis - """filename contains foo or """): # empty right expression + """filename contains foo or """, + ): # empty right expression parser = lexer.SearchParser(expression) self.assertRaises(lexer.ParseError, parser.Parse) diff --git a/grr/core/grr_response_core/lib/parser.py b/grr/core/grr_response_core/lib/parser.py index 6d3023acb6..e8419b9684 100644 --- a/grr/core/grr_response_core/lib/parser.py +++ b/grr/core/grr_response_core/lib/parser.py @@ -17,7 +17,6 @@ class CommandParser(abstract.SingleResponseParser[Any]): """Abstract parser for processing command output. Must implement the Parse function. - """ # TODO(hanuszczak): This should probably be abstract or private. @@ -33,13 +32,16 @@ def ParseResponse(self, knowledge_base, response): stdout=response.stdout, stderr=response.stderr, return_val=response.exit_status, - knowledge_base=knowledge_base) + knowledge_base=knowledge_base, + ) def CheckReturn(self, cmd, return_val): """Raise if return value is bad.""" if return_val != 0: - message = ("Parsing output of command '{command}' failed, as command had " - "{code} return code") + message = ( + "Parsing output of command '{command}' failed, as command had " + "{code} return code" + ) raise abstract.ParseError(message.format(command=cmd, code=return_val)) @@ -70,8 +72,9 @@ def Parse(self, stat, knowledge_base): def ParseResponse(self, knowledge_base, response): # TODO(hanuszczak): Why some of the registry value parsers anticipate string # response? This is stupid. - precondition.AssertType(response, - (rdf_client_fs.StatEntry, rdfvalue.RDFString)) + precondition.AssertType( + response, (rdf_client_fs.StatEntry, rdfvalue.RDFString) + ) return self.Parse(response, knowledge_base) diff --git a/grr/core/grr_response_core/lib/parsers/__init__.py b/grr/core/grr_response_core/lib/parsers/__init__.py index 4118180627..5223f06f84 100644 --- a/grr/core/grr_response_core/lib/parsers/__init__.py +++ b/grr/core/grr_response_core/lib/parsers/__init__.py @@ -1,9 +1,7 @@ #!/usr/bin/env python """Generic parsers (for GRR server and client code).""" -from typing import Iterator -from typing import Text -from typing import Type -from typing import TypeVar + +from typing import Iterator, Type, TypeVar from grr_response_core.lib import factory from grr_response_core.lib import rdfvalue @@ -23,37 +21,45 @@ _RDFValue = rdfvalue.RDFValue SINGLE_RESPONSE_PARSER_FACTORY: _Factory[SingleResponseParser[_RDFValue]] = ( - _Factory(SingleResponseParser[_RDFValue])) + _Factory(SingleResponseParser[_RDFValue]) +) MULTI_RESPONSE_PARSER_FACTORY: _Factory[MultiResponseParser[_RDFValue]] = ( - _Factory(MultiResponseParser[_RDFValue])) + _Factory(MultiResponseParser[_RDFValue]) +) -SINGLE_FILE_PARSER_FACTORY: _Factory[SingleFileParser[_RDFValue]] = ( - _Factory(SingleFileParser[_RDFValue])) +SINGLE_FILE_PARSER_FACTORY: _Factory[SingleFileParser[_RDFValue]] = _Factory( + SingleFileParser[_RDFValue] +) -MULTI_FILE_PARSER_FACTORY: _Factory[MultiFileParser[_RDFValue]] = ( - _Factory(MultiFileParser[_RDFValue])) +MULTI_FILE_PARSER_FACTORY: _Factory[MultiFileParser[_RDFValue]] = _Factory( + MultiFileParser[_RDFValue] +) _P = TypeVar("_P", bound=Parser) -class ArtifactParserFactory(object): +class ArtifactParserFactory: """A factory wrapper class that yields parsers for specific artifact.""" - def __init__(self, artifact_name: Text) -> None: + def __init__(self, artifact_name: str) -> None: """Initializes the artifact parser factory. Args: artifact_name: A name of the artifact this factory is supposed to provide parser instances for. """ - precondition.AssertType(artifact_name, Text) + precondition.AssertType(artifact_name, str) self._artifact_name = artifact_name def HasParsers(self) -> bool: - return (self.HasSingleResponseParsers() or self.HasMultiResponseParsers() or - self.HasSingleFileParsers() or self.HasMultiFileParsers()) + return ( + self.HasSingleResponseParsers() + or self.HasMultiResponseParsers() + or self.HasSingleFileParsers() + or self.HasMultiFileParsers() + ) def HasSingleResponseParsers(self) -> bool: return any(self.SingleResponseParserTypes()) @@ -65,7 +71,8 @@ def SingleResponseParserNames(self) -> Iterator[str]: return self._SupportedNames(SINGLE_RESPONSE_PARSER_FACTORY) def SingleResponseParserTypes( - self) -> Iterator[Type[SingleResponseParser[_RDFValue]]]: + self, + ) -> Iterator[Type[SingleResponseParser[_RDFValue]]]: return self._SupportedTypes(SINGLE_RESPONSE_PARSER_FACTORY) def HasMultiResponseParsers(self) -> bool: @@ -78,7 +85,8 @@ def MultiResponseParserNames(self) -> Iterator[str]: return self._SupportedNames(MULTI_RESPONSE_PARSER_FACTORY) def MultiResponseParserTypes( - self) -> Iterator[Type[MultiResponseParser[_RDFValue]]]: + self, + ) -> Iterator[Type[MultiResponseParser[_RDFValue]]]: return self._SupportedTypes(MULTI_RESPONSE_PARSER_FACTORY) def HasSingleFileParsers(self) -> bool: @@ -91,7 +99,8 @@ def SingleFileParserNames(self) -> Iterator[str]: return self._SupportedNames(SINGLE_FILE_PARSER_FACTORY) def SingleFileParserTypes( - self) -> Iterator[Type[SingleFileParser[_RDFValue]]]: + self, + ) -> Iterator[Type[SingleFileParser[_RDFValue]]]: return self._SupportedTypes(SINGLE_FILE_PARSER_FACTORY) def HasMultiFileParsers(self) -> bool: diff --git a/grr/core/grr_response_core/lib/parsers/all.py b/grr/core/grr_response_core/lib/parsers/all.py index 94228a06a6..09978a1e47 100644 --- a/grr/core/grr_response_core/lib/parsers/all.py +++ b/grr/core/grr_response_core/lib/parsers/all.py @@ -2,18 +2,10 @@ """A module for registering all known parsers.""" from grr_response_core.lib import parsers -from grr_response_core.lib.parsers import chrome_history -from grr_response_core.lib.parsers import config_file from grr_response_core.lib.parsers import cron_file_parser -from grr_response_core.lib.parsers import firefox3_history -from grr_response_core.lib.parsers import ie_history from grr_response_core.lib.parsers import linux_cmd_parser from grr_response_core.lib.parsers import linux_file_parser -from grr_response_core.lib.parsers import linux_pam_parser from grr_response_core.lib.parsers import linux_release_parser -from grr_response_core.lib.parsers import linux_service_parser -from grr_response_core.lib.parsers import linux_software_parser -from grr_response_core.lib.parsers import linux_sysctl_parser from grr_response_core.lib.parsers import osx_file_parser from grr_response_core.lib.parsers import osx_launchd from grr_response_core.lib.parsers import windows_persistence @@ -30,32 +22,12 @@ def Register(): "Dpkg", linux_cmd_parser.DpkgCmdParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( "Dmidecode", linux_cmd_parser.DmidecodeCmdParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "Mount", config_file.MountCmdParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( "OsxSpHardware", osx_file_parser.OSXSPHardwareDataTypeParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "Ps", linux_cmd_parser.PsCmdParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( "Rpm", linux_cmd_parser.RpmCmdParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "SshdConfig", config_file.SshdConfigCmdParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "Sysctl", linux_sysctl_parser.SysctlCmdParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "YumList", linux_cmd_parser.YumListCmdParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "YumRepolist", linux_cmd_parser.YumRepolistCmdParser) - - # Grep parsers. - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "Passwd", linux_file_parser.PasswdBufferParser) - parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( - "Netgroup", linux_file_parser.NetgroupBufferParser) # WMI query parsers. - parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( - "WmiEventConsumer", wmi_parser.WMIEventConsumerParser) parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( "WmiInstalledSoftware", wmi_parser.WMIInstalledSoftwareParser) parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( @@ -74,8 +46,6 @@ def Register(): "WinCodepage", windows_registry_parser.CodepageParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( "WinEnvironment", windows_registry_parser.WinEnvironmentParser) - parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( - "WinServices", windows_registry_parser.WinServicesParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( "WinSystemDrive", windows_registry_parser.WinSystemDriveParser) parsers.SINGLE_RESPONSE_PARSER_FACTORY.Register( @@ -102,80 +72,18 @@ def Register(): "WindowsPersistenceMechanism", windows_persistence.WindowsPersistenceMechanismsParser) - # Registry multi-parsers. - parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( - "WinUserSpecialDirs", windows_registry_parser.WinUserSpecialDirs) - parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( - "WindowsRegistryInstalledSoftware", - windows_registry_parser.WindowsRegistryInstalledSoftwareParser) - - # Artifact file multi-parsers. - parsers.MULTI_RESPONSE_PARSER_FACTORY.Register( - "OsxUsers", osx_file_parser.OSXUsersParser) - # File parsers. - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "ChromeHistory", chrome_history.ChromeHistoryParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "CronAtAllAllowDeny", config_file.CronAtAllowDenyParser) parsers.SINGLE_FILE_PARSER_FACTORY.Register( "CronTab", cron_file_parser.CronTabParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "FirefoxHistory", firefox3_history.FirefoxHistoryParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "IeHistory", ie_history.IEHistoryParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "LinuxWtmp", linux_file_parser.LinuxWtmpParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "Mtab", config_file.MtabParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "Netgroup", linux_file_parser.NetgroupParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "NfsExports", config_file.NfsExportsParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "Ntpd", config_file.NtpdParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "PackageSource", config_file.PackageSourceParser) parsers.SINGLE_FILE_PARSER_FACTORY.Register( "Passwd", linux_file_parser.PasswdParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "Path", linux_file_parser.PathParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "SshdConfigFile", config_file.SshdConfigParser) - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "Sudoers", config_file.SudoersParser) parsers.SINGLE_FILE_PARSER_FACTORY.Register( "OsxLaunchdPlist", osx_file_parser.OSXLaunchdPlistParser) parsers.SINGLE_FILE_PARSER_FACTORY.Register( "OSXInstallHistoryPlist", osx_file_parser.OSXInstallHistoryPlistParser) - try: - from debian import deb822 # pylint: disable=g-import-not-at-top - parsers.SINGLE_FILE_PARSER_FACTORY.Register( - "DpkgStatusParser", - linux_software_parser.DebianPackagesStatusParser, - lambda: linux_software_parser.DebianPackagesStatusParser(deb822)) - except ImportError: - pass - # File multi-parsers. - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "LinuxBaseShadow", linux_file_parser.LinuxBaseShadowParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "LinuxLsbInit", linux_service_parser.LinuxLSBInitParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "LinuxXinetd", linux_service_parser.LinuxXinetdParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "LinuxSysvInit", linux_service_parser.LinuxSysVInitParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "LinuxPam", linux_pam_parser.PAMParser) parsers.MULTI_FILE_PARSER_FACTORY.Register( "LinuxReleaseInfo", linux_release_parser.LinuxReleaseParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "PciDevicesInfo", linux_file_parser.PCIDevicesInfoParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "ProcSys", linux_sysctl_parser.ProcSysParser) - parsers.MULTI_FILE_PARSER_FACTORY.Register( - "Rsyslog", config_file.RsyslogParser) # pyformat: enable diff --git a/grr/core/grr_response_core/lib/parsers/chrome_history.py b/grr/core/grr_response_core/lib/parsers/chrome_history.py deleted file mode 100644 index e44ddff1a2..0000000000 --- a/grr/core/grr_response_core/lib/parsers/chrome_history.py +++ /dev/null @@ -1,154 +0,0 @@ -#!/usr/bin/env python -"""Parser for Google chrome/chromium History files.""" - -import logging -from typing import IO -from typing import Iterator -from typing import Text -from typing import Tuple - -from urllib import parse as urlparse -import sqlite3 - -from grr_response_core.lib import parsers -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import webhistory as rdf_webhistory -from grr_response_core.lib.util import sqlite - - -class ChromeHistoryParser( - parsers.SingleFileParser[rdf_webhistory.BrowserHistoryItem]): - """Parse Chrome history files into BrowserHistoryItem objects.""" - - output_types = [rdf_webhistory.BrowserHistoryItem] - supported_artifacts = ["ChromiumBasedBrowsersHistory"] - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_webhistory.BrowserHistoryItem]: - del knowledge_base # Unused. - - # TODO(user): Convert this to use the far more intelligent plaso parser. - chrome = ChromeParser() - path = pathspec.CollapsePath() - for timestamp, entry_type, url, data1, _, _ in chrome.Parse(path, filedesc): - if entry_type == "CHROME_DOWNLOAD": - yield rdf_webhistory.BrowserHistoryItem( - url=url, - domain=urlparse.urlparse(url).netloc, - access_time=timestamp, - program_name="Chrome", - source_path=pathspec.CollapsePath(), - download_path=data1) - elif entry_type == "CHROME_VISIT": - yield rdf_webhistory.BrowserHistoryItem( - url=url, - domain=urlparse.urlparse(url).netloc, - access_time=timestamp, - program_name="Chrome", - source_path=pathspec.CollapsePath(), - title=data1) - - -# TODO(hanuszczak): This shouldn't be a class, just a `Parse` function taking -# history file as input. -class ChromeParser(object): - """Class for handling the parsing of a Chrome history file.""" - VISITS_QUERY = ("SELECT visits.visit_time, urls.url, urls.title, " - "urls.typed_count " - "FROM urls, visits " - "WHERE urls.id = visits.url " - "ORDER BY visits.visit_time ASC;") - - # TODO(hanuszczak): Do we need to maintain code that is supposed to work with - # Chrome history format from 2013? - - # We use DESC here so we can pop off the end of the list and interleave with - # visits to maintain time order. - DOWNLOADS_QUERY = ("SELECT downloads.start_time, downloads.url, " - "downloads.full_path, downloads.received_bytes, " - "downloads.total_bytes " - "FROM downloads " - "ORDER BY downloads.start_time DESC;") - - # This is the newer form of downloads, introduced circa Mar 2013. - DOWNLOADS_QUERY_2 = ("SELECT downloads.start_time, downloads_url_chains.url," - "downloads.target_path, downloads.received_bytes," - "downloads.total_bytes " - "FROM downloads, downloads_url_chains " - "WHERE downloads.id = downloads_url_chains.id " - "ORDER BY downloads.start_time DESC;") - - # Time diff to convert microseconds since Jan 1, 1601 00:00:00 to - # microseconds since Jan 1, 1970 00:00:00 - TIME_CONV_CONST = 11644473600000000 - - def ConvertTimestamp(self, timestamp): - if not isinstance(timestamp, int): - timestamp = 0 - elif timestamp > 11644473600000000: - timestamp -= self.TIME_CONV_CONST - elif timestamp < 631152000000000: # 01-01-1900 00:00:00 - # This means we got seconds since Jan 1, 1970, we need microseconds. - timestamp *= 1000000 - return timestamp - - # TODO(hanuszczak): This function should return a well-structured data instead - # of a tuple of 6 elements (of which 2 of them are never used). - def Parse(self, filepath: Text, filedesc: IO[bytes]) -> Iterator[Tuple]: # pylint: disable=g-bare-generic - """Iterator returning a list for each entry in history. - - We store all the download events in an array (choosing this over visits - since there are likely to be less of them). We later interleave them with - visit events to get an overall correct time order. - - Args: - filepath: A path corresponding to the database file. - filedesc: A file-like object - - Yields: - a list of attributes for each entry - """ - with sqlite.IOConnection(filedesc) as conn: - # The artifact may collect also not-database objects (e.g. journals). To - # prevent them from making the flow to fail we first check whether the - # file is really an SQLite database. If it is, then we make certain - # assumptions about its schema (to provide the user with visible error - # message that the parsing failed and maybe the format changed). If it is - # not, then we emit a warning but carry on without an error. - try: - list(conn.Query("SELECT * FROM sqlite_master LIMIT 0")) - except sqlite3.DatabaseError as error: - logging.warning("'%s' is not an SQLite database: %s", filepath, error) - return - - # Query for old style and newstyle downloads storage. - rows = [] - - try: - rows.extend(conn.Query(self.DOWNLOADS_QUERY)) - except sqlite3.Error as error: - logging.warning("Chrome history database error: %s", error) - - try: - rows.extend(conn.Query(self.DOWNLOADS_QUERY_2)) - except sqlite3.Error as error: - logging.warning("Chrome history database error: %s", error) - - results = [] - for timestamp, url, path, received_bytes, total_bytes in rows: - timestamp = self.ConvertTimestamp(timestamp) - results.append((timestamp, "CHROME_DOWNLOAD", url, path, received_bytes, - total_bytes)) - - for timestamp, url, title, typed_count in conn.Query(self.VISITS_QUERY): - timestamp = self.ConvertTimestamp(timestamp) - results.append((timestamp, "CHROME_VISIT", url, title, typed_count, "")) - - results.sort(key=lambda it: it[0]) - for it in results: - yield it diff --git a/grr/core/grr_response_core/lib/parsers/chrome_history_test.py b/grr/core/grr_response_core/lib/parsers/chrome_history_test.py deleted file mode 100644 index 54d27ad319..0000000000 --- a/grr/core/grr_response_core/lib/parsers/chrome_history_test.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# Copyright 2011 Google Inc. All Rights Reserved. -"""Tests for grr.parsers.chrome_history.""" - - -import datetime -import io -import os - -from absl import app - -from grr_response_core.lib.parsers import chrome_history -from grr_response_core.lib.util import temp -from grr.test_lib import test_lib - - -class ChromeHistoryTest(test_lib.GRRBaseTest): - """Test parsing of chrome history files.""" - - def testBasicParsing(self): - """Test we can parse a standard file.""" - history_file = os.path.join(self.base_path, "parser_test", "History2") - with io.open(history_file, mode="rb") as history_filedesc: - history = chrome_history.ChromeParser() - entries = list(history.Parse(history_file, history_filedesc)) - - try: - dt1 = datetime.datetime(1970, 1, 1) - dt1 += datetime.timedelta(microseconds=entries[0][0]) - except (TypeError, ValueError): - dt1 = entries[0][0] - - try: - dt2 = datetime.datetime(1970, 1, 1) - dt2 += datetime.timedelta(microseconds=entries[-1][0]) - except (TypeError, ValueError): - dt2 = entries[-1][0] - - # Check that our results are properly time ordered - time_results = [x[0] for x in entries] - self.assertEqual(time_results, sorted(time_results)) - - self.assertEqual(str(dt1), "2013-05-03 15:11:26.556635") - self.assertStartsWith(entries[0][2], - "https://www.google.ch/search?q=why+you+shouldn") - - self.assertEqual(str(dt2), "2013-05-03 15:11:39.763984") - self.assertStartsWith(entries[-1][2], "http://www.test.ch/") - - self.assertLen(entries, 4) - - def testTimeOrderingDownload(self): - """Test we can correctly time order downloads and visits.""" - history_file = os.path.join(self.base_path, "parser_test", "History3") - with io.open(history_file, mode="rb") as history_filedesc: - history = chrome_history.ChromeParser() - entries = list(history.Parse(history_file, history_filedesc)) - - # Check that our results are properly time ordered - time_results = [x[0] for x in entries] - self.assertEqual(time_results, sorted(time_results)) - self.assertLen(entries, 23) - - def testBasicParsingOldFormat(self): - """Test we can parse a standard file.""" - history_file = os.path.join(self.base_path, "parser_test", "History") - with io.open(history_file, mode="rb") as history_filedesc: - history = chrome_history.ChromeParser() - entries = list(history.Parse(history_file, history_filedesc)) - - try: - dt1 = datetime.datetime(1970, 1, 1) - dt1 += datetime.timedelta(microseconds=entries[0][0]) - except (TypeError, ValueError): - dt1 = entries[0][0] - - try: - dt2 = datetime.datetime(1970, 1, 1) - dt2 += datetime.timedelta(microseconds=entries[-1][0]) - except (TypeError, ValueError): - dt2 = entries[-1][0] - - # Check that our results are properly time ordered - time_results = [x[0] for x in entries] - self.assertEqual(time_results, sorted(time_results)) - - self.assertEqual(str(dt1), "2011-04-07 12:03:11") - self.assertEqual(entries[0][2], "http://start.ubuntu.com/10.04/Google/") - - self.assertEqual(str(dt2), "2011-05-23 08:37:27.061516") - self.assertStartsWith( - entries[-1][2], "https://chrome.google.com/webs" - "tore/detail/mfjkgbjaikamkkojmak" - "jclmkianficch") - - self.assertLen(entries, 71) - - def testNonSqliteDatabase(self): - with temp.AutoTempFilePath(suffix="-journal") as filepath: - with io.open(filepath, "wb") as filedesc: - filedesc.write(b"foobar") - - with io.open(filepath, "rb") as filedesc: - # This should not fail, but return an empty list of results. - results = list(chrome_history.ChromeParser().Parse(filepath, filedesc)) - self.assertEmpty(results) - - -def main(argv): - test_lib.main(argv) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/config_file.py b/grr/core/grr_response_core/lib/parsers/config_file.py index 1244f0e608..e921dcce70 100644 --- a/grr/core/grr_response_core/lib/parsers/config_file.py +++ b/grr/core/grr_response_core/lib/parsers/config_file.py @@ -1,26 +1,13 @@ #!/usr/bin/env python """Simple parsers for configuration files.""" + from collections import abc import logging import re -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import Text from grr_response_core.lib import lexer -from grr_response_core.lib import parser -from grr_response_core.lib import parsers -from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import config_file as rdf_config_file -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict -from grr_response_core.lib.rdfvalues import standard as rdf_standard from grr_response_core.lib.util import precondition -from grr_response_core.lib.util import text def AsIter(arg): @@ -66,14 +53,16 @@ class FieldParser(lexer.Lexer): compatible regex string. """ - def __init__(self, - comments=r"#", - cont=r"\\\s*\n", - ml_quote=False, - quot=(r"\"", r"'"), - sep=r"[ \t\f\v]+", - term=r"[\r\n]", - verbose=0): + def __init__( + self, + comments=r"#", + cont=r"\\\s*\n", + ml_quote=False, + quot=(r"\"", r"'"), + sep=r"[ \t\f\v]+", + term=r"[\r\n]", + verbose=0, + ): r"""A generalized field-based parser. Handles whitespace, csv etc. @@ -189,8 +178,8 @@ def BadLine(self, **_): logging.debug("Skipped bad line in file at %s", self.processed) self.field = "" - def ParseEntries(self, data: Text): - precondition.AssertType(data, Text) + def ParseEntries(self, data: str): + precondition.AssertType(data, str) # Flush any old results. self.Reset() @@ -211,15 +200,17 @@ class KeyValueParser(FieldParser): kv_sep defaults to "=" """ - def __init__(self, - comments=r"#", - cont=r"\\\s*\n", - kv_sep="=", - ml_quote=False, - quot=(r"\"", r"'"), - sep=r"[ \t\f\v]+", - term=r"[\r\n]", - verbose=0): + def __init__( + self, + comments=r"#", + cont=r"\\\s*\n", + kv_sep="=", + ml_quote=False, + quot=(r"\"", r"'"), + sep=r"[ \t\f\v]+", + term=r"[\r\n]", + verbose=0, + ): """A generalized key-value parser. Handles whitespace, csv etc. @@ -242,7 +233,8 @@ def __init__(self, quot=quot, sep=sep, term=term, - verbose=verbose) + verbose=verbose, + ) self.key_field = "" def _GenStates(self): @@ -274,8 +266,9 @@ def GenInitialState(self): def GenKeyState(self): for c in self.comments: - self._AddToken("KEY", c, "EndKeyField,EndEntry,PopState,PushBack", - "COMMENT") + self._AddToken( + "KEY", c, "EndKeyField,EndEntry,PopState,PushBack", "COMMENT" + ) for t in self.term: self._AddToken("KEY", t, "EndKeyField,EndEntry,PopState", None) for k in self.kv_sep: @@ -283,8 +276,9 @@ def GenKeyState(self): def GenValueState(self): for c in self.comments: - self._AddToken("VALUE", c, "EndField,EndEntry,PopState,PushBack", - "COMMENT") + self._AddToken( + "VALUE", c, "EndField,EndEntry,PopState,PushBack", "COMMENT" + ) for t in self.term: self._AddToken("VALUE", t, "EndField,EndEntry,PopState", None) for s in self.sep: @@ -317,349 +311,6 @@ def ParseToOrderedDict(self, data): return result -class NfsExportsParser(parsers.SingleFileParser[rdf_config_file.NfsExport]): - """Parser for NFS exports.""" - - output_types = [rdf_config_file.NfsExport] - supported_artifacts = ["NfsExportsFile"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = FieldParser() - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_config_file.NfsExport]: - del knowledge_base # Unused. - del pathspec # Unused. - - for entry in self._field_parser.ParseEntries( - utils.ReadFileBytesAsUnicode(filedesc)): - if not entry: - continue - result = rdf_config_file.NfsExport() - result.share = entry[0] - for field in entry[1:]: - if field.startswith(("-", "(")): - result.defaults = field.strip("-()").split(",") - else: - client = rdf_config_file.NfsClient() - cfg = field.split("(", 1) - host = cfg[0] - if len(cfg) > 1: - options = cfg[1] - else: - options = None - client.host = host - if options: - client.options = options.strip("()").split(",") - result.clients.append(client) - yield result - - -class SshdFieldParser(object): - """The base class for the ssh config parsers.""" - - # Specify the values that are boolean or integer. Anything else is a string. - _integers = ["clientalivecountmax", - "magicudsport", - "maxauthtries", - "maxsessions", - "port", - "protocol", - "serverkeybits", - "x11displayoffset"] # pyformat: disable - _booleans = ["allowagentforwarding", - "challengeresponseauthentication", - "dsaauthentication", - "gssapiauthentication", - "gssapicleanupcredentials", - "gssapikeyexchange", - "gssapistorecredentialsonrekey", - "gssapistrictacceptorcheck", - "hostbasedauthentication", - "ignorerhosts", - "ignoreuserknownhosts", - "kbdinteractiveauthentication", - "kerberosauthentication", - "passwordauthentication", - "permitemptypasswords", - "permittunnel", - "permituserenvironment", - "pubkeyauthentication", - "rhostsrsaauthentication", - "rsaauthentication", - "strictmodes", - "uselogin", - "usepam", - "x11forwarding", - "x11uselocalhost"] # pyformat: disable - # Valid ways that parameters can repeat - _repeated = { - "acceptenv": r"[\n\s]+", - "allowgroups": r"[\s]+", - "allowusers": r"[\s]+", - "authenticationmethods": r"[\s]+", - "authorizedkeysfile": r"[\s]+", - "ciphers": r"[,]+", - "denygroups": r"[\s]+", - "denyusers": r"[\s]+", - "forcecommand": r"[\n]+", - "hostkey": r"[\n]+", - "kexalgorithms": r"[,]+", - "listenaddress": r"[\n]+", - "macs": r"[,]+", - "permitopen": r"[\s]+", - "port": r"[,\n]+", - "protocol": r"[,]+", - "pubkeyacceptedkeytypes": r"[,]+", - "subsystem": r"[\n]+" - } - _true = ["yes", "true", "1"] - _aliases = {"dsaauthentication": "pubkeyauthentication"} - _match_keywords = [ - "acceptenv", "allowagentforwarding", "allowgroups", "allowtcpforwarding", - "allowusers", "authenticationmethods", "authorizedkeyscommand", - "authorizedkeyscommanduser", "authorizedkeysfile", - "authorizedprincipalsfile", "banner", "chrootdirectory", "denygroups", - "denyusers", "forcecommand", "gatewayports", "gssapiauthentication", - "hostbasedauthentication", "hostbasedusesnamefrompacketonly", - "kbdinteractiveauthentication", "kerberosauthentication", "magicudspath", - "magicudsport", "maxauthtries", "maxsessions", "passwordauthentication", - "permitemptypasswords", "permitopen", "permitrootlogin", - "permittemphomedir", "permittty", "permittunnel", - "pubkeyacceptedkeytypes", "pubkeyauthentication", "rekeylimit", - "rhostsrsaauthentication", "rsaauthentication", "temphomedirpath", - "x11displayoffset", "x11forwarding", "x11uselocalhost" - ] - - def __init__(self): - super().__init__() - self.Flush() - - def Flush(self): - self.config = {} - self.matches = [] - self.section = self.config - self.processor = self._ParseEntry - - def ParseLine(self, line): - """Extracts keyword/value settings from the sshd config. - - The keyword is always the first string item. - Values are the remainder of the string. In cases where an sshd config - allows multiple values, these are split according to whatever separator(s) - sshd_config permits for that value. - - Keywords and values are normalized. Keywords are converted to lowercase. - Values are converted into integers, booleans or strings. Strings are always - lowercased. - - Args: - line: A line of the configuration file. - """ - kv = line.split(None, 1) - keyword = kv[0].lower() - # Safely set the argument string if it wasn't found. - values = kv[1:] or [""] - # Then split any parameters that are actually repeated items. - separators = self._repeated.get(keyword) - if separators: - repeated = [] - for v in values: - repeated.extend(re.split(separators, v)) - # Remove empty matches. - values = [v for v in repeated if v] - - # Now convert the values to the right types. - if keyword in self._integers: - values = [int(v) for v in values] - elif keyword in self._booleans: - values = [v.lower() in self._true for v in values] - else: - values = [v.lower() for v in values] - # Only repeated arguments should be treated as a list. - if keyword not in self._repeated: - values = values[0] - # Switch sections for new match blocks. - if keyword == "match": - self._NewMatchSection(values) - # If it's an alias, resolve it. - if keyword in self._aliases: - keyword = self._aliases[keyword] - # Add the keyword/values to the section. - self.processor(keyword, values) - - def _ParseEntry(self, key, val): - """Adds an entry for a configuration setting. - - Args: - key: The name of the setting. - val: The value of the setting. - """ - if key in self._repeated: - setting = self.section.setdefault(key, []) - setting.extend(val) - else: - self.section.setdefault(key, val) - - def _ParseMatchGrp(self, key, val): - """Adds valid match group parameters to the configuration.""" - if key in self._match_keywords: - self._ParseEntry(key, val) - - def _NewMatchSection(self, val): - """Create a new configuration section for each match clause. - - Each match clause is added to the main config, and the criterion that will - trigger the match is recorded, as is the configuration. - - Args: - val: The value following the 'match' keyword. - """ - section = {"criterion": val, "config": {}} - self.matches.append(section) - # Now add configuration items to config section of the match block. - self.section = section["config"] - # Switch to a match-specific processor on a new match_block. - self.processor = self._ParseMatchGrp - - def GenerateResults(self): - matches = [] - for match in self.matches: - criterion, config = match["criterion"], match["config"] - block = rdf_config_file.SshdMatchBlock(criterion=criterion, config=config) - matches.append(block) - yield rdf_config_file.SshdConfig(config=self.config, matches=matches) - - -class SshdConfigParser(parsers.SingleFileParser[rdf_config_file.SshdConfig]): - """A parser for sshd_config files.""" - - supported_artifacts = ["SshdConfigFile"] - output_types = [rdf_config_file.SshdConfig] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = SshdFieldParser() - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_config_file.SshdConfig]: - del knowledge_base # Unused. - del pathspec # Unused. - - # Clean out any residual state. - self._field_parser.Flush() - lines = [ - l.strip() for l in utils.ReadFileBytesAsUnicode(filedesc).splitlines() - ] - for line in lines: - # Remove comments (will break if it includes a quoted/escaped #) - line = line.split("#")[0].strip() - if line: - self._field_parser.ParseLine(line) - for result in self._field_parser.GenerateResults(): - yield result - - -class SshdConfigCmdParser(parser.CommandParser): - """A command parser for sshd -T output.""" - - supported_artifacts = ["SshdConfigCmd"] - output_types = [rdf_config_file.SshdConfig] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = SshdFieldParser() - - def Parse(self, cmd, args, stdout, stderr, return_val, knowledge_base): - # Clean out any residual state. - self._field_parser.Flush() - lines = [l.strip() for l in stdout.splitlines()] - for line in lines: - if line: - self._field_parser.ParseLine(line) - for result in self._field_parser.GenerateResults(): - yield result - - -class MtabParser(parsers.SingleFileParser[rdf_client_fs.Filesystem]): - """Parser for mounted filesystem data acquired from /proc/mounts.""" - output_types = [rdf_client_fs.Filesystem] - supported_artifacts = ["LinuxProcMounts", "LinuxFstab"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = FieldParser() - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_client_fs.Filesystem]: - del knowledge_base # Unused. - del pathspec # Unused. - - for entry in self._field_parser.ParseEntries( - utils.ReadFileBytesAsUnicode(filedesc)): - if not entry: - continue - result = rdf_client_fs.Filesystem() - result.device = text.Unescape(entry[0]) - result.mount_point = text.Unescape(entry[1]) - result.type = text.Unescape(entry[2]) - options = KeyValueParser(term=",").ParseToOrderedDict(entry[3]) - # Keys without values get assigned [] by default. Because these keys are - # actually true, if declared, change any [] values to True. - for k, v in options.items(): - options[k] = v or [True] - result.options = rdf_protodict.AttributedDict(**options) - yield result - - -class MountCmdParser(parser.CommandParser): - """Parser for mounted filesystem data acquired from the mount command.""" - output_types = [rdf_client_fs.Filesystem] - supported_artifacts = ["LinuxMountCmd"] - - mount_re = re.compile(r"(.*) on (.*) type (.*) \((.*)\)") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = FieldParser() - - def Parse(self, cmd, args, stdout, stderr, return_val, knowledge_base): - """Parse the mount command output.""" - _ = stderr, args, knowledge_base # Unused. - self.CheckReturn(cmd, return_val) - for entry in self._field_parser.ParseEntries(stdout): - line_str = " ".join(entry) - mount_rslt = self.mount_re.match(line_str) - if mount_rslt: - device, mount_point, fs_type, option_str = mount_rslt.groups() - result = rdf_client_fs.Filesystem() - result.device = device - result.mount_point = mount_point - result.type = fs_type - # Parse these options as a dict as some items may be key/values. - # KeyValue parser uses OrderedDict as the native parser method. Use it. - options = KeyValueParser(term=",").ParseToOrderedDict(option_str) - # Keys without values get assigned [] by default. Because these keys are - # actually true, if declared, change any [] values to True. - for k, v in options.items(): - options[k] = v or [True] - result.options = rdf_protodict.AttributedDict(**options) - yield result - - class RsyslogFieldParser(FieldParser): """Field parser for syslog configurations.""" @@ -704,520 +355,3 @@ def ParseAction(self, action): rslt.destination = dst.group(1) break return rslt - - -class RsyslogParser(parsers.MultiFileParser[rdf_protodict.AttributedDict]): - """Artifact parser for syslog configurations.""" - - output_types = [rdf_protodict.AttributedDict] - supported_artifacts = ["LinuxRsyslogConfigs"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = RsyslogFieldParser() - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_protodict.AttributedDict]: - del knowledge_base # Unused. - del pathspecs # Unused. - - # TODO(user): review quoting and line continuation. - result = rdf_config_file.LogConfig() - for file_obj in filedescs: - for entry in self._field_parser.ParseEntries( - utils.ReadFileBytesAsUnicode(file_obj)): - directive = entry[0] - log_rule = self._field_parser.log_rule_re.match(directive) - if log_rule and entry[1:]: - target = self._field_parser.ParseAction(entry[1]) - target.facility, target.priority = log_rule.groups() - result.targets.append(target) - yield result - - -class PackageSourceParser(parsers.SingleFileParser[rdf_protodict.AttributedDict] - ): - """Common code for APT and YUM source list parsing.""" - output_types = [rdf_protodict.AttributedDict] - - # Prevents this from automatically registering. - __abstract = True # pylint: disable=g-bad-name - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_protodict.AttributedDict]: - del knowledge_base # Unused. - - uris_to_parse = self.FindPotentialURIs(filedesc) - uris = [] - - for url_to_parse in uris_to_parse: - url = rdf_standard.URI.FromHumanReadable(url_to_parse) - - # if no transport then url_to_parse wasn't actually a valid URL - # either host or path also have to exist for this to be a valid URL - if url.transport and (url.host or url.path): - uris.append(url) - - filename = pathspec.path - cfg = {"filename": filename, "uris": uris} - yield rdf_protodict.AttributedDict(**cfg) - - def FindPotentialURIs(self, file_obj): - """Stub Method to be overridden by APT and Yum source parsers.""" - raise NotImplementedError("Please implement FindPotentialURIs.") - - # TODO: Make sure all special cases are caught by this function. - def ParseURIFromKeyValues(self, data, separator, uri_key): - """Parse key/value formatted source listing and return potential URLs. - - The fundamental shape of this format is as follows: - key: value # here : = separator - key : value - URI: [URL] # here URI = uri_key - [URL] # this is where it becomes trickey because [URL] - [URL] # can contain 'separator' specially if separator is : - key: value - - The key uri_key is of interest to us and since the next line - in the config could contain another [URL], we need to keep track of context - when we hit uri_key to be able to check if the next line(s) - have more [URL]. - - Args: - data: unprocessed lines from a file - separator: how the key/value pairs are separated - uri_key: starting name of the key containing URI. - - Returns: - A list of potential URLs found in data - """ - precondition.AssertType(data, Text) - precondition.AssertType(separator, Text) - - kv_entries = KeyValueParser(kv_sep=separator).ParseEntries(data) - spaced_entries = FieldParser().ParseEntries(data) - - uris = [] - check_uri_on_next_line = False - for kv_entry, sp_entry in zip(kv_entries, spaced_entries): - for k, v in kv_entry.items(): - # This line could be a URL if a) from key:value, value is empty OR - # b) if separator is : and first character of v starts with /. - if (check_uri_on_next_line and - (not v or (separator == ":" and v[0].startswith("/")))): - uris.append(sp_entry[0]) - else: - check_uri_on_next_line = False - if k.lower().startswith(uri_key) and v: - check_uri_on_next_line = True - uris.append(v[0]) # v is a list - - return uris - - -class APTPackageSourceParser(PackageSourceParser): - """Parser for APT source lists to extract URIs only.""" - supported_artifacts = ["APTSources"] - - def FindPotentialURIs(self, file_obj): - """Given a file, this will return all potential APT source URIs.""" - rfc822_format = "" # will contain all lines not in legacy format - uris_to_parse = [] - - for line in utils.ReadFileBytesAsUnicode(file_obj).splitlines(True): - # check if legacy style line - if it is then extract URL - m = re.search(r"^\s*deb(?:-\S+)?(?:\s+\[[^\]]*\])*\s+(\S+)(?:\s|$)", line) - if m: - uris_to_parse.append(m.group(1)) - else: - rfc822_format += line - - uris_to_parse.extend(self.ParseURIFromKeyValues(rfc822_format, ":", "uri")) - return uris_to_parse - - -class YumPackageSourceParser(PackageSourceParser): - """Parser for Yum source lists to extract URIs only.""" - supported_artifacts = ["YumSources"] - - def FindPotentialURIs(self, file_obj): - """Given a file, this will return all potential Yum source URIs.""" - return self.ParseURIFromKeyValues( - utils.ReadFileBytesAsUnicode(file_obj), "=", "baseurl") - - -class CronAtAllowDenyParser( - parsers.SingleFileParser[rdf_protodict.AttributedDict]): - """Parser for /etc/cron.allow /etc/cron.deny /etc/at.allow & /etc/at.deny.""" - output_types = [rdf_protodict.AttributedDict] - supported_artifacts = ["CronAtAllowDenyFiles"] - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_protodict.AttributedDict]: - del knowledge_base # Unused. - - lines = set([ - l.strip() for l in utils.ReadFileBytesAsUnicode(filedesc).splitlines() - ]) - - users = [] - bad_lines = [] - for line in lines: - # behaviour of At/Cron is undefined for lines with whitespace separated - # fields/usernames - if " " in line: - bad_lines.append(line) - elif line: # drop empty lines - users.append(line) - - filename = pathspec.path - cfg = {"filename": filename, "users": users} - yield rdf_protodict.AttributedDict(**cfg) - - if bad_lines: - yield rdf_anomaly.Anomaly( - type="PARSER_ANOMALY", - symptom="Dodgy entries in %s." % (filename), - reference_pathspec=pathspec, - finding=bad_lines) - - -class NtpdFieldParser(FieldParser): - """Field parser for ntpd.conf file.""" - output_types = [rdf_config_file.NtpConfig] - supported_artifacts = ["NtpConfFile"] - - # The syntax is based on: - # https://www.freebsd.org/cgi/man.cgi?query=ntp.conf&sektion=5 - # keywords with integer args. - _integers = set(["ttl", "hop"]) - # keywords with floating point args. - _floats = set(["broadcastdelay", "calldelay"]) - # keywords that have repeating args. - _repeated = set(["ttl", "hop"]) - # keywords that set an option state, but can be "repeated" as well. - _boolean = set(["enable", "disable"]) - # keywords that are keyed to their first argument, an address. - _address_based = set([ - "trap", "fudge", "server", "restrict", "peer", "broadcast", - "manycastclient" - ]) - # keywords that append/augment the config. - _accumulators = set(["includefile", "setvar"]) - # keywords that can appear multiple times, accumulating data each time. - _duplicates = _address_based | _boolean | _accumulators - # All the expected keywords. - _match_keywords = _integers | _floats | _repeated | _duplicates | set([ - "autokey", "revoke", "multicastclient", "driftfile", "broadcastclient", - "manycastserver", "includefile", "interface", "disable", "includefile", - "discard", "logconfig", "logfile", "tos", "tinker", "keys", "keysdir", - "requestkey", "trustedkey", "crypto", "control", "statsdir", "filegen" - ]) - - defaults = { - "auth": True, - "bclient": False, - "calibrate": False, - "kernel": False, - "monitor": True, - "ntp": True, - "pps": False, - "stats": False - } - - def __init__(self): - super().__init__() - # ntp.conf has no line continuation. Override the default 'cont' values - # then parse up the lines. - self.cont = "" - self.config = self.defaults.copy() - self.keyed = {} - - def ParseLine(self, entries): - """Extracts keyword/value settings from the ntpd config. - - The keyword is always the first entry item. - Values are the remainder of the entries. In cases where an ntpd config - allows multiple values, these are split according to whitespace or - duplicate entries. - - Keywords and values are normalized. Keywords are converted to lowercase. - Values are converted into integers, floats or strings. Strings are always - lowercased. - - Args: - entries: A list of items making up a single line of a ntp.conf file. - """ - # If no entries were found, short circuit. - if not entries: - return - keyword = entries[0].lower() - # Set the argument string if it wasn't found. - values = entries[1:] or [""] - - # Convert any types we need too. - if keyword in self._integers: - values = [int(v) for v in values] - if keyword in self._floats: - values = [float(v) for v in values] - - if keyword not in self._repeated | self._duplicates: - # We have a plain and simple single key/value config line. - if isinstance(values[0], str): - self.config[keyword] = " ".join(values) - else: - self.config[keyword] = values - - elif keyword in self._repeated: - # The keyword can have multiple single-word options, so add them as a list - # and overwrite previous settings. - self.config[keyword] = values - - elif keyword in self._duplicates: - if keyword in self._address_based: - # If we have an address keyed keyword, join the keyword and address - # together to make the complete key for this data. - address = values[0].lower() - values = values[1:] or [""] - # Add/overwrite the address in this 'keyed' keywords dictionary. - existing_keyword_config = self.keyed.setdefault(keyword, []) - # Create a dict which stores the server name and the options. - # Flatten the remaining options into a single string. - existing_keyword_config.append({ - "address": address, - "options": " ".join(values) - }) - - # Are we toggling an option? - elif keyword in self._boolean: - for option in values: - if keyword == "enable": - self.config[option] = True - else: - # As there are only two items in this set, we can assume disable. - self.config[option] = False - - else: - # We have a non-keyed & non-boolean keyword, so add to the collected - # data so far. Order matters technically. - prev_settings = self.config.setdefault(keyword, []) - prev_settings.append(" ".join(values)) - - -class NtpdParser(parsers.SingleFileParser[rdf_config_file.NtpConfig]): - """Artifact parser for ntpd.conf file.""" - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_config_file.NtpConfig]: - del knowledge_base # Unused. - del pathspec # Unused. - - # TODO(hanuszczak): This parser only allows single use because it messes - # with its state. This should be fixed. - field_parser = NtpdFieldParser() - for line in field_parser.ParseEntries( - utils.ReadFileBytesAsUnicode(filedesc)): - field_parser.ParseLine(line) - - yield rdf_config_file.NtpConfig( - config=field_parser.config, - server=field_parser.keyed.get("server"), - restrict=field_parser.keyed.get("restrict"), - fudge=field_parser.keyed.get("fudge"), - trap=field_parser.keyed.get("trap"), - peer=field_parser.keyed.get("peer"), - broadcast=field_parser.keyed.get("broadcast"), - manycastclient=field_parser.keyed.get("manycastclient")) - - def ParseMultiple(self, stats, file_objects, knowledge_base): - for s, f in zip(stats, file_objects): - for rslt in self.Parse(s, f, knowledge_base): - yield rslt - - -class SudoersFieldParser(FieldParser): - """Parser for privileged configuration files such as sudoers and pam.d/su.""" - - # Regex to remove comments from the file. The first group in the OR condition - # handles comments that cover a full line, while also ignoring #include(dir). - # The second group in the OR condition handles comments that begin partways - # through a line, without matching UIDs or GIDs which are specified with # in - # the format. - # TODO(user): this regex fails to match '#32 users', but handles quite a - # lot else. - # TODO(user): this should be rewritten as a proper lexer - COMMENTS_RE = re.compile(r"(#(?!include(?:dir)?\s+)\D+?$)", re.MULTILINE) - - ALIAS_TYPES = { - "User_Alias": rdf_config_file.SudoersAlias.Type.USER, - "Runas_Alias": rdf_config_file.SudoersAlias.Type.RUNAS, - "Host_Alias": rdf_config_file.SudoersAlias.Type.HOST, - "Cmnd_Alias": rdf_config_file.SudoersAlias.Type.CMD - } - ALIAS_FIELDS = { - "User_Alias": "users", - "Runas_Alias": "runas", - "Host_Alias": "hosts", - "Cmnd_Alias": "cmds" - } - DEFAULTS_KEY = "Defaults" - INCLUDE_KEYS = ["#include", "#includedir"] - - def __init__(self, *args, **kwargs): - kwargs["comments"] = [] - super().__init__(*args, **kwargs) - - def _ExtractList(self, fields, ignores=(",",), terminators=()): - """Extract a list from the given fields.""" - extracted = [] - i = 0 - for i, field in enumerate(fields): - # Space-separated comma; ignore, but this is not a finished list. - # Similar for any other specified ignores (eg, equals sign). - if field in ignores: - continue - - # However, some fields are specifically meant to terminate iteration. - if field in terminators: - break - - extracted.append(field.strip("".join(ignores))) - # Check for continuation; this will either be a trailing comma or the - # next field after this one being a comma. The lookahead here is a bit - # nasty. - if not (field.endswith(",") or - set(fields[i + 1:i + 2]).intersection(ignores)): - break - - return extracted, fields[i + 1:] - - def ParseSudoersEntry(self, entry, sudoers_config): - """Parse an entry and add it to the given SudoersConfig rdfvalue.""" - - key = entry[0] - if key in SudoersFieldParser.ALIAS_TYPES: - # Alias. - alias_entry = rdf_config_file.SudoersAlias( - type=SudoersFieldParser.ALIAS_TYPES.get(key), name=entry[1]) - - # Members of this alias, comma-separated. - members, _ = self._ExtractList(entry[2:], ignores=(",", "=")) - field = SudoersFieldParser.ALIAS_FIELDS.get(key) - getattr(alias_entry, field).Extend(members) - - sudoers_config.aliases.append(alias_entry) - elif key.startswith(SudoersFieldParser.DEFAULTS_KEY): - # Default. - # Identify scope if one exists (Defaults ...) - scope = None - if len(key) > len(SudoersFieldParser.DEFAULTS_KEY): - scope = key[len(SudoersFieldParser.DEFAULTS_KEY) + 1:] - - # There can be multiple defaults on a line, for the one scope. - entry = entry[1:] - defaults, _ = self._ExtractList(entry) - for default in defaults: - default_entry = rdf_config_file.SudoersDefault(scope=scope) - - # Extract key name and value(s). - default_name = default - value = [] - if "=" in default_name: - default_name, remainder = default_name.split("=", 1) - value = [remainder] - default_entry.name = default_name - if entry: - default_entry.value = " ".join(value) - - sudoers_config.defaults.append(default_entry) - elif key in SudoersFieldParser.INCLUDE_KEYS: - # TODO(user): make #includedir more obvious in the RDFValue somewhere - target = " ".join(entry[1:]) - sudoers_config.includes.append(target) - else: - users, entry = self._ExtractList(entry) - hosts, entry = self._ExtractList(entry, terminators=("=",)) - - # Remove = from = - if entry[0] == "=": - entry = entry[1:] - - # Command specification. - sudoers_entry = rdf_config_file.SudoersEntry( - users=users, hosts=hosts, cmdspec=entry) - - sudoers_config.entries.append(sudoers_entry) - - def Preprocess(self, data): - """Preprocess the given data, ready for parsing.""" - # Add whitespace to line continuations. - data = data.replace(":\\", ": \\") - - # Strip comments manually because sudoers has multiple meanings for '#'. - data = SudoersFieldParser.COMMENTS_RE.sub("", data) - return data - - -class SudoersParser(parsers.SingleFileParser[rdf_config_file.SudoersConfig]): - """Artifact parser for privileged configuration files.""" - - output_types = [rdf_config_file.SudoersConfig] - supported_artifacts = ["UnixSudoersConfiguration"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = SudoersFieldParser() - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_config_file.SudoersConfig]: - del knowledge_base # Unused. - del pathspec # Unused. - - self._field_parser.ParseEntries( - self._field_parser.Preprocess(utils.ReadFileBytesAsUnicode(filedesc))) - result = rdf_config_file.SudoersConfig() - for entry in self._field_parser.entries: - # Handle multiple entries in one line, eg: - # foo bar : baz - # ... would become ... - # [[foo, bar], [foo, baz]] - key = entry[0] - nested_entries = [] - if ":" not in entry: - nested_entries = [entry] - else: - runner = [] - for field in entry: - if field == ":": - nested_entries.append(runner) - runner = [key] - continue - - runner.append(field) - - nested_entries.append(runner) - - for nested_entry in nested_entries: - self._field_parser.ParseSudoersEntry(nested_entry, result) - - yield result diff --git a/grr/core/grr_response_core/lib/parsers/config_file_test.py b/grr/core/grr_response_core/lib/parsers/config_file_test.py index 565c331ed8..333c64b586 100644 --- a/grr/core/grr_response_core/lib/parsers/config_file_test.py +++ b/grr/core/grr_response_core/lib/parsers/config_file_test.py @@ -1,15 +1,9 @@ #!/usr/bin/env python """Unit test for config files.""" -import io - from absl import app from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly -from grr_response_core.lib.rdfvalues import config_file as rdf_config_file -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr.test_lib import test_lib CFG = b""" @@ -31,44 +25,6 @@ """ -class SshdConfigTest(test_lib.GRRBaseTest): - """Test parsing of an sshd configuration.""" - - def GetConfig(self): - """Read in the test configuration file.""" - parser = config_file.SshdConfigParser() - results = list(parser.ParseFile(None, None, io.BytesIO(CFG))) - self.assertLen(results, 1) - return results[0] - - def testParseConfig(self): - """Ensure we can extract sshd settings.""" - result = self.GetConfig() - self.assertIsInstance(result, rdf_config_file.SshdConfig) - self.assertCountEqual([2], result.config.protocol) - expect = ["aes128-ctr", "aes256-ctr", "aes128-cbc", "aes256-cbc"] - self.assertCountEqual(expect, result.config.ciphers) - - def testFindNumericValues(self): - """Keywords with numeric settings are converted to integers.""" - result = self.GetConfig() - self.assertEqual(768, result.config.serverkeybits) - self.assertCountEqual([22, 2222, 10222], result.config.port) - - def testParseMatchGroups(self): - """Match groups are added to separate sections.""" - result = self.GetConfig() - # Multiple Match groups found. - self.assertLen(result.matches, 2) - # Config options set per Match group. - block_1, block_2 = result.matches - self.assertEqual("user root", block_1.criterion) - self.assertEqual("address 192.168.3.12", block_2.criterion) - self.assertEqual("yes", block_1.config.permitrootlogin) - self.assertEqual("no", block_2.config.permitrootlogin) - self.assertFalse(block_1.config.protocol) - - class FieldParserTests(test_lib.GRRBaseTest): """Test the field parser.""" @@ -79,13 +35,21 @@ def testParser(self): this should be another entry "with this quoted text as one field" 'an entry'with" only two" fields ;; and not this comment. """ - expected = [["each", "of", "these", "words", "should", "be", "fields"], - [ - "this", "should", "be", "another", "entry", - "with this quoted text as one field" - ], ["an entrywith only two", "fields"]] + expected = [ + ["each", "of", "these", "words", "should", "be", "fields"], + [ + "this", + "should", + "be", + "another", + "entry", + "with this quoted text as one field", + ], + ["an entrywith only two", "fields"], + ] cfg = config_file.FieldParser( - sep=["[ \t\f\v]+", ":", ";"], comments=["#", ";;"]) + sep=["[ \t\f\v]+", ":", ";"], comments=["#", ";;"] + ) results = cfg.ParseEntries(test_data) for i, expect in enumerate(expected): self.assertCountEqual(expect, results[i]) @@ -123,690 +87,17 @@ def testParser(self): = # Bad line 'a key'with" no" value field ;; and not this comment. """ - expected = [{ - "key1": ["a", "list", "of", "fields"] - }, { - "key 2": ["another", "entry"] - }, { - "a keywith no value field": [] - }] + expected = [ + {"key1": ["a", "list", "of", "fields"]}, + {"key 2": ["another", "entry"]}, + {"a keywith no value field": []}, + ] cfg = config_file.KeyValueParser(kv_sep=["=", ":"], comments=["#", ";;"]) results = cfg.ParseEntries(test_data) for i, expect in enumerate(expected): self.assertDictEqual(expect, results[i]) -class NfsExportParserTests(test_lib.GRRBaseTest): - """Test the NFS exports parser.""" - - def testParseNfsExportFile(self): - test_data = br""" - /path/to/foo -rw,sync host1(ro) host2 - /path/to/bar *.example.org(all_squash,ro) \ - 192.168.1.0/24 (rw) # Mistake here - space makes this default. - """ - exports = io.BytesIO(test_data) - parser = config_file.NfsExportsParser() - results = list(parser.ParseFile(None, None, exports)) - self.assertEqual("/path/to/foo", results[0].share) - self.assertCountEqual(["rw", "sync"], results[0].defaults) - self.assertEqual("host1", results[0].clients[0].host) - self.assertCountEqual(["ro"], results[0].clients[0].options) - self.assertEqual("host2", results[0].clients[1].host) - self.assertCountEqual([], results[0].clients[1].options) - self.assertEqual("/path/to/bar", results[1].share) - self.assertCountEqual(["rw"], results[1].defaults) - self.assertEqual("*.example.org", results[1].clients[0].host) - self.assertCountEqual(["all_squash", "ro"], results[1].clients[0].options) - self.assertEqual("192.168.1.0/24", results[1].clients[1].host) - self.assertCountEqual([], results[1].clients[1].options) - - -class MtabParserTests(test_lib.GRRBaseTest): - """Test the mtab and proc/mounts parser.""" - - def testParseMountData(self): - test_data = br""" - rootfs / rootfs rw 0 0 - arnie@host.example.org:/users/arnie /home/arnie/remote fuse.sshfs rw,nosuid,nodev,max_read=65536 0 0 - /dev/sr0 /media/USB\040Drive vfat ro,nosuid,nodev - """ - exports = io.BytesIO(test_data) - parser = config_file.MtabParser() - results = list(parser.ParseFile(None, None, exports)) - self.assertEqual("rootfs", results[0].device) - self.assertEqual("/", results[0].mount_point) - self.assertEqual("rootfs", results[0].type) - self.assertTrue(results[0].options.rw) - self.assertFalse(results[0].options.ro) - - self.assertEqual("arnie@host.example.org:/users/arnie", results[1].device) - self.assertEqual("/home/arnie/remote", results[1].mount_point) - self.assertEqual("fuse.sshfs", results[1].type) - self.assertTrue(results[1].options.rw) - self.assertTrue(results[1].options.nosuid) - self.assertTrue(results[1].options.nodev) - self.assertEqual(["65536"], results[1].options.max_read) - - self.assertEqual("/dev/sr0", results[2].device) - self.assertEqual("/media/USB Drive", results[2].mount_point) - self.assertEqual("vfat", results[2].type) - self.assertTrue(results[2].options.ro) - self.assertTrue(results[2].options.nosuid) - self.assertTrue(results[2].options.nodev) - - -class MountCmdTests(test_lib.GRRBaseTest): - """Test the mount command parser.""" - - def testParseMountData(self): - test_data = r""" - rootfs on / type rootfs (rw) - arnie@host.example.org:/users/arnie on /home/arnie/remote type fuse.sshfs (rw,nosuid,nodev,max_read=65536) - /dev/sr0 on /media/USB Drive type vfat (ro,nosuid,nodev) - """ - parser = config_file.MountCmdParser() - results = list(parser.Parse("/bin/mount", [], test_data, "", 0, None)) - self.assertEqual("rootfs", results[0].device) - self.assertEqual("/", results[0].mount_point) - self.assertEqual("rootfs", results[0].type) - self.assertTrue(results[0].options.rw) - self.assertFalse(results[0].options.ro) - - self.assertEqual("arnie@host.example.org:/users/arnie", results[1].device) - self.assertEqual("/home/arnie/remote", results[1].mount_point) - self.assertEqual("fuse.sshfs", results[1].type) - self.assertTrue(results[1].options.rw) - self.assertTrue(results[1].options.nosuid) - self.assertTrue(results[1].options.nodev) - self.assertEqual(["65536"], results[1].options.max_read) - - self.assertEqual("/dev/sr0", results[2].device) - self.assertEqual("/media/USB Drive", results[2].mount_point) - self.assertEqual("vfat", results[2].type) - self.assertTrue(results[2].options.ro) - self.assertTrue(results[2].options.nosuid) - self.assertTrue(results[2].options.nodev) - - -class RsyslogParserTests(test_lib.GRRBaseTest): - """Test the rsyslog parser.""" - - def testParseRsyslog(self): - test_data = br""" - $SomeDirective - daemon.* @@tcp.example.com.:514;RSYSLOG_ForwardFormat - syslog.debug,info @udp.example.com.:514;RSYSLOG_ForwardFormat - kern.* |/var/log/pipe - news,uucp.* ~ - user.* ^/usr/bin/log2cowsay - *.* /var/log/messages - *.emerg * - mail.* -/var/log/maillog - """ - log_conf = io.BytesIO(test_data) - parser = config_file.RsyslogParser() - results = list(parser.ParseFiles(None, [None], [log_conf])) - self.assertLen(results, 1) - tcp, udp, pipe, null, script, fs, wall, async_fs = [ - target for target in results[0].targets - ] - - self.assertEqual("daemon", tcp.facility) - self.assertEqual("*", tcp.priority) - self.assertEqual("TCP", tcp.transport) - self.assertEqual("tcp.example.com.:514", tcp.destination) - - self.assertEqual("syslog", udp.facility) - self.assertEqual("debug,info", udp.priority) - self.assertEqual("UDP", udp.transport) - self.assertEqual("udp.example.com.:514", udp.destination) - - self.assertEqual("kern", pipe.facility) - self.assertEqual("*", pipe.priority) - self.assertEqual("PIPE", pipe.transport) - self.assertEqual("/var/log/pipe", pipe.destination) - - self.assertEqual("news,uucp", null.facility) - self.assertEqual("*", null.priority) - self.assertEqual("NONE", null.transport) - self.assertFalse(null.destination) - - self.assertEqual("user", script.facility) - self.assertEqual("*", script.priority) - self.assertEqual("SCRIPT", script.transport) - self.assertEqual("/usr/bin/log2cowsay", script.destination) - - self.assertEqual("*", fs.facility) - self.assertEqual("*", fs.priority) - self.assertEqual("FILE", fs.transport) - self.assertEqual("/var/log/messages", fs.destination) - - self.assertEqual("*", wall.facility) - self.assertEqual("emerg", wall.priority) - self.assertEqual("WALL", wall.transport) - self.assertEqual("*", wall.destination) - - self.assertEqual("mail", async_fs.facility) - self.assertEqual("*", async_fs.priority) - self.assertEqual("FILE", async_fs.transport) - self.assertEqual("/var/log/maillog", async_fs.destination) - - -class APTPackageSourceParserTests(test_lib.GRRBaseTest): - """Test the APT package source lists parser.""" - - def testPackageSourceData(self): - test_data = br""" - # Security updates - deb http://security.debian.org/ wheezy/updates main contrib non-free - deb-src [arch=amd64,trusted=yes] ftp://security.debian.org/ wheezy/updates main contrib non-free - - ## Random comment - - # Different transport protocols below - deb ssh://ftp.debian.org/debian wheezy main contrib non-free - deb-src file:/mnt/deb-sources-files/ wheezy main contrib non-free - - # correct - referencing root file system - deb-src file:/ - # incorrect - deb-src http:// - - # Bad lines below - these shouldn't get any URIs back - deb - deb-src [arch=i386] - deb-src abcdefghijklmnopqrstuvwxyz - """ - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/apt/sources.list") - parser = config_file.APTPackageSourceParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - - self.assertEqual("/etc/apt/sources.list", result.filename) - self.assertLen(result.uris, 5) - - self.assertEqual("http", result.uris[0].transport) - self.assertEqual("security.debian.org", result.uris[0].host) - self.assertEqual("/", result.uris[0].path) - - self.assertEqual("ftp", result.uris[1].transport) - self.assertEqual("security.debian.org", result.uris[1].host) - self.assertEqual("/", result.uris[1].path) - - self.assertEqual("ssh", result.uris[2].transport) - self.assertEqual("ftp.debian.org", result.uris[2].host) - self.assertEqual("/debian", result.uris[2].path) - - self.assertEqual("file", result.uris[3].transport) - self.assertEqual("", result.uris[3].host) - self.assertEqual("/mnt/deb-sources-files/", result.uris[3].path) - - self.assertEqual("file", result.uris[4].transport) - self.assertEqual("", result.uris[4].host) - self.assertEqual("/", result.uris[4].path) - - def testEmptySourceData(self): - test_data = (b"# comment 1\n" - b"# deb http://security.debian.org/ wheezy/updates main\n" - b"URI :\n" - b"URI:\n" - b"# Trailing whitespace on purpose\n" - b"URI: \n" - b"\n" - b"URIs :\n" - b"URIs:\n" - b"# Trailing whitespace on purpose\n" - b"URIs: \n" - b"# comment 2\n") - - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/apt/sources.list.d/test.list") - parser = config_file.APTPackageSourceParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - - self.assertEqual("/etc/apt/sources.list.d/test.list", result.filename) - self.assertEmpty(result.uris) - - def testRFC822StyleSourceDataParser(self): - """Test source list formatted as per rfc822 style.""" - - test_data = br""" - # comment comment comment - Types: deb deb-src - URIs: http://example.com/debian - http://1.example.com/debian1 - http://2.example.com/debian2 - - http://willdetect.example.com/debian-strange - URIs : ftp://3.example.com/debian3 - http://4.example.com/debian4 - blahblahblahblahblahlbha - http://willdetect2.example.com/debian-w2 - - http://willdetect3.example.com/debian-w3 - URI - URI : ssh://5.example.com/debian5 - Suites: stable testing - Sections: component1 component2 - Description: short - long long long - [option1]: [option1-value] - - deb-src [arch=amd64,trusted=yes] ftp://security.debian.org/ wheezy/updates main contrib non-free - - # comment comment comment - Types: deb - URI:ftp://another.example.com/debian2 - Suites: experimental - Sections: component1 component2 - Enabled: no - Description: http://debian.org - This URL shouldn't be picked up by the parser - [option1]: [option1-value] - - """ - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/apt/sources.list.d/rfc822.list") - parser = config_file.APTPackageSourceParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - - self.assertEqual("/etc/apt/sources.list.d/rfc822.list", result.filename) - self.assertLen(result.uris, 11) - - self.assertEqual("ftp", result.uris[0].transport) - self.assertEqual("security.debian.org", result.uris[0].host) - self.assertEqual("/", result.uris[0].path) - - self.assertEqual("http", result.uris[1].transport) - self.assertEqual("example.com", result.uris[1].host) - self.assertEqual("/debian", result.uris[1].path) - - self.assertEqual("http", result.uris[2].transport) - self.assertEqual("1.example.com", result.uris[2].host) - self.assertEqual("/debian1", result.uris[2].path) - - self.assertEqual("http", result.uris[3].transport) - self.assertEqual("2.example.com", result.uris[3].host) - self.assertEqual("/debian2", result.uris[3].path) - - self.assertEqual("http", result.uris[4].transport) - self.assertEqual("willdetect.example.com", result.uris[4].host) - self.assertEqual("/debian-strange", result.uris[4].path) - - self.assertEqual("ftp", result.uris[5].transport) - self.assertEqual("3.example.com", result.uris[5].host) - self.assertEqual("/debian3", result.uris[5].path) - - self.assertEqual("http", result.uris[6].transport) - self.assertEqual("4.example.com", result.uris[6].host) - self.assertEqual("/debian4", result.uris[6].path) - - self.assertEqual("http", result.uris[7].transport) - self.assertEqual("willdetect2.example.com", result.uris[7].host) - self.assertEqual("/debian-w2", result.uris[7].path) - - self.assertEqual("http", result.uris[8].transport) - self.assertEqual("willdetect3.example.com", result.uris[8].host) - self.assertEqual("/debian-w3", result.uris[8].path) - - self.assertEqual("ssh", result.uris[9].transport) - self.assertEqual("5.example.com", result.uris[9].host) - self.assertEqual("/debian5", result.uris[9].path) - - self.assertEqual("ftp", result.uris[10].transport) - self.assertEqual("another.example.com", result.uris[10].host) - self.assertEqual("/debian2", result.uris[10].path) - - -class YumPackageSourceParserTests(test_lib.GRRBaseTest): - """Test the Yum package source lists parser.""" - - def testPackageSourceData(self): - test_data = br""" - # comment 1 - [centosdvdiso] - name=CentOS DVD ISO - baseurl=file:///mnt - http://mirror1.centos.org/CentOS/6/os/i386/ - baseurl =ssh://mirror2.centos.org/CentOS/6/os/i386/ - enabled=1 - gpgcheck=1 - gpgkey=file:///mnt/RPM-GPG-KEY-CentOS-6 - - # comment2 - [examplerepo] - name=Example Repository - baseurl = https://mirror3.centos.org/CentOS/6/os/i386/ - enabled=1 - gpgcheck=1 - gpgkey=http://mirror.centos.org/CentOS/6/os/i386/RPM-GPG-KEY-CentOS-6 - - """ - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/yum.repos.d/test1.repo") - parser = config_file.YumPackageSourceParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - - self.assertEqual("/etc/yum.repos.d/test1.repo", result.filename) - self.assertLen(result.uris, 4) - - self.assertEqual("file", result.uris[0].transport) - self.assertEqual("", result.uris[0].host) - self.assertEqual("/mnt", result.uris[0].path) - - self.assertEqual("http", result.uris[1].transport) - self.assertEqual("mirror1.centos.org", result.uris[1].host) - self.assertEqual("/CentOS/6/os/i386/", result.uris[1].path) - - self.assertEqual("ssh", result.uris[2].transport) - self.assertEqual("mirror2.centos.org", result.uris[2].host) - self.assertEqual("/CentOS/6/os/i386/", result.uris[2].path) - - self.assertEqual("https", result.uris[3].transport) - self.assertEqual("mirror3.centos.org", result.uris[3].host) - self.assertEqual("/CentOS/6/os/i386/", result.uris[3].path) - - def testEmptySourceData(self): - test_data = (b"# comment 1\n" - b"baseurl=\n" - b"# Trailing whitespace on purpose\n" - b"baseurl= \n" - b"# Trailing whitespace on purpose\n" - b"baseurl = \n" - b"baseurl\n" - b"# comment 2\n") - - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/yum.repos.d/emptytest.repo") - parser = config_file.YumPackageSourceParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - - self.assertEqual("/etc/yum.repos.d/emptytest.repo", result.filename) - self.assertEmpty(result.uris) - - -class CronAtAllowDenyParserTests(test_lib.GRRBaseTest): - """Test the cron/at allow/deny parser.""" - - def testParseCronData(self): - test_data = br"""root - user - - user2 user3 - root - hi hello - user - pparth""" - file_obj = io.BytesIO(test_data) - pathspec = rdf_paths.PathSpec(path="/etc/at.allow") - parser = config_file.CronAtAllowDenyParser() - results = list(parser.ParseFile(None, pathspec, file_obj)) - - result = [ - d for d in results if isinstance(d, rdf_protodict.AttributedDict) - ][0] - filename = result.filename - users = result.users - self.assertEqual("/etc/at.allow", filename) - self.assertCountEqual(["root", "user", "pparth"], users) - - anomalies = [a for a in results if isinstance(a, rdf_anomaly.Anomaly)] - self.assertLen(anomalies, 1) - anom = anomalies[0] - self.assertEqual("Dodgy entries in /etc/at.allow.", anom.symptom) - self.assertCountEqual(["user2 user3", "hi hello"], anom.finding) - self.assertEqual(pathspec, anom.reference_pathspec) - self.assertEqual("PARSER_ANOMALY", anom.type) - - -class NtpParserTests(test_lib.GRRBaseTest): - """Test the ntp.conf parser.""" - - def testParseNtpConfig(self): - test_data = br""" - # Time servers - server 1.2.3.4 iburst - server 4.5.6.7 iburst - server 8.9.10.11 iburst - server time.google.com iburst - server 2001:1234:1234:2::f iburst - - # Drift file - driftfile /var/lib/ntp/ntp.drift - - restrict default nomodify noquery nopeer - - # Guard against monlist NTP reflection attacks. - disable monitor - - # Enable the creation of a peerstats file - enable stats - statsdir /var/log/ntpstats - filegen peerstats file peerstats type day link enable - - # Test only. - ttl 127 88 - broadcastdelay 0.01 -""" - conffile = io.BytesIO(test_data) - parser = config_file.NtpdParser() - results = list(parser.ParseFile(None, None, conffile)) - - # We expect some results. - self.assertTrue(results) - # There should be only one result. - self.assertLen(results, 1) - # Now that we are sure, just use that single result for easy of reading. - results = results[0] - - # Check all the expected "simple" config keywords are present. - expected_config_keywords = set([ - "driftfile", "statsdir", "filegen", "ttl", "broadcastdelay" - ]) | set(config_file.NtpdFieldParser.defaults.keys()) - self.assertEqual(expected_config_keywords, set(results.config.keys())) - - # Check all the expected "keyed" config keywords are present. - self.assertTrue(results.server) - self.assertTrue(results.restrict) - # And check one that isn't in the config, isn't in out result. - self.assertFalse(results.trap) - - # Check we got all the "servers". - servers = [ - "1.2.3.4", "4.5.6.7", "8.9.10.11", "time.google.com", - "2001:1234:1234:2::f" - ] - self.assertCountEqual(servers, [r.address for r in results.server]) - # In our test data, they all have "iburst" as an arg. Check that is found. - for r in results.server: - self.assertEqual("iburst", r.options) - - # Check a few values were parsed correctly. - self.assertEqual("/var/lib/ntp/ntp.drift", results.config["driftfile"]) - self.assertEqual("/var/log/ntpstats", results.config["statsdir"]) - self.assertEqual("peerstats file peerstats type day link enable", - results.config["filegen"]) - self.assertLen(results.restrict, 1) - self.assertEqual("default", results.restrict[0].address) - self.assertEqual("nomodify noquery nopeer", results.restrict[0].options) - # A option that can have a list of integers. - self.assertEqual([127, 88], results.config["ttl"]) - # An option that should only have a single float. - self.assertEqual([0.01], results.config["broadcastdelay"]) - - # Check the modified defaults. - self.assertFalse(results.config["monitor"]) - self.assertTrue(results.config["stats"]) - - # Check an unlisted defaults are unmodified. - self.assertFalse(results.config["kernel"]) - self.assertTrue(results.config["auth"]) - - -class SudoersParserTest(test_lib.GRRBaseTest): - """Test the sudoers parser.""" - - def testIncludes(self): - test_data = br""" - # general comment - #include a # end of line comment - #includedir b - #includeis now a comment - """ - contents = io.BytesIO(test_data) - config = config_file.SudoersParser() - result = list(config.ParseFile(None, None, contents)) - - self.assertListEqual(list(result[0].includes), ["a", "b"]) - self.assertListEqual(list(result[0].entries), []) - - def testParseAliases(self): - test_data = br""" - User_Alias basic = a , b, c - User_Alias left = a, b, c :\ - right = d, e, f - User_Alias complex = #1000, %group, %#1001, %:nonunix, %:#1002 - """ - contents = io.BytesIO(test_data) - config = config_file.SudoersParser() - result = list(config.ParseFile(None, None, contents)) - - golden = { - "aliases": [ - { - "name": "basic", - "type": "USER", - "users": ["a", "b", "c"], - }, - { - "name": "left", - "type": "USER", - "users": ["a", "b", "c"], - }, - { - "name": "right", - "type": "USER", - "users": ["d", "e", "f"], - }, - { - "name": "complex", - "type": "USER", - "users": ["#1000", "%group", "%#1001", "%:nonunix", "%:#1002"], - }, - ], - } - - self.assertDictEqual(result[0].ToPrimitiveDict(), golden) - - def testDefaults(self): - test_data = br""" - Defaults syslog=auth - Defaults>root !set_logname - Defaults:FULLTIMERS !lecture - Defaults@SERVERS log_year, logfile=/var/log/sudo.log - """ - contents = io.BytesIO(test_data) - config = config_file.SudoersParser() - result = list(config.ParseFile(None, None, contents)) - - golden = { - "defaults": [ - { - "name": "syslog", - "value": "auth", - }, - { - "scope": "root", - "name": "!set_logname", - "value": "", - }, - { - "scope": "FULLTIMERS", - "name": "!lecture", - "value": "", - }, - # 4th entry is split into two, for each option. - { - "scope": "SERVERS", - "name": "log_year", - "value": "", - }, - { - "scope": "SERVERS", - "name": "logfile", - "value": "/var/log/sudo.log", - }, - ], - } - - self.assertDictEqual(result[0].ToPrimitiveDict(), golden) - - def testSpecs(self): - test_data = br""" - # user specs - root ALL = (ALL) ALL - %wheel ALL = (ALL) ALL - bob SPARC = (OP) ALL : SGI = (OP) ALL - fred ALL = (DB) NOPASSWD: ALL - """ - contents = io.BytesIO(test_data) - config = config_file.SudoersParser() - result = list(config.ParseFile(None, None, contents)) - - golden = { - "entries": [ - { - "users": ["root"], - "hosts": ["ALL"], - "cmdspec": ["(ALL)", "ALL"], - }, - { - "users": ["%wheel"], - "hosts": ["ALL"], - "cmdspec": ["(ALL)", "ALL"], - }, - { - "users": ["bob"], - "hosts": ["SPARC"], - "cmdspec": ["(OP)", "ALL"], - }, - { - "users": ["bob"], - "hosts": ["SGI"], - "cmdspec": ["(OP)", "ALL"], - }, - { - "users": ["fred"], - "hosts": ["ALL"], - "cmdspec": ["(DB)", "NOPASSWD:", "ALL"], - }, - ], - } - - self.assertDictEqual(result[0].ToPrimitiveDict(), golden) - - def main(args): test_lib.main(args) diff --git a/grr/core/grr_response_core/lib/parsers/cron_file_parser.py b/grr/core/grr_response_core/lib/parsers/cron_file_parser.py index e1a19419aa..2652b84dc7 100644 --- a/grr/core/grr_response_core/lib/parsers/cron_file_parser.py +++ b/grr/core/grr_response_core/lib/parsers/cron_file_parser.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Simple parsers for cron type files.""" - from typing import IO from typing import Iterator @@ -41,9 +40,12 @@ def ParseFile( month=str(job.month), dayofweek=str(job.dow), command=str(job.command), - comment=str(job.comment))) + comment=str(job.comment), + ) + ) yield rdf_cronjobs.CronTabFile( # We're interested in the nominal file path, not the full Pathspec. path=pathspec.last.path, - jobs=entries) + jobs=entries, + ) diff --git a/grr/core/grr_response_core/lib/parsers/cron_file_parser_test.py b/grr/core/grr_response_core/lib/parsers/cron_file_parser_test.py index 2f75c8a48c..c2f9ac4502 100644 --- a/grr/core/grr_response_core/lib/parsers/cron_file_parser_test.py +++ b/grr/core/grr_response_core/lib/parsers/cron_file_parser_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for grr.parsers.cron_file_parser.""" - import os from absl import app @@ -32,7 +31,7 @@ def testCronTabParser(self): self.assertEqual(result.jobs[0].dayofmonth, "3") self.assertEqual(result.jobs[0].month, "4") self.assertEqual(result.jobs[0].dayofweek, "5") - self.assertEqual(result.jobs[0].command, "/usr/bin/echo \"test\"") + self.assertEqual(result.jobs[0].command, '/usr/bin/echo "test"') def main(args): diff --git a/grr/core/grr_response_core/lib/parsers/firefox3_history.py b/grr/core/grr_response_core/lib/parsers/firefox3_history.py deleted file mode 100644 index 08f218deb0..0000000000 --- a/grr/core/grr_response_core/lib/parsers/firefox3_history.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python -"""Parser for Mozilla Firefox3 3 History files.""" - -from typing import IO -from typing import Iterator -from typing import Tuple - -from urllib import parse as urlparse - -from grr_response_core.lib import parsers -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import webhistory as rdf_webhistory -from grr_response_core.lib.util import sqlite - - -class FirefoxHistoryParser( - parsers.SingleFileParser[rdf_webhistory.BrowserHistoryItem]): - """Parse Chrome history files into BrowserHistoryItem objects.""" - - output_types = [rdf_webhistory.BrowserHistoryItem] - supported_artifacts = ["FirefoxHistory"] - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_webhistory.BrowserHistoryItem]: - del knowledge_base # Unused. - - # TODO(user): Convert this to use the far more intelligent plaso parser. - ff = Firefox3History() - for timestamp, unused_entry_type, url, title in ff.Parse(filedesc): - yield rdf_webhistory.BrowserHistoryItem( - url=url, - domain=urlparse.urlparse(url).netloc, - access_time=timestamp, - program_name="Firefox", - source_path=pathspec.CollapsePath(), - title=title) - - -# TODO(hanuszczak): This should not be a class. -class Firefox3History(object): - """Class for handling the parsing of a Firefox 3 history file.""" - - VISITS_QUERY = ("SELECT moz_historyvisits.visit_date, moz_places.url," - " moz_places.title " - "FROM moz_places, moz_historyvisits " - "WHERE moz_places.id = moz_historyvisits.place_id " - "ORDER BY moz_historyvisits.visit_date ASC;") - - # TODO(hanuszczak): This should return well-structured data. - def Parse(self, filedesc: IO[bytes]) -> Iterator[Tuple]: # pylint: disable=g-bare-generic - """Iterator returning dict for each entry in history.""" - with sqlite.IOConnection(filedesc) as conn: - for timestamp, url, title in conn.Query(self.VISITS_QUERY): - if not isinstance(timestamp, int): - timestamp = 0 - - yield timestamp, "FIREFOX3_VISIT", url, title diff --git a/grr/core/grr_response_core/lib/parsers/firefox3_history_test.py b/grr/core/grr_response_core/lib/parsers/firefox3_history_test.py deleted file mode 100644 index 324efce503..0000000000 --- a/grr/core/grr_response_core/lib/parsers/firefox3_history_test.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# Copyright 2011 Google Inc. All Rights Reserved. -"""Tests for grr.parsers.firefox3_history.""" - - -import datetime -import io -import os - -from absl import app - -from grr_response_core.lib.parsers import firefox3_history -from grr.test_lib import test_lib - - -class Firefox3HistoryTest(test_lib.GRRBaseTest): - """Test parsing of Firefox 3 history files.""" - - # places.sqlite contains a single history entry: - # 2011-07-01 11:16:21.371935, FIREFOX3_VISIT, - # http://news.google.com/, Google News - - def testBasicParsing(self): - """Test we can parse a standard file.""" - history_file = os.path.join(self.base_path, "places.sqlite") - with io.open(history_file, mode="rb") as history_filedesc: - history = firefox3_history.Firefox3History() - # Parse returns (timestamp, dtype, url, title) - entries = [x for x in history.Parse(history_filedesc)] - - self.assertLen(entries, 1) - - try: - dt1 = datetime.datetime(1970, 1, 1) - dt1 += datetime.timedelta(microseconds=entries[0][0]) - except (TypeError, ValueError): - dt1 = entries[0][0] - - self.assertEqual(str(dt1), "2011-07-01 11:16:21.371935") - self.assertEqual(entries[0][2], "http://news.google.com/") - self.assertEqual(entries[0][3], "Google News") - - def testNewHistoryFile(self): - """Tests reading of history files written by recent versions of Firefox.""" - history_file = os.path.join(self.base_path, "new_places.sqlite") - with io.open(history_file, mode="rb") as history_filedesc: - history = firefox3_history.Firefox3History() - entries = [x for x in history.Parse(history_filedesc)] - - self.assertLen(entries, 3) - self.assertEqual(entries[1][3], - "Slashdot: News for nerds, stuff that matters") - self.assertEqual(entries[2][0], 1342526323608384) - self.assertEqual(entries[2][1], "FIREFOX3_VISIT") - self.assertEqual( - entries[2][2], - "https://blog.duosecurity.com/2012/07/exploit-mitigations" - "-in-android-jelly-bean-4-1/") - - # Check that our results are properly time ordered - time_results = [x[0] for x in entries] - self.assertEqual(time_results, sorted(time_results)) - - -def main(argv): - test_lib.main(argv) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/ie_history.py b/grr/core/grr_response_core/lib/parsers/ie_history.py deleted file mode 100644 index b036d7fbfa..0000000000 --- a/grr/core/grr_response_core/lib/parsers/ie_history.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python -# pylint: disable=line-too-long -"""Parser for IE index.dat files. - -Note that this is a very naive and incomplete implementation and should be -replaced with a more intelligent one. Do not implement anything based on this -code, it is a placeholder for something real. - -For anyone who wants a useful reference, see this: -https://forensicswiki.xyz/wiki/index.php?title=Internet_Explorer_History_File_Format -https://github.com/libyal/libmsiecf/blob/master/documentation/MSIE%20Cache%20File%20%28index.dat%29%20format.asciidoc -""" -# pylint: enable=line-too-long - -import logging -import operator -import struct -from typing import IO, Iterator -from urllib import parse as urlparse - -from grr_response_core.lib import parsers -from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import webhistory as rdf_webhistory - -# Difference between 1 Jan 1601 and 1 Jan 1970. -WIN_UNIX_DIFF_MSECS = 11644473600 * 1e6 - - -class IEHistoryParser( - parsers.SingleFileParser[rdf_webhistory.BrowserHistoryItem] -): - """Parse IE index.dat files into BrowserHistoryItem objects.""" - - output_types = [rdf_webhistory.BrowserHistoryItem] - supported_artifacts = ["InternetExplorerHistory"] - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_webhistory.BrowserHistoryItem]: - del knowledge_base # Unused. - - # TODO(user): Convert this to use the far more intelligent plaso parser. - ie = IEParser(filedesc) - for dat in ie.Parse(): - access_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( - int(dat.get("mtime")) - ) - yield rdf_webhistory.BrowserHistoryItem( - url=dat["url"], - domain=urlparse.urlparse(dat["url"]).netloc, - access_time=access_time, - program_name="Internet Explorer", - source_path=pathspec.CollapsePath(), - ) - - -class IEParser(object): - """Parser object for index.dat files. - - The file format for IE index.dat files is somewhat poorly documented. - The following implementation is based on information from: - - https://forensicswiki.xyz/wiki/index.php?title=Internet_Explorer_History_File_Format - - Returns results in chronological order based on mtime - """ - - FILE_HEADER = b"Client UrlCache MMF Ver 5.2" - BLOCK_SIZE = 0x80 - - def __init__(self, input_obj): - """Initialize. - - Args: - input_obj: A file like object to read the index.dat from. - """ - self._file = input_obj - self._entries = [] - - def Parse(self): - """Parse the file.""" - if not self._file: - logging.error("Couldn't open file") - return - - # Limit read size to 5MB. - self.input_dat = self._file.read(1024 * 1024 * 5) - if not self.input_dat.startswith(self.FILE_HEADER): - logging.error("Invalid index.dat file %s", self._file) - return - - # Events aren't time ordered in the history file, so we collect them all - # then sort. - events = [] - for event in self._DoParse(): - events.append(event) - - for event in sorted(events, key=operator.itemgetter("mtime")): - yield event - - def _GetRecord(self, offset, record_size): - """Retrieve a single record from the file. - - Args: - offset: offset from start of input_dat where header starts - record_size: length of the header according to file (untrusted) - - Returns: - A dict containing a single browser history record. - """ - record_header = "<4sLQQL" - get4 = lambda x: struct.unpack(" Sequence[rdf_client.SoftwarePackage]: - parser = linux_cmd_parser.YumListCmdParser() - parsed = list( - parser.Parse( - cmd="yum", - args=["list installed"], - stdout=output.encode("utf-8"), - stderr=b"", - return_val=0, - knowledge_base=None)) - - return parsed[0].packages - - class LinuxCmdParserTest(test_lib.GRRBaseTest): """Test parsing of linux command output.""" - def testYumListCmdParser(self): - """Ensure we can extract packages from yum output.""" - parser = linux_cmd_parser.YumListCmdParser() - content = open(os.path.join(self.base_path, "yum.out"), "rb").read() - out = list( - parser.Parse("/usr/bin/yum", ["list installed -q"], content, b"", 0, - None)) - self.assertLen(out, 1) - self.assertLen(out[0].packages, 2) - package = out[0].packages[0] - self.assertIsInstance(package, rdf_client.SoftwarePackage) - self.assertEqual(package.name, "ConsoleKit") - self.assertEqual(package.architecture, "x86_64") - self.assertEqual(package.publisher, "@base") - - def testYumRepolistCmdParser(self): - """Test to see if we can get data from yum repolist output.""" - parser = linux_cmd_parser.YumRepolistCmdParser() - content = open(os.path.join(self.base_path, "repolist.out"), "rb").read() - repolist = list( - parser.Parse("/usr/bin/yum", ["repolist", "-v", "-q"], content, b"", 0, - None)) - self.assertIsInstance(repolist[0], rdf_client.PackageRepository) - - self.assertEqual(repolist[0].id, "rhel") - self.assertEqual(repolist[0].name, "rhel repo") - self.assertEqual(repolist[0].revision, "1") - self.assertEqual(repolist[0].last_update, "Sun Mar 15 08:51:32") - self.assertEqual(repolist[0].num_packages, "12") - self.assertEqual(repolist[0].size, "8 GB") - self.assertEqual(repolist[0].baseurl, "http://rhel/repo") - self.assertEqual(repolist[0].timeout, - "1200 second(s) (last: Mon Apr 1 20:30:02 2016)") - self.assertLen(repolist, 2) - def testRpmCmdParser(self): """Ensure we can extract packages from rpm output.""" parser = linux_cmd_parser.RpmCmdParser() @@ -253,7 +51,7 @@ def testRpmCmdParser(self): "keyutils-libs": "1.2-1.el5", "less": "436-9.el5", "libstdc++-devel": "4.1.2-55.el5", - "gcc-c++": "4.1.2-55.el5" + "gcc-c++": "4.1.2-55.el5", } self.assertCountEqual(expected, software) self.assertEqual("Broken rpm database.", anomaly[0].symptom) @@ -270,10 +68,14 @@ def testDpkgCmdParser(self): package_list.packages[0], rdf_client.SoftwarePackage( name="acpi-support-base", - description="scripts for handling base ACPI events such as the power button", + description=( + "scripts for handling base ACPI events such as the power button" + ), version="0.140-5", architecture="all", - install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED)) + install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED, + ), + ) self.assertEqual( package_list.packages[22], rdf_client.SoftwarePackage( @@ -281,7 +83,9 @@ def testDpkgCmdParser(self): description=None, # Test package with empty description. version="1:3.2-6", architecture="amd64", - install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED)) + install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED, + ), + ) def testDpkgCmdParserPrecise(self): """Ensure we can extract packages from dpkg output on ubuntu precise.""" @@ -300,7 +104,9 @@ def testDpkgCmdParserPrecise(self): description="add and remove users and groups", version="3.113ubuntu2", architecture=None, - install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED)) + install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED, + ), + ) self.assertEqual( package_list.packages[12], rdf_client.SoftwarePackage( @@ -308,14 +114,17 @@ def testDpkgCmdParserPrecise(self): description=None, # Test package with empty description. version="1:3.2-1ubuntu1", architecture=None, - install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED)) + install_state=rdf_client.SoftwarePackage.InstallState.INSTALLED, + ), + ) def testDmidecodeParser(self): """Test to see if we can get data from dmidecode output.""" parser = linux_cmd_parser.DmidecodeCmdParser() content = open(os.path.join(self.base_path, "dmidecode.out"), "rb").read() parse_result = list( - parser.Parse("/usr/sbin/dmidecode", ["-q"], content, b"", 0, None)) + parser.Parse("/usr/sbin/dmidecode", ["-q"], content, b"", 0, None) + ) self.assertLen(parse_result, 1) hardware = parse_result[0] @@ -324,8 +133,9 @@ def testDmidecodeParser(self): self.assertEqual(hardware.serial_number, "2UA25107BB") self.assertEqual(hardware.system_manufacturer, "Hewlett-Packard") self.assertEqual(hardware.system_product_name, "HP Z420 Workstation") - self.assertEqual(hardware.system_uuid, - "4596BF80-41F0-11E2-A3B4-10604B5C7F38") + self.assertEqual( + hardware.system_uuid, "4596BF80-41F0-11E2-A3B4-10604B5C7F38" + ) self.assertEqual(hardware.system_sku_number, "C2R51UC#ABA") self.assertEqual(hardware.system_family, "103C_53335X G=D") @@ -336,103 +146,6 @@ def testDmidecodeParser(self): self.assertEqual(hardware.bios_revision, "2.8") -class PsCmdParserTest(absltest.TestCase): - - def testRealOutput(self): - stdout = b"""\ -UID PID PPID C STIME TTY TIME CMD -root 1 0 0 Oct02 ? 00:01:35 /sbin/init splash -root 2 0 0 Oct02 ? 00:00:00 [kthreadd] -root 5 2 0 Oct02 ? 00:00:00 [kworker/0:0H] -colord 68931 1 0 Oct02 ? 00:00:00 /usr/lib/colord/colord -foobar 69081 69080 1 Oct02 ? 02:08:49 cinnamon --replace -""" - - parser = linux_cmd_parser.PsCmdParser() - processes = list(parser.Parse("/bin/ps", "-ef", stdout, b"", 0, None)) - - self.assertLen(processes, 5) - - self.assertEqual(processes[0].username, "root") - self.assertEqual(processes[0].pid, 1) - self.assertEqual(processes[0].ppid, 0) - self.assertEqual(processes[0].cpu_percent, 0.0) - self.assertEqual(processes[0].terminal, "?") - self.assertEqual(processes[0].cmdline, ["/sbin/init", "splash"]) - - self.assertEqual(processes[1].username, "root") - self.assertEqual(processes[1].pid, 2) - self.assertEqual(processes[1].ppid, 0) - self.assertEqual(processes[1].cpu_percent, 0.0) - self.assertEqual(processes[1].terminal, "?") - self.assertEqual(processes[1].cmdline, ["[kthreadd]"]) - - self.assertEqual(processes[2].username, "root") - self.assertEqual(processes[2].pid, 5) - self.assertEqual(processes[2].ppid, 2) - self.assertEqual(processes[2].cpu_percent, 0.0) - self.assertEqual(processes[2].terminal, "?") - self.assertEqual(processes[2].cmdline, ["[kworker/0:0H]"]) - - self.assertEqual(processes[3].username, "colord") - self.assertEqual(processes[3].pid, 68931) - self.assertEqual(processes[3].ppid, 1) - self.assertEqual(processes[3].cpu_percent, 0.0) - self.assertEqual(processes[3].terminal, "?") - self.assertEqual(processes[3].cmdline, ["/usr/lib/colord/colord"]) - - self.assertEqual(processes[4].username, "foobar") - self.assertEqual(processes[4].pid, 69081) - self.assertEqual(processes[4].ppid, 69080) - self.assertEqual(processes[4].cpu_percent, 1.0) - self.assertEqual(processes[4].terminal, "?") - self.assertEqual(processes[4].cmdline, ["cinnamon", "--replace"]) - - def testDoesNotFailOnIncorrectInput(self): - stdout = b"""\ -UID PID PPID C STIME TTY TIME CMD -foo 1 0 0 Sep01 ? 00:01:23 /baz/norf -bar 2 1 0 Sep02 ? 00:00:00 /baz/norf --thud --quux -THIS IS AN INVALID LINE -quux 5 2 0 Sep03 ? 00:00:00 /blargh/norf -quux ??? ??? 0 Sep04 ? 00:00:00 ??? -foo 4 2 0 Sep05 ? 00:00:00 /foo/bar/baz --quux=1337 -""" - - parser = linux_cmd_parser.PsCmdParser() - processes = list(parser.Parse("/bin/ps", "-ef", stdout, b"", 0, None)) - - self.assertLen(processes, 4) - - self.assertEqual(processes[0].username, "foo") - self.assertEqual(processes[0].pid, 1) - self.assertEqual(processes[0].ppid, 0) - self.assertEqual(processes[0].cpu_percent, 0) - self.assertEqual(processes[0].terminal, "?") - self.assertEqual(processes[0].cmdline, ["/baz/norf"]) - - self.assertEqual(processes[1].username, "bar") - self.assertEqual(processes[1].pid, 2) - self.assertEqual(processes[1].ppid, 1) - self.assertEqual(processes[1].cpu_percent, 0) - self.assertEqual(processes[1].terminal, "?") - self.assertEqual(processes[1].cmdline, ["/baz/norf", "--thud", "--quux"]) - - self.assertEqual(processes[2].username, "quux") - self.assertEqual(processes[2].pid, 5) - self.assertEqual(processes[2].ppid, 2) - self.assertEqual(processes[2].cpu_percent, 0) - self.assertEqual(processes[2].terminal, "?") - self.assertEqual(processes[2].cmdline, ["/blargh/norf"]) - - self.assertEqual(processes[3].username, "foo") - self.assertEqual(processes[3].pid, 4) - self.assertEqual(processes[3].ppid, 2) - self.assertEqual(processes[3].cpu_percent, 0) - self.assertEqual(processes[3].terminal, "?") - self.assertEqual(processes[3].cmdline, ["/foo/bar/baz", "--quux=1337"]) - - def main(args): test_lib.main(args) diff --git a/grr/core/grr_response_core/lib/parsers/linux_file_parser.py b/grr/core/grr_response_core/lib/parsers/linux_file_parser.py index c2d19a3462..d8c762a9f7 100644 --- a/grr/core/grr_response_core/lib/parsers/linux_file_parser.py +++ b/grr/core/grr_response_core/lib/parsers/linux_file_parser.py @@ -1,100 +1,15 @@ #!/usr/bin/env python """Simple parsers for Linux files.""" -import collections -import logging -import os -import re -from typing import Any -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import Optional -from typing import Text +from typing import IO, Iterator, Optional -from grr_response_core import config from grr_response_core.lib import parsers -from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.util import precondition -class PCIDevicesInfoParser(parsers.MultiFileParser[rdf_client.PCIDevice]): - """Parser for PCI devices' info files located in /sys/bus/pci/devices/*/*.""" - - output_types = [rdf_client.PCIDevice] - supported_artifacts = ["PCIDevicesInfoFiles"] - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_client.PCIDevice]: - del knowledge_base # Unused. - - # Each file gives us only partial information for a particular PCI device. - # Iterate through all the files first to create a dictionary encapsulating - # complete information for each of the PCI device on the system. We need - # all information for a PCI device before a proto for it can be created. - # We will store data in a dictionary of dictionaries that looks like this: - # data = { '0000:7f:0d.0': { 'class': '0x088000', - # 'vendor': '0x8086', - # 'device': '0x0ee1' } } - # The key is location of PCI device on system in extended B/D/F notation - # and value is a dictionary containing filename:data pairs for each file - # returned by artifact collection for that PCI device. - - # Extended B/D/F is of form "domain:bus:device.function". Compile a regex - # so we can use it to skip parsing files that don't match it. - hc = r"[0-9A-Fa-f]" - bdf_regex = re.compile(r"^%s+:%s+:%s+\.%s+" % (hc, hc, hc, hc)) - - # This will make sure that when a non-existing 'key' (PCI location) - # is accessed for the first time a new 'key':{} pair is auto-created - data = collections.defaultdict(dict) - - for pathspec, file_obj in zip(pathspecs, filedescs): - filename = pathspec.Basename() - # Location of PCI device is the name of parent directory of returned file. - bdf = pathspec.Dirname().Basename() - - # Make sure we only parse files that are under a valid B/D/F folder - if bdf_regex.match(bdf): - # Remove newlines from all files except config. Config contains raw data - # so we don't want to touch it even if it has a newline character. - file_data = file_obj.read() - if filename != "config": - file_data = file_data.rstrip(b"\n") - data[bdf][filename] = file_data - - # Now that we've captured all information for each PCI device. Let's convert - # the dictionary into a list of PCIDevice protos. - for bdf, bdf_filedata in data.items(): - pci_device = rdf_client.PCIDevice() - bdf_split = bdf.split(":") - df_split = bdf_split[2].split(".") - - # We'll convert the hex into decimal to store in the protobuf. - pci_device.domain = int(bdf_split[0], 16) - pci_device.bus = int(bdf_split[1], 16) - pci_device.device = int(df_split[0], 16) - pci_device.function = int(df_split[1], 16) - - pci_device.class_id = bdf_filedata.get("class") - pci_device.vendor = bdf_filedata.get("vendor") - pci_device.vendor_device_id = bdf_filedata.get("device") - pci_device.config = bdf_filedata.get("config") - - yield pci_device - - class PasswdParser(parsers.SingleFileParser[rdf_client.User]): """Parser for passwd files. Yields User semantic values.""" @@ -103,7 +18,7 @@ class PasswdParser(parsers.SingleFileParser[rdf_client.User]): @classmethod def ParseLine(cls, index, line) -> Optional[rdf_client.User]: - precondition.AssertType(line, Text) + precondition.AssertType(line, str) fields = "username,password,uid,gid,fullname,homedir,shell".split(",") try: @@ -116,12 +31,14 @@ def ParseLine(cls, index, line) -> Optional[rdf_client.User]: homedir=dat["homedir"], shell=dat["shell"], gid=int(dat["gid"]), - full_name=dat["fullname"]) + full_name=dat["fullname"], + ) return user - except (IndexError, KeyError): - raise parsers.ParseError("Invalid passwd file at line %d. %s" % - ((index + 1), line)) + except (IndexError, KeyError) as e: + raise parsers.ParseError( + "Invalid passwd file at line %d. %s" % ((index + 1), line) + ) from e def ParseFile( self, @@ -141,29 +58,9 @@ def ParseFile( yield user -class PasswdBufferParser(parsers.SingleResponseParser[rdf_client.User]): - """Parser for lines grepped from passwd files.""" - - output_types = [rdf_client.User] - supported_artifacts = ["LinuxPasswdHomedirs"] - - def ParseResponse( - self, - knowledge_base: rdf_client.KnowledgeBase, - response: rdfvalue.RDFValue, - ) -> Iterator[rdf_client.User]: - if not isinstance(response, rdf_file_finder.FileFinderResult): - raise TypeError(f"Unexpected response type: `{type(response)}`") - - lines = [x.data.decode("utf-8") for x in response.matches] - for index, line in enumerate(lines): - user = PasswdParser.ParseLine(index, line.strip()) - if user is not None: - yield user - - class UtmpStruct(utils.Struct): """Parse wtmp file from utmp.h.""" + _fields = [ ("h", "ut_type"), ("i", "pid"), @@ -181,738 +78,3 @@ class UtmpStruct(utils.Struct): ("i", "ip_4"), ("20s", "nothing"), ] - - -class LinuxWtmpParser(parsers.SingleFileParser[rdf_client.User]): - """Simplified parser for linux wtmp files. - - Yields User semantic values for USER_PROCESS events. - """ - - output_types = [rdf_client.User] - supported_artifacts = ["LinuxWtmp"] - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_client.User]: - del knowledge_base # Unused. - del pathspec # Unused. - - users = {} - wtmp = filedesc.read() - while wtmp: - try: - record = UtmpStruct(wtmp) - except utils.ParsingError: - break - - wtmp = wtmp[record.size:] - # Users only appear for USER_PROCESS events, others are system. - if record.ut_type != 7: - continue - - # Lose the null termination - record.user = record.user.split(b"\x00", 1)[0] - - # Store the latest login time. - # TODO(user): remove the 0 here once RDFDatetime can support times - # pre-epoch properly. - try: - users[record.user] = max(users[record.user], record.sec, 0) - except KeyError: - users[record.user] = record.sec - - for user, last_login in users.items(): - yield rdf_client.User( - username=utils.SmartUnicode(user), last_logon=last_login * 1000000) - - -class NetgroupParser(parsers.SingleFileParser[rdf_client.User]): - """Parser that extracts users from a netgroup file.""" - - output_types = [rdf_client.User] - supported_artifacts = ["NetgroupConfiguration"] - # From useradd man page - USERNAME_REGEX = r"^[a-z_][a-z0-9_-]{0,30}[$]?$" - - @classmethod - def ParseLines(cls, lines): - users = set() - filter_regexes = [ - re.compile(x) - for x in config.CONFIG["Artifacts.netgroup_filter_regexes"] - ] - username_regex = re.compile(cls.USERNAME_REGEX) - ignorelist = config.CONFIG["Artifacts.netgroup_ignore_users"] - for index, line in enumerate(lines): - if line.startswith("#"): - continue - - splitline = line.split(" ") - group_name = splitline[0] - - if filter_regexes: - filter_match = False - for regex in filter_regexes: - if regex.search(group_name): - filter_match = True - break - if not filter_match: - continue - - for member in splitline[1:]: - if member.startswith("("): - try: - _, user, _ = member.split(",") - if user not in users and user not in ignorelist: - if not username_regex.match(user): - yield rdf_anomaly.Anomaly( - type="PARSER_ANOMALY", - symptom="Invalid username: %s" % user) - else: - users.add(user) - yield rdf_client.User(username=utils.SmartUnicode(user)) - except ValueError: - raise parsers.ParseError("Invalid netgroup file at line %d: %s" % - (index + 1, line)) - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_client.User]: - """Parse the netgroup file and return User objects. - - Lines are of the form: - group1 (-,user1,) (-,user2,) (-,user3,) - - Groups are ignored, we return users in lines that match the filter regexes, - or all users in the file if no filters are specified. - - We assume usernames are in the default regex format specified in the adduser - man page. Notably no non-ASCII characters. - - Args: - knowledge_base: A knowledgebase for the client to whom the file belongs. - pathspec: A pathspec corresponding to the parsed file. - filedesc: A file-like object to parse. - - Returns: - rdf_client.User - """ - del knowledge_base # Unused. - del pathspec # Unused. - - lines = [ - l.strip() for l in utils.ReadFileBytesAsUnicode(filedesc).splitlines() - ] - return self.ParseLines(lines) - - -class NetgroupBufferParser(parsers.SingleResponseParser[rdf_client.User]): - """Parser for lines grepped from /etc/netgroup files.""" - - output_types = [rdf_client.User] - - def ParseResponse( - self, - knowledge_base: rdf_client.KnowledgeBase, - response: rdfvalue.RDFValue, - ) -> Iterator[rdf_client.User]: - if not isinstance(response, rdf_file_finder.FileFinderResult): - raise TypeError(f"Unexpected response type: `{type(response)}`") - - return NetgroupParser.ParseLines( - [x.data.decode("utf-8").strip() for x in response.matches]) - - -# TODO(hanuszczak): Subclasses of this class do not respect any types at all, -# this should be fixed. -class LinuxBaseShadowParser(parsers.MultiFileParser[Any]): - """Base parser to process user/groups with shadow files.""" - - # A list of hash types and hash matching expressions. - hashes = [("SHA512", re.compile(r"\$6\$[A-z\d\./]{0,16}\$[A-z\d\./]{86}$")), - ("SHA256", re.compile(r"\$5\$[A-z\d\./]{0,16}\$[A-z\d\./]{43}$")), - ("DISABLED", re.compile(r"!.*")), ("UNSET", re.compile(r"\*.*")), - ("MD5", re.compile(r"\$1\$([A-z\d\./]{1,8}\$)?[A-z\d\./]{22}$")), - ("DES", re.compile(r"[A-z\d\./]{2}.{11}$")), - ("BLOWFISH", re.compile(r"\$2a?\$\d\d\$[A-z\d\.\/]{22}$")), - ("NTHASH", re.compile(r"\$3\$")), ("UNUSED", re.compile(r"\$4\$"))] - - # Prevents this from automatically registering. - __abstract = True # pylint: disable=g-bad-name - - base_store = None - shadow_store = None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Entries as defined by "getent", i.e. account databases used by nsswitch. - self.entry = {} - # Shadow files - self.shadow = {} - - def GetPwStore(self, pw_attr): - """Decide if the passwd field is a passwd or a reference to shadow. - - Evaluates the contents of the password field to determine how the password - is stored. - - If blank either no password is required or no access is granted. - This behavior is system and application dependent. - - If 'x', the encrypted password is stored in /etc/shadow. - - Otherwise, the password is any other string, it's treated as an encrypted - password. - - Args: - pw_attr: The password field as a string. - - Returns: - An enum indicating the location of the password store. - """ - # PwEntry.PwStore enum values. - if pw_attr == "x": - return self.shadow_store - return self.base_store - - def GetHashType(self, hash_str): - """Identify the type of hash in a hash string. - - Args: - hash_str: A string value that may be a hash. - - Returns: - A string description of the type of hash. - """ - # Return the type of the first matching hash. - for hash_type, hash_re in self.hashes: - if hash_re.match(hash_str): - return hash_type - # No hash matched. - return "EMPTY" - - def _ParseFile(self, file_obj, line_parser): - """Process a file line by line. - - Args: - file_obj: The file to parse. - line_parser: The parser method used to process and store line content. - - Raises: - parser.ParseError if the parser is unable to process the line. - """ - lines = [ - l.strip() for l in utils.ReadFileBytesAsUnicode(file_obj).splitlines() - ] - try: - for index, line in enumerate(lines): - if line: - line_parser(line) - except (IndexError, KeyError) as e: - raise parsers.ParseError("Invalid file at line %d: %s" % (index + 1, e)) - - def ReconcileShadow(self, store_type): - """Verify that entries that claim to use shadow files have a shadow entry. - - If the entries of the non-shadowed file indicate that a shadow file is used, - check that there is actually an entry for that file in shadow. - - Args: - store_type: The type of password store that should be used (e.g. - /etc/shadow or /etc/gshadow) - """ - for k, v in self.entry.items(): - if v.pw_entry.store == store_type: - shadow_entry = self.shadow.get(k) - if shadow_entry is not None: - v.pw_entry = shadow_entry - else: - v.pw_entry.store = "UNKNOWN" - - def _Anomaly(self, msg, found): - return rdf_anomaly.Anomaly( - type="PARSER_ANOMALY", symptom=msg, finding=found) - - @staticmethod - def MemberDiff(data1, set1_name, data2, set2_name): - """Helper method to perform bidirectional set differences.""" - set1 = set(data1) - set2 = set(data2) - diffs = [] - msg = "Present in %s, missing in %s: %s" - if set1 != set2: - in_set1 = set1 - set2 - in_set2 = set2 - set1 - if in_set1: - diffs.append(msg % (set1_name, set2_name, ",".join(in_set1))) - if in_set2: - diffs.append(msg % (set2_name, set1_name, ",".join(in_set2))) - return diffs - - def ParseFiles(self, knowledge_base, pathspecs, filedescs): - del knowledge_base # Unused. - - fileset = { - pathspec.path: obj for pathspec, obj in zip(pathspecs, filedescs) - } - return self.ParseFileset(fileset) - - -class LinuxSystemGroupParser(LinuxBaseShadowParser): - """Parser for group files. Yields Group semantic values.""" - - output_types = [rdf_client.Group] - supported_artifacts = ["LoginPolicyConfiguration"] - - base_store = rdf_client.PwEntry.PwStore.GROUP - shadow_store = rdf_client.PwEntry.PwStore.GSHADOW - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.gshadow_members = {} - - def ParseGshadowEntry(self, line): - """Extract the members of each group from /etc/gshadow. - - Identifies the groups in /etc/gshadow and several attributes of the group, - including how the password is crypted (if set). - - gshadow files have the format group_name:passwd:admins:members - admins are both group members and can manage passwords and memberships. - - Args: - line: An entry in gshadow. - """ - fields = ("name", "passwd", "administrators", "members") - if line: - rslt = dict(zip(fields, line.split(":"))) - # Add the shadow state to the internal store. - name = rslt["name"] - pw_entry = self.shadow.setdefault(name, rdf_client.PwEntry()) - pw_entry.store = self.shadow_store - pw_entry.hash_type = self.GetHashType(rslt["passwd"]) - # Add the members to the internal store. - members = self.gshadow_members.setdefault(name, set()) - for accts in rslt["administrators"], rslt["members"]: - if accts: - members.update(accts.split(",")) - - def ParseGroupEntry(self, line): - """Extract the members of a group from /etc/group.""" - fields = ("name", "passwd", "gid", "members") - if line: - rslt = dict(zip(fields, line.split(":"))) - name = rslt["name"] - group = self.entry.setdefault(name, rdf_client.Group(name=name)) - group.pw_entry.store = self.GetPwStore(rslt["passwd"]) - if group.pw_entry.store == self.base_store: - group.pw_entry.hash_type = self.GetHashType(rslt["passwd"]) - # If the group contains NIS entries, they may not have a gid. - if rslt["gid"]: - group.gid = int(rslt["gid"]) - group.members = set(rslt["members"].split(",")) - - def MergeMembers(self): - """Add shadow group members to the group if gshadow is used. - - Normally group and shadow should be in sync, but no guarantees. Merges the - two stores as membership in either file may confer membership. - """ - for group_name, members in self.gshadow_members.items(): - group = self.entry.get(group_name) - if group and group.pw_entry.store == self.shadow_store: - group.members = members.union(group.members) - - def FindAnomalies(self): - """Identify any anomalous group attributes or memberships.""" - for grp_name, group in self.entry.items(): - shadow = self.shadow.get(grp_name) - gshadows = self.gshadow_members.get(grp_name, []) - if shadow is not None: - diff = self.MemberDiff(group.members, "group", gshadows, "gshadow") - if diff: - msg = "Group/gshadow members differ in group: %s" % grp_name - yield self._Anomaly(msg, diff) - - diff = self.MemberDiff(self.entry, "group", self.gshadow_members, "gshadow") - if diff: - yield self._Anomaly("Mismatched group and gshadow files.", diff) - - def ParseFileset(self, fileset=None): - """Process linux system group and gshadow files. - - Orchestrates collection of account entries from /etc/group and /etc/gshadow. - The group and gshadow entries are reconciled and member users are added to - the entry. - - Args: - fileset: A dict of files mapped from path to an open file. - - Yields: - - A series of Group entries, each of which is populated with group - [memberships and indications of the shadow state of any group password. - - A series of anomalies in cases where there are mismatches between group - and gshadow states. - """ - # Get relevant shadow attributes. - gshadow = fileset.get("/etc/gshadow") - if gshadow: - self._ParseFile(gshadow, self.ParseGshadowEntry) - else: - logging.debug("No /etc/gshadow file.") - group = fileset.get("/etc/group") - if group: - self._ParseFile(group, self.ParseGroupEntry) - else: - logging.debug("No /etc/group file.") - self.ReconcileShadow(self.shadow_store) - # Identify any anomalous group/shadow entries. - # This needs to be done before memberships are merged: merged memberships - # are the *effective* membership regardless of weird configurations. - for anom in self.FindAnomalies(): - yield anom - # Then add shadow group members to the group membership. - self.MergeMembers() - for group in self.entry.values(): - yield group - - -class LinuxSystemPasswdParser(LinuxBaseShadowParser): - """Parser for local accounts.""" - - output_types = [rdf_client.User] - supported_artifacts = ["LoginPolicyConfiguration"] - - base_store = rdf_client.PwEntry.PwStore.PASSWD - shadow_store = rdf_client.PwEntry.PwStore.SHADOW - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.groups = {} # Groups mapped by name. - self.memberships = {} # Group memberships per user. - self.uids = {} # Assigned uids - self.gids = {} # Assigned gids - - def ParseShadowEntry(self, line): - """Extract the user accounts in /etc/shadow. - - Identifies the users in /etc/shadow and several attributes of their account, - including how their password is crypted and password aging characteristics. - - Args: - line: An entry of the shadow file. - """ - fields = ("login", "passwd", "last_change", "min_age", "max_age", - "warn_time", "inactivity", "expire", "reserved") - if line: - rslt = dict(zip(fields, line.split(":"))) - pw_entry = self.shadow.setdefault(rslt["login"], rdf_client.PwEntry()) - pw_entry.store = self.shadow_store - pw_entry.hash_type = self.GetHashType(rslt["passwd"]) - # Treat carefully here in case these values aren't set. - last_change = rslt.get("last_change") - if last_change: - pw_entry.age = int(last_change) - max_age = rslt.get("max_age") - if max_age: - pw_entry.max_age = int(max_age) - - def ParsePasswdEntry(self, line): - """Process the passwd entry fields and primary group memberships.""" - fields = ("uname", "passwd", "uid", "gid", "fullname", "homedir", "shell") - if line: - rslt = dict(zip(fields, line.split(":"))) - user = self.entry.setdefault(rslt["uname"], rdf_client.User()) - user.username = rslt["uname"] - user.pw_entry.store = self.GetPwStore(rslt["passwd"]) - if user.pw_entry.store == self.base_store: - user.pw_entry.hash_type = self.GetHashType(rslt["passwd"]) - # If the passwd file contains NIS entries they may not have uid/gid set. - if rslt["uid"]: - user.uid = int(rslt["uid"]) - if rslt["gid"]: - user.gid = int(rslt["gid"]) - user.homedir = rslt["homedir"] - user.shell = rslt["shell"] - user.full_name = rslt["fullname"] - # Map uid numbers to detect duplicates. - uids = self.uids.setdefault(user.uid, set()) - uids.add(user.username) - # Map primary group memberships to populate memberships. - gid = self.gids.setdefault(user.gid, set()) - gid.add(user.username) - - def _Members(self, group): - """Unify members of a group and accounts with the group as primary gid.""" - group.members = set(group.members).union(self.gids.get(group.gid, [])) - return group - - def AddGroupMemberships(self): - """Adds aggregate group membership from group, gshadow and passwd.""" - self.groups = {g.name: self._Members(g) for g in self.groups.values()} - # Map the groups a user is a member of, irrespective of primary/extra gid. - for g in self.groups.values(): - for user in g.members: - membership = self.memberships.setdefault(user, set()) - membership.add(g.gid) - # Now add the completed membership to the user account. - for user in self.entry.values(): - user.gids = self.memberships.get(user.username) - - def FindAnomalies(self): - """Identify anomalies in the password/shadow and group/gshadow data.""" - # Find anomalous group entries. - findings = [] - group_entries = {g.gid for g in self.groups.values()} - for gid in set(self.gids) - group_entries: - undefined = ",".join(self.gids.get(gid, [])) - findings.append( - "gid %d assigned without /etc/groups entry: %s" % (gid, undefined)) - if findings: - yield self._Anomaly("Accounts with invalid gid.", findings) - - # Find any shared user IDs. - findings = [] - for uid, names in self.uids.items(): - if len(names) > 1: - findings.append("uid %d assigned to multiple accounts: %s" % - (uid, ",".join(sorted(names)))) - if findings: - yield self._Anomaly("Accounts with shared uid.", findings) - - # Find privileged groups with unusual members. - findings = [] - root_grp = self.groups.get("root") - if root_grp is not None: - root_members = sorted([m for m in root_grp.members if m != "root"]) - if root_members: - findings.append("Accounts in 'root' group: %s" % ",".join(root_members)) - if findings: - yield self._Anomaly("Privileged group with unusual members.", findings) - - # Find accounts without passwd/shadow entries. - diffs = self.MemberDiff(self.entry, "passwd", self.shadow, "shadow") - if diffs: - yield self._Anomaly("Mismatched passwd and shadow files.", diffs) - - def AddPassword(self, fileset): - """Add the passwd entries to the shadow store.""" - passwd = fileset.get("/etc/passwd") - if passwd: - self._ParseFile(passwd, self.ParsePasswdEntry) - else: - logging.debug("No /etc/passwd file.") - - def AddShadow(self, fileset): - """Add the shadow entries to the shadow store.""" - shadow = fileset.get("/etc/shadow") - if shadow: - self._ParseFile(shadow, self.ParseShadowEntry) - else: - logging.debug("No /etc/shadow file.") - - def ParseFileset(self, fileset=None): - """Process linux system login files. - - Orchestrates collection of account entries from /etc/passwd and - /etc/shadow. The passwd and shadow entries are reconciled and group - memberships are mapped to the account. - - Args: - fileset: A dict of files mapped from path to an open file. - - Yields: - - A series of User entries, each of which is populated with - group memberships and indications of the shadow state of the account. - - A series of anomalies in cases where there are mismatches between passwd - and shadow state. - """ - self.AddPassword(fileset) - self.AddShadow(fileset) - self.ReconcileShadow(self.shadow_store) - # Get group memberships using the files that were already collected. - # Separate out groups and anomalies. - for rdf in LinuxSystemGroupParser().ParseFileset(fileset): - if isinstance(rdf, rdf_client.Group): - self.groups[rdf.name] = rdf - else: - yield rdf - self.AddGroupMemberships() - for user in self.entry.values(): - yield user - for grp in self.groups.values(): - yield grp - for anom in self.FindAnomalies(): - yield anom - - -class PathParser(parsers.SingleFileParser[rdf_protodict.AttributedDict]): - """Parser for dotfile entries. - - Extracts path attributes from dotfiles to infer effective paths for users. - This parser doesn't attempt or expect to determine path state for all cases, - rather, it is a best effort attempt to detect common misconfigurations. It is - not intended to detect maliciously obfuscated path modifications. - """ - output_types = [rdf_protodict.AttributedDict] - # TODO(user): Modify once a decision is made on contextual selection of - # parsed results for artifact data. - supported_artifacts = ["RootUserShellConfigs", "ShellConfigurationFile"] - - # https://cwe.mitre.org/data/definitions/426.html - _TARGETS = ("CLASSPATH", "LD_AOUT_LIBRARY_PATH", "LD_AOUT_PRELOAD", - "LD_LIBRARY_PATH", "LD_PRELOAD", "MODULE_PATH", "PATH", - "PERL5LIB", "PERLLIB", "PYTHONPATH", "RUBYLIB") - _SH_CONTINUATION = ("{", "}", "||", "&&", "export") - - _CSH_FILES = (".login", ".cshrc", ".tcsh", "csh.cshrc", "csh.login", - "csh.logout") - # This matches "set a = (b . ../../.. )", "set a=(. b c)" etc. - _CSH_SET_RE = re.compile(r"(\w+)\s*=\s*\((.*)\)$") - - # This matches $PATH, ${PATH}, "$PATH" and "${ PATH }" etc. - # Omits more fancy parameter expansion e.g. ${unset_val:=../..} - _SHELLVAR_RE = re.compile(r'"?\$\{?\s*(\w+)\s*\}?"?') - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Terminate entries on ";" to capture multiple values on one line. - self.parser = config_file.FieldParser(term=r"[\r\n;]") - - def _ExpandPath(self, target, vals, paths): - """Extract path information, interpolating current path values as needed.""" - if target not in self._TARGETS: - return - expanded = [] - for val in vals: - # Null entries specify the current directory, so :a::b:c: is equivalent - # to .:a:.:b:c:. - shellvar = self._SHELLVAR_RE.match(val) - if not val: - expanded.append(".") - elif shellvar: - # The value may actually be in braces as well. Always convert to upper - # case so we deal with stuff like lowercase csh path. - existing = paths.get(shellvar.group(1).upper()) - if existing: - expanded.extend(existing) - else: - expanded.append(val) - else: - expanded.append(val) - paths[target] = expanded - - def _ParseShVariables(self, lines): - """Extract env_var and path values from sh derivative shells. - - Iterates over each line, word by word searching for statements that set the - path. These are either variables, or conditions that would allow a variable - to be set later in the line (e.g. export). - - Args: - lines: A list of lines, each of which is a list of space separated words. - - Returns: - a dictionary of path names and values. - """ - paths = {} - for line in lines: - for entry in line: - if "=" in entry: - # Pad out the list so that it's always 2 elements, even if the split - # failed. - target, vals = (entry.split("=", 1) + [""])[:2] - if vals: - path_vals = vals.split(":") - else: - path_vals = [] - self._ExpandPath(target, path_vals, paths) - elif entry not in self._SH_CONTINUATION: - # Stop processing the line unless the entry might allow paths to still - # be set, e.g. - # reserved words: "export" - # conditions: { PATH=VAL } && PATH=:$PATH || PATH=. - break - return paths - - def _ParseCshVariables(self, lines): - """Extract env_var and path values from csh derivative shells. - - Path attributes can be set several ways: - - setenv takes the form "setenv PATH_NAME COLON:SEPARATED:LIST" - - set takes the form "set path_name=(space separated list)" and is - automatically exported for several types of files. - - The first entry in each stanza is used to decide what context to use. - Other entries are used to identify the path name and any assigned values. - - Args: - lines: A list of lines, each of which is a list of space separated words. - - Returns: - a dictionary of path names and values. - """ - paths = {} - for line in lines: - if len(line) < 2: - continue - action = line[0] - if action == "setenv": - target = line[1] - path_vals = [] - if line[2:]: - path_vals = line[2].split(":") - self._ExpandPath(target, path_vals, paths) - elif action == "set": - set_vals = self._CSH_SET_RE.search(" ".join(line[1:])) - if set_vals: - target, vals = set_vals.groups() - # Automatically exported to ENV vars. - if target in ("path", "term", "user"): - target = target.upper() - path_vals = vals.split() - self._ExpandPath(target, path_vals, paths) - return paths - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_protodict.AttributedDict]: - """Identifies the paths set within a file. - - Expands paths within the context of the file, but does not infer fully - expanded paths from external states. There are plenty of cases where path - attributes are unresolved, e.g. sourcing other files. - - Lines are not handled literally. A field parser is used to: - - Break lines with multiple distinct statements into separate lines (e.g. - lines with a ';' separating stanzas. - - Strip out comments. - - Handle line continuations to capture multi-line configurations into one - statement. - - Args: - knowledge_base: A knowledgebase for the client to whom the file belongs. - pathspec: A pathspec corresponding to the parsed file. - filedesc: A file-like object to parse. - - Yields: - An attributed dict for each env vars. 'name' contains the path name, and - 'vals' contains its vals. - """ - del knowledge_base # Unused. - - lines = self.parser.ParseEntries(utils.ReadFileBytesAsUnicode(filedesc)) - if os.path.basename(pathspec.path) in self._CSH_FILES: - paths = self._ParseCshVariables(lines) - else: - paths = self._ParseShVariables(lines) - for path_name, path_vals in paths.items(): - yield rdf_protodict.AttributedDict( - config=pathspec.path, name=path_name, vals=path_vals) diff --git a/grr/core/grr_response_core/lib/parsers/linux_file_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_file_parser_test.py index 1519f39795..b9c2e1aad9 100644 --- a/grr/core/grr_response_core/lib/parsers/linux_file_parser_test.py +++ b/grr/core/grr_response_core/lib/parsers/linux_file_parser_test.py @@ -1,146 +1,19 @@ #!/usr/bin/env python """Unit test for the linux file parser.""" - import io -import operator -import os from absl import app from grr_response_core.lib import parsers from grr_response_core.lib.parsers import linux_file_parser -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder -from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr.test_lib import test_lib class LinuxFileParserTest(test_lib.GRRBaseTest): """Test parsing of linux files.""" - def testPCIDevicesInfoParser(self): - """Ensure we can extract PCI devices info.""" - - # Test when there's data for one PCI device only. - test_data1 = { - "/sys/bus/pci/devices/0000:00:01.0/vendor": b"0x0e00\n", - "/sys/bus/pci/devices/0000:00:01.0/class": b"0x060400\n", - "/sys/bus/pci/devices/0000:00:01.0/device": b"0x0e02\n", - "/sys/bus/pci/devices/0000:00:01.0/config": b"0200" - } - device_1 = rdf_client.PCIDevice( - domain=0, - bus=0, - device=1, - function=0, - class_id="0x060400", - vendor="0x0e00", - vendor_device_id="0x0e02", - config=b"0200") - parsed_results = self._ParsePCIDeviceTestData(test_data1) - self._MatchPCIDeviceResultToExpected(parsed_results, [device_1]) - - test_data2 = { - "/sys/bus/pci/devices/0000:00:00.0/vendor": - b"0x8086\n", - "/sys/bus/pci/devices/0000:00:00.0/class": - b"0x060000\n", - "/sys/bus/pci/devices/0000:00:00.0/device": - b"0x0e00\n", - "/sys/bus/pci/devices/0000:00:00.0/config": (b"\xea\xe8\xe7\xbc\x7a\x84" - b"\x91"), - } - device_2 = rdf_client.PCIDevice( - domain=0, - bus=0, - device=0, - function=0, - class_id="0x060000", - vendor="0x8086", - vendor_device_id="0x0e00", - config=b"\xea\xe8\xe7\xbcz\x84\x91") - parsed_results = self._ParsePCIDeviceTestData(test_data2) - self._MatchPCIDeviceResultToExpected(parsed_results, [device_2]) - - # Test for when there's missing data. - test_data3 = { - "/sys/bus/pci/devices/0000:00:03.0/vendor": b"0x0e00\n", - "/sys/bus/pci/devices/0000:00:03.0/config": b"0030" - } - device_3 = rdf_client.PCIDevice( - domain=0, bus=0, device=3, function=0, vendor="0x0e00", config=b"0030") - parsed_results = self._ParsePCIDeviceTestData(test_data3) - self._MatchPCIDeviceResultToExpected(parsed_results, [device_3]) - - # Test when data contains non-valid B/D/F folders/files. - test_data4 = { - "/sys/bus/pci/devices/0000:00:05.0/vendor": b"0x0e00\n", - "/sys/bus/pci/devices/0000:00:05.0/class": b"0x060400\n", - "/sys/bus/pci/devices/0000:00:05.0/device": b"0x0e02\n", - "/sys/bus/pci/devices/0000:00:05.0/config": b"0200", - "/sys/bus/pci/devices/crazyrandomfile/test1": b"test1", - "/sys/bus/pci/devices/::./test2": b"test2", - "/sys/bus/pci/devices/00:5.0/test3": b"test3" - } - device_4 = rdf_client.PCIDevice( - domain=0, - bus=0, - device=5, - function=0, - class_id="0x060400", - vendor="0x0e00", - vendor_device_id="0x0e02", - config=b"0200") - parsed_results = self._ParsePCIDeviceTestData(test_data4) - self._MatchPCIDeviceResultToExpected(parsed_results, [device_4]) - - # Test when there's multiple PCI devices in the test_data. - combined_data = test_data1.copy() - combined_data.update(test_data3) - combined_data.update(test_data4) - combined_data.update(test_data2) - parsed_results = self._ParsePCIDeviceTestData(combined_data) - self._MatchPCIDeviceResultToExpected( - parsed_results, [device_1, device_4, device_2, device_3]) - - def _ParsePCIDeviceTestData(self, test_data): - """Given test_data dictionary, parse it using PCIDevicesInfoParser.""" - parser = linux_file_parser.PCIDevicesInfoParser() - pathspecs = [] - file_objs = [] - - # Populate stats, file_ojbs, kb_ojbs lists needed by the parser. - for filename, data in test_data.items(): - pathspec = rdf_paths.PathSpec(path=filename, pathtype="OS") - file_obj = io.BytesIO(data) - pathspecs.append(pathspec) - file_objs.append(file_obj) - - return list(parser.ParseFiles(None, pathspecs, file_objs)) - - def _MatchPCIDeviceResultToExpected(self, parsed_results, expected_output): - """Make sure the parsed_results match expected_output.""" - - # Check the size matches. - self.assertLen(parsed_results, len(expected_output)) - - # Sort parsed_results and expected_outputs so we're comparing properly. - results = sorted(parsed_results, key=operator.attrgetter("device")) - outputs = sorted(expected_output, key=operator.attrgetter("device")) - - # Check all the content matches. - for result, output in zip(results, outputs): - self.assertEqual(result.domain, output.domain) - self.assertEqual(result.bus, output.bus) - self.assertEqual(result.device, output.device) - self.assertEqual(result.function, output.function) - self.assertEqual(result.class_id, output.class_id) - self.assertEqual(result.vendor, output.vendor) - self.assertEqual(result.vendor_device_id, output.vendor_device_id) - self.assertEqual(result.config, output.config) - def testPasswdParser(self): """Ensure we can extract users from a passwd file.""" parser = linux_file_parser.PasswdParser() @@ -162,349 +35,6 @@ def testPasswdParser(self): with self.assertRaises(parsers.ParseError): list(parser.ParseFile(None, None, io.BytesIO(dat))) - def testPasswdBufferParser(self): - """Ensure we can extract users from a passwd file.""" - parser = linux_file_parser.PasswdBufferParser() - buf1 = rdf_client.BufferReference( - data=b"user1:x:1000:1000:User1 Name,,,:/home/user1:/bin/bash\n") - buf2 = rdf_client.BufferReference( - data=b"user2:x:1000:1000:User2 Name,,,:/home/user2:/bin/bash\n") - - ff_result = rdf_file_finder.FileFinderResult(matches=[buf1, buf2]) - out = list(parser.ParseResponse(rdf_client.KnowledgeBase(), ff_result)) - self.assertLen(out, 2) - self.assertIsInstance(out[1], rdf_client.User) - self.assertIsInstance(out[1], rdf_client.User) - self.assertEqual(out[0].username, "user1") - self.assertEqual(out[0].full_name, "User1 Name,,,") - - def testNetgroupParser(self): - """Ensure we can extract users from a netgroup file.""" - parser = linux_file_parser.NetgroupParser() - dat = """group1 (-,user1,) (-,user2,) (-,user3,) -#group1 comment -group2 (-,user4,) (-,user2,) - -super_group (-,user5,) (-,user6,) (-,文德文,) group1 group2 -super_group2 (-,user7,) super_group -super_group3 (-,user5,) (-,user6,) group1 group2 -""" - dat_fd = io.BytesIO(dat.encode("utf-8")) - - with test_lib.ConfigOverrider( - {"Artifacts.netgroup_ignore_users": ["user2", "user3"]}): - out = list(parser.ParseFile(None, None, dat_fd)) - users = [] - for result in out: - if isinstance(result, rdf_anomaly.Anomaly): - self.assertIn("文德文", result.symptom) - else: - users.append(result) - - self.assertCountEqual([x.username for x in users], - [u"user1", u"user4", u"user5", u"user6", u"user7"]) - - dat_fd.seek(0) - - with test_lib.ConfigOverrider( - {"Artifacts.netgroup_filter_regexes": [r"^super_group3$"]}): - out = list(parser.ParseFile(None, None, dat_fd)) - self.assertCountEqual([x.username for x in out], [u"user5", u"user6"]) - - def testNetgroupBufferParser(self): - """Ensure we can extract users from a netgroup file.""" - parser = linux_file_parser.NetgroupBufferParser() - buf1 = rdf_client.BufferReference( - data=b"group1 (-,user1,) (-,user2,) (-,user3,)\n") - buf2 = rdf_client.BufferReference( - data=b"super_group3 (-,user5,) (-,user6,) group1 group2\n") - - ff_result = rdf_file_finder.FileFinderResult(matches=[buf1, buf2]) - with test_lib.ConfigOverrider( - {"Artifacts.netgroup_ignore_users": ["user2", "user3"]}): - out = list(parser.ParseResponse(rdf_client.KnowledgeBase, ff_result)) - self.assertCountEqual([x.username for x in out], - [u"user1", u"user5", u"user6"]) - - def testNetgroupParserBadInput(self): - parser = linux_file_parser.NetgroupParser() - dat = b"""group1 (-,user1,) (-,user2,) (-,user3,) -#group1 comment -group2 user4 (-user2,) -super_group (-,,user5,) (-user6,) group1 group2 -super_group2 (-,user7,) super_group -""" - with self.assertRaises(parsers.ParseError): - list(parser.ParseFile(None, None, io.BytesIO(dat))) - - def testWtmpParser(self): - """Test parsing of wtmp file.""" - parser = linux_file_parser.LinuxWtmpParser() - path = os.path.join(self.base_path, "VFSFixture/var/log/wtmp") - with open(path, "rb") as wtmp_fd: - out = list(parser.ParseFile(None, None, wtmp_fd)) - - self.assertLen(out, 3) - self.assertCountEqual(["%s:%d" % (x.username, x.last_logon) for x in out], [ - "user1:1296552099000000", "user2:1296552102000000", - "user3:1296569997000000" - ]) - - -class LinuxShadowParserTest(test_lib.GRRBaseTest): - """Test parsing of linux shadow files.""" - - crypt = { - "DES": "A.root/o0tr.o", - "MD5": "$1$roo/root/o07r.0tROOTro", - "SHA256": "$5$sal/s.lt5a17${0}".format("r" * 43), - "SHA512": "$6$sa./sa1T${0}".format("r" * 86), - "UNSET": "*", - "DISABLED": "!$1$roo/rootroootROOO0oooTroooooooo", - "EMPTY": "" - } - - def _GenFiles(self, passwd, shadow, group, gshadow): - pathspecs = [] - files = [] - for path in ["/etc/passwd", "/etc/shadow", "/etc/group", "/etc/gshadow"]: - pathspecs.append(rdf_paths.PathSpec(path=path)) - for data in passwd, shadow, group, gshadow: - if data is None: - data = [] - lines = "\n".join(data).format(**self.crypt).encode("utf-8") - files.append(io.BytesIO(lines)) - return pathspecs, files - - def testNoAnomaliesWhenEverythingIsFine(self): - passwd = [ - "ok_1:x:1000:1000::/home/ok_1:/bin/bash", - "ok_2:x:1001:1001::/home/ok_2:/bin/bash" - ] - shadow = [ - "ok_1:{SHA256}:16000:0:99999:7:::", "ok_2:{SHA512}:16000:0:99999:7:::" - ] - group = ["ok_1:x:1000:ok_1", "ok_2:x:1001:ok_2"] - gshadow = ["ok_1:::ok_1", "ok_2:::ok_2"] - pathspecs, files = self._GenFiles(passwd, shadow, group, gshadow) - parser = linux_file_parser.LinuxSystemPasswdParser() - rdfs = parser.ParseFiles(None, pathspecs, files) - results = [r for r in rdfs if isinstance(r, rdf_anomaly.Anomaly)] - self.assertFalse(results) - - def testSystemGroupParserAnomaly(self): - """Detect anomalies in group/gshadow files.""" - group = [ - "root:x:0:root,usr1", "adm:x:1:syslog,usr1", - "users:x:1000:usr1,usr2,usr3,usr4" - ] - gshadow = ["root::usr4:root", "users:{DES}:usr1:usr2,usr3,usr4"] - pathspecs, files = self._GenFiles(None, None, group, gshadow) - - # Set up expected anomalies. - member = { - "symptom": - "Group/gshadow members differ in group: root", - "finding": [ - "Present in group, missing in gshadow: usr1", - "Present in gshadow, missing in group: usr4" - ], - "type": - "PARSER_ANOMALY" - } - group = { - "symptom": "Mismatched group and gshadow files.", - "finding": ["Present in group, missing in gshadow: adm"], - "type": "PARSER_ANOMALY" - } - expected = [rdf_anomaly.Anomaly(**member), rdf_anomaly.Anomaly(**group)] - - parser = linux_file_parser.LinuxSystemGroupParser() - rdfs = parser.ParseFiles(None, pathspecs, files) - results = [r for r in rdfs if isinstance(r, rdf_anomaly.Anomaly)] - self.assertEqual(expected, results) - - def testSystemAccountAnomaly(self): - passwd = [ - "root:x:0:0::/root:/bin/sash", - "miss:x:1000:100:Missing:/home/miss:/bin/bash", - "bad1:x:0:1001:Bad 1:/home/bad1:/bin/bash", - "bad2:x:1002:0:Bad 2:/home/bad2:/bin/bash" - ] - shadow = [ - "root:{UNSET}:16000:0:99999:7:::", "ok:{SHA512}:16000:0:99999:7:::", - "bad1::16333:0:99999:7:::", "bad2:{DES}:16333:0:99999:7:::" - ] - group = [ - "root:x:0:root", "miss:x:1000:miss", "bad1:x:1001:bad1", - "bad2:x:1002:bad2" - ] - gshadow = ["root:::root", "miss:::miss", "bad1:::bad1", "bad2:::bad2"] - pathspecs, files = self._GenFiles(passwd, shadow, group, gshadow) - - no_grp = { - "symptom": "Accounts with invalid gid.", - "finding": ["gid 100 assigned without /etc/groups entry: miss"], - "type": "PARSER_ANOMALY" - } - uid = { - "symptom": "Accounts with shared uid.", - "finding": ["uid 0 assigned to multiple accounts: bad1,root"], - "type": "PARSER_ANOMALY" - } - gid = { - "symptom": "Privileged group with unusual members.", - "finding": ["Accounts in 'root' group: bad2"], - "type": "PARSER_ANOMALY" - } - no_match = { - "symptom": - "Mismatched passwd and shadow files.", - "finding": [ - "Present in passwd, missing in shadow: miss", - "Present in shadow, missing in passwd: ok" - ], - "type": - "PARSER_ANOMALY" - } - expected = [ - rdf_anomaly.Anomaly(**no_grp), - rdf_anomaly.Anomaly(**uid), - rdf_anomaly.Anomaly(**gid), - rdf_anomaly.Anomaly(**no_match) - ] - - parser = linux_file_parser.LinuxSystemPasswdParser() - rdfs = parser.ParseFiles(None, pathspecs, files) - results = [r for r in rdfs if isinstance(r, rdf_anomaly.Anomaly)] - - self.assertLen(results, len(expected)) - for expect, result in zip(expected, results): - self.assertEqual(expect.symptom, result.symptom) - # Expand out repeated field helper. - self.assertCountEqual(list(expect.finding), list(result.finding)) - self.assertEqual(expect.type, result.type) - - def GetExpectedUser(self, algo, user_store, group_store): - user = rdf_client.User( - username="user", - full_name="User", - uid="1001", - gid="1001", - homedir="/home/user", - shell="/bin/bash") - user.pw_entry = rdf_client.PwEntry(store=user_store, hash_type=algo) - user.gids = [1001] - grp = rdf_client.Group(gid=1001, members=["user"], name="user") - grp.pw_entry = rdf_client.PwEntry(store=group_store, hash_type=algo) - return user, grp - - def CheckExpectedUser(self, algo, expect, result): - self.assertEqual(expect.username, result.username) - self.assertEqual(expect.gid, result.gid) - self.assertEqual(expect.pw_entry.store, result.pw_entry.store) - self.assertEqual(expect.pw_entry.hash_type, result.pw_entry.hash_type) - self.assertCountEqual(expect.gids, result.gids) - - def CheckExpectedGroup(self, algo, expect, result): - self.assertEqual(expect.name, result.name) - self.assertEqual(expect.gid, result.gid) - self.assertEqual(expect.pw_entry.store, result.pw_entry.store) - self.assertEqual(expect.pw_entry.hash_type, result.pw_entry.hash_type) - - def CheckCryptResults(self, passwd, shadow, group, gshadow, algo, usr, grp): - pathspecs, files = self._GenFiles(passwd, shadow, group, gshadow) - parser = linux_file_parser.LinuxSystemPasswdParser() - results = list(parser.ParseFiles(None, pathspecs, files)) - usrs = [r for r in results if isinstance(r, rdf_client.User)] - grps = [r for r in results if isinstance(r, rdf_client.Group)] - self.assertLen(usrs, 1, "Different number of usr %s results" % algo) - self.assertLen(grps, 1, "Different number of grp %s results" % algo) - self.CheckExpectedUser(algo, usr, usrs[0]) - self.CheckExpectedGroup(algo, grp, grps[0]) - - def testSetShadowedEntries(self): - passwd = ["user:x:1001:1001:User:/home/user:/bin/bash"] - group = ["user:x:1001:user"] - for algo, crypted in self.crypt.items(): - # Flush the parser for each iteration. - shadow = ["user:%s:16000:0:99999:7:::" % crypted] - gshadow = ["user:%s::user" % crypted] - usr, grp = self.GetExpectedUser(algo, "SHADOW", "GSHADOW") - self.CheckCryptResults(passwd, shadow, group, gshadow, algo, usr, grp) - - def testSetNonShadowedEntries(self): - shadow = ["user::16000:0:99999:7:::"] - gshadow = ["user:::user"] - for algo, crypted in self.crypt.items(): - # Flush the parser for each iteration. - passwd = ["user:%s:1001:1001:User:/home/user:/bin/bash" % crypted] - group = ["user:%s:1001:user" % crypted] - usr, grp = self.GetExpectedUser(algo, "PASSWD", "GROUP") - self.CheckCryptResults(passwd, shadow, group, gshadow, algo, usr, grp) - - -class LinuxDotFileParserTest(test_lib.GRRBaseTest): - """Test parsing of user dotfiles.""" - - def testFindPaths(self): - # TODO(user): Deal with cases where multiple vars are exported. - # export TERM PERLLIB=.:shouldntbeignored - bashrc_data = io.BytesIO(b""" - IGNORE='bad' PATH=${HOME}/bin:$PATH - { PYTHONPATH=/path1:/path2 } - export TERM=screen-256color - export http_proxy="http://proxy.example.org:3128/" - export HTTP_PROXY=$http_proxy - if [[ "$some_condition" ]]; then - export PATH=:$PATH; LD_LIBRARY_PATH=foo:bar:$LD_LIBRARY_PATH - PYTHONPATH=$PATH:"${PYTHONPATH}" - CLASSPATH= - fi - echo PATH=/should/be/ignored - # Ignore PATH=foo:bar - TERM=vt100 PS=" Foo" PERL5LIB=:shouldntbeignored - """) - cshrc_data = io.BytesIO(b""" - setenv PATH ${HOME}/bin:$PATH - setenv PYTHONPATH /path1:/path2 - set term = (screen-256color) - setenv http_proxy "http://proxy.example.org:3128/" - setenv HTTP_PROXY $http_proxy - if ( -e "$some_condition" ) then - set path = (. $path); setenv LD_LIBRARY_PATH foo:bar:$LD_LIBRARY_PATH - setenv PYTHONPATH $PATH:"${PYTHONPATH}" - setenv CLASSPATH - endif - echo PATH=/should/be/ignored - setenv PERL5LIB :shouldntbeignored - """) - parser = linux_file_parser.PathParser() - bashrc_pathspec = rdf_paths.PathSpec.OS(path="/home/user1/.bashrc") - cshrc_pathspec = rdf_paths.PathSpec.OS(path="/home/user1/.cshrc") - bashrc = { - r.name: r.vals - for r in parser.ParseFile(None, bashrc_pathspec, bashrc_data) - } - cshrc = { - r.name: r.vals - for r in parser.ParseFile(None, cshrc_pathspec, cshrc_data) - } - expected = { - "PATH": [".", "${HOME}/bin", "$PATH"], - "PYTHONPATH": [".", "${HOME}/bin", "$PATH", "/path1", "/path2"], - "LD_LIBRARY_PATH": ["foo", "bar", "$LD_LIBRARY_PATH"], - "CLASSPATH": [], - "PERL5LIB": [".", "shouldntbeignored"] - } - # Got the same environment variables for bash and cshrc files. - self.assertCountEqual(expected, bashrc) - self.assertCountEqual(expected, cshrc) - # The path values are expanded correctly. - for var_name in ("PATH", "PYTHONPATH", "LD_LIBRARY_PATH"): - self.assertEqual(expected[var_name], bashrc[var_name]) - self.assertEqual(expected[var_name], cshrc[var_name]) - def main(args): test_lib.main(args) diff --git a/grr/core/grr_response_core/lib/parsers/linux_pam_parser.py b/grr/core/grr_response_core/lib/parsers/linux_pam_parser.py deleted file mode 100644 index 779e9c1261..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_pam_parser.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python -"""Parsers for Linux PAM configuration files.""" - -import os -import re -from typing import IO -from typing import Iterable -from typing import Iterator - -from grr_response_core.lib import parsers -from grr_response_core.lib import utils -from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import config_file as rdf_config_file -from grr_response_core.lib.rdfvalues import paths as rdf_paths - - -class PAMFieldParser(config_file.FieldParser): - """Field parser for PAM configurations.""" - - # The syntax is based on: - # http://linux.die.net/man/5/pam.d - - PAMDIR = "/etc/pam.d" - OLD_PAMCONF_FILENAME = "/etc/pam.conf" - PAMCONF_RE = re.compile( - r""" - (\S+) # The "type". - \s+ # separator - ( # Now match the "control" argument. - \[[^\]]*\] # Complex form. e.g. [success=ok default=die] etc. - | \w+ # Or a single word form. - ) # End of the "control" argument. - \s+ # separator - (\S+) # The "module-path". - (?:\s+(.*))? # And the optional "module-arguments" is anything else. - """, re.VERBOSE) - - def _FixPath(self, path): - # Anchor any relative paths in the PAMDIR - if not os.path.isabs(path): - return os.path.join(self.PAMDIR, path) - else: - return path - - def EnumerateAllConfigs(self, pathspecs, file_objects): - """Generate RDFs for the fully expanded configs. - - Args: - pathspecs: A list of pathspecs corresponding to the file_objects. - file_objects: A list of file handles. - - Returns: - A tuple of a list of RDFValue PamConfigEntries found & a list of strings - which are the external config references found. - """ - # Convert the stats & file_objects into a cache of a - # simple path keyed dict of file contents. - cache = {} - for pathspec, file_obj in zip(pathspecs, file_objects): - cache[pathspec.path] = utils.ReadFileBytesAsUnicode(file_obj) - - result = [] - external = [] - # Check to see if we have the old pam config file laying around. - if self.OLD_PAMCONF_FILENAME in cache: - # The PAM documentation says if it contains config data, then - # it takes precedence over the rest of the config. - # If it doesn't, the rest of the PAMDIR config counts. - result, external = self.EnumerateConfig(None, self.OLD_PAMCONF_FILENAME, - cache) - if result: - return result, external - - # If we made it here, there isn't a old-style pam.conf file worth - # speaking of, so process everything! - for path in cache: - # PAM uses the basename as the 'service' id. - service = os.path.basename(path) - r, e = self.EnumerateConfig(service, path, cache) - result.extend(r) - external.extend(e) - return result, external - - def EnumerateConfig(self, service, path, cache, filter_type=None): - """Return PamConfigEntries it finds as it recursively follows PAM configs. - - Args: - service: A string containing the service name we are processing. - path: A string containing the file path name we want. - cache: A dictionary keyed on path, with the file contents (list of str). - filter_type: A string containing type name of the results we want. - - Returns: - A tuple of a list of RDFValue PamConfigEntries found & a list of strings - which are the external config references found. - """ - - result = [] - external = [] - path = self._FixPath(path) - - # Make sure we only look at files under PAMDIR. - # Check we have the file in our artifact/cache. If not, our artifact - # didn't give it to us, and that's a problem. - # Note: This should only ever happen if it was referenced - # from /etc/pam.conf so we can assume that was the file. - if path not in cache: - external.append("%s -> %s", self.OLD_PAMCONF_FILENAME, path) - return result, external - - for tokens in self.ParseEntries(cache[path]): - if path == self.OLD_PAMCONF_FILENAME: - # We are processing the old style PAM conf file. It's a special case. - # It's format is "service type control module-path module-arguments" - # i.e. the 'service' is the first arg, the rest is line - # is like everything else except for that addition. - try: - service = tokens[0] # Grab the service from the start line. - tokens = tokens[1:] # Make the rest of the line look like "normal". - except IndexError: - continue # It's a blank line, skip it. - - # Process any inclusions in the line. - new_path = None - filter_request = None - try: - # If a line starts with @include, then include the entire referenced - # file. - # e.g. "@include common-auth" - if tokens[0] == "@include": - new_path = tokens[1] - # If a line's second arg is an include/substack, then filter the - # referenced file only including entries that match the 'type' - # requested. - # e.g. "auth include common-auth-screensaver" - elif tokens[1] in ["include", "substack"]: - new_path = tokens[2] - filter_request = tokens[0] - except IndexError: - # It's not a valid include line, so keep processing as normal. - pass - - # If we found an include file, enumerate that file now, and - # included it where we are in this config file. - if new_path: - # Preemptively check to see if we have a problem where the config - # is referencing a file outside of the expected/defined artifact. - # Doing it here allows us to produce a better context for the - # problem. Hence the slight duplication of code. - - new_path = self._FixPath(new_path) - if new_path not in cache: - external.append("%s -> %s" % (path, new_path)) - continue # Skip to the next line of the file. - r, e = self.EnumerateConfig(service, new_path, cache, filter_request) - result.extend(r) - external.extend(e) - else: - # If we have been asked to filter on types, skip over any types - # we are not interested in. - if filter_type and tokens[0] != filter_type: - continue # We can skip this line. - - # If we got here, then we want to include this line in this service's - # config. - - # Reform the line and break into the correct fields as best we can. - # Note: ParseEntries doesn't cope with what we need to do. - match = self.PAMCONF_RE.match(" ".join(tokens)) - if match: - p_type, control, module_path, module_args = match.group(1, 2, 3, 4) - # Trim a leading "-" from the type field if present. - if p_type.startswith("-"): - p_type = p_type[1:] - result.append( - rdf_config_file.PamConfigEntry( - service=service, - type=p_type, - control=control, - module_path=module_path, - module_args=module_args)) - return result, external - - -class PAMParser(parsers.MultiFileParser[rdf_config_file.PamConfig]): - """Artifact parser for PAM configurations.""" - - output_types = [rdf_config_file.PamConfig] - supported_artifacts = ["LinuxPamConfigs"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._field_parser = PAMFieldParser() - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_config_file.PamConfig]: - del knowledge_base # Unused. - - results, externals = self._field_parser.EnumerateAllConfigs( - pathspecs, filedescs) - yield rdf_config_file.PamConfig(entries=results, external_config=externals) diff --git a/grr/core/grr_response_core/lib/parsers/linux_pam_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_pam_parser_test.py deleted file mode 100644 index 76c30573af..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_pam_parser_test.py +++ /dev/null @@ -1,236 +0,0 @@ -#!/usr/bin/env python -"""Unit test for the linux pam config parser.""" - - -import platform -import unittest - -from absl import app - -from grr_response_core.lib.parsers import linux_pam_parser -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import config_file as rdf_config_file -from grr.test_lib import artifact_test_lib -from grr.test_lib import test_lib - -ETC_PAM_CONF_EMPTY = b""" -# Nothing to do here. - # white space - -# ^ blank line -""" -ETC_PAM_CONF_SIMPLE = b""" -ssh auth required test.so -telnet auth required unix.so -ssh session required pam_limits.so -""" -ETC_PAM_CONF_COMPLEX = ETC_PAM_CONF_SIMPLE + b""" -telnet account include filt_include -ssh @include full_include -""" -ETC_PAMD_FILT_INCLUDE = b""" -account required pam_nologin.so -auth required pam_env.so envfile=/etc/default/locale -""" -ETC_PAMD_FULL_INCLUDE = ETC_PAMD_FILT_INCLUDE -ETC_PAMD_SSH = b""" -auth required test.so # Comment -session required pam_limits.so random=option # Comment -account include filt_include # only include 'account' entries from file. -@include full_include # Include everything from file 'full_include' -""" -ETC_PAMD_TELNET = b""" -# Blank line - -# Multi line and 'type' with a leading '-'. --auth [success=ok new_authtok_reqd=ok ignore=ignore default=bad] \ - testing.so module arguments # Comments -""" -ETC_PAMD_EXTERNAL = b""" -password substack nonexistant -auth optional testing.so -@include /external/nonexistant -""" - -TELNET_ONLY_CONFIG = {'/etc/pam.d/telnet': ETC_PAMD_TELNET} -TELNET_ONLY_CONFIG_EXPECTED = [ - ('telnet', 'auth', - '[success=ok new_authtok_reqd=ok ignore=ignore default=bad]', 'testing.so', - 'module arguments') -] - -TELNET_WITH_PAMCONF = { - '/etc/pam.conf': ETC_PAM_CONF_EMPTY, - '/etc/pam.d/telnet': ETC_PAMD_TELNET -} -TELNET_WITH_PAMCONF_EXPECTED = TELNET_ONLY_CONFIG_EXPECTED - -PAM_CONF_SIMPLE = {'/etc/pam.conf': ETC_PAM_CONF_SIMPLE} -PAM_CONF_SIMPLE_EXPECTED = [('ssh', 'auth', 'required', 'test.so', ''), - ('telnet', 'auth', 'required', 'unix.so', ''), - ('ssh', 'session', 'required', 'pam_limits.so', '')] - -PAM_CONF_OVERRIDE = { - '/etc/pam.conf': ETC_PAM_CONF_SIMPLE, - '/etc/pam.d/telnet': ETC_PAMD_TELNET -} -PAM_CONF_OVERRIDE_EXPECTED = PAM_CONF_SIMPLE_EXPECTED - -PAM_CONF_OVERRIDE_COMPLEX = { - '/etc/pam.conf': ETC_PAM_CONF_COMPLEX, - '/etc/pam.d/ssh': ETC_PAMD_SSH, - '/etc/pam.d/full_include': ETC_PAMD_FULL_INCLUDE, - '/etc/pam.d/filt_include': ETC_PAMD_FILT_INCLUDE, - '/etc/pam.d/telnet': ETC_PAMD_TELNET -} -PAM_CONF_OVERRIDE_COMPLEX_EXPECTED = PAM_CONF_SIMPLE_EXPECTED + [ - ('telnet', 'account', 'required', 'pam_nologin.so', ''), - ('ssh', 'account', 'required', 'pam_nologin.so', ''), - ('ssh', 'auth', 'required', 'pam_env.so', 'envfile=/etc/default/locale') -] - -PAM_CONF_TYPICAL = { - '/etc/pam.conf': ETC_PAM_CONF_EMPTY, - '/etc/pam.d/ssh': ETC_PAMD_SSH, - '/etc/pam.d/full_include': ETC_PAMD_FULL_INCLUDE, - '/etc/pam.d/filt_include': ETC_PAMD_FILT_INCLUDE, - '/etc/pam.d/telnet': ETC_PAMD_TELNET -} -PAM_CONF_TYPICAL_EXPECTED = TELNET_ONLY_CONFIG_EXPECTED + [ - ('ssh', 'auth', 'required', 'test.so', ''), - ('ssh', 'session', 'required', 'pam_limits.so', 'random=option'), - ('ssh', 'account', 'required', 'pam_nologin.so', ''), - ('ssh', 'account', 'required', 'pam_nologin.so', ''), - ('ssh', 'auth', 'required', 'pam_env.so', 'envfile=/etc/default/locale'), - ('filt_include', 'account', 'required', 'pam_nologin.so', ''), - ('filt_include', 'auth', 'required', 'pam_env.so', - 'envfile=/etc/default/locale'), - ('full_include', 'account', 'required', 'pam_nologin.so', ''), - ('full_include', 'auth', 'required', 'pam_env.so', - 'envfile=/etc/default/locale') -] - -PAM_CONF_EXTERNAL_REF = { - '/etc/pam.conf': ETC_PAM_CONF_EMPTY, - '/etc/pam.d/external': ETC_PAMD_EXTERNAL -} -PAM_CONF_EXTERNAL_REF_EXPECTED = [('external', 'auth', 'optional', 'testing.so', - '')] -PAM_CONF_EXTERNAL_REF_ERRORS = [ - '/etc/pam.d/external -> /etc/pam.d/nonexistant', - '/etc/pam.d/external -> /external/nonexistant' -] - - -# TODO: This test fails on Windows, but could theoretically pass. -@unittest.skipIf(platform.system() == 'Windows', - 'Test fails on Windows (but is non-criticial for Windows).') -class LinuxPAMParserTest(test_lib.GRRBaseTest): - """Test parsing of PAM config files.""" - - def setUp(self): - super().setUp() - self.kb = rdf_client.KnowledgeBase(fqdn='test.example.com', os='Linux') - - def _EntryToTuple(self, entry): - return (entry.service, entry.type, entry.control, entry.module_path, - entry.module_args) - - def _EntriesToTuples(self, entries): - return [self._EntryToTuple(x) for x in entries] - - def testParseMultiple(self): - """Tests for the ParseMultiple() method.""" - parser = linux_pam_parser.PAMParser() - - # Parse the simplest 'normal' config we can. - # e.g. a single entry for 'telnet' with no includes etc. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - TELNET_ONLY_CONFIG) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(TELNET_ONLY_CONFIG_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse the simplest 'normal' config we can but with an effectively - # empty /etc/pam.conf file. - # e.g. a single entry for 'telnet' with no includes etc. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - TELNET_WITH_PAMCONF) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - entry = out[0].entries[0] - self.assertEqual( - ('telnet', 'auth', - '[success=ok new_authtok_reqd=ok ignore=ignore default=bad]', - 'testing.so', 'module arguments'), self._EntryToTuple(entry)) - self.assertCountEqual(TELNET_WITH_PAMCONF_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse a simple old-style pam config. i.e. Just /etc/pam.conf. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - PAM_CONF_SIMPLE) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(PAM_CONF_SIMPLE_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse a simple old-style pam config overriding a 'new' style config. - # i.e. Configs in /etc/pam.conf override everything else. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - PAM_CONF_OVERRIDE) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(PAM_CONF_OVERRIDE_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse a complex old-style pam config overriding a 'new' style config but - # the /etc/pam.conf includes parts from the /etc/pam.d dir. - # i.e. Configs in /etc/pam.conf override everything else but imports stuff. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - PAM_CONF_OVERRIDE_COMPLEX) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(PAM_CONF_OVERRIDE_COMPLEX_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse a normal-looking pam configuration. - # i.e. A no-op of a /etc/pam.conf with multiple files under /etc/pam.d - # that have includes etc. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - PAM_CONF_TYPICAL) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(PAM_CONF_TYPICAL_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertEqual([], out[0].external_config) - - # Parse a config which has references to external or missing files. - pathspecs, file_objs = artifact_test_lib.GenPathspecFileData( - PAM_CONF_EXTERNAL_REF) - out = list(parser.ParseFiles(self.kb, pathspecs, file_objs)) - self.assertLen(out, 1) - self.assertIsInstance(out[0], rdf_config_file.PamConfig) - self.assertCountEqual(PAM_CONF_EXTERNAL_REF_EXPECTED, - self._EntriesToTuples(out[0].entries)) - self.assertCountEqual(PAM_CONF_EXTERNAL_REF_ERRORS, - list(out[0].external_config)) - - -def main(args): - test_lib.main(args) - - -if __name__ == '__main__': - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/linux_release_parser.py b/grr/core/grr_response_core/lib/parsers/linux_release_parser.py index 3112bc269b..bd863b4da7 100644 --- a/grr/core/grr_response_core/lib/parsers/linux_release_parser.py +++ b/grr/core/grr_response_core/lib/parsers/linux_release_parser.py @@ -3,11 +3,7 @@ import collections import re - -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import Text +from typing import IO, Iterable, Iterator from grr_response_core.lib import parsers from grr_response_core.lib import utils @@ -22,11 +18,12 @@ _SYSTEMD_OS_RELEASE_VERSION = 'VERSION_ID' ParsedRelease = collections.namedtuple('ParsedRelease', 'release, major, minor') -WeightedReleaseFile = collections.namedtuple('WeightedReleaseFile', - 'weight, path, processor') +WeightedReleaseFile = collections.namedtuple( + 'WeightedReleaseFile', 'weight, path, processor' +) -class ReleaseParseHandler(object): +class ReleaseParseHandler: """Base class for distribution data file parse handlers.""" def __init__(self, contents): @@ -35,7 +32,7 @@ def __init__(self, contents): Args: contents: file contents that are to be parsed. """ - precondition.AssertOptionalType(contents, Text) + precondition.AssertOptionalType(contents, str) self.contents = contents def Parse(self): @@ -211,7 +208,7 @@ def ParseFiles( yield rdf_protodict.Dict({ 'os_release': result.release, 'os_major_version': result.major, - 'os_minor_version': result.minor + 'os_minor_version': result.minor, }) return @@ -224,7 +221,7 @@ def ParseFiles( yield rdf_protodict.Dict({ 'os_release': 'AmazonLinuxAMI', 'os_major_version': int(match_object.group(1)), - 'os_minor_version': int(match_object.group(2)) + 'os_minor_version': int(match_object.group(2)), }) return @@ -236,7 +233,8 @@ def ParseFiles( # No successful parse. yield rdf_anomaly.Anomaly( - type='PARSER_ANOMALY', symptom='Unable to determine distribution.') + type='PARSER_ANOMALY', symptom='Unable to determine distribution.' + ) def _ParseOSReleaseFile(self, matches_dict): # The spec for the os-release file is given at @@ -266,8 +264,11 @@ def _ParseOSReleaseFile(self, matches_dict): # multi-part version numbers so we use a default minor version of # zero. os_minor_version = 0 if minor_match is None else int(minor_match) - if (os_release_name and os_major_version is not None and - os_minor_version is not None): + if ( + os_release_name + and os_major_version is not None + and os_minor_version is not None + ): return rdf_protodict.Dict({ 'os_release': os_release_name, 'os_major_version': os_major_version, diff --git a/grr/core/grr_response_core/lib/parsers/linux_release_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_release_parser_test.py index 6fb4259630..bfc49e624d 100644 --- a/grr/core/grr_response_core/lib/parsers/linux_release_parser_test.py +++ b/grr/core/grr_response_core/lib/parsers/linux_release_parser_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Unit test for the linux distribution parser.""" - import io import os @@ -160,8 +159,10 @@ def testEndToEndRockyLinux(self): def testEndToEndAmazon(self): parser = linux_release_parser.LinuxReleaseParser() test_data = [ - ("/etc/system-release", - os.path.join(self.parser_test_dir, "amazon-system-release")), + ( + "/etc/system-release", + os.path.join(self.parser_test_dir, "amazon-system-release"), + ), ] pathspecs, file_objects = self._CreateTestData(test_data) actual_result = list(parser.ParseFiles(None, pathspecs, file_objects)) @@ -177,8 +178,10 @@ def testEndToEndAmazon(self): def testEndToEndCoreOS(self): parser = linux_release_parser.LinuxReleaseParser() test_data = [ - ("/etc/os-release", - os.path.join(self.parser_test_dir, "coreos-os-release")), + ( + "/etc/os-release", + os.path.join(self.parser_test_dir, "coreos-os-release"), + ), ] pathspecs, file_objects = self._CreateTestData(test_data) actual_result = list(parser.ParseFiles(None, pathspecs, file_objects)) @@ -194,8 +197,10 @@ def testEndToEndCoreOS(self): def testEndToEndGoogleCOS(self): parser = linux_release_parser.LinuxReleaseParser() test_data = [ - ("/etc/os-release", - os.path.join(self.parser_test_dir, "google-cos-os-release")), + ( + "/etc/os-release", + os.path.join(self.parser_test_dir, "google-cos-os-release"), + ), ] pathspecs, file_objects = self._CreateTestData(test_data) actual_result = list(parser.ParseFiles(None, pathspecs, file_objects)) diff --git a/grr/core/grr_response_core/lib/parsers/linux_service_parser.py b/grr/core/grr_response_core/lib/parsers/linux_service_parser.py deleted file mode 100644 index 5c61c05e26..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_service_parser.py +++ /dev/null @@ -1,350 +0,0 @@ -#!/usr/bin/env python -"""Simple parsers for configuration files.""" - - -import logging -import os -import re -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import Text - -from grr_response_core.lib import lexer -from grr_response_core.lib import parsers -from grr_response_core.lib import utils -from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict -from grr_response_core.lib.util import precondition - - -class LSBInitLexer(lexer.Lexer): - """Parse out upstart configurations from init scripts. - - Runlevels in /etc/init.d are defined in stanzas like: - ### BEGIN INIT INFO - # Provides: sshd - # Required-Start: $remote_fs $syslog - # Required-Stop: $remote_fs $syslog - # Default-Start: 2 3 4 5 - # Default-Stop: 1 - # Short-Description: OpenBSD Secure Shell server - ### END INIT INFO - """ - - tokens = [ - lexer.Token("INITIAL", r"### BEGIN INIT INFO", None, "UPSTART"), - lexer.Token("UPSTART", r"### END INIT INFO", "Finish", "INITIAL"), - lexer.Token("UPSTART", r"#\s+([-\w]+):\s+([^#\n]*)", "StoreEntry", None), - lexer.Token("UPSTART", r"\n\s*\w+", "Finish", None), - lexer.Token(".*", ".", None, None) - ] - - required = {"provides", "default-start"} - - def __init__(self): - super().__init__() - self.entries = {} - - def StoreEntry(self, match, **_): - key, val = match.groups() - setting = key.strip().lower() - if setting: - self.entries[setting] = val - - def Finish(self, **_): - self.buffer = [] - - def ParseEntries(self, data): - precondition.AssertType(data, Text) - self.entries = {} - self.Reset() - self.Feed(data) - self.Close() - found = set(self.entries) - if self.required.issubset(found): - return self.entries - - -def _LogInvalidRunLevels(states, valid): - """Log any invalid run states found.""" - invalid = set() - for state in states: - if state not in valid: - invalid.add(state) - if invalid: - logging.warning("Invalid init runlevel(s) encountered: %s", - ", ".join(invalid)) - - -def GetRunlevelsLSB(states): - """Accepts a string and returns a list of strings of numeric LSB runlevels.""" - if not states: - return set() - valid = set(["0", "1", "2", "3", "4", "5", "6"]) - _LogInvalidRunLevels(states, valid) - return valid.intersection(set(states.split())) - - -def GetRunlevelsNonLSB(states): - """Accepts a string and returns a list of strings of numeric LSB runlevels.""" - if not states: - return set() - convert_table = { - "0": "0", - "1": "1", - "2": "2", - "3": "3", - "4": "4", - "5": "5", - "6": "6", - # SysV, Gentoo, Solaris, HP-UX all allow an alpha variant - # for single user. https://en.wikipedia.org/wiki/Runlevel - "S": "1", - "s": "1" - } - _LogInvalidRunLevels(states, convert_table) - return set([convert_table[s] for s in states.split() if s in convert_table]) - - -class LinuxLSBInitParser( - parsers.MultiFileParser[rdf_client.LinuxServiceInformation]): - """Parses LSB style /etc/init.d entries.""" - - output_types = [rdf_client.LinuxServiceInformation] - supported_artifacts = ["LinuxLSBInit"] - - def _Facilities(self, condition): - results = [] - for facility in condition.split(): - for expanded in self.insserv.get(facility, []): - if expanded not in results: - results.append(expanded) - return results - - def _ParseInit(self, init_files): - init_lexer = LSBInitLexer() - for path, file_obj in init_files: - init = init_lexer.ParseEntries(utils.ReadFileBytesAsUnicode(file_obj)) - if init: - service = rdf_client.LinuxServiceInformation() - service.name = init.get("provides") - service.start_mode = "INIT" - service.start_on = GetRunlevelsLSB(init.get("default-start")) - if service.start_on: - service.starts = True - service.stop_on = GetRunlevelsLSB(init.get("default-stop")) - service.description = init.get("short-description") - service.start_after = self._Facilities(init.get("required-start", [])) - service.stop_after = self._Facilities(init.get("required-stop", [])) - yield service - else: - logging.debug("No runlevel information found in %s", path) - - def _InsservExpander(self, facilities, val): - """Expand insserv variables.""" - expanded = [] - if val.startswith("$"): - vals = facilities.get(val, []) - for v in vals: - expanded.extend(self._InsservExpander(facilities, v)) - elif val.startswith("+"): - expanded.append(val[1:]) - else: - expanded.append(val) - return expanded - - def _ParseInsserv(self, data): - """/etc/insserv.conf* entries define system facilities. - - Full format details are in man 8 insserv, but the basic structure is: - $variable facility1 facility2 - $second_variable facility3 $variable - - Any init script that specifies Required-Start: $second_variable needs to be - expanded to facility1 facility2 facility3. - - Args: - data: A string of insserv definitions. - """ - p = config_file.FieldParser() - entries = p.ParseEntries(data) - raw = {e[0]: e[1:] for e in entries} - # Now expand out the facilities to services. - facilities = {} - for k, v in raw.items(): - # Remove interactive tags. - k = k.replace("<", "").replace(">", "") - facilities[k] = v - for k, vals in facilities.items(): - self.insserv[k] = [] - for v in vals: - self.insserv[k].extend(self._InsservExpander(facilities, v)) - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_client.LinuxServiceInformation]: - del knowledge_base # Unused. - - self.insserv = {} - paths = [pathspec.path for pathspec in pathspecs] - files = dict(zip(paths, filedescs)) - insserv_data = "" - init_files = [] - for k, v in files.items(): - if k.startswith("/etc/insserv.conf"): - insserv_data += "%s\n" % utils.ReadFileBytesAsUnicode(v) - else: - init_files.append((k, v)) - self._ParseInsserv(insserv_data) - for rslt in self._ParseInit(init_files): - yield rslt - - -class LinuxXinetdParser( - parsers.MultiFileParser[rdf_client.LinuxServiceInformation]): - """Parses xinetd entries.""" - - output_types = [rdf_client.LinuxServiceInformation] - supported_artifacts = ["LinuxXinetd"] - - def _ParseSection(self, section, cfg): - p = config_file.KeyValueParser() - # Skip includedir, we get this from the artifact. - if section.startswith("includedir"): - return - elif section.startswith("default"): - for val in p.ParseEntries(cfg): - self.default.update(val) - elif section.startswith("service"): - svc = section.replace("service", "").strip() - if not svc: - return - self.entries[svc] = {} - for val in p.ParseEntries(cfg): - self.entries[svc].update(val) - - def _ProcessEntries(self, fd): - """Extract entries from the xinetd config files.""" - p = config_file.KeyValueParser(kv_sep="{", term="}", sep=None) - data = utils.ReadFileBytesAsUnicode(fd) - entries = p.ParseEntries(data) - for entry in entries: - for section, cfg in entry.items(): - # The parser returns a list of configs. There will only be one. - if cfg: - cfg = cfg[0].strip() - else: - cfg = "" - self._ParseSection(section, cfg) - - def _GenConfig(self, cfg): - """Interpolate configurations with defaults to generate actual configs.""" - # Some setting names may have a + or - suffix. These indicate that the - # settings modify the default values. - merged = self.default.copy() - for setting, vals in cfg.items(): - option, operator = (setting.split(None, 1) + [None])[:2] - vals = set(vals) - default = set(self.default.get(option, [])) - # If there is an operator, updated values accordingly. - if operator == "+": - vals = default.union(vals) - elif operator == "-": - vals = default.difference(vals) - merged[option] = list(vals) - return rdf_protodict.AttributedDict(**merged) - - def _GenService(self, name, cfg): - # Merge the config values. - service = rdf_client.LinuxServiceInformation(name=name) - service.config = self._GenConfig(cfg) - if service.config.disable == ["no"]: - service.starts = True - service.start_mode = "XINETD" - service.start_after = ["xinetd"] - return service - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_client.LinuxServiceInformation]: - del knowledge_base # Unused. - - self.entries = {} - self.default = {} - paths = [pathspec.path for pathspec in pathspecs] - files = dict(zip(paths, filedescs)) - for v in files.values(): - self._ProcessEntries(v) - for name, cfg in self.entries.items(): - yield self._GenService(name, cfg) - - -class LinuxSysVInitParser( - parsers.MultiFileParser[rdf_client.LinuxServiceInformation]): - """Parses SysV runlevel entries. - - Reads the stat entries for files under /etc/rc* runlevel scripts. - Identifies start and stop levels for services. - - Yields: - LinuxServiceInformation for each service with a runlevel entry. - Anomalies if there are non-standard service startup definitions. - """ - - output_types = [rdf_client.LinuxServiceInformation] - supported_artifacts = ["LinuxSysVInit"] - - runlevel_re = re.compile(r"/etc/rc(?:\.)?([0-6S]|local$)(?:\.d)?") - runscript_re = re.compile(r"(?P[KS])(?P\d+)(?P\S+)") - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_client.LinuxServiceInformation]: - """Identify the init scripts and the start/stop scripts at each runlevel. - - Evaluate all the stat entries collected from the system. - If the path name matches a runlevel spec, and if the filename matches a - sysv init symlink process the link as a service. - - Args: - knowledge_base: A client's knowledge base. - pathspecs: A list of path description for collected files. - filedescs: A list of file descriptors of collected files. - - Yields: - rdf_client.LinuxServiceInformation for each detected service. - """ - del knowledge_base, filedescs # Unused. - - services = {} - for pathspec in pathspecs: - path = pathspec.path - runlevel = self.runlevel_re.match(os.path.dirname(path)) - runscript = self.runscript_re.match(os.path.basename(path)) - if runlevel and runscript: - svc = runscript.groupdict() - service = services.setdefault( - svc["name"], - rdf_client.LinuxServiceInformation( - name=svc["name"], start_mode="INIT")) - runlvl = GetRunlevelsNonLSB(runlevel.group(1)) - if svc["action"] == "S" and runlvl: - service.start_on.append(runlvl.pop()) - service.starts = True - elif runlvl: - service.stop_on.append(runlvl.pop()) - - for svc in services.values(): - yield svc diff --git a/grr/core/grr_response_core/lib/parsers/linux_service_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_service_parser_test.py deleted file mode 100644 index 398598233c..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_service_parser_test.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python -"""Unit test for the linux sysctl parser.""" - - -import io - -from absl import app -from absl.testing import absltest - -from grr_response_core.lib.parsers import linux_service_parser -from grr_response_core.lib.parsers import parsers_test_lib -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr.test_lib import artifact_test_lib -from grr.test_lib import test_lib - - -class LinuxLSBInitParserTest(test_lib.GRRBaseTest): - """Test parsing of linux /etc/init.d files with LSB headers.""" - - def testParseLSBInit(self): - """Init entries return accurate LinuxServiceInformation values.""" - configs = parsers_test_lib.GenInit("sshd", "OpenBSD Secure Shell server") - pathspecs, files = artifact_test_lib.GenPathspecFileData(configs) - - parser = linux_service_parser.LinuxLSBInitParser() - results = list(parser.ParseFiles(None, pathspecs, files)) - self.assertIsInstance(results[0], rdf_client.LinuxServiceInformation) - result = results[0] - self.assertEqual("sshd", result.name) - self.assertEqual("OpenBSD Secure Shell server", result.description) - self.assertEqual("INIT", result.start_mode) - self.assertCountEqual([2, 3, 4, 5], result.start_on) - self.assertCountEqual([1], result.stop_on) - self.assertCountEqual([ - "umountfs", "umountnfs", "sendsigs", "rsyslog", "sysklogd", "syslog-ng", - "dsyslog", "inetutils-syslogd" - ], result.start_after) - self.assertCountEqual( - ["rsyslog", "sysklogd", "syslog-ng", "dsyslog", "inetutils-syslogd"], - result.stop_after) - - def testSkipBadLSBInit(self): - """Bad Init entries fail gracefully.""" - empty = "" - snippet = r"""# Provides: sshd""" - unfinished = """ - ### BEGIN INIT INFO - what are you thinking? - """ - data = { - "/tmp/empty": empty.encode("utf-8"), - "/tmp/snippet": snippet.encode("utf-8"), - "/tmp/unfinished": unfinished.encode("utf-8"), - } - pathspecs, files = artifact_test_lib.GenPathspecFileData(data) - parser = linux_service_parser.LinuxLSBInitParser() - results = list(parser.ParseFiles(None, pathspecs, files)) - self.assertFalse(results) - - -class LinuxXinetdParserTest(test_lib.GRRBaseTest): - """Test parsing of xinetd entries.""" - - def testParseXinetd(self): - """Xinetd entries return accurate LinuxServiceInformation values.""" - configs = parsers_test_lib.GenXinetd("telnet", "yes") - configs.update(parsers_test_lib.GenXinetd("forwarder", "no")) - pathspecs, files = artifact_test_lib.GenPathspecFileData(configs) - - parser = linux_service_parser.LinuxXinetdParser() - results = list(parser.ParseFiles(None, pathspecs, files)) - self.assertLen(results, 2) - self.assertCountEqual(["forwarder", "telnet"], [r.name for r in results]) - for rslt in results: - self.assertFalse(rslt.start_on) - self.assertFalse(rslt.stop_on) - self.assertFalse(rslt.stop_after) - if rslt.name == "telnet": - self.assertFalse(rslt.start_mode) - self.assertFalse(rslt.start_after) - self.assertFalse(rslt.starts) - else: - self.assertEqual(rslt.start_mode, - rdf_client.LinuxServiceInformation.StartMode.XINETD) - self.assertCountEqual(["xinetd"], list(rslt.start_after)) - self.assertTrue(rslt.starts) - - -class LinuxSysVInitParserTest(absltest.TestCase): - """Test parsing of sysv startup and shutdown links.""" - - def testParseServices(self): - knowledge_base = rdf_client.KnowledgeBase() - - paths = [ - # Directories. - "/etc", - "/etc/rc1.d", - "/etc/rc2.d", - "/etc/rc6.d", - "/etc/rcS.d", - # Files. - "/etc/rc.local", - "/etc/ignoreme", - "/etc/rc2.d/S20ssh", - # Links. - "/etc/rc1.d/S90single", - "/etc/rc1.d/K20ssh", - "/etc/rc1.d/ignore", - "/etc/rc2.d/S20ntp", - "/etc/rc2.d/S30ufw", - "/etc/rc6.d/K20ssh", - "/etc/rcS.d/S20firewall", - ] - pathspecs = [rdf_paths.PathSpec(path=path) for path in paths] - filedescs = [io.BytesIO(b"") for _ in paths] - - parser = linux_service_parser.LinuxSysVInitParser() - results = list(parser.ParseFiles(knowledge_base, pathspecs, filedescs)) - - services = {service.name: service for service in results} - self.assertLen(services, 5) - self.assertCountEqual(["single", "ssh", "ntp", "ufw", "firewall"], services) - self.assertCountEqual([2], services["ssh"].start_on) - self.assertCountEqual([1, 6], services["ssh"].stop_on) - self.assertTrue(services["ssh"].starts) - self.assertCountEqual([1], services["firewall"].start_on) - self.assertTrue(services["firewall"].starts) - - -def main(args): - test_lib.main(args) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/linux_software_parser.py b/grr/core/grr_response_core/lib/parsers/linux_software_parser.py deleted file mode 100644 index 2c56177057..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_software_parser.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python -"""Simple parsers for Linux files.""" - -import re -from typing import IO -from typing import Iterator - -from grr_response_core.lib import parsers -from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths - - -class DebianPackagesStatusParser( - parsers.SingleFileParser[rdf_client.SoftwarePackages]): - """Parser for /var/lib/dpkg/status. Yields SoftwarePackage semantic values.""" - - output_types = [rdf_client.SoftwarePackages] - supported_artifacts = ["DebianPackagesStatus"] - - installed_re = re.compile(r"^\w+ \w+ installed$") - - def __init__(self, deb822): - """Initializes the parser. - - Args: - deb822: An accessor for RFC822-like data formats. - """ - self._deb822 = deb822 - - def ParseFile( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspec: rdf_paths.PathSpec, - filedesc: IO[bytes], - ) -> Iterator[rdf_client.SoftwarePackages]: - del knowledge_base # Unused. - del pathspec # Unused. - - packages = [] - sw_data = utils.ReadFileBytesAsUnicode(filedesc) - try: - for pkg in self._deb822.Packages.iter_paragraphs(sw_data.splitlines()): - if self.installed_re.match(pkg["Status"]): - packages.append( - rdf_client.SoftwarePackage( - name=pkg["Package"], - description=pkg["Description"], - version=pkg["Version"], - architecture=pkg["Architecture"], - publisher=pkg["Maintainer"], - install_state="INSTALLED")) - except SystemError: - yield rdf_anomaly.Anomaly( - type="PARSER_ANOMALY", symptom="Invalid dpkg status file") - finally: - if packages: - yield rdf_client.SoftwarePackages(packages=packages) diff --git a/grr/core/grr_response_core/lib/parsers/linux_software_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_software_parser_test.py deleted file mode 100644 index 4e1a503371..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_software_parser_test.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python -"""Unit test for the linux file parser.""" - - -import os -import unittest - -from absl import app - -from grr_response_core.lib.parsers import linux_software_parser -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly -from grr.test_lib import test_lib - -try: - from debian import deb822 # pylint: disable=g-import-not-at-top -except ImportError: - raise unittest.SkipTest("`deb822` not available") - - -class LinuxSoftwareParserTest(test_lib.GRRBaseTest): - """Test parsing of linux software collection.""" - - def testDebianPackagesStatusParser(self): - """Test parsing of a status file.""" - parser = linux_software_parser.DebianPackagesStatusParser(deb822) - path = os.path.join(self.base_path, "dpkg_status") - with open(path, "rb") as data: - out = list(parser.ParseFile(None, None, data)) - self.assertLen(out, 1) - package_list = out[0] - self.assertLen(package_list.packages, 2) - package0 = package_list.packages[0] - self.assertEqual(("t1", "v1"), (package0.name, package0.version)) - package1 = package_list.packages[1] - self.assertEqual(("t2", "v2"), (package1.name, package1.version)) - - def testDebianPackagesStatusParserBadInput(self): - """If the status file is broken, fail nicely.""" - parser = linux_software_parser.DebianPackagesStatusParser(deb822) - path = os.path.join(self.base_path, "numbers.txt") - with open(path, "rb") as data: - out = list(parser.ParseFile(None, None, data)) - for result in out: - self.assertIsInstance(result, rdf_anomaly.Anomaly) - - -def main(args): - test_lib.main(args) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser.py b/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser.py deleted file mode 100644 index cdbd932dfb..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python -"""Simple parsers for configuration files.""" - - -from typing import IO -from typing import Iterable -from typing import Iterator - -from grr_response_core.lib import parser -from grr_response_core.lib import parsers -from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict - - -class ProcSysParser(parsers.MultiFileParser[rdf_protodict.AttributedDict]): - """Parser for /proc/sys entries.""" - - output_types = [rdf_protodict.AttributedDict] - supported_artifacts = ["LinuxProcSysHardeningSettings"] - - def _Parse(self, pathspec, file_obj): - # Remove /proc/sys - key = pathspec.path.replace("/proc/sys/", "", 1) - key = key.replace("/", "_") - value = file_obj.read().decode("utf-8").split() - if len(value) == 1: - value = value[0] - return key, value - - def ParseFiles( - self, - knowledge_base: rdf_client.KnowledgeBase, - pathspecs: Iterable[rdf_paths.PathSpec], - filedescs: Iterable[IO[bytes]], - ) -> Iterator[rdf_protodict.AttributedDict]: - del knowledge_base # Unused. - - config = {} - for pathspec, file_obj in zip(pathspecs, filedescs): - k, v = self._Parse(pathspec, file_obj) - config[k] = v - yield rdf_protodict.AttributedDict(config) - - -class SysctlCmdParser(parser.CommandParser): - """Parser for sysctl -a output.""" - - output_types = [rdf_protodict.AttributedDict] - supported_artifacts = ["LinuxSysctlCmd"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lexer = config_file.KeyValueParser() - - def Parse(self, cmd, args, stdout, stderr, return_val, knowledge_base): - """Parse the sysctl output.""" - _ = stderr, args, knowledge_base # Unused. - self.CheckReturn(cmd, return_val) - result = rdf_protodict.AttributedDict() - # The KeyValueParser generates an ordered dict by default. The sysctl vals - # aren't ordering dependent, but there's no need to un-order it. - for k, v in self.lexer.ParseToOrderedDict(stdout).items(): - key = k.replace(".", "_") - if len(v) == 1: - v = v[0] - result[key] = v - return [result] diff --git a/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser_test.py b/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser_test.py deleted file mode 100644 index df5a67bee0..0000000000 --- a/grr/core/grr_response_core/lib/parsers/linux_sysctl_parser_test.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -"""Unit test for the linux sysctl parser.""" - - -import io - -from absl import app - -from grr_response_core.lib.parsers import linux_sysctl_parser -from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict -from grr.test_lib import test_lib - - -class ProcSysParserTest(test_lib.GRRBaseTest): - """Test parsing of linux /proc/sys data.""" - - def _GenTestData(self, paths, data): - pathspecs = [] - files = [] - for path in paths: - p = rdf_paths.PathSpec(path=path) - pathspecs.append(p) - for val in data: - files.append(io.BytesIO(val.encode("utf-8"))) - return pathspecs, files - - def testParseSysctl(self): - """Sysctl entries return an underscore separated key and 0+ values.""" - parser = linux_sysctl_parser.ProcSysParser() - paths = ["/proc/sys/net/ipv4/ip_forward", "/proc/sys/kernel/printk"] - vals = ["0", "3 4 1 3"] - pathspecs, files = self._GenTestData(paths, vals) - results = list(parser.ParseFiles(None, pathspecs, files)) - self.assertLen(results, 1) - self.assertIsInstance(results[0], rdf_protodict.AttributedDict) - self.assertEqual("0", results[0].net_ipv4_ip_forward) - self.assertEqual(["3", "4", "1", "3"], results[0].kernel_printk) - - -class SysctlCmdParserTest(test_lib.GRRBaseTest): - """Test parsing of linux sysctl -a command output.""" - - def testParseSysctl(self): - """Sysctl entries return an underscore separated key and 0+ values.""" - content = """ - kernel.printk = 3 4 1 3 - net.ipv4.ip_forward = 0 - """ - parser = linux_sysctl_parser.SysctlCmdParser() - results = parser.Parse("/sbin/sysctl", ["-a"], content, "", 0, None) - self.assertLen(results, 1) - self.assertIsInstance(results[0], rdf_protodict.AttributedDict) - self.assertEqual("0", results[0].net_ipv4_ip_forward) - self.assertEqual(["3", "4", "1", "3"], results[0].kernel_printk) - - -def main(args): - test_lib.main(args) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/parsers/local/__init__.py b/grr/core/grr_response_core/lib/parsers/local/__init__.py deleted file mode 100644 index 2cdaee1f35..0000000000 --- a/grr/core/grr_response_core/lib/parsers/local/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env python -"""This directory contains local site-specific parser implementations.""" diff --git a/grr/core/grr_response_core/lib/parsers/osx_file_parser.py b/grr/core/grr_response_core/lib/parsers/osx_file_parser.py index 770705fba0..1e550d86c5 100644 --- a/grr/core/grr_response_core/lib/parsers/osx_file_parser.py +++ b/grr/core/grr_response_core/lib/parsers/osx_file_parser.py @@ -1,57 +1,22 @@ #!/usr/bin/env python """Simple parsers for OS X files.""" - import datetime import io -import os import plistlib -import stat -from typing import IO -from typing import Iterable -from typing import Iterator - +from typing import IO, Iterator from grr_response_core.lib import parser from grr_response_core.lib import parsers -from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import plist as rdf_plist -class OSXUsersParser(parsers.MultiResponseParser[rdf_client.User]): - """Parser for Glob of /Users/*.""" - - output_types = [rdf_client.User] - - # TODO: The parser has to be invoked explicitly, we should not - # relly on magic parsing anymore. - supported_artifacts = [] - - _ignore_users = ["Shared"] - - def ParseResponses( - self, - knowledge_base: rdf_client.KnowledgeBase, - responses: Iterable[rdfvalue.RDFValue], - ) -> Iterator[rdf_client.User]: - for response in responses: - if not isinstance(response, rdf_client_fs.StatEntry): - raise TypeError(f"Unexpected response type: `{type(response)}`") - - # TODO: `st_mode` has to be an `int`, not `StatMode`. - if stat.S_ISDIR(int(response.st_mode)): - homedir = response.pathspec.path - username = os.path.basename(homedir) - if username not in self._ignore_users: - yield rdf_client.User(username=username, homedir=homedir) - - # TODO(hanuszczak): Why is a command parser in a file called `osx_file_parsers`? class OSXSPHardwareDataTypeParser(parser.CommandParser): """Parser for the Hardware Data from System Profiler.""" + output_types = [rdf_client.HardwareInfo] supported_artifacts = ["OSXSPHardwareDataType"] @@ -76,7 +41,8 @@ def Parse(self, cmd, args, stdout, stderr, return_val, knowledge_base): yield rdf_client.HardwareInfo( serial_number=serial_number, bios_version=bios_version, - system_product_name=system_product_name) + system_product_name=system_product_name, + ) class OSXLaunchdPlistParser(parsers.SingleFileParser[rdf_plist.LaunchdPlist]): @@ -84,7 +50,8 @@ class OSXLaunchdPlistParser(parsers.SingleFileParser[rdf_plist.LaunchdPlist]): output_types = [rdf_plist.LaunchdPlist] supported_artifacts = [ - "MacOSLaunchAgentsPlistFile", "MacOSLaunchDaemonsPlistFile" + "MacOSLaunchAgentsPlistFile", + "MacOSLaunchDaemonsPlistFile", ] def ParseFile( @@ -98,18 +65,44 @@ def ParseFile( kwargs = {"path": pathspec.last.path} direct_copy_items = [ - "Label", "Disabled", "UserName", "GroupName", "Program", - "StandardInPath", "StandardOutPath", "StandardErrorPath", - "LimitLoadToSessionType", "EnableGlobbing", "EnableTransactions", - "OnDemand", "RunAtLoad", "RootDirectory", "WorkingDirectory", "Umask", - "TimeOut", "ExitTimeOut", "ThrottleInterval", "InitGroups", - "StartOnMount", "StartInterval", "Debug", "WaitForDebugger", "Nice", - "ProcessType", "AbandonProcessGroup", "LowPriorityIO", "LaunchOnlyOnce" + "Label", + "Disabled", + "UserName", + "GroupName", + "Program", + "StandardInPath", + "StandardOutPath", + "StandardErrorPath", + "LimitLoadToSessionType", + "EnableGlobbing", + "EnableTransactions", + "OnDemand", + "RunAtLoad", + "RootDirectory", + "WorkingDirectory", + "Umask", + "TimeOut", + "ExitTimeOut", + "ThrottleInterval", + "InitGroups", + "StartOnMount", + "StartInterval", + "Debug", + "WaitForDebugger", + "Nice", + "ProcessType", + "AbandonProcessGroup", + "LowPriorityIO", + "LaunchOnlyOnce", ] string_array_items = [ - "LimitLoadToHosts", "LimitLoadFromHosts", "LimitLoadToSessionType", - "ProgramArguments", "WatchPaths", "QueueDirectories" + "LimitLoadToHosts", + "LimitLoadFromHosts", + "LimitLoadToSessionType", + "ProgramArguments", + "WatchPaths", + "QueueDirectories", ] flag_only_items = ["SoftResourceLimits", "HardResourceLimits", "Sockets"] @@ -142,7 +135,8 @@ def ParseFile( if plist.get("inetdCompatibility") is not None: kwargs["inetdCompatibilityWait"] = plist.get("inetdCompatibility").get( - "Wait") + "Wait" + ) keepalive = plist.get("KeepAlive") if isinstance(keepalive, bool) or keepalive is None: @@ -158,7 +152,9 @@ def ParseFile( for pathstate in pathstates: keepalivedict["PathState"].append( rdf_plist.PlistBoolDictEntry( - name=pathstate, value=pathstates[pathstate])) + name=pathstate, value=pathstates[pathstate] + ) + ) otherjobs = keepalive.get("OtherJobEnabled") if otherjobs is not None: @@ -166,7 +162,9 @@ def ParseFile( for otherjob in otherjobs: keepalivedict["OtherJobEnabled"].append( rdf_plist.PlistBoolDictEntry( - name=otherjob, value=otherjobs[otherjob])) + name=otherjob, value=otherjobs[otherjob] + ) + ) kwargs["KeepAliveDict"] = rdf_plist.LaunchdKeepAlive(**keepalivedict) envvars = plist.get("EnvironmentVariables") @@ -175,7 +173,9 @@ def ParseFile( for envvar in envvars: kwargs["EnvironmentVariables"].append( rdf_plist.PlistStringDictEntry( - name=envvar, value=str(envvars[envvar]))) + name=envvar, value=str(envvars[envvar]) + ) + ) startcalendarinterval = plist.get("StartCalendarInterval") if startcalendarinterval is not None: @@ -186,7 +186,8 @@ def ParseFile( Hour=startcalendarinterval.get("Hour"), Day=startcalendarinterval.get("Day"), Weekday=startcalendarinterval.get("Weekday"), - Month=startcalendarinterval.get("Month")) + Month=startcalendarinterval.get("Month"), + ) ] else: kwargs["StartCalendarInterval"] = [] @@ -197,13 +198,16 @@ def ParseFile( Hour=entry.get("Hour"), Day=entry.get("Day"), Weekday=entry.get("Weekday"), - Month=entry.get("Month"))) + Month=entry.get("Month"), + ) + ) yield rdf_plist.LaunchdPlist(**kwargs) class OSXInstallHistoryPlistParser( - parsers.SingleFileParser[rdf_client.SoftwarePackages]): + parsers.SingleFileParser[rdf_client.SoftwarePackages] +): """Parse InstallHistory plist files into SoftwarePackage objects.""" output_types = [rdf_client.SoftwarePackages] @@ -225,7 +229,8 @@ def ParseFile( if not isinstance(plist, list): raise parsers.ParseError( - "InstallHistory plist is a '%s', expecting a list" % type(plist)) + "InstallHistory plist is a '%s', expecting a list" % type(plist) + ) packages = [] for sw in plist: @@ -235,7 +240,9 @@ def ParseFile( version=sw.get("displayVersion"), description=",".join(sw.get("packageIdentifiers", [])), # TODO(hanuszczak): make installed_on an RDFDatetime - installed_on=_DateToEpoch(sw.get("date")))) + installed_on=_DateToEpoch(sw.get("date")), + ) + ) if packages: yield rdf_client.SoftwarePackages(packages=packages) diff --git a/grr/core/grr_response_core/lib/parsers/osx_file_parser_test.py b/grr/core/grr_response_core/lib/parsers/osx_file_parser_test.py index 76a589ed56..19b845ed3a 100644 --- a/grr/core/grr_response_core/lib/parsers/osx_file_parser_test.py +++ b/grr/core/grr_response_core/lib/parsers/osx_file_parser_test.py @@ -3,15 +3,14 @@ import io import os - import plistlib + from absl import app from grr_response_core.lib import parsers from grr_response_core.lib.parsers import osx_file_parser from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_action as rdf_client_action -from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr.test_lib import test_lib from grr.test_lib import time @@ -20,30 +19,6 @@ class TestOSXFileParsing(test_lib.GRRBaseTest): """Test parsing of OSX files.""" - def testOSXUsersParser(self): - """Ensure we can extract users from a passwd file.""" - paths = ["/Users/user1", "/Users/user2", "/Users/Shared"] - statentries = [] - for path in paths: - statentries.append( - rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.OS), - st_mode=16877)) - - statentries.append( - rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec( - path="/Users/.localized", - pathtype=rdf_paths.PathSpec.PathType.OS), - st_mode=33261)) - - parser = osx_file_parser.OSXUsersParser() - out = list(parser.ParseResponses(rdf_client.KnowledgeBase(), statentries)) - self.assertCountEqual([x.username for x in out], ["user1", "user2"]) - self.assertCountEqual([x.homedir for x in out], - ["/Users/user1", "/Users/user2"]) - def testOSXSPHardwareDataTypeParserInvalidInput(self): parser = osx_file_parser.OSXSPHardwareDataTypeParser() @@ -62,11 +37,19 @@ def testOSXSPHardwareDataTypeParserInvalidInput(self): def testOSXSPHardwareDataTypeParser(self): parser = osx_file_parser.OSXSPHardwareDataTypeParser() - content = open(os.path.join(self.base_path, "system_profiler.xml"), - "rb").read() + content = open( + os.path.join(self.base_path, "system_profiler.xml"), "rb" + ).read() result = list( - parser.Parse("/usr/sbin/system_profiler", ["SPHardwareDataType -xml"], - content, "", 0, None)) + parser.Parse( + "/usr/sbin/system_profiler", + ["SPHardwareDataType -xml"], + content, + "", + 0, + None, + ) + ) self.assertEqual(result[0].serial_number, "C02JQ0F5F6L9") self.assertEqual(result[0].bios_version, "MBP101.00EE.B02") self.assertEqual(result[0].system_product_name, "MacBookPro10,1") @@ -83,10 +66,13 @@ def testOSXLaunchdPlistParser(self): for result in results: self.assertEqual(result.Label, "com.google.code.grr") - self.assertCountEqual(result.ProgramArguments, [ - "/usr/lib/grr/grr_3.0.0.5_amd64/grr", - "--config=/usr/lib/grr/grr_3.0.0.5_amd64/grr.yaml" - ]) + self.assertCountEqual( + result.ProgramArguments, + [ + "/usr/lib/grr/grr_3.0.0.5_amd64/grr", + "--config=/usr/lib/grr/grr_3.0.0.5_amd64/grr.yaml", + ], + ) def testOSXInstallHistoryPlistParserInvalidInput(self): parser = osx_file_parser.OSXInstallHistoryPlistParser() @@ -122,12 +108,16 @@ def testOSXInstallHistoryPlistParser(self): "com.eset.esetNod32Antivirus.pkgid.pkg," "com.eset.esetNod32Antivirus.com.eset.esets_daemon.pkg," "com.eset.esetNod32Antivirus.esetsbkp.pkg," - "com.eset.esetNod32Antivirus.esets_kac_64_106.pkg") + "com.eset.esetNod32Antivirus.esets_kac_64_106.pkg", + ) self.assertEqual( packages[0].installed_on, - time.HumanReadableToMicrosecondsSinceEpoch("2017-07-20T18:40:22Z")) - self.assertEqual(packages[0].install_state, - rdf_client.SoftwarePackage.InstallState.INSTALLED) + time.HumanReadableToMicrosecondsSinceEpoch("2017-07-20T18:40:22Z"), + ) + self.assertEqual( + packages[0].install_state, + rdf_client.SoftwarePackage.InstallState.INSTALLED, + ) # old grr agent self.assertEqual(packages[1].name, "grr") @@ -135,9 +125,12 @@ def testOSXInstallHistoryPlistParser(self): self.assertEqual(packages[1].description, "com.google.code.grr.grr_3.2.1.0") self.assertEqual( packages[1].installed_on, - time.HumanReadableToMicrosecondsSinceEpoch("2018-03-13T05:39:17Z")) - self.assertEqual(packages[1].install_state, - rdf_client.SoftwarePackage.InstallState.INSTALLED) + time.HumanReadableToMicrosecondsSinceEpoch("2018-03-13T05:39:17Z"), + ) + self.assertEqual( + packages[1].install_state, + rdf_client.SoftwarePackage.InstallState.INSTALLED, + ) # new grr agent self.assertEqual(packages[2].name, "grr") @@ -145,24 +138,32 @@ def testOSXInstallHistoryPlistParser(self): self.assertEqual(packages[2].description, "com.google.code.grr.grr_3.2.3.2") self.assertEqual( packages[2].installed_on, - time.HumanReadableToMicrosecondsSinceEpoch("2018-08-07T16:07:10Z")) - self.assertEqual(packages[2].install_state, - rdf_client.SoftwarePackage.InstallState.INSTALLED) + time.HumanReadableToMicrosecondsSinceEpoch("2018-08-07T16:07:10Z"), + ) + self.assertEqual( + packages[2].install_state, + rdf_client.SoftwarePackage.InstallState.INSTALLED, + ) # Sierra self.assertEqual(packages[3].name, "macOS Sierra Update") self.assertEqual(packages[3].version, "10.12.6") self.assertEqual( - packages[3].description, "com.apple.pkg.update.os.10.12.6Patch.16G29," + packages[3].description, + "com.apple.pkg.update.os.10.12.6Patch.16G29," "com.apple.pkg.FirmwareUpdate," "com.apple.update.fullbundleupdate.16G29," - "com.apple.pkg.EmbeddedOSFirmware") + "com.apple.pkg.EmbeddedOSFirmware", + ) # echo $(( $(date --date="2017-07-25T04:26:10Z" +"%s") * 1000000)) self.assertEqual( packages[3].installed_on, - time.HumanReadableToMicrosecondsSinceEpoch("2017-07-25T04:26:10Z")) - self.assertEqual(packages[3].install_state, - rdf_client.SoftwarePackage.InstallState.INSTALLED) + time.HumanReadableToMicrosecondsSinceEpoch("2017-07-25T04:26:10Z"), + ) + self.assertEqual( + packages[3].install_state, + rdf_client.SoftwarePackage.InstallState.INSTALLED, + ) # MacOS 11.2 self.assertEqual(packages[4].name, "macOS 11.2") @@ -170,9 +171,12 @@ def testOSXInstallHistoryPlistParser(self): self.assertEqual(packages[4].description, "") self.assertEqual( packages[4].installed_on, - time.HumanReadableToMicrosecondsSinceEpoch("2021-02-09T22:34:52Z")) - self.assertEqual(packages[4].install_state, - rdf_client.SoftwarePackage.InstallState.INSTALLED) + time.HumanReadableToMicrosecondsSinceEpoch("2021-02-09T22:34:52Z"), + ) + self.assertEqual( + packages[4].install_state, + rdf_client.SoftwarePackage.InstallState.INSTALLED, + ) def main(argv): diff --git a/grr/core/grr_response_core/lib/parsers/osx_launchd.py b/grr/core/grr_response_core/lib/parsers/osx_launchd.py index 4a1e4ea245..4d4e5d8bf2 100644 --- a/grr/core/grr_response_core/lib/parsers/osx_launchd.py +++ b/grr/core/grr_response_core/lib/parsers/osx_launchd.py @@ -3,7 +3,6 @@ """Parser for OSX launchd jobs.""" - import re from typing import Iterator @@ -14,7 +13,7 @@ from grr_response_core.lib.rdfvalues import standard as rdf_standard -class OSXLaunchdJobDict(object): +class OSXLaunchdJobDict: """Cleanup launchd jobs reported by the service management framework. Exclude some rubbish like logged requests that aren't real jobs (see @@ -34,7 +33,7 @@ def __init__(self, launchdjobs): Args: launchdjobs: NSCFArray of NSCFDictionarys containing launchd job data from - the ServiceManagement framework. + the ServiceManagement framework. """ self.launchdjobs = launchdjobs @@ -54,6 +53,7 @@ def FilterItem(self, launchditem): Args: launchditem: job NSCFDictionary + Returns: True if the item should be filtered (dropped) """ @@ -64,8 +64,10 @@ def FilterItem(self, launchditem): class DarwinPersistenceMechanismsParser( - parsers.SingleResponseParser[rdf_standard.PersistenceFile]): + parsers.SingleResponseParser[rdf_standard.PersistenceFile] +): """Turn various persistence objects into PersistenceFiles.""" + output_types = [rdf_standard.PersistenceFile] supported_artifacts = ["DarwinPersistenceMechanisms"] @@ -80,10 +82,12 @@ def ParseResponse( if isinstance(response, rdf_client.OSXServiceInformation): if response.program: pathspec = rdf_paths.PathSpec( - path=response.program, pathtype=rdf_paths.PathSpec.PathType.UNSET) + path=response.program, pathtype=rdf_paths.PathSpec.PathType.UNSET + ) elif response.args: pathspec = rdf_paths.PathSpec( - path=response.args[0], pathtype=rdf_paths.PathSpec.PathType.UNSET) + path=response.args[0], pathtype=rdf_paths.PathSpec.PathType.UNSET + ) if pathspec is not None: yield rdf_standard.PersistenceFile(pathspec=pathspec) diff --git a/grr/core/grr_response_core/lib/parsers/osx_launchd_test.py b/grr/core/grr_response_core/lib/parsers/osx_launchd_test.py index 69feb38100..351d63622b 100644 --- a/grr/core/grr_response_core/lib/parsers/osx_launchd_test.py +++ b/grr/core/grr_response_core/lib/parsers/osx_launchd_test.py @@ -1,8 +1,6 @@ #!/usr/bin/env python """Tests for grr.parsers.osx_launchd.""" - - from absl import app from grr_response_core.lib.parsers import osx_launchd @@ -42,7 +40,8 @@ class DarwinPersistenceMechanismsParserTest(flow_test_lib.FlowTestsBaseclass): def testParse(self): parser = osx_launchd.DarwinPersistenceMechanismsParser() serv_info = rdf_client.OSXServiceInformation( - label="blah", args=["/blah/test", "-v"]) + label="blah", args=["/blah/test", "-v"] + ) results = list(parser.ParseResponse(rdf_client.KnowledgeBase(), serv_info)) self.assertEqual(results[0].pathspec.path, "/blah/test") diff --git a/grr/core/grr_response_core/lib/parsers/parsers_test.py b/grr/core/grr_response_core/lib/parsers/parsers_test.py index 215b2a28f5..f37d6630b5 100644 --- a/grr/core/grr_response_core/lib/parsers/parsers_test.py +++ b/grr/core/grr_response_core/lib/parsers/parsers_test.py @@ -1,8 +1,5 @@ #!/usr/bin/env python - -from typing import IO -from typing import Iterable -from typing import Iterator +from typing import IO, Iterable, Iterator from unittest import mock from absl.testing import absltest @@ -17,8 +14,11 @@ class ArtifactParserFactoryTest(absltest.TestCase): - @mock.patch.object(parsers, "SINGLE_RESPONSE_PARSER_FACTORY", - factory.Factory(parsers.SingleResponseParser)) + @mock.patch.object( + parsers, + "SINGLE_RESPONSE_PARSER_FACTORY", + factory.Factory(parsers.SingleResponseParser), + ) def testSingleResponseParsers(self): class FooParser(parsers.SingleResponseParser[None]): @@ -70,8 +70,11 @@ def ParseResponse( thud_parsers = thud_factory.SingleResponseParsers() self.assertCountEqual(map(type, thud_parsers), [BarParser, BazParser]) - @mock.patch.object(parsers, "MULTI_RESPONSE_PARSER_FACTORY", - factory.Factory(parsers.MultiResponseParser)) + @mock.patch.object( + parsers, + "MULTI_RESPONSE_PARSER_FACTORY", + factory.Factory(parsers.MultiResponseParser), + ) def testMultiResponseParsers(self): class FooParser(parsers.MultiResponseParser[None]): @@ -107,8 +110,11 @@ def ParseResponses( bar_parsers = bar_factory.MultiResponseParsers() self.assertCountEqual(map(type, bar_parsers), [BarParser]) - @mock.patch.object(parsers, "SINGLE_FILE_PARSER_FACTORY", - factory.Factory(parsers.SingleFileParser)) + @mock.patch.object( + parsers, + "SINGLE_FILE_PARSER_FACTORY", + factory.Factory(parsers.SingleFileParser), + ) def testSingleFileParsers(self): class FooParser(parsers.SingleFileParser): @@ -128,8 +134,11 @@ def ParseFile(self, knowledge_base, pathspec, filedesc): baz_parsers = baz_factory.SingleFileParsers() self.assertCountEqual(map(type, baz_parsers), []) - @mock.patch.object(parsers, "MULTI_FILE_PARSER_FACTORY", - factory.Factory(parsers.MultiFileParser)) + @mock.patch.object( + parsers, + "MULTI_FILE_PARSER_FACTORY", + factory.Factory(parsers.MultiFileParser), + ) def testMultiFileParsers(self): class FooParser(parsers.MultiFileParser[None]): @@ -171,10 +180,16 @@ def ParseFiles( thud_parsers = thud_factory.MultiFileParsers() self.assertCountEqual(map(type, thud_parsers), [BarParser]) - @mock.patch.object(parsers, "SINGLE_FILE_PARSER_FACTORY", - factory.Factory(parsers.SingleFileParser)) - @mock.patch.object(parsers, "MULTI_RESPONSE_PARSER_FACTORY", - factory.Factory(parsers.MultiResponseParser)) + @mock.patch.object( + parsers, + "SINGLE_FILE_PARSER_FACTORY", + factory.Factory(parsers.SingleFileParser), + ) + @mock.patch.object( + parsers, + "MULTI_RESPONSE_PARSER_FACTORY", + factory.Factory(parsers.MultiResponseParser), + ) def testAllParsers(self): class FooParser(parsers.SingleFileParser[None]): diff --git a/grr/core/grr_response_core/lib/parsers/parsers_test_lib.py b/grr/core/grr_response_core/lib/parsers/parsers_test_lib.py index 637eabe627..7a2c525bab 100644 --- a/grr/core/grr_response_core/lib/parsers/parsers_test_lib.py +++ b/grr/core/grr_response_core/lib/parsers/parsers_test_lib.py @@ -1,14 +1,13 @@ #!/usr/bin/env python """Parser testing lib.""" - import io from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import paths as rdf_paths -def GenInit(svc, desc, start=("2", "3", "4", "5"), stop=("1")): +def GenInit(svc, desc, start=("2", "3", "4", "5"), stop="1"): """Generate init file.""" insserv = r""" $local_fs +umountfs @@ -28,7 +27,7 @@ def GenInit(svc, desc, start=("2", "3", "4", "5"), stop=("1")): """ % (svc, " ".join(start), " ".join(stop), desc) return { "/etc/insserv.conf": insserv.encode("utf-8"), - "/etc/init.d/%s" % svc: tmpl.encode("utf-8") + "/etc/init.d/%s" % svc: tmpl.encode("utf-8"), } diff --git a/grr/core/grr_response_core/lib/parsers/registry_init.py b/grr/core/grr_response_core/lib/parsers/registry_init.py deleted file mode 100644 index 1db69dce90..0000000000 --- a/grr/core/grr_response_core/lib/parsers/registry_init.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -"""Loads all the parsers so they are visible in the registry.""" - -# pylint: disable=g-import-not-at-top -# pylint: disable=unused-import -from grr_response_core.lib.parsers import config_file -from grr_response_core.lib.parsers import cron_file_parser -from grr_response_core.lib.parsers import ie_history -from grr_response_core.lib.parsers import linux_cmd_parser -from grr_response_core.lib.parsers import linux_file_parser -from grr_response_core.lib.parsers import linux_pam_parser -from grr_response_core.lib.parsers import linux_release_parser -from grr_response_core.lib.parsers import linux_service_parser -from grr_response_core.lib.parsers import linux_software_parser -from grr_response_core.lib.parsers import linux_sysctl_parser -from grr_response_core.lib.parsers import local -from grr_response_core.lib.parsers import osx_file_parser -from grr_response_core.lib.parsers import osx_launchd -from grr_response_core.lib.parsers import windows_persistence -from grr_response_core.lib.parsers import windows_registry_parser -from grr_response_core.lib.parsers import wmi_parser diff --git a/grr/core/grr_response_core/lib/parsers/windows_persistence.py b/grr/core/grr_response_core/lib/parsers/windows_persistence.py index 5f92a5372e..d8af6767bd 100644 --- a/grr/core/grr_response_core/lib/parsers/windows_persistence.py +++ b/grr/core/grr_response_core/lib/parsers/windows_persistence.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Parse various Windows persistence mechanisms into PersistenceFiles.""" - from typing import Iterator from grr_response_core.lib import artifact_utils @@ -15,8 +14,10 @@ class WindowsPersistenceMechanismsParser( - parsers.SingleResponseParser[rdf_standard.PersistenceFile]): + parsers.SingleResponseParser[rdf_standard.PersistenceFile] +): """Turn various persistence objects into PersistenceFiles.""" + output_types = [rdf_standard.PersistenceFile] supported_artifacts = ["WindowsPersistenceMechanisms"] # Required for environment variable expansion @@ -26,8 +27,9 @@ def _GetFilePaths(self, path, kb): """Guess windows filenames from a commandline string.""" environ_vars = artifact_utils.GetWindowsEnvironmentVariablesMap(kb) - path_guesses = path_detection_windows.DetectExecutablePaths([path], - environ_vars) + path_guesses = path_detection_windows.DetectExecutablePaths( + [path], environ_vars + ) if not path_guesses: # TODO(user): yield a ParserAnomaly object @@ -35,7 +37,8 @@ def _GetFilePaths(self, path, kb): return [ rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.UNSET) + path=path, pathtype=rdf_paths.PathSpec.PathType.UNSET + ) for path in path_guesses ] @@ -53,10 +56,12 @@ def ParseResponse( elif response.HasField("image_path"): pathspecs = self._GetFilePaths(response.image_path, knowledge_base) - if (isinstance(response, rdf_client_fs.StatEntry) and - response.HasField("registry_type")): - pathspecs = self._GetFilePaths(response.registry_data.string, - knowledge_base) + if isinstance(response, rdf_client_fs.StatEntry) and response.HasField( + "registry_type" + ): + pathspecs = self._GetFilePaths( + response.registry_data.string, knowledge_base + ) for pathspec in pathspecs: yield rdf_standard.PersistenceFile(pathspec=pathspec) diff --git a/grr/core/grr_response_core/lib/parsers/windows_persistence_test.py b/grr/core/grr_response_core/lib/parsers/windows_persistence_test.py index 441ba553c4..a04b4711c2 100644 --- a/grr/core/grr_response_core/lib/parsers/windows_persistence_test.py +++ b/grr/core/grr_response_core/lib/parsers/windows_persistence_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for grr.parsers.windows_persistence.""" - from absl import app from grr_response_core.lib.parsers import windows_persistence @@ -17,22 +16,26 @@ class WindowsPersistenceMechanismsParserTest(flow_test_lib.FlowTestsBaseclass): def testParse(self): parser = windows_persistence.WindowsPersistenceMechanismsParser() - path = (r"HKEY_LOCAL_MACHINE\Software\Microsoft\Windows\CurrentVersion" - r"\Run\test") + path = ( + r"HKEY_LOCAL_MACHINE\Software\Microsoft\Windows\CurrentVersion" + r"\Run\test" + ) pathspec = rdf_paths.PathSpec( - path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY) + path=path, pathtype=rdf_paths.PathSpec.PathType.REGISTRY + ) reg_data = "C:\\blah\\some.exe /v" reg_type = rdf_client_fs.StatEntry.RegistryType.REG_SZ stat = rdf_client_fs.StatEntry( pathspec=pathspec, registry_type=reg_type, - registry_data=rdf_protodict.DataBlob(string=reg_data)) + registry_data=rdf_protodict.DataBlob(string=reg_data), + ) persistence = [stat] image_paths = [ "system32\\drivers\\ACPI.sys", "%systemroot%\\system32\\svchost.exe -k netsvcs", - "\\SystemRoot\\system32\\drivers\\acpipmi.sys" + "\\SystemRoot\\system32\\drivers\\acpipmi.sys", ] reg_key = "HKEY_LOCAL_MACHINE/SYSTEM/CurrentControlSet/services/AcpiPmi" for path in image_paths: @@ -40,16 +43,18 @@ def testParse(self): name="blah", display_name="GRRservice", image_path=path, - registry_key=reg_key) + registry_key=reg_key, + ) persistence.append(serv_info) knowledge_base = rdf_client.KnowledgeBase() knowledge_base.environ_systemroot = "C:\\Windows" expected = [ - "C:\\blah\\some.exe", "C:\\Windows\\system32\\drivers\\ACPI.sys", + "C:\\blah\\some.exe", + "C:\\Windows\\system32\\drivers\\ACPI.sys", "C:\\Windows\\system32\\svchost.exe", - "C:\\Windows\\system32\\drivers\\acpipmi.sys" + "C:\\Windows\\system32\\drivers\\acpipmi.sys", ] for index, item in enumerate(persistence): diff --git a/grr/core/grr_response_core/lib/parsers/windows_registry_parser.py b/grr/core/grr_response_core/lib/parsers/windows_registry_parser.py index 38ce65ac73..b402429e57 100644 --- a/grr/core/grr_response_core/lib/parsers/windows_registry_parser.py +++ b/grr/core/grr_response_core/lib/parsers/windows_registry_parser.py @@ -1,22 +1,15 @@ #!/usr/bin/env python """Simple parsers for registry keys and values.""" - -import logging -import os import re -from typing import Iterable -from typing import Iterator from grr_response_core.lib import artifact_utils from grr_response_core.lib import parser from grr_response_core.lib import parsers from grr_response_core.lib import rdfvalue -from grr_response_core.lib import type_info from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs -from grr_response_core.lib.util import precondition SID_RE = re.compile(r"^S-\d-\d+-(\d+-){1,14}\d+$") @@ -27,7 +20,8 @@ class CurrentControlSetKBParser(parser.RegistryValueParser): output_types = [rdfvalue.RDFString] # TODO(user): remove CurrentControlSet after artifact cleanup. supported_artifacts = [ - "WindowsRegistryCurrentControlSet", "CurrentControlSet" + "WindowsRegistryCurrentControlSet", + "CurrentControlSet", ] def Parse(self, stat, unused_knowledge_base): @@ -35,10 +29,12 @@ def Parse(self, stat, unused_knowledge_base): value = stat.registry_data.GetValue() if not str(value).isdigit() or int(value) > 999 or int(value) < 0: - raise parsers.ParseError("Invalid value for CurrentControlSet key %s" % - value) + raise parsers.ParseError( + "Invalid value for CurrentControlSet key %s" % value + ) yield rdfvalue.RDFString( - "HKEY_LOCAL_MACHINE\\SYSTEM\\ControlSet%03d" % int(value)) + "HKEY_LOCAL_MACHINE\\SYSTEM\\ControlSet%03d" % int(value) + ) class WinEnvironmentParser(parser.RegistryValueParser): @@ -60,7 +56,8 @@ def Parse(self, stat, knowledge_base): if not value: raise parsers.ParseError("Invalid value for key %s" % stat.pathspec.path) value = artifact_utils.ExpandWindowsEnvironmentVariables( - value, knowledge_base) + value, knowledge_base + ) if value: yield rdfvalue.RDFString(value) @@ -84,8 +81,9 @@ def Parse(self, stat, _): if re.match(r"^[A-Za-z]:$", systemdrive): yield rdfvalue.RDFString(systemdrive) else: - raise parsers.ParseError("Bad drive letter for key %s" % - stat.pathspec.path) + raise parsers.ParseError( + "Bad drive letter for key %s" % stat.pathspec.path + ) class WinSystemRootParser(parser.RegistryValueParser): @@ -127,7 +125,8 @@ def Parse(self, stat, knowledge_base): # Provide a default, if the registry value is not available. value = "%SystemDrive%\\Documents and Settings" interpolated_value = artifact_utils.ExpandWindowsEnvironmentVariables( - value, knowledge_base) + value, knowledge_base + ) yield rdfvalue.RDFString(interpolated_value) @@ -141,7 +140,8 @@ class AllUsersProfileEnvironmentVariable(parser.RegistryParser): def Parse(self, stat, knowledge_base): value = stat.registry_data.GetValue() or "All Users" all_users_dir = artifact_utils.ExpandWindowsEnvironmentVariables( - "%ProfilesDirectory%\\" + value, knowledge_base) + "%ProfilesDirectory%\\" + value, knowledge_base + ) yield rdfvalue.RDFString(all_users_dir) @@ -151,6 +151,7 @@ class WinUserSids(parser.RegistryParser): This reads a listing of the profile paths to extract a list of SIDS for users with profiles on a system. """ + output_types = [rdf_client.User] supported_artifacts = ["WindowsRegistryProfiles"] @@ -181,163 +182,6 @@ def Parse(self, stat, knowledge_base): yield kb_user -class WinUserSpecialDirs(parser.RegistryMultiParser): - r"""Parser for extracting special folders from registry. - - Keys will come from HKEY_USERS and will list the Shell Folders and user's - Environment key. We extract each subkey that matches on of our knowledge base - attributes. - - Known folder GUIDs: - http://msdn.microsoft.com/en-us/library/windows/desktop/dd378457(v=vs.85).aspx - """ - output_types = [rdf_client.User] - supported_artifacts = ["WindowsUserShellFolders"] - # Required for environment variable expansion - knowledgebase_dependencies = [ - "environ_systemdrive", "environ_systemroot", "users.userprofile" - ] - - key_var_mapping = { - "Shell Folders": { - "{A520A1A4-1780-4FF6-BD18-167343C5AF16}": "localappdata_low", - "Desktop": "desktop", - "AppData": "appdata", - "Local AppData": "localappdata", - "Cookies": "cookies", - "Cache": "internet_cache", - "Recent": "recent", - "Startup": "startup", - "Personal": "personal", - }, - "Environment": { - "TEMP": "temp", - }, - "Volatile Environment": { - "USERDOMAIN": "userdomain", - }, - } - - def ParseMultiple(self, stats, knowledge_base): - """Parse each returned registry value.""" - user_dict = {} - - for stat in stats: - sid_str = stat.pathspec.path.split("/", 3)[2] - if SID_RE.match(sid_str): - if sid_str not in user_dict: - user_dict[sid_str] = rdf_client.User(sid=sid_str) - - if stat.registry_data.GetValue(): - # Look up in the mapping if we can use this entry to populate a user - # attribute, and if so, set it. - reg_key_name = stat.pathspec.Dirname().Basename() - if reg_key_name in self.key_var_mapping: - map_dict = self.key_var_mapping[reg_key_name] - reg_key = stat.pathspec.Basename() - kb_attr = map_dict.get(reg_key) - if kb_attr: - value = artifact_utils.ExpandWindowsEnvironmentVariables( - stat.registry_data.GetValue(), knowledge_base) - value = artifact_utils.ExpandWindowsUserEnvironmentVariables( - value, knowledge_base, sid=sid_str) - user_dict[sid_str].Set(kb_attr, value) - - # Now yield each user we found. - return user_dict.values() - - -class WinServicesParser( - parsers.MultiResponseParser[rdf_client.WindowsServiceInformation]): - """Parser for Windows services values from the registry. - - See service key doco: - http://support.microsoft.com/kb/103000 - """ - - output_types = [rdf_client.WindowsServiceInformation] - supported_artifacts = ["WindowsServices"] - - def __init__(self): - # The key can be "services" or "Services" on different versions of windows. - self.service_re = re.compile( - r".*HKEY_LOCAL_MACHINE/SYSTEM/[^/]+/services/([^/]+)(/(.*))?$", - re.IGNORECASE) - super().__init__() - - def _GetServiceName(self, path): - return self.service_re.match(path).group(1) - - def _GetKeyName(self, path): - key_name = self.service_re.match(path).group(3) - if key_name is None: - return None - return key_name.lower() - - def ParseResponses( - self, - knowledge_base: rdf_client.KnowledgeBase, - responses: Iterable[rdfvalue.RDFValue], - ) -> Iterator[rdf_client.WindowsServiceInformation]: - """Parse Service registry keys and return WindowsServiceInformation.""" - del knowledge_base # Unused. - precondition.AssertIterableType(responses, rdf_client_fs.StatEntry) - - services = {} - field_map = { - "Description": "description", - "DisplayName": "display_name", - "Group": "group_name", - "DriverPackageId": "driver_package_id", - "ErrorControl": "error_control", - "ImagePath": "image_path", - "ObjectName": "object_name", - "Start": "startup_type", - "Type": "service_type", - "Parameters/ServiceDLL": "service_dll" - } - - # Field map key should be converted to lowercase because key acquired - # through self._GetKeyName could have some characters in different - # case than the field map, e.g. ServiceDLL and ServiceDll. - field_map = {k.lower(): v for k, v in field_map.items()} - for stat in responses: - - # Ignore subkeys - if not stat.HasField("registry_data"): - continue - - service_name = self._GetServiceName(stat.pathspec.path) - reg_key = os.path.dirname(stat.pathspec.path) - service_info = rdf_client.WindowsServiceInformation( - name=service_name, registry_key=reg_key) - services.setdefault(service_name, service_info) - - key = self._GetKeyName(stat.pathspec.path) - - if key in field_map: - try: - services[service_name].Set(field_map[key], - stat.registry_data.GetValue()) - except type_info.TypeValueError: - - # Flatten multi strings into a simple string - if (stat.registry_type == - rdf_client_fs.StatEntry.RegistryType.REG_MULTI_SZ): - services[service_name].Set( - field_map[key], - utils.SmartUnicode(stat.registry_data.GetValue())) - else: - # Log failures for everything else - # TODO(user): change this to yield a ParserAnomaly object. - dest_type = type(services[service_name].Get(field_map[key])) - logging.debug("Wrong type set for %s:%s, expected %s, got %s", - stat.pathspec.path, stat.registry_data.GetValue(), - dest_type, type(stat.registry_data.GetValue())) - - return services.values() - - class WinTimezoneParser(parser.RegistryValueParser): """Parser for TimeZoneKeyName value.""" @@ -549,36 +393,3 @@ def Parse(self, stat, knowledge_base): "Central Standard Time": "CST6CDT", "Pacific Standard Time": "PST8PDT", } - - -class WindowsRegistryInstalledSoftwareParser(parser.RegistryMultiParser): - """Parser registry uninstall keys yields rdf_client.SoftwarePackages.""" - output_types = [rdf_client.SoftwarePackages] - supported_artifacts = ["WindowsUninstallKeys"] - - def ParseMultiple(self, stats, kb): - del kb # unused - - apps = {} - for stat in stats: - matches = re.search(r"/CurrentVersion/Uninstall/([^/]+)/([^$]+)", - stat.pathspec.path) - if not matches: - continue - app_name, key = matches.groups() - apps.setdefault(app_name, {})[key] = stat.registry_data.GetValue() - - packages = [] - for key, app in apps.items(): - if "DisplayName" not in app: - continue - packages.append( - rdf_client.SoftwarePackage.Installed( - name=app.get("DisplayName"), - description=app.get("Publisher", ""), - version=app.get("DisplayVersion", ""))) - - if packages: - return [rdf_client.SoftwarePackages(packages=packages)] - - return [] diff --git a/grr/core/grr_response_core/lib/parsers/windows_registry_parser_test.py b/grr/core/grr_response_core/lib/parsers/windows_registry_parser_test.py index e0b9f9e0ed..c6d4022cf6 100644 --- a/grr/core/grr_response_core/lib/parsers/windows_registry_parser_test.py +++ b/grr/core/grr_response_core/lib/parsers/windows_registry_parser_test.py @@ -1,11 +1,9 @@ #!/usr/bin/env python """Tests for grr.parsers.windows_registry_parser.""" - from absl import app from grr_response_core.lib.parsers import windows_registry_parser -from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import protodict as rdf_protodict @@ -20,137 +18,31 @@ def _MakeRegStat(self, path, value, registry_type): pathspec = rdf_paths.PathSpec( path=path, path_options=options, - pathtype=rdf_paths.PathSpec.PathType.REGISTRY) + pathtype=rdf_paths.PathSpec.PathType.REGISTRY, + ) if registry_type == rdf_client_fs.StatEntry.RegistryType.REG_MULTI_SZ: reg_data = rdf_protodict.DataBlob( list=rdf_protodict.BlobArray( - content=[rdf_protodict.DataBlob(string=value)])) + content=[rdf_protodict.DataBlob(string=value)] + ) + ) else: reg_data = rdf_protodict.DataBlob().SetValue(value) return rdf_client_fs.StatEntry( - pathspec=pathspec, registry_data=reg_data, registry_type=registry_type) - - def testGetServiceName(self): - hklm = "HKEY_LOCAL_MACHINE/SYSTEM/CurrentControlSet/services" - parser = windows_registry_parser.WinServicesParser() - self.assertEqual( - parser._GetServiceName("%s/SomeService/Start" % hklm), "SomeService") - self.assertEqual( - parser._GetServiceName("%s/SomeService/Parameters/ServiceDLL" % hklm), - "SomeService") - - def testWinServicesParser(self): - dword = rdf_client_fs.StatEntry.RegistryType.REG_DWORD_LITTLE_ENDIAN - reg_str = rdf_client_fs.StatEntry.RegistryType.REG_SZ - hklm = "HKEY_LOCAL_MACHINE/SYSTEM/CurrentControlSet/Services" - hklm_set01 = "HKEY_LOCAL_MACHINE/SYSTEM/CurrentControlSet/services" - service_keys = [ - ("%s/ACPI/Type" % hklm, 1, dword), - ("%s/ACPI/Start" % hklm, 0, dword), - # This one is broken, the parser should just ignore it. - ("%s/notarealservice" % hklm, 3, dword), - ("%s/ACPI/ErrorControl" % hklm, 3, dword), - ("%s/ACPI/ImagePath" % hklm, "system32\\drivers\\ACPI.sys", reg_str), - ("%s/ACPI/DisplayName" % hklm, "Microsoft ACPI Driver", reg_str), - ("%s/ACPI/Group" % hklm, "Boot Bus Extender", reg_str), - ("%s/ACPI/DriverPackageId" % hklm, - "acpi.inf_amd64_neutral_99aaaaabcccccccc", reg_str), - ("%s/AcpiPmi/Start" % hklm_set01, 3, dword), - ("%s/AcpiPmi/DisplayName" % hklm_set01, "AcpiPmi", - rdf_client_fs.StatEntry.RegistryType.REG_MULTI_SZ), - (u"%s/中国日报/DisplayName" % hklm, u"中国日报", reg_str), - (u"%s/中国日报/Parameters/ServiceDLL" % hklm, "blah.dll", reg_str) - ] - - stats = [self._MakeRegStat(*x) for x in service_keys] - parser = windows_registry_parser.WinServicesParser() - results = parser.ParseResponses(None, stats) - - names = [] - for result in results: - if result.display_name == u"中国日报": - self.assertEqual(result.display_name, u"中国日报") - self.assertEqual(result.service_dll, "blah.dll") - names.append(result.display_name) - elif str(result.registry_key).endswith("AcpiPmi"): - self.assertEqual(result.name, "AcpiPmi") - self.assertEqual(result.startup_type, 3) - self.assertEqual(result.display_name, "['AcpiPmi']") - self.assertEqual(result.registry_key, "%s/AcpiPmi" % hklm_set01) - names.append(result.display_name) - elif str(result.registry_key).endswith("ACPI"): - self.assertEqual(result.name, "ACPI") - self.assertEqual(result.service_type, 1) - self.assertEqual(result.startup_type, 0) - self.assertEqual(result.error_control, 3) - self.assertEqual(result.image_path, "system32\\drivers\\ACPI.sys") - self.assertEqual(result.display_name, "Microsoft ACPI Driver") - self.assertEqual(result.group_name, "Boot Bus Extender") - self.assertEqual(result.driver_package_id, - "acpi.inf_amd64_neutral_99aaaaabcccccccc") - names.append(result.display_name) - self.assertCountEqual(names, - [u"中国日报", "['AcpiPmi']", "Microsoft ACPI Driver"]) - - def testWinUserSpecialDirs(self): - reg_str = rdf_client_fs.StatEntry.RegistryType.REG_SZ - hk_u = "registry/HKEY_USERS/S-1-1-1010-10101-1010" - service_keys = [("%s/Environment/TEMP" % hk_u, r"temp\path", reg_str), - ("%s/Volatile Environment/USERDOMAIN" % hk_u, "GEVULOT", - reg_str)] - - stats = [self._MakeRegStat(*x) for x in service_keys] - parser = windows_registry_parser.WinUserSpecialDirs() - results = list(parser.ParseMultiple(stats, None)) - self.assertEqual(results[0].temp, r"temp\path") - self.assertEqual(results[0].userdomain, "GEVULOT") + pathspec=pathspec, registry_data=reg_data, registry_type=registry_type + ) def testWinSystemDriveParser(self): - sysroot = (r"HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\Windows NT" - r"\CurrentVersion\SystemRoot") + sysroot = ( + r"HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\Windows NT" + r"\CurrentVersion\SystemRoot" + ) stat = self._MakeRegStat(sysroot, r"C:\Windows", None) parser = windows_registry_parser.WinSystemDriveParser() self.assertEqual(r"C:", next(parser.Parse(stat, None))) - def testWindowsRegistryInstalledSoftware(self): - reg_str = rdf_client_fs.StatEntry.RegistryType.REG_SZ - hklm = "HKEY_LOCAL_MACHINE" - k = hklm + "/Software/Microsoft/Windows/CurrentVersion/Uninstall" - service_keys = [ - # Valid. - (k + "/Google Chrome/DisplayName", "Google Chrome", reg_str), - (k + "/Google Chrome/DisplayVersion", "89.0.4389.82", reg_str), - (k + "/Google Chrome/Publisher", "Google LLC", reg_str), - # Invalid - Contains no data. - (k + "/AddressBook/Default", "", reg_str), - # Invalid - Missing DisplayName. - (k + "/Foo/DisplayVersion", "1.2.3.4", reg_str), - (k + "/Foo/Publisher", "Bar Inc", reg_str), - # Valid. - (k + "/Baz/DisplayName", "Baz", reg_str), - (k + "/Baz/DisplayVersion", "2.3.4.5", reg_str), - (k + "/Baz/Publisher", "Baz LLC", reg_str), - ] - - stats = [self._MakeRegStat(*x) for x in service_keys] - parser = windows_registry_parser.WindowsRegistryInstalledSoftwareParser() - - got = parser.ParseMultiple(stats, None) # KnowledgeBase is not used. - want = [ - rdf_client.SoftwarePackages(packages=[ - rdf_client.SoftwarePackage.Installed( - name="Google Chrome", - description="Google LLC", - version="89.0.4389.82"), - rdf_client.SoftwarePackage.Installed( - name="Baz", description="Baz LLC", version="2.3.4.5"), - ]) - ] - - self.assertEqual(want, got) - def main(argv): test_lib.main(argv) diff --git a/grr/core/grr_response_core/lib/parsers/wmi_parser.py b/grr/core/grr_response_core/lib/parsers/wmi_parser.py index 75986baaa5..38becec47f 100644 --- a/grr/core/grr_response_core/lib/parsers/wmi_parser.py +++ b/grr/core/grr_response_core/lib/parsers/wmi_parser.py @@ -1,18 +1,13 @@ #!/usr/bin/env python """Simple parsers for the output of WMI queries.""" - import calendar import struct import time - from grr_response_core.lib import parser -from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs -from grr_response_core.lib.rdfvalues import wmi as rdf_wmi from grr_response_core.lib.util import precondition @@ -42,7 +37,7 @@ def BinarySIDtoStringSID(sid): precondition.AssertType(sid, bytes) if not sid: - return u"" + return "" str_sid_components = [sid[0]] # Now decode the 48-byte portion @@ -57,87 +52,22 @@ def BinarySIDtoStringSID(sid): start = 8 for i in range(subauthority_count): - authority = sid[start:start + 4] + authority = sid[start : start + 4] if not authority: break if len(authority) < 4: - message = ("In binary SID '%s', component %d has been truncated. " - "Expected 4 bytes, found %d: (%s)") + message = ( + "In binary SID '%s', component %d has been truncated. " + "Expected 4 bytes, found %d: (%s)" + ) message %= (sid, i, len(authority), authority) raise ValueError(message) str_sid_components.append(struct.unpack(" "ArtifactProcessorDescriptor": + cls, parser_cls: Type[parsers.Parser] + ) -> "ArtifactProcessorDescriptor": """Creates a descriptor corresponding to the given parser. Args: @@ -379,7 +343,8 @@ def FromParser( return cls( name=parser_cls.__name__, description=description, - output_types=output_types) + output_types=output_types, + ) class ArtifactDescriptor(rdf_structs.RDFProtoStruct): @@ -409,18 +374,21 @@ def Validate(self): class ArtifactProgress(rdf_structs.RDFProtoStruct): """Collection progress of an Artifact.""" + protobuf = flows_pb2.ArtifactProgress rdf_deps = [] class ArtifactCollectorFlowProgress(rdf_structs.RDFProtoStruct): """Collection progress of ArtifactCollectorFlow.""" + protobuf = flows_pb2.ArtifactCollectorFlowProgress rdf_deps = [ArtifactProgress] class ClientActionResult(rdf_structs.RDFProtoStruct): """An RDFValue representing one type of response for a client action.""" + protobuf = artifact_pb2.ClientActionResult def GetValueClass(self): diff --git a/grr/core/grr_response_core/lib/rdfvalues/benchmark_test.py b/grr/core/grr_response_core/lib/rdfvalues/benchmark_test.py index 5a9e231f5a..52b8f9b6c3 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/benchmark_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/benchmark_test.py @@ -1,8 +1,6 @@ #!/usr/bin/env python """This module tests the RDFValue implementation for performance.""" -from typing import Text - from absl import app from grr_response_core.lib import type_info @@ -21,30 +19,40 @@ class StructGrrMessage(rdf_structs.RDFProtoStruct): rdf_structs.ProtoString( name="session_id", field_number=1, - description="Every Flow has a unique session id."), + description="Every Flow has a unique session id.", + ), rdf_structs.ProtoUnsignedInteger( name="request_id", field_number=2, - description="This message is in response to this request number"), + description="This message is in response to this request number", + ), rdf_structs.ProtoUnsignedInteger( name="response_id", field_number=3, - description="Responses for each request."), + description="Responses for each request.", + ), rdf_structs.ProtoString( name="name", field_number=4, - description=("This is the name of the client action that will be " - "executed. It is set by the flow and is executed by " - "the client.")), + description=( + "This is the name of the client action that will be " + "executed. It is set by the flow and is executed by " + "the client." + ), + ), rdf_structs.ProtoBinary( name="args", field_number=5, - description="This field contains an encoded rdfvalue."), + description="This field contains an encoded rdfvalue.", + ), rdf_structs.ProtoString( name="source", field_number=6, - description=("Client name where the message came from (This is " - "copied from the MessageList)")), + description=( + "Client name where the message came from (This is " + "copied from the MessageList)" + ), + ), ) @@ -54,7 +62,10 @@ class FastGrrMessageList(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoList( rdf_structs.ProtoEmbedded( - name="job", field_number=1, nested=StructGrrMessage))) + name="job", field_number=1, nested=StructGrrMessage + ) + ) + ) class RDFValueBenchmark(benchmark_test_lib.AverageMicroBenchmarks): @@ -64,12 +75,13 @@ class RDFValueBenchmark(benchmark_test_lib.AverageMicroBenchmarks): units = "us" USER_ACCOUNT = dict( - username=u"user", - full_name=u"John Smith", + username="user", + full_name="John Smith", last_logon=10000, - userdomain=u"Some domain name", - homedir=u"/home/user", - sid=u"some sid") + userdomain="Some domain name", + homedir="/home/user", + sid="some sid", + ) def testObjectCreation(self): """Compare the speed of object creation to raw protobufs.""" @@ -107,34 +119,45 @@ def ProtoCreateAndSerializeFromProto(): s.ParseFromString(test_proto) self.assertEqual(s.SerializeToString(), test_proto) - self.TimeIt(RDFStructCreateAndSerialize, - "SProto Create from keywords and serialize.") + self.TimeIt( + RDFStructCreateAndSerialize, + "SProto Create from keywords and serialize.", + ) - self.TimeIt(RDFStructCreateAndSerializeSetValue, - "SProto Create, Set And Serialize") + self.TimeIt( + RDFStructCreateAndSerializeSetValue, "SProto Create, Set And Serialize" + ) - self.TimeIt(RDFStructCreateAndSerializeFromProto, - "SProto from serialized and serialize.") + self.TimeIt( + RDFStructCreateAndSerializeFromProto, + "SProto from serialized and serialize.", + ) - self.TimeIt(ProtoCreateAndSerialize, - "Protobuf from keywords and serialize.") + self.TimeIt( + ProtoCreateAndSerialize, "Protobuf from keywords and serialize." + ) - self.TimeIt(ProtoCreateAndSerializeSetValue, - "Protobuf Create, Set and serialize") + self.TimeIt( + ProtoCreateAndSerializeSetValue, "Protobuf Create, Set and serialize" + ) - self.TimeIt(ProtoCreateAndSerializeFromProto, - "Protobuf from serialized and serialize.") + self.TimeIt( + ProtoCreateAndSerializeFromProto, + "Protobuf from serialized and serialize.", + ) def testObjectCreation2(self): def ProtoCreateAndSerialize(): s = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) return len(s.SerializeToString()) def RDFStructCreateAndSerialize(): s = StructGrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) return len(s.SerializeToBytes()) @@ -189,12 +212,12 @@ def RDFStructCreateAndSerialize(): self.TimeIt( RDFStructCreateAndSerialize, "RDFStruct Repeated Fields", - repetitions=repeats) + repetitions=repeats, + ) self.TimeIt( - ProtoCreateAndSerialize, - "Protobuf Repeated Fields", - repetitions=repeats) + ProtoCreateAndSerialize, "Protobuf Repeated Fields", repetitions=repeats + ) # Check that we can unserialize a protobuf encoded using the standard # library. @@ -213,7 +236,8 @@ def testDecode(self): """Test decoding performance.""" s = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) data = s.SerializeToString() def ProtoDecode(): @@ -221,13 +245,13 @@ def ProtoDecode(): new_s.ParseFromString(data) self.assertEqual(new_s.session_id, "session") - self.assertIsInstance(new_s.session_id, Text) + self.assertIsInstance(new_s.session_id, str) def RDFStructDecode(): new_s = StructGrrMessage.FromSerializedBytes(data) self.assertEqual(new_s.session_id, "session") - self.assertIsInstance(new_s.session_id, Text) + self.assertIsInstance(new_s.session_id, str) self.TimeIt(RDFStructDecode) self.TimeIt(ProtoDecode) @@ -248,13 +272,13 @@ def ProtoDecode(): new_s.ParseFromString(data) self.assertEqual(new_s.username, "user") - self.assertIsInstance(new_s.username, Text) + self.assertIsInstance(new_s.username, str) def RDFStructDecode(): new_s = rdf_client.User.FromSerializedBytes(data) self.assertEqual(new_s.username, "user") - self.assertIsInstance(new_s.username, Text) + self.assertIsInstance(new_s.username, str) self.TimeIt(RDFStructDecode) self.TimeIt(ProtoDecode) @@ -262,19 +286,22 @@ def RDFStructDecode(): def testEncode(self): """Comparing encoding speed of a typical protobuf.""" s = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) serialized = s.SerializeToString() def ProtoEncode(): s1 = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) test = s1.SerializeToString() self.assertLen(serialized, len(test)) def RDFStructEncode(): s2 = StructGrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) test = s2.SerializeToBytes() self.assertLen(serialized, len(test)) @@ -287,17 +314,18 @@ def testEncodeDecode(self): def Check(s, new_s): self.assertEqual(s.name, new_s.name) - self.assertEqual(s.name, u"foo") + self.assertEqual(s.name, "foo") self.assertEqual(s.request_id, new_s.request_id) self.assertEqual(s.request_id, 1) self.assertEqual(s.response_id, new_s.response_id) self.assertEqual(s.response_id, 1) self.assertEqual(s.session_id, new_s.session_id) - self.assertEqual(s.session_id, u"session") + self.assertEqual(s.session_id, "session") def ProtoEncodeDecode(): s = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) data = s.SerializeToString() new_s = jobs_pb2.GrrMessage() @@ -307,7 +335,8 @@ def ProtoEncodeDecode(): def RDFStructEncodeDecode(): s = StructGrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) data = s.SerializeToBytes() new_s = StructGrrMessage.FromSerializedBytes(data) @@ -325,7 +354,8 @@ def testDecodeEncode(self): """Test performance of decode/encode cycle.""" s = jobs_pb2.GrrMessage( - name=u"foo", request_id=1, response_id=1, session_id=u"session") + name="foo", request_id=1, response_id=1, session_id="session" + ) data = s.SerializeToString() def ProtoDecodeEncode(): diff --git a/grr/core/grr_response_core/lib/rdfvalues/client.py b/grr/core/grr_response_core/lib/rdfvalues/client.py index 940da76fa6..036eb6747d 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client.py @@ -5,13 +5,10 @@ client. """ -import binascii -import hashlib import logging import platform import re import socket -import struct import sys from typing import Mapping @@ -26,7 +23,6 @@ from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.rdfvalues import structs as rdf_structs -from grr_response_core.lib.util import text from grr_response_proto import jobs_pb2 from grr_response_proto import knowledge_base_pb2 from grr_response_proto import sysinfo_pb2 @@ -62,18 +58,21 @@ def __init__(self, initializer=None): @classmethod def _Normalize(cls, string): - normalized = super(ClientURN, cls)._Normalize(string.strip()) + normalized = super()._Normalize(string.strip()) if normalized: match = cls.CLIENT_ID_RE.match(normalized) if not match: raise type_info.TypeValueError( "Client URN '{!r} from initializer {!r} malformed".format( - normalized, string)) + normalized, string + ) + ) clientid = match.group("clientid") clientid_correctcase = "".join( - (clientid[0].upper(), clientid[1:].lower())) + (clientid[0].upper(), clientid[1:].lower()) + ) normalized = normalized.replace(clientid, clientid_correctcase, 1) return normalized @@ -85,25 +84,6 @@ def Validate(cls, value): return False - @classmethod - def FromPrivateKey(cls, private_key): - return cls.FromPublicKey(private_key.GetPublicKey()) - - @classmethod - def FromPublicKey(cls, public_key): - """An alternate constructor which generates a new client id.""" - # Our CN will be the first 64 bits of the hash of the public key - # in MPI format - the length of the key in 4 bytes + the key - # prefixed with a 0. This weird format is an artifact from the way - # M2Crypto handled this, we have to live with it for now. - n = public_key.GetN() - raw_n = binascii.unhexlify("%x" % n) - - mpi_format = struct.pack(">i", len(raw_n) + 1) + b"\x00" + raw_n - - digest = text.Hexify(hashlib.sha256(mpi_format).digest()[:8]) - return cls("C.{}".format(digest)) - def Add(self, path): """Add a relative stem to the current value and return a new RDFURN. @@ -125,19 +105,12 @@ def Add(self, path): return rdfvalue.RDFURN(utils.JoinPath(self._value, path)) -class PCIDevice(rdf_structs.RDFProtoStruct): - """A PCI device on the client. - - This class describes a PCI device located on the client. - """ - protobuf = sysinfo_pb2.PCIDevice - - class PackageRepository(rdf_structs.RDFProtoStruct): """Description of the configured repositories (Yum etc). Describes the configured software package repositories. """ + protobuf = sysinfo_pb2.PackageRepository @@ -147,6 +120,7 @@ class ManagementAgent(rdf_structs.RDFProtoStruct): Describes the state, last run timestamp, and name of the management agent installed on the system. """ + protobuf = sysinfo_pb2.ManagementAgent rdf_deps = [ rdfvalue.RDFDatetime, @@ -155,11 +129,13 @@ class ManagementAgent(rdf_structs.RDFProtoStruct): class PwEntry(rdf_structs.RDFProtoStruct): """Information about password structures.""" + protobuf = knowledge_base_pb2.PwEntry class Group(rdf_structs.RDFProtoStruct): """Information about system posix groups.""" + protobuf = knowledge_base_pb2.Group rdf_deps = [ PwEntry, @@ -168,6 +144,7 @@ class Group(rdf_structs.RDFProtoStruct): class User(rdf_structs.RDFProtoStruct): """Information about the users.""" + protobuf = knowledge_base_pb2.User rdf_deps = [ PwEntry, @@ -194,6 +171,7 @@ class KnowledgeBaseUser(User): class KnowledgeBase(rdf_structs.RDFProtoStruct): """Information about the system and users.""" + protobuf = knowledge_base_pb2.KnowledgeBase rdf_deps = [ User, @@ -214,7 +192,8 @@ def MergeOrAddUser(self, kb_user): """ user = self.GetUser( - sid=kb_user.sid, uid=kb_user.uid, username=kb_user.username) + sid=kb_user.sid, uid=kb_user.uid, username=kb_user.username + ) new_attrs = [] merge_conflicts = [] # Record when we overwrite a value. if not user: @@ -293,16 +272,19 @@ def GetKbFieldNames(self): class HardwareInfo(rdf_structs.RDFProtoStruct): """Various hardware information.""" + protobuf = sysinfo_pb2.HardwareInfo class ClientInformation(rdf_structs.RDFProtoStruct): """The GRR client information.""" + protobuf = jobs_pb2.ClientInformation class BufferReference(rdf_structs.RDFProtoStruct): """Stores information about a buffer in a file on the client.""" + protobuf = jobs_pb2.BufferReference rdf_deps = [ rdf_paths.PathSpec, @@ -314,6 +296,7 @@ def __eq__(self, other): class Process(rdf_structs.RDFProtoStruct): """Represent a process on the client.""" + protobuf = sysinfo_pb2.Process rdf_deps = [ rdf_client_network.NetworkConnection, @@ -361,10 +344,12 @@ def FromPsutilProcess(cls, psutil_process): try: # Not available on Windows. if hasattr(psutil_process, "uids"): - (response.real_uid, response.effective_uid, - response.saved_uid) = psutil_process.uids() - (response.real_gid, response.effective_gid, - response.saved_gid) = psutil_process.gids() + (response.real_uid, response.effective_uid, response.saved_uid) = ( + psutil_process.uids() + ) + (response.real_gid, response.effective_gid, response.saved_gid) = ( + psutil_process.gids() + ) except (psutil.NoSuchProcess, psutil.AccessDenied): pass @@ -427,13 +412,15 @@ def FromPsutilProcess(cls, psutil_process): try: for c in psutil_process.connections(): conn = response.connections.Append( - family=c.family, type=c.type, pid=psutil_process.pid) + family=c.family, type=c.type, pid=psutil_process.pid + ) try: conn.state = c.status except ValueError: - logging.info("Encountered unknown connection status (%s).", - c.status) + logging.info( + "Encountered unknown connection status (%s).", c.status + ) try: conn.local_address.ip, conn.local_address.port = c.laddr @@ -446,8 +433,9 @@ def FromPsutilProcess(cls, psutil_process): # Could be in state LISTEN. if c.remote_address: - (conn.remote_address.ip, - conn.remote_address.port) = c.remote_address + (conn.remote_address.ip, conn.remote_address.port) = ( + c.remote_address + ) except (psutil.NoSuchProcess, psutil.AccessDenied): pass @@ -464,26 +452,31 @@ class NamedPipe(rdf_structs.RDFProtoStruct): class SoftwarePackage(rdf_structs.RDFProtoStruct): """Represent an installed package on the client.""" + protobuf = sysinfo_pb2.SoftwarePackage @classmethod def Installed(cls, **kwargs): return SoftwarePackage( - install_state=SoftwarePackage.InstallState.INSTALLED, **kwargs) + install_state=SoftwarePackage.InstallState.INSTALLED, **kwargs + ) @classmethod def Pending(cls, **kwargs): return SoftwarePackage( - install_state=SoftwarePackage.InstallState.PENDING, **kwargs) + install_state=SoftwarePackage.InstallState.PENDING, **kwargs + ) @classmethod def Uninstalled(cls, **kwargs): return SoftwarePackage( - install_state=SoftwarePackage.InstallState.UNINSTALLED, **kwargs) + install_state=SoftwarePackage.InstallState.UNINSTALLED, **kwargs + ) class SoftwarePackages(rdf_structs.RDFProtoStruct): """A list of installed packages on the system.""" + protobuf = sysinfo_pb2.SoftwarePackages rdf_deps = [ @@ -493,11 +486,13 @@ class SoftwarePackages(rdf_structs.RDFProtoStruct): class LogMessage(rdf_structs.RDFProtoStruct): """A log message sent from the client to the server.""" + protobuf = jobs_pb2.LogMessage class Uname(rdf_structs.RDFProtoStruct): """A protobuf to represent the current system.""" + protobuf = jobs_pb2.Uname rdf_deps = [ rdfvalue.RDFDatetime, @@ -524,8 +519,10 @@ def signature(self): if result: return result - raise ValueError("PEP 425 Signature not set - this is likely an old " - "component file, please back it up and remove it.") + raise ValueError( + "PEP 425 Signature not set - this is likely an old " + "component file, please back it up and remove it." + ) @classmethod def FromCurrentSystem(cls): @@ -560,8 +557,11 @@ def FromCurrentSystem(cls): # 0.34.2 pep_platform = pep425tags.get_platform(None) pep425tag = "%s%s-%s-%s" % ( - pep425tags.get_abbr_impl(), pep425tags.get_impl_ver(), - str(pep425tags.get_abi_tag()).lower(), pep_platform) + pep425tags.get_abbr_impl(), + pep425tags.get_impl_ver(), + str(pep425tags.get_abi_tag()).lower(), + pep_platform, + ) else: # For example: windows_7_amd64 pep425tag = "%s_%s_%s" % (system, release, architecture) @@ -580,6 +580,7 @@ def FromCurrentSystem(cls): class StartupInfo(rdf_structs.RDFProtoStruct): """Information about the startup of a GRR agent.""" + protobuf = jobs_pb2.StartupInfo rdf_deps = [ ClientInformation, @@ -589,6 +590,7 @@ class StartupInfo(rdf_structs.RDFProtoStruct): class WindowsServiceInformation(rdf_structs.RDFProtoStruct): """Windows Service.""" + protobuf = sysinfo_pb2.WindowsServiceInformation rdf_deps = [ rdf_protodict.Dict, @@ -598,6 +600,7 @@ class WindowsServiceInformation(rdf_structs.RDFProtoStruct): class OSXServiceInformation(rdf_structs.RDFProtoStruct): """OSX Service (launchagent/daemon).""" + protobuf = sysinfo_pb2.OSXServiceInformation rdf_deps = [ rdfvalue.RDFURN, @@ -606,6 +609,7 @@ class OSXServiceInformation(rdf_structs.RDFProtoStruct): class LinuxServiceInformation(rdf_structs.RDFProtoStruct): """Linux Service (init/upstart/systemd).""" + protobuf = sysinfo_pb2.LinuxServiceInformation rdf_deps = [ rdf_protodict.AttributedDict, @@ -621,11 +625,13 @@ class RunKey(rdf_structs.RDFProtoStruct): class RunKeyEntry(rdf_protodict.RDFValueArray): """Structure of a Run Key entry with keyname, filepath, and last written.""" + rdf_type = RunKey class ClientCrash(rdf_structs.RDFProtoStruct): """Details of a client crash.""" + protobuf = jobs_pb2.ClientCrash rdf_deps = [ ClientInformation, @@ -644,11 +650,13 @@ class EdrAgent(rdf_structs.RDFProtoStruct): class FleetspeakValidationInfoTag(rdf_structs.RDFProtoStruct): """Dictionary entry in FleetspeakValidationInfo.""" + protobuf = jobs_pb2.FleetspeakValidationInfoTag class FleetspeakValidationInfo(rdf_structs.RDFProtoStruct): """Dictionary-like struct containing Fleetspeak ValidationInfo.""" + protobuf = jobs_pb2.FleetspeakValidationInfo rdf_deps = [FleetspeakValidationInfoTag] @@ -665,6 +673,7 @@ def ToStringDict(self) -> Mapping[str, str]: class ClientSummary(rdf_structs.RDFProtoStruct): """Object containing client's summary data.""" + protobuf = jobs_pb2.ClientSummary rdf_deps = [ ClientInformation, diff --git a/grr/core/grr_response_core/lib/rdfvalues/client_action.py b/grr/core/grr_response_core/lib/rdfvalues/client_action.py index 1da722c510..b6bd70823f 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client_action.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client_action.py @@ -1,12 +1,10 @@ #!/usr/bin/env python """Client actions requests and responses.""" - from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.rdfvalues import structs as rdf_structs - from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 @@ -51,6 +49,7 @@ class ExecuteResponse(rdf_structs.RDFProtoStruct): class Iterator(rdf_structs.RDFProtoStruct): """An Iterated client action is one which can be resumed on the client.""" + protobuf = jobs_pb2.Iterator rdf_deps = [ rdf_protodict.Dict, @@ -90,6 +89,7 @@ def AddRequest(self, *args, **kw): class FingerprintResponse(rdf_structs.RDFProtoStruct): """Proto containing dicts with hashes.""" + protobuf = jobs_pb2.FingerprintResponse rdf_deps = [ rdf_protodict.Dict, @@ -114,4 +114,5 @@ class StatFSRequest(rdf_structs.RDFProtoStruct): class ListNetworkConnectionsArgs(rdf_structs.RDFProtoStruct): """Args for the ListNetworkConnections client action.""" + protobuf = flows_pb2.ListNetworkConnectionsArgs diff --git a/grr/core/grr_response_core/lib/rdfvalues/client_fs.py b/grr/core/grr_response_core/lib/rdfvalues/client_fs.py index 73e910c331..be5de93eec 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client_fs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client_fs.py @@ -4,7 +4,6 @@ import stat from typing import Text - from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client_action as rdf_client_action from grr_response_core.lib.rdfvalues import paths as rdf_paths @@ -22,6 +21,7 @@ class Filesystem(rdf_structs.RDFProtoStruct): This class describes a filesystem mounted on the client. """ + protobuf = sysinfo_pb2.Filesystem rdf_deps = [ rdf_protodict.AttributedDict, @@ -33,6 +33,7 @@ class Filesystems(rdf_protodict.RDFValueArray): This is used to represent the list of valid filesystems on the client. """ + rdf_type = Filesystem @@ -44,21 +45,25 @@ class FolderInformation(rdf_structs.RDFProtoStruct): the location of user specific items, e.g. the Temporary folder, or the Internet cache. """ + protobuf = jobs_pb2.FolderInformation class WindowsVolume(rdf_structs.RDFProtoStruct): """A disk volume on a windows client.""" + protobuf = sysinfo_pb2.WindowsVolume class UnixVolume(rdf_structs.RDFProtoStruct): """A disk volume on a unix client.""" + protobuf = sysinfo_pb2.UnixVolume class Volume(rdf_structs.RDFProtoStruct): """A disk volume on the client.""" + protobuf = sysinfo_pb2.Volume rdf_deps = [ rdfvalue.RDFDatetime, @@ -68,8 +73,9 @@ class Volume(rdf_structs.RDFProtoStruct): def FreeSpacePercent(self): try: - return (self.actual_available_allocation_units / - self.total_allocation_units) * 100.0 + return ( + self.actual_available_allocation_units / self.total_allocation_units + ) * 100.0 except ZeroDivisionError: return 100 @@ -78,8 +84,11 @@ def FreeSpaceBytes(self): def AUToBytes(self, allocation_units): """Convert a number of allocation units to bytes.""" - return (allocation_units * self.sectors_per_allocation_unit * - self.bytes_per_sector) + return ( + allocation_units + * self.sectors_per_allocation_unit + * self.bytes_per_sector + ) def AUToGBytes(self, allocation_units): """Convert a number of allocation units to GigaBytes.""" @@ -87,8 +96,13 @@ def AUToGBytes(self, allocation_units): def Name(self): """Return the best available name for this volume.""" - return (self.name or self.device_path or self.windowsvolume.drive_letter or - self.unixvolume.mount_point or None) + return ( + self.name + or self.device_path + or self.windowsvolume.drive_letter + or self.unixvolume.mount_point + or None + ) class DiskUsage(rdf_structs.RDFProtoStruct): @@ -97,11 +111,13 @@ class DiskUsage(rdf_structs.RDFProtoStruct): class Volumes(rdf_protodict.RDFValueArray): """A list of disk volumes on the client.""" + rdf_type = Volume class StatMode(rdfvalue.RDFInteger): """The mode of a file.""" + protobuf_type = "unsigned_integer" def __str__(self) -> Text: @@ -131,9 +147,9 @@ def __str__(self) -> Text: bin_mode = "0" * (9 - len(bin_mode)) + bin_mode bits = [] - for i in range(len(mode_template)): + for i, mode_ in enumerate(mode_template): if bin_mode[i] == "1": - bit = mode_template[i] + bit = mode_ else: bit = "-" @@ -172,6 +188,7 @@ class ExtAttr(rdf_structs.RDFProtoStruct): class StatEntry(rdf_structs.RDFProtoStruct): """Represent an extended stat response.""" + protobuf = jobs_pb2.StatEntry rdf_deps = [ rdf_protodict.DataBlob, @@ -189,6 +206,7 @@ def AFF4Path(self, client_urn): class FindSpec(rdf_structs.RDFProtoStruct): """A find specification.""" + protobuf = jobs_pb2.FindSpec rdf_deps = [ rdfvalue.RDFBytes, @@ -205,17 +223,23 @@ def Validate(self): """Ensure the pathspec is valid.""" self.pathspec.Validate() - if (self.HasField("start_time") and self.HasField("end_time") and - self.start_time > self.end_time): + if ( + self.HasField("start_time") + and self.HasField("end_time") + and self.start_time > self.end_time + ): raise ValueError("Start time must be before end time.") if not self.path_regex and not self.data_regex and not self.path_glob: - raise ValueError("A Find specification can not contain both an empty " - "path regex and an empty data regex") + raise ValueError( + "A Find specification can not contain both an empty " + "path regex and an empty data regex" + ) class BareGrepSpec(rdf_structs.RDFProtoStruct): """A GrepSpec without a target.""" + protobuf = flows_pb2.BareGrepSpec rdf_deps = [ rdf_standard.LiteralExpression, diff --git a/grr/core/grr_response_core/lib/rdfvalues/client_network.py b/grr/core/grr_response_core/lib/rdfvalues/client_network.py index 1670e3547c..c0f2734842 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client_network.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client_network.py @@ -4,9 +4,7 @@ import binascii import ipaddress import logging -from typing import Optional -from typing import Text -from typing import Union +from typing import Optional, Union from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import protodict as rdf_protodict @@ -23,6 +21,7 @@ class NetworkEndpoint(rdf_structs.RDFProtoStruct): class NetworkConnection(rdf_structs.RDFProtoStruct): """Information about a single network connection.""" + protobuf = sysinfo_pb2.NetworkConnection rdf_deps = [ NetworkEndpoint, @@ -31,6 +30,7 @@ class NetworkConnection(rdf_structs.RDFProtoStruct): class Connections(rdf_protodict.RDFValueArray): """A list of connections on the host.""" + rdf_type = NetworkConnection @@ -44,6 +44,7 @@ class NetworkAddress(rdf_structs.RDFProtoStruct): available on windows before python 3.4. So we use the older IPv4 functions for v4 addresses and our own pure python implementations for IPv6. """ + protobuf = jobs_pb2.NetworkAddress @classmethod @@ -62,7 +63,7 @@ def FromPackedBytes(cls, ip: bytes) -> "NetworkAddress": return result @property - def human_readable_address(self) -> Text: + def human_readable_address(self) -> str: addr = self.AsIPAddr() if addr is not None: return str(addr) @@ -70,8 +71,8 @@ def human_readable_address(self) -> Text: return "" @human_readable_address.setter - def human_readable_address(self, value: Text) -> None: - precondition.AssertType(value, Text) + def human_readable_address(self, value: str) -> None: + precondition.AssertType(value, str) addr = ipaddress.ip_address(value) if isinstance(addr, ipaddress.IPv6Address): @@ -98,8 +99,9 @@ def AsIPAddr(self) -> Optional[IPAddress]: return ipaddress.IPv6Address(self.packed_bytes) except ipaddress.AddressValueError: hex_packed_bytes = text.Hexify(self.packed_bytes) - logging.error("AddressValueError for %s (%s)", hex_packed_bytes, - self.address_type) + logging.error( + "AddressValueError for %s (%s)", hex_packed_bytes, self.address_type + ) raise message = "IP address has invalid type: {}".format(self.address_type) @@ -108,6 +110,7 @@ def AsIPAddr(self) -> Optional[IPAddress]: class DNSClientConfiguration(rdf_structs.RDFProtoStruct): """DNS client config.""" + protobuf = sysinfo_pb2.DNSClientConfiguration @@ -115,17 +118,18 @@ class MacAddress(rdfvalue.RDFBytes): """A MAC address.""" @property - def human_readable_address(self) -> Text: + def human_readable_address(self) -> str: return text.Hexify(self._value) @classmethod - def FromHumanReadableAddress(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadableAddress(cls, string: str): + precondition.AssertType(string, str) return cls(binascii.unhexlify(string.encode("ascii"))) class Interface(rdf_structs.RDFProtoStruct): """A network interface on the client system.""" + protobuf = jobs_pb2.Interface rdf_deps = [ MacAddress, @@ -146,6 +150,7 @@ def GetIPAddresses(self): class Interfaces(rdf_protodict.RDFValueArray): """The list of interfaces on a host.""" + rdf_type = Interface def GetIPAddresses(self): diff --git a/grr/core/grr_response_core/lib/rdfvalues/client_stats.py b/grr/core/grr_response_core/lib/rdfvalues/client_stats.py index e1a2b06026..748dcde43d 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client_stats.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client_stats.py @@ -1,23 +1,21 @@ #!/usr/bin/env python """Stats-related client rdfvalues.""" - - from grr_response_core.lib import rdfvalue - from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import structs as rdf_structs - from grr_response_proto import jobs_pb2 class CpuSeconds(rdf_structs.RDFProtoStruct): """CPU usage is reported as both a system and user components.""" + protobuf = jobs_pb2.CpuSeconds class CpuSample(rdf_structs.RDFProtoStruct): """A single CPU sample.""" + protobuf = jobs_pb2.CpuSample rdf_deps = [ rdfvalue.RDFDatetime, @@ -47,7 +45,8 @@ def FromMany(cls, samples): timestamp=max(sample.timestamp for sample in samples), cpu_percent=cpu_percent, user_cpu_time=max(sample.user_cpu_time for sample in samples), - system_cpu_time=max(sample.system_cpu_time for sample in samples)) + system_cpu_time=max(sample.system_cpu_time for sample in samples), + ) class IOSample(rdf_structs.RDFProtoStruct): @@ -79,11 +78,13 @@ def FromMany(cls, samples): read_bytes=max(sample.read_bytes for sample in samples), read_count=max(sample.read_count for sample in samples), write_bytes=max(sample.write_bytes for sample in samples), - write_count=max(sample.write_count for sample in samples)) + write_count=max(sample.write_count for sample in samples), + ) class ClientStats(rdf_structs.RDFProtoStruct): """A client stat object.""" + protobuf = jobs_pb2.ClientStats rdf_deps = [ CpuSample, @@ -108,9 +109,11 @@ def Downsampled(cls, stats, interval=None): result = cls(stats) result.cpu_samples = cls._Downsample( - kind=CpuSample, samples=stats.cpu_samples, interval=interval) + kind=CpuSample, samples=stats.cpu_samples, interval=interval + ) result.io_samples = cls._Downsample( - kind=IOSample, samples=stats.io_samples, interval=interval) + kind=IOSample, samples=stats.io_samples, interval=interval + ) return result @classmethod @@ -126,6 +129,7 @@ def _Downsample(cls, kind, samples, interval): class ClientResources(rdf_structs.RDFProtoStruct): """An RDFValue class representing the client resource usage.""" + protobuf = jobs_pb2.ClientResources rdf_deps = [ rdf_client.ClientURN, diff --git a/grr/core/grr_response_core/lib/rdfvalues/client_test.py b/grr/core/grr_response_core/lib/rdfvalues/client_test.py index 771bd3be58..bda277739f 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/client_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/client_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Test client RDFValues.""" - import platform import socket from unittest import mock @@ -133,8 +132,10 @@ def testInitialization(self): def testURNValidation(self): # These should all come out the same: C.00aaeccbb45f33a3 test_set = [ - "C.00aaeccbb45f33a3", "C.00aaeccbb45f33a3".upper(), - "c.00aaeccbb45f33a3", "C.00aaeccbb45f33a3 " + "C.00aaeccbb45f33a3", + "C.00aaeccbb45f33a3".upper(), + "c.00aaeccbb45f33a3", + "C.00aaeccbb45f33a3 ", ] results = [] for urnstr in test_set: @@ -150,28 +151,34 @@ def testURNValidation(self): rdf_client.ClientURN(rdf_client.ClientURN(test_set[0])) error_set = [ - "B.00aaeccbb45f33a3", "c.00accbb45f33a3", "aff5:/C.00aaeccbb45f33a3" + "B.00aaeccbb45f33a3", + "c.00accbb45f33a3", + "aff5:/C.00aaeccbb45f33a3", ] for badurn in error_set: self.assertRaises(type_info.TypeValueError, rdf_client.ClientURN, badurn) -class NetworkAddressTests(rdf_test_base.RDFValueTestMixin, - test_lib.GRRBaseTest): +class NetworkAddressTests( + rdf_test_base.RDFValueTestMixin, test_lib.GRRBaseTest +): """Test the NetworkAddress.""" rdfvalue_class = rdf_client_network.NetworkAddress def GenerateSample(self, number=0): return rdf_client_network.NetworkAddress( - human_readable_address="192.168.0.%s" % number) + human_readable_address="192.168.0.%s" % number + ) def testIPv4(self): sample = rdf_client_network.NetworkAddress( - human_readable_address="192.168.0.1") - self.assertEqual(sample.address_type, - rdf_client_network.NetworkAddress.Family.INET) + human_readable_address="192.168.0.1" + ) + self.assertEqual( + sample.address_type, rdf_client_network.NetworkAddress.Family.INET + ) # Equal to socket.inet_pton(socket.AF_INET, "192.168.0.1"), which is # unavailable on Windows. self.assertEqual(sample.packed_bytes, b"\xc0\xa8\x00\x01") @@ -186,12 +193,13 @@ def testIPv6(self): # on Windows. expected_addresses = [ b"\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29", - b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", ] for address, expected in zip(ipv6_addresses, expected_addresses): sample = rdf_client_network.NetworkAddress(human_readable_address=address) - self.assertEqual(sample.address_type, - rdf_client_network.NetworkAddress.Family.INET6) + self.assertEqual( + sample.address_type, rdf_client_network.NetworkAddress.Family.INET6 + ) self.assertEqual(sample.packed_bytes, expected) self.assertEqual(sample.human_readable_address, address) @@ -227,7 +235,8 @@ def testGetFQDN(self): def testGetFQDN_Localhost(self): with mock.patch.object( - socket, "getfqdn", return_value=rdf_client._LOCALHOST): + socket, "getfqdn", return_value=rdf_client._LOCALHOST + ): with mock.patch.object(socket, "gethostname", return_value="foo"): uname = self.rdfvalue_class.FromCurrentSystem() self.assertEqual(uname.fqdn, "foo") @@ -241,24 +250,28 @@ def testFromMany(self): timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2001-01-01"), cpu_percent=0.2, user_cpu_time=0.1, - system_cpu_time=0.5), + system_cpu_time=0.5, + ), rdf_client_stats.CpuSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2001-02-01"), cpu_percent=0.1, user_cpu_time=2.5, - system_cpu_time=1.2), + system_cpu_time=1.2, + ), rdf_client_stats.CpuSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2001-03-01"), cpu_percent=0.6, user_cpu_time=3.4, - system_cpu_time=2.4), + system_cpu_time=2.4, + ), ] expected = rdf_client_stats.CpuSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2001-03-01"), cpu_percent=0.3, user_cpu_time=3.4, - system_cpu_time=2.4) + system_cpu_time=2.4, + ) self.assertEqual(rdf_client_stats.CpuSample.FromMany(samples), expected) @@ -274,21 +287,25 @@ def testFromMany(self): rdf_client_stats.IOSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2001-01-01"), read_bytes=0, - write_bytes=0), + write_bytes=0, + ), rdf_client_stats.IOSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2002-01-01"), read_bytes=512, - write_bytes=1024), + write_bytes=1024, + ), rdf_client_stats.IOSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2003-01-01"), read_bytes=2048, - write_bytes=4096), + write_bytes=4096, + ), ] expected = rdf_client_stats.IOSample( timestamp=rdfvalue.RDFDatetime.FromHumanReadable("2003-01-01"), read_bytes=2048, - write_bytes=4096) + write_bytes=4096, + ) self.assertEqual(rdf_client_stats.IOSample.FromMany(samples), expected) @@ -308,42 +325,51 @@ def testDownsampled(self): timestamp=timestamp("2001-01-01 00:00"), user_cpu_time=2.5, system_cpu_time=3.2, - cpu_percent=0.5), + cpu_percent=0.5, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:05"), user_cpu_time=2.6, system_cpu_time=4.7, - cpu_percent=0.6), + cpu_percent=0.6, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:10"), user_cpu_time=10.0, system_cpu_time=14.2, - cpu_percent=0.9), + cpu_percent=0.9, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:12"), user_cpu_time=12.3, system_cpu_time=14.9, - cpu_percent=0.1), + cpu_percent=0.1, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:21"), user_cpu_time=16.1, system_cpu_time=22.3, - cpu_percent=0.4) + cpu_percent=0.4, + ), ], io_samples=[ rdf_client_stats.IOSample( timestamp=timestamp("2001-01-01 00:00"), read_count=0, - write_count=0), + write_count=0, + ), rdf_client_stats.IOSample( timestamp=timestamp("2001-01-01 00:02"), read_count=3, - write_count=5), + write_count=5, + ), rdf_client_stats.IOSample( timestamp=timestamp("2001-01-01 00:12"), read_count=6, - write_count=8), - ]) + write_count=8, + ), + ], + ) expected = rdf_client_stats.ClientStats( cpu_samples=[ @@ -351,31 +377,38 @@ def testDownsampled(self): timestamp=timestamp("2001-01-01 00:05"), user_cpu_time=2.6, system_cpu_time=4.7, - cpu_percent=0.55), + cpu_percent=0.55, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:12"), user_cpu_time=12.3, system_cpu_time=14.9, - cpu_percent=0.5), + cpu_percent=0.5, + ), rdf_client_stats.CpuSample( timestamp=timestamp("2001-01-01 00:21"), user_cpu_time=16.1, system_cpu_time=22.3, - cpu_percent=0.4), + cpu_percent=0.4, + ), ], io_samples=[ rdf_client_stats.IOSample( timestamp=timestamp("2001-01-01 00:02"), read_count=3, - write_count=5), + write_count=5, + ), rdf_client_stats.IOSample( timestamp=timestamp("2001-01-01 00:12"), read_count=6, - write_count=8), - ]) + write_count=8, + ), + ], + ) actual = rdf_client_stats.ClientStats.Downsampled( - stats, interval=rdfvalue.Duration.From(10, rdfvalue.MINUTES)) + stats, interval=rdfvalue.Duration.From(10, rdfvalue.MINUTES) + ) self.assertEqual(actual, expected) @@ -388,21 +421,35 @@ def testFromPsutilProcess(self): res = rdf_client.Process.FromPsutilProcess(p) int_fields = [ - "pid", "ppid", "ctime", "num_threads", "user_cpu_time", - "system_cpu_time", "RSS_size", "VMS_size", "memory_percent" + "pid", + "ppid", + "ctime", + "num_threads", + "user_cpu_time", + "system_cpu_time", + "RSS_size", + "VMS_size", + "memory_percent", ] if platform.system() != "Windows": int_fields.extend([ - "real_uid", "effective_uid", "saved_uid", "real_gid", "effective_gid", - "saved_gid" + "real_uid", + "effective_uid", + "saved_uid", + "real_gid", + "effective_gid", + "saved_gid", ]) for field in int_fields: self.assertGreater( - getattr(res, field), 0, + getattr(res, field), + 0, "rdf_client.Process.{} is not greater than 0, got {!r}.".format( - field, getattr(res, field))) + field, getattr(res, field) + ), + ) string_fields = ["name", "exe", "cmdline", "cwd", "username"] @@ -411,8 +458,10 @@ def testFromPsutilProcess(self): for field in string_fields: self.assertNotEqual( - getattr(res, field), "", - "rdf_client.Process.{} is the empty string.".format(field)) + getattr(res, field), + "", + "rdf_client.Process.{} is the empty string.".format(field), + ) # Prevent flaky tests by allowing "sleeping" as state of current process. self.assertIn(res.status, ["running", "sleeping"]) diff --git a/grr/core/grr_response_core/lib/rdfvalues/cloud.py b/grr/core/grr_response_core/lib/rdfvalues/cloud.py index ba95988eeb..e61abc20f2 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/cloud.py +++ b/grr/core/grr_response_core/lib/rdfvalues/cloud.py @@ -1,14 +1,12 @@ #!/usr/bin/env python """Cloud-related rdfvalues.""" - - from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 -AMAZON_URL_BASE = "http://169.254.169.254/latest/meta-data/" +AMAZON_URL_BASE = "http://169.254.169.254/latest/meta-data" AMAZON_BIOS_REGEX = ".*amazon" AMAZON_SERVICE_REGEX = "SERVICE_NAME: AWSLiteAgent" # Using the ip and not metadata.google.internal to avoid issues on endpoints @@ -19,11 +17,36 @@ class CloudMetadataRequest(rdf_structs.RDFProtoStruct): + """RDF wrapper for `CloudMetadataRequest` message.""" + protobuf = flows_pb2.CloudMetadataRequest rdf_deps = [ rdf_protodict.Dict, ] + @classmethod + def ForAmazon(cls, *args, **kwargs) -> "CloudMetadataRequest": + return cls( + *args, + **kwargs, + bios_version_regex=AMAZON_BIOS_REGEX, + service_name_regex=AMAZON_SERVICE_REGEX, + instance_type=CloudInstance.InstanceType.AMAZON, + timeout=1.0, + ) + + @classmethod + def ForGoogle(cls, *args, **kwargs) -> "CloudMetadataRequest": + return cls( + *args, + **kwargs, + bios_version_regex=GOOGLE_BIOS_REGEX, + service_name_regex=GOOGLE_SERVICE_REGEX, + headers={"Metadata-Flavor": "Google"}, + instance_type=CloudInstance.InstanceType.GOOGLE, + timeout=1.0, + ) + class CloudMetadataRequests(rdf_structs.RDFProtoStruct): protobuf = flows_pb2.CloudMetadataRequests @@ -59,63 +82,81 @@ class CloudInstance(rdf_structs.RDFProtoStruct): ] -def _MakeArgs(amazon_collection_map, google_collection_map): - """Build metadata requests list from collection maps.""" - request_list = [] - for url, label in amazon_collection_map.items(): - request_list.append( - CloudMetadataRequest( - bios_version_regex=AMAZON_BIOS_REGEX, - service_name_regex=AMAZON_SERVICE_REGEX, - instance_type="AMAZON", - timeout=1.0, - url=url, - label=label)) - for url, label in google_collection_map.items(): - request_list.append( - CloudMetadataRequest( - bios_version_regex=GOOGLE_BIOS_REGEX, - service_name_regex=GOOGLE_SERVICE_REGEX, - headers={"Metadata-Flavor": "Google"}, - instance_type="GOOGLE", - timeout=1.0, - url=url, - label=label)) - return request_list - - def MakeGoogleUniqueID(cloud_instance): """Make the google unique ID of zone/project/id.""" - if not (cloud_instance.zone and cloud_instance.project_id and - cloud_instance.instance_id): - raise ValueError("Bad zone/project_id/id: '%s/%s/%s'" % - (cloud_instance.zone, cloud_instance.project_id, - cloud_instance.instance_id)) + if not ( + cloud_instance.zone + and cloud_instance.project_id + and cloud_instance.instance_id + ): + raise ValueError( + "Bad zone/project_id/id: '%s/%s/%s'" + % ( + cloud_instance.zone, + cloud_instance.project_id, + cloud_instance.instance_id, + ) + ) return "/".join([ - cloud_instance.zone.split("/")[-1], cloud_instance.project_id, - cloud_instance.instance_id + cloud_instance.zone.split("/")[-1], + cloud_instance.project_id, + cloud_instance.instance_id, ]) def BuildCloudMetadataRequests(): """Build the standard set of cloud metadata to collect during interrogate.""" - amazon_collection_map = { - "/".join((AMAZON_URL_BASE, "instance-id")): "instance_id", - "/".join((AMAZON_URL_BASE, "ami-id")): "ami_id", - "/".join((AMAZON_URL_BASE, "hostname")): "hostname", - "/".join((AMAZON_URL_BASE, "public-hostname")): "public_hostname", - "/".join((AMAZON_URL_BASE, "instance-type")): "instance_type", - } - google_collection_map = { - "/".join((GOOGLE_URL_BASE, "instance/id")): "instance_id", - "/".join((GOOGLE_URL_BASE, "instance/zone")): "zone", - "/".join((GOOGLE_URL_BASE, "project/project-id")): "project_id", - "/".join((GOOGLE_URL_BASE, "instance/hostname")): "hostname", - "/".join((GOOGLE_URL_BASE, "instance/machine-type")): "machine_type", - } - - return CloudMetadataRequests(requests=_MakeArgs(amazon_collection_map, - google_collection_map)) + return CloudMetadataRequests( + requests=[ + CloudMetadataRequest.ForAmazon( + url="/".join((AMAZON_URL_BASE, "instance-id")), + label="instance_id", + ), + CloudMetadataRequest.ForAmazon( + url="/".join((AMAZON_URL_BASE, "ami-id")), + label="ami_id", + ), + CloudMetadataRequest.ForAmazon( + url="/".join((AMAZON_URL_BASE, "hostname")), + label="hostname", + ), + CloudMetadataRequest.ForAmazon( + url="/".join((AMAZON_URL_BASE, "public-hostname")), + label="public_hostname", + # Per AWS documentation, `public-hostname` is available only on + # hosts with public IPv4 addresses and `enableDnsHostnames` option + # set. Otherwise, attempts to query this property will result in + # 404 HTTP status. + # + # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-categories.html + ignore_http_errors=True, + ), + CloudMetadataRequest.ForAmazon( + url="/".join((AMAZON_URL_BASE, "instance-type")), + label="instance_type", + ), + CloudMetadataRequest.ForGoogle( + url="/".join((GOOGLE_URL_BASE, "instance/id")), + label="instance_id", + ), + CloudMetadataRequest.ForGoogle( + url="/".join((GOOGLE_URL_BASE, "instance/zone")), + label="zone", + ), + CloudMetadataRequest.ForGoogle( + url="/".join((GOOGLE_URL_BASE, "project/project-id")), + label="project_id", + ), + CloudMetadataRequest.ForGoogle( + url="/".join((GOOGLE_URL_BASE, "instance/hostname")), + label="hostname", + ), + CloudMetadataRequest.ForGoogle( + url="/".join((GOOGLE_URL_BASE, "instance/machine-type")), + label="machine_type", + ), + ], + ) def ConvertCloudMetadataResponsesToCloudInstance(metadata_responses): @@ -128,6 +169,7 @@ def ConvertCloudMetadataResponsesToCloudInstance(metadata_responses): Args: metadata_responses: CloudMetadataResponses object from the client. + Returns: CloudInstance object Raises: @@ -141,7 +183,8 @@ def ConvertCloudMetadataResponsesToCloudInstance(metadata_responses): result = CloudInstance(cloud_type="AMAZON", amazon=cloud_instance) else: raise ValueError( - "Unknown cloud instance type: %s" % metadata_responses.instance_type) + "Unknown cloud instance type: %s" % metadata_responses.instance_type + ) for cloud_metadata in metadata_responses.responses: setattr(cloud_instance, cloud_metadata.label, cloud_metadata.text) diff --git a/grr/core/grr_response_core/lib/rdfvalues/cloud_test.py b/grr/core/grr_response_core/lib/rdfvalues/cloud_test.py index fc2e71b21b..c0aaf5f10c 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/cloud_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/cloud_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for grr.lib.rdfvalues.cloud.""" - from absl import app from grr_response_core.lib.rdfvalues import cloud as rdf_cloud @@ -14,10 +13,12 @@ def testMakeGoogleUniqueID(self): google_cloud_instance = rdf_cloud.GoogleCloudInstance( instance_id="1771384456894610289", zone="projects/123456789733/zones/us-central1-a", - project_id="myproject") + project_id="myproject", + ) self.assertEqual( rdf_cloud.MakeGoogleUniqueID(google_cloud_instance), - "us-central1-a/myproject/1771384456894610289") + "us-central1-a/myproject/1771384456894610289", + ) def main(argv): diff --git a/grr/core/grr_response_core/lib/rdfvalues/config.py b/grr/core/grr_response_core/lib/rdfvalues/config.py index 2f0d77bd2f..ce7e6bd6ab 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/config.py +++ b/grr/core/grr_response_core/lib/rdfvalues/config.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Implementations of RDFValues used in GRR config options definitions.""" - from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import config_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/config_file.py b/grr/core/grr_response_core/lib/rdfvalues/config_file.py index 0d1f77ecde..69d83acc23 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/config_file.py +++ b/grr/core/grr_response_core/lib/rdfvalues/config_file.py @@ -8,11 +8,13 @@ class LogTarget(rdf_structs.RDFProtoStruct): """An RDFValue representation of a logging target.""" + protobuf = config_file_pb2.LogTarget class LogConfig(rdf_structs.RDFProtoStruct): """An RDFValue representation of a logging configuration.""" + protobuf = config_file_pb2.LogConfig rdf_deps = [ LogTarget, @@ -21,36 +23,22 @@ class LogConfig(rdf_structs.RDFProtoStruct): class NfsClient(rdf_structs.RDFProtoStruct): """An RDFValue representation of an NFS Client configuration.""" + protobuf = config_file_pb2.NfsClient class NfsExport(rdf_structs.RDFProtoStruct): """An RDFValue representation of an NFS Export entry.""" + protobuf = config_file_pb2.NfsExport rdf_deps = [ NfsClient, ] -class SshdMatchBlock(rdf_structs.RDFProtoStruct): - """An RDFValue representation of an sshd config match block.""" - protobuf = config_file_pb2.SshdMatchBlock - rdf_deps = [ - rdf_protodict.AttributedDict, - ] - - -class SshdConfig(rdf_structs.RDFProtoStruct): - """An RDFValue representation of a sshd config file.""" - protobuf = config_file_pb2.SshdConfig - rdf_deps = [ - rdf_protodict.AttributedDict, - SshdMatchBlock, - ] - - class NtpConfig(rdf_structs.RDFProtoStruct): """An RDFValue representation of a ntp config file.""" + protobuf = config_file_pb2.NtpConfig rdf_deps = [ rdf_protodict.AttributedDict, @@ -59,37 +47,14 @@ class NtpConfig(rdf_structs.RDFProtoStruct): class PamConfigEntry(rdf_structs.RDFProtoStruct): """An RDFValue representation of a single entry in a PAM configuration.""" + protobuf = config_file_pb2.PamConfigEntry class PamConfig(rdf_structs.RDFProtoStruct): """An RDFValue representation of an entire PAM configuration.""" + protobuf = config_file_pb2.PamConfig rdf_deps = [ PamConfigEntry, ] - - -class SudoersAlias(rdf_structs.RDFProtoStruct): - """An RDFValue representation of a sudoers alias.""" - protobuf = config_file_pb2.SudoersAlias - - -class SudoersDefault(rdf_structs.RDFProtoStruct): - """An RDFValue representation of a sudoers default.""" - protobuf = config_file_pb2.SudoersDefault - - -class SudoersEntry(rdf_structs.RDFProtoStruct): - """An RDFValue representation of a sudoers file command list entry.""" - protobuf = config_file_pb2.SudoersEntry - - -class SudoersConfig(rdf_structs.RDFProtoStruct): - """An RDFValue representation of a sudoers config file.""" - protobuf = config_file_pb2.SudoersConfig - rdf_deps = [ - SudoersAlias, - SudoersDefault, - SudoersEntry, - ] diff --git a/grr/core/grr_response_core/lib/rdfvalues/cronjobs.py b/grr/core/grr_response_core/lib/rdfvalues/cronjobs.py index 5f1dc40157..3488fddbde 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/cronjobs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/cronjobs.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """RDFValues for GRR client-side cron jobs parsing.""" - from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/crypto.py b/grr/core/grr_response_core/lib/rdfvalues/crypto.py index 43a083d36a..8fd061020b 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/crypto.py +++ b/grr/core/grr_response_core/lib/rdfvalues/crypto.py @@ -5,7 +5,6 @@ import hashlib import logging import os -from typing import Text from cryptography import exceptions from cryptography import x509 @@ -68,14 +67,17 @@ def __init__(self, initializer=None): elif isinstance(initializer, bytes): try: value = x509.load_pem_x509_certificate( - initializer, backend=openssl.backend) + initializer, backend=openssl.backend + ) except (ValueError, TypeError) as e: - raise rdfvalue.DecodeError("Invalid certificate %s: %s" % - (initializer, e)) + raise rdfvalue.DecodeError( + "Invalid certificate %s: %s" % (initializer, e) + ) super().__init__(value) else: - raise rdfvalue.InitializeError("Cannot initialize %s from %s." % - (self.__class__, initializer)) + raise rdfvalue.InitializeError( + "Cannot initialize %s from %s." % (self.__class__, initializer) + ) if self._value is not None: self.GetCN() # This can also raise if there isn't exactly one CN entry. @@ -109,8 +111,8 @@ def FromSerializedBytes(cls, value: bytes): return cls(value) @classmethod - def FromHumanReadable(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadable(cls, string: str): + precondition.AssertType(string, str) return cls.FromSerializedBytes(string.encode("ascii")) @classmethod @@ -128,7 +130,7 @@ def SerializeToBytes(self) -> bytes: def AsPEM(self): return self.SerializeToBytes() - def __str__(self) -> Text: + def __str__(self) -> str: return self.SerializeToBytes().decode("ascii") def Verify(self, public_key): @@ -158,54 +160,10 @@ def Verify(self, public_key): public_key.Verify( self._value.tbs_certificate_bytes, self._value.signature, - hash_algorithm=self._value.signature_hash_algorithm) + hash_algorithm=self._value.signature_hash_algorithm, + ) return True - @classmethod - def ClientCertFromCSR(cls, csr): - """Creates a new cert for the given common name. - - Args: - csr: A CertificateSigningRequest. - - Returns: - The signed cert. - """ - builder = x509.CertificateBuilder() - # Use the client CN for a cert serial_id. This will ensure we do - # not have clashing cert id. - common_name = csr.GetCN() - serial = int(common_name.split(".")[1], 16) - builder = builder.serial_number(serial) - builder = builder.subject_name( - x509.Name( - [x509.NameAttribute(oid.NameOID.COMMON_NAME, str(common_name))])) - - now = rdfvalue.RDFDatetime.Now() - now_plus_year = now + rdfvalue.Duration.From(52, rdfvalue.WEEKS) - builder = builder.not_valid_after(now_plus_year.AsDatetime()) - now_minus_ten = now - rdfvalue.Duration.From(10, rdfvalue.SECONDS) - builder = builder.not_valid_before(now_minus_ten.AsDatetime()) - # TODO(user): dependency loop with - # grr/core/grr_response_core/config/client.py. - # pylint: disable=protected-access - ca_cert = config_lib._CONFIG["CA.certificate"] - # pylint: enable=protected-access - builder = builder.issuer_name(ca_cert.GetIssuer()) - builder = builder.public_key(csr.GetPublicKey().GetRawPublicKey()) - - # TODO(user): dependency loop with - # grr/core/grr_response_core/config/client.py. - # pylint: disable=protected-access - ca_key = config_lib._CONFIG["PrivateKeys.ca_key"] - # pylint: enable=protected-access - - return RDFX509Cert( - builder.sign( - private_key=ca_key.GetRawPrivateKey(), - algorithm=hashes.SHA256(), - backend=openssl.backend)) - class CertificateSigningRequest(rdfvalue.RDFPrimitive): """A CSR Rdfvalue.""" @@ -219,17 +177,24 @@ def __init__(self, initializer=None, common_name=None, private_key=None): value = x509.load_pem_x509_csr(initializer, backend=openssl.backend) super().__init__(value) elif common_name and private_key: - value = x509.CertificateSigningRequestBuilder().subject_name( - x509.Name( - [x509.NameAttribute(oid.NameOID.COMMON_NAME, - str(common_name))])).sign( - private_key.GetRawPrivateKey(), - hashes.SHA256(), - backend=openssl.backend) + value = ( + x509.CertificateSigningRequestBuilder() + .subject_name( + x509.Name([ + x509.NameAttribute(oid.NameOID.COMMON_NAME, str(common_name)) + ]) + ) + .sign( + private_key.GetRawPrivateKey(), + hashes.SHA256(), + backend=openssl.backend, + ) + ) super().__init__(value) elif initializer is not None: - raise rdfvalue.InitializeError("Cannot initialize %s from %s." % - (self.__class__, initializer)) + raise rdfvalue.InitializeError( + "Cannot initialize %s from %s." % (self.__class__, initializer) + ) @classmethod def FromSerializedBytes(cls, value: bytes): @@ -251,7 +216,7 @@ def SerializeToBytes(self) -> bytes: def AsPEM(self): return self.SerializeToBytes() - def __str__(self) -> Text: + def __str__(self) -> str: return self.SerializeToBytes().decode("ascii") def GetCN(self): @@ -273,7 +238,8 @@ def Verify(self, public_key): public_key.Verify( self._value.tbs_certrequest_bytes, self._value.signature, - hash_algorithm=self._value.signature_hash_algorithm) + hash_algorithm=self._value.signature_hash_algorithm, + ) return True @@ -292,20 +258,22 @@ def __init__(self, initializer=None): super().__init__(initializer) return - if isinstance(initializer, Text): + if isinstance(initializer, str): initializer = initializer.encode("ascii") if isinstance(initializer, bytes): try: value = serialization.load_pem_public_key( - initializer, backend=openssl.backend) + initializer, backend=openssl.backend + ) super().__init__(value) return except (TypeError, ValueError, exceptions.UnsupportedAlgorithm) as e: raise type_info.TypeValueError("Public key invalid: %s" % e) - raise rdfvalue.InitializeError("Cannot initialize %s from %s." % - (self.__class__, initializer)) + raise rdfvalue.InitializeError( + "Cannot initialize %s from %s." % (self.__class__, initializer) + ) def GetRawPublicKey(self): return self._value @@ -321,8 +289,8 @@ def FromWireFormat(cls, value): return cls.FromSerializedBytes(value) @classmethod - def FromHumanReadable(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadable(cls, string: str): + precondition.AssertType(string, str) return cls.FromSerializedBytes(string.encode("ascii")) def SerializeToBytes(self) -> bytes: @@ -330,12 +298,13 @@ def SerializeToBytes(self) -> bytes: return b"" return self._value.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo) + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) def GetN(self): return self._value.public_numbers().n - def __str__(self) -> Text: + def __str__(self) -> str: return self.SerializeToBytes().decode("ascii") # TODO(user): this should return a string, since PEM format @@ -358,7 +327,9 @@ def Encrypt(self, message): padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), - label=None)) + label=None, + ), + ) except ValueError as e: raise CipherError(e) @@ -370,16 +341,16 @@ def Verify(self, message, signature, hash_algorithm=None): if hash_algorithm is None: hash_algorithm = hashes.SHA256() - last_e = None for padding_algorithm in [ padding.PSS( - mgf=padding.MGF1(hash_algorithm), - salt_length=padding.PSS.MAX_LENGTH), - padding.PKCS1v15() + mgf=padding.MGF1(hash_algorithm), salt_length=padding.PSS.MAX_LENGTH + ), + padding.PKCS1v15(), ]: try: - self._value.verify(signature, message, padding_algorithm, - hash_algorithm) + self._value.verify( + signature, message, padding_algorithm, hash_algorithm + ) return True except exceptions.InvalidSignature as e: @@ -404,16 +375,18 @@ def __init__(self, initializer=None, allow_prompt=None): super().__init__(initializer) return - if isinstance(initializer, Text): + if isinstance(initializer, str): initializer = initializer.encode("ascii") if not isinstance(initializer, bytes): - raise rdfvalue.InitializeError("Cannot initialize %s from %s." % - (self.__class__, initializer)) + raise rdfvalue.InitializeError( + "Cannot initialize %s from %s." % (self.__class__, initializer) + ) try: value = serialization.load_pem_private_key( - initializer, password=None, backend=openssl.backend) + initializer, password=None, backend=openssl.backend + ) super().__init__(value) return except (TypeError, ValueError, exceptions.UnsupportedAlgorithm) as e: @@ -443,14 +416,15 @@ def __init__(self, initializer=None, allow_prompt=None): # The private key is encrypted and we can ask the user for the passphrase. password = utils.PassphraseCallback() value = serialization.load_pem_private_key( - initializer, password=password, backend=openssl.backend) + initializer, password=password, backend=openssl.backend + ) super().__init__(value) except (TypeError, ValueError, exceptions.UnsupportedAlgorithm) as e: raise type_info.TypeValueError("Unable to load private key: %s" % e) @classmethod - def FromHumanReadable(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadable(cls, string: str): + precondition.AssertType(string, str) return cls.FromSerializedBytes(string.encode("ascii")) def GetRawPrivateKey(self): @@ -459,17 +433,13 @@ def GetRawPrivateKey(self): def GetPublicKey(self): return RSAPublicKey(self._value.public_key()) - def Sign(self, message, use_pss=False): + def Sign(self, message): """Sign a given message.""" precondition.AssertType(message, bytes) - # TODO(amoser): This should use PSS by default at some point. - if not use_pss: - padding_algorithm = padding.PKCS1v15() - else: - padding_algorithm = padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH) - + padding_algorithm = padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH + ) return self._value.sign(message, padding_algorithm, hashes.SHA256()) def Decrypt(self, message): @@ -482,14 +452,17 @@ def Decrypt(self, message): padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), - label=None)) + label=None, + ), + ) except ValueError as e: raise CipherError(e) @classmethod def GenerateKey(cls, bits=2048, exponent=65537): key = rsa.generate_private_key( - public_exponent=exponent, key_size=bits, backend=openssl.backend) + public_exponent=exponent, key_size=bits, backend=openssl.backend + ) return cls(key) @classmethod @@ -500,7 +473,6 @@ def FromSerializedBytes(cls, value: bytes): @classmethod def FromWireFormat(cls, value): precondition.AssertType(value, bytes) - return cls(value) def SerializeToBytes(self) -> bytes: if self._value is None: @@ -508,9 +480,10 @@ def SerializeToBytes(self) -> bytes: return self._value.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) - def __str__(self) -> Text: + def __str__(self) -> str: digest = hashlib.sha256(self.AsPEM()).hexdigest() return "%s (%s)" % ((self.__class__).__name__, digest) @@ -526,7 +499,8 @@ def AsPassphraseProtectedPEM(self, passphrase): return self._value.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(passphrase)) + encryption_algorithm=serialization.BestAvailableEncryption(passphrase), + ) def KeyLen(self): if self._value is None: @@ -546,6 +520,7 @@ class PEMPublicKey(RSAPublicKey): class Hash(rdf_structs.RDFProtoStruct): """A hash object containing multiple digests.""" + protobuf = jobs_pb2.Hash rdf_deps = [ rdf_standard.AuthenticodeSignedData, @@ -561,6 +536,7 @@ class SignedBlob(rdf_structs.RDFProtoStruct): The client can receive and verify a signed blob (e.g. driver or executable binary). Once verified, the client may execute this. """ + protobuf = jobs_pb2.SignedBlob def Verify(self, public_key): @@ -577,8 +553,11 @@ def Verify(self, public_key): """ if self.digest_type != self.HashType.SHA256: raise rdfvalue.DecodeError("Unsupported digest.") + # TODO: Remove PKCS1v15 signature type when client adopted + # change to PSS. if self.signature_type not in [ - self.SignatureType.RSA_PKCS1v15, self.SignatureType.RSA_PSS + self.SignatureType.RSA_PKCS1v15, + self.SignatureType.RSA_PSS, ]: raise rdfvalue.DecodeError("Unsupported signature type.") @@ -606,7 +585,7 @@ def Sign(self, data, signing_key, verify_key=None): logging.warning("signing key is too short.") self.signature = signing_key.Sign(data) - self.signature_type = self.SignatureType.RSA_PKCS1v15 + self.signature_type = self.SignatureType.RSA_PSS self.digest = hashlib.sha256(data).digest() self.digest_type = self.HashType.SHA256 @@ -636,8 +615,9 @@ def __init__(self, initializer=None): precondition.AssertType(initializer, bytes) if len(initializer) % 8: - raise CipherError("Invalid key length %d (%s)." % - (len(initializer) * 8, initializer)) + raise CipherError( + "Invalid key length %d (%s)." % (len(initializer) * 8, initializer) + ) super().__init__(initializer) @@ -656,17 +636,17 @@ def FromSerializedBytes(cls, value: bytes): return cls(value) @classmethod - def FromHumanReadable(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadable(cls, string: str): + precondition.AssertType(string, str) return cls(binascii.unhexlify(string)) - def __str__(self) -> Text: + def __str__(self) -> str: return "%s (%s)" % (self.__class__.__name__, self.AsHexDigest()) def __len__(self) -> int: return len(self._value) - def AsHexDigest(self) -> Text: + def AsHexDigest(self) -> str: return text.Hexify(self._value) def SerializeToBytes(self): @@ -684,7 +664,7 @@ def RawBytes(self): return self._value -class StreamingCBCEncryptor(object): +class StreamingCBCEncryptor: """A class to stream data to a CBCCipher object.""" def __init__(self, cipher): @@ -707,7 +687,7 @@ def Finalize(self): return res -class AES128CBCCipher(object): +class AES128CBCCipher: """A Cipher using AES128 in CBC mode and PKCS7 for padding.""" algorithm = None @@ -732,8 +712,8 @@ def UnPad(self, padded_data): def GetEncryptor(self): return ciphers.Cipher( - algorithms.AES(self.key), modes.CBC(self.iv), - backend=openssl.backend).encryptor() + algorithms.AES(self.key), modes.CBC(self.iv), backend=openssl.backend + ).encryptor() def Encrypt(self, data): """A convenience method which pads and encrypts at once.""" @@ -747,8 +727,8 @@ def Encrypt(self, data): def GetDecryptor(self): return ciphers.Cipher( - algorithms.AES(self.key), modes.CBC(self.iv), - backend=openssl.backend).decryptor() + algorithms.AES(self.key), modes.CBC(self.iv), backend=openssl.backend + ).decryptor() def Decrypt(self, data): """A convenience method which pads and decrypts at once.""" @@ -763,6 +743,7 @@ def Decrypt(self, data): class SymmetricCipher(rdf_structs.RDFProtoStruct): """Abstract symmetric cipher operations.""" + protobuf = jobs_pb2.SymmetricCipher rdf_deps = [ EncryptionKey, @@ -776,7 +757,8 @@ def Generate(cls, algorithm): return cls( _algorithm=algorithm, _key=EncryptionKey.GenerateKey(length=128), - _iv=EncryptionKey.GenerateKey(length=128)) + _iv=EncryptionKey.GenerateKey(length=128), + ) def _get_cipher(self): if self._algorithm != self.Algorithm.AES128CBC: @@ -797,7 +779,7 @@ def Decrypt(self, data): return self._get_cipher().Decrypt(data) -class HMAC(object): +class HMAC: """A wrapper for the cryptography HMAC object.""" def __init__(self, key, use_sha256=False): @@ -848,6 +830,7 @@ def Verify(self, message, signature): class Password(rdf_structs.RDFProtoStruct): """A password stored in the database.""" + protobuf = jobs_pb2.Password def _CalculateHash(self, password, salt, iteration_count): @@ -856,7 +839,8 @@ def _CalculateHash(self, password, salt, iteration_count): length=32, salt=salt, iterations=iteration_count, - backend=openssl.backend) + backend=openssl.backend, + ) return kdf.derive(password) def SetPassword(self, password): @@ -864,15 +848,16 @@ def SetPassword(self, password): self.iteration_count = 100000 # prevent non-descriptive 'key_material must be bytes' error later - if isinstance(password, Text): + if isinstance(password, str): password = password.encode("utf-8") - self.hashed_pwd = self._CalculateHash(password, self.salt, - self.iteration_count) + self.hashed_pwd = self._CalculateHash( + password, self.salt, self.iteration_count + ) def CheckPassword(self, password): # prevent non-descriptive 'key_material must be bytes' error later - if isinstance(password, Text): + if isinstance(password, str): password = password.encode("utf-8") h = self._CalculateHash(password, self.salt, self.iteration_count) diff --git a/grr/core/grr_response_core/lib/rdfvalues/crypto_test.py b/grr/core/grr_response_core/lib/rdfvalues/crypto_test.py index 4633dd5ee1..ee19e85e61 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/crypto_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/crypto_test.py @@ -1,10 +1,10 @@ #!/usr/bin/env python """Crypto rdfvalue tests.""" - import binascii import hashlib import os +import unittest from unittest import mock from absl import app @@ -25,7 +25,8 @@ class SignedBlobTest(rdf_test_base.RDFValueTestMixin, test_lib.GRRBaseTest): def setUp(self): super().setUp() self.private_key = config.CONFIG[ - "PrivateKeys.executable_signing_private_key"] + "PrivateKeys.executable_signing_private_key" + ] self.public_key = config.CONFIG["Client.executable_signing_public_key"] def GenerateSample(self, number=0): @@ -34,6 +35,13 @@ def GenerateSample(self, number=0): return result + @unittest.skip( + "Samples are expected to be different for the same data as PSS padding" + " generates a new Hash for every sample." + ) + def testComparisons(self): + pass + def testSignVerify(self): sample = self.GenerateSample() @@ -41,15 +49,17 @@ def testSignVerify(self): # Change the data - this should fail since the hash is incorrect. sample.data += b"X" - self.assertRaises(rdf_crypto.VerificationError, sample.Verify, - self.public_key) + self.assertRaises( + rdf_crypto.VerificationError, sample.Verify, self.public_key + ) # Update the hash sample.digest = hashlib.sha256(sample.data).digest() # Should still fail. - self.assertRaises(rdf_crypto.VerificationError, sample.Verify, - self.public_key) + self.assertRaises( + rdf_crypto.VerificationError, sample.Verify, self.public_key + ) # If we change the digest verification should fail. sample = self.GenerateSample() @@ -60,7 +70,7 @@ def testSignVerify(self): # PSS should be accepted. sample = self.GenerateSample() sample.signature_type = sample.SignatureType.RSA_PSS - sample.signature = self.private_key.Sign(sample.data, use_pss=1) + sample.signature = self.private_key.Sign(sample.data) sample.Verify(self.public_key) @@ -81,36 +91,16 @@ def setUp(self): config_stubber.Start() self.addCleanup(config_stubber.Stop) - def testInvalidX509Certificates(self): - """Deliberately try to parse an invalid certificate.""" - config.CONFIG.Initialize(data=""" -[Frontend] -certificate = -----BEGIN CERTIFICATE----- - MIIDczCCAVugAwIBAgIJANdK3LO+9qOIMA0GCSqGSIb3DQEBCwUAMFkxCzAJBgNV - uqnFquJfg8xMWHHJmPEocDpJT8Tlmbw= - -----END CERTIFICATE----- -""") - config.CONFIG.context = [] - - errors = config.CONFIG.Validate("Frontend") - self.assertCountEqual(list(errors.keys()), ["Frontend.certificate"]) - def testInvalidRSAPrivateKey(self): """Deliberately try to parse invalid RSA keys.""" config.CONFIG.Initialize(data=""" [PrivateKeys] -server_key = -----BEGIN PRIVATE KEY----- - MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAMdgLNxyvDnQsuqp - jzITFeE6mjs3k1I= - -----END PRIVATE KEY----- executable_signing_private_key = -----BEGIN RSA PRIVATE KEY----- MIIBOgIBAAJBALnfFW1FffeKPs5PLUhFOSkNrr9TDCODQAI3WluLh0sW7/ro93eo -----END RSA PRIVATE KEY----- """) config.CONFIG.context = [] - with self.assertRaises(config_lib.ConfigFormatError): - config.CONFIG.Get("PrivateKeys.server_key") with self.assertRaises(config_lib.ConfigFormatError): config.CONFIG.Get("PrivateKeys.executable_signing_private_key") @@ -144,7 +134,8 @@ def testRSAPublicKeyFailure(self): errors = config.CONFIG.Validate("Client") self.assertCountEqual( - list(errors.keys()), ["Client.executable_signing_public_key"]) + list(errors.keys()), ["Client.executable_signing_public_key"] + ) def testRSAPrivate(self): """Tests parsing an RSA private key.""" @@ -181,7 +172,8 @@ def testRSAPrivate(self): config.CONFIG.context = [] self.assertIsInstance( config.CONFIG.Get("PrivateKeys.executable_signing_private_key"), - rdf_crypto.RSAPrivateKey) + rdf_crypto.RSAPrivateKey, + ) class CryptoUtilTest(CryptoTestBase): @@ -193,23 +185,29 @@ def testStreamingCBCEncryptor(self): message = b"Hello World!!!!!" * 10 for plaintext, partitions in [ - (message, [ - [160], - [80, 80], - [75, 75, 10], - [1, 159], - [10] * 16, - [1] * 160, - ]), + ( + message, + [ + [160], + [80, 80], + [75, 75, 10], + [1, 159], + [10] * 16, + [1] * 160, + ], + ), # Prime length, not a multiple of blocksize. - (message[:149], [ - [149], - [80, 69], - [75, 55, 19], - [1, 148], - [10] * 14 + [9], - [1] * 149, - ]) + ( + message[:149], + [ + [149], + [80, 69], + [75, 55, 19], + [1, 148], + [10] * 14 + [9], + [1] * 149, + ], + ), ]: for partition in partitions: cipher = rdf_crypto.AES128CBCCipher(key, iv) @@ -217,7 +215,7 @@ def testStreamingCBCEncryptor(self): offset = 0 out = [] for n in partition: - next_partition = plaintext[offset:offset + n] + next_partition = plaintext[offset : offset + n] out.append(streaming_cbc.Update(next_partition)) offset += n out.append(streaming_cbc.Finalize()) @@ -263,8 +261,9 @@ def testAES128CBCCipher(self): self.assertRaises(rdf_crypto.CipherError, cipher.Decrypt, plain_text) -class SymmetricCipherTest(rdf_test_base.RDFValueTestMixin, - test_lib.GRRBaseTest): +class SymmetricCipherTest( + rdf_test_base.RDFValueTestMixin, test_lib.GRRBaseTest +): rdfvalue_class = rdf_crypto.SymmetricCipher sample_cache = {} @@ -320,8 +319,11 @@ def testPassPhraseEncryption(self): with self.assertRaises(type_info.TypeValueError): rdf_crypto.RSAPrivateKey(protected_pem) - with mock.patch.object(config.CONFIG, "context", - config.CONFIG.context + ["Commandline Context"]): + with mock.patch.object( + config.CONFIG, + "context", + config.CONFIG.context + ["Commandline Context"], + ): rdf_crypto.RSAPrivateKey(protected_pem) # allow_prompt=False even prevents this in the Commandline Context. @@ -344,12 +346,21 @@ def testSignVerify(self): broken_signature = _Tamper(signature) broken_message = _Tamper(message) - self.assertRaises(rdf_crypto.VerificationError, public_key.Verify, message, - broken_signature) - self.assertRaises(rdf_crypto.VerificationError, public_key.Verify, - broken_message, signature) - self.assertRaises(rdf_crypto.VerificationError, public_key.Verify, message, - b"") + self.assertRaises( + rdf_crypto.VerificationError, + public_key.Verify, + message, + broken_signature, + ) + self.assertRaises( + rdf_crypto.VerificationError, + public_key.Verify, + broken_message, + signature, + ) + self.assertRaises( + rdf_crypto.VerificationError, public_key.Verify, message, b"" + ) def testEncryptDecrypt(self): private_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=2048) @@ -363,26 +374,15 @@ def testEncryptDecrypt(self): plaintext = private_key.Decrypt(ciphertext) self.assertEqual(plaintext, message) - self.assertRaises(rdf_crypto.CipherError, private_key.Decrypt, - _Tamper(ciphertext)) - - def testPSSPadding(self): - private_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=2048) - public_key = private_key.GetPublicKey() - message = b"Hello World!" - - # Generate two different signtures, one using PKCS1v15 padding, one using - # PSS. The crypto code should accept both as valid. - signature_pkcs1v15 = private_key.Sign(message) - signature_pss = private_key.Sign(message, use_pss=True) - self.assertNotEqual(signature_pkcs1v15, signature_pss) - public_key.Verify(message, signature_pkcs1v15) - public_key.Verify(message, signature_pss) + self.assertRaises( + rdf_crypto.CipherError, private_key.Decrypt, _Tamper(ciphertext) + ) def testM2CryptoSigningCompatibility(self): pem = open(os.path.join(self.base_path, "m2crypto/rsa_key"), "rb").read() - signature = open(os.path.join(self.base_path, "m2crypto/signature"), - "rb").read() + signature = open( + os.path.join(self.base_path, "m2crypto/signature"), "rb" + ).read() private_key = rdf_crypto.RSAPrivateKey(pem) message = b"Signed by M2Crypto!" @@ -395,7 +395,8 @@ def testM2CryptoEncryptionCompatibility(self): pem = open(os.path.join(self.base_path, "m2crypto/rsa_key"), "rb").read() private_key = rdf_crypto.RSAPrivateKey(pem) ciphertext = open( - os.path.join(self.base_path, "m2crypto/rsa_ciphertext"), "rb").read() + os.path.join(self.base_path, "m2crypto/rsa_ciphertext"), "rb" + ).read() message = b"Encrypted by M2Crypto!" plaintext = private_key.Decrypt(ciphertext) @@ -414,12 +415,17 @@ def testHMAC(self): h.Verify(message, signature) broken_message = message + b"!" - self.assertRaises(rdf_crypto.VerificationError, h.Verify, broken_message, - signature) + self.assertRaises( + rdf_crypto.VerificationError, h.Verify, broken_message, signature + ) broken_signature = _Tamper(signature) - self.assertRaises(rdf_crypto.VerificationError, h.Verify, b"Hello World!", - broken_signature) + self.assertRaises( + rdf_crypto.VerificationError, + h.Verify, + b"Hello World!", + broken_signature, + ) def testSHA256(self): """Tests that both types of signatures are ok.""" @@ -437,7 +443,8 @@ def testM2CryptoCompatibility(self): message = b"HMAC by M2Crypto!" signature = binascii.unhexlify("99cae3ec7b41ceb6e6619f2f85368cb3ae118b70") key = rdf_crypto.EncryptionKey.FromHumanReadable( - "94bd4e0ecc8397a8b2cdbc4b127ee7b0") + "94bd4e0ecc8397a8b2cdbc4b127ee7b0" + ) h = rdf_crypto.HMAC(key) self.assertEqual(h.HMAC(message), signature) @@ -447,27 +454,10 @@ def testM2CryptoCompatibility(self): class RDFX509CertTest(CryptoTestBase): - def testCertificateVerification(self): - private_key = rdf_crypto.RSAPrivateKey.GenerateKey() - csr = rdf_crypto.CertificateSigningRequest( - common_name="C.0000000000000001", private_key=private_key) - client_cert = rdf_crypto.RDFX509Cert.ClientCertFromCSR(csr) - - ca_signing_key = config.CONFIG["PrivateKeys.ca_key"] - - csr.Verify(private_key.GetPublicKey()) - client_cert.Verify(ca_signing_key.GetPublicKey()) - - wrong_key = rdf_crypto.RSAPrivateKey.GenerateKey() - with self.assertRaises(rdf_crypto.VerificationError): - csr.Verify(wrong_key.GetPublicKey()) - - with self.assertRaises(rdf_crypto.VerificationError): - client_cert.Verify(wrong_key.GetPublicKey()) - def testExpiredTestCertificate(self): - pem = open(os.path.join(self.base_path, "outdated_certificate"), - "rb").read() + pem = open( + os.path.join(self.base_path, "outdated_certificate"), "rb" + ).read() certificate = rdf_crypto.RDFX509Cert(pem) exception_catcher = self.assertRaises(rdf_crypto.VerificationError) @@ -478,23 +468,6 @@ def testExpiredTestCertificate(self): self.assertIn("Certificate expired!", str(exception_catcher.exception)) - def testCertificateValidation(self): - private_key = rdf_crypto.RSAPrivateKey.GenerateKey() - csr = rdf_crypto.CertificateSigningRequest( - common_name="C.0000000000000001", private_key=private_key) - client_cert = rdf_crypto.RDFX509Cert.ClientCertFromCSR(csr) - - now = rdfvalue.RDFDatetime.Now() - now_plus_year_and_a_bit = now + rdfvalue.Duration.From(55, rdfvalue.WEEKS) - now_minus_a_bit = now - rdfvalue.Duration.From(1, rdfvalue.HOURS) - with test_lib.FakeTime(now_plus_year_and_a_bit): - with self.assertRaises(rdf_crypto.VerificationError): - client_cert.Verify(private_key.GetPublicKey()) - - with test_lib.FakeTime(now_minus_a_bit): - with self.assertRaises(rdf_crypto.VerificationError): - client_cert.Verify(private_key.GetPublicKey()) - class PasswordTest(CryptoTestBase): diff --git a/grr/core/grr_response_core/lib/rdfvalues/events.py b/grr/core/grr_response_core/lib/rdfvalues/events.py index 1db980c293..6d53141350 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/events.py +++ b/grr/core/grr_response_core/lib/rdfvalues/events.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """RDF values related to events.""" - from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import random diff --git a/grr/core/grr_response_core/lib/rdfvalues/file_finder.py b/grr/core/grr_response_core/lib/rdfvalues/file_finder.py index 452710779c..2c84618f9e 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/file_finder.py +++ b/grr/core/grr_response_core/lib/rdfvalues/file_finder.py @@ -67,7 +67,8 @@ def Validate(self): # The literal must not be empty in the literal match condition. if not self.HasField("literal") or not self.literal: raise ValueError( - "No literal provided to FileFinderContentsLiteralMatchCondition.") + "No literal provided to FileFinderContentsLiteralMatchCondition." + ) class FileFinderCondition(rdf_structs.RDFProtoStruct): @@ -218,6 +219,7 @@ def Download(cls, **kwargs): class FileFinderArgs(rdf_structs.RDFProtoStruct): """An RDF value representing file finder flow arguments.""" + protobuf = flows_pb2.FileFinderArgs rdf_deps = [ FileFinderAction, @@ -244,12 +246,14 @@ class FileFinderResult(rdf_structs.RDFProtoStruct): class CollectFilesByKnownPathArgs(rdf_structs.RDFProtoStruct): """Arguments for CollectFilesByKnownPath.""" + protobuf = flows_pb2.CollectFilesByKnownPathArgs rdf_deps = [] class CollectFilesByKnownPathResult(rdf_structs.RDFProtoStruct): """Result returned by CollectFilesByKnownPath.""" + protobuf = flows_pb2.CollectFilesByKnownPathResult rdf_deps = [ rdf_crypto.Hash, @@ -259,12 +263,14 @@ class CollectFilesByKnownPathResult(rdf_structs.RDFProtoStruct): class CollectFilesByKnownPathProgress(rdf_structs.RDFProtoStruct): """Progress returned by CollectFilesByKnownPath.""" + protobuf = flows_pb2.CollectFilesByKnownPathProgress rdf_deps = [] class CollectMultipleFilesArgs(rdf_structs.RDFProtoStruct): """Arguments for CollectMultipleFiles.""" + protobuf = flows_pb2.CollectMultipleFilesArgs rdf_deps = [ rdf_paths.GlobExpression, @@ -280,6 +286,7 @@ class CollectMultipleFilesArgs(rdf_structs.RDFProtoStruct): class CollectMultipleFilesResult(rdf_structs.RDFProtoStruct): """Result returned by CollectMultipleFiles.""" + protobuf = flows_pb2.CollectMultipleFilesResult rdf_deps = [ rdf_crypto.Hash, @@ -289,6 +296,7 @@ class CollectMultipleFilesResult(rdf_structs.RDFProtoStruct): class CollectMultipleFilesProgress(rdf_structs.RDFProtoStruct): """Progress returned by CollectMultipleFiles.""" + protobuf = flows_pb2.CollectMultipleFilesProgress rdf_deps = [] diff --git a/grr/core/grr_response_core/lib/rdfvalues/flows.py b/grr/core/grr_response_core/lib/rdfvalues/flows.py index 1fad73e2d1..3389374e89 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/flows.py +++ b/grr/core/grr_response_core/lib/rdfvalues/flows.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """RDFValue implementations related to flow scheduling.""" - import threading import time @@ -18,6 +17,7 @@ class GrrMessage(rdf_structs.RDFProtoStruct): """An RDFValue class to manage GRR messages.""" + protobuf = jobs_pb2.GrrMessage rdf_deps = [ rdf_protodict.EmbeddedRDFValue, @@ -31,11 +31,9 @@ class GrrMessage(rdf_structs.RDFProtoStruct): next_id_base = 0 max_ttl = 5 - def __init__(self, - initializer=None, - payload=None, - generate_task_id=False, - **kwarg): + def __init__( + self, initializer=None, payload=None, generate_task_id=False, **kwarg + ): super().__init__(initializer=initializer, **kwarg) if payload is not None: @@ -83,13 +81,15 @@ def HasTaskID(self): @property def args(self): - raise RuntimeError("Direct access to serialized args is not permitted! " - "Use payload field.") + raise RuntimeError( + "Direct access to serialized args is not permitted! Use payload field." + ) @args.setter def args(self, value): - raise RuntimeError("Direct access to serialized args is not permitted! " - "Use payload field.") + raise RuntimeError( + "Direct access to serialized args is not permitted! Use payload field." + ) @property def payload(self): @@ -121,6 +121,7 @@ class GrrStatus(rdf_structs.RDFProtoStruct): followed by a single status message. The GrrStatus message contains error and traceback information for any failures on the client. """ + protobuf = jobs_pb2.GrrStatus rdf_deps = [ rdf_client_stats.CpuSeconds, @@ -142,6 +143,7 @@ class Notification(rdf_structs.RDFProtoStruct): Usually the notification means that some operation is completed, and provides a link to view the results. """ + protobuf = jobs_pb2.Notification rdf_deps = [ rdfvalue.RDFDatetime, @@ -154,7 +156,7 @@ class Notification(rdf_structs.RDFProtoStruct): "FlowStatus", # Link to a flow "GrantAccess", # Link to an access grant page "ArchiveGenerationFinished", - "Error" + "Error", ] @@ -168,6 +170,7 @@ class FlowNotification(rdf_structs.RDFProtoStruct): class NotificationList(rdf_protodict.RDFValueArray): """A List of notifications for this user.""" + rdf_type = Notification @@ -191,6 +194,7 @@ def __len__(self): class CipherProperties(rdf_structs.RDFProtoStruct): """Contains information about a cipher and keys.""" + protobuf = jobs_pb2.CipherProperties rdf_deps = [ rdf_crypto.EncryptionKey, @@ -224,6 +228,7 @@ class CipherMetadata(rdf_structs.RDFProtoStruct): class FlowLog(rdf_structs.RDFProtoStruct): """An RDFValue class representing flow log entries.""" + protobuf = jobs_pb2.FlowLog rdf_deps = [ rdf_client.ClientURN, @@ -250,4 +255,5 @@ class ClientCommunication(rdf_structs.RDFProtoStruct): class EmptyFlowArgs(rdf_structs.RDFProtoStruct): """Some flows do not take argumentnts.""" + protobuf = flows_pb2.EmptyFlowArgs diff --git a/grr/core/grr_response_core/lib/rdfvalues/flows_test.py b/grr/core/grr_response_core/lib/rdfvalues/flows_test.py index 30701d4be9..6975af96a3 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/flows_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/flows_test.py @@ -1,8 +1,6 @@ #!/usr/bin/env python """Test for the flow state class.""" - - from absl import app from grr_response_core.lib import rdfvalue @@ -26,24 +24,45 @@ def testSessionIDValidation(self): rdfvalue.SessionID(rdfvalue.RDFURN("aff4:/flows/DEBUG-user1:12345678:hunt")) def testBadStructure(self): - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A:123456:1:")) - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A:123456::")) - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A:123456:")) - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A:")) - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/:")) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A:123456:1:"), + ) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A:123456::"), + ) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A:123456:"), + ) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A:"), + ) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/:"), + ) def testBadQueue(self): - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A%b:12345678")) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A%b:12345678"), + ) def testBadFlowID(self): - self.assertRaises(rdfvalue.InitializeError, rdfvalue.SessionID, - rdfvalue.RDFURN("aff4:/flows/A:1234567G%sdf")) + self.assertRaises( + rdfvalue.InitializeError, + rdfvalue.SessionID, + rdfvalue.RDFURN("aff4:/flows/A:1234567G%sdf"), + ) def main(argv): diff --git a/grr/core/grr_response_core/lib/rdfvalues/large_file.py b/grr/core/grr_response_core/lib/rdfvalues/large_file.py index 63c741e900..f21a9096ab 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/large_file.py +++ b/grr/core/grr_response_core/lib/rdfvalues/large_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with RDF wrappers for large file collection proto messages.""" + from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import large_file_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/memory.py b/grr/core/grr_response_core/lib/rdfvalues/memory.py index 4e2f073111..3d56b32279 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/memory.py +++ b/grr/core/grr_response_core/lib/rdfvalues/memory.py @@ -23,6 +23,7 @@ class YaraSignatureShard(rdf_structs.RDFProtoStruct): class YaraProcessScanRequest(rdf_structs.RDFProtoStruct): """Args for YaraProcessScan flow and client action.""" + protobuf = flows_pb2.YaraProcessScanRequest rdf_deps = [ YaraSignature, @@ -36,7 +37,9 @@ def __init__(self, *args, **kwargs): # These default values were migrated from the Protobuf definition. if not self.HasField("include_errors_in_results"): - self.include_errors_in_results = YaraProcessScanRequest.ErrorPolicy.NO_ERRORS + self.include_errors_in_results = ( + YaraProcessScanRequest.ErrorPolicy.NO_ERRORS + ) if not self.HasField("include_misses_in_results"): self.include_misses_in_results = False if not self.HasField("ignore_grr_process"): @@ -59,6 +62,8 @@ def __init__(self, *args, **kwargs): self.dump_process_on_match = False if not self.HasField("process_dump_size_limit"): self.process_dump_size_limit = 0 + if not self.HasField("context_window"): + self.context_window = 50 class ProcessMemoryError(rdf_structs.RDFProtoStruct): @@ -68,6 +73,7 @@ class ProcessMemoryError(rdf_structs.RDFProtoStruct): class YaraStringMatch(rdf_structs.RDFProtoStruct): """A result of Yara string matching.""" + protobuf = flows_pb2.YaraStringMatch rdf_deps = [] @@ -82,16 +88,22 @@ def FromLibYaraStringMatch(cls, yara_string_match): class YaraMatch(rdf_structs.RDFProtoStruct): """A result of Yara matching.""" + protobuf = flows_pb2.YaraMatch rdf_deps = [YaraStringMatch] @classmethod - def FromLibYaraMatch(cls, yara_match): + def FromLibYaraMatch(cls, yara_match, data, context_window): res = cls() res.rule_name = yara_match.rule - res.string_matches = [ - YaraStringMatch.FromLibYaraStringMatch(sm) for sm in yara_match.strings - ] + string_matches = [] + for sm in yara_match.strings: + yara_string_match = YaraStringMatch.FromLibYaraStringMatch(sm) + if context_window > 0: + context = data[sm[0] - context_window : sm[0] + context_window] + yara_string_match.context = context + string_matches.append(yara_string_match) + res.string_matches = string_matches return res @@ -112,6 +124,7 @@ class YaraProcessScanResponse(rdf_structs.RDFProtoStruct): class YaraProcessDumpArgs(rdf_structs.RDFProtoStruct): """Args for DumpProcessMemory flow and YaraProcessDump client action.""" + protobuf = flows_pb2.YaraProcessDumpArgs rdf_deps = [rdfvalue.ByteSize] diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_anomaly.py b/grr/core/grr_response_core/lib/rdfvalues/mig_anomaly.py index 92c53f2cc3..32ad0f7afc 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_anomaly.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_anomaly.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly from grr_response_proto import anomaly_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_artifacts.py b/grr/core/grr_response_core/lib/rdfvalues/mig_artifacts.py index f179cb4e24..7ba0ffa6c2 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_artifacts.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_artifacts.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import artifacts as rdf_artifacts from grr_response_proto import artifact_pb2 from grr_response_proto import flows_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_client.py b/grr/core/grr_response_core/lib/rdfvalues/mig_client.py index c6eddcefdc..4396cc7752 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_client.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_client.py @@ -1,19 +1,12 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_proto import jobs_pb2 from grr_response_proto import knowledge_base_pb2 from grr_response_proto import sysinfo_pb2 -def ToProtoPCIDevice(rdf: rdf_client.PCIDevice) -> sysinfo_pb2.PCIDevice: - return rdf.AsPrimitiveProto() - - -def ToRDFPCIDevice(proto: sysinfo_pb2.PCIDevice) -> rdf_client.PCIDevice: - return rdf_client.PCIDevice.FromSerializedBytes(proto.SerializeToString()) - - def ToProtoPackageRepository( rdf: rdf_client.PackageRepository, ) -> sysinfo_pb2.PackageRepository: diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_client_action.py b/grr/core/grr_response_core/lib/rdfvalues/mig_client_action.py index 1265a4d00c..f5494c6370 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_client_action.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_client_action.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import client_action as rdf_client_action from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_client_fs.py b/grr/core/grr_response_core/lib/rdfvalues/mig_client_fs.py index 03bd80a9d1..6e60fe8e47 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_client_fs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_client_fs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_client_network.py b/grr/core/grr_response_core/lib/rdfvalues/mig_client_network.py index 4872af9cf3..b36e776cfd 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_client_network.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_client_network.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import client_network as rdf_client_network from grr_response_proto import jobs_pb2 from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_client_stats.py b/grr/core/grr_response_core/lib/rdfvalues/mig_client_stats.py index e8cea2618c..7bf12f0619 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_client_stats.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_client_stats.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_cloud.py b/grr/core/grr_response_core/lib/rdfvalues/mig_cloud.py index e997b9575e..4d5d50330d 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_cloud.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_cloud.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import cloud as rdf_cloud from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_config.py b/grr/core/grr_response_core/lib/rdfvalues/mig_config.py index 9558c7a042..843c7b35c0 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_config.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_config.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import config as rdf_config from grr_response_proto import config_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_config_file.py b/grr/core/grr_response_core/lib/rdfvalues/mig_config_file.py index 4aa12f6640..7355b40fd3 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_config_file.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_config_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import config_file as rdf_config_file from grr_response_proto import config_file_pb2 @@ -60,34 +61,6 @@ def ToRDFNfsExport( ) -def ToProtoSshdMatchBlock( - rdf: rdf_config_file.SshdMatchBlock, -) -> config_file_pb2.SshdMatchBlock: - return rdf.AsPrimitiveProto() - - -def ToRDFSshdMatchBlock( - proto: config_file_pb2.SshdMatchBlock, -) -> rdf_config_file.SshdMatchBlock: - return rdf_config_file.SshdMatchBlock.FromSerializedBytes( - proto.SerializeToString() - ) - - -def ToProtoSshdConfig( - rdf: rdf_config_file.SshdConfig, -) -> config_file_pb2.SshdConfig: - return rdf.AsPrimitiveProto() - - -def ToRDFSshdConfig( - proto: config_file_pb2.SshdConfig, -) -> rdf_config_file.SshdConfig: - return rdf_config_file.SshdConfig.FromSerializedBytes( - proto.SerializeToString() - ) - - def ToProtoNtpConfig( rdf: rdf_config_file.NtpConfig, ) -> config_file_pb2.NtpConfig: @@ -128,59 +101,3 @@ def ToRDFPamConfig( return rdf_config_file.PamConfig.FromSerializedBytes( proto.SerializeToString() ) - - -def ToProtoSudoersAlias( - rdf: rdf_config_file.SudoersAlias, -) -> config_file_pb2.SudoersAlias: - return rdf.AsPrimitiveProto() - - -def ToRDFSudoersAlias( - proto: config_file_pb2.SudoersAlias, -) -> rdf_config_file.SudoersAlias: - return rdf_config_file.SudoersAlias.FromSerializedBytes( - proto.SerializeToString() - ) - - -def ToProtoSudoersDefault( - rdf: rdf_config_file.SudoersDefault, -) -> config_file_pb2.SudoersDefault: - return rdf.AsPrimitiveProto() - - -def ToRDFSudoersDefault( - proto: config_file_pb2.SudoersDefault, -) -> rdf_config_file.SudoersDefault: - return rdf_config_file.SudoersDefault.FromSerializedBytes( - proto.SerializeToString() - ) - - -def ToProtoSudoersEntry( - rdf: rdf_config_file.SudoersEntry, -) -> config_file_pb2.SudoersEntry: - return rdf.AsPrimitiveProto() - - -def ToRDFSudoersEntry( - proto: config_file_pb2.SudoersEntry, -) -> rdf_config_file.SudoersEntry: - return rdf_config_file.SudoersEntry.FromSerializedBytes( - proto.SerializeToString() - ) - - -def ToProtoSudoersConfig( - rdf: rdf_config_file.SudoersConfig, -) -> config_file_pb2.SudoersConfig: - return rdf.AsPrimitiveProto() - - -def ToRDFSudoersConfig( - proto: config_file_pb2.SudoersConfig, -) -> rdf_config_file.SudoersConfig: - return rdf_config_file.SudoersConfig.FromSerializedBytes( - proto.SerializeToString() - ) diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_cronjobs.py b/grr/core/grr_response_core/lib/rdfvalues/mig_cronjobs.py index 8bcfadbf09..afaee049ff 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_cronjobs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_cronjobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import cronjobs as rdf_cronjobs from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_crypto.py b/grr/core/grr_response_core/lib/rdfvalues/mig_crypto.py index bbe5bac3a8..e5483f1c68 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_crypto.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_crypto.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_dummy.py b/grr/core/grr_response_core/lib/rdfvalues/mig_dummy.py index 620d9d102e..416b131c93 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_dummy.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_dummy.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import dummy as rdf_dummy from grr_response_proto import dummy_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_events.py b/grr/core/grr_response_core/lib/rdfvalues/mig_events.py index 7c31c85261..4b5c696a3e 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_events.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_events.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import events as rdf_events from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_file_finder.py b/grr/core/grr_response_core/lib/rdfvalues/mig_file_finder.py index cc5415cd35..b73cf8b9c9 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_file_finder.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_file_finder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder from grr_response_proto import flows_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_flows.py b/grr/core/grr_response_core/lib/rdfvalues/mig_flows.py index 31c5efa864..1a72ff5330 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_flows.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_flows.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_large_file.py b/grr/core/grr_response_core/lib/rdfvalues/mig_large_file.py index 19326b0281..d3952ee28f 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_large_file.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_large_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import large_file as rdf_large_file from grr_response_proto import large_file_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_memory.py b/grr/core/grr_response_core/lib/rdfvalues/mig_memory.py index b199642cdc..643c4cb0ea 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_memory.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_memory.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import memory as rdf_memory from grr_response_proto import flows_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_osquery.py b/grr/core/grr_response_core/lib/rdfvalues/mig_osquery.py index 56cb1ae3ad..b1a3a6f6d5 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_osquery.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_osquery.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import osquery as rdf_osquery from grr_response_proto import osquery_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_paths.py b/grr/core/grr_response_core/lib/rdfvalues/mig_paths.py index a32d033a22..e38fd7318b 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_paths.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_paths.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_plist.py b/grr/core/grr_response_core/lib/rdfvalues/mig_plist.py index 273e15e2f6..212266137b 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_plist.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_plist.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import plist as rdf_plist from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_protodict.py b/grr/core/grr_response_core/lib/rdfvalues/mig_protodict.py index 9db4945476..91c91e63a3 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_protodict.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_protodict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from typing import Any, Dict from grr_response_core.lib.rdfvalues import protodict as rdf_protodict @@ -76,7 +77,7 @@ def FromProtoAttributedDictToNativeDict( def FromNativeDictToProtoAttributedDict( - dictionary: Dict[Any, Any] + dictionary: Dict[Any, Any], ) -> jobs_pb2.AttributedDict: rdf_dict = rdf_protodict.AttributedDict().FromDict(dictionary) return ToProtoAttributedDict(rdf_dict) diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_read_low_level.py b/grr/core/grr_response_core/lib/rdfvalues/mig_read_low_level.py index 757e161f59..20a5f53f61 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_read_low_level.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_read_low_level.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import read_low_level as rdf_read_low_level from grr_response_proto import read_low_level_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_search.py b/grr/core/grr_response_core/lib/rdfvalues/mig_search.py index 8427f63d22..c4aefc9a73 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_search.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_search.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import search as rdf_search from grr_response_proto.api import search_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_standard.py b/grr/core/grr_response_core/lib/rdfvalues/mig_standard.py index 7c65d245f6..be7657dfdc 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_standard.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_standard.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import standard as rdf_standard from grr_response_proto import jobs_pb2 from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_stats.py b/grr/core/grr_response_core/lib/rdfvalues/mig_stats.py index 4724d6ebce..82e352e36c 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_stats.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_stats.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import stats as rdf_stats from grr_response_proto import analysis_pb2 from grr_response_proto import jobs_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_structs.py b/grr/core/grr_response_core/lib/rdfvalues/mig_structs.py index 8c33f6cbc1..42e11b4cae 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_structs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_structs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from google.protobuf import any_pb2 from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import semantic_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_timeline.py b/grr/core/grr_response_core/lib/rdfvalues/mig_timeline.py index 389dd9c8d3..3093badac4 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_timeline.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_timeline.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import timeline as rdf_timeline from grr_response_proto import timeline_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_webhistory.py b/grr/core/grr_response_core/lib/rdfvalues/mig_webhistory.py index 3f5f6fc72c..d9a7d1f550 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_webhistory.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_webhistory.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import webhistory as rdf_webhistory from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_wkt.py b/grr/core/grr_response_core/lib/rdfvalues/mig_wkt.py index 82e4efe721..e064bdab58 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_wkt.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_wkt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from google.protobuf import timestamp_pb2 from grr_response_core.lib.rdfvalues import wkt as rdf_wkt diff --git a/grr/core/grr_response_core/lib/rdfvalues/mig_wmi.py b/grr/core/grr_response_core/lib/rdfvalues/mig_wmi.py index 6efdd5d854..9f47410595 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/mig_wmi.py +++ b/grr/core/grr_response_core/lib/rdfvalues/mig_wmi.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_core.lib.rdfvalues import wmi as rdf_wmi from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/osquery_test.py b/grr/core/grr_response_core/lib/rdfvalues/osquery_test.py index dc315e9549..8e48095407 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/osquery_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/osquery_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_core.lib.rdfvalues import osquery as rdf_osquery diff --git a/grr/core/grr_response_core/lib/rdfvalues/paths.py b/grr/core/grr_response_core/lib/rdfvalues/paths.py index 9d67a73a33..2240d5653f 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/paths.py +++ b/grr/core/grr_response_core/lib/rdfvalues/paths.py @@ -15,10 +15,10 @@ On the server the PathSpec is represented as a PathSpec object, and stored as an attribute of the AFF4 object. This module defines this abstraction. """ + import itertools import posixpath import re - from typing import Sequence from grr_response_core.lib import artifact_utils @@ -37,6 +37,7 @@ class PathSpec(rdf_structs.RDFProtoStruct): class makes it easier to manipulate these structures by providing useful helpers. """ + protobuf = jobs_pb2.PathSpec rdf_deps = [ rdfvalue.ByteSize, @@ -236,7 +237,9 @@ def AFF4Path(self, client_urn): if not self.HasField("pathtype"): raise ValueError( "Can't determine AFF4 path without a valid pathtype for {}.".format( - self)) + self + ) + ) first_component = self[0] dev = first_component.path @@ -244,8 +247,11 @@ def AFF4Path(self, client_urn): # We divide here just to get prettier numbers in the GUI dev += ":{}".format(first_component.offset // 512) - if (len(self) > 1 and first_component.pathtype == PathSpec.PathType.OS and - self[1].pathtype in (PathSpec.PathType.TSK, PathSpec.PathType.NTFS)): + if ( + len(self) > 1 + and first_component.pathtype == PathSpec.PathType.OS + and self[1].pathtype in (PathSpec.PathType.TSK, PathSpec.PathType.NTFS) + ): result = [self.AFF4_PREFIXES[self[1].pathtype], dev] # Skip the top level pathspec. @@ -282,6 +288,7 @@ def _unique(iterable): class GlobComponentExplanation(rdf_structs.RDFProtoStruct): """A sub-part of a GlobExpression with examples.""" + protobuf = flows_pb2.GlobComponentExplanation @@ -291,11 +298,21 @@ class GlobComponentExplanation(rdf_structs.RDFProtoStruct): _VAR_PATTERN = re.compile("(" + "|".join([r"%%\w+%%", r"%%\w+\.\w+%%"]) + ")") _REGEX_SPLIT_PATTERN = re.compile( - "(" + "|".join(["{[^}]+,[^}]+}", "\\?", "\\*\\*\\/?", "\\*"]) + ")") - -_COMPONENT_SPLIT_PATTERN = re.compile("(" + "|".join([ - r"{[^}]+,[^}]+}", r"\?", r"\*\*\d*/?", r"\*", r"%%\w+%%", r"%%\w+\.\w+%%" -]) + ")") + "(" + "|".join(["{[^}]+,[^}]+}", "\\?", "\\*\\*\\/?", "\\*"]) + ")" +) + +_COMPONENT_SPLIT_PATTERN = re.compile( + "(" + + "|".join([ + r"{[^}]+,[^}]+}", + r"\?", + r"\*\*\d*/?", + r"\*", + r"%%\w+%%", + r"%%\w+\.\w+%%", + ]) + + ")" +) class GlobExpression(rdfvalue.RDFString): @@ -337,7 +354,7 @@ def InterpolateGrouping(self, pattern): components = [] offset = 0 for match in GROUPING_PATTERN.finditer(pattern): - components.append([pattern[offset:match.start()]]) + components.append([pattern[offset : match.start()]]) # Expand the attribute into the set of possibilities: alternatives = match.group(1).split(",") @@ -348,7 +365,7 @@ def InterpolateGrouping(self, pattern): # Now calculate the cartesian products of all these sets to form all # strings. for vector in itertools.product(*components): - yield u"".join(vector) + yield "".join(vector) def _ReplaceRegExGrouping(self, grouping): alternatives = grouping.group(1).split(",") @@ -366,8 +383,9 @@ def _ReplaceRegExPart(self, part): else: return re.escape(part) - def ExplainComponents(self, example_count: int, - knowledge_base) -> Sequence[GlobComponentExplanation]: + def ExplainComponents( + self, example_count: int, knowledge_base + ) -> Sequence[GlobComponentExplanation]: """Returns a list of GlobComponentExplanations with examples.""" parts = _COMPONENT_SPLIT_PATTERN.split(self._value) components = [] @@ -388,7 +406,8 @@ def ExplainComponents(self, example_count: int, # possible values, this should still be enough. try: examples = artifact_utils.InterpolateKbAttributes( - glob_part, knowledge_base) + glob_part, knowledge_base + ) except artifact_utils.Error: # Interpolation can fail for many non-critical reasons, e.g. when the # client is missing a KB attribute. @@ -410,6 +429,6 @@ def AsRegEx(self): A RegularExpression() object. """ parts = _REGEX_SPLIT_PATTERN.split(self._value) - result = u"".join(self._ReplaceRegExPart(p) for p in parts) + result = "".join(self._ReplaceRegExPart(p) for p in parts) - return rdf_standard.RegularExpression(u"(?i)\\A%s\\Z" % result) + return rdf_standard.RegularExpression("(?i)\\A%s\\Z" % result) diff --git a/grr/core/grr_response_core/lib/rdfvalues/paths_test.py b/grr/core/grr_response_core/lib/rdfvalues/paths_test.py index 53c83cff47..0b38a6f91e 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/paths_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/paths_test.py @@ -29,7 +29,8 @@ def GenerateSample(self, number=0): def testPop(self): """Test we can pop arbitrary elements from the pathspec.""" sample = rdf_paths.PathSpec( - path="/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/", pathtype=rdf_paths.PathSpec.PathType.OS + ) for i in range(5): sample.Append(path=str(i), pathtype=rdf_paths.PathSpec.PathType.OS) @@ -54,7 +55,8 @@ def testPathSpec(self): pathspec_pb.nested_path.pathtype = 2 reference_pathspec = rdf_paths.PathSpec.FromSerializedBytes( - pathspec_pb.SerializeToString()) + pathspec_pb.SerializeToString() + ) # Create a new RDFPathspec from scratch. pathspec = rdf_paths.PathSpec() @@ -107,14 +109,15 @@ def testPathSpec(self): def testUnicodePaths(self): """Test that we can manipulate paths in unicode.""" - sample = rdf_paths.PathSpec(pathtype=1, path=u"/dev/c/msn升级程序[1].exe") + sample = rdf_paths.PathSpec(pathtype=1, path="/dev/c/msn升级程序[1].exe") # Ensure we can convert to a string. str(sample) def testCopy(self): sample = rdf_paths.PathSpec( - path="/", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/", pathtype=rdf_paths.PathSpec.PathType.OS + ) sample.Append(path="foo", pathtype=rdf_paths.PathSpec.PathType.TSK) # Make a copy of the original and change it. @@ -165,18 +168,21 @@ def testGroupingInterpolation(self): glob_expression = rdf_paths.GlobExpression() interpolated = glob_expression.InterpolateGrouping("/home/*.{sh,deb}") - self.assertCountEqual(interpolated, [u"/home/*.deb", u"/home/*.sh"]) + self.assertCountEqual(interpolated, ["/home/*.deb", "/home/*.sh"]) interpolated = glob_expression.InterpolateGrouping("/home/*.{sh, deb}") - self.assertCountEqual(interpolated, [u"/home/*. deb", u"/home/*.sh"]) + self.assertCountEqual(interpolated, ["/home/*. deb", "/home/*.sh"]) interpolated = glob_expression.InterpolateGrouping( - "HKEY_CLASSES_ROOT/CLSID/{16d12736-7a9e-4765-bec6-f301d679caaa}") + "HKEY_CLASSES_ROOT/CLSID/{16d12736-7a9e-4765-bec6-f301d679caaa}" + ) self.assertCountEqual( interpolated, - [u"HKEY_CLASSES_ROOT/CLSID/{16d12736-7a9e-4765-bec6-f301d679caaa}"]) + ["HKEY_CLASSES_ROOT/CLSID/{16d12736-7a9e-4765-bec6-f301d679caaa}"], + ) def testValidation(self): glob_expression = rdf_paths.GlobExpression( - "/home/%%users.username%%/**/.mozilla/") + "/home/%%users.username%%/**/.mozilla/" + ) glob_expression.Validate() glob_expression = rdf_paths.GlobExpression("/home/**/**") @@ -259,28 +265,33 @@ def testRegExIsCaseInsensitive(self): self.assertFalse(regex.Match("/foo/bar2/blah.COM")) def testGlobExpressionSplitsIntoExplainableComponents(self): - kb = rdf_client.KnowledgeBase(users=[ - rdf_client.User(homedir="/home/foo"), - rdf_client.User(homedir="/home/bar"), - rdf_client.User(homedir="/home/baz"), - ]) + kb = rdf_client.KnowledgeBase( + users=[ + rdf_client.User(homedir="/home/foo"), + rdf_client.User(homedir="/home/bar"), + rdf_client.User(homedir="/home/baz"), + ] + ) # Test for preservation of **/ because it behaves different to **. ge = rdf_paths.GlobExpression("/foo/**/{bar,baz}/bar?/.*baz") components = ge.ExplainComponents(2, kb) self.assertEqual( [c.glob_expression for c in components], - ["/foo/", "**/", "{bar,baz}", "/bar", "?", "/.", "*", "baz"]) + ["/foo/", "**/", "{bar,baz}", "/bar", "?", "/.", "*", "baz"], + ) ge = rdf_paths.GlobExpression("/foo/**bar") components = ge.ExplainComponents(2, kb) - self.assertEqual([c.glob_expression for c in components], - ["/foo/", "**", "bar"]) + self.assertEqual( + [c.glob_expression for c in components], ["/foo/", "**", "bar"] + ) ge = rdf_paths.GlobExpression("/foo/**10bar") components = ge.ExplainComponents(2, kb) - self.assertEqual([c.glob_expression for c in components], - ["/foo/", "**10", "bar"]) + self.assertEqual( + [c.glob_expression for c in components], ["/foo/", "**10", "bar"] + ) ge = rdf_paths.GlobExpression("/{foo,bar,baz}") components = ge.ExplainComponents(2, kb) @@ -288,8 +299,9 @@ def testGlobExpressionSplitsIntoExplainableComponents(self): ge = rdf_paths.GlobExpression("%%users.homedir%%/foo") components = ge.ExplainComponents(2, kb) - self.assertEqual([c.glob_expression for c in components], - ["%%users.homedir%%", "/foo"]) + self.assertEqual( + [c.glob_expression for c in components], ["%%users.homedir%%", "/foo"] + ) self.assertEqual(components[0].examples, ["/home/foo", "/home/bar"]) def testExplainComponentsReturnsEmptyExamplesOnKbError(self): @@ -300,7 +312,8 @@ def testExplainComponentsReturnsEmptyExamplesOnKbError(self): self.assertEqual(components[0].examples, []) def _testAFF4Path_mountPointResolution( - self, pathtype: rdf_paths.PathSpec.PathType) -> None: + self, pathtype: rdf_paths.PathSpec.PathType + ) -> None: path = rdf_paths.PathSpec( path="\\\\.\\Volume{1234}\\", pathtype=rdf_paths.PathSpec.PathType.OS, @@ -308,11 +321,13 @@ def _testAFF4Path_mountPointResolution( nested_path=rdf_paths.PathSpec( path="/windows/", pathtype=pathtype, - )) + ), + ) prefix = rdf_paths.PathSpec.AFF4_PREFIXES[pathtype] self.assertEqual( str(path.AFF4Path(rdf_client.ClientURN("C.0000000000000001"))), - f"aff4:/C.0000000000000001{prefix}/\\\\.\\Volume{{1234}}\\/windows") + f"aff4:/C.0000000000000001{prefix}/\\\\.\\Volume{{1234}}\\/windows", + ) def testAFF4Path_mountPointResolution_TSK(self): self._testAFF4Path_mountPointResolution(rdf_paths.PathSpec.PathType.TSK) diff --git a/grr/core/grr_response_core/lib/rdfvalues/plist.py b/grr/core/grr_response_core/lib/rdfvalues/plist.py index fb6196acca..0fa9d56bf1 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/plist.py +++ b/grr/core/grr_response_core/lib/rdfvalues/plist.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Plist related rdfvalues.""" - from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/proto2.py b/grr/core/grr_response_core/lib/rdfvalues/proto2.py index 766a745594..30590e3821 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/proto2.py +++ b/grr/core/grr_response_core/lib/rdfvalues/proto2.py @@ -9,11 +9,9 @@ library. """ - import inspect import logging - from grr_response_core.lib import rdfvalue from grr_response_core.lib import type_info from grr_response_proto import semantic_pb2 @@ -92,12 +90,14 @@ def DefineFromWireFormat(cls, protobuf): different names. """ mro_chain = [c.__name__ for c in inspect.getmro(cls)] - if (protobuf.__name__ not in mro_chain and - not getattr(cls, "allow_custom_class_name", False)): + if protobuf.__name__ not in mro_chain and not getattr( + cls, "allow_custom_class_name", False + ): raise ProtobufNameMustMatchClassOrParentClassError( "Can't define RDFProtoStruct class %s from proto %s " - "(proto name must match one of the classes in the MRO chain: %s)" % - (cls.__name__, protobuf.__name__, ", ".join(mro_chain))) + "(proto name must match one of the classes in the MRO chain: %s)" + % (cls.__name__, protobuf.__name__, ", ".join(mro_chain)) + ) cls.recorded_rdf_deps = set() @@ -127,7 +127,8 @@ def DefineFromWireFormat(cls, protobuf): name=field.name, friendly_name=options.friendly_name, field_number=field.number, - labels=list(options.label)) + labels=list(options.label), + ) if field.has_default_value: kwargs["default"] = field.default_value @@ -137,34 +138,46 @@ def DefineFromWireFormat(cls, protobuf): cls.recorded_rdf_deps.add(options.type) rdf_type = rdfvalue.RDFValue.classes.get(options.type) if rdf_type: - if (CHECK_PROTOBUF_DEPENDENCIES and rdf_type not in cls.rdf_deps and - options.type not in cls.rdf_deps): + if ( + CHECK_PROTOBUF_DEPENDENCIES + and rdf_type not in cls.rdf_deps + and options.type not in cls.rdf_deps + ): raise rdfvalue.InitializeError( "%s.%s: field %s is of type %s, " - "but type is missing from its dependencies list" % - (cls.__module__, cls.__name__, field.name, options.type)) + "but type is missing from its dependencies list" + % (cls.__module__, cls.__name__, field.name, options.type) + ) # Make sure that the field type is the same as what is required by the # semantic type. required_field_type = _SEMANTIC_PRIMITIVE_TO_FIELD_TYPE[ - rdf_type.protobuf_type] + rdf_type.protobuf_type + ] if required_field_type != field.type: raise rdfvalue.InitializeError( - ("%s: .proto file uses incorrect field to store Semantic Value " - "%s: Should be %s") % - (cls.__name__, field.name, rdf_type.protobuf_type)) + ( + "%s: .proto file uses incorrect field to store Semantic Value" + " %s: Should be %s" + ) + % (cls.__name__, field.name, rdf_type.protobuf_type) + ) type_descriptor = classes_dict["ProtoRDFValue"]( - rdf_type=options.type, **kwargs) + rdf_type=options.type, **kwargs + ) # A semantic protobuf is already a semantic value so it is an error to # specify it in two places. elif options.type and field.type == TYPE_MESSAGE: raise rdfvalue.InitializeError( - ("%s: .proto file specified both Semantic Value type %s and " - "Semantic protobuf %s") % - (cls.__name__, options.type, field.message_type.name)) + ( + "%s: .proto file specified both Semantic Value type %s and " + "Semantic protobuf %s" + ) + % (cls.__name__, options.type, field.message_type.name) + ) # Try to figure out what this field actually is from the descriptor. elif field.type == TYPE_DOUBLE: @@ -188,18 +201,22 @@ def DefineFromWireFormat(cls, protobuf): dynamic_cb = getattr(cls, options.dynamic_type, None) if dynamic_cb is not None: type_descriptor = classes_dict["ProtoDynamicEmbedded"]( - dynamic_cb=dynamic_cb, **kwargs) + dynamic_cb=dynamic_cb, **kwargs + ) else: - logging.warning("Dynamic type specifies a non existent callback %s", - options.dynamic_type) + logging.warning( + "Dynamic type specifies a non existent callback %s", + options.dynamic_type, + ) - elif (field.type == TYPE_MESSAGE and field.message_type.name == "Any"): + elif field.type == TYPE_MESSAGE and field.message_type.name == "Any": if options.no_dynamic_type_lookup: type_descriptor = classes_dict["ProtoAnyValue"](**kwargs) else: dynamic_cb = getattr(cls, options.dynamic_type, None) type_descriptor = classes_dict["ProtoDynamicAnyValueEmbedded"]( - dynamic_cb=dynamic_cb, **kwargs) + dynamic_cb=dynamic_cb, **kwargs + ) elif field.type == TYPE_INT64 or field.type == TYPE_INT32: type_descriptor = classes_dict["ProtoSignedInteger"](**kwargs) @@ -214,22 +231,31 @@ def DefineFromWireFormat(cls, protobuf): # when it is known. Therefore this can actually also refer to this current # protobuf (i.e. nested proto). type_descriptor = classes_dict["ProtoEmbedded"]( - nested=field.message_type.name, **kwargs) + nested=field.message_type.name, **kwargs + ) cls.recorded_rdf_deps.add(field.message_type.name) if CHECK_PROTOBUF_DEPENDENCIES: found = False for d in cls.rdf_deps: - if (hasattr(d, "__name__") and d.__name__ == field.message_type.name - or d == field.message_type.name): + if ( + hasattr(d, "__name__") + and d.__name__ == field.message_type.name + or d == field.message_type.name + ): found = True if not found: raise rdfvalue.InitializeError( "%s.%s: TYPE_MESSAGE field %s is %s, " - "but type is missing from its dependencies list" % - (cls.__module__, cls.__name__, field.name, - field.message_type.name)) + "but type is missing from its dependencies list" + % ( + cls.__module__, + cls.__name__, + field.name, + field.message_type.name, + ) + ) # TODO(user): support late binding here. if type_descriptor.type: @@ -260,10 +286,18 @@ def DefineFromWireFormat(cls, protobuf): # different protobuf. if semantic_protobuf_primitive != field.message_type.name: raise rdfvalue.InitializeError( - ("%s.%s: Conflicting primitive (%s) and semantic protobuf %s " - "which implements primitive protobuf (%s)") % - (cls.__name__, field.name, field.message_type.name, - type_descriptor.type.__name__, semantic_protobuf_primitive)) + ( + "%s.%s: Conflicting primitive (%s) and semantic protobuf %s " + "which implements primitive protobuf (%s)" + ) + % ( + cls.__name__, + field.name, + field.message_type.name, + type_descriptor.type.__name__, + semantic_protobuf_primitive, + ) + ) elif field.enum_type: # It is an enum. # TODO(hanuszczak): Protobuf descriptors use `bytes` objects to represent @@ -282,7 +316,8 @@ def DefineFromWireFormat(cls, protobuf): enum_dict[enum_value_name] = enum_value.number description = enum_value.GetOptions().Extensions[ - semantic_pb2.description] + semantic_pb2.description + ] enum_descriptions[enum_value_name] = description labels = [ label @@ -295,7 +330,8 @@ def DefineFromWireFormat(cls, protobuf): enum=enum_dict, enum_descriptions=enum_descriptions, enum_labels=enum_labels, - **kwargs) + **kwargs, + ) # Attach the enum container to the class for easy reference: setattr(cls, enum_desc_name, type_descriptor.enum_container) @@ -306,7 +342,8 @@ def DefineFromWireFormat(cls, protobuf): if field.label == LABEL_REPEATED: options = field.GetOptions().Extensions[semantic_pb2.sem_type] type_descriptor = classes_dict["ProtoList"]( - type_descriptor, labels=list(options.label)) + type_descriptor, labels=list(options.label) + ) try: cls.AddDescriptor(type_descriptor) @@ -328,5 +365,6 @@ def DefineFromWireFormat(cls, protobuf): leftover_deps.remove(d) if leftover_deps: raise rdfvalue.InitializeError( - "Found superfluous dependencies for %s: %s" % - (cls.__name__, ",".join(leftover_deps))) + "Found superfluous dependencies for %s: %s" + % (cls.__name__, ",".join(leftover_deps)) + ) diff --git a/grr/core/grr_response_core/lib/rdfvalues/protodict.py b/grr/core/grr_response_core/lib/rdfvalues/protodict.py index ebfc5f77f8..e8c1fcb072 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/protodict.py +++ b/grr/core/grr_response_core/lib/rdfvalues/protodict.py @@ -21,8 +21,11 @@ class EmbeddedRDFValue(rdf_structs.RDFProtoStruct): ] def __init__(self, initializer=None, payload=None, *args, **kwargs): - if (not payload and isinstance(initializer, rdfvalue.RDFValue) and - not isinstance(initializer, EmbeddedRDFValue)): + if ( + not payload + and isinstance(initializer, rdfvalue.RDFValue) + and not isinstance(initializer, EmbeddedRDFValue) + ): # The initializer is an RDFValue object that we can use as payload. payload = initializer initializer = None @@ -52,6 +55,7 @@ def payload(self, payload): class DataBlob(rdf_structs.RDFProtoStruct): """Wrapper class for DataBlob protobuf.""" + protobuf = jobs_pb2.DataBlob rdf_deps = [ "BlobArray", # TODO(user): dependency loop. @@ -72,8 +76,14 @@ def SetValue(self, value, raise_on_error=True): Raises: TypeError: if the value can't be serialized and raise_on_error is True """ - type_mappings = [(Text, "string"), (bytes, "data"), (bool, "boolean"), - (int, "integer"), (dict, "dict"), (float, "float")] + type_mappings = [ + (Text, "string"), + (bytes, "data"), + (bool, "boolean"), + (int, "integer"), + (dict, "dict"), + (float, "float"), + ] if value is None: self.none = "None" @@ -83,14 +93,14 @@ def SetValue(self, value, raise_on_error=True): self.rdf_value.name = value.__class__.__name__ elif isinstance(value, (list, tuple)): - self.list.content.Extend([ - DataBlob().SetValue(v, raise_on_error=raise_on_error) for v in value - ]) + self.list.content.Extend( + [DataBlob().SetValue(v, raise_on_error=raise_on_error) for v in value] + ) elif isinstance(value, set): - self.set.content.Extend([ - DataBlob().SetValue(v, raise_on_error=raise_on_error) for v in value - ]) + self.set.content.Extend( + [DataBlob().SetValue(v, raise_on_error=raise_on_error) for v in value] + ) elif isinstance(value, dict): self.dict.FromDict(value, raise_on_error=raise_on_error) @@ -102,8 +112,10 @@ def SetValue(self, value, raise_on_error=True): return self - message = "Unsupported type for ProtoDict: %s of type %s" % (value, - type(value)) + message = "Unsupported type for ProtoDict: %s of type %s" % ( + value, + type(value), + ) if raise_on_error: raise TypeError(message) @@ -120,8 +132,15 @@ def GetValue(self, ignore_error=True): return None field_names = [ - "integer", "string", "data", "boolean", "list", "dict", "rdf_value", - "float", "set" + "integer", + "string", + "data", + "boolean", + "list", + "dict", + "rdf_value", + "float", + "set", ] values = [getattr(self, x) for x in field_names if self.HasField(x)] @@ -168,6 +187,7 @@ class Dict(rdf_structs.RDFProtoStruct): The dict may contain strings (python unicode objects), int64, or binary blobs (python string objects) as keys and values. """ + protobuf = jobs_pb2.Dict rdf_deps = [ KeyValue, @@ -212,7 +232,8 @@ def FromDict(self, dictionary, raise_on_error=True): for key, value in dictionary.items(): self._values[key] = KeyValue( k=DataBlob().SetValue(key, raise_on_error=raise_on_error), - v=DataBlob().SetValue(value, raise_on_error=raise_on_error)) + v=DataBlob().SetValue(value, raise_on_error=raise_on_error), + ) self.dat = self._values.values() # pytype: disable=annotation-type-mismatch return self @@ -283,7 +304,8 @@ def SetItem(self, key, value, raise_on_error=True): cast(rdf_structs.RepeatedFieldHelper, self.dat).dirty = True self._values[key] = KeyValue( k=DataBlob().SetValue(key, raise_on_error=raise_on_error), - v=DataBlob().SetValue(value, raise_on_error=raise_on_error)) + v=DataBlob().SetValue(value, raise_on_error=raise_on_error), + ) def __setitem__(self, key, value): # TODO(user):pytype: assigning "dirty" here is a hack. The assumption @@ -294,7 +316,8 @@ def __setitem__(self, key, value): raise TypeError("self.dat has an unexpected type %s" % self.dat.__class__) cast(rdf_structs.RepeatedFieldHelper, self.dat).dirty = True self._values[key] = KeyValue( - k=DataBlob().SetValue(key), v=DataBlob().SetValue(value)) + k=DataBlob().SetValue(key), v=DataBlob().SetValue(value) + ) def __iter__(self): for x in self._values.values(): @@ -429,6 +452,7 @@ class RDFValueArray(rdf_structs.RDFProtoStruct): protobuf with a repeated field (This can be now done dynamically, which is the main reason we used this in the past). """ + protobuf = jobs_pb2.BlobArray allow_custom_class_name = True rdf_deps = [ @@ -450,8 +474,9 @@ def __init__(self, initializer=None): except TypeError: if initializer is not None: raise rdfvalue.InitializeError( - "%s can not be initialized from %s" % - (self.__class__.__name__, type(initializer))) + "%s can not be initialized from %s" + % (self.__class__.__name__, type(initializer)) + ) def Append(self, value=None, **kwarg): """Add another member to the array. @@ -468,16 +493,20 @@ def Append(self, value=None, **kwarg): ValueError: If the value to add is not allowed. """ if self.rdf_type is not None: - if (isinstance(value, rdfvalue.RDFValue) and - value.__class__ != self.rdf_type): + if ( + isinstance(value, rdfvalue.RDFValue) + and value.__class__ != self.rdf_type + ): raise ValueError("Can only accept %s" % self.rdf_type) try: # Try to coerce the value. value = self.rdf_type(value, **kwarg) # pylint: disable=not-callable - except (TypeError, ValueError): - raise ValueError("Unable to initialize %s from type %s" % - (self.__class__.__name__, type(value))) + except (TypeError, ValueError) as e: + raise ValueError( + "Unable to initialize %s from type %s" + % (self.__class__.__name__, type(value)) + ) from e self.content.Append(DataBlob().SetValue(value)) diff --git a/grr/core/grr_response_core/lib/rdfvalues/protodict_test.py b/grr/core/grr_response_core/lib/rdfvalues/protodict_test.py index 14d424f142..79d983efdd 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/protodict_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/protodict_test.py @@ -9,7 +9,6 @@ an __iter__) method, but are serializable as an RDFProto. """ - from collections import abc from typing import Text @@ -74,7 +73,7 @@ def testSerialization(self): test_dict = dict( key1=1, # Integer. key2="foo", # String. - key3=u"\u4f60\u597d", # Unicode. + key3="\u4f60\u597d", # Unicode. key5=rdfvalue.RDFDatetime.FromHumanReadable("2012/12/11"), # RDFValue. key6=None, # Support None Encoding. key7=rdf_structs.EnumNamedValue(5, name="Test"), # Enums. @@ -114,7 +113,8 @@ def testNestedDictsMultipleTypes(self): key2=rdf_protodict.Dict({"A": 1}), key3=[1, 2, 3, [1, 2, [3]]], key4=[[], None, ["abc"]], - key5=set([1, 2, 3])) + key5=set([1, 2, 3]), + ) sample = rdf_protodict.Dict(**test_dict) self.CheckTestDict(test_dict, sample) @@ -124,7 +124,7 @@ def testNestedDictsMultipleTypes(self): def testNestedDictsOpaqueTypes(self): - class UnSerializable(object): + class UnSerializable: pass test_dict = dict( @@ -133,7 +133,8 @@ class UnSerializable(object): key3=[1, UnSerializable(), 3, [1, 2, [3]]], key4=[[], None, ["abc"]], key5=UnSerializable(), - key6=["a", UnSerializable(), "b"]) + key6=["a", UnSerializable(), "b"], + ) self.assertRaises(TypeError, rdf_protodict.Dict, **test_dict) @@ -425,7 +426,7 @@ def testArray(self): None, # None. rdfvalue.RDFDatetime.Now(), # An RDFValue instance. [1, 2], # A nested list. - u"升级程序", # Unicode. + "升级程序", # Unicode. ] sample = rdf_protodict.RDFValueArray(test_list) @@ -457,8 +458,9 @@ def testPop(self): self.assertEqual(sample.Pop(), "world") -class EmbeddedRDFValueTest(rdf_test_base.RDFProtoTestMixin, - test_lib.GRRBaseTest): +class EmbeddedRDFValueTest( + rdf_test_base.RDFProtoTestMixin, test_lib.GRRBaseTest +): rdfvalue_class = rdf_protodict.EmbeddedRDFValue def GenerateSample(self, number=0): diff --git a/grr/core/grr_response_core/lib/rdfvalues/read_low_level.py b/grr/core/grr_response_core/lib/rdfvalues/read_low_level.py index 3b62bf5b37..e2c5e5f02a 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/read_low_level.py +++ b/grr/core/grr_response_core/lib/rdfvalues/read_low_level.py @@ -9,6 +9,7 @@ class ReadLowLevelArgs(rdf_structs.RDFProtoStruct): """Arguments for ReadLowLevel flow.""" + protobuf = read_low_level_pb2.ReadLowLevelArgs rdf_deps = [ rdfvalue.ByteSize, @@ -30,18 +31,22 @@ def Validate(self): raise ValueError(f"Negative length ({self.length})") if self.length > self.MAX_RAW_DATA_BYTES: - raise ValueError(f"Cannot read more than {self.MAX_RAW_DATA_BYTES} bytes " - f"({self.length} bytes requested") + raise ValueError( + f"Cannot read more than {self.MAX_RAW_DATA_BYTES} bytes " + f"({self.length} bytes requested" + ) class ReadLowLevelFlowResult(rdf_structs.RDFProtoStruct): """Result returned by ReadLowLevel.""" + protobuf = read_low_level_pb2.ReadLowLevelFlowResult rdf_deps = [] class ReadLowLevelRequest(rdf_structs.RDFProtoStruct): """Request for ReadLowLevel action.""" + protobuf = read_low_level_pb2.ReadLowLevelRequest rdf_deps = [ rdfvalue.ByteSize, @@ -50,6 +55,7 @@ class ReadLowLevelRequest(rdf_structs.RDFProtoStruct): class ReadLowLevelResult(rdf_structs.RDFProtoStruct): """Result for ReadLowLevel action.""" + protobuf = read_low_level_pb2.ReadLowLevelResult rdf_deps = [ rdf_client.BufferReference, diff --git a/grr/core/grr_response_core/lib/rdfvalues/standard.py b/grr/core/grr_response_core/lib/rdfvalues/standard.py index f23f4ee8ae..8022ca498e 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/standard.py +++ b/grr/core/grr_response_core/lib/rdfvalues/standard.py @@ -1,10 +1,7 @@ #!/usr/bin/env python """Standard RDFValues.""" - import re -from typing import Text - from urllib import parse as urlparse from grr_response_core.lib import config_lib @@ -18,8 +15,10 @@ class RegularExpression(rdfvalue.RDFString): """A semantic regular expression.""" - context_help_url = ("investigating-with-grr/flows/" - "literal-and-regex-matching.html#regex-matches") + context_help_url = ( + "investigating-with-grr/flows/" + "literal-and-regex-matching.html#regex-matches" + ) def __init__(self, initializer=None): super().__init__(initializer=initializer) @@ -52,8 +51,10 @@ def FindIter(self, text): class LiteralExpression(rdfvalue.RDFBytes): """A RDFBytes literal for use in GrepSpec.""" - context_help_url = ("investigating-with-grr/flows/" - "literal-and-regex-matching.html#literal-matches") + context_help_url = ( + "investigating-with-grr/flows/" + "literal-and-regex-matching.html#literal-matches" + ) class EmailAddress(rdfvalue.RDFString): @@ -81,8 +82,10 @@ def __init__(self, initializer=None): domain = config_lib._CONFIG["Logging.domain"] # pylint: enable=protected-access if self._value and domain and self._match.group(1) != domain: - raise ValueError("Email address '%s' does not belong to the configured " - "domain '%s'" % (self._match.group(1), domain)) + raise ValueError( + "Email address '%s' does not belong to the configured domain '%s'" + % (self._match.group(1), domain) + ) class AuthenticodeSignedData(rdf_structs.RDFProtoStruct): @@ -99,6 +102,7 @@ class PersistenceFile(rdf_structs.RDFProtoStruct): class URI(rdf_structs.RDFProtoStruct): """Represets a URI with its individual components separated.""" + protobuf = sysinfo_pb2.URI def __init__(self, initializer=None, **kwargs): @@ -125,13 +129,13 @@ def FromSerializedBytes(cls, value: bytes): return cls(urlparse.urlparse(value.decode("utf-8"))) @classmethod - def FromHumanReadable(cls, value: Text): - precondition.AssertType(value, Text) + def FromHumanReadable(cls, value: str): + precondition.AssertType(value, str) return cls(urlparse.urlparse(value)) def SerializeToBytes(self) -> bytes: return self.SerializeToHumanReadable().encode("utf-8") - def SerializeToHumanReadable(self) -> Text: + def SerializeToHumanReadable(self) -> str: parts = (self.transport, self.host, self.path, self.query, self.fragment) return urlparse.urlunsplit(parts) # pytype: disable=bad-return-type diff --git a/grr/core/grr_response_core/lib/rdfvalues/standard_test.py b/grr/core/grr_response_core/lib/rdfvalues/standard_test.py index 3d3f280b1e..b320f3a650 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/standard_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/standard_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Test standard RDFValues.""" - from absl import app from grr_response_core.lib.rdfvalues import standard as rdf_standard @@ -23,7 +22,8 @@ def testURI(self): host="google.com", path="/index", query="q=hi", - fragment="anchor1") + fragment="anchor1", + ) self.assertEqual(sample.transport, "http") self.assertEqual(sample.host, "google.com") self.assertEqual(sample.path, "/index") diff --git a/grr/core/grr_response_core/lib/rdfvalues/stats.py b/grr/core/grr_response_core/lib/rdfvalues/stats.py index 795c49373b..29d6609c46 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/stats.py +++ b/grr/core/grr_response_core/lib/rdfvalues/stats.py @@ -2,11 +2,7 @@ """RDFValue instances related to the statistics collection.""" import bisect -import math -import threading - -from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import analysis_pb2 @@ -20,8 +16,9 @@ class Distribution(rdf_structs.RDFProtoStruct): def __init__(self, initializer=None, bins=None): if initializer and bins: - raise ValueError("Either 'initializer' or 'bins' arguments can " - "be specified.") + raise ValueError( + "Either 'initializer' or 'bins' arguments can be specified." + ) super().__init__(initializer=initializer) @@ -81,58 +78,25 @@ class StatsHistogramBin(rdf_structs.RDFProtoStruct): class StatsHistogram(rdf_structs.RDFProtoStruct): """Histogram with a user-provided set of bins.""" + protobuf = jobs_pb2.StatsHistogram rdf_deps = [ StatsHistogramBin, ] - @classmethod - def FromBins(cls, bins): - res = cls() - for b in bins: - res.bins.Append(StatsHistogramBin(range_max_value=b)) - return res - - def RegisterValue(self, value): - """Puts a given value into an appropriate bin.""" - if self.bins: - for b in self.bins: - if b.range_max_value > value: - b.num += 1 - return - - self.bins[-1].num += 1 - class RunningStats(rdf_structs.RDFProtoStruct): """Class for collecting running stats: mean, stddev and histogram data.""" + protobuf = jobs_pb2.RunningStats rdf_deps = [ StatsHistogram, ] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._sum_sq = 0 - - def RegisterValue(self, value): - self.num += 1 - self.sum += value - self._sum_sq += value**2 - self.stddev = math.sqrt(self._sum_sq / self.num - self.mean**2) - - self.histogram.RegisterValue(value) - - @property - def mean(self): - if self.num == 0: - return 0 - else: - return self.sum / self.num - class ClientResourcesStats(rdf_structs.RDFProtoStruct): """RDF value representing clients' resources usage statistics for hunts.""" + protobuf = jobs_pb2.ClientResourcesStats rdf_deps = [ rdf_client_stats.ClientResources, @@ -140,65 +104,65 @@ class ClientResourcesStats(rdf_structs.RDFProtoStruct): ] CPU_STATS_BINS = [ - 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, 10, - 15, 20 + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1, + 1.5, + 2, + 2.5, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 15, + 20, ] NETWORK_STATS_BINS = [ - 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, - 131072, 262144, 524288, 1048576, 2097152 + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + 2097152, ] NUM_WORST_PERFORMERS = 10 - def __init__(self, initializer=None, **kwargs): - super().__init__(initializer=initializer, **kwargs) - - self.user_cpu_stats.histogram = StatsHistogram.FromBins(self.CPU_STATS_BINS) - self.system_cpu_stats.histogram = StatsHistogram.FromBins( - self.CPU_STATS_BINS) - self.network_bytes_sent_stats.histogram = StatsHistogram.FromBins( - self.NETWORK_STATS_BINS) - - self.lock = threading.RLock() - - def __getstate__(self): - # We can't pickle the lock. - res = self.__dict__.copy() - del res["lock"] - return res - - def __setstate__(self, state): - self.__dict__ = state - self.lock = threading.RLock() - - @utils.Synchronized - def RegisterResources(self, client_resources): - """Update stats with info about resources consumed by a single client.""" - self.user_cpu_stats.RegisterValue(client_resources.cpu_usage.user_cpu_time) - self.system_cpu_stats.RegisterValue( - client_resources.cpu_usage.system_cpu_time) - self.network_bytes_sent_stats.RegisterValue( - client_resources.network_bytes_sent) - - self.worst_performers.Append(client_resources) - new_worst_performers = sorted( - self.worst_performers, - key=lambda s: s.cpu_usage.user_cpu_time + s.cpu_usage.system_cpu_time, - reverse=True)[:self.NUM_WORST_PERFORMERS] - self.worst_performers = new_worst_performers - class Sample(rdf_structs.RDFProtoStruct): """A Graph sample is a single data point.""" + protobuf = analysis_pb2.Sample class SampleFloat(rdf_structs.RDFProtoStruct): """A Graph float data point.""" + protobuf = analysis_pb2.SampleFloat class Graph(rdf_structs.RDFProtoStruct): """A Graph is a collection of sample points.""" + protobuf = analysis_pb2.Graph rdf_deps = [ Sample, @@ -223,6 +187,7 @@ def __iter__(self): class ClientGraphSeries(rdf_structs.RDFProtoStruct): """A collection of graphs for a single client-report type.""" + protobuf = analysis_pb2.ClientGraphSeries rdf_deps = [ Graph, diff --git a/grr/core/grr_response_core/lib/rdfvalues/stats_test.py b/grr/core/grr_response_core/lib/rdfvalues/stats_test.py deleted file mode 100644 index 67df6fbbe7..0000000000 --- a/grr/core/grr_response_core/lib/rdfvalues/stats_test.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -"""Test for RunningStats class.""" - - - -from absl import app - -from grr_response_core.lib.rdfvalues import stats as rdf_stats -from grr_response_core.lib.rdfvalues import test_base as rdf_test_base -from grr.test_lib import test_lib - - -class RunningStatsTest(rdf_test_base.RDFValueTestMixin, test_lib.GRRBaseTest): - rdfvalue_class = rdf_stats.RunningStats - - def GenerateSample(self, number=0): - value = rdf_stats.RunningStats() - value.RegisterValue(number) - value.RegisterValue(number * 2) - value.histogram = rdf_stats.StatsHistogram.FromBins([2.0, number, 10.0]) - return value - - def testMeanIsCalculatedCorrectly(self): - stats = rdf_stats.RunningStats() - values = range(100) - - for v in values: - stats.RegisterValue(v) - - # Compare calculated mean with a precalculated value. - self.assertAlmostEqual(stats.mean, 49.5) - - def testStdDevIsCalculatedCorrectly(self): - stats = rdf_stats.RunningStats() - values = range(100) - - for v in values: - stats.RegisterValue(v) - - # Compare calculated standard deviation with a precalculated value. - self.assertAlmostEqual(stats.stddev, 28.86607004) - - def testHistogramIsCalculatedCorrectly(self): - stats = rdf_stats.RunningStats() - stats.histogram = rdf_stats.StatsHistogram.FromBins([2.0, 4.0, 10.0]) - - stats.RegisterValue(1.0) - stats.RegisterValue(1.0) - - stats.RegisterValue(2.0) - stats.RegisterValue(2.1) - stats.RegisterValue(2.2) - - stats.RegisterValue(8.0) - stats.RegisterValue(9.0) - stats.RegisterValue(10.0) - stats.RegisterValue(11.0) - - self.assertAlmostEqual(stats.histogram.bins[0].range_max_value, 2.0) - self.assertEqual(stats.histogram.bins[0].num, 2) - - self.assertAlmostEqual(stats.histogram.bins[1].range_max_value, 4.0) - self.assertEqual(stats.histogram.bins[1].num, 3) - - self.assertAlmostEqual(stats.histogram.bins[2].range_max_value, 10.0) - self.assertEqual(stats.histogram.bins[2].num, 4) - - -def main(argv): - test_lib.main(argv) - - -if __name__ == "__main__": - app.run(main) diff --git a/grr/core/grr_response_core/lib/rdfvalues/structs.py b/grr/core/grr_response_core/lib/rdfvalues/structs.py index dd4443c63f..75960bd156 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/structs.py +++ b/grr/core/grr_response_core/lib/rdfvalues/structs.py @@ -7,7 +7,7 @@ import functools import logging import struct -from typing import ByteString, Iterator, Optional, Sequence, Text, Type, TypeVar, cast +from typing import ByteString, Iterator, Optional, Sequence, Type, TypeVar, cast from google.protobuf import any_pb2 from google.protobuf import wrappers_pb2 @@ -49,7 +49,7 @@ def SignedVarintEncode(value): """Encode a signed integer as a signed varint.""" if value < 0: - value += (1 << 64) + value += 1 << 64 return VarintEncode(value) @@ -57,8 +57,8 @@ def SignedVarintEncode(value): def SignedVarintReader(buf, pos=0): """A signed 64 bit decoder for signed varints.""" result, p = VarintReader(buf, pos) - if result > 0x7fffffffffffffff: - result -= (1 << 64) + if result > 0x7FFFFFFFFFFFFFFF: + result -= 1 << 64 return (result, p) @@ -81,8 +81,10 @@ def Tag(entry): elif wire_format is not None: encoded_tag = wire_format[0] else: - raise AssertionError("Each entry is expected to have " - "either a type_descriptor or wire_format.") + raise AssertionError( + "Each entry is expected to have " + "either a type_descriptor or wire_format." + ) return VarintReader(encoded_tag, 0)[0] @@ -95,8 +97,9 @@ def _SerializeEntries(entries): output = [] for python_format, wire_format, type_descriptor in entries: - if wire_format is None or (python_format and - type_descriptor.IsDirty(python_format)): + if wire_format is None or ( + python_format and type_descriptor.IsDirty(python_format) + ): wire_format = type_descriptor.ConvertToWireFormat(python_format) precondition.AssertIterableType(wire_format, bytes) @@ -112,8 +115,9 @@ def ReadIntoObject(buff, index, value_obj, length=0): # Split the buffer into tags and wire_format representations, then collect # these into the raw data cache. - for (encoded_tag, encoded_length, encoded_field) in SplitBuffer( - buff, index=index, length=length): + for encoded_tag, encoded_length, encoded_field in SplitBuffer( + buff, index=index, length=length + ): type_info_obj = value_obj.type_infos_by_encoded_tag.get(encoded_tag) @@ -156,8 +160,9 @@ class ProtoType(type_info.TypeInfoObject): This is an abstract class - do not instantiate directly. """ + # Must be overridden by implementations. - wire_type = None + wire_type: int # We cache the serialized version of the tag here so we just need to do a # string comparison instead of decoding the tag each time. @@ -181,12 +186,14 @@ class ProtoType(type_info.TypeInfoObject): # access. set_default_on_access = False - def __init__(self, - field_number=None, - required=False, - labels=None, - set_default_on_access=None, - **kwargs): + def __init__( + self, + field_number=None, + required=False, + labels=None, + set_default_on_access=None, + **kwargs, + ): super().__init__(**kwargs) # TODO: Without this type hint, pytype thinks that field_number # is always None. @@ -249,7 +256,6 @@ def ConvertFromWireFormat(self, value, container=None): Returns: The parameter encoded in the python format representation. - """ raise NotImplementedError @@ -267,7 +273,6 @@ def ConvertToWireFormat(self, value): Returns: The parameter encoded in the wire format representation. - """ raise NotImplementedError @@ -279,9 +284,12 @@ def _FormatDefault(self): return " [default = %s]" % self.GetDefault() def _FormatField(self): - result = " optional %s %s = %s%s" % (self.proto_type_name, self.name, - self.field_number, - self._FormatDefault()) + result = " optional %s %s = %s%s" % ( + self.proto_type_name, + self.name, + self.field_number, + self._FormatDefault(), + ) return result + ";\n" @@ -289,7 +297,7 @@ def Definition(self): """Return a string with the definition of this field.""" return self._FormatDescriptionComment() + self._FormatField() - def Format(self, value) -> Iterator[Text]: + def Format(self, value) -> Iterator[str]: """A Generator for display lines representing value.""" yield str(value) @@ -302,11 +310,14 @@ def GetDefault(self, container=None): _ = container return self.default - def __str__(self) -> Text: + def __str__(self) -> str: # TODO: This fails for ProtoList. return "" % ( - self.name, self.__class__.__name__, self.owner.__name__, - self.field_number) + self.name, + self.__class__.__name__, + self.owner.__name__, + self.field_number, + ) def SetOwner(self, owner): self.owner = owner @@ -320,7 +331,7 @@ class ProtoString(ProtoType): # This descriptor describes unicode strings. type = rdfvalue.RDFString - def __init__(self, default=u"", **kwargs): + def __init__(self, default="", **kwargs): # Strings default to "" if not specified. super().__init__(**kwargs) @@ -332,13 +343,13 @@ def GetDefault(self, container=None): _ = container return self.default - def Validate(self, value, **_) -> Text: # pytype: disable=signature-mismatch # overriding-parameter-count-checks + def Validate(self, value, **_) -> str: # pytype: disable=signature-mismatch # overriding-parameter-count-checks """Validates a python format representation of the value.""" if isinstance(value, rdfvalue.RDFString): # TODO(hanuszczak): Use `str` here. - return Text(value) + return str(value) - if isinstance(value, Text): + if isinstance(value, str): return value if isinstance(value, bytes): @@ -346,8 +357,9 @@ def Validate(self, value, **_) -> Text: # pytype: disable=signature-mismatch # raise type_info.TypeValueError( "Not a valid unicode string: {!r} of type {}".format( - value, - type(value).__name__)) + value, type(value).__name__ + ) + ) def ConvertFromWireFormat(self, value, container=None): """Internally strings are utf8 encoded.""" @@ -473,6 +485,7 @@ class ProtoFixed32(ProtoUnsignedInteger): The wire format is a 4 byte string, while the python type is an int. """ + _size = 4 proto_type_name = "sfixed32" @@ -503,6 +516,7 @@ class ProtoFixedU32(ProtoFixed32): The wire format is a 4 byte string, while the python type is an int. """ + proto_type_name = "fixed32" def ConvertToWireFormat(self, value): @@ -517,6 +531,7 @@ class ProtoFloat(ProtoFixed32): The wire format is a 4 byte string, while the python type is a float. """ + proto_type_name = "float" def Validate(self, value, **_): @@ -537,6 +552,7 @@ class ProtoDouble(ProtoFixed64): The wire format is a 8 byte string, while the python type is a float. """ + proto_type_name = "double" def Validate(self, value, **_): @@ -561,11 +577,9 @@ class EnumNamedValue(rdfvalue.RDFPrimitive): protobuf_type = "integer" - def __init__(self, - initializer=None, - name=None, - description=None, - labels=None): + def __init__( + self, initializer=None, name=None, description=None, labels=None + ): if initializer is None: initializer = 0 @@ -576,8 +590,8 @@ def __init__(self, if labels is None: labels = () - precondition.AssertType(name, Text) - precondition.AssertOptionalType(description, Text) + precondition.AssertType(name, str) + precondition.AssertOptionalType(description, str) precondition.AssertIterableType(labels, int) super().__init__((int(initializer), name, description, tuple(labels))) @@ -612,12 +626,13 @@ def __lt__(self, other): return (self.id, self.name) < (other.id, other.name) return NotImplemented - def __str__(self) -> Text: + def __str__(self) -> str: return self.name def __repr__(self): return "{}(initializer={!r}, name={!r})".format( - type(self).__name__, self.id, self.name) + type(self).__name__, self.id, self.name + ) def __int__(self): return self.id @@ -644,8 +659,8 @@ def FromSerializedBytes(cls, value: bytes): return cls(0) @classmethod - def FromHumanReadable(cls, string: Text): - precondition.AssertType(string, Text) + def FromHumanReadable(cls, string: str): + precondition.AssertType(string, str) try: num = int(string) @@ -670,13 +685,15 @@ class ProtoEnum(ProtoSignedInteger): type = EnumNamedValue - def __init__(self, - default=None, - enum_name=None, - enum=None, - enum_descriptions=None, - enum_labels=None, - **kwargs): + def __init__( + self, + default=None, + enum_name=None, + enum=None, + enum_descriptions=None, + enum_labels=None, + **kwargs, + ): super().__init__(**kwargs) if enum_name is None: raise type_info.TypeValueError("Enum groups must be given a name.") @@ -694,7 +711,8 @@ def __init__(self, name=enum_name, descriptions=enum_descriptions, enum_labels=enum_labels, - values=(enum or {})) + values=(enum or {}), + ) self.enum = self.enum_container.enum_dict self.reverse_enum = self.enum_container.reverse_enum @@ -705,7 +723,8 @@ def __init__(self, def GetDefault(self, container=None): _ = container return EnumNamedValue( - self.default, name=self.reverse_enum.get(self.default)) + self.default, name=self.reverse_enum.get(self.default) + ) def Validate(self, value, **_): """Check that value is a valid enum.""" @@ -723,8 +742,9 @@ def Validate(self, value, **_): checked_value = int(value) if checked_value is None: raise type_info.TypeValueError( - "Value %s is not a valid enum value for field %s" % - (value, self.name)) + "Value %s is not a valid enum value for field %s" + % (value, self.name) + ) return EnumNamedValue(checked_value, name=self.reverse_enum.get(value)) @@ -774,10 +794,7 @@ def Validate(self, value, **_): return bool(int(super().Validate(value))) def ConvertFromWireFormat(self, value, container=None): - return bool( - int( - super(ProtoBoolean, - self).ConvertFromWireFormat(value, container=container))) + return bool(int(super().ConvertFromWireFormat(value, container=container))) def ConvertToWireFormat(self, value): return super().ConvertToWireFormat(bool(value)) @@ -824,7 +841,8 @@ def __init__(self, nested=None, **kwargs): else: raise type_info.TypeValueError( - "Only RDFProtoStructs can be nested, not %s" % nested.__name__) + "Only RDFProtoStructs can be nested, not %s" % nested.__name__ + ) def ConvertFromWireFormat(self, value, container=None): """The wire format is simply a string.""" @@ -852,8 +870,9 @@ class is finally defined. It gives the field descriptor an opportunity to TypeError: If the target class is not of the expected type. """ if not issubclass(target, RDFProtoStruct): - raise TypeError("Field %s expects a protobuf, but target is %s" % - (self, target)) + raise TypeError( + "Field %s expects a protobuf, but target is %s" % (self, target) + ) self.late_bound = False @@ -881,16 +900,18 @@ def GetDefault(self, container=None): def Validate(self, value, **_): # pytype: disable=signature-mismatch # overriding-parameter-count-checks if isinstance(value, str): - raise type_info.TypeValueError("Field %s must be of type %s" % - (self.name, self.type.__name__)) + raise type_info.TypeValueError( + "Field %s must be of type %s" % (self.name, self.type.__name__) + ) # We may coerce it to the correct type. if value.__class__ is not self.type: try: value = self.type(value) - except rdfvalue.InitializeError: - raise type_info.TypeValueError("Field %s must be of type %s" % - (self.name, self.type.__name__)) + except rdfvalue.InitializeError as e: + raise type_info.TypeValueError( + "Field %s must be of type %s" % (self.name, self.type.__name__) + ) from e return value @@ -899,8 +920,11 @@ def Definition(self): return self._FormatDescriptionComment() + self._FormatField() def _FormatField(self): - result = " optional %s %s = %s" % (self.proto_type_name, self.name, - self.field_number) + result = " optional %s %s = %s" % ( + self.proto_type_name, + self.name, + self.field_number, + ) return result + ";\n" def Format(self, value): @@ -946,8 +970,10 @@ def Validate(self, value, container=None): required_type = self._type(container) if required_type and not isinstance(value, required_type): - raise ValueError("Expected value of type %s, but got %s" % - (required_type, value.__class__.__name__)) + raise ValueError( + "Expected value of type %s, but got %s" + % (required_type, value.__class__.__name__) + ) return value @@ -1000,7 +1026,8 @@ def ConvertFromWireFormat(self, value, container=None): # If one of the protobuf library wrapper classes is used, unwrap the value. if result.type_url.startswith("type.googleapis.com/google.protobuf."): wrapper_cls = self.__class__.WRAPPER_BY_TYPE[ - converted_value.protobuf_type] + converted_value.protobuf_type + ] wrapper_value = wrapper_cls() wrapper_value.ParseFromString(result.value) return converted_value.FromWireFormat(wrapper_value.value) @@ -1032,8 +1059,10 @@ def ConvertToWireFormat(self, value): # Is it a protobuf-based value? if hasattr(value.__class__, "protobuf"): if value.__class__.protobuf: - type_name = ("type.googleapis.com/grr.%s" % - value.__class__.protobuf.__name__) + type_name = ( + "type.googleapis.com/grr.%s" + % value.__class__.protobuf.__name__ + ) else: type_name = value.__class__.__name__ data = value.SerializeToBytes() @@ -1043,12 +1072,14 @@ def ConvertToWireFormat(self, value): wrapped_data = wrapper_cls() wrapped_data.value = serialization.ToWireFormat(value) - type_name = ("type.googleapis.com/google.protobuf.%s" % - wrapper_cls.__name__) + type_name = ( + "type.googleapis.com/google.protobuf.%s" % wrapper_cls.__name__ + ) data = wrapped_data.SerializeToString() else: - raise ValueError("Can't convert value %s to a protobuf.Any value." % - value) + raise ValueError( + "Can't convert value %s to a protobuf.Any value." % value + ) any_value = AnyValue(type_url=type_name, value=data) output = _SerializeEntries(_GetOrderedEntries(any_value.GetRawData())) @@ -1085,7 +1116,7 @@ def ConvertToWireFormat(self, value): return (self.encoded_tag, VarintEncode(len(data)), data) -class RepeatedFieldHelper(abc.Sequence, object): +class RepeatedFieldHelper(abc.Sequence): """A helper for the RDFProto to handle repeated fields. This helper is intended to only be constructed from the RDFProto class. @@ -1143,7 +1174,8 @@ def IsDirty(self): def Copy(self): return RepeatedFieldHelper( - wrapped_list=self.wrapped_list[:], type_descriptor=self.type_descriptor) + wrapped_list=self.wrapped_list[:], type_descriptor=self.type_descriptor + ) def Append(self, rdf_value=utils.NotAValue, wire_format=None, **kwargs): """Append the value to our internal list.""" @@ -1159,9 +1191,9 @@ def Append(self, rdf_value=utils.NotAValue, wire_format=None, **kwargs): rdf_value = self.type_descriptor.Validate(rdf_value, **kwargs) except (TypeError, ValueError) as e: raise type_info.TypeValueError( - "Assignment value must be %s, but %s can not " - "be coerced. Error: %s" % - (self.type_descriptor.proto_type_name, type(rdf_value), e)) + "Assignment value must be %s, but %s can not be coerced. Error: %s" + % (self.type_descriptor.proto_type_name, type(rdf_value), e) + ) self.wrapped_list.append((rdf_value, wire_format)) @@ -1188,12 +1220,14 @@ def __getitem__(self, item): result.append(self.wrapped_list[i]) return self.__class__( - wrapped_list=result, type_descriptor=self.type_descriptor) + wrapped_list=result, type_descriptor=self.type_descriptor + ) python_format, wire_format = self.wrapped_list[item] if python_format is None: python_format = self.type_descriptor.ConvertFromWireFormat( - wire_format, container=self.container) + wire_format, container=self.container + ) self.wrapped_list[item] = (python_format, wire_format) @@ -1206,7 +1240,7 @@ def __ne__(self, other): return not self == other # pylint: disable=g-comparison-negation def __eq__(self, other): - if not isinstance(other, Sequence) or isinstance(other, (ByteString, Text)): + if not isinstance(other, Sequence) or isinstance(other, (ByteString, str)): return NotImplemented if len(self) != len(other): return False @@ -1215,7 +1249,7 @@ def __eq__(self, other): return False return True - def __str__(self) -> Text: + def __str__(self) -> str: result = ["'%s': [" % self.type_descriptor.name] for element in self: for line in self.type_descriptor.Format(element): @@ -1263,8 +1297,10 @@ class ProtoList(ProtoType): def __init__(self, delegate, labels=None, **kwargs): self.delegate = delegate if not isinstance(delegate, ProtoType): - raise AttributeError("Delegate class must derive from ProtoType, not %s" % - delegate.__class__.__name__) + raise AttributeError( + "Delegate class must derive from ProtoType, not %s" + % delegate.__class__.__name__ + ) # If our delegate is late bound we must also be late bound. This means that # the repeated field is not registered in the owner protobuf just @@ -1281,7 +1317,8 @@ def __init__(self, delegate, labels=None, **kwargs): description=delegate.description, field_number=delegate.field_number, friendly_name=delegate.friendly_name, - labels=labels) + labels=labels, + ) def IsDirty(self, value): return value.IsDirty() @@ -1289,7 +1326,8 @@ def IsDirty(self, value): def GetDefault(self, container=None): # By default an empty RepeatedFieldHelper. return RepeatedFieldHelper( - type_descriptor=self.delegate, container=container) + type_descriptor=self.delegate, container=container + ) def Validate(self, value, **_): # pytype: disable=signature-mismatch # overriding-parameter-count-checks """Check that value is a list of the required type.""" @@ -1297,8 +1335,10 @@ def Validate(self, value, **_): # pytype: disable=signature-mismatch # overrid # elements in a RepeatedFieldHelper already are coerced to the delegate # type. In that case we just make a copy. This only works when the value # wraps the same type as us. - if (value.__class__ is RepeatedFieldHelper and - value.type_descriptor is self.delegate): + if ( + value.__class__ is RepeatedFieldHelper + and value.type_descriptor is self.delegate + ): result = value.Copy() # Make sure the base class finds the value valid. @@ -1327,7 +1367,8 @@ def ConvertToWireFormat(self, value): """ output = _SerializeEntries( (python_format, wire_format, value.type_descriptor) - for (python_format, wire_format) in value.wrapped_list) + for (python_format, wire_format) in value.wrapped_list + ) return b"", b"", output def Format(self, value): @@ -1340,8 +1381,11 @@ def Format(self, value): yield "]" def _FormatField(self): - result = " repeated %s %s = %s" % (self.delegate.proto_type_name, - self.name, self.field_number) + result = " repeated %s %s = %s" % ( + self.delegate.proto_type_name, + self.name, + self.field_number, + ) return result + ";\n" @@ -1404,7 +1448,8 @@ class ProtoRDFValue(ProtoType): unsigned_integer_32=ProtoUnsignedInteger, integer=ProtoUnsignedInteger, signed_integer=ProtoSignedInteger, - string=ProtoString) + string=ProtoString, + ) def __init__(self, rdf_type=None, default=None, **kwargs): super().__init__(**kwargs) @@ -1452,8 +1497,9 @@ def _GetPrimitiveEncoder(self): """Finds the primitive encoder according to the type's protobuf_type.""" # Decide what should the primitive type be for packing the target rdfvalue # into the protobuf and create a delegate descriptor to control that. - primitive_cls = self._PROTO_DATA_STORE_LOOKUP[serialization.GetProtobufType( - self.type)] + primitive_cls = self._PROTO_DATA_STORE_LOOKUP[ + serialization.GetProtobufType(self.type) + ] self.primitive_desc = primitive_cls(**self._kwargs) # Our wiretype is the same as the delegate's. @@ -1477,8 +1523,9 @@ def IsDirty(self, python_format): return python_format.dirty def Definition(self): - return ("\n // Semantic Type: %s" % - self.type.__name__) + self.primitive_desc.Definition() + return ( + "\n // Semantic Type: %s" % self.type.__name__ + ) + self.primitive_desc.Definition() def Validate(self, value, **_): # pytype: disable=signature-mismatch # overriding-parameter-count-checks # Try to coerce into the correct type: @@ -1495,7 +1542,8 @@ def ConvertFromWireFormat(self, value, container=None): # rdfvalue. We use the delegate primitive descriptor to perform the # conversion. value = self.primitive_desc.ConvertFromWireFormat( - value, container=container) + value, container=container + ) result = self.type(value) @@ -1503,7 +1551,8 @@ def ConvertFromWireFormat(self, value, container=None): def ConvertToWireFormat(self, value): return self.primitive_desc.ConvertToWireFormat( - value.SerializeToWireFormat()) + value.SerializeToWireFormat() + ) def Copy(self, field_number=None): """Returns descriptor copy, optionally changing field number.""" @@ -1514,11 +1563,15 @@ def Copy(self, field_number=None): return ProtoRDFValue( rdf_type=self.original_proto_type_name, default=getattr(self, "default", None), - **new_args) + **new_args, + ) def _FormatField(self): - result = " optional %s %s = %s" % (self.proto_type_name, self.name, - self.field_number) + result = " optional %s %s = %s" % ( + self.proto_type_name, + self.name, + self.field_number, + ) return result + ";\n" def Format(self, value): @@ -1528,7 +1581,11 @@ def Format(self, value): def __str__(self): return "" % ( - self.name, self.proto_type_name, self.owner.__name__, self.field_number) + self.name, + self.proto_type_name, + self.owner.__name__, + self.field_number, + ) class RDFStructMetaclass(rdfvalue.RDFValueMetaclass): @@ -1537,7 +1594,7 @@ class RDFStructMetaclass(rdfvalue.RDFValueMetaclass): _HAS_DYNAMIC_ATTRIBUTES = True # help out pytype def __init__(untyped_cls, name, bases, env_dict): # pylint: disable=no-self-argument - super(RDFStructMetaclass, untyped_cls).__init__(name, bases, env_dict) + super().__init__(name, bases, env_dict) # TODO(user):pytype: find a more elegant solution (if possible). # cast() doesn't accept forward references and argument annotations @@ -1630,11 +1687,13 @@ def __init__(self, initializer=None, **kwargs): if not hasattr(self.__class__, arg): if arg in self.late_bound_type_infos: raise AttributeError( - "Field %s refers to an as yet undefined Semantic Type." % - self.late_bound_type_infos[arg]) + "Field %s refers to an as yet undefined Semantic Type." + % self.late_bound_type_infos[arg] + ) - raise AttributeError("Proto %s has no field %s" % - (self.__class__.__name__, arg)) + raise AttributeError( + "Proto %s has no field %s" % (self.__class__.__name__, arg) + ) # Call setattr to allow the class to define @property pseudo fields which # can also be initialized. @@ -1647,8 +1706,10 @@ def __init__(self, initializer=None, **kwargs): self.CopyConstructor(initializer) else: - raise ValueError("%s can not be initialized from %s" % - (self.__class__.__name__, type(initializer))) + raise ValueError( + "%s can not be initialized from %s" + % (self.__class__.__name__, type(initializer)) + ) def CopyConstructor(self, other): """Efficiently copy from other into this object. @@ -1750,8 +1811,11 @@ def FromSerializedBytes(cls, value: bytes): try: ReadIntoObject(value, 0, instance) except ValueError: - logging.error("Error in ReadIntoObject. %d bytes, extract: %r", - len(value), value[:1000]) + logging.error( + "Error in ReadIntoObject. %d bytes, extract: %r", + len(value), + value[:1000], + ) raise instance.dirty = True @@ -1799,11 +1863,13 @@ def Format(self): """Format a message in a human readable way.""" yield "message %s {" % self.__class__.__name__ - for k, (python_format, wire_format, - type_descriptor) in sorted(self.GetRawData().items()): + for k, (python_format, wire_format, type_descriptor) in sorted( + self.GetRawData().items() + ): if python_format is None: python_format = type_descriptor.ConvertFromWireFormat( - wire_format, container=self) + wire_format, container=self + ) # Skip printing of unknown fields. if isinstance(k, str): @@ -1814,12 +1880,12 @@ def Format(self): yield "}" - def __str__(self) -> Text: + def __str__(self) -> str: return "\n".join(self.Format()) def __dir__(self): """Add the virtualized fields to the console's tab completion.""" - return (dir(super()) + [x.name for x in self.type_infos]) # pylint: disable=not-an-iterable + return dir(super()) + [x.name for x in self.type_infos] # pylint: disable=not-an-iterable def _Set(self, value, type_descriptor): """Validate the value and set the attribute with it.""" @@ -1839,8 +1905,9 @@ def _Set(self, value, type_descriptor): # Make sure to invalidate our parent's cache if needed. self.dirty = True - if (self._prev_hash is not None and - prev_value != self.Get(attr, allow_set_default=False)): + if self._prev_hash is not None and prev_value != self.Get( + attr, allow_set_default=False + ): try: hash(self) # Recompute hash to raise if hash changed due to mutation. except AssertionError as ex: @@ -1850,7 +1917,9 @@ def _Set(self, value, type_descriptor): "sets or keys of dicts is discouraged. If used anyway, mutating is " "prohibited, because it causes the hash to change. Be aware that " "accessing unset fields can trigger a mutation.".format( - type(self).__name__, attr, value, prev_value)) from ex + type(self).__name__, attr, value, prev_value + ) + ) from ex return value def Set(self, attr, value): @@ -1862,8 +1931,9 @@ def _GetTypeDescriptor(self, attr): type_descriptor = self.type_infos.get(attr) if type_descriptor is None: - raise AttributeError("'%s' object has no attribute '%s'" % - (self.__class__.__name__, attr)) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, attr) + ) return type_descriptor @@ -1898,7 +1968,8 @@ def Get(self, attr, allow_set_default=True): # Decode on demand and cache for next time. if python_format is None: python_format = type_descriptor.ConvertFromWireFormat( - wire_format, container=self) + wire_format, container=self + ) self._data[attr] = (python_format, wire_format, type_descriptor) @@ -1917,21 +1988,20 @@ def ClearFieldsWithLabel(self, label, exceptions=None): def AddDescriptor(cls, field_desc): if not isinstance(field_desc, ProtoType): raise type_info.TypeValueError( - "%s field '%s' should be of type ProtoType" % - (cls.__name__, field_desc.name)) + "%s field '%s' should be of type ProtoType" + % (cls.__name__, field_desc.name) + ) cls.type_infos_by_field_number[field_desc.field_number] = field_desc cls.type_infos.Append(field_desc) -class EnumContainer(object): +class EnumContainer: """A data class to hold enum objects.""" - def __init__(self, - name=None, - descriptions=None, - enum_labels=None, - values=None): + def __init__( + self, name=None, descriptions=None, enum_labels=None, values=None + ): descriptions = descriptions or {} enum_labels = enum_labels or {} values = values or {} @@ -1945,7 +2015,8 @@ def __init__(self, v, name=k, description=descriptions.get(k, None), - labels=enum_labels.get(k, None)) + labels=enum_labels.get(k, None), + ) self.enum_dict[k] = v self.reverse_enum[v] = k setattr(self, k, v) @@ -1966,6 +2037,7 @@ class RDFProtoStruct(RDFStruct): This implementation is faster than the standard protobuf library. """ + # TODO(user): if a semantic proto defines a field with the same name as # these class variables under some circumstances the proto default value will # be set incorrectly. Figure out a way to make this safe. @@ -2039,9 +2111,11 @@ def FromDict(self, dictionary): for dynamic_field in dynamic_fields: nested_value = dynamic_field.GetDefault(container=self) if nested_value is None: - raise RuntimeError("Can't initialize dynamic field %s, probably some " - "necessary fields weren't supplied." % - dynamic_field.name) + raise RuntimeError( + "Can't initialize dynamic field %s, probably some " + "necessary fields weren't supplied." + % dynamic_field.name + ) nested_value.FromDict(dictionary[dynamic_field.name]) self.Set(dynamic_field.name, nested_value) @@ -2055,7 +2129,8 @@ def _ToPrimitive(self, value, stringify_leaf_fields): # then protodict has already been loaded. # TODO(user): remove this hack elif "Dict" in rdfvalue.RDFValue.classes and isinstance( - value, rdfvalue.RDFValue.classes["Dict"]): + value, rdfvalue.RDFValue.classes["Dict"] + ): primitive_dict = {} # TODO(user):pytype: get rid of a dependency loop described above and # do a proper type check. @@ -2112,6 +2187,8 @@ def Validate(self): @classmethod def FromTextFormat(cls, text): """Parse this object from a text representation.""" + if cls.protobuf is None: + raise ValueError("protobuf must be set on cls.") tmp = cls.protobuf() # pylint: disable=not-callable text_format.Merge(text, tmp) @@ -2122,8 +2199,9 @@ def AddDescriptor(cls, field_desc): """Register this descriptor with the Proto Struct.""" if not isinstance(field_desc, ProtoType): raise type_info.TypeValueError( - "%s field '%s' should be of type ProtoType" % - (cls.__name__, field_desc.name)) + "%s field '%s' should be of type ProtoType" + % (cls.__name__, field_desc.name) + ) # Ensure the field descriptor knows the class that owns it. field_desc.SetOwner(cls) @@ -2138,8 +2216,9 @@ def AddDescriptor(cls, field_desc): # Ensure this field number is unique: if field_desc.field_number in cls.type_infos_by_field_number: raise type_info.TypeValueError( - "Field number %s for field %s is not unique in %s" % - (field_desc.field_number, field_desc.name, cls.__name__)) + "Field number %s for field %s is not unique in %s" + % (field_desc.field_number, field_desc.name, cls.__name__) + ) # We store an index of the type info by tag values to speed up parsing. cls.type_infos_by_field_number[field_desc.field_number] = field_desc @@ -2153,30 +2232,39 @@ def AddDescriptor(cls, field_desc): # This lambda is a class method so pylint: disable=protected-access # This is much faster than __setattr__/__getattr__ setattr( - cls, field_desc.name, - property(lambda self: self.Get(field_desc.name), - lambda self, x: self._Set(x, field_desc), None, - field_desc.description)) + cls, + field_desc.name, + property( + lambda self: self.Get(field_desc.name), + lambda self, x: self._Set(x, field_desc), + None, + field_desc.description, + ), + ) def UnionCast(self): union_field = getattr(self, self.union_field) cast_field_name = str(union_field).lower() set_fields = set( - type_descriptor.name for type_descriptor, _ in self.ListSetFields()) + type_descriptor.name for type_descriptor, _ in self.ListSetFields() + ) union_cases = [ case.lower() for case in self.type_infos[self.union_field].enum_container.enum_dict ] - mismatched_union_cases = ( - set_fields.intersection(union_cases).difference([cast_field_name])) + mismatched_union_cases = set_fields.intersection(union_cases).difference( + [cast_field_name] + ) if mismatched_union_cases: - raise ValueError("Inconsistent union proto data. Expected only %r " - "to be set, %r are also set." % - (cast_field_name, list(mismatched_union_cases))) + raise ValueError( + "Inconsistent union proto data. Expected only %r " + "to be set, %r are also set." + % (cast_field_name, list(mismatched_union_cases)) + ) try: return getattr(self, cast_field_name) @@ -2186,6 +2274,7 @@ def UnionCast(self): class SemanticDescriptor(RDFProtoStruct): """A semantic protobuf describing the .proto extension.""" + protobuf = semantic_pb2.SemanticDescriptor @@ -2194,6 +2283,7 @@ class SemanticDescriptor(RDFProtoStruct): class AnyValue(RDFProtoStruct): """Protobuf with arbitrary serialized proto and its type.""" + protobuf = any_pb2.Any allow_custom_class_name = True @@ -2256,4 +2346,6 @@ def Unpack(self, cls: Type[_V]) -> _V: def TypeURL(cls: Type[_V]) -> str: + if cls.protobuf is None: + raise ValueError("protobuf must be set on cls.") return f"type.googleapis.com/{cls.protobuf.DESCRIPTOR.full_name}" diff --git a/grr/core/grr_response_core/lib/rdfvalues/structs_test.py b/grr/core/grr_response_core/lib/rdfvalues/structs_test.py index ed41256160..853c809ef1 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/structs_test.py +++ b/grr/core/grr_response_core/lib/rdfvalues/structs_test.py @@ -1,10 +1,8 @@ #!/usr/bin/env python """Test RDFStruct implementations.""" - import base64 import random -from typing import Text from absl import app from absl.testing import absltest @@ -37,7 +35,8 @@ class TestStructWithManyFields(rdf_structs.RDFProtoStruct): field_number=i + 1, default="string", description="A string value", - ) for i in range(100) + ) + for i in range(100) ]) @@ -50,16 +49,18 @@ class TestStruct(rdf_structs.RDFProtoStruct): field_number=1, default="string", description="A string value", - labels=[rdf_structs.SemanticDescriptor.Labels.HIDDEN]), + labels=[rdf_structs.SemanticDescriptor.Labels.HIDDEN], + ), rdf_structs.ProtoUnsignedInteger( - name="int", field_number=2, default=5, - description="An integer value"), + name="int", field_number=2, default=5, description="An integer value" + ), rdf_structs.ProtoList( rdf_structs.ProtoString( name="repeated", field_number=3, - description="A repeated string value")), - + description="A repeated string value", + ) + ), # We can serialize an arbitrary RDFValue. This will be serialized into a # binary string and parsed on demand. rdf_structs.ProtoRDFValue( @@ -67,36 +68,38 @@ class TestStruct(rdf_structs.RDFProtoStruct): field_number=6, default=rdfvalue.RDFURN("www.google.com"), rdf_type="RDFURN", - description="An arbitrary RDFValue field."), + description="An arbitrary RDFValue field.", + ), rdf_structs.ProtoEnum( name="type", field_number=7, enum_name="Type", - enum={ - "FIRST": 1, - "SECOND": 2, - "THIRD": 3 - }, + enum={"FIRST": 1, "SECOND": 2, "THIRD": 3}, default=3, - description="An enum field"), + description="An enum field", + ), rdf_structs.ProtoFloat( name="float", field_number=8, description="A float number", - default=1.1), + default=1.1, + ), ) # In order to define a recursive structure we must add it manually after the # class definition. TestStruct.AddDescriptor( - rdf_structs.ProtoEmbedded(name="nested", field_number=4, - nested=TestStruct),) + rdf_structs.ProtoEmbedded(name="nested", field_number=4, nested=TestStruct), +) TestStruct.AddDescriptor( rdf_structs.ProtoList( rdf_structs.ProtoEmbedded( - name="repeat_nested", field_number=5, nested=TestStruct)),) + name="repeat_nested", field_number=5, nested=TestStruct + ) + ), +) class VersionedTestStructV1(rdf_structs.RDFProtoStruct): @@ -107,7 +110,8 @@ class VersionedTestStructV1(rdf_structs.RDFProtoStruct): name="bool1", field_number=1, default=False, - )) + ) + ) class VersionedTestStructV2(rdf_structs.RDFProtoStruct): @@ -118,11 +122,13 @@ class VersionedTestStructV2(rdf_structs.RDFProtoStruct): name="bool1", field_number=1, default=False, - ), rdf_structs.ProtoBoolean( + ), + rdf_structs.ProtoBoolean( name="bool2", field_number=2, default=False, - )) + ), + ) class TestStructWithBool(rdf_structs.RDFProtoStruct): @@ -133,7 +139,8 @@ class TestStructWithBool(rdf_structs.RDFProtoStruct): name="foo", field_number=1, default=False, - )) + ) + ) class TestStructWithEnum(rdf_structs.RDFProtoStruct): @@ -144,18 +151,18 @@ class TestStructWithEnum(rdf_structs.RDFProtoStruct): name="foo", field_number=1, enum_name="Foo", - enum={ - "ZERO": 0, - "ONE": 1, - "TWO": 2 - }, - default=0)) + enum={"ZERO": 0, "ONE": 1, "TWO": 2}, + default=0, + ) + ) class PartialTest1(rdf_structs.RDFProtoStruct): """This is a protobuf with fewer fields than TestStruct.""" + type_description = type_info.TypeDescriptorSet( - rdf_structs.ProtoUnsignedInteger(name="int", field_number=2),) + rdf_structs.ProtoUnsignedInteger(name="int", field_number=2), + ) class DynamicTypeTest(rdf_structs.RDFProtoStruct): @@ -167,15 +174,19 @@ class DynamicTypeTest(rdf_structs.RDFProtoStruct): field_number=1, # By default return the TestStruct proto. default="TestStruct", - description="A string value"), + description="A string value", + ), rdf_structs.ProtoDynamicEmbedded( name="dynamic", # The callback here returns the type specified by the type member. dynamic_cb=lambda x: rdf_structs.RDFProtoStruct.classes.get(x.type), field_number=2, - description="A dynamic value based on another field."), + description="A dynamic value based on another field.", + ), rdf_structs.ProtoEmbedded( - name="nested", field_number=3, nested=rdf_client.User)) + name="nested", field_number=3, nested=rdf_client.User + ), + ) class DynamicAnyValueTypeTest(rdf_structs.RDFProtoStruct): @@ -183,13 +194,15 @@ class DynamicAnyValueTypeTest(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoString( - name="type", field_number=1, description="A string value"), + name="type", field_number=1, description="A string value" + ), rdf_structs.ProtoDynamicAnyValueEmbedded( name="dynamic", # The callback here returns the type specified by the type member. dynamic_cb=lambda x: rdf_structs.RDFProtoStruct.classes.get(x.type), field_number=2, - description="A dynamic value based on another field."), + description="A dynamic value based on another field.", + ), ) @@ -198,27 +211,32 @@ class AnyValueWithoutTypeFunctionTest(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoDynamicAnyValueEmbedded( - name="dynamic", field_number=1, description="A dynamic value."),) + name="dynamic", field_number=1, description="A dynamic value." + ), + ) class LateBindingTest(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( # A nested protobuf referring to an undefined type. rdf_structs.ProtoEmbedded( - name="nested", field_number=1, nested="UndefinedYet"), + name="nested", field_number=1, nested="UndefinedYet" + ), rdf_structs.ProtoRDFValue( name="rdfvalue", field_number=6, rdf_type="UndefinedRDFValue", - description="An undefined RDFValue field."), - + description="An undefined RDFValue field.", + ), # A repeated late bound field. rdf_structs.ProtoList( rdf_structs.ProtoRDFValue( name="repeated", field_number=7, rdf_type="UndefinedRDFValue2", - description="An undefined RDFValue field.")), + description="An undefined RDFValue field.", + ) + ), ) @@ -230,28 +248,28 @@ class UnionTest(rdf_structs.RDFProtoStruct): name="struct_flavor", field_number=1, enum_name="Type", - enum={ - "FIRST": 1, - "SECOND": 2, - "THIRD": 3 - }, + enum={"FIRST": 1, "SECOND": 2, "THIRD": 3}, default=3, - description="An union enum field"), + description="An union enum field", + ), rdf_structs.ProtoFloat( name="first", field_number=2, description="A float number", - default=1.1), + default=1.1, + ), rdf_structs.ProtoString( name="second", field_number=3, default="string", - description="A string value"), + description="A string value", + ), rdf_structs.ProtoUnsignedInteger( name="third", field_number=4, default=5, - description="An integer value"), + description="An integer value", + ), ) @@ -279,7 +297,8 @@ def GenerateSample(self, number=1): int=number, foobar="foo%s" % number, urn="www.example.com", - float=2.3 + number) + float=2.3 + number, + ) def testDynamicType(self): test_pb = DynamicTypeTest() @@ -317,12 +336,13 @@ def testAnyValueWithoutTypeCallback(self): rdfvalue.RDFString("test"), rdfvalue.RDFInteger(1234), rdfvalue.RDFBytes(b"abc"), - rdf_flows.GrrStatus(status="WORKER_STUCK", error_message="stuck") + rdf_flows.GrrStatus(status="WORKER_STUCK", error_message="stuck"), ]: test_pb.dynamic = value_to_assign serialized = test_pb.SerializeToBytes() deserialized = AnyValueWithoutTypeFunctionTest.FromSerializedBytes( - serialized) + serialized + ) self.assertEqual(deserialized, test_pb) self.assertEqual(type(deserialized.dynamic), type(value_to_assign)) @@ -338,7 +358,8 @@ def testDynamicAnyValueType(self): # Test serialization/deserialization. serialized = test_pb.SerializeToBytes() self.assertEqual( - DynamicAnyValueTypeTest.FromSerializedBytes(serialized), test_pb) + DynamicAnyValueTypeTest.FromSerializedBytes(serialized), test_pb + ) # Test proto definition. self.assertEqual( @@ -346,7 +367,8 @@ def testDynamicAnyValueType(self): "message DynamicAnyValueTypeTest {\n\n " "// A string value\n optional string type = 1;\n\n " "// A dynamic value based on another field.\n " - "optional google.protobuf.Any dynamic = 2;\n}\n") + "optional google.protobuf.Any dynamic = 2;\n}\n", + ) def testUninitializedDynamicValueIsNonePerDefault(self): response = rdf_flow_objects.FlowResponse() # Do not set payload. @@ -371,17 +393,23 @@ def testStructDefinition(self): """Ensure that errors in struct definitions are raised.""" # A descriptor without a field number should raise. self.assertRaises( - type_info.TypeValueError, rdf_structs.ProtoEmbedded, name="name") + type_info.TypeValueError, rdf_structs.ProtoEmbedded, name="name" + ) # Adding a duplicate field number should raise. self.assertRaises( - type_info.TypeValueError, TestStruct.AddDescriptor, - rdf_structs.ProtoUnsignedInteger(name="int", field_number=2)) + type_info.TypeValueError, + TestStruct.AddDescriptor, + rdf_structs.ProtoUnsignedInteger(name="int", field_number=2), + ) # Adding a descriptor which is not a Proto* descriptor is not allowed for # Struct fields: - self.assertRaises(type_info.TypeValueError, TestStruct.AddDescriptor, - type_info.String(name="int")) + self.assertRaises( + type_info.TypeValueError, + TestStruct.AddDescriptor, + type_info.String(name="int"), + ) def testRepeatedMember(self): tested = TestStruct(int=5) @@ -404,8 +432,9 @@ def testRepeatedMember(self): # Check that slicing works. sliced = new_tested.repeat_nested[3:5] self.assertEqual(sliced.__class__, new_tested.repeat_nested.__class__) - self.assertEqual(sliced.type_descriptor, - new_tested.repeat_nested.type_descriptor) + self.assertEqual( + sliced.type_descriptor, new_tested.repeat_nested.type_descriptor + ) self.assertLen(sliced, 2) self.assertEqual(sliced[0].foobar, "Nest3") @@ -477,8 +506,9 @@ def testRDFStruct(self): tested.nested = TestStruct(foobar="nested_foo") # Not OK to use the wrong semantic type. - self.assertRaises(ValueError, setattr, tested, "nested", - PartialTest1(int=1)) + self.assertRaises( + ValueError, setattr, tested, "nested", PartialTest1(int=1) + ) # Not OK to assign a serialized string - even if it is for the right type - # since there is no type checking. @@ -497,8 +527,9 @@ def testRDFStruct(self): tested.repeated = ["string"] self.assertEqual(tested.repeated, ["string"]) - self.assertRaises(type_info.TypeValueError, setattr, tested, "repeated", - [1, 2, 3]) + self.assertRaises( + type_info.TypeValueError, setattr, tested, "repeated", [1, 2, 3] + ) # Coercing on assignment. This field is an RDFURN: tested.urn = "www.example.com" @@ -526,7 +557,7 @@ def testRDFStruct(self): tested.type = "2" self.assertEqual(tested.type, 2) # unicode strings should be treated the same way. - tested.type = u"2" + tested.type = "2" self.assertEqual(tested.type, 2) # Out of range values are permitted and preserved through serialization. tested.type = 4 @@ -556,8 +587,9 @@ def testCacheInvalidation(self): def testLateBinding(self): # The LateBindingTest protobuf is not fully defined. - self.assertRaises(KeyError, LateBindingTest.type_infos.__getitem__, - "nested") + self.assertRaises( + KeyError, LateBindingTest.type_infos.__getitem__, "nested" + ) self.assertIn("UndefinedYet", rdfvalue._LATE_BINDING_STORE) @@ -573,7 +605,9 @@ def testLateBinding(self): class UndefinedYet(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoString( - name="foobar", field_number=1, description="A string value"),) + name="foobar", field_number=1, description="A string value" + ), + ) # The field is now resolved. self.assertNotIn("UndefinedYet", rdfvalue._LATE_BINDING_STORE) @@ -587,8 +621,9 @@ class UndefinedYet(rdf_structs.RDFProtoStruct): def testRDFValueLateBinding(self): # The LateBindingTest protobuf is not fully defined. - self.assertRaises(KeyError, LateBindingTest.type_infos.__getitem__, - "rdfvalue") + self.assertRaises( + KeyError, LateBindingTest.type_infos.__getitem__, "rdfvalue" + ) self.assertIn("UndefinedRDFValue", rdfvalue._LATE_BINDING_STORE) @@ -615,8 +650,9 @@ class UndefinedRDFValue(rdfvalue.RDFString): def testRepeatedRDFValueLateBinding(self): # The LateBindingTest protobuf is not fully defined. - self.assertRaises(KeyError, LateBindingTest.type_infos.__getitem__, - "repeated") + self.assertRaises( + KeyError, LateBindingTest.type_infos.__getitem__, "repeated" + ) self.assertIn("UndefinedRDFValue2", rdfvalue._LATE_BINDING_STORE) @@ -729,17 +765,14 @@ def testConversionToPrimitiveDictNoSerialization(self): int=2, repeated=["value0", "value1"], nested=TestStruct(int=567), - repeat_nested=[TestStruct(int=568)]) + repeat_nested=[TestStruct(int=568)], + ) expected_dict = { "foobar": "foo", "int": 2, "repeated": ["value0", "value1"], - "nested": { - "int": 567 - }, - "repeat_nested": [{ - "int": 568 - }] + "nested": {"int": 567}, + "repeat_nested": [{"int": 568}], } self.assertEqual(test_struct.ToPrimitiveDict(), expected_dict) @@ -749,20 +782,18 @@ def testConversionToPrimitiveDictWithSerialization(self): int=2, repeated=["value0", "value1"], nested=TestStruct(int=567), - repeat_nested=[TestStruct(int=568)]) + repeat_nested=[TestStruct(int=568)], + ) expected_dict = { "foobar": "foo", "int": "2", # Serialized "repeated": ["value0", "value1"], - "nested": { - "int": "567" # Serialized - }, - "repeat_nested": [{ - "int": "568" # Serialized - }] + "nested": {"int": "567"}, # Serialized + "repeat_nested": [{"int": "568"}], # Serialized } self.assertEqual( - test_struct.ToPrimitiveDict(stringify_leaf_fields=True), expected_dict) + test_struct.ToPrimitiveDict(stringify_leaf_fields=True), expected_dict + ) def testToPrimitiveDictStringifyBytes(self): data = b"\xff\xfe\xff" @@ -771,12 +802,13 @@ def testToPrimitiveDictStringifyBytes(self): class FooStruct(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( - rdf_structs.ProtoBinary(name="data", field_number=1)) + rdf_structs.ProtoBinary(name="data", field_number=1) + ) foo_struct = FooStruct(data=data) foo_dict = foo_struct.ToPrimitiveDict(stringify_leaf_fields=True) - self.assertIsInstance(foo_dict["data"], Text) + self.assertIsInstance(foo_dict["data"], str) self.assertEqual(foo_dict["data"], encoded_data) def testToPrimitiveDictStringifyRDFBytes(self): @@ -787,12 +819,14 @@ class BarStruct(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoRDFValue( - name="data", field_number=1, rdf_type=rdfvalue.RDFBytes)) + name="data", field_number=1, rdf_type=rdfvalue.RDFBytes + ) + ) bar_struct = BarStruct(data=data) bar_dict = bar_struct.ToPrimitiveDict(stringify_leaf_fields=True) - self.assertIsInstance(bar_dict["data"], Text) + self.assertIsInstance(bar_dict["data"], str) self.assertEqual(bar_dict["data"], encoded_data) def _GenerateSampleWithManyFields(self): @@ -804,7 +838,8 @@ def _GenerateSampleWithManyFields(self): sample = TestStructWithManyFields(**fields) parsed = TestStructWithManyFields.FromSerializedBytes( - sample.SerializeToBytes()) + sample.SerializeToBytes() + ) return sample, parsed @@ -822,10 +857,12 @@ def testSymmetricEqualityForDisjointFields(self): # the default value. self.assertNotEqual( rdf_client_stats.IOSample(write_count=6), - rdf_client_stats.IOSample(read_bytes=0)) + rdf_client_stats.IOSample(read_bytes=0), + ) self.assertNotEqual( rdf_client_stats.IOSample(read_bytes=0), - rdf_client_stats.IOSample(write_count=6)) + rdf_client_stats.IOSample(write_count=6), + ) def testDefaultRepeatedSetterDoesNotChangeEquality(self): sample = TestStruct() @@ -846,7 +883,8 @@ def testCanDeserializeOldVersion(self): def testCanCompareEqualityAcrossVersions(self): a = VersionedTestStructV1.FromSerializedBytes( - VersionedTestStructV2(bool1=True, bool2=True).SerializeToBytes()) + VersionedTestStructV2(bool1=True, bool2=True).SerializeToBytes() + ) b = VersionedTestStructV1(bool1=True) @@ -897,12 +935,18 @@ def testNestedProtobufAssignment(self): pathspec = rdf_paths.PathSpec(path=test_path, pathtype=1) # Should raise - incompatible RDFType. - self.assertRaises(ValueError, setattr, container, "pathspec", - rdfvalue.RDFString("hello")) + self.assertRaises( + ValueError, setattr, container, "pathspec", rdfvalue.RDFString("hello") + ) # Should raise - incompatible RDFProto type. - self.assertRaises(ValueError, setattr, container, "pathspec", - rdf_client_fs.StatEntry(st_size=5)) + self.assertRaises( + ValueError, + setattr, + container, + "pathspec", + rdf_client_fs.StatEntry(st_size=5), + ) # Assign directly. container.device = pathspec @@ -922,7 +966,9 @@ def testSimpleTypeAssignment(self): name="test", field_number=45, default=rdfvalue.RDFInteger(0), - rdf_type=rdfvalue.RDFInteger)) + rdf_type=rdfvalue.RDFInteger, + ) + ) self.assertIsInstance(sample.test, rdfvalue.RDFInteger) @@ -952,13 +998,19 @@ def testSimpleTypeAssignment(self): self.assertEqual(sample.test, 10) # Assign an RDFValue which can not be coerced. - self.assertRaises(type_info.TypeValueError, setattr, sample, "test", - rdfvalue.RDFString("hello")) + self.assertRaises( + type_info.TypeValueError, + setattr, + sample, + "test", + rdfvalue.RDFString("hello"), + ) def testComplexConstruction(self): """Test that we can construct RDFProtos with nested fields.""" pathspec = rdf_paths.PathSpec( - path="/foobar", pathtype=rdf_paths.PathSpec.PathType.TSK) + path="/foobar", pathtype=rdf_paths.PathSpec.PathType.TSK + ) sample = rdf_client_fs.StatEntry(pathspec=pathspec, st_size=5) self.assertEqual(sample.pathspec.path, "/foobar") @@ -968,10 +1020,11 @@ def testComplexConstruction(self): def testUnicodeSupport(self): pathspec = rdf_paths.PathSpec( - path="/foobar", pathtype=rdf_paths.PathSpec.PathType.TSK) - pathspec.path = u"Grüezi" + path="/foobar", pathtype=rdf_paths.PathSpec.PathType.TSK + ) + pathspec.path = "Grüezi" - self.assertEqual(pathspec.path, u"Grüezi") + self.assertEqual(pathspec.path, "Grüezi") def testRepeatedFields(self): """Test handling of protobuf repeated fields.""" @@ -988,7 +1041,8 @@ def testRepeatedFields(self): # Add an rdfvalue. sample.addresses.Append( - rdf_client_network.NetworkAddress(human_readable_address="1.2.3.4")) + rdf_client_network.NetworkAddress(human_readable_address="1.2.3.4") + ) self.assertLen(sample.addresses, 2) self.assertEqual(sample.addresses[1].human_readable_address, "1.2.3.4") @@ -1050,14 +1104,13 @@ class EnumContainerTest(absltest.TestCase): def setUp(self): super().setUp() self.enum_container = rdf_structs.EnumContainer( - name="foo", values={ - "bar": 1, - "baz": 2 - }) + name="foo", values={"bar": 1, "baz": 2} + ) def testFromString(self): self.assertEqual( - self.enum_container.FromString("bar"), self.enum_container.bar) + self.enum_container.FromString("bar"), self.enum_container.bar + ) def testFromString_invalidValue(self): with self.assertRaises(ValueError): @@ -1194,8 +1247,22 @@ def testCorrectlyReadsEncodedIntegers(self): def testReaderCanReadWhatEncoderHasEncoded(self): for v in [ - 0, 1, 2, 10, 20, 63, 64, 65, 127, 128, 129, 255, 256, 257, 1 << 63 - 1, - 1 << 64 - 1 + 0, + 1, + 2, + 10, + 20, + 63, + 64, + 65, + 127, + 128, + 129, + 255, + 256, + 257, + 1 << 63 - 1, + 1 << 64 - 1, ]: with self.subTest(value=v): buf = rdf_structs.VarintEncode(v) @@ -1204,24 +1271,32 @@ def testReaderCanReadWhatEncoderHasEncoded(self): self.assertLen(buf, p) def testDecodingZeroBufferRaises(self): - with self.assertRaisesRegex(ValueError, - "Too many bytes when decoding varint"): + with self.assertRaisesRegex( + ValueError, "Too many bytes when decoding varint" + ): rdf_structs.VarintReader(b"", 0) def testDecodingValueLargerThan64BitReturnsTruncatedValue(self): self.assertEqual( - rdf_structs.VarintReader(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02", - 0), (0, 10)) + rdf_structs.VarintReader( + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02", 0 + ), + (0, 10), + ) def testRaisesWhenBufferIsTooLong(self): - with self.assertRaisesRegex(ValueError, - "Too many bytes when decoding varint"): + with self.assertRaisesRegex( + ValueError, "Too many bytes when decoding varint" + ): + rdf_structs.VarintReader( + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01\x00\x00", 0 + ) + with self.assertRaisesRegex( + ValueError, "Too many bytes when decoding varint" + ): rdf_structs.VarintReader( - b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01\x00\x00", 0) - with self.assertRaisesRegex(ValueError, - "Too many bytes when decoding varint"): - rdf_structs.VarintReader(b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", - 0) + b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", 0 + ) def testNonCanonicalZeroDoesNotRaise(self): self.assertEqual(rdf_structs.VarintReader(b"\x80\x80\x80\x00", 0), (0, 4)) @@ -1252,7 +1327,8 @@ def testCorrectlyReadsEncodedSignedIntegers(self): for src_val, b in self._ENCODED_SIGNED_VARINTS: with self.subTest(buffer=b): self.assertEqual( - rdf_structs.SignedVarintReader(b, 0), (src_val, len(b))) + rdf_structs.SignedVarintReader(b, 0), (src_val, len(b)) + ) class SplitBufferTest(absltest.TestCase): @@ -1287,37 +1363,43 @@ def testRaisesOnBrokenVarintTag(self): def testCorrectlyProcessesVarintTag(self): self.assertEqual( - rdf_structs.SplitBuffer("\x00\x01", 0, 2), [(b"\x00", b"", b"\x01")]) + rdf_structs.SplitBuffer("\x00\x01", 0, 2), [(b"\x00", b"", b"\x01")] + ) def testRaisesOnOversizedFixed64Tag(self): for l in range(8): with self.subTest(l=l): - with self.assertRaisesRegex(ValueError, - "Fixed64 tag exceeds available buffer"): + with self.assertRaisesRegex( + ValueError, "Fixed64 tag exceeds available buffer" + ): buf = b"\x01" + b"\x00" * l rdf_structs.SplitBuffer(buf, 0, len(buf)) def testCorrectlyProcessesFixed64Tag(self): self.assertEqual( rdf_structs.SplitBuffer("\x01\x00\x01\x02\x03\x04\x05\0x6\0x7", 0, 9), - [(b"\x01", b"", b"\x00\x01\x02\x03\x04\x05\x00x")]) + [(b"\x01", b"", b"\x00\x01\x02\x03\x04\x05\x00x")], + ) def testRaisesOnOversizedFixed32Tag(self): for l in range(4): with self.subTest(l=l): - with self.assertRaisesRegex(ValueError, - "Fixed32 tag exceeds available buffer"): + with self.assertRaisesRegex( + ValueError, "Fixed32 tag exceeds available buffer" + ): buf = b"\x05" + b"\x00" * l rdf_structs.SplitBuffer(buf, 0, len(buf)) def testCorrectlyProcessesFixed32Tag(self): self.assertEqual( rdf_structs.SplitBuffer("\x05\x00\x01\x02\x03", 0, 5), - [(b"\x05", b"", b"\x00\x01\x02\x03")]) + [(b"\x05", b"", b"\x00\x01\x02\x03")], + ) def testRaisesOnBrokenLengthDelimitedTag(self): - with self.assertRaisesRegex(ValueError, - "Broken length_delimited tag encountered"): + with self.assertRaisesRegex( + ValueError, "Broken length_delimited tag encountered" + ): rdf_structs.SplitBuffer(b"\x02\xff", 0, 2) def testRaisesOnLengthDelimitedTagExceedingMaxInt(self): @@ -1326,14 +1408,16 @@ def testRaisesOnLengthDelimitedTagExceedingMaxInt(self): rdf_structs.SplitBuffer(buf, 0, len(buf)) def testRaisesOnOversizedLengthDelimitedTag(self): - with self.assertRaisesRegex(ValueError, - "Length tag exceeds available buffer"): + with self.assertRaisesRegex( + ValueError, "Length tag exceeds available buffer" + ): rdf_structs.SplitBuffer(b"\x02\x02", 0, 2) def testCorrectlyProcessesLengthDelimitedTag(self): self.assertEqual( rdf_structs.SplitBuffer("\x02\x02\x00\x01", 0, 4), - [(b"\x02", b"\x02", b"\x00\x01")]) + [(b"\x02", b"\x02", b"\x00\x01")], + ) def testRaisesOnUnknownTag(self): with self.assertRaisesRegex(ValueError, "Unexpected Tag"): diff --git a/grr/core/grr_response_core/lib/rdfvalues/test_base.py b/grr/core/grr_response_core/lib/rdfvalues/test_base.py index 5a4a7e513b..0aabfd0c80 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/test_base.py +++ b/grr/core/grr_response_core/lib/rdfvalues/test_base.py @@ -1,15 +1,13 @@ #!/usr/bin/env python """The base classes for RDFValue tests.""" -from typing import Text - from grr_response_core.lib import serialization from grr_response_core.lib.rdfvalues import structs as rdf_structs # pylint:mode=test -class RDFValueTestMixin(object): +class RDFValueTestMixin: """The base class for testing RDFValue implementations.""" # This should be overridden by the RDFValue class we want to test. @@ -94,15 +92,16 @@ def testSerialization(self, sample=None): if protobuf_type == "bytes": self.assertIsInstance(serialized, bytes) elif protobuf_type == "string": - self.assertIsInstance(serialized, Text) + self.assertIsInstance(serialized, str) elif protobuf_type in ["unsigned_integer", "integer"]: self.assertIsInstance(serialized, int) else: self.fail("%s has no valid protobuf_type" % self.rdfvalue_class) # Ensure we can parse it again. - rdfvalue_object = serialization.FromWireFormat(self.rdfvalue_class, - serialized) + rdfvalue_object = serialization.FromWireFormat( + self.rdfvalue_class, serialized + ) self.CheckRDFValue(rdfvalue_object, sample) diff --git a/grr/core/grr_response_core/lib/rdfvalues/timeline.py b/grr/core/grr_response_core/lib/rdfvalues/timeline.py index 3fa6a6e6bf..cb783034c3 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/timeline.py +++ b/grr/core/grr_response_core/lib/rdfvalues/timeline.py @@ -2,7 +2,6 @@ """A module with RDF value wrappers for timeline protobufs.""" import os - from typing import Iterator from grr_response_core.lib.rdfvalues import structs as rdf_structs diff --git a/grr/core/grr_response_core/lib/rdfvalues/webhistory.py b/grr/core/grr_response_core/lib/rdfvalues/webhistory.py index 48fb720b51..d8211ae597 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/webhistory.py +++ b/grr/core/grr_response_core/lib/rdfvalues/webhistory.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """RDFValues describing web history artifacts.""" - from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/rdfvalues/wkt.py b/grr/core/grr_response_core/lib/rdfvalues/wkt.py index a9433a728c..b717a7fcea 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/wkt.py +++ b/grr/core/grr_response_core/lib/rdfvalues/wkt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with RDF wrappers for Protocol Buffers Well-Known Types.""" + from google.protobuf import timestamp_pb2 from grr_response_core.lib.rdfvalues import structs as rdf_structs diff --git a/grr/core/grr_response_core/lib/rdfvalues/wmi.py b/grr/core/grr_response_core/lib/rdfvalues/wmi.py index 35e620075b..fb1f743f78 100644 --- a/grr/core/grr_response_core/lib/rdfvalues/wmi.py +++ b/grr/core/grr_response_core/lib/rdfvalues/wmi.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """WMI RDF values.""" - from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import sysinfo_pb2 diff --git a/grr/core/grr_response_core/lib/registry.py b/grr/core/grr_response_core/lib/registry.py index ddd8e93cbb..b014b22994 100644 --- a/grr/core/grr_response_core/lib/registry.py +++ b/grr/core/grr_response_core/lib/registry.py @@ -37,8 +37,9 @@ def IsAbstract(cls): # naming them Abstract. abstract_attribute = "_%s__abstract" % cls.__name__ - return (cls.__name__.startswith("Abstract") or - hasattr(cls, abstract_attribute)) + return cls.__name__.startswith("Abstract") or hasattr( + cls, abstract_attribute + ) def __init__(cls, name, bases, env_dict): abc.ABCMeta.__init__(cls, name, bases, env_dict) @@ -58,8 +59,10 @@ def __init__(cls, name, bases, env_dict): try: if cls.classes and cls.__name__ in cls.classes: - raise RuntimeError("Duplicate names for registered classes: %s, %s" % - (cls, cls.classes[cls.__name__])) + raise RuntimeError( + "Duplicate names for registered classes: %s, %s" + % (cls, cls.classes[cls.__name__]) + ) cls.classes[cls.__name__] = cls cls.classes_by_name[getattr(cls, "name", None)] = cls @@ -117,26 +120,6 @@ def __init__(cls, name, bases, env_dict): EventRegistry.EVENT_NAME_MAP.setdefault(ev, set()).add(cls) -class AFF4FlowRegistry(MetaclassRegistry): - """A dedicated registry that only contains flows.""" - - FLOW_REGISTRY = {} - - def __init__(cls, name, bases, env_dict): - MetaclassRegistry.__init__(cls, name, bases, env_dict) - - if not cls.IsAbstract(): - cls.FLOW_REGISTRY[name] = cls - - @classmethod - def FlowClassByName(mcs, flow_name): - flow_cls = mcs.FLOW_REGISTRY.get(flow_name) - if flow_cls is None: - raise ValueError("Flow '%s' not known." % flow_name) - - return flow_cls - - class FlowRegistry(MetaclassRegistry): """A dedicated registry that only contains new style flows.""" @@ -183,7 +166,7 @@ class SystemCronJobRegistry(CronJobRegistry): SYSTEM_CRON_REGISTRY = {} def __init__(cls, name, bases, env_dict): - super(SystemCronJobRegistry, cls).__init__(name, bases, env_dict) + super().__init__(name, bases, env_dict) if not cls.IsAbstract(): cls.SYSTEM_CRON_REGISTRY[name] = cls diff --git a/grr/core/grr_response_core/lib/serialization.py b/grr/core/grr_response_core/lib/serialization.py index ba367ec8cf..6914eb698c 100644 --- a/grr/core/grr_response_core/lib/serialization.py +++ b/grr/core/grr_response_core/lib/serialization.py @@ -2,7 +2,6 @@ """(De-)serialization to bytes, wire format, and human readable strings.""" import abc -from typing import Text from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import precondition @@ -18,31 +17,27 @@ def protobuf_type(self): @abc.abstractmethod def FromBytes(self, value: bytes): """Deserializes a value from bytes outputted by ToBytes.""" - pass @abc.abstractmethod def ToBytes(self, value) -> bytes: """Serializes `value` into bytes which can be parsed with FromBytes.""" - pass @abc.abstractmethod def FromWireFormat(self, value): """Deserializes a value from a primitive outputted by ToWireFormat.""" - pass @abc.abstractmethod def ToWireFormat(self, value): """Serializes to a primitive which can be parsed with FromWireFormat.""" - pass @abc.abstractmethod - def FromHumanReadable(self, string: Text): + def FromHumanReadable(self, string: str): """Deserializes a value from a string outputted by str(value).""" - pass class BoolConverter(Converter): """Converter for Python's `bool`.""" + protobuf_type = "unsigned_integer" wrapping_type = bool @@ -59,11 +54,11 @@ def FromWireFormat(self, value: int) -> bool: def ToWireFormat(self, value: bool) -> int: return 1 if value else 0 - def FromHumanReadable(self, string: Text) -> bool: + def FromHumanReadable(self, string: str) -> bool: upper_string = string.upper() - if upper_string == u"TRUE" or string == u"1": + if upper_string == "TRUE" or string == "1": return True - elif upper_string == u"FALSE" or string == u"0": + elif upper_string == "FALSE" or string == "0": return False else: raise ValueError("Unparsable boolean string: `%s`" % string) @@ -94,7 +89,7 @@ def ToWireFormat(self, value: rdfvalue.RDFValue): precondition.AssertType(value, self._cls) return value.SerializeToWireFormat() - def FromHumanReadable(self, string: Text) -> rdfvalue.RDFValue: + def FromHumanReadable(self, string: str) -> rdfvalue.RDFValue: if issubclass(self._cls, rdfvalue.RDFPrimitive): return self._cls.FromHumanReadable(string) else: @@ -115,9 +110,9 @@ def GetProtobufType(cls): return _GetFactory(cls).protobuf_type -def FromHumanReadable(cls, string: Text): +def FromHumanReadable(cls, string: str): """Deserializes a value of `cls` from a string outputted by str(value).""" - precondition.AssertType(string, Text) + precondition.AssertType(string, str) return _GetFactory(cls).FromHumanReadable(string) diff --git a/grr/core/grr_response_core/lib/serialization_test.py b/grr/core/grr_response_core/lib/serialization_test.py index f493719d9c..afa3491dd7 100644 --- a/grr/core/grr_response_core/lib/serialization_test.py +++ b/grr/core/grr_response_core/lib/serialization_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for (de)serialization logic.""" - from absl.testing import absltest from grr_response_core.lib import serialization @@ -10,38 +9,42 @@ class BoolConverterTest(absltest.TestCase): def testFromHumanReadableTrue(self): - self.assertIs(serialization.FromHumanReadable(bool, u"true"), True) - self.assertIs(serialization.FromHumanReadable(bool, u"True"), True) - self.assertIs(serialization.FromHumanReadable(bool, u"TRUE"), True) - self.assertIs(serialization.FromHumanReadable(bool, u"1"), True) + self.assertIs(serialization.FromHumanReadable(bool, "true"), True) + self.assertIs(serialization.FromHumanReadable(bool, "True"), True) + self.assertIs(serialization.FromHumanReadable(bool, "TRUE"), True) + self.assertIs(serialization.FromHumanReadable(bool, "1"), True) def testFromHumanReadableFalse(self): - self.assertIs(serialization.FromHumanReadable(bool, u"false"), False) - self.assertIs(serialization.FromHumanReadable(bool, u"False"), False) - self.assertIs(serialization.FromHumanReadable(bool, u"FALSE"), False) - self.assertIs(serialization.FromHumanReadable(bool, u"0"), False) + self.assertIs(serialization.FromHumanReadable(bool, "false"), False) + self.assertIs(serialization.FromHumanReadable(bool, "False"), False) + self.assertIs(serialization.FromHumanReadable(bool, "FALSE"), False) + self.assertIs(serialization.FromHumanReadable(bool, "0"), False) def testFromHumanReadableRaisesOnIncorrectInteger(self): with self.assertRaises(ValueError): - serialization.FromHumanReadable(bool, u"2") + serialization.FromHumanReadable(bool, "2") def testFromHumanReadableRaisesOnWeirdInput(self): with self.assertRaises(ValueError): - serialization.FromHumanReadable(bool, u"yes") + serialization.FromHumanReadable(bool, "yes") def testWireFormat(self): self.assertIs( serialization.FromWireFormat(bool, serialization.ToWireFormat(True)), - True) + True, + ) self.assertIs( serialization.FromWireFormat(bool, serialization.ToWireFormat(False)), - False) + False, + ) def testBytes(self): self.assertIs( - serialization.FromBytes(bool, serialization.ToBytes(True)), True) + serialization.FromBytes(bool, serialization.ToBytes(True)), True + ) self.assertIs( - serialization.FromBytes(bool, serialization.ToBytes(False)), False) + serialization.FromBytes(bool, serialization.ToBytes(False)), False + ) def testHumanReadable(self): self.assertIs(serialization.FromHumanReadable(bool, str(True)), True) diff --git a/grr/core/grr_response_core/lib/time_utils.py b/grr/core/grr_response_core/lib/time_utils.py index 1c8dc3e16e..cad9e03245 100644 --- a/grr/core/grr_response_core/lib/time_utils.py +++ b/grr/core/grr_response_core/lib/time_utils.py @@ -4,7 +4,7 @@ from grr_response_core.lib import rdfvalue -class TimeRange(object): +class TimeRange: """An object representing a closed time-range. Attributes: @@ -25,8 +25,9 @@ def __init__(self, start: rdfvalue.RDFDatetime, end: rdfvalue.RDFDatetime): """ if start > end: raise ValueError( - "Invalid time-range: %s > %s." % (start.AsMicrosecondsSinceEpoch(), - end.AsMicrosecondsSinceEpoch())) + "Invalid time-range: %s > %s." + % (start.AsMicrosecondsSinceEpoch(), end.AsMicrosecondsSinceEpoch()) + ) self._start = start self._end = end diff --git a/grr/core/grr_response_core/lib/time_utils_test.py b/grr/core/grr_response_core/lib/time_utils_test.py index e8826709c2..f66dba4d55 100644 --- a/grr/core/grr_response_core/lib/time_utils_test.py +++ b/grr/core/grr_response_core/lib/time_utils_test.py @@ -12,14 +12,17 @@ class TimeUtilsTest(absltest.TestCase): def testInvalidTimeRange(self): - with self.assertRaisesWithLiteralMatch(ValueError, - "Invalid time-range: 2000 > 1000."): + with self.assertRaisesWithLiteralMatch( + ValueError, "Invalid time-range: 2000 > 1000." + ): time_utils.TimeRange( - rdfvalue.RDFDatetime(2000), rdfvalue.RDFDatetime(1000)) + rdfvalue.RDFDatetime(2000), rdfvalue.RDFDatetime(1000) + ) def testIncludesTimeRange(self): time_range = time_utils.TimeRange( - rdfvalue.RDFDatetime(1000), rdfvalue.RDFDatetime(2000)) + rdfvalue.RDFDatetime(1000), rdfvalue.RDFDatetime(2000) + ) self.assertFalse(time_range.Includes(rdfvalue.RDFDatetime(500))) self.assertTrue(time_range.Includes(rdfvalue.RDFDatetime(1000))) self.assertTrue(time_range.Includes(rdfvalue.RDFDatetime(1500))) diff --git a/grr/core/grr_response_core/lib/type_info.py b/grr/core/grr_response_core/lib/type_info.py index 00776c7a10..5507c6b0f5 100644 --- a/grr/core/grr_response_core/lib/type_info.py +++ b/grr/core/grr_response_core/lib/type_info.py @@ -7,9 +7,7 @@ """ import logging -from typing import Iterable -from typing import Optional -from typing import Text +from typing import Iterable, Optional from grr_response_core.lib import rdfvalue from grr_response_core.lib import serialization @@ -37,12 +35,14 @@ class TypeInfoObject(metaclass=MetaclassRegistry): # The delegate type this TypeInfoObject manages. _type = None - def __init__(self, - name="", - default=None, - description="", - friendly_name="", - hidden=False): + def __init__( + self, + name="", + default=None, + description="", + friendly_name="", + hidden=False, + ): """Build a TypeInfo type descriptor. Args: @@ -99,9 +99,11 @@ def ToString(self, value): def Help(self): """Returns a helpful string describing this type info.""" - return "%s\n Description: %s\n Default: %s" % (self.name, - self.description, - self.GetDefault()) + return "%s\n Description: %s\n Default: %s" % ( + self.name, + self.description, + self.GetDefault(), + ) class RDFValueType(TypeInfoObject): @@ -141,9 +143,11 @@ def Validate(self, value): # Try to coerce the type to the correct rdf_class. try: return self.rdfclass(value) - except rdfvalue.InitializeError: - raise TypeValueError("Value for arg %s should be an %s" % - (self.name, self.rdfclass.__name__)) + except rdfvalue.InitializeError as e: + raise TypeValueError( + "Value for arg %s should be an %s" + % (self.name, self.rdfclass.__name__) + ) from e return value @@ -211,12 +215,14 @@ def Validate(self, value): r = self.rdfclass() r.FromDict(value) return r - except (AttributeError, TypeError, rdfvalue.InitializeError): + except (AttributeError, TypeError, rdfvalue.InitializeError) as e: # AttributeError is raised if value contains items that don't # belong to the given rdfstruct. # TypeError will be raised if value is not a dict-like object. - raise TypeValueError("Value for arg %s should be an %s" % - (self.name, self.rdfclass.__name__)) + raise TypeValueError( + "Value for arg %s should be an %s" + % (self.name, self.rdfclass.__name__) + ) from e return value @@ -250,10 +256,13 @@ def __iter__(self): def __str__(self): result = "\n ".join( - ["%s: %s" % (x.name, x.description) for x in self.descriptors]) + ["%s: %s" % (x.name, x.description) for x in self.descriptors] + ) return "\n %s\n\n" % ( - self.__class__.__name__, result) + self.__class__.__name__, + result, + ) def __add__(self, other): return self.Add(other) @@ -288,8 +297,7 @@ def Remove(self, *descriptor_names): new_descriptor_map.pop(name, None) new_descriptors = [ - desc for desc in self.descriptors - if desc in new_descriptor_map.values() + desc for desc in self.descriptors if desc in new_descriptor_map.values() ] return TypeDescriptorSet(*new_descriptors) @@ -385,24 +393,24 @@ def ToString(self, value): class String(TypeInfoObject): """A String type.""" - _type = Text + _type = str - def __init__(self, default: Text = "", **kwargs): - precondition.AssertType(default, Text) + def __init__(self, default: str = "", **kwargs): + precondition.AssertType(default, str) super().__init__(default=default, **kwargs) - def Validate(self, value: Text) -> Text: - if not isinstance(value, Text): + def Validate(self, value: str) -> str: + if not isinstance(value, str): raise TypeValueError("'{}' is not a valid string".format(value)) return value - def FromString(self, string: Text) -> Text: - precondition.AssertType(string, Text) + def FromString(self, string: str) -> str: + precondition.AssertType(string, str) return string - def ToString(self, value: Text) -> Text: - precondition.AssertType(value, Text) + def ToString(self, value: str) -> str: + precondition.AssertType(value, str) return value @@ -421,11 +429,11 @@ def Validate(self, value: bytes) -> bytes: return value - def FromString(self, string: Text) -> bytes: - precondition.AssertType(string, Text) + def FromString(self, string: str) -> bytes: + precondition.AssertType(string, str) return string.encode("utf-8") - def ToString(self, value: bytes) -> Text: + def ToString(self, value: bytes) -> str: precondition.AssertType(value, bytes) return value.decode("utf-8") @@ -453,6 +461,7 @@ def FromString(self, string): class Float(Integer): """Type info describing a float.""" + _type = float def Validate(self, value): @@ -496,7 +505,7 @@ def __init__(self, choices=None, validator=None, **kwargs): Args: choices: list of available choices validator: validator to use for each of the list *items* the validator for - the top level is a list. + the top level is a list. **kwargs: passed through to parent class. """ self.choices = choices diff --git a/grr/core/grr_response_core/lib/type_info_test.py b/grr/core/grr_response_core/lib/type_info_test.py index 658a0bacc9..7de7e1b4c5 100644 --- a/grr/core/grr_response_core/lib/type_info_test.py +++ b/grr/core/grr_response_core/lib/type_info_test.py @@ -27,8 +27,8 @@ def testTypeInfoStringObjects(self): self.assertRaises(type_info.TypeValueError, a.Validate, 1) self.assertRaises(type_info.TypeValueError, a.Validate, None) a.Validate("test") - a.Validate(u"test") - a.Validate(u"/test-Îñ铁网åţî[öñåļ(îžåţîờñ") + a.Validate("test") + a.Validate("/test-Îñ铁网åţî[öñåļ(îžåţîờñ") def testTypeInfoNumberObjects(self): """Test the type info objects behave as expected.""" @@ -45,8 +45,9 @@ def testTypeInfoListObjects(self): self.assertRaises(type_info.TypeValueError, a.Validate, "test") self.assertRaises(type_info.TypeValueError, a.Validate, None) self.assertRaises(type_info.TypeValueError, a.Validate, ["test"]) - self.assertRaises(type_info.TypeValueError, a.Validate, - [rdf_paths.PathSpec()]) + self.assertRaises( + type_info.TypeValueError, a.Validate, [rdf_paths.PathSpec()] + ) a.Validate([1, 2, 3]) def testTypeInfoListConvertsObjectsOnValidation(self): @@ -105,11 +106,13 @@ def testTypeDescriptorSet(self): type_infos = [ type_info.String(name="output", default="analysis/{p}/{u}-{t}"), type_info.String( - description="Profile to use.", name="profile", default=""), + description="Profile to use.", name="profile", default="" + ), type_info.String( description="A comma separated list of plugins.", name="plugins", - default=""), + default="", + ), ] info = type_info.TypeDescriptorSet( @@ -118,11 +121,17 @@ def testTypeDescriptorSet(self): type_infos[2], ) - new_info = type_info.TypeDescriptorSet(type_infos[0],) + new_info = type_info.TypeDescriptorSet( + type_infos[0], + ) - updated_info = new_info + type_info.TypeDescriptorSet(type_infos[1],) + updated_info = new_info + type_info.TypeDescriptorSet( + type_infos[1], + ) - updated_info += type_info.TypeDescriptorSet(type_infos[2],) + updated_info += type_info.TypeDescriptorSet( + type_infos[2], + ) self.assertEqual(info.descriptor_map, updated_info.descriptor_map) self.assertCountEqual(info.descriptors, updated_info.descriptors) diff --git a/grr/core/grr_response_core/lib/util/aead.py b/grr/core/grr_response_core/lib/util/aead.py index 6c632f56a0..6aa4aca654 100644 --- a/grr/core/grr_response_core/lib/util/aead.py +++ b/grr/core/grr_response_core/lib/util/aead.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with utilities for AEAD streams.""" + import io import itertools import os diff --git a/grr/core/grr_response_core/lib/util/body.py b/grr/core/grr_response_core/lib/util/body.py index ee7a14b5c1..3d7d5f3741 100644 --- a/grr/core/grr_response_core/lib/util/body.py +++ b/grr/core/grr_response_core/lib/util/body.py @@ -1,11 +1,10 @@ #!/usr/bin/env python """A module with utilities for working with the Sleuthkit's body format.""" + import enum import io import stat - -from typing import Iterator -from typing import Optional +from typing import Iterator, Optional from grr_response_proto import timeline_pb2 diff --git a/grr/core/grr_response_core/lib/util/body_test.py b/grr/core/grr_response_core/lib/util/body_test.py index 73555143c0..9fe740dae9 100644 --- a/grr/core/grr_response_core/lib/util/body_test.py +++ b/grr/core/grr_response_core/lib/util/body_test.py @@ -183,12 +183,12 @@ def testPathWithNewline(self): def testPathWithQuote(self): entry = timeline_pb2.TimelineEntry() - entry.path = "/foo\"bar".encode("utf-8") + entry.path = '/foo"bar'.encode("utf-8") stream = body.Stream(iter([entry])) content = b"".join(stream).decode("utf-8") - self.assertIn("|/foo\"bar|", content) + self.assertIn('|/foo"bar|', content) def testPathWithBackslash(self): entry = timeline_pb2.TimelineEntry() diff --git a/grr/core/grr_response_core/lib/util/cache.py b/grr/core/grr_response_core/lib/util/cache.py index eeeacb11e9..131b6d076b 100644 --- a/grr/core/grr_response_core/lib/util/cache.py +++ b/grr/core/grr_response_core/lib/util/cache.py @@ -49,8 +49,8 @@ def Foo(id): NOTE 2: all decorated functions' arguments have to be hashable. Args: - min_time_between_calls: An rdfvalue.Duration specifying the minimal - time to pass between 2 consecutive function calls with same arguments. + min_time_between_calls: An rdfvalue.Duration specifying the minimal time to + pass between 2 consecutive function calls with same arguments. Returns: A Python function decorator. diff --git a/grr/core/grr_response_core/lib/util/cache_test.py b/grr/core/grr_response_core/lib/util/cache_test.py index 2eedffea22..64b9163934 100644 --- a/grr/core/grr_response_core/lib/util/cache_test.py +++ b/grr/core/grr_response_core/lib/util/cache_test.py @@ -19,7 +19,8 @@ def setUp(self): def testCallsFunctionEveryTimeWhenMinTimeBetweenCallsZero(self): decorated = cache.WithLimitedCallFrequency(rdfvalue.Duration(0))( - self.mock_fn) + self.mock_fn + ) for _ in range(10): decorated() @@ -60,8 +61,8 @@ def testCacheIsCleanedAfterMinTimeBetweenCallsHasElapsed(self): def testCallsFunctionOnceInGivenTimeRangeWhenMinTimeBetweenCallsNonZero(self): decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - self.mock_fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(self.mock_fn) now = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(now): @@ -81,8 +82,8 @@ def testCallsFunctionOnceInGivenTimeRangeWhenMinTimeBetweenCallsNonZero(self): def testCachingIsDonePerArguments(self): decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - self.mock_fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(self.mock_fn) now = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(now): @@ -120,8 +121,8 @@ def Fn(): return self.mock_fn() decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - Fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(Fn) results = [] @@ -159,8 +160,8 @@ def Fn(x): return x decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - Fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(Fn) def T(): decorated(1) @@ -185,8 +186,8 @@ def testPropagatesExceptions(self): mock_fn.__name__ = "foo" # Expected by functools.wraps. decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - mock_fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(mock_fn) with self.assertRaises(ValueError): decorated() @@ -196,8 +197,8 @@ def testExceptionIsNotCached(self): mock_fn.__name__ = "foo" # Expected by functools.wraps. decorated = cache.WithLimitedCallFrequency( - rdfvalue.Duration.From(30, rdfvalue.SECONDS))( - mock_fn) + rdfvalue.Duration.From(30, rdfvalue.SECONDS) + )(mock_fn) for _ in range(10): with self.assertRaises(ValueError): diff --git a/grr/core/grr_response_core/lib/util/chunked.py b/grr/core/grr_response_core/lib/util/chunked.py index 21d1c5f47e..68876f9660 100644 --- a/grr/core/grr_response_core/lib/util/chunked.py +++ b/grr/core/grr_response_core/lib/util/chunked.py @@ -2,9 +2,7 @@ """A module with utilities for a very simple chunked serialization format.""" import struct -from typing import IO -from typing import Iterator -from typing import Optional +from typing import IO, Iterator, Optional class Error(Exception): @@ -34,8 +32,9 @@ def Write(buf: IO[bytes], chunk: bytes) -> None: buf.write(chunk) -def Read(buf: IO[bytes], - max_chunk_size: Optional[int] = None) -> Optional[bytes]: +def Read( + buf: IO[bytes], max_chunk_size: Optional[int] = None +) -> Optional[bytes]: """Reads a single chunk from the input buffer. Args: @@ -68,26 +67,29 @@ def Read(buf: IO[bytes], # informative exception message. if max_chunk_size is not None and count > max_chunk_size: - raise ChunkSizeTooBigError(f"Malformed input: chunk size {count} " - f"is bigger than {max_chunk_size}") + raise ChunkSizeTooBigError( + f"Malformed input: chunk size {count} is bigger than {max_chunk_size}" + ) chunk = buf.read(count) if len(chunk) != count: raise ChunkTruncatedError( f"Malformed input: chunk size {count} " - f"is bigger than actual number of bytes read {len(chunk)}") + f"is bigger than actual number of bytes read {len(chunk)}" + ) return chunk -def ReadAll(buf: IO[bytes], - max_chunk_size: Optional[int] = None) -> Iterator[bytes]: +def ReadAll( + buf: IO[bytes], max_chunk_size: Optional[int] = None +) -> Iterator[bytes]: """Reads all the chunks from the input buffer (until the end). Args: buf: An input buffer to read the chunks from. max_chunk_size: If set, will raise if chunk's size is larger than a given - value. + value. Yields: Chunks of bytes stored in the buffer. diff --git a/grr/core/grr_response_core/lib/util/chunked_test.py b/grr/core/grr_response_core/lib/util/chunked_test.py index 28a6ec95f3..e3f20de828 100644 --- a/grr/core/grr_response_core/lib/util/chunked_test.py +++ b/grr/core/grr_response_core/lib/util/chunked_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io from absl.testing import absltest diff --git a/grr/core/grr_response_core/lib/util/collection_test.py b/grr/core/grr_response_core/lib/util/collection_test.py index 85746fc21d..896001b381 100644 --- a/grr/core/grr_response_core/lib/util/collection_test.py +++ b/grr/core/grr_response_core/lib/util/collection_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_core.lib.util import collection @@ -195,46 +194,14 @@ def testSingleKeys(self): def testMultipleKeys(self): in_dicts = {"a": [1, 2], "b": [3, 4], "c": [5, 6]} out_dicts = [ - { - "a": 1, - "b": 3, - "c": 5 - }, - { - "a": 1, - "b": 3, - "c": 6 - }, - { - "a": 1, - "b": 4, - "c": 5 - }, - { - "a": 1, - "b": 4, - "c": 6 - }, - { - "a": 2, - "b": 3, - "c": 5 - }, - { - "a": 2, - "b": 3, - "c": 6 - }, - { - "a": 2, - "b": 4, - "c": 5 - }, - { - "a": 2, - "b": 4, - "c": 6 - }, + {"a": 1, "b": 3, "c": 5}, + {"a": 1, "b": 3, "c": 6}, + {"a": 1, "b": 4, "c": 5}, + {"a": 1, "b": 4, "c": 6}, + {"a": 2, "b": 3, "c": 5}, + {"a": 2, "b": 3, "c": 6}, + {"a": 2, "b": 4, "c": 5}, + {"a": 2, "b": 4, "c": 6}, ] self.assertCountEqual(list(collection.DictProduct(in_dicts)), out_dicts) diff --git a/grr/core/grr_response_core/lib/util/filesystem.py b/grr/core/grr_response_core/lib/util/filesystem.py index 04328350e8..bb176088c9 100644 --- a/grr/core/grr_response_core/lib/util/filesystem.py +++ b/grr/core/grr_response_core/lib/util/filesystem.py @@ -5,16 +5,12 @@ import os import platform import stat - -from typing import Dict -from typing import NamedTuple -from typing import Optional -from typing import Text +from typing import Dict, NamedTuple, Optional from grr_response_core.lib.util import precondition -class Stat(object): +class Stat: """A wrapper around standard `os.[l]stat` function. The standard API for using `stat` results is very clunky and unpythonic. @@ -29,7 +25,7 @@ class Stat(object): """ @classmethod - def FromPath(cls, path: Text, follow_symlink: bool = True) -> "Stat": + def FromPath(cls, path: str, follow_symlink: bool = True) -> "Stat": """Returns stat information about the given OS path, calling os.[l]stat. Args: @@ -62,10 +58,12 @@ def FromPath(cls, path: Text, follow_symlink: bool = True) -> "Stat": return cls(path=path, stat_obj=stat_obj, symlink_target=target) - def __init__(self, - path: Text, - stat_obj: os.stat_result, - symlink_target: Optional[Text] = None) -> None: + def __init__( + self, + path: str, + stat_obj: os.stat_result, + symlink_target: Optional[str] = None, + ) -> None: """Wrap an existing stat result in a `filesystem.Stat` instance. Args: @@ -82,7 +80,7 @@ def __init__(self, def GetRaw(self) -> os.stat_result: return self._stat - def GetPath(self) -> Text: + def GetPath(self) -> str: return self._path def GetLinuxFlags(self) -> int: @@ -125,7 +123,7 @@ def GetChangeTime(self) -> int: def GetDevice(self) -> int: return self._stat.st_dev - def GetSymlinkTarget(self) -> Optional[Text]: + def GetSymlinkTarget(self) -> Optional[str]: return self._symlink_target def IsDirectory(self) -> bool: @@ -166,6 +164,7 @@ def _FetchLinuxFlags(self) -> int: try: # This import is Linux-specific. import fcntl # pylint: disable=g-import-not-at-top + buf = array.array("l", [0]) # TODO(user):pytype: incorrect type spec for fcntl.ioctl # pytype: disable=wrong-arg-types @@ -186,7 +185,7 @@ def _FetchOsxFlags(self) -> int: return self._stat.st_flags # pytype: disable=attribute-error -class StatCache(object): +class StatCache: """An utility class for avoiding unnecessary syscalls to `[l]stat`. This class is useful in situations where manual bookkeeping of stat results @@ -195,12 +194,12 @@ class StatCache(object): smart enough to cache symlink results when a file is not a symlink. """ - _Key = NamedTuple("_Key", (("path", Text), ("follow_symlink", bool))) # pylint: disable=invalid-name + _Key = NamedTuple("_Key", (("path", str), ("follow_symlink", bool))) # pylint: disable=invalid-name def __init__(self): self._cache: Dict[StatCache._Key, Stat] = {} - def Get(self, path: Text, follow_symlink: bool = True) -> Stat: + def Get(self, path: str, follow_symlink: bool = True) -> Stat: """Stats given file or returns a cached result if available. Args: diff --git a/grr/core/grr_response_core/lib/util/filesystem_test.py b/grr/core/grr_response_core/lib/util/filesystem_test.py index 91422f7703..f28e5d8a77 100644 --- a/grr/core/grr_response_core/lib/util/filesystem_test.py +++ b/grr/core/grr_response_core/lib/util/filesystem_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import datetime import io import os @@ -82,8 +81,10 @@ def testSocket(self): @unittest.skipIf(platform.system() == "Windows", "requires Unix-like system") def testSymlink(self): - with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath, \ - temp.AutoTempFilePath() as temp_filepath: + with ( + temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath, + temp.AutoTempFilePath() as temp_filepath, + ): with io.open(temp_filepath, "wb") as fd: fd.write(b"foobar") @@ -135,11 +136,14 @@ def testGetOsxFlags(self): self.assertFalse(stat.GetOsxFlags() & self.UF_IMMUTABLE) self.assertEqual(stat.GetLinuxFlags(), 0) - @unittest.skipIf(platform.system() == "Windows", - "Windows does not support os.symlink().") + @unittest.skipIf( + platform.system() == "Windows", "Windows does not support os.symlink()." + ) def testGetFlagsSymlink(self): - with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath, \ - temp.AutoTempFilePath() as temp_filepath: + with ( + temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath, + temp.AutoTempFilePath() as temp_filepath, + ): temp_linkpath = os.path.join(temp_dirpath, "foo") os.symlink(temp_filepath, temp_linkpath) @@ -148,8 +152,9 @@ def testGetFlagsSymlink(self): self.assertEqual(stat.GetLinuxFlags(), 0) self.assertEqual(stat.GetOsxFlags(), 0) - @unittest.skipIf(platform.system() == "Windows", - "Windows does not support socket.AF_UNIX.") + @unittest.skipIf( + platform.system() == "Windows", "Windows does not support socket.AF_UNIX." + ) def testGetFlagsSocket(self): with temp.AutoTempDirPath(remove_non_empty=True) as temp_dirpath: temp_socketpath = os.path.join(temp_dirpath, "foo") @@ -163,8 +168,9 @@ def testGetFlagsSocket(self): # pylint: disable=line-too-long # [1]: https://unix.stackexchange.com/questions/367008/why-is-socket-path-length-limited-to-a-hundred-chars # pylint: enable=ling-too-long - if ((platform.system() == "Linux" and len(temp_socketpath) > 108) or - (platform.system() == "Darwin" and len(temp_socketpath) > 104)): + if (platform.system() == "Linux" and len(temp_socketpath) > 108) or ( + platform.system() == "Darwin" and len(temp_socketpath) > 104 + ): message = "Generated path '{}' is too long for a socket path" self.skipTest(message.format(temp_socketpath)) @@ -238,8 +244,9 @@ def testBasicUsage(self): self.assertEqual(other_baz_stat.GetSize(), 9) self.assertFalse(stat_mock.FromPath.called) - @unittest.skipIf(platform.system() == "Windows", - "Windows does not support os.symlink().") + @unittest.skipIf( + platform.system() == "Windows", "Windows does not support os.symlink()." + ) def testFollowSymlink(self): with io.open(self.Path("foo"), "wb") as fd: fd.write(b"123456") diff --git a/grr/core/grr_response_core/lib/util/gzchunked.py b/grr/core/grr_response_core/lib/util/gzchunked.py index f9c41f5eb7..f28edaf0f6 100644 --- a/grr/core/grr_response_core/lib/util/gzchunked.py +++ b/grr/core/grr_response_core/lib/util/gzchunked.py @@ -5,7 +5,6 @@ import io import os import struct - from typing import Iterator from grr_response_core.lib.util import chunked diff --git a/grr/core/grr_response_core/lib/util/gzchunked_test.py b/grr/core/grr_response_core/lib/util/gzchunked_test.py index 91d058cfd6..d37678014f 100644 --- a/grr/core/grr_response_core/lib/util/gzchunked_test.py +++ b/grr/core/grr_response_core/lib/util/gzchunked_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import gzip import io import os diff --git a/grr/core/grr_response_core/lib/util/io.py b/grr/core/grr_response_core/lib/util/io.py index 189713e9f1..159e355a1f 100644 --- a/grr/core/grr_response_core/lib/util/io.py +++ b/grr/core/grr_response_core/lib/util/io.py @@ -1,8 +1,8 @@ #!/usr/bin/env python """A module with utilities for working with I/O.""" + import io -from typing import IO -from typing import Iterator +from typing import IO, Iterator def Chunk(stream: IO[bytes], size: int) -> Iterator[bytes]: diff --git a/grr/core/grr_response_core/lib/util/io_test.py b/grr/core/grr_response_core/lib/util/io_test.py index 0ed85c1273..e10e5057fb 100644 --- a/grr/core/grr_response_core/lib/util/io_test.py +++ b/grr/core/grr_response_core/lib/util/io_test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import io import itertools + from absl.testing import absltest from grr_response_core.lib.util import io as ioutil diff --git a/grr/core/grr_response_core/lib/util/iterator.py b/grr/core/grr_response_core/lib/util/iterator.py index 820715a998..3bc93fc5de 100644 --- a/grr/core/grr_response_core/lib/util/iterator.py +++ b/grr/core/grr_response_core/lib/util/iterator.py @@ -1,8 +1,7 @@ #!/usr/bin/env python """A module with utilities for working with iterators.""" -from typing import Iterator -from typing import Optional -from typing import TypeVar + +from typing import Iterator, Optional, TypeVar _T = TypeVar("_T") diff --git a/grr/core/grr_response_core/lib/util/precondition.py b/grr/core/grr_response_core/lib/util/precondition.py index 6ee11a35d5..ffb21e68d6 100644 --- a/grr/core/grr_response_core/lib/util/precondition.py +++ b/grr/core/grr_response_core/lib/util/precondition.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with assertion functions for checking preconditions.""" + import collections from collections import abc import re @@ -106,6 +107,8 @@ def ValidateClientId(client_id): def ValidateFlowId(flow_id): """Raises, if the given value is not a valid FlowId string.""" _ValidateStringId("flow_id", flow_id) - if (len(flow_id) not in [8, 16] or - re.match(r"^[0-9a-fA-F]*$", flow_id) is None): + if ( + len(flow_id) not in [8, 16] + or re.match(r"^[0-9a-fA-F]*$", flow_id) is None + ): raise ValueError("Flow id has incorrect format: `%s`" % flow_id) diff --git a/grr/core/grr_response_core/lib/util/random.py b/grr/core/grr_response_core/lib/util/random.py index f9e0b3bec4..f4b80e41ce 100644 --- a/grr/core/grr_response_core/lib/util/random.py +++ b/grr/core/grr_response_core/lib/util/random.py @@ -1,8 +1,8 @@ #!/usr/bin/env python """A module with utilities for optimized pseudo-random number generation.""" + import os import struct - import threading from typing import Callable, List @@ -29,7 +29,8 @@ def UInt32() -> int: except IndexError: data = os.urandom(struct.calcsize("=L") * _random_buffer_size) _random_buffer.extend( - struct.unpack("=" + "L" * _random_buffer_size, data)) + struct.unpack("=" + "L" * _random_buffer_size, data) + ) return _random_buffer.pop() diff --git a/grr/core/grr_response_core/lib/util/random_test.py b/grr/core/grr_response_core/lib/util/random_test.py index c0b3cedd5a..50497f8604 100644 --- a/grr/core/grr_response_core/lib/util/random_test.py +++ b/grr/core/grr_response_core/lib/util/random_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io import os from unittest import mock diff --git a/grr/core/grr_response_core/lib/util/retry.py b/grr/core/grr_response_core/lib/util/retry.py index b505e20eca..f5ebb6a4ae 100644 --- a/grr/core/grr_response_core/lib/util/retry.py +++ b/grr/core/grr_response_core/lib/util/retry.py @@ -1,18 +1,13 @@ #!/usr/bin/env python """A module with utilities for retrying function execution.""" + import dataclasses import datetime import functools import logging import random import time -from typing import Callable -from typing import Generic -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union +from typing import Callable, Generic, Optional, Tuple, Type, TypeVar, Union @dataclasses.dataclass @@ -27,6 +22,7 @@ class Opts: jitter: A random jitter to add to delay between retries. sleep: A sleep function used for delaying retries. """ + attempts: int = 1 init_delay: datetime.timedelta = datetime.timedelta(0) diff --git a/grr/core/grr_response_core/lib/util/sqlite.py b/grr/core/grr_response_core/lib/util/sqlite.py index 097cc2f120..79f32ee794 100644 --- a/grr/core/grr_response_core/lib/util/sqlite.py +++ b/grr/core/grr_response_core/lib/util/sqlite.py @@ -3,17 +3,13 @@ import contextlib import io -from typing import IO -from typing import Iterator -from typing import Text -from typing import Tuple - import sqlite3 +from typing import IO, Iterator, Tuple from grr_response_core.lib.util import temp -class ConnectionContext(object): +class ConnectionContext: """A wrapper class around an SQLite connection object. This class wraps a low-level SQLite connection that is error-prone and does @@ -28,7 +24,7 @@ def __init__(self, conn: sqlite3.Connection) -> None: """ self._conn = conn - def Query(self, query: Text) -> Iterator[Tuple]: # pylint: disable=g-bare-generic + def Query(self, query: str) -> Iterator[Tuple]: # pylint: disable=g-bare-generic """Queries the underlying database. Args: diff --git a/grr/core/grr_response_core/lib/util/sqlite_test.py b/grr/core/grr_response_core/lib/util/sqlite_test.py index d8bd6aaa28..aec01a37f6 100644 --- a/grr/core/grr_response_core/lib/util/sqlite_test.py +++ b/grr/core/grr_response_core/lib/util/sqlite_test.py @@ -1,10 +1,9 @@ #!/usr/bin/env python - import contextlib import io +import sqlite3 from absl.testing import absltest -import sqlite3 from grr_response_core.lib.util import sqlite from grr_response_core.lib.util import temp diff --git a/grr/core/grr_response_core/lib/util/statx.py b/grr/core/grr_response_core/lib/util/statx.py index 2a8eb62841..f4c3d17068 100644 --- a/grr/core/grr_response_core/lib/util/statx.py +++ b/grr/core/grr_response_core/lib/util/statx.py @@ -6,6 +6,7 @@ [1]: https://www.man7.org/linux/man-pages/man2/statx.2.html """ + import ctypes import functools import operator @@ -21,6 +22,7 @@ # TODO(hanuszczak): Migrate to data classes on support for 3.7 is available. class Result(NamedTuple): """A result of extended stat collection.""" + # A bitmask with extra file attributes. attributes: int # A number of hard links. @@ -115,13 +117,13 @@ class _StatxStruct(ctypes.Structure): def rdev(self) -> int: """Device identifier (if the file represents a device).""" # https://elixir.bootlin.com/linux/v5.6/source/tools/include/nolibc/nolibc.h - return ((self.stx_rdev_major & 0xfff) << 8) | (self.stx_rdev_minor & 0xff) + return ((self.stx_rdev_major & 0xFFF) << 8) | (self.stx_rdev_minor & 0xFF) @property def dev(self) -> int: """Device identifier of the filesystem the file resides on.""" # https://elixir.bootlin.com/linux/v5.6/source/tools/include/nolibc/nolibc.h - return ((self.stx_dev_major & 0xfff) << 8) | (self.stx_dev_minor & 0xff) + return ((self.stx_dev_major & 0xFFF) << 8) | (self.stx_dev_minor & 0xFF) # https://elixir.bootlin.com/linux/v3.4/source/include/linux/fcntl.h @@ -139,18 +141,22 @@ def dev(self) -> int: _STATX_CTIME = 0x00000080 _STATX_INO = 0x00000100 _STATX_SIZE = 0x00000200 -_STATX_ALL = functools.reduce(operator.__or__, [ - _STATX_MODE, - _STATX_NLINK, - _STATX_UID, - _STATX_GID, - _STATX_ATIME, - _STATX_BTIME, - _STATX_MTIME, - _STATX_CTIME, - _STATX_INO, - _STATX_SIZE, -], 0) +_STATX_ALL = functools.reduce( + operator.__or__, + [ + _STATX_MODE, + _STATX_NLINK, + _STATX_UID, + _STATX_GID, + _STATX_ATIME, + _STATX_BTIME, + _STATX_MTIME, + _STATX_CTIME, + _STATX_INO, + _STATX_SIZE, + ], + 0, +) if platform.system() == "Linux": @@ -178,8 +184,13 @@ def dev(self) -> int: def _GetImplLinuxStatx(path: bytes) -> Result: """A Linux-specific stat implementation through `statx`.""" c_result = _StatxStruct() - c_status = _statx(0, path, _AT_SYMLINK_NOFOLLOW | _AT_STATX_SYNC_AS_STAT, - _STATX_ALL, ctypes.pointer(c_result)) + c_status = _statx( + 0, + path, + _AT_SYMLINK_NOFOLLOW | _AT_STATX_SYNC_AS_STAT, + _STATX_ALL, + ctypes.pointer(c_result), + ) if c_status != 0: raise OSError(f"Failed to stat '{path}', error code: {c_status}") @@ -197,7 +208,8 @@ def _GetImplLinuxStatx(path: bytes) -> Result: ctime_ns=c_result.stx_ctime.nanos, mtime_ns=c_result.stx_mtime.nanos, rdev=c_result.rdev, - dev=c_result.dev) + dev=c_result.dev, + ) _GetImpl = _GetImplLinuxStatx BTIME_SUPPORT = True @@ -220,7 +232,8 @@ def _GetImplLinux(path: bytes) -> Result: ctime_ns=stat_obj.st_ctime_ns, mtime_ns=stat_obj.st_mtime_ns, rdev=stat_obj.st_rdev, - dev=stat_obj.st_dev) + dev=stat_obj.st_dev, + ) _GetImpl = _GetImplLinux BTIME_SUPPORT = False @@ -247,7 +260,8 @@ def _GetImplMacos(path: bytes) -> Result: ctime_ns=stat_obj.st_ctime_ns, mtime_ns=stat_obj.st_mtime_ns, rdev=stat_obj.st_rdev, - dev=stat_obj.st_dev) + dev=stat_obj.st_dev, + ) _GetImpl = _GetImplMacos BTIME_SUPPORT = True @@ -278,7 +292,8 @@ def _GetImplWindows(path: bytes) -> Result: ctime_ns=stat_obj.st_ctime_ns, mtime_ns=stat_obj.st_mtime_ns, rdev=0, # Not available. - dev=stat_obj.st_dev) + dev=stat_obj.st_dev, + ) _GetImpl = _GetImplWindows BTIME_SUPPORT = True diff --git a/grr/core/grr_response_core/lib/util/temp.py b/grr/core/grr_response_core/lib/util/temp.py index d803172060..d78dd8ab08 100644 --- a/grr/core/grr_response_core/lib/util/temp.py +++ b/grr/core/grr_response_core/lib/util/temp.py @@ -6,7 +6,6 @@ import shutil import tempfile from typing import Optional -from typing import Text from absl import flags @@ -15,7 +14,7 @@ FLAGS = flags.FLAGS -def _TestTempRootPath() -> Optional[Text]: +def _TestTempRootPath() -> Optional[str]: """Returns a default root path for storing temporary files during tests.""" # `TEST_TMPDIR` and `FLAGS.test_tmpdir` are only defined only for test # environments. For non-test code, we use the default temporary directory. @@ -36,7 +35,7 @@ def _TestTempRootPath() -> Optional[Text]: return test_tmpdir -def TempDirPath(suffix: Text = "", prefix: Text = "tmp") -> Text: +def TempDirPath(suffix: str = "", prefix: str = "tmp") -> str: """Creates a temporary directory based on the environment configuration. The directory will be placed in folder as specified by the `TEST_TMPDIR` @@ -50,15 +49,17 @@ def TempDirPath(suffix: Text = "", prefix: Text = "tmp") -> Text: Returns: An absolute path to the created directory. """ - precondition.AssertType(suffix, Text) - precondition.AssertType(prefix, Text) + precondition.AssertType(suffix, str) + precondition.AssertType(prefix, str) return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=_TestTempRootPath()) -def TempFilePath(suffix: Text = "", - prefix: Text = "tmp", - dir: Optional[Text] = None) -> Text: # pylint: disable=redefined-builtin +def TempFilePath( + suffix: str = "", + prefix: str = "tmp", + dir: Optional[str] = None, # pylint: disable=redefined-builtin +) -> str: """Creates a temporary file based on the environment configuration. If no directory is specified the file will be placed in folder as specified by @@ -80,9 +81,9 @@ def TempFilePath(suffix: Text = "", ValueError: If the specified directory is not part of the default test temporary directory. """ - precondition.AssertType(suffix, Text) - precondition.AssertType(prefix, Text) - precondition.AssertOptionalType(dir, Text) + precondition.AssertType(suffix, str) + precondition.AssertType(prefix, str) + precondition.AssertOptionalType(dir, str) root = _TestTempRootPath() if not dir: @@ -98,7 +99,7 @@ def TempFilePath(suffix: Text = "", return path -class AutoTempDirPath(object): +class AutoTempDirPath: """Creates a temporary directory based on the environment configuration. The directory will be placed in folder as specified by the `TEST_TMPDIR` @@ -108,29 +109,33 @@ class AutoTempDirPath(object): This object is a context manager and the directory is automatically removed when it goes out of scope. - Args: - suffix: A suffix to end the directory name with. - prefix: A prefix to begin the directory name with. - remove_non_empty: If set to `True` the directory removal will succeed even - if it is not empty. - Returns: An absolute path to the created directory. """ - def __init__(self, - suffix: Text = "", - prefix: Text = "tmp", - remove_non_empty: bool = False): - precondition.AssertType(suffix, Text) - precondition.AssertType(prefix, Text) + def __init__( + self, + suffix: str = "", + prefix: str = "tmp", + remove_non_empty: bool = False, + ): + """Creates a temporary directory based on the environment configuration. + + Args: + suffix: A suffix to end the directory name with. + prefix: A prefix to begin the directory name with. + remove_non_empty: If set to `True` the directory removal will succeed even + if it is not empty. + """ + precondition.AssertType(suffix, str) + precondition.AssertType(prefix, str) precondition.AssertType(remove_non_empty, bool) self.suffix = suffix self.prefix = prefix self.remove_non_empty = remove_non_empty - def __enter__(self) -> Text: + def __enter__(self) -> str: self.path = TempDirPath(suffix=self.suffix, prefix=self.prefix) return self.path @@ -145,7 +150,7 @@ def __exit__(self, exc_type, exc_value, traceback): os.rmdir(self.path) -class AutoTempFilePath(object): +class AutoTempFilePath: """Creates a temporary file based on the environment configuration. If no directory is specified the file will be placed in folder as specified by @@ -158,11 +163,6 @@ class AutoTempFilePath(object): This object is a context manager and the associated file is automatically removed when it goes out of scope. - Args: - suffix: A suffix to end the file name with. - prefix: A prefix to begin the file name with. - dir: A directory to place the file in. - Returns: An absolute path to the created file. @@ -171,21 +171,32 @@ class AutoTempFilePath(object): temporary directory. """ - def __init__(self, - suffix: Text = "", - prefix: Text = "tmp", - dir: Optional[Text] = None): # pylint: disable=redefined-builtin - precondition.AssertType(prefix, Text) - precondition.AssertType(suffix, Text) - precondition.AssertOptionalType(dir, Text) + def __init__( + self, + suffix: str = "", + prefix: str = "tmp", + dir: Optional[str] = None, # pylint: disable=redefined-builtin + ): + """Creates a temporary file based on the environment configuration. + + Args: + suffix: A suffix to end the file name with. + prefix: A prefix to begin the file name with. + dir: A directory to place the file in. + """ + precondition.AssertType(suffix, str) + precondition.AssertType(prefix, str) + precondition.AssertType(suffix, str) + precondition.AssertOptionalType(dir, str) self.suffix = suffix self.prefix = prefix self.dir = dir - def __enter__(self) -> Text: + def __enter__(self) -> str: self.path = TempFilePath( - suffix=self.suffix, prefix=self.prefix, dir=self.dir) + suffix=self.suffix, prefix=self.prefix, dir=self.dir + ) return self.path def __exit__(self, exc_type, exc_value, traceback): diff --git a/grr/core/grr_response_core/lib/util/temp_test.py b/grr/core/grr_response_core/lib/util/temp_test.py index cd65870d0b..7811e7c9a9 100644 --- a/grr/core/grr_response_core/lib/util/temp_test.py +++ b/grr/core/grr_response_core/lib/util/temp_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io import os diff --git a/grr/core/grr_response_core/lib/util/text.py b/grr/core/grr_response_core/lib/util/text.py index 527be2bdfe..451d8165d3 100644 --- a/grr/core/grr_response_core/lib/util/text.py +++ b/grr/core/grr_response_core/lib/util/text.py @@ -3,12 +3,10 @@ import binascii -from typing import Text - from grr_response_core.lib.util import precondition -def Asciify(data: bytes) -> Text: +def Asciify(data: bytes) -> str: """Turns given bytes to human-readable ASCII representation. All ASCII-representable bytes are turned into proper characters, whereas all @@ -25,7 +23,7 @@ def Asciify(data: bytes) -> Text: return repr(data)[2:-1] -def Hexify(data: bytes) -> Text: +def Hexify(data: bytes) -> str: """Turns given bytes to its hex representation. It works just like `binascii.hexlify` but always returns string objects rather @@ -50,5 +48,5 @@ def Unescape(string: str) -> str: Returns: An unescaped version of the input string. """ - precondition.AssertType(string, Text) + precondition.AssertType(string, str) return string.encode("utf-8").decode("unicode_escape") diff --git a/grr/core/grr_response_core/lib/util/text_test.py b/grr/core/grr_response_core/lib/util/text_test.py index e777d782e5..8041c06221 100644 --- a/grr/core/grr_response_core/lib/util/text_test.py +++ b/grr/core/grr_response_core/lib/util/text_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_core.lib.util import text @@ -39,7 +38,7 @@ def testWhitespace(self): def testQuotemark(self): self.assertEqual(text.Unescape("\\'"), "'") - self.assertEqual(text.Unescape("\\\""), "\"") + self.assertEqual(text.Unescape('\\"'), '"') def testMany(self): self.assertEqual(text.Unescape("foo\\n\\'bar\\'\nbaz"), "foo\n'bar'\nbaz") diff --git a/grr/core/grr_response_core/lib/util/timeline.py b/grr/core/grr_response_core/lib/util/timeline.py index 4f68448a50..5acc7c7728 100644 --- a/grr/core/grr_response_core/lib/util/timeline.py +++ b/grr/core/grr_response_core/lib/util/timeline.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module defining timeline-related utility functions.""" + from typing import Iterator from grr_response_core.lib.util import gzchunked @@ -13,6 +14,7 @@ def _ParseTimelineEntryProto(bstr: bytes) -> timeline_pb2.TimelineEntry: def DeserializeTimelineEntryProtoStream( - entries: Iterator[bytes],) -> Iterator[timeline_pb2.TimelineEntry]: + entries: Iterator[bytes], +) -> Iterator[timeline_pb2.TimelineEntry]: """Deserializes given gzchunked stream chunks into TimelineEntry protos.""" return map(_ParseTimelineEntryProto, gzchunked.Deserialize(entries)) diff --git a/grr/core/grr_response_core/lib/utils.py b/grr/core/grr_response_core/lib/utils.py index 5b79289d9a..5b69d48318 100644 --- a/grr/core/grr_response_core/lib/utils.py +++ b/grr/core/grr_response_core/lib/utils.py @@ -18,7 +18,7 @@ import tempfile import threading import time -from typing import Generic, Iterable, Optional, Text, TypeVar +from typing import Generic, Iterable, Optional, TypeVar import weakref import zipfile @@ -44,7 +44,7 @@ def Wrapped(self, *args): return Wrapped -class TempDirectory(object): +class TempDirectory: """A self cleaning temporary directory. Do not use this function for any client related temporary files! Use @@ -75,13 +75,15 @@ def NewFunction(self, *args, **kw): class InterruptableThread(threading.Thread): """A class which exits once the main thread exits.""" - def __init__(self, - target=None, - args=None, - kwargs=None, - sleep_time=10, - name: Optional[Text] = None, - **kw): + def __init__( + self, + target=None, + args=None, + kwargs=None, + sleep_time=10, + name: Optional[str] = None, + **kw, + ): self.exit = False self.last_run = 0 self.target = target @@ -121,13 +123,13 @@ def run(self): self.last_run = now() # Exit if the main thread disappears. - while (time and not self.exit and - now() < self.last_run + self.sleep_time): + while time and not self.exit and now() < self.last_run + self.sleep_time: sleep(1) -class Node(object): +class Node: """An entry to a linked list.""" + next = None prev = None data = None @@ -146,7 +148,7 @@ def __repr__(self): # TODO(user):pytype: self.next and self.prev are assigned to self but then # are used in AppendNode in a very different way. Should be redesigned. # pytype: disable=attribute-error -class LinkedList(object): +class LinkedList: """A simple doubly linked list used for fast caches.""" def __init__(self): @@ -211,7 +213,7 @@ def Print(self): # pytype: enable=attribute-error -class FastStore(object): +class FastStore: """This is a cache which expires objects in oldest first manner. This implementation first appeared in PyFlag. @@ -408,7 +410,8 @@ def HouseKeeper(): TimeBasedCache.active_caches = weakref.WeakSet() # This thread is designed to never finish. TimeBasedCache.house_keeper_thread = InterruptableThread( - name="HouseKeeperThread", target=HouseKeeper) + name="HouseKeeperThread", target=HouseKeeper + ) TimeBasedCache.house_keeper_thread.start() TimeBasedCache.active_caches.add(self) @@ -445,12 +448,12 @@ def Get(self, key): return stored.value -class Struct(object): +class Struct: """A baseclass for parsing binary Structs.""" # Derived classes must initialize this into an array of (format, # name) tuples. - _fields = None + _fields: list[tuple[str, str]] def __init__(self, data): """Parses ourselves from data.""" @@ -458,7 +461,7 @@ def __init__(self, data): self.size = struct.calcsize(format_str) try: - parsed_data = struct.unpack(format_str, data[:self.size]) + parsed_data = struct.unpack(format_str, data[: self.size]) except struct.error: raise ParsingError("Unable to parse") @@ -492,7 +495,7 @@ def SmartUnicode(string): Returns: a unicode object. """ - if isinstance(string, Text): + if isinstance(string, str): return string if isinstance(string, bytes): @@ -518,14 +521,14 @@ def FormatAsHexString(num, width=None, prefix="0x"): return "%s%s" % (prefix, hex_str) -def FormatAsTimestamp(timestamp: int) -> Text: +def FormatAsTimestamp(timestamp: int) -> str: if not timestamp: return "-" return time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(timestamp)) -def NormalizePath(path: Text, sep: Text = "/") -> Text: +def NormalizePath(path: str, sep: str = "/") -> str: """A sane implementation of os.path.normpath. The standard implementation treats leading / and // as different leading to @@ -544,8 +547,8 @@ def NormalizePath(path: Text, sep: Text = "/") -> Text: that would result in the system opening the same physical file will produce the same normalized path. """ - precondition.AssertType(path, Text) - precondition.AssertType(sep, Text) + precondition.AssertType(path, str) + precondition.AssertType(sep, str) if not path: return sep @@ -586,7 +589,7 @@ def NormalizePath(path: Text, sep: Text = "/") -> Text: # TODO(hanuszczak): The linter complains for a reason here, the signature of # this function should be fixed as soon as possible. -def JoinPath(stem: Text = "", *parts: Text) -> Text: # pylint: disable=keyword-arg-before-vararg +def JoinPath(stem: str = "", *parts: str) -> str: # pylint: disable=keyword-arg-before-vararg """A sane version of os.path.join. The intention here is to append the stem to the path. The standard module @@ -600,8 +603,8 @@ def JoinPath(stem: Text = "", *parts: Text) -> Text: # pylint: disable=keyword- Returns: a normalized path. """ - precondition.AssertIterableType(parts, Text) - precondition.AssertType(stem, Text) + precondition.AssertIterableType(parts, str) + precondition.AssertType(stem, str) result = (stem + NormalizePath("/".join(parts))).replace("//", "/") result = result.rstrip("/") @@ -637,9 +640,9 @@ def GeneratePassphrase(length=20): return "".join(random.choice(valid_chars) for i in range(length)) -def PassphraseCallback(verify=False, - prompt1="Enter passphrase:", - prompt2="Verify passphrase:"): +def PassphraseCallback( + verify=False, prompt1="Enter passphrase:", prompt2="Verify passphrase:" +): """A utility function to read a passphrase from stdin.""" while 1: try: @@ -664,7 +667,7 @@ def FormatNumberAsString(num): return "%3.1f%s" % (num, "TB") -class NotAValue(object): +class NotAValue: pass @@ -714,7 +717,8 @@ def GetValueAndReset(self): """Gets stream buffer since the last GetValueAndReset() call.""" if not self._stream: raise ArchiveAlreadyClosedError( - "Attempting to get a value from a closed stream.") + "Attempting to get a value from a closed stream." + ) value = self._stream.getvalue() self._stream.seek(0) @@ -736,7 +740,8 @@ def __init__(self, compression=zipfile.ZIP_STORED): self._compression = compression self._stream = RollingMemoryStream() self._zipfile = zipfile.ZipFile( - self._stream, mode="w", compression=compression, allowZip64=True) + self._stream, mode="w", compression=compression, allowZip64=True + ) self._zipopen = None @@ -800,7 +805,8 @@ def Close(self): def WriteFromFD(self, src_fd, arcname=None, compress_type=None, st=None): """A convenience method for adding an entire file to the ZIP archive.""" yield self.WriteFileHeader( - arcname=arcname, compress_type=compress_type, st=st) + arcname=arcname, compress_type=compress_type, st=st + ) while True: buf = src_fd.read(1024 * 1024) @@ -835,7 +841,8 @@ def __init__(self): # TODO(user):pytype: self._stream should be a valid IO object. # pytype: disable=wrong-arg-types self._tar_fd = tarfile.open( - mode="w:gz", fileobj=self._stream, encoding="utf-8") + mode="w:gz", fileobj=self._stream, encoding="utf-8" + ) # pytype: enable=wrong-arg-types self._ResetState() @@ -860,7 +867,7 @@ def Close(self): def WriteFileHeader(self, arcname=None, st=None): """Writes file header.""" - precondition.AssertType(arcname, Text) + precondition.AssertType(arcname, str) if st is None: raise ValueError("Stat object can't be None.") @@ -888,10 +895,14 @@ def WriteFileChunk(self, chunk): def WriteFileFooter(self): """Writes file footer (finishes the file).""" + if self.cur_info is None: + raise ValueError("WriteFileHeader() must be called first.") if self.cur_file_size != self.cur_info.size: - raise IOError("Incorrect file size: st_size=%d, but written %d bytes." % - (self.cur_info.size, self.cur_file_size)) + raise IOError( + "Incorrect file size: st_size=%d, but written %d bytes." + % (self.cur_info.size, self.cur_file_size) + ) # TODO(user):pytype: BLOCKSIZE/NUL constants are not visible to type # checker. @@ -944,7 +955,7 @@ def output_size(self): return self._stream.tell() -class Stubber(object): +class Stubber: """A context manager for doing simple stubs.""" def __init__(self, module, target_name, stub): @@ -1018,8 +1029,9 @@ def MergeDirectories(src: str, dst: str) -> None: def ResolveHostnameToIP(host, port): """Resolves a hostname to an IP address.""" - ip_addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, 0, - socket.IPPROTO_TCP) + ip_addrs = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, 0, socket.IPPROTO_TCP + ) # getaddrinfo returns tuples (family, socktype, proto, canonname, sockaddr). # We are interested in sockaddr which is in turn a tuple # (address, port) for IPv4 or (address, port, flow info, scope id) @@ -1032,8 +1044,11 @@ def ResolveHostnameToIP(host, port): def ProcessIdString(): - return "%s@%s:%d" % (psutil.Process().name(), socket.gethostname(), - os.getpid()) + return "%s@%s:%d" % ( + psutil.Process().name(), + socket.gethostname(), + os.getpid(), + ) def RegexListDisjunction(regex_list: Iterable[bytes]): diff --git a/grr/core/grr_response_core/lib/utils_test.py b/grr/core/grr_response_core/lib/utils_test.py index 29e89d542b..1e89cf5e0f 100644 --- a/grr/core/grr_response_core/lib/utils_test.py +++ b/grr/core/grr_response_core/lib/utils_test.py @@ -145,7 +145,8 @@ def FormatAsHexStringTest(self): # No trailing "L". self.assertEqual(utils.FormatAsHexString(int(1e19)), "0x8ac7230489e80000") self.assertEqual( - utils.FormatAsHexString(int(1e19), 5), "0x8ac7230489e80000") + utils.FormatAsHexString(int(1e19), 5), "0x8ac7230489e80000" + ) def testXor(self): test_str = b"foobar4815162342" diff --git a/grr/core/grr_response_core/path_detection/core.py b/grr/core/grr_response_core/path_detection/core.py index 270552c78b..869e203f42 100644 --- a/grr/core/grr_response_core/path_detection/core.py +++ b/grr/core/grr_response_core/path_detection/core.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """The path detection interface (base) class definitions.""" - import abc import shlex @@ -23,7 +22,7 @@ def SplitIntoComponents(str_in): stripped). """ - if str_in.startswith(("\"", "'")): + if str_in.startswith(('"', "'")): return shlex.split(str_in) else: components = str_in.split(" ", 1) @@ -42,11 +41,10 @@ def Extract(self, components): Args: components: Source string represented as a list of components. Components - are generated by applying SplitIntoComponents to a string. - Components is effectively a space-separated representation of a - source string. " ".join(components) should produce the original string - (with pairs of single/double quotes stripped). See - SplitIntoComponents for details. + are generated by applying SplitIntoComponents to a string. Components is + effectively a space-separated representation of a source string. " + ".join(components) should produce the original string (with pairs of + single/double quotes stripped). See SplitIntoComponents for details. Returns: A list of extracted paths (as strings). @@ -70,7 +68,7 @@ def Process(self, path): raise NotImplementedError() -class Detector(object): +class Detector: """Configurable class that implements all detection steps.""" def __init__(self, extractors=None, post_processors=None): diff --git a/grr/core/grr_response_core/path_detection/core_test.py b/grr/core/grr_response_core/path_detection/core_test.py index 3703967781..b36e1a6146 100644 --- a/grr/core/grr_response_core/path_detection/core_test.py +++ b/grr/core/grr_response_core/path_detection/core_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests core paths detection logic.""" - from absl import app from grr_response_core.path_detection import core @@ -15,54 +14,67 @@ def testSplitsBySpaceInTrivialCases(self): """Test it splits components by space in trivial cases.""" self.assertEqual( core.SplitIntoComponents(r"C:\Program Files\Realtek\Audio\blah.exe -s"), - [r"C:\Program", r"Files\Realtek\Audio\blah.exe", r"-s"]) + [r"C:\Program", r"Files\Realtek\Audio\blah.exe", r"-s"], + ) self.assertEqual( core.SplitIntoComponents( - r"rundll32.exe C:\Windows\system32\advpack.dll,DelNodeRunDLL32"), - [r"rundll32.exe", r"C:\Windows\system32\advpack.dll,DelNodeRunDLL32"]) + r"rundll32.exe C:\Windows\system32\advpack.dll,DelNodeRunDLL32" + ), + [r"rundll32.exe", r"C:\Windows\system32\advpack.dll,DelNodeRunDLL32"], + ) def testStripsDoubleQuotes(self): """Test it strips double quotes.""" self.assertEqual( core.SplitIntoComponents( - "\"C:\\Program Files\\Realtek\\Audio\\blah.exe\""), - [r"C:\Program Files\Realtek\Audio\blah.exe"]) + '"C:\\Program Files\\Realtek\\Audio\\blah.exe"' + ), + [r"C:\Program Files\Realtek\Audio\blah.exe"], + ) def testStripsSingleQuotes(self): """Test it strips single quotes.""" self.assertEqual( core.SplitIntoComponents(r"'C:\Program Files\Realtek\Audio\blah.exe'"), - [r"C:\Program Files\Realtek\Audio\blah.exe"]) + [r"C:\Program Files\Realtek\Audio\blah.exe"], + ) def testStripsSingleQuotesEvenIfFirstComponentIsNotQuoted(self): """Test it strips single quotes even if first component is not quoted.""" self.assertEqual( core.SplitIntoComponents( - r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe'"), - [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe"]) + r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe'" + ), + [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe"], + ) def testStripsSingleQuotesEvenIfThereIsCommaAfterQuote(self): """Test it strips single quotes even if there's a comma after the quote.""" self.assertEqual( core.SplitIntoComponents( - r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe',SomeFunc"), - [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe,SomeFunc"]) + r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe',SomeFunc" + ), + [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe,SomeFunc"], + ) def testStripsDoubleQuotesEvenIfFirstComponentIsNotQuoted(self): """Test it strips double quotes even first component is not quoted.""" self.assertEqual( core.SplitIntoComponents( - "rundll32.exe " - "\"C:\\Program Files\\Realtek\\Audio\\blah.exe\""), - [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe"]) + 'rundll32.exe "C:\\Program Files\\Realtek\\Audio\\blah.exe"' + ), + [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe"], + ) def testStripsDoubleQuotesEvenIfThereIsCommaAfterQuote(self): """Test it strips double quotes even if there's a comma after the quote.""" self.assertEqual( core.SplitIntoComponents( "rundll32.exe " - "\"C:\\Program Files\\Realtek\\Audio\\blah.exe\",SomeFunc"), - [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe,SomeFunc"]) + '"C:\\Program Files\\Realtek\\Audio\\blah.exe",SomeFunc' + ), + [r"rundll32.exe", r"C:\Program Files\Realtek\Audio\blah.exe,SomeFunc"], + ) class TestExtractor(core.Extractor): @@ -100,34 +112,38 @@ def testReturnsWhatSingleExtractorReturns(self): def testReturnsCombinedResultsFromTwoExtractors(self): """Test it returns combined results from two extractors.""" detector = core.Detector( - extractors=[TestExtractor(multiplier=2), - TestExtractor(multiplier=3)]) + extractors=[TestExtractor(multiplier=2), TestExtractor(multiplier=3)] + ) self.assertEqual(detector.Detect("a b"), set(["b_0", "b_1", "b_2"])) def testAppliesPostProcessorToExtractedPaths(self): """Test it applies the post processor to extracted paths.""" detector = core.Detector( extractors=[TestExtractor(multiplier=2)], - post_processors=[TestPostProcessor("_bar")]) + post_processors=[TestPostProcessor("_bar")], + ) self.assertEqual(detector.Detect("a b"), set(["b_0_bar", "b_1_bar"])) def testPostProcessorMayReturnMultipleProcessedPaths(self): """Test the post processor may return multiple processed paths.""" detector = core.Detector( extractors=[TestExtractor(multiplier=2)], - post_processors=[TestPostProcessor("_bar", count=2)]) + post_processors=[TestPostProcessor("_bar", count=2)], + ) self.assertEqual( detector.Detect("a b"), - set(["b_0_bar", "b_1_bar", "b_0_bar_bar", "b_1_bar_bar"])) + set(["b_0_bar", "b_1_bar", "b_0_bar_bar", "b_1_bar_bar"]), + ) def testAppliesMultiplePostProcessorsToExtractedPaths(self): """Test it applies multiple post processors to extracted paths.""" detector = core.Detector( extractors=[TestExtractor(multiplier=2)], - post_processors=[TestPostProcessor("_foo"), - TestPostProcessor("_bar")]) + post_processors=[TestPostProcessor("_foo"), TestPostProcessor("_bar")], + ) self.assertEqual( - detector.Detect("a b"), set(["b_0_foo_bar", "b_1_foo_bar"])) + detector.Detect("a b"), set(["b_0_foo_bar", "b_1_foo_bar"]) + ) def main(argv): diff --git a/grr/core/grr_response_core/path_detection/windows.py b/grr/core/grr_response_core/path_detection/windows.py index 8cfe22642f..ddeed43e02 100644 --- a/grr/core/grr_response_core/path_detection/windows.py +++ b/grr/core/grr_response_core/path_detection/windows.py @@ -1,10 +1,8 @@ #!/usr/bin/env python """Windows paths detection classes.""" - import re - from grr_response_core.path_detection import core @@ -29,7 +27,7 @@ def Extract(self, components): if rundll_index == -1: return [] - components = components[(rundll_index + 1):] + components = components[(rundll_index + 1) :] # We expect components after "rundll32.exe" to point at a DLL and a # function. For example: @@ -43,8 +41,16 @@ def Extract(self, components): class ExecutableExtractor(core.Extractor): """Extractor for ordinary paths.""" - EXECUTABLE_EXTENSIONS = ("exe", "com", "bat", "dll", "msi", "sys", "scr", - "pif") + EXECUTABLE_EXTENSIONS = ( + "exe", + "com", + "bat", + "dll", + "msi", + "sys", + "scr", + "pif", + ) def Extract(self, components): """Extracts interesting paths from a given path. @@ -58,7 +64,7 @@ def Extract(self, components): for index, component in enumerate(components): if component.lower().endswith(self.EXECUTABLE_EXTENSIONS): - extracted_path = " ".join(components[0:index + 1]) + extracted_path = " ".join(components[0 : index + 1]) return [extracted_path] return [] @@ -80,15 +86,16 @@ def __init__(self, vars_map): Args: vars_map: Dictionary of "string" -> "string|list", i.e. a mapping of - environment variables names to their suggested values or to lists - of their suggested values. + environment variables names to their suggested values or to lists of + their suggested values. """ super(core.PostProcessor, self).__init__() self.vars_map = {} for var_name, value in vars_map.items(): var_regex = re.compile( - re.escape("%" + var_name + "%"), flags=re.IGNORECASE) + re.escape("%" + var_name + "%"), flags=re.IGNORECASE + ) self.vars_map[var_name.lower()] = (var_regex, value) def Process(self, path): @@ -151,7 +158,8 @@ def CreateWindowsRegistryExecutablePathsDetector(vars_map=None): """ return core.Detector( extractors=[RunDllExtractor(), ExecutableExtractor()], - post_processors=[EnvVarsPostProcessor(vars_map or {})],) + post_processors=[EnvVarsPostProcessor(vars_map or {})], + ) def DetectExecutablePaths(source_values, vars_map=None): @@ -160,8 +168,8 @@ def DetectExecutablePaths(source_values, vars_map=None): Args: source_values: A list of strings to detect paths in. vars_map: Dictionary of "string" -> "string|list", i.e. a mapping of - environment variables names to their suggested values or to lists - of their suggested values. + environment variables names to their suggested values or to lists of their + suggested values. Yields: A list of detected paths (as strings). diff --git a/grr/core/grr_response_core/path_detection/windows_test.py b/grr/core/grr_response_core/path_detection/windows_test.py index 416bd2aa3f..bbcd911e29 100644 --- a/grr/core/grr_response_core/path_detection/windows_test.py +++ b/grr/core/grr_response_core/path_detection/windows_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for windows paths detection logic.""" - from absl import app from grr_response_core.path_detection import windows @@ -30,17 +29,20 @@ def testRunDllCheckIsCaseInsensitive(self): def testReturnsAllComponentsExceptForTheFirstOneIfFirstOneIsRunDll(self): """Test it returns all components except for the first 'rundll'.""" self.assertEqual( - self.extractor.Extract(["rundll32.exe", "b", "c", "d"]), ["b c d"]) + self.extractor.Extract(["rundll32.exe", "b", "c", "d"]), ["b c d"] + ) def testReturnsThirdOutOfThreeComponentsIfFirstTwoAreRunDll(self): """Test it returns 3rd out of 2 components if the first two are rundll.""" self.assertEqual( - self.extractor.Extract([r"C:\some", r"path\rundll32.exe", "b"]), ["b"]) + self.extractor.Extract([r"C:\some", r"path\rundll32.exe", "b"]), ["b"] + ) def testStripsFunctionName(self): """Test it strips the function name.""" self.assertEqual( - self.extractor.Extract(["rundll32.exe", "b,FuncName"]), ["b"]) + self.extractor.Extract(["rundll32.exe", "b,FuncName"]), ["b"] + ) class ExecutableExtractorTest(test_lib.GRRBaseTest): @@ -76,64 +78,71 @@ def testReplacesOneVariable(self): processor = windows.EnvVarsPostProcessor({"foo": "bar"}) self.assertEqual( processor.Process(r"C:\WINDOWS\%foo%\something"), - [r"C:\WINDOWS\bar\something"]) + [r"C:\WINDOWS\bar\something"], + ) def testReplacesTwoVariables(self): """Test it correctly replaces two variables.""" processor = windows.EnvVarsPostProcessor({"foo": "bar", "blah": "blahblah"}) self.assertEqual( processor.Process(r"C:\WINDOWS\%foo%\%blah%\something"), - [r"C:\WINDOWS\bar\blahblah\something"]) + [r"C:\WINDOWS\bar\blahblah\something"], + ) def testVariableReplacementIsCaseInsensitive(self): """Test variable replacement is case insensitive.""" processor = windows.EnvVarsPostProcessor({"foo": "bar"}) self.assertEqual( processor.Process(r"C:\WINDOWS\%FoO%\something"), - [r"C:\WINDOWS\bar\something"]) + [r"C:\WINDOWS\bar\something"], + ) def testGeneratesMultipleReplacementsIfReplacementIsList(self): """Test it generates multiple replacements if replacement is a list.""" processor = windows.EnvVarsPostProcessor({"foo": ["bar", "blah"]}) self.assertEqual( set(processor.Process(r"C:\WINDOWS\%foo%\something")), - set([r"C:\WINDOWS\bar\something", r"C:\WINDOWS\blah\something"])) + set([r"C:\WINDOWS\bar\something", r"C:\WINDOWS\blah\something"]), + ) def testVariableValueIsStableInASinglePath(self): """Test it keeps variable value stable in a single path.""" processor = windows.EnvVarsPostProcessor({"foo": ["bar", "blah"]}) self.assertEqual( set(processor.Process(r"C:\WINDOWS\%foo%\%foo%\something")), - set([ - r"C:\WINDOWS\bar\bar\something", r"C:\WINDOWS\blah\blah\something" - ])) + set( + [r"C:\WINDOWS\bar\bar\something", r"C:\WINDOWS\blah\blah\something"] + ), + ) def testGeneratesProductIfTwoReplacementsHaveMultipleValues(self): """Test it generates a product if two replacements have multiple values.""" - processor = windows.EnvVarsPostProcessor({ - "foo": ["bar1", "bar2"], - "blah": ["blah1", "blah2"] - }) + processor = windows.EnvVarsPostProcessor( + {"foo": ["bar1", "bar2"], "blah": ["blah1", "blah2"]} + ) self.assertEqual( set(processor.Process(r"C:\WINDOWS\%foo%\%blah%\something")), set([ r"C:\WINDOWS\bar1\blah1\something", r"C:\WINDOWS\bar1\blah2\something", r"C:\WINDOWS\bar2\blah1\something", - r"C:\WINDOWS\bar2\blah2\something" - ])) + r"C:\WINDOWS\bar2\blah2\something", + ]), + ) def testReplacesSystemRootPrefixWithSystemRootVariable(self): """Test it replaces system root prefix with a system root variable.""" processor = windows.EnvVarsPostProcessor({"systemroot": "blah"}) self.assertEqual( - processor.Process(r"\SystemRoot\foo\bar"), [r"blah\foo\bar"]) + processor.Process(r"\SystemRoot\foo\bar"), [r"blah\foo\bar"] + ) def testReplacesSystem32PrefixWithSystemRootVariable(self): """Test it replaces system32 prefix with a system root variable.""" processor = windows.EnvVarsPostProcessor({"systemroot": "blah"}) self.assertEqual( - processor.Process(r"System32\foo\bar"), [r"blah\system32\foo\bar"]) + processor.Process(r"System32\foo\bar"), [r"blah\system32\foo\bar"] + ) class WindowsRegistryExecutablePathsDetectorTest(test_lib.GRRBaseTest): @@ -141,12 +150,20 @@ class WindowsRegistryExecutablePathsDetectorTest(test_lib.GRRBaseTest): def testExtractsPathsFromNonRunDllStrings(self): """Test it extracts paths from non-rundll strings.""" - fixture = [(r"C:\Program Files\Realtek\Audio\blah.exe -s", - r"C:\Program Files\Realtek\Audio\blah.exe"), - (r"'C:\Program Files\Realtek\Audio\blah.exe' -s", - r"C:\Program Files\Realtek\Audio\blah.exe"), - (r"C:\Program Files\NVIDIA Corporation\nwiz.exe /quiet /blah", - r"C:\Program Files\NVIDIA Corporation\nwiz.exe")] + fixture = [ + ( + r"C:\Program Files\Realtek\Audio\blah.exe -s", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), + ( + r"'C:\Program Files\Realtek\Audio\blah.exe' -s", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), + ( + r"C:\Program Files\NVIDIA Corporation\nwiz.exe /quiet /blah", + r"C:\Program Files\NVIDIA Corporation\nwiz.exe", + ), + ] for in_str, result in fixture: self.assertEqual(list(windows.DetectExecutablePaths([in_str])), [result]) @@ -154,32 +171,46 @@ def testExtractsPathsFromNonRunDllStrings(self): def testExctactsPathsFromRunDllStrings(self): """Test it extracts paths from rundll strings.""" fixture = [ - (r"rundll32.exe C:\Windows\system32\advpack.dll,DelNodeRunDLL32", - r"C:\Windows\system32\advpack.dll"), - (r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe',blah", - r"C:\Program Files\Realtek\Audio\blah.exe"), - (r"'rundll32.exe' 'C:\Program Files\Realtek\Audio\blah.exe',blah", - r"C:\Program Files\Realtek\Audio\blah.exe") + ( + r"rundll32.exe C:\Windows\system32\advpack.dll,DelNodeRunDLL32", + r"C:\Windows\system32\advpack.dll", + ), + ( + r"rundll32.exe 'C:\Program Files\Realtek\Audio\blah.exe',blah", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), + ( + r"'rundll32.exe' 'C:\Program Files\Realtek\Audio\blah.exe',blah", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), ] for in_str, result in fixture: self.assertEqual( set(windows.DetectExecutablePaths([in_str])), - set([result, "rundll32.exe"])) + set([result, "rundll32.exe"]), + ) def testReplacesEnvironmentVariable(self): """Test it replaces environment variables.""" mapping = { "programfiles": r"C:\Program Files", } - fixture = [(r"%ProgramFiles%\Realtek\Audio\blah.exe -s", - r"C:\Program Files\Realtek\Audio\blah.exe"), - (r"'%ProgramFiles%\Realtek\Audio\blah.exe' -s", - r"C:\Program Files\Realtek\Audio\blah.exe")] + fixture = [ + ( + r"%ProgramFiles%\Realtek\Audio\blah.exe -s", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), + ( + r"'%ProgramFiles%\Realtek\Audio\blah.exe' -s", + r"C:\Program Files\Realtek\Audio\blah.exe", + ), + ] for in_str, result in fixture: self.assertEqual( - list(windows.DetectExecutablePaths([in_str], mapping)), [result]) + list(windows.DetectExecutablePaths([in_str], mapping)), [result] + ) def testReplacesEnvironmentVariablesWithMultipleMappings(self): """Test it replaces environment variables with multiple mappings.""" @@ -191,18 +222,27 @@ def testReplacesEnvironmentVariablesWithMultipleMappings(self): ] } - fixture = [(r"%AppData%\Realtek\Audio\blah.exe -s", [ - r"C:\Users\foo\Application Data\Realtek\Audio\blah.exe", - r"C:\Users\bar\Application Data\Realtek\Audio\blah.exe" - ]), - (r"'%AppData%\Realtek\Audio\blah.exe' -s", [ - r"C:\Users\foo\Application Data\Realtek\Audio\blah.exe", - r"C:\Users\bar\Application Data\Realtek\Audio\blah.exe" - ])] + fixture = [ + ( + r"%AppData%\Realtek\Audio\blah.exe -s", + [ + r"C:\Users\foo\Application Data\Realtek\Audio\blah.exe", + r"C:\Users\bar\Application Data\Realtek\Audio\blah.exe", + ], + ), + ( + r"'%AppData%\Realtek\Audio\blah.exe' -s", + [ + r"C:\Users\foo\Application Data\Realtek\Audio\blah.exe", + r"C:\Users\bar\Application Data\Realtek\Audio\blah.exe", + ], + ), + ] for in_str, result in fixture: self.assertEqual( - set(windows.DetectExecutablePaths([in_str], mapping)), set(result)) + set(windows.DetectExecutablePaths([in_str], mapping)), set(result) + ) def main(argv): diff --git a/grr/core/grr_response_core/stats/default_stats_collector.py b/grr/core/grr_response_core/stats/default_stats_collector.py index b51c10e7dd..d7fdbde145 100644 --- a/grr/core/grr_response_core/stats/default_stats_collector.py +++ b/grr/core/grr_response_core/stats/default_stats_collector.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Default implementation for a stats-collector.""" - import abc from grr_response_core.lib import utils @@ -43,18 +42,23 @@ def _DefaultValue(self): def Get(self, fields=None): """Gets the metric value corresponding to the given field values.""" if not self._field_defs and fields: - raise ValueError("Metric was registered without fields, " - "but following fields were provided: %s." % (fields,)) + raise ValueError( + "Metric was registered without fields, " + "but following fields were provided: %s." % (fields,) + ) if self._field_defs and not fields: - raise ValueError("Metric was registered with fields (%s), " - "but no fields were provided." % self._field_defs) + raise ValueError( + "Metric was registered with fields (%s), but no fields were provided." + % self._field_defs + ) if self._field_defs and fields and len(self._field_defs) != len(fields): raise ValueError( "Metric was registered with %d fields (%s), but " - "%d fields were provided (%s)." % (len( - self._field_defs), self._field_defs, len(fields), fields)) + "%d fields were provided (%s)." + % (len(self._field_defs), self._field_defs, len(fields), fields) + ) metric_value = self._metric_values.get(_FieldsToKey(fields)) return self._DefaultValue() if metric_value is None else metric_value @@ -74,7 +78,8 @@ def Increment(self, delta, fields=None): """Increments counter value by a given delta.""" if delta < 0: raise ValueError( - "Counter increment should not be < 0 (received: %d)" % delta) + "Counter increment should not be < 0 (received: %d)" % delta + ) self._metric_values[_FieldsToKey(fields)] = self.Get(fields=fields) + delta @@ -94,8 +99,29 @@ class _EventMetric(_Metric): def __init__(self, bins, fields): super().__init__(fields) self._bins = bins or [ - 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8, 9, - 10, 15, 20, 50, 100 + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1, + 1.5, + 2, + 2.5, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 15, + 20, + 50, + 100, ] def _DefaultValue(self): @@ -159,17 +185,21 @@ def __init__(self): def _InitializeMetric(self, metadata): """See base class.""" field_defs = stats_utils.FieldDefinitionTuplesFromProtos( - metadata.fields_defs) + metadata.fields_defs + ) if metadata.metric_type == rdf_stats.MetricMetadata.MetricType.COUNTER: self._counter_metrics[metadata.varname] = _CounterMetric(field_defs) elif metadata.metric_type == rdf_stats.MetricMetadata.MetricType.EVENT: self._event_metrics[metadata.varname] = _EventMetric( - list(metadata.bins), field_defs) + list(metadata.bins), field_defs + ) elif metadata.metric_type == rdf_stats.MetricMetadata.MetricType.GAUGE: value_type = stats_utils.PythonTypeFromMetricValueType( - metadata.value_type) + metadata.value_type + ) self._gauge_metrics[metadata.varname] = _GaugeMetric( - value_type, field_defs) + value_type, field_defs + ) else: raise ValueError("Unknown metric type: %s." % metadata.metric_type) diff --git a/grr/core/grr_response_core/stats/default_stats_collector_test.py b/grr/core/grr_response_core/stats/default_stats_collector_test.py index 996a901563..2ce3120345 100644 --- a/grr/core/grr_response_core/stats/default_stats_collector_test.py +++ b/grr/core/grr_response_core/stats/default_stats_collector_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for the DefaultStatsCollector.""" - from absl import app from grr_response_core.stats import default_stats_collector diff --git a/grr/core/grr_response_core/stats/metrics.py b/grr/core/grr_response_core/stats/metrics.py index b40be26418..f0153fa248 100644 --- a/grr/core/grr_response_core/stats/metrics.py +++ b/grr/core/grr_response_core/stats/metrics.py @@ -26,7 +26,8 @@ def __init__(self, metadata): def GetValue(self, fields=None): """Returns the value of a given metric for given field values.""" return stats_collector_instance.Get().GetMetricValue( - self.name, fields=fields) + self.name, fields=fields + ) def GetFields(self): """Returns all field values for the given metric.""" @@ -49,12 +50,15 @@ def __init__(self, name, fields=(), docstring=None, units=None): value_type=rdf_stats.MetricMetadata.ValueType.INT, fields_defs=stats_utils.FieldDefinitionProtosFromTuples(fields), docstring=docstring, - units=units)) + units=units, + ) + ) def Increment(self, delta=1, fields=None): """Increments a counter metric by a given delta.""" stats_collector_instance.Get().IncrementCounter( - self.name, delta, fields=fields) + self.name, delta, fields=fields + ) def Counted(self, fields=None): """Returns a decorator that counts function calls.""" @@ -85,17 +89,21 @@ def __init__(self, name, value_type, fields=(), docstring=None, units=None): value_type=stats_utils.MetricValueTypeFromPythonType(value_type), fields_defs=stats_utils.FieldDefinitionProtosFromTuples(fields), docstring=docstring, - units=units)) + units=units, + ) + ) def SetValue(self, value, fields=None): """Sets value of a given gauge metric.""" stats_collector_instance.Get().SetGaugeValue( - self.name, value, fields=fields) + self.name, value, fields=fields + ) def SetCallback(self, callback, fields=None): """Attaches a callback to the given gauge metric.""" stats_collector_instance.Get().SetGaugeCallback( - self.name, callback, fields=fields) + self.name, callback, fields=fields + ) class Event(AbstractMetric): @@ -115,7 +123,9 @@ def __init__(self, name, bins=(), fields=(), docstring=None, units=None): value_type=rdf_stats.MetricMetadata.ValueType.DISTRIBUTION, fields_defs=stats_utils.FieldDefinitionProtosFromTuples(fields), docstring=docstring, - units=units)) + units=units, + ) + ) def RecordEvent(self, value, fields=None): """Records value corresponding to the given event metric.""" diff --git a/grr/core/grr_response_core/stats/metrics_test.py b/grr/core/grr_response_core/stats/metrics_test.py index 17f5cad24c..0204179ea7 100644 --- a/grr/core/grr_response_core/stats/metrics_test.py +++ b/grr/core/grr_response_core/stats/metrics_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for the metrics interface for stats collection.""" - from unittest import mock from absl.testing import absltest @@ -12,37 +11,45 @@ from grr.test_lib import stats_test_lib -class MetricsTest(stats_test_lib.StatsTestMixin, - stats_test_lib.StatsCollectorTestMixin, absltest.TestCase): +class MetricsTest( + stats_test_lib.StatsTestMixin, + stats_test_lib.StatsCollectorTestMixin, + absltest.TestCase, +): def testCounterRegistration(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): metrics.Counter("cfoo") self.assertIsNotNone(self.collector.GetMetricMetadata("cfoo")) def testGaugeRegistration(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): metrics.Gauge("gfoo", int) self.assertIsNotNone(self.collector.GetMetricMetadata("gfoo")) def testEventRegistration(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): metrics.Event("efoo") self.assertIsNotNone(self.collector.GetMetricMetadata("efoo")) def testCounterIncrement(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) with self.assertStatsCounterDelta(1, counter, fields=["baz"]): counter.Increment(fields=["baz"]) def testGetValue(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) self.assertEqual(counter.GetValue(["baz"]), 0) counter.Increment(fields=["baz"]) @@ -50,7 +57,8 @@ def testGetValue(self): def testGetFields(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) self.assertEmpty(counter.GetFields()) counter.Increment(fields=["baz"]) @@ -59,7 +67,8 @@ def testGetFields(self): def testCountedDecoratorIncrement(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) @counter.Counted(fields=["baz"]) @@ -71,7 +80,8 @@ def Foo(): def testSuccessesCountedDecoratorIncrement(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) @counter.SuccessesCounted(fields=["baz"]) @@ -83,7 +93,8 @@ def Foo(): def testErrorsCountedDecoratorIncrement(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): counter = metrics.Counter("cfoo", fields=[("bar", str)]) @counter.ErrorsCounted(fields=["baz"]) @@ -96,21 +107,24 @@ def Foo(): def testSetGaugeValue(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): gauge = metrics.Gauge("gfoo", int, fields=[("bar", str)]) with self.assertStatsCounterDelta(42, gauge, fields=["baz"]): gauge.SetValue(42, fields=["baz"]) def testRecordEvent(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): event = metrics.Event("efoo", fields=[("bar", str)]) with self.assertStatsCounterDelta(1, event, fields=["baz"]): event.RecordEvent(42, fields=["baz"]) def testTimedDecorator(self): with self.SetUpStatsCollector( - default_stats_collector.DefaultStatsCollector()): + default_stats_collector.DefaultStatsCollector() + ): event = metrics.Event("efoo", fields=[("bar", str)]) @event.Timed(fields=["baz"]) @@ -122,9 +136,11 @@ def Foo(): def testMetricCanBeRegisteredAfterStatsCollectorHasBeenSetUp(self): with mock.patch.multiple( - stats_collector_instance, _metadatas=[], _stats_singleton=None): + stats_collector_instance, _metadatas=[], _stats_singleton=None + ): stats_collector_instance.Set( - default_stats_collector.DefaultStatsCollector()) + default_stats_collector.DefaultStatsCollector() + ) counter = metrics.Counter("cfoo") counter.Increment(1) diff --git a/grr/core/grr_response_core/stats/stats_collector_instance.py b/grr/core/grr_response_core/stats/stats_collector_instance.py index a973ee85e9..20666cf942 100644 --- a/grr/core/grr_response_core/stats/stats_collector_instance.py +++ b/grr/core/grr_response_core/stats/stats_collector_instance.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Contains a stats-collector singleton shared across a GRR process.""" - import logging import threading diff --git a/grr/core/grr_response_core/stats/stats_test_utils.py b/grr/core/grr_response_core/stats/stats_test_utils.py index cd1d7f758a..4b91e75f07 100644 --- a/grr/core/grr_response_core/stats/stats_test_utils.py +++ b/grr/core/grr_response_core/stats/stats_test_utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Common tests for stats-collector implementations.""" - import abc import time from unittest import mock @@ -18,7 +17,8 @@ class StatsCollectorTest( stats_test_lib.StatsCollectorTestMixin, absltest.TestCase, - metaclass=abc.ABCMeta): + metaclass=abc.ABCMeta, +): """Stats collection tests. Each test method has uniquely-named metrics to accommodate implementations @@ -70,7 +70,8 @@ def testDecrementingCounterRaises(self): def testCounterWithFields(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter( - "testCounterWithFields_counter", fields=[("dimension", str)]) + "testCounterWithFields_counter", fields=[("dimension", str)] + ) # Test that default values for any fields values are 0." self.assertEqual(0, counter.GetValue(fields=["a"])) @@ -134,7 +135,8 @@ def testSimpleGauge(self): def testGaugeWithFields(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): int_gauge = metrics.Gauge( - "testGaugeWithFields_int_gauge", int, fields=[("dimension", str)]) + "testGaugeWithFields_int_gauge", int, fields=[("dimension", str)] + ) self.assertEqual(0, int_gauge.GetValue(fields=["dimension_value_1"])) self.assertEqual(0, int_gauge.GetValue(fields=["dimesnioN_value_2"])) @@ -162,7 +164,8 @@ def testGaugeWithCallback(self): def testSimpleEventMetric(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): event_metric = metrics.Event( - "testSimpleEventMetric_event_metric", bins=[0.0, 0.1, 0.2]) + "testSimpleEventMetric_event_metric", bins=[0.0, 0.1, 0.2] + ) data = event_metric.GetValue() self.assertAlmostEqual(0, data.sum) @@ -196,7 +199,8 @@ def testEventMetricWithFields(self): event_metric = metrics.Event( "testEventMetricWithFields_event_metric", bins=[0.0, 0.1, 0.2], - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) data = event_metric.GetValue(fields=["dimension_value_1"]) self.assertAlmostEqual(0, data.sum) @@ -222,10 +226,12 @@ def testEventMetricWithFields(self): def testRaisesOnImproperFieldsUsage1(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter("testRaisesOnImproperFieldsUsage1_counter") - int_gauge = metrics.Gauge("testRaisesOnImproperFieldsUsage1_int_gauge", - int) + int_gauge = metrics.Gauge( + "testRaisesOnImproperFieldsUsage1_int_gauge", int + ) event_metric = metrics.Event( - "testRaisesOnImproperFieldsUsage1_event_metric") + "testRaisesOnImproperFieldsUsage1_event_metric" + ) # Check for counters with self.assertRaises(ValueError): @@ -243,14 +249,17 @@ def testRaisesOnImproperFieldsUsage2(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter( "testRaisesOnImproperFieldsUsage2_counter", - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) int_gauge = metrics.Gauge( "testRaisesOnImproperFieldsUsage2_int_gauge", int, - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) event_metric = metrics.Event( "testRaisesOnImproperFieldsUsage2_event_metric", - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) # Check for counters with self.assertRaises(ValueError): @@ -281,34 +290,47 @@ def testGetAllMetricsMetadataWorksCorrectlyOnSimpleMetrics(self): metrics.Event(event_metric_name) metadatas = self.collector.GetAllMetricsMetadata() - self.assertEqual(metadatas[counter_name].metric_type, - rdf_stats.MetricMetadata.MetricType.COUNTER) + self.assertEqual( + metadatas[counter_name].metric_type, + rdf_stats.MetricMetadata.MetricType.COUNTER, + ) self.assertFalse(metadatas[counter_name].fields_defs) - self.assertEqual(metadatas[int_gauge_name].metric_type, - rdf_stats.MetricMetadata.MetricType.GAUGE) - self.assertEqual(metadatas[int_gauge_name].fields_defs, [ - rdf_stats.MetricFieldDefinition( - field_name="dimension", - field_type=rdf_stats.MetricFieldDefinition.FieldType.STR) - ]) - - self.assertEqual(metadatas[event_metric_name].metric_type, - rdf_stats.MetricMetadata.MetricType.EVENT) + self.assertEqual( + metadatas[int_gauge_name].metric_type, + rdf_stats.MetricMetadata.MetricType.GAUGE, + ) + self.assertEqual( + metadatas[int_gauge_name].fields_defs, + [ + rdf_stats.MetricFieldDefinition( + field_name="dimension", + field_type=rdf_stats.MetricFieldDefinition.FieldType.STR, + ) + ], + ) + + self.assertEqual( + metadatas[event_metric_name].metric_type, + rdf_stats.MetricMetadata.MetricType.EVENT, + ) self.assertFalse(metadatas[event_metric_name].fields_defs) def testGetMetricFieldsWorksCorrectly(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter( "testGetMetricFieldsWorksCorrectly_counter", - fields=[("dimension1", str), ("dimension2", str)]) + fields=[("dimension1", str), ("dimension2", str)], + ) int_gauge = metrics.Gauge( "testGetMetricFieldsWorksCorrectly_int_gauge", int, - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) event_metric = metrics.Event( "testGetMetricFieldsWorksCorrectly_event_metric", - fields=[("dimension", str)]) + fields=[("dimension", str)], + ) counter.Increment(fields=["b", "b"]) counter.Increment(fields=["a", "c"]) @@ -400,7 +422,8 @@ def testCombiningDecorators(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter("testCombiningDecorators_counter") event_metric = metrics.Event( - "testCombiningDecorators_event_metric", bins=[0.0, 0.1, 0.2]) + "testCombiningDecorators_event_metric", bins=[0.0, 0.1, 0.2] + ) @event_metric.Timed() @counter.Counted() @@ -419,7 +442,8 @@ def testExceptionHandling(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter("testExceptionHandling_counter") event_metric = metrics.Event( - "testExceptionHandling_event_metric", bins=[0, 0.1, 0.2]) + "testExceptionHandling_event_metric", bins=[0, 0.1, 0.2] + ) @event_metric.Timed() @counter.Counted() @@ -441,7 +465,8 @@ def testMultipleFuncs(self): with self.SetUpStatsCollector(self._CreateStatsCollector()): counter = metrics.Counter("testMultipleFuncs_counter") event_metric = metrics.Event( - "testMultipleFuncs_event_metric", bins=[0, 1, 2]) + "testMultipleFuncs_event_metric", bins=[0, 1, 2] + ) @counter.Counted() def Func1(n): diff --git a/grr/core/grr_response_core/stats/stats_utils.py b/grr/core/grr_response_core/stats/stats_utils.py index 52f8536938..90c226ce8b 100644 --- a/grr/core/grr_response_core/stats/stats_utils.py +++ b/grr/core/grr_response_core/stats/stats_utils.py @@ -1,15 +1,13 @@ #!/usr/bin/env python """Utilities for handling stats.""" - import functools import time -from typing import Text from grr_response_core.lib.rdfvalues import stats as rdf_stats -class Timed(object): +class Timed: """A decorator that records timing metrics for function calls.""" def __init__(self, event_metric, fields=None): @@ -31,7 +29,7 @@ def Decorated(*args, **kwargs): return Decorated -class Counted(object): +class Counted: """A decorator that counts function calls.""" def __init__(self, counter_metric, fields=None): @@ -51,7 +49,7 @@ def Decorated(*args, **kwargs): return Decorated -class SuccessesCounted(object): +class SuccessesCounted: """A decorator that counts function calls that don't raise an exception.""" def __init__(self, counter_metric, fields=None): @@ -70,7 +68,7 @@ def Decorated(*args, **kwargs): return Decorated -class ErrorsCounted(object): +class ErrorsCounted: """A decorator that counts function calls that raise an exception.""" def __init__(self, counter_metric, fields=None): @@ -97,7 +95,7 @@ def FieldDefinitionProtosFromTuples(field_def_tuples): for field_name, field_type in field_def_tuples: if field_type is int: field_type = rdf_stats.MetricFieldDefinition.FieldType.INT - elif issubclass(field_type, Text): + elif issubclass(field_type, str): field_type = rdf_stats.MetricFieldDefinition.FieldType.STR elif issubclass(field_type, bool): field_type = rdf_stats.MetricFieldDefinition.FieldType.BOOL @@ -105,7 +103,9 @@ def FieldDefinitionProtosFromTuples(field_def_tuples): raise ValueError("Invalid field type: %s" % field_type) field_def_protos.append( rdf_stats.MetricFieldDefinition( - field_name=field_name, field_type=field_type)) + field_name=field_name, field_type=field_type + ) + ) return field_def_protos diff --git a/grr/core/scripts/make_new_server_key.py b/grr/core/scripts/make_new_server_key.py deleted file mode 100644 index 1645123db0..0000000000 --- a/grr/core/scripts/make_new_server_key.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -"""A console script to create new server keys. - -Use this script to make a new server key and certificate. You can just run the -script inside a console using `run -i make_new_server_key.py`. -""" - -from grr import config -from grr.lib.rdfvalues import crypto as rdf_crypto -from grr.server import key_utils - -ca_certificate = config.CONFIG["CA.certificate"] -ca_private_key = config.CONFIG["PrivateKeys.ca_key"] - -# Check the current certificate serial number -existing_cert = config.CONFIG["Frontend.certificate"] -print("Current serial number:", existing_cert.GetSerialNumber()) - -server_private_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=4096) -server_cert = key_utils.MakeCASignedCert( - u"grr", - server_private_key, - ca_certificate, - ca_private_key, - serial_number=existing_cert.GetSerialNumber() + 1) - -print("New Server cert (Frontend.certificate):") -print(server_cert.AsPEM()) - -print("New Server Private Key:") -print(server_private_key.AsPEM()) diff --git a/grr/proto/grr_response_proto/artifact.proto b/grr/proto/grr_response_proto/artifact.proto index ecd06349c6..899ed22e8e 100644 --- a/grr/proto/grr_response_proto/artifact.proto +++ b/grr/proto/grr_response_proto/artifact.proto @@ -90,9 +90,8 @@ message Artifact { description: "A list of artifact collectors.", }]; */ - repeated string provides = 8 [(sem_type) = { - description: "A list of knowledgebase values this artifact provides.", - }]; + reserved 8; + reserved "provides"; repeated ArtifactSource sources = 9 [(sem_type) = { description: "A list of artifact sources.", diff --git a/grr/proto/grr_response_proto/config_file.proto b/grr/proto/grr_response_proto/config_file.proto index 39b2b2cb51..976e60d6bf 100644 --- a/grr/proto/grr_response_proto/config_file.proto +++ b/grr/proto/grr_response_proto/config_file.proto @@ -51,23 +51,6 @@ message NfsExport { repeated NfsClient clients = 3; } -// An sshd match block configuration. This is a subcomponent of an sshd config. -message SshdMatchBlock { - optional string criterion = 1 - [(sem_type) = { description: "Criteria that trigger a match block." }]; - optional AttributedDict config = 2 - [(sem_type) = { description: "The configuration of the match block." }]; -} - -// A sshd configuration containing the sshd settings, and any number of match -// groups. -message SshdConfig { - optional AttributedDict config = 1 - [(sem_type) = { description: "The main sshd configuration." }]; - repeated SshdMatchBlock matches = 2 - [(sem_type) = { description: "Match block sections." }]; -} - // A ntp configuration containing the all the ntp settings. message NtpConfig { optional AttributedDict config = 1 @@ -112,60 +95,3 @@ message PamConfig { description: "Details of references to external config files." }]; } - -// Sudoers aliases. -message SudoersAlias { - enum Type { - USER = 0; - RUNAS = 1; - HOST = 2; - CMD = 3; - } - - optional Type type = 1 [(sem_type) = { description: "Alias type." }]; - optional string name = 2 [(sem_type) = { description: "Alias name." }]; - - repeated string users = 3 - [(sem_type) = { description: "User list, if type is USER." }]; - repeated string runas = 4 - [(sem_type) = { description: "Runas list, if type is RUNAS." }]; - repeated string hosts = 5 - [(sem_type) = { description: "Host list, if type is HOST." }]; - repeated string cmds = 6 - [(sem_type) = { description: "Command list, if type is CMD." }]; -} - -// Default setting in sudoers. -message SudoersDefault { - optional string scope = 1 - [(sem_type) = { description: "Scope for this default (eg, >root)." }]; - optional string name = 2 [(sem_type) = { - description: "Name for the default, including negations (!)." - }]; - optional string value = 3 - [(sem_type) = { description: "Value for the default, if one exists." }]; -} - -// Sudoers file entry. -message SudoersEntry { - repeated string users = 1 - [(sem_type) = { description: "Users this rule applies to." }]; - repeated string hosts = 2 - [(sem_type) = { description: "Hosts this rule applies to (optional)." }]; - repeated string cmdspec = 3 [(sem_type) = { - description: "All content after the '=' in the sudoers rule." - }]; -} - -// Sudoers configuration. -message SudoersConfig { - repeated SudoersDefault defaults = 1 [(sem_type) = { - description: "Default settings (binary options or key/value pairs)." - }]; - repeated SudoersAlias aliases = 2 - [(sem_type) = { description: "Aliases within a sudoers file." }]; - repeated SudoersEntry entries = 3 - [(sem_type) = { description: "Entries within a sudoers file." }]; - repeated string includes = 4 - [(sem_type) = { description: "Includes within a sudoers file." }]; -} diff --git a/grr/proto/grr_response_proto/deprecated.proto b/grr/proto/grr_response_proto/deprecated.proto index de413c6cc2..b46f4fcca6 100644 --- a/grr/proto/grr_response_proto/deprecated.proto +++ b/grr/proto/grr_response_proto/deprecated.proto @@ -532,14 +532,6 @@ message UninstallArgs { }]; } -message UpdateClientArgs { - reserved 1; - optional string binary_path = 2 [(sem_type) = { - description: "Identifies the binary uploaded to GRR server that has " - "to be run on the client to perform the update.", - }]; -} - message KeepAliveArgs { optional uint64 duration = 1 [ (sem_type) = { @@ -863,3 +855,162 @@ message ApiReportTickSpecifier { optional float x = 1; optional string label = 2; } + +message ChromeHistoryArgs { + optional PathSpec.PathType pathtype = 1 [ + (sem_type) = { description: "Type of path access to use." }, + default = OS + ]; + + optional bool get_archive = 2 [(sem_type) = { + description: "Should we get Archived History as well (3 months old)." + }]; + + optional string username = 3 [(sem_type) = { + description: "The user to get Chrome history for. If history_path is " + "not set this will be used to guess the path to the " + "history files. Can be in form DOMAIN\\user.", + }]; + + optional string history_path = 5 [(sem_type) = { + description: "Path to a profile directory that contains a History " + "file.", + }]; +} + +message FirefoxHistoryArgs { + optional PathSpec.PathType pathtype = 1 [ + (sem_type) = { description: "Type of path access to use." }, + default = OS + ]; + + optional bool get_archive = 2 [ + (sem_type) = { + description: "Should we get Archived History as well (3 months old).", + }, + default = false + ]; + + optional string username = 3 [(sem_type) = { + description: "The user to get history for. If history_path is " + "not set this will be used to guess the path to the " + "history files. Can be in form DOMAIN\\user." + }]; + + optional string history_path = 5 [(sem_type) = { + description: "Path to a profile directory that contains a History file.", + }]; +} + +message PCIDevice { + // Location of PCI device on the system. + optional uint32 domain = 1 [(sem_type) = { + description: "PCI domain this device is in.", + }]; + optional uint32 bus = 2 [(sem_type) = { + description: "8 bit PCI bus this device is on.", + }]; + optional uint32 device = 3 [(sem_type) = { + description: "5 bit device location on the bus.", + }]; + optional uint32 function = 4 [(sem_type) = { + description: "3 bit device function that has been mapped.", + }]; + + // Information from the PCI device itself. + optional string class_id = 5 [(sem_type) = { + description: "Hex string of 24 bit device class ID (e.g. Display).", + }]; + optional string vendor = 6 [(sem_type) = { + description: "Hex string of 16 bit device vendor ID.", + }]; + optional string vendor_device_id = 7 [(sem_type) = { + description: "Hex string of 16 bit device ID as set by the vendor.", + }]; + // This is stored as bytes to preserve data as-is. + optional bytes config = 8 [(sem_type) = { + description: "64 to 256 bytes of PCI configuration space header.", + }]; +} + +message SshdMatchBlock { + optional string criterion = 1 [(sem_type) = { + description: "Criteria that trigger a match block.", + }]; + optional AttributedDict config = 2 [(sem_type) = { + description: "The configuration of the match block.", + }]; +} + +message SshdConfig { + optional AttributedDict config = 1 [(sem_type) = { + description: "The main sshd configuration.", + }]; + repeated SshdMatchBlock matches = 2 [(sem_type) = { + description: "Match block sections.", + }]; +} + +message SudoersAlias { + enum Type { + USER = 0; + RUNAS = 1; + HOST = 2; + CMD = 3; + } + + optional Type type = 1 [(sem_type) = { description: "Alias type." }]; + optional string name = 2 [(sem_type) = { description: "Alias name." }]; + + repeated string users = 3 [(sem_type) = { + description: "User list, if type is USER.", + }]; + repeated string runas = 4 [(sem_type) = { + description: "Runas list, if type is RUNAS.", + }]; + repeated string hosts = 5 [(sem_type) = { + description: "Host list, if type is HOST.", + }]; + repeated string cmds = 6 [(sem_type) = { + description: "Command list, if type is CMD.", + }]; +} + +message SudoersDefault { + optional string scope = 1 [(sem_type) = { + description: "Scope for this default (eg, >root).", + }]; + optional string name = 2 [(sem_type) = { + description: "Name for the default, including negations (!).", + }]; + optional string value = 3 [(sem_type) = { + description: "Value for the default, if one exists.", + }]; +} + +message SudoersEntry { + repeated string users = 1 [(sem_type) = { + description: "Users this rule applies to.", + }]; + repeated string hosts = 2 [(sem_type) = { + description: "Hosts this rule applies to (optional).", + }]; + repeated string cmdspec = 3 [(sem_type) = { + description: "All content after the '=' in the sudoers rule.", + }]; +} + +message SudoersConfig { + repeated SudoersDefault defaults = 1 [(sem_type) = { + description: "Default settings (binary options or key/value pairs).", + }]; + repeated SudoersAlias aliases = 2 [(sem_type) = { + description: "Aliases within a sudoers file.", + }]; + repeated SudoersEntry entries = 3 [(sem_type) = { + description: "Entries within a sudoers file.", + }]; + repeated string includes = 4 [(sem_type) = { + description: "Includes within a sudoers file.", + }]; +} diff --git a/grr/proto/grr_response_proto/export.proto b/grr/proto/grr_response_proto/export.proto index 63d38039c8..0049765fdd 100644 --- a/grr/proto/grr_response_proto/export.proto +++ b/grr/proto/grr_response_proto/export.proto @@ -51,9 +51,7 @@ message ExportedMetadata { }]; optional uint64 client_age = 4 [(sem_type) = { type: "RDFDatetime", description: "Age of the client." }]; - // TODO: Remove this field once it is confirmed that it is safe - // to do. - optional string uname = 5 [(sem_type) = { description: "Uname string." }]; + reserved 5; optional string os_release = 6 [(sem_type) = { description: "The OS release identifier e.g. 7, OSX, debian." }]; diff --git a/grr/proto/grr_response_proto/flows.proto b/grr/proto/grr_response_proto/flows.proto index 2f1c2d16bb..95a696048e 100644 --- a/grr/proto/grr_response_proto/flows.proto +++ b/grr/proto/grr_response_proto/flows.proto @@ -25,6 +25,13 @@ message CloudMetadataRequest { }]; optional CloudInstance.InstanceType instance_type = 7 [(sem_type) = { description: "AMAZON/GOOGLE etc." }, default = UNSET]; + + // Whether to ignore HTTP errors when the request is made. + // + // Some metadata requests are not critical and the whole action invocation + // should not fail in case attempt to obtain them fails. By setting this flag + // we make such failures non-critical. + optional bool ignore_http_errors = 8; } message CloudMetadataRequests { @@ -550,6 +557,15 @@ message OnlineNotificationArgs { }]; } +// Next field ID: 3 +message UpdateClientArgs { + reserved 1; + optional string binary_path = 2 [(sem_type) = { + description: "Identifies the binary uploaded to GRR server that has " + "to be run on the client to perform the update.", + }]; +} + // Next field ID: 3 message LaunchBinaryArgs { optional string binary = 1 [(sem_type) = { @@ -748,31 +764,6 @@ message FileCollectorArgs { }]; } -// Next field ID: 6 -message FirefoxHistoryArgs { - optional PathSpec.PathType pathtype = 1 [ - (sem_type) = { description: "Type of path access to use." }, - default = OS - ]; - - optional bool get_archive = 2 [ - (sem_type) = { - description: "Should we get Archived History as well (3 months old).", - }, - default = false - ]; - - optional string username = 3 [(sem_type) = { - description: "The user to get history for. If history_path is " - "not set this will be used to guess the path to the " - "history files. Can be in form DOMAIN\\user." - }]; - - optional string history_path = 5 [(sem_type) = { - description: "Path to a profile directory that contains a History file.", - }]; -} - // Next field ID: 2 message ListDirectoryArgs { optional PathSpec pathspec = 1 @@ -866,29 +857,6 @@ message GetMBRArgs { ]; } -// Next field ID: 6 -message ChromeHistoryArgs { - optional PathSpec.PathType pathtype = 1 [ - (sem_type) = { description: "Type of path access to use." }, - default = OS - ]; - - optional bool get_archive = 2 [(sem_type) = { - description: "Should we get Archived History as well (3 months old)." - }]; - - optional string username = 3 [(sem_type) = { - description: "The user to get Chrome history for. If history_path is " - "not set this will be used to guess the path to the " - "history files. Can be in form DOMAIN\\user.", - }]; - - optional string history_path = 5 [(sem_type) = { - description: "Path to a profile directory that contains a History " - "file.", - }]; -} - enum Browser { UNDEFINED = 0; CHROME = 1; @@ -2007,7 +1975,7 @@ message YaraSignatureShard { optional bytes payload = 2; } -// Next field ID: 24 +// Next field ID: 26 message YaraProcessScanRequest { optional string yara_signature = 1 [(sem_type) = { type: "YaraSignature", @@ -2065,6 +2033,10 @@ message YaraProcessScanRequest { "flag to change this behavior.", label: ADVANCED, }]; + + // Whether to skip scanning all parent processes of the GRR agent. + optional bool ignore_parent_processes = 24; + optional uint32 per_process_timeout = 5 [(sem_type) = { description: "A timeout in seconds that is applied while scanning; " "applies to each scan individually.", @@ -2127,6 +2099,10 @@ message YaraProcessScanRequest { type: "Duration", }]; + optional uint32 context_window = 25 [(sem_type) = { + description: "Include this many bytes of context around the hit.", + }]; + // ImplementationType has been introduced for rolling-out sandboxing. // The purpose is to make it possible to switch the implementation at // run-time via the UI in case there are problems. @@ -2166,6 +2142,9 @@ message YaraStringMatch { optional bytes data = 3 [(sem_type) = { description: "The actual data that matched.", }]; + optional bytes context = 4 [(sem_type) = { + description: "The context around the match.", + }]; } message YaraMatch { @@ -2272,6 +2251,9 @@ message YaraProcessDumpArgs { "the remaining memory pages will be dumped up to size_limit.", label: ADVANCED, }]; + + // Whether to skip dumping all parent processes of the GRR agent. + optional bool ignore_parent_processes = 13; } message ProcessMemoryRegion { @@ -2363,6 +2345,9 @@ message FlowProcessingRequest { optional uint64 delivery_time = 3 [(sem_type) = { type: "RDFDatetime", }]; + optional uint64 creation_time = 4 [(sem_type) = { + type: "RDFDatetime", + }]; } message FlowRequest { @@ -2385,9 +2370,12 @@ message FlowRequest { optional uint64 nr_responses_expected = 8; // Id of the response that should be the next for incremental processing. optional uint64 next_response_id = 10; + optional uint64 timestamp = 11 [(sem_type) = { + type: "RDFDatetime", + }]; } -// Next id: 9 +// Next id: 10 message FlowResponse { optional string client_id = 1; optional string flow_id = 2; @@ -2408,6 +2396,9 @@ message FlowResponse { optional google.protobuf.Any any_payload = 8 [(sem_type) = { no_dynamic_type_lookup: true, }]; + optional uint64 timestamp = 9 [(sem_type) = { + type: "RDFDatetime", + }]; } // Next id: 7 @@ -2444,7 +2435,7 @@ message FlowLogEntry { optional string message = 2; } -// Next id: 12 +// Next id: 13 message FlowStatus { optional string client_id = 1; optional string flow_id = 2; @@ -2471,15 +2462,21 @@ message FlowStatus { optional uint64 runtime_us = 11 [(sem_type) = { type: "Duration", }]; + optional uint64 timestamp = 12 [(sem_type) = { + type: "RDFDatetime", + }]; } -// Next id: 6 +// Next id: 7 message FlowIterator { optional string client_id = 1; optional string flow_id = 2; optional string hunt_id = 5; optional uint64 request_id = 3; optional uint64 response_id = 4; + optional uint64 timestamp = 6 [(sem_type) = { + type: "RDFDatetime", + }]; } // Next id: 36 diff --git a/grr/proto/grr_response_proto/jobs.proto b/grr/proto/grr_response_proto/jobs.proto index 8e5607307e..8a8c7d9702 100644 --- a/grr/proto/grr_response_proto/jobs.proto +++ b/grr/proto/grr_response_proto/jobs.proto @@ -508,10 +508,20 @@ message RequestState { } message CpuSeconds { - optional float user_cpu_time = 1 [(sem_type) = { + // TODO: Remove `deprecated` fields. + // DEPRECATED: Use user_cpu_seconds instead. + optional float deprecated_user_cpu_time = 1 [(sem_type) = { friendly_name: "User cpu seconds used", }]; - optional float system_cpu_time = 2 [(sem_type) = { + // DEPRECATED: Use system_cpu_seconds instead. + optional float deprecated_system_cpu_time = 2 [(sem_type) = { + friendly_name: "System cpu seconds used", + }]; + + optional double user_cpu_time = 3 [(sem_type) = { + friendly_name: "User cpu seconds used", + }]; + optional double system_cpu_time = 4 [(sem_type) = { friendly_name: "System cpu seconds used", }]; } diff --git a/grr/proto/grr_response_proto/sysinfo.proto b/grr/proto/grr_response_proto/sysinfo.proto index ea8dbc0314..308d89339b 100644 --- a/grr/proto/grr_response_proto/sysinfo.proto +++ b/grr/proto/grr_response_proto/sysinfo.proto @@ -56,38 +56,6 @@ message ManagementAgent { }]; } -// Describe a PCI Device -message PCIDevice { - // Location of PCI device on the system. - optional uint32 domain = 1 [(sem_type) = { - description: "PCI domain this device is in.", - }]; - optional uint32 bus = 2 [(sem_type) = { - description: "8 bit PCI bus this device is on.", - }]; - optional uint32 device = 3 [(sem_type) = { - description: "5 bit device location on the bus.", - }]; - optional uint32 function = 4 [(sem_type) = { - description: "3 bit device function that has been mapped.", - }]; - - // Information from the PCI device itself. - optional string class_id = 5 [(sem_type) = { - description: "Hex string of 24 bit device class ID (e.g. Display).", - }]; - optional string vendor = 6 [(sem_type) = { - description: "Hex string of 16 bit device vendor ID.", - }]; - optional string vendor_device_id = 7 [(sem_type) = { - description: "Hex string of 16 bit device ID as set by the vendor.", - }]; - // This is stored as bytes to preserve data as-is. - optional bytes config = 8 [(sem_type) = { - description: "64 to 256 bytes of PCI configuration space header.", - }]; -} - // A Process record describing a system process. message Process { option (semantic) = { diff --git a/grr/server/grr_response_server/artifact.py b/grr/server/grr_response_server/artifact.py index 4e7bbd30fc..8a1deaa617 100644 --- a/grr/server/grr_response_server/artifact.py +++ b/grr/server/grr_response_server/artifact.py @@ -910,9 +910,17 @@ def _ProcessWindowsProfiles( # pylint: disable=line-too-long # pyformat: disable args.paths.extend([ - rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\*", - rf"HKEY_USERS\{user.sid}\Environment\*", - rf"HKEY_USERS\{user.sid}\Volatile Environment\*", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\{{A520A1A4-1780-4FF6-BD18-167343C5AF16}}", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Desktop", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\AppData", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Local AppData", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Cookies", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Cache", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Recent", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Startup", + rf"HKEY_USERS\{user.sid}\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\Personal", + rf"HKEY_USERS\{user.sid}\Environment\TEMP", + rf"HKEY_USERS\{user.sid}\Volatile Environment\USERDOMAIN", ]) # pylint: enable=line-too-long # pyformat: enable @@ -981,15 +989,35 @@ def _ProcessWindowsProfileExtras( registry_value = parts[-1] registry_data = response.stat_entry.registry_data.string - attrs = windows_registry_parser.WinUserSpecialDirs.key_var_mapping - try: - attr = attrs[registry_key][registry_value] - except KeyError: + # TODO: Replace with `match` once we can use Python 3.10 + # features. + case = (registry_key, registry_value) + if case == ("Shell Folders", "{A520A1A4-1780-4FF6-BD18-167343C5AF16}"): + user.localappdata_low = registry_data + elif case == ("Shell Folders", "Desktop"): + user.desktop = registry_data + elif case == ("Shell Folders", "AppData"): + user.appdata = registry_data + elif case == ("Shell Folders", "Local AppData"): + user.localappdata = registry_data + elif case == ("Shell Folders", "Cookies"): + user.cookies = registry_data + elif case == ("Shell Folders", "Cache"): + user.internet_cache = registry_data + elif case == ("Shell Folders", "Recent"): + user.recent = registry_data + elif case == ("Shell Folders", "Startup"): + user.startup = registry_data + elif case == ("Shell Folders", "Personal"): + user.personal = registry_data + elif case == ("Environment", "TEMP"): + user.temp = registry_data + elif case == ("Volatile Environment", "USERDOMAIN"): + user.userdomain = registry_data + else: self.Log("Invalid registry value for %r", path) continue - setattr(user, attr, registry_data) - def _ProcessWindowsWMIUserAccounts( self, responses: flow_responses.Responses[rdfvalue.RDFValue], @@ -1027,6 +1055,23 @@ def _ProcessWindowsWMIUserAccounts( def End(self, responses): """Finish up.""" del responses + + # TODO: `%LOCALAPPDATA%` is a very often used variable that we + # potentially not collect due to limitations of the Windows registry. For + # now, in case we did not collect it, we set it to the default Windows value + # (which should be the case almost always but is nevertheless not the most + # way of handling it). + # + # Alternatively, we could develop a more general way of handling default + # environment variable values in case they are missing. + for user in self.state.knowledge_base.users: + if not user.localappdata: + self.Log( + "Missing `%%LOCALAPPDATA%%` for '%s', using Windows default", + user.username, + ) + user.localappdata = rf"{user.userprofile}\AppData\Local" + self.SendReply(self.state.knowledge_base) def InitializeKnowledgeBase(self): diff --git a/grr/server/grr_response_server/artifact_registry.py b/grr/server/grr_response_server/artifact_registry.py index 09dcab21c2..215f38b629 100644 --- a/grr/server/grr_response_server/artifact_registry.py +++ b/grr/server/grr_response_server/artifact_registry.py @@ -22,6 +22,7 @@ # files. DEPRECATED_ARTIFACT_FIELDS = frozenset([ "labels", + "provides", ]) @@ -339,7 +340,6 @@ def GetArtifacts(self, name_list=None, source_type=None, exclude_dependents=False, - provides=None, reload_datastore_artifacts=False): """Retrieve artifact classes with optional filtering. @@ -352,7 +352,6 @@ def GetArtifacts(self, source_type exclude_dependents: if true only artifacts with no dependencies will be returned - provides: return the artifacts that provide these dependencies reload_datastore_artifacts: If true, the data store sources are queried for new artifacts. @@ -376,14 +375,7 @@ def GetArtifacts(self, if exclude_dependents and GetArtifactPathDependencies(artifact): continue - if not provides: - results[artifact.name] = artifact - else: - # This needs to remain the last test, if it matches the result is added - for provide_string in artifact.provides: - if provide_string in provides: - results[artifact.name] = artifact - break + results[artifact.name] = artifact return list(results.values()) @@ -432,78 +424,6 @@ def Exists(self, name: str) -> bool: def GetArtifactNames(self, *args, **kwargs): return set([a.name for a in self.GetArtifacts(*args, **kwargs)]) - @utils.Synchronized - def SearchDependencies(self, - os_name, - artifact_name_list, - existing_artifact_deps=None, - existing_expansion_deps=None): - """Return a set of artifact names needed to fulfill dependencies. - - Search the path dependency tree for all artifacts that can fulfill - dependencies of artifact_name_list. If multiple artifacts provide a - dependency, they are all included. - - Args: - os_name: operating system string - artifact_name_list: list of artifact names to find dependencies for. - existing_artifact_deps: existing dependencies to add to, for recursion, - e.g. set(["WindowsRegistryProfiles", "WindowsEnvironmentVariablePath"]) - existing_expansion_deps: existing expansion dependencies to add to, for - recursion, e.g. set(["users.userprofile", "users.homedir"]) - - Returns: - (artifact_names, expansion_names): a tuple of sets, one with artifact - names, the other expansion names - """ - artifact_deps = existing_artifact_deps or set() - expansion_deps = existing_expansion_deps or set() - - artifact_objs = self.GetArtifacts( - os_name=os_name, name_list=artifact_name_list) - artifact_deps = artifact_deps.union([a.name for a in artifact_objs]) - - for artifact in artifact_objs: - expansions = GetArtifactPathDependencies(artifact) - if expansions: - expansion_deps = expansion_deps.union(set(expansions)) - # Get the names of the artifacts that provide those expansions - new_artifact_names = self.GetArtifactNames( - os_name=os_name, provides=expansions) - missing_artifacts = new_artifact_names - artifact_deps - - if missing_artifacts: - # Add those artifacts and any child dependencies - new_artifacts, new_expansions = self.SearchDependencies( - os_name, - new_artifact_names, - existing_artifact_deps=artifact_deps, - existing_expansion_deps=expansion_deps) - artifact_deps = artifact_deps.union(new_artifacts) - expansion_deps = expansion_deps.union(new_expansions) - - return artifact_deps, expansion_deps - - @utils.Synchronized - def DumpArtifactsToYaml(self, sort_by_os=True): - """Dump a list of artifacts into a yaml string.""" - artifact_list = self.GetArtifacts() - if sort_by_os: - # Sort so its easier to split these if necessary. - yaml_list = [] - for os_name in rdf_artifacts.Artifact.SUPPORTED_OS_LIST: - done = {a.name: a for a in artifact_list if a.supported_os == [os_name]} - # Separate into knowledge_base and non-kb for easier sorting. - done_sorted = list(sorted(done.values(), key=lambda x: x.name)) - yaml_list.extend(x.ToYaml() for x in done_sorted if x.provides) - yaml_list.extend(x.ToYaml() for x in done_sorted if not x.provides) - artifact_list = [a for a in artifact_list if a.name not in done] - yaml_list.extend(x.ToYaml() for x in artifact_list) # The rest. - else: - yaml_list = [x.ToYaml() for x in artifact_list] - - return "---\n\n".join(yaml_list) - REGISTRY = ArtifactRegistry() @@ -563,18 +483,12 @@ def ValidateSyntax(rdf_artifact): detail = "invalid `supported_os` ('%s' not in %s)" % (supp_os, valid_os) raise rdf_artifacts.ArtifactSyntaxError(rdf_artifact, detail) - # Anything listed in provides must be defined in the KnowledgeBase - valid_provides = rdf_client.KnowledgeBase().GetKbFieldNames() - for kb_var in rdf_artifact.provides: - if kb_var not in valid_provides: - detail = "broken `provides` ('%s' not in %s)" % (kb_var, valid_provides) - raise rdf_artifacts.ArtifactSyntaxError(rdf_artifact, detail) + kb_field_names = rdf_client.KnowledgeBase().GetKbFieldNames() # Any %%blah%% path dependencies must be defined in the KnowledgeBase for dep in GetArtifactPathDependencies(rdf_artifact): - if dep not in valid_provides: - detail = "broken path dependencies ('%s' not in %s)" % (dep, - valid_provides) + if dep not in kb_field_names: + detail = f"broken path dependencies ({dep!r} not in {kb_field_names})" raise rdf_artifacts.ArtifactSyntaxError(rdf_artifact, detail) for source in rdf_artifact.sources: diff --git a/grr/server/grr_response_server/artifact_registry_test.py b/grr/server/grr_response_server/artifact_registry_test.py index 26a43123fc..99ff7adff4 100644 --- a/grr/server/grr_response_server/artifact_registry_test.py +++ b/grr/server/grr_response_server/artifact_registry_test.py @@ -87,7 +87,6 @@ def testValidateSyntaxSimple(self): artifact = rdf_artifacts.Artifact( name="Foo", doc="This is Foo.", - provides=["fqdn", "domain"], supported_os=["Windows"], urls=["https://example.com"]) ar.ValidateSyntax(artifact) @@ -113,15 +112,13 @@ def testValidateSyntaxWithSources(self): artifact = rdf_artifacts.Artifact( name="Bar", doc="This is Bar.", - provides=["environ_windir"], supported_os=["Windows"], urls=["https://example.com"], sources=[registry_key_source, file_source]) ar.ValidateSyntax(artifact) def testValidateSyntaxMissingDoc(self): - artifact = rdf_artifacts.Artifact( - name="Baz", provides=["os"], supported_os=["Linux"]) + artifact = rdf_artifacts.Artifact(name="Baz", supported_os=["Linux"]) with self.assertRaisesRegex(rdf_artifacts.ArtifactSyntaxError, "missing doc"): @@ -131,19 +128,11 @@ def testValidateSyntaxInvalidSupportedOs(self): artifact = rdf_artifacts.Artifact( name="Quux", doc="This is Quux.", - provides=["os"], supported_os=["Solaris"]) with self.assertRaisesRegex(rdf_artifacts.ArtifactSyntaxError, "'Solaris'"): ar.ValidateSyntax(artifact) - def testValidateSyntaxBrokenProvides(self): - artifact = rdf_artifacts.Artifact( - name="Thud", doc="This is Thud.", provides=["fqdn", "garbage"]) - - with self.assertRaisesRegex(rdf_artifacts.ArtifactSyntaxError, "'garbage'"): - ar.ValidateSyntax(artifact) - def testValidateSyntaxBadSource(self): source = { "type": rdf_artifacts.ArtifactSource.SourceType.ARTIFACT_GROUP, @@ -153,7 +142,6 @@ def testValidateSyntaxBadSource(self): artifact = rdf_artifacts.Artifact( name="Barf", doc="This is Barf.", - provides=["os"], sources=[source]) with self.assertRaisesRegex(rdf_artifacts.ArtifactSyntaxError, @@ -252,6 +240,7 @@ def testArtifactsFromYamlIgnoresDeprecatedFields(self): - type: PATH attributes: paths: ['/bar', '/baz'] + provides: [os_release, os_major_version, os_minor_version] --- name: Quux doc: Lorem ipsum. @@ -260,6 +249,7 @@ def testArtifactsFromYamlIgnoresDeprecatedFields(self): - type: PATH attributes: paths: ['/norf', '/thud'] + provides: [domain] """) artifacts = registry.ArtifactsFromYaml(yaml) artifacts.sort(key=lambda artifact: artifact.name) diff --git a/grr/server/grr_response_server/artifact_test.py b/grr/server/grr_response_server/artifact_test.py index 9dbdca611c..96a7b424ce 100644 --- a/grr/server/grr_response_server/artifact_test.py +++ b/grr/server/grr_response_server/artifact_test.py @@ -21,7 +21,6 @@ from grr_response_core.lib import parser from grr_response_core.lib import parsers from grr_response_core.lib import rdfvalue -from grr_response_core.lib.parsers import linux_file_parser from grr_response_core.lib.parsers import wmi_parser from grr_response_core.lib.rdfvalues import anomaly as rdf_anomaly from grr_response_core.lib.rdfvalues import artifacts as rdf_artifacts @@ -39,6 +38,7 @@ from grr_response_server.databases import db from grr_response_server.databases import db_test_utils from grr_response_server.flows.general import collectors +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import action_mocks from grr.test_lib import artifact_test_lib @@ -236,18 +236,28 @@ def testUploadArtifactYamlFileAndDumpToYaml(self): artifact.UploadArtifactYamlFile(filedesc.read()) loaded_artifacts = artifact_registry.REGISTRY.GetArtifacts() self.assertGreaterEqual(len(loaded_artifacts), 20) - self.assertIn("DepsWindirRegex", [a.name for a in loaded_artifacts]) - - # Now dump back to YAML. - yaml_data = artifact_registry.REGISTRY.DumpArtifactsToYaml() - for snippet in [ - "name: TestFilesArtifact", - "urls:\\s*- https://msdn.microsoft.com/en-us/library/", - "returned_types:\\s*- SoftwarePackage", - "args:\\s*- --list", - "cmd: /usr/bin/dpkg", - ]: - self.assertRegex(yaml_data, snippet) + + artifacts_by_name = { + artifact.name: artifact for artifact in loaded_artifacts + } + self.assertIn("DepsWindirRegex", artifacts_by_name) + self.assertIn("TestFilesArtifact", artifacts_by_name) + self.assertStartsWith( + artifacts_by_name["WMIActiveScriptEventConsumer"].urls[0], + "https://msdn.microsoft.com/en-us/library/", + ) + self.assertEqual( + artifacts_by_name["TestEchoArtifact"].sources[0].returned_types, + ["SoftwarePackages"], + ) + self.assertEqual( + artifacts_by_name["TestCmdArtifact"].sources[0].attributes["cmd"], + "/usr/bin/dpkg", + ) + self.assertEqual( + artifacts_by_name["TestCmdArtifact"].sources[0].attributes["args"], + ["--list"], + ) finally: artifact.LoadArtifactsOnce() @@ -429,28 +439,6 @@ def testFilesArtifact(self): fd = file_store.OpenFile(cp) self.assertNotEmpty(fd.read()) - @parser_test_lib.WithParser("Passwd", linux_file_parser.PasswdBufferParser) - def testLinuxPasswdHomedirsArtifact(self): - """Check LinuxPasswdHomedirs artifacts.""" - with vfs_test_lib.FakeTestDataVFSOverrider(): - fd = self.RunCollectorAndGetResults( - ["LinuxPasswdHomedirs"], - client_mock=action_mocks.ClientFileFinderWithVFS(), - client_id=test_lib.TEST_CLIENT_ID, - ) - - self.assertLen(fd, 5) - self.assertCountEqual( - [x.username for x in fd], - [u"exomemory", u"gevulot", u"gogol", u"user1", u"user2"]) - for user in fd: - if user.username == u"exomemory": - self.assertEqual(user.full_name, u"Never Forget (admin)") - self.assertEqual(user.gid, 47) - self.assertEqual(user.homedir, u"/var/lib/exomemory") - self.assertEqual(user.shell, u"/bin/sh") - self.assertEqual(user.uid, 46) - def testArtifactOutput(self): """Check we can run command based artifacts.""" client_id = test_lib.TEST_CLIENT_ID @@ -674,18 +662,10 @@ def EnumerateUsers( self.assertEqual(user.last_logon.AsSecondsSinceEpoch(), 1296552099) self.assertEqual(user.homedir, "/home/user1") - @parser_test_lib.WithAllParsers def testKnowledgeBaseRetrievalLinuxNoUsers(self): """Cause a users.username dependency failure.""" - with test_lib.ConfigOverrider({ - "Artifacts.knowledge_base": [ - "NetgroupConfiguration", - ], - "Artifacts.netgroup_filter_regexes": ["^doesntexist$"], - }): - with vfs_test_lib.FakeTestDataVFSOverrider(): - with test_lib.SuppressLogs(): - kb = self._RunKBI(require_complete=False) + with vfs_test_lib.FakeTestDataVFSOverrider(): + kb = self._RunKBI(require_complete=False) self.assertEqual(kb.os_major_version, 14) self.assertEqual(kb.os_minor_version, 4) @@ -1008,7 +988,9 @@ def _WriteFile(self, path: str, data: bytes) -> None: path_info = rdf_objects.PathInfo.OS(components=components) path_info.hash_entry.sha256 = bytes(blob_id) - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) client_path = db.ClientPath.OS( client_id=self.client_id, components=components) diff --git a/grr/server/grr_response_server/artifact_utils_test.py b/grr/server/grr_response_server/artifact_utils_test.py index 6c25644c8b..c9db67182d 100644 --- a/grr/server/grr_response_server/artifact_utils_test.py +++ b/grr/server/grr_response_server/artifact_utils_test.py @@ -88,61 +88,20 @@ def testGetArtifacts(self, registry): for result in results: self.assertFalse(ar.GetArtifactPathDependencies(result)) - # Check provides filtering - results = registry.GetArtifacts( - os_name="Windows", provides=["users.homedir", "domain"]) - for result in results: - # provides contains at least one of the filter strings - self.assertGreaterEqual( - len(set(result.provides).union(set(["users.homedir", "domain"]))), 1) - - results = registry.GetArtifacts( - os_name="Windows", provides=["nothingprovidesthis"]) - self.assertEmpty(results) - @artifact_test_lib.PatchDefaultArtifactRegistry def testGetArtifactNames(self, registry): registry.AddFileSource(self.test_artifacts_file) - result_objs = registry.GetArtifacts( - os_name="Windows", provides=["users.homedir", "domain"]) + result_objs = registry.GetArtifacts(os_name="Windows") - results_names = registry.GetArtifactNames( - os_name="Windows", provides=["users.homedir", "domain"]) + results_names = registry.GetArtifactNames(os_name="Windows") self.assertCountEqual(set([a.name for a in result_objs]), results_names) self.assertNotEmpty(results_names) - results_names = registry.GetArtifactNames( - os_name="Darwin", provides=["users.username"]) + results_names = registry.GetArtifactNames(os_name="Darwin") self.assertIn("UsersDirectory", results_names) - @artifact_test_lib.PatchCleanArtifactRegistry - def testSearchDependencies(self, registry): - registry.AddFileSource(self.test_artifacts_file) - - names, expansions = registry.SearchDependencies( - "Windows", [u"TestAggregationArtifactDeps", u"DepsParent"]) - - # This list contains all artifacts that can provide the dependency, e.g. - # DepsHomedir and DepsHomedir2 both provide - # users.homedir. - self.assertCountEqual(names, [ - u"DepsHomedir", u"DepsHomedir2", u"DepsDesktop", u"DepsParent", - u"DepsWindir", u"DepsWindirRegex", u"DepsControlSet", - u"TestAggregationArtifactDeps" - ]) - - self.assertCountEqual(expansions, [ - "current_control_set", "users.homedir", "users.desktop", - "environ_windir", "users.username" - ]) - - # None of these match the OS, so we should get an empty list. - names, expansions = registry.SearchDependencies( - "Darwin", [u"TestCmdArtifact", u"TestFileArtifact"]) - self.assertCountEqual(names, []) - @artifact_test_lib.PatchCleanArtifactRegistry def testArtifactConversion(self, registry): registry.AddFileSource(self.test_artifacts_file) @@ -396,7 +355,6 @@ def GenerateSample(self, number=0): result = rdf_artifacts.Artifact( name="artifact%s" % number, doc="Doco", - provides="environ_windir", supported_os="Windows", urls="http://blah") return result @@ -430,7 +388,6 @@ def testGetArtifactPathDependencies(self): artifact = rdf_artifacts.Artifact( name="artifact", doc="Doco", - provides=["environ_windir"], supported_os=["Windows"], urls=["http://blah"], sources=sources) @@ -465,30 +422,11 @@ def testValidateSyntax(self): artifact = rdf_artifacts.Artifact( name="good", doc="Doco", - provides=["environ_windir"], supported_os=["Windows"], urls=["http://blah"], sources=sources) ar.ValidateSyntax(artifact) - def testValidateSyntaxBadProvides(self): - sources = [{ - "type": rdf_artifacts.ArtifactSource.SourceType.FILE, - "attributes": { - "paths": [r"%%environ_systemdrive%%\Temp"] - } - }] - - artifact = rdf_artifacts.Artifact( - name="bad", - doc="Doco", - provides=["windir"], - supported_os=["Windows"], - urls=["http://blah"], - sources=sources) - with self.assertRaises(rdf_artifacts.ArtifactDefinitionError): - ar.ValidateSyntax(artifact) - def testValidateSyntaxBadPathDependency(self): sources = [{ "type": rdf_artifacts.ArtifactSource.SourceType.FILE, @@ -500,7 +438,6 @@ def testValidateSyntaxBadPathDependency(self): artifact = rdf_artifacts.Artifact( name="bad", doc="Doco", - provides=["environ_windir"], supported_os=["Windows"], urls=["http://blah"], sources=sources) diff --git a/grr/server/grr_response_server/authorization/auth_manager.py b/grr/server/grr_response_server/authorization/auth_manager.py index bad1541547..adecbf0ec3 100644 --- a/grr/server/grr_response_server/authorization/auth_manager.py +++ b/grr/server/grr_response_server/authorization/auth_manager.py @@ -38,7 +38,8 @@ def CreateAuthorizations(self, yaml_data, auth_class): auth_object = auth_class(**auth) if auth_object.key in self.auth_objects: raise InvalidAuthorization( - "Duplicate authorizations for %s" % auth_object.key) + "Duplicate authorizations for %s" % auth_object.key + ) self.auth_objects[auth_object.key] = auth_object def GetAuthorizationForSubject(self, subject): @@ -67,8 +68,9 @@ class AuthorizationManager(object): def __init__(self, group_access_manager=None): self.authorized_users = dict() - self.group_access_manager = (group_access_manager or - groups.CreateGroupAccessManager()) + self.group_access_manager = ( + group_access_manager or groups.CreateGroupAccessManager() + ) self.Initialize() def Initialize(self): @@ -100,9 +102,9 @@ def CheckPermissions(self, username, subject): """Checks if a given user has access to a given subject.""" if subject in self.authorized_users: - return ((username in self.authorized_users[subject]) or - self.group_access_manager.MemberOfAuthorizedGroup( - username, subject)) + return ( + username in self.authorized_users[subject] + ) or self.group_access_manager.MemberOfAuthorizedGroup(username, subject) # In case the subject is not found, the safest thing to do is to raise. # It's up to the users of this class to handle this exception and diff --git a/grr/server/grr_response_server/authorization/auth_manager_test.py b/grr/server/grr_response_server/authorization/auth_manager_test.py index 35354d3ef5..54a8d3a289 100644 --- a/grr/server/grr_response_server/authorization/auth_manager_test.py +++ b/grr/server/grr_response_server/authorization/auth_manager_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Tests for AuthorizationManager.""" + from unittest import mock from absl import app @@ -38,10 +39,14 @@ def testCreateAuthorizationsInitializesAuthorizationsFromYaml(self): self.assertEqual( self.auth_reader.GetAuthorizationForSubject("ApiCallRobotRouter").data, - dict(router="ApiCallRobotRouter", users=["foo", "bar"])) + dict(router="ApiCallRobotRouter", users=["foo", "bar"]), + ) self.assertEqual( - self.auth_reader.GetAuthorizationForSubject("ApiCallDisabledRouter") - .data, dict(router="ApiCallDisabledRouter", users=["blah"])) + self.auth_reader.GetAuthorizationForSubject( + "ApiCallDisabledRouter" + ).data, + dict(router="ApiCallDisabledRouter", users=["blah"]), + ) def testCreateAuthorizationsRaisesOnDuplicateKeys(self): yaml_data = """ @@ -58,7 +63,8 @@ def testGetAllAuthorizationObjectsPreservesOrder(self): self.auth_reader.CreateAuthorizations(yaml_data, DummyAuthorization) for index, authorization in enumerate( - self.auth_reader.GetAllAuthorizationObjects()): + self.auth_reader.GetAllAuthorizationObjects() + ): self.assertEqual(authorization.key, "Router%d" % index) def testGetAuthSubjectsPreservesOrder(self): @@ -77,7 +83,8 @@ def setUp(self): self.group_access_manager = groups.NoGroupAccess() self.auth_manager = auth_manager.AuthorizationManager( - group_access_manager=self.group_access_manager) + group_access_manager=self.group_access_manager + ) def testGetAuthSubjectsPreservesOrder(self): for index in range(10): @@ -93,32 +100,41 @@ def testCheckPermissionRaisesInvalidSubjectIfNoSubjectRegistered(self): def testCheckPermissionsReturnsFalseIfDenyAllWasCalled(self): self.auth_manager.DenyAll("subject-bar") self.assertFalse( - self.auth_manager.CheckPermissions("user-foo", "subject-bar")) + self.auth_manager.CheckPermissions("user-foo", "subject-bar") + ) def testCheckPermissionsReturnsTrueIfUserWasAuthorized(self): self.auth_manager.AuthorizeUser("user-foo", "subject-bar") self.assertTrue( - self.auth_manager.CheckPermissions("user-foo", "subject-bar")) + self.auth_manager.CheckPermissions("user-foo", "subject-bar") + ) def testCheckPermissionsReturnsFalseIfUserWasNotAuthorized(self): self.auth_manager.AuthorizeUser("user-foo", "subject-bar") self.assertFalse( - self.auth_manager.CheckPermissions("user-bar", "subject-bar")) + self.auth_manager.CheckPermissions("user-bar", "subject-bar") + ) def testCheckPermissionsReturnsTrueIfGroupWasAuthorized(self): self.auth_manager.DenyAll("subject-bar") - with mock.patch.object(self.group_access_manager, "MemberOfAuthorizedGroup", - lambda *args: True): + with mock.patch.object( + self.group_access_manager, "MemberOfAuthorizedGroup", lambda *args: True + ): self.assertTrue( - self.auth_manager.CheckPermissions("user-bar", "subject-bar")) + self.auth_manager.CheckPermissions("user-bar", "subject-bar") + ) def testCheckPermissionsReturnsFalseIfGroupWasNotAuthorized(self): self.auth_manager.DenyAll("subject-bar") - with mock.patch.object(self.group_access_manager, "MemberOfAuthorizedGroup", - lambda *args: False): + with mock.patch.object( + self.group_access_manager, + "MemberOfAuthorizedGroup", + lambda *args: False, + ): self.assertFalse( - self.auth_manager.CheckPermissions("user-bar", "subject-bar")) + self.auth_manager.CheckPermissions("user-bar", "subject-bar") + ) def main(argv): diff --git a/grr/server/grr_response_server/authorization/client_approval_auth.py b/grr/server/grr_response_server/authorization/client_approval_auth.py index 5417bb49f3..eebcafa9ed 100644 --- a/grr/server/grr_response_server/authorization/client_approval_auth.py +++ b/grr/server/grr_response_server/authorization/client_approval_auth.py @@ -3,12 +3,10 @@ import io - from grr_response_core import config from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import acls_pb2 - from grr_response_server import access_control from grr_response_server.authorization import auth_manager @@ -31,6 +29,7 @@ class ErrorInvalidApprovalSpec(Error): class ClientApprovalAuthorization(rdf_structs.RDFProtoStruct): """Authorization to approve clients with a particular label.""" + protobuf = acls_pb2.ClientApprovalAuthorization @property @@ -38,14 +37,16 @@ def label(self): label = self.Get("label") if not label: raise ErrorInvalidClientApprovalAuthorization( - "label string cannot be empty") + "label string cannot be empty" + ) return label @label.setter def label(self, value): if not isinstance(value, str) or not value: raise ErrorInvalidClientApprovalAuthorization( - "label must be a non-empty string") + "label must be a non-empty string" + ) self.Set("label", value) @property @@ -124,9 +125,10 @@ def CheckApproversForLabel(self, client_id, requester, approvers, label): if auth.requester_must_be_authorized: if not self.CheckPermissions(requester, label): raise access_control.UnauthorizedAccess( - "User %s not in %s or groups:%s for %s" % - (requester, auth.users, auth.groups, label), - subject=client_id) + "User %s not in %s or groups:%s for %s" + % (requester, auth.users, auth.groups, label), + subject=client_id, + ) approved_count = 0 for approver in approvers: @@ -135,9 +137,10 @@ def CheckApproversForLabel(self, client_id, requester, approvers, label): if approved_count < auth.num_approvers_required: raise access_control.UnauthorizedAccess( - "Found %s approvers for %s, needed %s" % - (approved_count, label, auth.num_approvers_required), - subject=client_id) + "Found %s approvers for %s, needed %s" + % (approved_count, label, auth.num_approvers_required), + subject=client_id, + ) return True diff --git a/grr/server/grr_response_server/authorization/client_approval_auth_test.py b/grr/server/grr_response_server/authorization/client_approval_auth_test.py index 0ab71c131e..a5d79fe2e1 100644 --- a/grr/server/grr_response_server/authorization/client_approval_auth_test.py +++ b/grr/server/grr_response_server/authorization/client_approval_auth_test.py @@ -9,30 +9,36 @@ from grr.test_lib import test_lib -class ClientApprovalAuthorizationTest(rdf_test_base.RDFValueTestMixin, - test_lib.GRRBaseTest): +class ClientApprovalAuthorizationTest( + rdf_test_base.RDFValueTestMixin, test_lib.GRRBaseTest +): rdfvalue_class = client_approval_auth.ClientApprovalAuthorization def GenerateSample(self, number=0): return client_approval_auth.ClientApprovalAuthorization( - label="label%d" % number, users=["test", "test2"]) + label="label%d" % number, users=["test", "test2"] + ) def testApprovalValidation(self): # String instead of list of users with self.assertRaises( - client_approval_auth.ErrorInvalidClientApprovalAuthorization): + client_approval_auth.ErrorInvalidClientApprovalAuthorization + ): client_approval_auth.ClientApprovalAuthorization( - label="label", users="test") + label="label", users="test" + ) # Missing label acl = client_approval_auth.ClientApprovalAuthorization(users=["test"]) with self.assertRaises( - client_approval_auth.ErrorInvalidClientApprovalAuthorization): + client_approval_auth.ErrorInvalidClientApprovalAuthorization + ): print(acl.label) # Bad label with self.assertRaises( - client_approval_auth.ErrorInvalidClientApprovalAuthorization): + client_approval_auth.ErrorInvalidClientApprovalAuthorization + ): acl.label = None @@ -72,14 +78,16 @@ def _CreateAuthMultiApproval(self): def testRaisesOnNoApprovals(self): self._CreateAuthSingleLabel() with self.assertRaises(access_control.UnauthorizedAccess): - self.mgr.CheckApproversForLabel(self.client_id, "requester_user", [], - "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "requester_user", [], "label1" + ) def testRaisesOnSelfApproval(self): self._CreateAuthSingleLabel() with self.assertRaises(access_control.UnauthorizedAccess): - self.mgr.CheckApproversForLabel(self.client_id, "requester_user", - ["requester_user"], "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "requester_user", ["requester_user"], "label1" + ) def testRaisesOnAuthorizedSelfApproval(self): self._CreateAuthSingleLabel() @@ -89,19 +97,22 @@ def testRaisesOnAuthorizedSelfApproval(self): def testRaisesOnApprovalFromUnauthorized(self): self._CreateAuthSingleLabel() with self.assertRaises(access_control.UnauthorizedAccess): - self.mgr.CheckApproversForLabel(self.client_id, "requester_user", - ["approver1"], "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "requester_user", ["approver1"], "label1" + ) def testPassesWithApprovalFromApprovedUser(self): self._CreateAuthSingleLabel() - self.mgr.CheckApproversForLabel(self.client_id, "requester_user", - ["approver1", "two"], "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "requester_user", ["approver1", "two"], "label1" + ) def testRaisesWhenRequesterNotAuthorized(self): self._CreateAuthCheckRequester() with self.assertRaises(access_control.UnauthorizedAccess): - self.mgr.CheckApproversForLabel(self.client_id, "requester_user", ["one"], - "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "requester_user", ["one"], "label1" + ) def testRaisesOnSelfApprovalByAuthorizedRequester(self): self._CreateAuthCheckRequester() @@ -110,19 +121,22 @@ def testRaisesOnSelfApprovalByAuthorizedRequester(self): def testPassesWhenApproverAndRequesterAuthorized(self): self._CreateAuthCheckRequester() - self.mgr.CheckApproversForLabel(self.client_id, "one", ["one", "two"], - "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "one", ["one", "two"], "label1" + ) def testRaisesWhenOnlyOneAuthorizedApprover(self): self._CreateAuthMultiApproval() with self.assertRaises(access_control.UnauthorizedAccess): - self.mgr.CheckApproversForLabel(self.client_id, "one", ["one", "two"], - "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "one", ["one", "two"], "label1" + ) def testPassesWithTwoAuthorizedApprovers(self): self._CreateAuthMultiApproval() - self.mgr.CheckApproversForLabel(self.client_id, "one", ["two", "four"], - "label1") + self.mgr.CheckApproversForLabel( + self.client_id, "one", ["two", "four"], "label1" + ) def main(argv): diff --git a/grr/server/grr_response_server/authorization/mig_client_approval_auth.py b/grr/server/grr_response_server/authorization/mig_client_approval_auth.py index 4919cabc9f..0ec7e9a894 100644 --- a/grr/server/grr_response_server/authorization/mig_client_approval_auth.py +++ b/grr/server/grr_response_server/authorization/mig_client_approval_auth.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import acls_pb2 from grr_response_server.authorization import client_approval_auth diff --git a/grr/server/grr_response_server/bin/api_shell_raw_access.py b/grr/server/grr_response_server/bin/api_shell_raw_access.py index a74ce8d514..793be1a36d 100644 --- a/grr/server/grr_response_server/bin/api_shell_raw_access.py +++ b/grr/server/grr_response_server/bin/api_shell_raw_access.py @@ -18,32 +18,42 @@ from grr_response_server.gui import api_call_context _PAGE_SIZE = flags.DEFINE_integer( - "page_size", 1000, - "Page size used when paging through collections of items. Default is 1000.") + "page_size", + 1000, + "Page size used when paging through collections of items. Default is 1000.", +) _USERNAME = flags.DEFINE_string( - "username", None, "Username to use when making raw API calls. If not " - "specified, USER environment variable value will be used.") + "username", + None, + "Username to use when making raw API calls. If not " + "specified, USER environment variable value will be used.", +) _EXEC_CODE = flags.DEFINE_string( - "exec_code", None, + "exec_code", + None, "If present, no IPython shell is started but the code given in " "the flag is run instead (comparable to the -c option of " "IPython). The code will be able to use a predefined " - "global 'grrapi' object.") + "global 'grrapi' object.", +) _EXEC_FILE = flags.DEFINE_string( - "exec_file", None, + "exec_file", + None, "If present, no IPython shell is started but the code given in " "command file is supplied as input instead. The code " "will be able to use a predefined global 'grrapi' " - "object.") + "object.", +) _VERSION = flags.DEFINE_bool( "version", default=False, allow_override=True, - help="Print the API shell version number and exit immediately.") + help="Print the API shell version number and exit immediately.", +) def main(argv=None): @@ -54,8 +64,10 @@ def main(argv=None): return config.CONFIG.AddContext(contexts.COMMAND_LINE_CONTEXT) - config.CONFIG.AddContext(contexts.CONSOLE_CONTEXT, - "Context applied when running the console binary.") + config.CONFIG.AddContext( + contexts.CONSOLE_CONTEXT, + "Context applied when running the console binary.", + ) server_startup.Init() fleetspeak_connector.Init() @@ -64,14 +76,18 @@ def main(argv=None): username = os.environ["USER"] if not username: - print("Username has to be specified with either --username flag or " - "USER environment variable.") + print( + "Username has to be specified with either --username flag or " + "USER environment variable." + ) sys.exit(1) grrapi = api.GrrApi( connector=api_shell_raw_access_lib.RawConnector( context=api_call_context.ApiCallContext(username=username), - page_size=_PAGE_SIZE.value)) + page_size=_PAGE_SIZE.value, + ) + ) if _EXEC_CODE.value and _EXEC_FILE.value: print("--exec_code --exec_file flags can't be supplied together.") diff --git a/grr/server/grr_response_server/bin/api_shell_raw_access_lib.py b/grr/server/grr_response_server/bin/api_shell_raw_access_lib.py index ffc07bbe38..f1057e96a1 100644 --- a/grr/server/grr_response_server/bin/api_shell_raw_access_lib.py +++ b/grr/server/grr_response_server/bin/api_shell_raw_access_lib.py @@ -4,7 +4,6 @@ from typing import Optional from google.protobuf import message - from grr_api_client import connectors from grr_api_client import errors from grr_api_client import utils @@ -32,11 +31,13 @@ def __init__(self, page_size=None, context=None): self._root_router = api_root_router.ApiRootRouter() def _MatchRouter(self, method_name, args): - if (hasattr(self._router, method_name) and - hasattr(self._root_router, method_name)): + if hasattr(self._router, method_name) and hasattr( + self._root_router, method_name + ): mdata = self._router.__class__.GetAnnotatedMethods()[method_name] root_mdata = self._root_router.__class__.GetAnnotatedMethods()[ - method_name] + method_name + ] if args is None: if mdata.args_type is None: @@ -46,13 +47,15 @@ def _MatchRouter(self, method_name, args): else: if mdata.args_type and mdata.args_type.protobuf == args.__class__: return self._router - elif (root_mdata.args_type and - root_mdata.args_type.protobuf == args.__class__): + elif ( + root_mdata.args_type + and root_mdata.args_type.protobuf == args.__class__ + ): return self._root_router raise RuntimeError( - "Can't unambiguously select root/non-root router for %s" % - method_name) + "Can't unambiguously select root/non-root router for %s" % method_name + ) elif hasattr(self._router, method_name): return self._router elif hasattr(self._root_router, method_name): @@ -79,8 +82,9 @@ def _CallMethod(self, method_name, args): except NotImplementedError as e: raise errors.ApiNotImplementedError( "Method {} is not implemented in {}.".format( - method_name, - type(router).__name__)) from e + method_name, type(router).__name__ + ) + ) from e except Exception as e: # pylint: disable=broad-except raise errors.UnknownError(e) diff --git a/grr/server/grr_response_server/bin/api_shell_raw_access_lib_test.py b/grr/server/grr_response_server/bin/api_shell_raw_access_lib_test.py index be8d5d1edb..6f18d1fcbc 100644 --- a/grr/server/grr_response_server/bin/api_shell_raw_access_lib_test.py +++ b/grr/server/grr_response_server/bin/api_shell_raw_access_lib_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import io from absl import app @@ -21,7 +20,8 @@ def setUp(self): super().setUp() self.connector = api_shell_raw_access_lib.RawConnector( context=api_call_context.ApiCallContext(self.test_username), - page_size=10) + page_size=10, + ) def testCorrectlyCallsGeneralMethod(self): self.SetupClients(10) @@ -35,7 +35,8 @@ def testCorrectlyCallsStreamingMethod(self): fixture_test_lib.ClientFixture(client_id) args = vfs_pb2.ApiGetFileBlobArgs( - client_id=client_id, file_path="fs/tsk/c/bin/rbash") + client_id=client_id, file_path="fs/tsk/c/bin/rbash" + ) out = io.BytesIO() self.connector.SendStreamingRequest("GetFileBlob", args).WriteToStream(out) self.assertEqual(out.getvalue(), b"Hello world") diff --git a/grr/server/grr_response_server/bin/config_updater.py b/grr/server/grr_response_server/bin/config_updater.py index 97f95333a2..76a90cdf25 100644 --- a/grr/server/grr_response_server/bin/config_updater.py +++ b/grr/server/grr_response_server/bin/config_updater.py @@ -25,56 +25,70 @@ from grr_response_server.bin import config_updater_util parser = argparse_flags.ArgumentParser( - description=("Set configuration parameters for the GRR Server." - "\nThis script has numerous subcommands to perform " - "various actions. When you are first setting up, you " - "probably only care about 'initialize'.")) + description=( + "Set configuration parameters for the GRR Server." + "\nThis script has numerous subcommands to perform " + "various actions. When you are first setting up, you " + "probably only care about 'initialize'." + ) +) # Generic arguments. parser.add_argument( "--version", action="version", version=config_server.VERSION["packageversion"], - help="Print config updater version number and exit immediately.") + help="Print config updater version number and exit immediately.", +) subparsers = parser.add_subparsers( - title="subcommands", dest="subparser_name", description="valid subcommands") + title="subcommands", dest="subparser_name", description="valid subcommands" +) # Subparsers. parser_generate_keys = subparsers.add_parser( - "generate_keys", help="Generate crypto keys in the configuration.") + "generate_keys", help="Generate crypto keys in the configuration." +) parser_repack_clients = subparsers.add_parser( "repack_clients", - help="Repack the clients binaries with the current configuration.") + help="Repack the clients binaries with the current configuration.", +) parser_initialize = subparsers.add_parser( - "initialize", help="Run all the required steps to setup a new GRR install.") + "initialize", help="Run all the required steps to setup a new GRR install." +) parser_set_var = subparsers.add_parser("set_var", help="Set a config variable.") parser_switch_datastore = subparsers.add_parser( "switch_datastore", - help="Switch from a legacy datastore (AFF4) " - "to the new optimized implementation (REL_DB).") + help=( + "Switch from a legacy datastore (AFF4) " + "to the new optimized implementation (REL_DB)." + ), +) # Update an existing user. parser_update_user = subparsers.add_parser( - "update_user", help="Update user settings.") + "update_user", help="Update user settings." +) parser_update_user.add_argument("username", help="Username to update.") parser_update_user.add_argument( "--password", default=None, - help="New password for this user (will prompt for password if not given).") + help="New password for this user (will prompt for password if not given).", +) parser_update_user.add_argument( "--admin", default=True, type=config_updater_util.ArgparseBool, - help="Make the user an admin, if they aren't already.") + help="Make the user an admin, if they aren't already.", +) parser_add_user = subparsers.add_parser("add_user", help="Add a new user.") @@ -85,121 +99,154 @@ "--admin", default=True, type=config_updater_util.ArgparseBool, - help="Add the user with admin privileges.") + help="Add the user with admin privileges.", +) parser_initialize.add_argument( - "--external_hostname", default=None, help="External hostname to use.") + "--external_hostname", default=None, help="External hostname to use." +) parser_initialize.add_argument( - "--admin_password", default=None, help="Admin password for web interface.") + "--admin_password", default=None, help="Admin password for web interface." +) parser_initialize.add_argument( "--noprompt", default=False, action="store_true", - help="Set to avoid prompting during initialize.") + help="Set to avoid prompting during initialize.", +) parser_initialize.add_argument( "--redownload_templates", default=False, action="store_true", - help="Re-download templates during noninteractive config initialization " - "(server debs already include templates).") + help=( + "Re-download templates during noninteractive config initialization " + "(server debs already include templates)." + ), +) # TODO(hanuszczak): Rename this flag to `repack_templates` (true by default). parser_initialize.add_argument( "--norepack_templates", default=False, action="store_true", - help="Skip template repacking during noninteractive config initialization.") + help="Skip template repacking during noninteractive config initialization.", +) parser_initialize.add_argument( "--mysql_hostname", - help="Hostname for a running MySQL instance (only appplies if --noprompt " - "is set).") + help=( + "Hostname for a running MySQL instance (only appplies if --noprompt " + "is set)." + ), +) parser_initialize.add_argument( "--mysql_port", type=int, - help="Port for a running MySQL instance (only applies if --noprompt " - "is set).") + help=( + "Port for a running MySQL instance (only applies if --noprompt is set)." + ), +) parser_initialize.add_argument( "--mysql_db", - help="Name of GRR's MySQL database (only applies if --noprompt is set).") + help="Name of GRR's MySQL database (only applies if --noprompt is set).", +) parser_initialize.add_argument( "--mysql_fleetspeak_db", - help="Name of Fleetspeak's MySQL database (only applies if --noprompt is set)." + help=( + "Name of Fleetspeak's MySQL database (only applies if --noprompt is" + " set)." + ), ) parser_initialize.add_argument( "--mysql_username", - help="Name of GRR MySQL database user (only applies if --noprompt is set).") + help="Name of GRR MySQL database user (only applies if --noprompt is set).", +) parser_initialize.add_argument( "--mysql_password", - help="Password for GRR MySQL database user (only applies if --noprompt is " - "set).") + help=( + "Password for GRR MySQL database user (only applies if --noprompt is " + "set)." + ), +) parser_initialize.add_argument( "--mysql_client_key_path", - help="The path name of the client private key file.") + help="The path name of the client private key file.", +) parser_initialize.add_argument( "--mysql_client_cert_path", - help="The path name of the client public key certificate file.") + help="The path name of the client public key certificate file.", +) parser_initialize.add_argument( "--mysql_ca_cert_path", - help="The path name of the Certificate Authority (CA) certificate file.") + help="The path name of the Certificate Authority (CA) certificate file.", +) # Deprecated. There is no choice anymore, relational db is always enabled. parser_initialize.add_argument( "--use_rel_db", default=True, action="store_true", - help="Use the new-generation datastore (REL_DB). Deprecated, REL_DB is now " - "the only available choice.") + help=( + "Use the new-generation datastore (REL_DB). Deprecated, REL_DB is now " + "the only available choice." + ), +) parser_initialize.add_argument( "--use_fleetspeak", default=False, action="store_true", - help="Use the new-generation communication framework (Fleetspeak).") + help="Use the new-generation communication framework (Fleetspeak).", +) parser_set_var.add_argument("var", help="Variable to set.") parser_set_var.add_argument("val", help="Value to set.") # Delete an existing user. parser_delete_user = subparsers.add_parser( - "delete_user", help="Delete an user account.") + "delete_user", help="Delete a user account." +) parser_delete_user.add_argument("username", help="Username to delete.") # Show user account. parser_show_user = subparsers.add_parser( - "show_user", help="Display user settings or list all users.") + "show_user", help="Display user settings or list all users." +) parser_show_user.add_argument( "--username", default=None, nargs="?", - help="Username to display. If not specified, list all users.") + help="Username to display. If not specified, list all users.", +) # Generate Keys Arguments parser_generate_keys.add_argument( "--overwrite_keys", default=False, action="store_true", - help="Required to overwrite existing keys.") + help="Required to overwrite existing keys.", +) # Repack arguments. parser_repack_clients.add_argument( "--noupload", default=False, action="store_true", - help="Don't upload the client binaries to the datastore.") + help="Don't upload the client binaries to the datastore.", +) def _ExtendWithUploadArgs(upload_parser): @@ -211,24 +258,31 @@ def _ExtendWithUploadSignedArgs(upload_signed_parser): "--platform", required=True, choices=maintenance_utils.SUPPORTED_PLATFORMS, - help="The platform the file will be used on. This determines which " - "signing keys to use, and the path on the server the file will be " - "uploaded to.") + help=( + "The platform the file will be used on. This determines which " + "signing keys to use, and the path on the server the file will be " + "uploaded to." + ), + ) upload_signed_parser.add_argument( "--upload_subdirectory", required=False, default="", - help="Directory path under which to place an uploaded python-hack " - "or executable, e.g for a Windows executable named 'hello.exe', " - "if --upload_subdirectory is set to 'test', the path of the " - "uploaded binary will be 'windows/test/hello.exe', relative to " - "the root path for executables.") + help=( + "Directory path under which to place an uploaded python-hack " + "or executable, e.g. for a Windows executable named 'hello.exe', " + "if --upload_subdirectory is set to 'test', the path of the " + "uploaded binary will be 'windows/test/hello.exe', relative to " + "the root path for executables." + ), + ) # Upload parsers. parser_upload_artifact = subparsers.add_parser( - "upload_artifact", help="Upload a raw json artifact file.") + "upload_artifact", help="Upload a raw json artifact file." +) _ExtendWithUploadArgs(parser_upload_artifact) @@ -236,41 +290,55 @@ def _ExtendWithUploadSignedArgs(upload_signed_parser): "--overwrite_artifact", default=False, action="store_true", - help="Overwrite existing artifact.") + help="Overwrite existing artifact.", +) parser_delete_artifacts = subparsers.add_parser( - "delete_artifacts", help="Delete a list of artifacts from the data store.") + "delete_artifacts", help="Delete a list of artifacts from the data store." +) parser_delete_artifacts.add_argument( - "--artifact", default=[], action="append", help="The artifacts to delete.") + "--artifact", default=[], action="append", help="The artifacts to delete." +) parser_upload_python = subparsers.add_parser( "upload_python", - help="Sign and upload a 'python hack' which can be used to execute code on " - "a client.") + help=( + "Sign and upload a 'python hack' which can be used to execute code on " + "a client." + ), +) _ExtendWithUploadArgs(parser_upload_python) _ExtendWithUploadSignedArgs(parser_upload_python) parser_upload_exe = subparsers.add_parser( "upload_exe", - help="Sign and upload an executable which can be used to execute code on " - "a client.") + help=( + "Sign and upload an executable which can be used to execute code on " + "a client." + ), +) _ExtendWithUploadArgs(parser_upload_exe) _ExtendWithUploadSignedArgs(parser_upload_exe) parser_rotate_key = subparsers.add_parser( - "rotate_server_key", help="Sets a new server key.") + "rotate_server_key", help="Sets a new server key." +) parser_rotate_key.add_argument( - "--common_name", default="grr", help="The common name to use for the cert.") + "--common_name", default="grr", help="The common name to use for the cert." +) parser_rotate_key.add_argument( "--keylength", default=None, - help="The key length for the new server key. " - "Defaults to the Server.rsa_key_length config option.") + help=( + "The key length for the new server key. " + "Defaults to the Server.rsa_key_length config option." + ), +) def main(args): @@ -297,14 +365,16 @@ def main(args): redownload_templates=args.redownload_templates, repack_templates=not args.norepack_templates, use_fleetspeak=args.use_fleetspeak, - mysql_fleetspeak_db=args.mysql_fleetspeak_db) + mysql_fleetspeak_db=args.mysql_fleetspeak_db, + ) else: config_updater_util.Initialize( grr_config.CONFIG, external_hostname=args.external_hostname, admin_password=args.admin_password, redownload_templates=args.redownload_templates, - repack_templates=not args.norepack_templates) + repack_templates=not args.norepack_templates, + ) return server_startup.Init() @@ -317,7 +387,8 @@ def main(args): if args.subparser_name == "generate_keys": try: config_updater_keys_util.GenerateKeys( - grr_config.CONFIG, overwrite_keys=args.overwrite_keys) + grr_config.CONFIG, overwrite_keys=args.overwrite_keys + ) except RuntimeError as e: # GenerateKeys will raise if keys exist and overwrite_keys is not set. print("ERROR: %s" % e) @@ -336,28 +407,32 @@ def main(args): elif args.subparser_name == "update_user": config_updater_util.UpdateUser( - args.username, password=args.password, is_admin=args.admin) + args.username, password=args.password, is_admin=args.admin + ) elif args.subparser_name == "delete_user": config_updater_util.DeleteUser(args.username) elif args.subparser_name == "add_user": config_updater_util.CreateUser( - args.username, password=args.password, is_admin=args.admin) + args.username, password=args.password, is_admin=args.admin + ) elif args.subparser_name == "upload_python": config_updater_util.UploadSignedBinary( args.file, config_pb2.ApiGrrBinary.Type.PYTHON_HACK, args.platform, - upload_subdirectory=args.upload_subdirectory) + upload_subdirectory=args.upload_subdirectory, + ) elif args.subparser_name == "upload_exe": config_updater_util.UploadSignedBinary( args.file, config_pb2.ApiGrrBinary.Type.EXECUTABLE, args.platform, - upload_subdirectory=args.upload_subdirectory) + upload_subdirectory=args.upload_subdirectory, + ) elif args.subparser_name == "set_var": var = args.var @@ -391,7 +466,7 @@ def main(args): elif args.subparser_name == "rotate_server_key": print(""" -You are about to rotate the server key. Note that: +You are about to rotate the Fleetspeak server key. Note that: - Clients might experience intermittent connection problems after the server keys rotated. @@ -401,20 +476,12 @@ def main(args): to accept any certificate with a smaller serial number from that point on. """) - if input("Continue? [yN]: ").upper() == "Y": - if args.keylength: - keylength = int(args.keylength) - else: - keylength = grr_config.CONFIG["Server.rsa_key_length"] - - maintenance_utils.RotateServerKey( - cn=args.common_name, keylength=keylength) - if grr_config.CONFIG["Server.fleetspeak_enabled"]: config_updater_util.FleetspeakConfig().RotateKey() - print("Fleetspeak server key rotated, " - "please restart fleetspeak-server.") + print( + "Fleetspeak server key rotated, please restart fleetspeak-server." + ) def Run(): diff --git a/grr/server/grr_response_server/bin/config_updater_keys_util.py b/grr/server/grr_response_server/bin/config_updater_keys_util.py index 40c1a279b7..90e1a1ffe2 100644 --- a/grr/server/grr_response_server/bin/config_updater_keys_util.py +++ b/grr/server/grr_response_server/bin/config_updater_keys_util.py @@ -4,7 +4,6 @@ from grr_response_core import config as grr_config from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import crypto as rdf_crypto -from grr_response_server import key_utils class Error(Exception): @@ -31,36 +30,29 @@ def _GenerateCSRFKey(config): def GenerateKeys(config, overwrite_keys=False): """Generate the keys we need for a GRR server.""" - if not hasattr(key_utils, "MakeCACert"): - raise OpenSourceKeyUtilsRequiredError( - "Generate keys can only run with open source key_utils.") - if (config.Get("PrivateKeys.server_key", default=None) and - not overwrite_keys): - print(config.Get("PrivateKeys.server_key")) + + if ( + config.Get("PrivateKeys.executable_signing_private_key", default=None) + and not overwrite_keys + ): + print(config.Get("PrivateKeys.executable_signing_private_key")) raise KeysAlreadyExistError( - "Config %s already has keys, use --overwrite_keys to " - "override." % config.parser) + "Config %s already has keys, use --overwrite_keys to override." + % config.parser + ) length = grr_config.CONFIG["Server.rsa_key_length"] print("All keys will have a bit length of %d." % length) print("Generating executable signing key") executable_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=length) - config.Set("PrivateKeys.executable_signing_private_key", - executable_key.AsPEM().decode("ascii")) - config.Set("Client.executable_signing_public_key", - executable_key.GetPublicKey().AsPEM().decode("ascii")) - - print("Generating CA keys") - ca_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=length) - ca_cert = key_utils.MakeCACert(ca_key) - config.Set("CA.certificate", ca_cert.AsPEM().decode("ascii")) - config.Set("PrivateKeys.ca_key", ca_key.AsPEM().decode("ascii")) - - print("Generating Server keys") - server_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=length) - server_cert = key_utils.MakeCASignedCert(u"grr", server_key, ca_cert, ca_key) - config.Set("Frontend.certificate", server_cert.AsPEM().decode("ascii")) - config.Set("PrivateKeys.server_key", server_key.AsPEM().decode("ascii")) + config.Set( + "PrivateKeys.executable_signing_private_key", + executable_key.AsPEM().decode("ascii"), + ) + config.Set( + "Client.executable_signing_public_key", + executable_key.GetPublicKey().AsPEM().decode("ascii"), + ) print("Generating secret key for csrf protection.") _GenerateCSRFKey(config) diff --git a/grr/server/grr_response_server/bin/config_updater_util.py b/grr/server/grr_response_server/bin/config_updater_util.py index 834ef37de8..2ba9ee8d6b 100644 --- a/grr/server/grr_response_server/bin/config_updater_util.py +++ b/grr/server/grr_response_server/bin/config_updater_util.py @@ -10,7 +10,7 @@ import subprocess import sys import time -from typing import Optional, Text, Generator +from typing import Generator, Optional from urllib import parse as urlparse import MySQLdb @@ -56,7 +56,8 @@ class ConfigInitError(Exception): def __init__(self): super().__init__( "Aborting config initialization. Please run 'grr_config_updater " - "initialize' to retry initialization.") + "initialize' to retry initialization." + ) class BinaryTooLargeError(Exception): @@ -76,10 +77,9 @@ def __init__(self, username): def ImportConfig(filename, config): """Reads an old config file and imports keys and user accounts.""" - sections_to_import = ["PrivateKeys"] entries_to_import = [ - "Client.executable_signing_public_key", "CA.certificate", - "Frontend.certificate" + "Client.executable_signing_public_key", + "PrivateKeys.executable_signing_private_key", ] options_imported = 0 old_config = grr_config.CONFIG.MakeNewConfig() @@ -87,8 +87,7 @@ def ImportConfig(filename, config): for entry in old_config.raw_data: try: - section = entry.split(".")[0] - if section in sections_to_import or entry in entries_to_import: + if entry in entries_to_import: config.Set(entry, old_config.Get(entry)) print("Imported %s." % entry) options_imported += 1 @@ -118,12 +117,17 @@ def RetryQuestion(question_text, output_re="", default_val=None): def RetryBoolQuestion(question_text, default_bool): if not isinstance(default_bool, bool): - raise ValueError("default_bool should be a boolean, not %s" % - type(default_bool)) + raise ValueError( + "default_bool should be a boolean, not %s" % type(default_bool) + ) default_val = "Y" if default_bool else "N" prompt_suff = "[Yn]" if default_bool else "[yN]" - return RetryQuestion("%s %s: " % (question_text, prompt_suff), "[yY]|[nN]", - default_val)[0].upper() == "Y" + return ( + RetryQuestion( + "%s %s: " % (question_text, prompt_suff), "[yY]|[nN]", default_val + )[0].upper() + == "Y" + ) def RetryIntQuestion(question_text: str, default_int: int) -> int: @@ -137,7 +141,7 @@ def GetPassword(question_text: str) -> str: # pytype: enable=wrong-arg-types -def ConfigureHostnames(config, external_hostname: Optional[Text] = None): +def ConfigureHostnames(config, external_hostname: Optional[str] = None): """This configures the hostnames stored in the config.""" if not external_hostname: try: @@ -146,30 +150,36 @@ def ConfigureHostnames(config, external_hostname: Optional[Text] = None): print("Sorry, we couldn't guess your hostname.\n") external_hostname = RetryQuestion( - "Please enter your hostname e.g. " - "grr.example.com", "^[\\.A-Za-z0-9-]+$", external_hostname) + "Please enter your hostname e.g. grr.example.com", + "^[\\.A-Za-z0-9-]+$", + external_hostname, + ) print("""\n\n-=Server URL=- The Server URL specifies the URL that the clients will connect to communicate with the server. For best results this should be publicly accessible. By default this will be port 8080 with the URL ending in /control. """) - frontend_url = RetryQuestion("Frontend URL", "^http://.*/$", - "http://%s:8080/" % external_hostname) + frontend_url = RetryQuestion( + "Frontend URL", "^http://.*/$", "http://%s:8080/" % external_hostname + ) config.Set("Client.server_urls", [frontend_url]) frontend_port = urlparse.urlparse(frontend_url).port or grr_config.CONFIG.Get( - "Frontend.bind_port") + "Frontend.bind_port" + ) config.Set("Frontend.bind_port", frontend_port) print("""\n\n-=AdminUI URL=-: The UI URL specifies where the Administrative Web Interface can be found. """) - ui_url = RetryQuestion("AdminUI URL", "^http[s]*://.*$", - "http://%s:8000" % external_hostname) + ui_url = RetryQuestion( + "AdminUI URL", "^http[s]*://.*$", "http://%s:8000" % external_hostname + ) config.Set("AdminUI.url", ui_url) ui_port = urlparse.urlparse(ui_url).port or grr_config.CONFIG.Get( - "AdminUI.port") + "AdminUI.port" + ) config.Set("AdminUI.port", ui_port) @@ -190,7 +200,8 @@ def CheckMySQLConnection(db_options): db=db_options["Mysql.database_name"], user=db_options["Mysql.database_username"], passwd=db_options["Mysql.database_password"], - charset="utf8") + charset="utf8", + ) if "Mysql.port" in db_options: connection_options["port"] = db_options["Mysql.port"] @@ -223,16 +234,22 @@ def CheckMySQLConnection(db_options): if len(mysql_op_error.args) < 2: # We expect the exception's arguments to be an error-code and # an error message. - print("Unexpected exception type received from MySQL. %d attempts " - "left: %s" % (tries_left, mysql_op_error)) + print( + "Unexpected exception type received from MySQL. %d attempts " + "left: %s" % (tries_left, mysql_op_error) + ) time.sleep(_MYSQL_RETRY_WAIT_SECS) continue if mysql_op_error.args[0] == mysql_conn_errors.CONNECTION_ERROR: - print("Failed to connect to MySQL. Is it running? %d attempts left." % - tries_left) + print( + "Failed to connect to MySQL. Is it running? %d attempts left." + % tries_left + ) elif mysql_op_error.args[0] == mysql_conn_errors.UNKNOWN_HOST: - print("Unknown-hostname error encountered while trying to connect to " - "MySQL.") + print( + "Unknown-hostname error encountered while trying to connect to " + "MySQL." + ) return False # No need for retry. elif mysql_op_error.args[0] == general_mysql_errors.BAD_DB_ERROR: # GRR db doesn't exist yet. That's expected if this is the initial @@ -240,17 +257,24 @@ def CheckMySQLConnection(db_options): return True elif mysql_op_error.args[0] in ( general_mysql_errors.ACCESS_DENIED_ERROR, - general_mysql_errors.DBACCESS_DENIED_ERROR): - print("Permission error encountered while trying to connect to " - "MySQL: %s" % mysql_op_error) + general_mysql_errors.DBACCESS_DENIED_ERROR, + ): + print( + "Permission error encountered while trying to connect to MySQL: %s" + % mysql_op_error + ) return False # No need for retry. else: - print("Unexpected operational error encountered while trying to " - "connect to MySQL. %d attempts left: %s" % - (tries_left, mysql_op_error)) + print( + "Unexpected operational error encountered while trying to " + "connect to MySQL. %d attempts left: %s" + % (tries_left, mysql_op_error) + ) except MySQLdb.Error as mysql_error: - print("Unexpected error encountered while trying to connect to MySQL. " - "%d attempts left: %s" % (tries_left, mysql_error)) + print( + "Unexpected error encountered while trying to connect to MySQL. " + "%d attempts left: %s" % (tries_left, mysql_error) + ) time.sleep(_MYSQL_RETRY_WAIT_SECS) return False @@ -265,34 +289,42 @@ def ConfigureMySQLDatastore(config): print("GRR will use MySQL as its database backend. Enter connection details:") datastore_init_complete = False while not datastore_init_complete: - db_options["Mysql.host"] = RetryQuestion("MySQL Host", "^[\\.A-Za-z0-9-]+$", - config["Mysql.host"]) + db_options["Mysql.host"] = RetryQuestion( + "MySQL Host", "^[\\.A-Za-z0-9-]+$", config["Mysql.host"] + ) db_options["Mysql.port"] = int( - RetryQuestion("MySQL Port (0 for local socket)", "^[0-9]+$", - config["Mysql.port"])) - db_options["Mysql.database"] = RetryQuestion("MySQL Database", - "^[A-Za-z0-9-]+$", - config["Mysql.database_name"]) + RetryQuestion( + "MySQL Port (0 for local socket)", "^[0-9]+$", config["Mysql.port"] + ) + ) + db_options["Mysql.database"] = RetryQuestion( + "MySQL Database", "^[A-Za-z0-9-]+$", config["Mysql.database_name"] + ) db_options["Mysql.database_name"] = db_options["Mysql.database"] db_options["Mysql.username"] = RetryQuestion( - "MySQL Username", "[A-Za-z0-9-@]+$", config["Mysql.database_username"]) + "MySQL Username", "[A-Za-z0-9-@]+$", config["Mysql.database_username"] + ) db_options["Mysql.database_username"] = db_options["Mysql.username"] db_options["Mysql.password"] = GetPassword( - "Please enter password for database user %s: " % - db_options["Mysql.username"]) + "Please enter password for database user %s: " + % db_options["Mysql.username"] + ) db_options["Mysql.database_password"] = db_options["Mysql.password"] use_ssl = RetryBoolQuestion("Configure SSL connections for MySQL?", False) if use_ssl: db_options["Mysql.client_key_path"] = RetryQuestion( "Path to the client private key file", - default_val=config["Mysql.client_key_path"]) + default_val=config["Mysql.client_key_path"], + ) db_options["Mysql.client_cert_path"] = RetryQuestion( "Path to the client certificate file", - default_val=config["Mysql.client_cert_path"]) + default_val=config["Mysql.client_cert_path"], + ) db_options["Mysql.ca_cert_path"] = RetryQuestion( "Path to the CA certificate file", - default_val=config["Mysql.ca_cert_path"]) + default_val=config["Mysql.ca_cert_path"], + ) if CheckMySQLConnection(db_options): print("Successfully connected to MySQL with the provided details.") @@ -301,7 +333,9 @@ def ConfigureMySQLDatastore(config): print("Error: Could not connect to MySQL with the provided details.") should_retry = RetryBoolQuestion( "Re-enter MySQL details? Answering 'no' will abort config " - "initialization: ", True) + "initialization: ", + True, + ) if should_retry: db_options.clear() else: @@ -328,22 +362,26 @@ def __init__(self): self.mysql_database: str = None self.mysql_unix_socket: str = None self.config_dir = package.ResourcePath( - "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server") + "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server" + ) self._fleetspeak_config_command_path = package.ResourcePath( "fleetspeak-server-bin", - "fleetspeak-server-bin/usr/bin/fleetspeak-config") + "fleetspeak-server-bin/usr/bin/fleetspeak-config", + ) def Prompt(self, config): """Sets up the in-memory configuration interactively.""" if self._IsFleetspeakPresent(): self.use_fleetspeak = RetryBoolQuestion( - "Use Fleetspeak (next generation communication " - "framework)?", True) + "Use Fleetspeak (next generation communication framework)?", True + ) else: self.use_fleetspeak = False - print("Fleetspeak (next generation " - "communication framework) seems to be missing.") + print( + "Fleetspeak (next generation " + "communication framework) seems to be missing." + ) print("Skipping Fleetspeak configuration.\n") if self.use_fleetspeak: @@ -354,11 +392,14 @@ def Prompt(self, config): print("Sorry, we couldn't guess your hostname.\n") self.external_hostname = RetryQuestion( - "Please enter your hostname e.g. " - "grr.example.com", "^[\\.A-Za-z0-9-]+$", self.external_hostname) + "Please enter your hostname e.g. grr.example.com", + "^[\\.A-Za-z0-9-]+$", + self.external_hostname, + ) - self.https_port = RetryIntQuestion("Fleetspeak public HTTPS port", - self.https_port) + self.https_port = RetryIntQuestion( + "Fleetspeak public HTTPS port", self.https_port + ) self._PromptMySQL(config) @@ -376,8 +417,9 @@ def RotateKey(self): os.rename(self._ConfigPath(cert_file), self._ConfigPath(old_file)) # Run fleetspeak-config to regenerate them subprocess.check_call([ - self._fleetspeak_config_command_path, "-config", - self._ConfigPath("fleetspeak_config.config") + self._fleetspeak_config_command_path, + "-config", + self._ConfigPath("fleetspeak_config.config"), ]) def _ConfigPath(self, *path_components: str) -> str: @@ -393,28 +435,40 @@ def _IsFleetspeakPresent(self) -> bool: def _PromptMySQLOnce(self, config): """Prompt the MySQL configuration once.""" - self.mysql_host = RetryQuestion("Fleetspeak MySQL Host", - "^[\\.A-Za-z0-9-]+$", self.mysql_host or - config["Mysql.host"]) - self.mysql_port = RetryIntQuestion( - "Fleetspeak MySQL Port (0 for local socket)", self.mysql_port or - 0) or None + self.mysql_host = RetryQuestion( + "Fleetspeak MySQL Host", + "^[\\.A-Za-z0-9-]+$", + self.mysql_host or config["Mysql.host"], + ) + self.mysql_port = ( + RetryIntQuestion( + "Fleetspeak MySQL Port (0 for local socket)", self.mysql_port or 0 + ) + or None + ) if self.mysql_port is None: # golang's mysql connector needs the socket specified explicitly. self.mysql_unix_socket = RetryQuestion( - "Fleetspeak MySQL local socket path", ".+", - self._FindMysqlUnixSocket() or "") - - self.mysql_database = RetryQuestion("Fleetspeak MySQL Database", - "^[A-Za-z0-9-]+$", - self.mysql_database or "fleetspeak") + "Fleetspeak MySQL local socket path", + ".+", + self._FindMysqlUnixSocket() or "", + ) + + self.mysql_database = RetryQuestion( + "Fleetspeak MySQL Database", + "^[A-Za-z0-9-]+$", + self.mysql_database or "fleetspeak", + ) self.mysql_username = RetryQuestion( - "Fleetspeak MySQL Username", "[A-Za-z0-9-@]+$", self.mysql_username or - config["Mysql.database_username"]) + "Fleetspeak MySQL Username", + "[A-Za-z0-9-@]+$", + self.mysql_username or config["Mysql.database_username"], + ) self.mysql_password = GetPassword( - f"Please enter password for database user {self.mysql_username}: ") + f"Please enter password for database user {self.mysql_username}: " + ) def _PromptMySQL(self, config): """Prompts the MySQL configuration, retrying if the configuration is invalid.""" @@ -425,8 +479,9 @@ def _PromptMySQL(self, config): return else: print("Error: Could not connect to MySQL with the given configuration.") - retry = RetryBoolQuestion("Do you want to retry MySQL configuration?", - True) + retry = RetryBoolQuestion( + "Do you want to retry MySQL configuration?", True + ) if not retry: raise ConfigInitError() @@ -439,15 +494,18 @@ def _WriteDisabled(self, config): if self._IsFleetspeakPresent(): with open(self._ConfigPath("disabled"), "w") as f: - f.write("The existence of this file disables the " - "fleetspeak-server.service systemd unit.\n") + f.write( + "The existence of this file disables the " + "fleetspeak-server.service systemd unit.\n" + ) def _WriteEnabled(self, config): """Applies the in-memory configuration for the use_fleetspeak case.""" service_config = services_pb2.ServiceConfig(name="GRR", factory="GRPC") grpc_config = grpcservice_pb2.Config( - target="localhost:{}".format(self.grr_port), insecure=True) + target="localhost:{}".format(self.grr_port), insecure=True + ) service_config.config.Pack(grpc_config) server_conf = server_pb2.ServerConfig(services=[service_config]) server_conf.broadcast_poll_time.seconds = 1 @@ -463,7 +521,9 @@ def _WriteEnabled(self, config): user=self.mysql_username, password=self.mysql_password, socket=self.mysql_unix_socket, - db=self.mysql_database)) + db=self.mysql_database, + ) + ) else: cp.components_config.mysql_data_source_name = ( "{user}:{password}@tcp({host}:{port})/{db}".format( @@ -471,31 +531,39 @@ def _WriteEnabled(self, config): password=self.mysql_password, host=self.mysql_host, port=self.mysql_port, - db=self.mysql_database)) + db=self.mysql_database, + ) + ) cp.components_config.https_config.listen_address = "{}:{}".format( - self.external_hostname, self.https_port) + self.external_hostname, self.https_port + ) cp.components_config.https_config.disable_streaming = False cp.components_config.admin_config.listen_address = "localhost:{}".format( - self.admin_port) + self.admin_port + ) cp.public_host_port.append(cp.components_config.https_config.listen_address) cp.server_component_configuration_file = self._ConfigPath( - "server.components.config") + "server.components.config" + ) cp.trusted_cert_file = self._ConfigPath("trusted_cert.pem") cp.trusted_cert_key_file = self._ConfigPath("trusted_cert_key.pem") cp.server_cert_file = self._ConfigPath("server_cert.pem") cp.server_cert_key_file = self._ConfigPath("server_cert_key.pem") cp.linux_client_configuration_file = self._ConfigPath("linux_client.config") cp.windows_client_configuration_file = self._ConfigPath( - "windows_client.config") + "windows_client.config" + ) cp.darwin_client_configuration_file = self._ConfigPath( - "darwin_client.config") + "darwin_client.config" + ) with open(self._ConfigPath("fleetspeak_config.config"), "w") as f: f.write(text_format.MessageToString(cp)) subprocess.check_call([ - self._fleetspeak_config_command_path, "-config", - self._ConfigPath("fleetspeak_config.config") + self._fleetspeak_config_command_path, + "-config", + self._ConfigPath("fleetspeak_config.config"), ]) # These modules don't exist on Windows, so importing locally. @@ -504,11 +572,15 @@ def _WriteEnabled(self, config): import pwd # pylint: enable=g-import-not-at-top - if (os.geteuid() == 0 and pwd.getpwnam("fleetspeak") and - grp.getgrnam("fleetspeak") and - os.path.exists("/etc/fleetspeak-server")): + if ( + os.geteuid() == 0 + and pwd.getpwnam("fleetspeak") + and grp.getgrnam("fleetspeak") + and os.path.exists("/etc/fleetspeak-server") + ): subprocess.check_call( - ["chown", "-R", "fleetspeak:fleetspeak", "/etc/fleetspeak-server"]) + ["chown", "-R", "fleetspeak:fleetspeak", "/etc/fleetspeak-server"] + ) try: os.unlink(self._ConfigPath("disabled")) @@ -519,24 +591,37 @@ def _WriteEnabled(self, config): config.Set("Client.fleetspeak_enabled", True) config.Set("ClientBuilder.fleetspeak_bundled", True) config.Set( - "Target:Linux", { - "ClientBuilder.fleetspeak_client_config": + "Target:Linux", + { + "ClientBuilder.fleetspeak_client_config": ( cp.linux_client_configuration_file - }) + ) + }, + ) config.Set( - "Target:Windows", { - "ClientBuilder.fleetspeak_client_config": + "Target:Windows", + { + "ClientBuilder.fleetspeak_client_config": ( cp.windows_client_configuration_file - }) + ) + }, + ) config.Set( - "Target:Darwin", { - "ClientBuilder.fleetspeak_client_config": + "Target:Darwin", + { + "ClientBuilder.fleetspeak_client_config": ( cp.darwin_client_configuration_file - }) - config.Set("Server.fleetspeak_server", - cp.components_config.admin_config.listen_address) - config.Set("FleetspeakFrontend Context", - {"Server.fleetspeak_message_listen_address": grpc_config.target}) + ) + }, + ) + config.Set( + "Server.fleetspeak_server", + cp.components_config.admin_config.listen_address, + ) + config.Set( + "FleetspeakFrontend Context", + {"Server.fleetspeak_message_listen_address": grpc_config.target}, + ) def _CheckMySQLConnection(self): """Checks the MySQL configuration by attempting a connection.""" @@ -588,36 +673,43 @@ def _FindMysqlUnixSocket(self) -> Optional[str]: def ConfigureDatastore(config): """Guides the user through configuration of the datastore.""" - print("\n\n-=GRR Datastore=-\n" - "For GRR to work each GRR server has to be able to communicate with\n" - "the datastore. To do this we need to configure a datastore.\n") + print( + "\n\n-=GRR Datastore=-\n" + "For GRR to work each GRR server has to be able to communicate with\n" + "the datastore. To do this we need to configure a datastore.\n" + ) ConfigureMySQLDatastore(config) -def ConfigureUrls(config, external_hostname: Optional[Text] = None): +def ConfigureUrls(config, external_hostname: Optional[str] = None): """Guides the user through configuration of various URLs used by GRR.""" - print("\n\n-=GRR URLs=-\n" - "For GRR to work each client has to be able to communicate with the\n" - "server. To do this we normally need a public dns name or IP address\n" - "to communicate with. In the standard configuration this will be used\n" - "to host both the client facing server and the admin user interface.\n") + print( + "\n\n-=GRR URLs=-\n" + "For GRR to work each client has to be able to communicate with the\n" + "server. To do this we normally need a public dns name or IP address\n" + "to communicate with. In the standard configuration this will be used\n" + "to host both the client facing server and the admin user interface.\n" + ) existing_ui_urn = grr_config.CONFIG.Get("AdminUI.url", default=None) existing_frontend_urns = grr_config.CONFIG.Get("Client.server_urls") if not existing_frontend_urns: # Port from older deprecated setting Client.control_urls. existing_control_urns = grr_config.CONFIG.Get( - "Client.control_urls", default=None) + "Client.control_urls", default=None + ) if existing_control_urns is not None: existing_frontend_urns = [] for existing_control_urn in existing_control_urns: if not existing_control_urn.endswith("control"): - raise RuntimeError("Invalid existing control URL: %s" % - existing_control_urn) + raise RuntimeError( + "Invalid existing control URL: %s" % existing_control_urn + ) existing_frontend_urns.append( - existing_control_urn.rsplit("/", 1)[0] + "/") + existing_control_urn.rsplit("/", 1)[0] + "/" + ) config.Set("Client.server_urls", existing_frontend_urns) config.Set("Client.control_urls", ["deprecated use Client.server_urls"]) @@ -625,50 +717,67 @@ def ConfigureUrls(config, external_hostname: Optional[Text] = None): if not existing_frontend_urns or not existing_ui_urn: ConfigureHostnames(config, external_hostname=external_hostname) else: - print("Found existing settings:\n AdminUI URL: %s\n " - "Frontend URL(s): %s\n" % (existing_ui_urn, existing_frontend_urns)) + print( + "Found existing settings:\n AdminUI URL: %s\n Frontend URL(s): %s\n" + % (existing_ui_urn, existing_frontend_urns) + ) if not RetryBoolQuestion("Do you want to keep this configuration?", True): ConfigureHostnames(config, external_hostname=external_hostname) def ConfigureEmails(config): """Guides the user through email setup.""" - print("\n\n-=GRR Emails=-\n" - "GRR needs to be able to send emails for various logging and\n" - "alerting functions. The email domain will be appended to GRR\n" - "usernames when sending emails to users.\n") + print( + "\n\n-=GRR Emails=-\n" + "GRR needs to be able to send emails for various logging and\n" + "alerting functions. The email domain will be appended to GRR\n" + "usernames when sending emails to users.\n" + ) existing_log_domain = grr_config.CONFIG.Get("Logging.domain", default=None) existing_al_email = grr_config.CONFIG.Get( - "Monitoring.alert_email", default=None) + "Monitoring.alert_email", default=None + ) existing_em_email = grr_config.CONFIG.Get( - "Monitoring.emergency_access_email", default=None) + "Monitoring.emergency_access_email", default=None + ) if existing_log_domain and existing_al_email and existing_em_email: - print("Found existing settings:\n" - " Email Domain: %s\n Alert Email Address: %s\n" - " Emergency Access Email Address: %s\n" % - (existing_log_domain, existing_al_email, existing_em_email)) + print( + "Found existing settings:\n" + " Email Domain: %s\n Alert Email Address: %s\n" + " Emergency Access Email Address: %s\n" + % (existing_log_domain, existing_al_email, existing_em_email) + ) if RetryBoolQuestion("Do you want to keep this configuration?", True): return - print("\n\n-=Monitoring/Email Domain=-\n" - "Emails concerning alerts or updates must be sent to this domain.\n") - domain = RetryQuestion("Email Domain e.g example.com", - "^([\\.A-Za-z0-9-]+)*$", - grr_config.CONFIG.Get("Logging.domain")) + print( + "\n\n-=Monitoring/Email Domain=-\n" + "Emails concerning alerts or updates must be sent to this domain.\n" + ) + domain = RetryQuestion( + "Email Domain e.g. example.com", + "^([\\.A-Za-z0-9-]+)*$", + grr_config.CONFIG.Get("Logging.domain"), + ) config.Set("Logging.domain", domain) - print("\n\n-=Alert Email Address=-\n" - "Address where monitoring events get sent, e.g. crashed clients, \n" - "broken server, etc.\n") + print( + "\n\n-=Alert Email Address=-\n" + "Address where monitoring events get sent, e.g. crashed clients, \n" + "broken server, etc.\n" + ) email = RetryQuestion("Alert Email Address", "", "grr-monitoring@%s" % domain) config.Set("Monitoring.alert_email", email) - print("\n\n-=Emergency Email Address=-\n" - "Address where high priority events such as an emergency ACL bypass " - "are sent.\n") - emergency_email = RetryQuestion("Emergency Access Email Address", "", - "grr-emergency@%s" % domain) + print( + "\n\n-=Emergency Email Address=-\n" + "Address where high priority events such as an emergency ACL bypass " + "are sent.\n" + ) + emergency_email = RetryQuestion( + "Emergency Access Email Address", "", "grr-emergency@%s" % domain + ) config.Set("Monitoring.emergency_access_email", emergency_email) @@ -680,21 +789,29 @@ def InstallTemplatePackage(): # Install the GRR server component to satisfy the dependency below. major_minor_version = ".".join( - pkg_resources.get_distribution("grr-response-core").version.split(".") - [0:2]) + pkg_resources.get_distribution("grr-response-core").version.split(".")[ + 0:2 + ] + ) # Note that this version spec requires a recent version of pip subprocess.check_call([ - sys.executable, pip, "install", "--upgrade", "-f", + sys.executable, + pip, + "install", + "--upgrade", + "-f", "https://storage.googleapis.com/releases.grr-response.com/index.html", - "grr-response-templates==%s.*" % major_minor_version + "grr-response-templates==%s.*" % major_minor_version, ]) -def FinalizeConfigInit(config, - admin_password: Optional[Text] = None, - redownload_templates: bool = False, - repack_templates: bool = True, - prompt: bool = True): +def FinalizeConfigInit( + config, + admin_password: Optional[str] = None, + redownload_templates: bool = False, + repack_templates: bool = True, + prompt: bool = True, +): """Performs the final steps of config initialization.""" config.Set("Server.initialized", True) print("\nWriting configuration to %s." % config["Config.writeback"]) @@ -709,8 +826,13 @@ def FinalizeConfigInit(config, except UserAlreadyExistsError: if prompt: # pytype: disable=wrong-arg-count - if ((input("User 'admin' already exists, do you want to " - "reset the password? [yN]: ").upper() or "N") == "Y"): + if ( + input( + "User 'admin' already exists, do you want to " + "reset the password? [yN]: " + ).upper() + or "N" + ) == "Y": UpdateUser("admin", password=admin_password, is_admin=True) # pytype: enable=wrong-arg-count else: @@ -719,19 +841,23 @@ def FinalizeConfigInit(config, print("\nStep 4: Repackaging clients with new configuration.") if prompt: redownload_templates = RetryBoolQuestion( - "Server debs include client templates. Re-download templates?", False) + "Server debs include client templates. Re-download templates?", False + ) repack_templates = RetryBoolQuestion("Repack client templates?", True) if redownload_templates: InstallTemplatePackage() # Build debug binaries, then build release binaries. if repack_templates: repacking.TemplateRepacker().RepackAllTemplates(upload=True) - print("\nGRR Initialization complete! You can edit the new configuration " - "in %s.\n" % config["Config.writeback"]) + print( + "\nGRR Initialization complete! You can edit the new configuration " + "in %s.\n" + % config["Config.writeback"] + ) if prompt and os.geteuid() == 0: restart = RetryBoolQuestion( - "Restart service for the new configuration " - "to take effect?", True) + "Restart service for the new configuration to take effect?", True + ) if restart: for service in ("grr-server", "fleetspeak-server"): try: @@ -741,15 +867,18 @@ def FinalizeConfigInit(config, print(f"Failed to restart: {service}.") print(e, file=sys.stderr) else: - print("Please restart the service for the new configuration to take " - "effect.\n") + print( + "Please restart the service for the new configuration to take effect.\n" + ) -def Initialize(config=None, - external_hostname: Optional[Text] = None, - admin_password: Optional[Text] = None, - redownload_templates: bool = False, - repack_templates: bool = True): +def Initialize( + config=None, + external_hostname: Optional[str] = None, + admin_password: Optional[str] = None, + redownload_templates: bool = False, + repack_templates: bool = True, +): """Initialize or update a GRR configuration.""" print("Checking write access on config %s" % config["Config.writeback"]) @@ -762,8 +891,7 @@ def Initialize(config=None, if prev_config_file and os.access(prev_config_file, os.R_OK): print("Found config file %s." % prev_config_file) # pytype: disable=wrong-arg-count - if input("Do you want to import this configuration? " - "[yN]: ").upper() == "Y": + if input("Do you want to import this configuration? [yN]: ").upper() == "Y": options_imported = ImportConfig(prev_config_file, config) # pytype: enable=wrong-arg-count else: @@ -778,13 +906,20 @@ def Initialize(config=None, ConfigureEmails(config) print("\nStep 2: Key Generation") - if config.Get("PrivateKeys.server_key", default=None): + if config.Get("PrivateKeys.executable_signing_private_key", default=None): if options_imported > 0: - print("Since you have imported keys from another installation in the " - "last step,\nyou probably do not want to generate new keys now.") + print( + "Since you have imported keys from another installation in the " + "last step,\nyou probably do not want to generate new keys now." + ) # pytype: disable=wrong-arg-count - if (input("You already have keys in your config, do you want to" - " overwrite them? [yN]: ").upper() or "N") == "Y": + if ( + input( + "You already have keys in your config, do you want to" + " overwrite them? [yN]: " + ).upper() + or "N" + ) == "Y": config_updater_keys_util.GenerateKeys(config, overwrite_keys=True) # pytype: enable=wrong-arg-count else: @@ -796,25 +931,26 @@ def Initialize(config=None, admin_password=admin_password, redownload_templates=redownload_templates, repack_templates=repack_templates, - prompt=True) + prompt=True, + ) def InitializeNoPrompt( config=None, - external_hostname: Optional[Text] = None, - admin_password: Optional[Text] = None, - mysql_hostname: Optional[Text] = None, + external_hostname: Optional[str] = None, + admin_password: Optional[str] = None, + mysql_hostname: Optional[str] = None, mysql_port: Optional[int] = None, - mysql_username: Optional[Text] = None, - mysql_password: Optional[Text] = None, - mysql_db: Optional[Text] = None, - mysql_client_key_path: Optional[Text] = None, - mysql_client_cert_path: Optional[Text] = None, - mysql_ca_cert_path: Optional[Text] = None, + mysql_username: Optional[str] = None, + mysql_password: Optional[str] = None, + mysql_db: Optional[str] = None, + mysql_client_key_path: Optional[str] = None, + mysql_client_cert_path: Optional[str] = None, + mysql_ca_cert_path: Optional[str] = None, redownload_templates: bool = False, repack_templates: bool = True, use_fleetspeak: bool = False, - mysql_fleetspeak_db: Optional[Text] = None, + mysql_fleetspeak_db: Optional[str] = None, ): """Initialize GRR with no prompts. @@ -850,7 +986,8 @@ def InitializeNoPrompt( raise ValueError("Config has already been initialized.") if not external_hostname: raise ValueError( - "--noprompt set, but --external_hostname was not provided.") + "--noprompt set, but --external_hostname was not provided." + ) if not admin_password: raise ValueError("--noprompt set, but --admin_password was not provided.") if mysql_password is None: @@ -866,24 +1003,31 @@ def InitializeNoPrompt( config_dict["Mysql.host"] = mysql_hostname or config["Mysql.host"] config_dict["Mysql.port"] = mysql_port or config["Mysql.port"] - config_dict["Mysql.database_name"] = config_dict[ - "Mysql.database"] = mysql_db or config["Mysql.database_name"] + config_dict["Mysql.database_name"] = config_dict["Mysql.database"] = ( + mysql_db or config["Mysql.database_name"] + ) config_dict["Mysql.database_username"] = config_dict["Mysql.username"] = ( - mysql_username or config["Mysql.database_username"]) + mysql_username or config["Mysql.database_username"] + ) config_dict["Client.server_urls"] = [ "http://%s:%s/" % (external_hostname, config["Frontend.bind_port"]) ] - config_dict["AdminUI.url"] = "http://%s:%s" % (external_hostname, - config["AdminUI.port"]) + config_dict["AdminUI.url"] = "http://%s:%s" % ( + external_hostname, + config["AdminUI.port"], + ) config_dict["Logging.domain"] = external_hostname - config_dict["Monitoring.alert_email"] = ("grr-monitoring@%s" % - external_hostname) - config_dict["Monitoring.emergency_access_email"] = ("grr-emergency@%s" % - external_hostname) + config_dict["Monitoring.alert_email"] = ( + "grr-monitoring@%s" % external_hostname + ) + config_dict["Monitoring.emergency_access_email"] = ( + "grr-emergency@%s" % external_hostname + ) # Print all configuration options, except for the MySQL password. print("Setting configuration as:\n\n%s" % config_dict) - config_dict["Mysql.database_password"] = config_dict[ - "Mysql.password"] = mysql_password + config_dict["Mysql.database_password"] = config_dict["Mysql.password"] = ( + mysql_password + ) if mysql_client_key_path is not None: config_dict["Mysql.client_key_path"] = mysql_client_key_path @@ -915,13 +1059,13 @@ def InitializeNoPrompt( admin_password=admin_password, redownload_templates=redownload_templates, repack_templates=repack_templates, - prompt=False) + prompt=False, + ) -def UploadSignedBinary(source_path, - binary_type, - platform, - upload_subdirectory=""): +def UploadSignedBinary( + source_path, binary_type, platform, upload_subdirectory="" +): """Signs a binary and uploads it to the datastore. Args: @@ -939,11 +1083,13 @@ def UploadSignedBinary(source_path, if file_size > _MAX_SIGNED_BINARY_BYTES: raise BinaryTooLargeError( "File [%s] is of size %d (bytes), which exceeds the allowed maximum " - "of %d bytes." % (source_path, file_size, _MAX_SIGNED_BINARY_BYTES)) + "of %d bytes." % (source_path, file_size, _MAX_SIGNED_BINARY_BYTES) + ) context = ["Platform:%s" % platform.title(), "Client Context"] signing_key = grr_config.CONFIG.Get( - "PrivateKeys.executable_signing_private_key", context=context) + "PrivateKeys.executable_signing_private_key", context=context + ) root_api = maintenance_utils.InitGRRRootAPI() binary_path = "/".join([ @@ -957,7 +1103,9 @@ def UploadSignedBinary(source_path, binary.Upload( fd, sign_fn=binary.DefaultUploadSigner( - private_key=signing_key.GetRawPrivateKey())) + private_key=signing_key.GetRawPrivateKey() + ), + ) print("Uploaded %s to %s" % (binary_type, binary_path)) @@ -972,15 +1120,18 @@ def CreateUser(username, password=None, is_admin=False): if user_exists: raise UserAlreadyExistsError("User '%s' already exists." % username) user_type, password = _GetUserTypeAndPassword( - username, password=password, is_admin=is_admin) + username, password=password, is_admin=is_admin + ) grr_api.CreateGrrUser( - username=username, user_type=user_type, password=password) + username=username, user_type=user_type, password=password + ) def UpdateUser(username, password=None, is_admin=False): """Updates the password or privilege-level for a user.""" user_type, password = _GetUserTypeAndPassword( - username, password=password, is_admin=is_admin) + username, password=password, is_admin=is_admin + ) grr_api = maintenance_utils.InitGRRRootAPI() grr_user = grr_api.GrrUser(username).Get() grr_user.Modify(user_type=user_type, password=password) @@ -1005,8 +1156,10 @@ def GetAllUserSummaries(): def _Summarize(user_info): """Returns a string with summary info for a user.""" - return "Username: %s\nIs Admin: %s" % (user_info.username, user_info.user_type - == api_root.GrrUser.USER_TYPE_ADMIN) + return "Username: %s\nIs Admin: %s" % ( + user_info.username, + user_info.user_type == api_root.GrrUser.USER_TYPE_ADMIN, + ) def DeleteUser(username): @@ -1038,40 +1191,52 @@ def _GetUserTypeAndPassword(username, password=None, is_admin=False): def SwitchToRelDB(config): """Switches a given config from using AFF4 to using REL_DB.""" - print("***************************************************************\n" - "Make sure to back up the existing configuration writeback file.\n" - "Writeback file path:\n%s\n" - "***************************************************************\n" % - config["Config.writeback"]) + print( + "***************************************************************\n" + "Make sure to back up the existing configuration writeback file.\n" + "Writeback file path:\n%s\n" + "***************************************************************\n" + % config["Config.writeback"] + ) RetryBoolQuestion("Continue?", True) config.Set("Database.implementation", "MysqlDB") - if (config["Blobstore.implementation"] != "DbBlobStore" or RetryBoolQuestion( + if config["Blobstore.implementation"] != "DbBlobStore" or RetryBoolQuestion( "You have a custom 'Blobstore.implementation' setting. Do you want\n" "to switch to DbBlobStore (default option for REL_DB, meaning that\n" - "blobs will be stored inside the MySQL database)?", True)): + "blobs will be stored inside the MySQL database)?", + True, + ): config.Set("Blobstore.implementation", "DbBlobStore") - if (RetryBoolQuestion( + if RetryBoolQuestion( "Do you want to use a different MySQL database for the REL_DB datastore?", - True)): - db_name = RetryQuestion("MySQL Database", "^[A-Za-z0-9-]+$", - config["Mysql.database_name"]) + True, + ): + db_name = RetryQuestion( + "MySQL Database", "^[A-Za-z0-9-]+$", config["Mysql.database_name"] + ) else: db_name = config["Mysql.database_name"] config.Set("Mysql.database", db_name) - if (input("Do you want to use previously set up MySQL username and password\n" - "to connect to MySQL database '%s'? [Yn]: " % db_name).upper() or - "Y") == "Y": + if ( + input( + "Do you want to use previously set up MySQL username and password\n" + "to connect to MySQL database '%s'? [Yn]: " % db_name + ).upper() + or "Y" + ) == "Y": username = config["Mysql.database_username"] password = config["Mysql.database_password"] else: - username = RetryQuestion("MySQL Username", "[A-Za-z0-9-@]+$", - config["Mysql.database_username"]) - password = GetPassword("Please enter password for database user %s: " % - username) + username = RetryQuestion( + "MySQL Username", "[A-Za-z0-9-@]+$", config["Mysql.database_username"] + ) + password = GetPassword( + "Please enter password for database user %s: " % username + ) config.Set("Mysql.username", username) config.Set("Mysql.password", password) @@ -1099,8 +1264,9 @@ def ArgparseBool(raw_value): 'True' or 'False'. """ if not isinstance(raw_value, str): - raise argparse.ArgumentTypeError("Unexpected type: %s. Expected a string." % - type(raw_value).__name__) + raise argparse.ArgumentTypeError( + "Unexpected type: %s. Expected a string." % type(raw_value).__name__ + ) if raw_value.lower() == "true": return True @@ -1108,4 +1274,5 @@ def ArgparseBool(raw_value): return False else: raise argparse.ArgumentTypeError( - "Invalid value encountered. Expected 'True' or 'False'.") + "Invalid value encountered. Expected 'True' or 'False'." + ) diff --git a/grr/server/grr_response_server/bin/config_updater_util_test.py b/grr/server/grr_response_server/bin/config_updater_util_test.py index 11fd58e425..7cafa920fe 100644 --- a/grr/server/grr_response_server/bin/config_updater_util_test.py +++ b/grr/server/grr_response_server/bin/config_updater_util_test.py @@ -51,15 +51,19 @@ def testConfigureMySQLDatastore(self, getpass_mock, connect_mock): db="grr-test-db", user="grr-test-user", passwd="grr-test-password", - charset="utf8") + charset="utf8", + ) self.assertEqual(config.writeback_data["Mysql.host"], "localhost") self.assertEqual(config.writeback_data["Mysql.port"], 1234) - self.assertEqual(config.writeback_data["Mysql.database_name"], - "grr-test-db") - self.assertEqual(config.writeback_data["Mysql.database_username"], - "grr-test-user") - self.assertEqual(config.writeback_data["Mysql.database_password"], - "grr-test-password") + self.assertEqual( + config.writeback_data["Mysql.database_name"], "grr-test-db" + ) + self.assertEqual( + config.writeback_data["Mysql.database_username"], "grr-test-user" + ) + self.assertEqual( + config.writeback_data["Mysql.database_password"], "grr-test-password" + ) @mock.patch.object(MySQLdb, "connect") @mock.patch.object(getpass, "getpass") @@ -93,28 +97,36 @@ def testConfigureMySQLDatastoreWithSSL(self, getpass_mock, connect_mock): "key": "key_file_path", "cert": "cert_file_path", "ca": "ca_cert_file_path", - }) + }, + ) self.assertEqual(config.writeback_data["Mysql.host"], "localhost") self.assertEqual(config.writeback_data["Mysql.port"], 1234) - self.assertEqual(config.writeback_data["Mysql.database_name"], - "grr-test-db") - self.assertEqual(config.writeback_data["Mysql.database_username"], - "grr-test-user") - self.assertEqual(config.writeback_data["Mysql.database_password"], - "grr-test-password") - self.assertEqual(config.writeback_data["Mysql.client_key_path"], - "key_file_path") - self.assertEqual(config.writeback_data["Mysql.client_cert_path"], - "cert_file_path") - self.assertEqual(config.writeback_data["Mysql.ca_cert_path"], - "ca_cert_file_path") + self.assertEqual( + config.writeback_data["Mysql.database_name"], "grr-test-db" + ) + self.assertEqual( + config.writeback_data["Mysql.database_username"], "grr-test-user" + ) + self.assertEqual( + config.writeback_data["Mysql.database_password"], "grr-test-password" + ) + self.assertEqual( + config.writeback_data["Mysql.client_key_path"], "key_file_path" + ) + self.assertEqual( + config.writeback_data["Mysql.client_cert_path"], "cert_file_path" + ) + self.assertEqual( + config.writeback_data["Mysql.ca_cert_path"], "ca_cert_file_path" + ) @mock.patch.object(MySQLdb, "connect") @mock.patch.object(getpass, "getpass") @mock.patch.object(config_updater_util, "_MYSQL_MAX_RETRIES", new=1) @mock.patch.object(config_updater_util, "_MYSQL_RETRY_WAIT_SECS", new=0.1) - def testConfigureMySQLDatastore_ConnectionRetry(self, getpass_mock, - connect_mock): + def testConfigureMySQLDatastore_ConnectionRetry( + self, getpass_mock, connect_mock + ): # Mock user-inputs for MySQL prompts. self.input_mock.side_effect = [ "Y", # Use REL_DB as the primary data store. @@ -123,11 +135,12 @@ def testConfigureMySQLDatastore_ConnectionRetry(self, getpass_mock, "grr-test-db", # GRR db name. "grr-test-user", # GRR db user. "n", # No SSL. - "n" # Exit config initialization after retries are depleted. + "n", # Exit config initialization after retries are depleted. ] getpass_mock.return_value = "grr-test-password" # DB password for GRR. connect_mock.side_effect = MySQLdb.OperationalError( - mysql_conn_errors.CONNECTION_ERROR, "Fake connection error.") + mysql_conn_errors.CONNECTION_ERROR, "Fake connection error." + ) config = grr_config.CONFIG.CopyConfig() with self.assertRaises(config_updater_util.ConfigInitError): config_updater_util.ConfigureMySQLDatastore(config) @@ -142,13 +155,17 @@ def testUploadPythonHack(self): python_hack_path, config_pb2.ApiGrrBinary.Type.PYTHON_HACK, "linux", - upload_subdirectory="test") + upload_subdirectory="test", + ) python_hack_urn = rdfvalue.RDFURN( - "aff4:/config/python_hacks/linux/test/hello_world.py") + "aff4:/config/python_hacks/linux/test/hello_world.py" + ) blob_iterator, _ = signed_binary_utils.FetchBlobsForSignedBinaryByURN( - python_hack_urn) + python_hack_urn + ) uploaded_blobs = list( - signed_binary_utils.StreamSignedBinaryContents(blob_iterator)) + signed_binary_utils.StreamSignedBinaryContents(blob_iterator) + ) uploaded_content = b"".join(uploaded_blobs) self.assertEqual(uploaded_content, b"print('Hello, world!')") @@ -161,14 +178,17 @@ def testUploadExecutable(self): executable_path, config_pb2.ApiGrrBinary.Type.EXECUTABLE, "windows", - upload_subdirectory="anti-malware/registry-tools") + upload_subdirectory="anti-malware/registry-tools", + ) executable_urn = rdfvalue.RDFURN( - "aff4:/config/executables/windows/anti-malware/registry-tools/" - "foo.exe") + "aff4:/config/executables/windows/anti-malware/registry-tools/foo.exe" + ) blob_iterator, _ = signed_binary_utils.FetchBlobsForSignedBinaryByURN( - executable_urn) + executable_urn + ) uploaded_blobs = list( - signed_binary_utils.StreamSignedBinaryContents(blob_iterator)) + signed_binary_utils.StreamSignedBinaryContents(blob_iterator) + ) uploaded_content = b"".join(uploaded_blobs) self.assertEqual(uploaded_content, b"\xaa\xbb\xcc\xdd") @@ -180,12 +200,16 @@ def testUploadOverlyLargeSignedBinary(self): f.write(b"\xaa\xbb\xcc\xdd\xee\xff") expected_message = ( "File [%s] is of size 6 (bytes), which exceeds the allowed maximum " - "of 5 bytes." % executable_path) + "of 5 bytes." % executable_path + ) with self.assertRaisesWithLiteralMatch( - config_updater_util.BinaryTooLargeError, expected_message): + config_updater_util.BinaryTooLargeError, expected_message + ): config_updater_util.UploadSignedBinary( - executable_path, config_pb2.ApiGrrBinary.Type.EXECUTABLE, - "windows") + executable_path, + config_pb2.ApiGrrBinary.Type.EXECUTABLE, + "windows", + ) @mock.patch.object(getpass, "getpass") def testCreateAdminUser(self, getpass_mock): @@ -195,7 +219,8 @@ def testCreateAdminUser(self, getpass_mock): def testCreateStandardUser(self): config_updater_util.CreateUser( - "foo_user", password="foo_password", is_admin=False) + "foo_user", password="foo_password", is_admin=False + ) self._AssertStoredUserDetailsAre("foo_user", "foo_password", False) def testCreateAlreadyExistingUser(self): @@ -205,32 +230,42 @@ def testCreateAlreadyExistingUser(self): def testUpdateUser(self): config_updater_util.CreateUser( - "foo_user", password="foo_password1", is_admin=False) + "foo_user", password="foo_password1", is_admin=False + ) self._AssertStoredUserDetailsAre("foo_user", "foo_password1", False) config_updater_util.UpdateUser( - "foo_user", password="foo_password2", is_admin=True) + "foo_user", password="foo_password2", is_admin=True + ) self._AssertStoredUserDetailsAre("foo_user", "foo_password2", True) def testGetUserSummary(self): config_updater_util.CreateUser( - "foo_user", password="foo_password", is_admin=False) + "foo_user", password="foo_password", is_admin=False + ) self.assertMultiLineEqual( config_updater_util.GetUserSummary("foo_user"), - "Username: foo_user\nIs Admin: False") + "Username: foo_user\nIs Admin: False", + ) def testGetAllUserSummaries(self): config_updater_util.CreateUser( - "foo_user1", password="foo_password1", is_admin=False) + "foo_user1", password="foo_password1", is_admin=False + ) config_updater_util.CreateUser( - "foo_user2", password="foo_password2", is_admin=True) - expected_summaries = ("Username: foo_user1\nIs Admin: False\n\n" - "Username: foo_user2\nIs Admin: True") - self.assertMultiLineEqual(config_updater_util.GetAllUserSummaries(), - expected_summaries) + "foo_user2", password="foo_password2", is_admin=True + ) + expected_summaries = ( + "Username: foo_user1\nIs Admin: False\n\n" + "Username: foo_user2\nIs Admin: True" + ) + self.assertMultiLineEqual( + config_updater_util.GetAllUserSummaries(), expected_summaries + ) def testDeleteUser(self): config_updater_util.CreateUser( - "foo_user", password="foo_password", is_admin=False) + "foo_user", password="foo_password", is_admin=False + ) self.assertNotEmpty(config_updater_util.GetUserSummary("foo_user")) config_updater_util.DeleteUser("foo_user") with self.assertRaises(config_updater_util.UserNotFoundError): @@ -258,23 +293,27 @@ def testArgparseBool_CaseInsensitive(self): def testArgparseBool_DefaultValue(self): parser = argparse.ArgumentParser() parser.add_argument( - "--foo", default=True, type=config_updater_util.ArgparseBool) + "--foo", default=True, type=config_updater_util.ArgparseBool + ) parser.add_argument( - "--bar", default=False, type=config_updater_util.ArgparseBool) + "--bar", default=False, type=config_updater_util.ArgparseBool + ) namespace = parser.parse_args([]) self.assertTrue(namespace.foo) self.assertFalse(namespace.bar) def testArgparseBool_InvalidType(self): expected_error = "Unexpected type: float. Expected a string." - with self.assertRaisesWithLiteralMatch(argparse.ArgumentTypeError, - expected_error): + with self.assertRaisesWithLiteralMatch( + argparse.ArgumentTypeError, expected_error + ): config_updater_util.ArgparseBool(1.23) def testArgparseBool_InvalidValue(self): expected_error = "Invalid value encountered. Expected 'True' or 'False'." - with self.assertRaisesWithLiteralMatch(argparse.ArgumentTypeError, - expected_error): + with self.assertRaisesWithLiteralMatch( + argparse.ArgumentTypeError, expected_error + ): config_updater_util.ArgparseBool("baz") diff --git a/grr/server/grr_response_server/bin/console.py b/grr/server/grr_response_server/bin/console.py index 945be7b1ab..0fea21c9ff 100644 --- a/grr/server/grr_response_server/bin/console.py +++ b/grr/server/grr_response_server/bin/console.py @@ -25,26 +25,33 @@ from grr_response_server import server_startup _CODE_TO_EXECUTE = flags.DEFINE_string( - "code_to_execute", None, + "code_to_execute", + None, "If present, no console is started but the code given in " "the flag is run instead (comparable to the -c option of " - "IPython).") + "IPython).", +) _COMMAND_FILE = flags.DEFINE_string( - "command_file", None, + "command_file", + None, "If present, no console is started but the code given in " - "command file is supplied as input instead.") + "command file is supplied as input instead.", +) _EXIT_ON_COMPLETE = flags.DEFINE_bool( - "exit_on_complete", True, + "exit_on_complete", + True, "If set to False and command_file or code_to_execute is " - "set we keep the console alive after the code completes.") + "set we keep the console alive after the code completes.", +) _VERSION = flags.DEFINE_bool( "version", default=False, allow_override=True, - help="Print the GRR console version number and exit immediately.") + help="Print the GRR console version number and exit immediately.", +) def main(argv): @@ -55,11 +62,13 @@ def main(argv): print("GRR console {}".format(config_server.VERSION["packageversion"])) return - banner = ("\nWelcome to the GRR console\n") + banner = "\nWelcome to the GRR console\n" config.CONFIG.AddContext(contexts.COMMAND_LINE_CONTEXT) - config.CONFIG.AddContext(contexts.CONSOLE_CONTEXT, - "Context applied when running the console binary.") + config.CONFIG.AddContext( + contexts.CONSOLE_CONTEXT, + "Context applied when running the console binary.", + ) server_startup.Init() fleetspeak_connector.Init() @@ -78,8 +87,9 @@ def main(argv): with open(_COMMAND_FILE.value, "r") as filedesc: exec(filedesc.read()) # pylint: disable=exec-used - if (_EXIT_ON_COMPLETE.value and - (_CODE_TO_EXECUTE.value or _COMMAND_FILE.value)): + if _EXIT_ON_COMPLETE.value and ( + _CODE_TO_EXECUTE.value or _COMMAND_FILE.value + ): return else: # We want the normal shell. diff --git a/grr/server/grr_response_server/bin/fleetspeak_frontend_server.py b/grr/server/grr_response_server/bin/fleetspeak_frontend_server.py index 7e742d99a0..5f528e1717 100644 --- a/grr/server/grr_response_server/bin/fleetspeak_frontend_server.py +++ b/grr/server/grr_response_server/bin/fleetspeak_frontend_server.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """This is the GRR frontend FS Server.""" + import logging from typing import FrozenSet, Sequence, Tuple @@ -32,10 +33,8 @@ # TODO: remove after the issue is fixed. -CLIENT_ID_SKIP_LIST = frozenset( - [ - ] -) +CLIENT_ID_SKIP_LIST = frozenset([ +]) MIN_DELAY_BETWEEN_METADATA_UPDATES = rdfvalue.Duration.From( @@ -73,8 +72,10 @@ def __init__(self): self.frontend = frontend_lib.FrontEndServer( max_queue_size=config.CONFIG["Frontend.max_queue_size"], message_expiry_time=config.CONFIG["Frontend.message_expiry_time"], - max_retransmission_time=config - .CONFIG["Frontend.max_retransmission_time"]) + max_retransmission_time=config.CONFIG[ + "Frontend.max_retransmission_time" + ], + ) @FRONTEND_REQUEST_COUNT.Counted(fields=["fleetspeak"]) @FRONTEND_REQUEST_LATENCY.Timed(fields=["fleetspeak"]) @@ -143,7 +144,8 @@ def _LogDelayed(msg: str) -> None: INCOMING_FLEETSPEAK_MESSAGES.Increment(fields=["PROCESS_GRR"]) grr_message = rdf_flows.GrrMessage.FromSerializedBytes( - fs_msg.data.value) + fs_msg.data.value + ) _LogDelayed("Starting processing GRR message") self._ProcessGRRMessages(grr_client_id, [grr_message]) _LogDelayed("Finished processing GRR message") @@ -153,9 +155,11 @@ def _LogDelayed(msg: str) -> None: ) packed_messages = rdf_flows.PackedMessageList.FromSerializedBytes( - fs_msg.data.value) + fs_msg.data.value + ) message_list = communicator.Communicator.DecompressMessageList( - packed_messages) + packed_messages + ) _LogDelayed("Starting processing GRR message list") self._ProcessGRRMessages(grr_client_id, message_list.job) _LogDelayed("Finished processing GRR message list") @@ -180,8 +184,10 @@ def _LogDelayed(msg: str) -> None: else: INCOMING_FLEETSPEAK_MESSAGES.Increment(fields=["INVALID"]) - logging.error("Received message with unrecognized message_type: %s", - fs_msg.message_type) + logging.error( + "Received message with unrecognized message_type: %s", + fs_msg.message_type, + ) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) except Exception: logging.exception("Exception processing message: %s", fs_msg) @@ -217,9 +223,11 @@ def _ProcessGRRMessages( for grr_message in grr_messages: grr_message.source = grr_client_id grr_message.auth_state = ( - rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED) + rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED + ) self.frontend.ReceiveMessages( - client_id=grr_client_id, messages=grr_messages) + client_id=grr_client_id, messages=grr_messages + ) except Exception: logging.exception("Exception receiving messages from: %s", grr_client_id) raise diff --git a/grr/server/grr_response_server/bin/fleetspeak_frontend_server_test.py b/grr/server/grr_response_server/bin/fleetspeak_frontend_server_test.py index 6738f20903..d3d9eb9a51 100644 --- a/grr/server/grr_response_server/bin/fleetspeak_frontend_server_test.py +++ b/grr/server/grr_response_server/bin/fleetspeak_frontend_server_test.py @@ -8,6 +8,7 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import flows as rdf_flows +from grr_response_proto import flows_pb2 from grr_response_server import communicator from grr_response_server import data_store from grr_response_server import events @@ -15,7 +16,6 @@ from grr_response_server.bin import fleetspeak_frontend_server from grr_response_server.flows.general import processes as flow_processes from grr_response_server.models import clients -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr.test_lib import action_mocks from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -40,16 +40,16 @@ def testReceiveMessages(self): - rdfvalue.Duration("1s"), ) - rdf_flow = rdf_flow_objects.Flow( - client_id=client_id, - flow_id=flow_id, - create_time=rdfvalue.RDFDatetime.Now()) - data_store.REL_DB.WriteFlowObject(rdf_flow) + flow = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) + data_store.REL_DB.WriteFlowObject(flow) - flow_request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + flow_request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) + before_write = data_store.REL_DB.Now() data_store.REL_DB.WriteFlowRequests([flow_request]) + after_write = data_store.REL_DB.Now() session_id = "%s/%s" % (client_id, flow_id) fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id) fs_messages = [] @@ -58,11 +58,13 @@ def testReceiveMessages(self): request_id=1, response_id=i + 1, session_id=session_id, - payload=rdfvalue.RDFInteger(i)) + payload=rdfvalue.RDFInteger(i), + ) fs_message = fs_common_pb2.Message( message_type="GrrMessage", source=fs_common_pb2.Address( - client_id=fs_client_id, service_name=FS_SERVICE_NAME), + client_id=fs_client_id, service_name=FS_SERVICE_NAME + ), ) fs_message.data.Pack(grr_message.AsPrimitiveProto()) fs_message.validation_info.tags["foo"] = "bar" @@ -83,10 +85,14 @@ def testReceiveMessages(self): ) flow_data = data_store.REL_DB.ReadAllFlowRequestsAndResponses( - client_id, flow_id) + client_id, flow_id + ) self.assertLen(flow_data, 1) stored_flow_request, flow_responses = flow_data[0] - self.assertEqual(stored_flow_request, flow_request) + self.assertEqual(stored_flow_request.client_id, flow_request.client_id) + self.assertEqual(stored_flow_request.flow_id, flow_request.flow_id) + self.assertEqual(stored_flow_request.request_id, flow_request.request_id) + self.assertBetween(stored_flow_request.timestamp, before_write, after_write) self.assertLen(flow_responses, 9) def testReceiveMessageList(self): @@ -101,16 +107,16 @@ def testReceiveMessageList(self): - rdfvalue.Duration("1s"), ) - rdf_flow = rdf_flow_objects.Flow( - client_id=client_id, - flow_id=flow_id, - create_time=rdfvalue.RDFDatetime.Now()) - data_store.REL_DB.WriteFlowObject(rdf_flow) - - flow_request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + flow = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) + data_store.REL_DB.WriteFlowObject(flow) + flow_request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) + before_write = data_store.REL_DB.Now() data_store.REL_DB.WriteFlowRequests([flow_request]) + after_write = data_store.REL_DB.Now() + session_id = "%s/%s" % (client_id, flow_id) fs_client_id = fleetspeak_utils.GRRIDToFleetspeakID(client_id) grr_messages = [] @@ -119,15 +125,19 @@ def testReceiveMessageList(self): request_id=1, response_id=i + 1, session_id=session_id, - payload=rdfvalue.RDFInteger(i)) + payload=rdfvalue.RDFInteger(i), + ) grr_messages.append(grr_message) packed_messages = rdf_flows.PackedMessageList() communicator.Communicator.EncodeMessageList( - rdf_flows.MessageList(job=grr_messages), packed_messages) + rdf_flows.MessageList(job=grr_messages), packed_messages + ) fs_message = fs_common_pb2.Message( message_type="MessageList", source=fs_common_pb2.Address( - client_id=fs_client_id, service_name=FS_SERVICE_NAME)) + client_id=fs_client_id, service_name=FS_SERVICE_NAME + ), + ) fs_message.data.Pack(packed_messages.AsPrimitiveProto()) fs_message.validation_info.tags["foo"] = "bar" @@ -145,10 +155,14 @@ def testReceiveMessageList(self): ) flow_data = data_store.REL_DB.ReadAllFlowRequestsAndResponses( - client_id, flow_id) + client_id, flow_id + ) self.assertLen(flow_data, 1) stored_flow_request, flow_responses = flow_data[0] - self.assertEqual(stored_flow_request, flow_request) + self.assertEqual(stored_flow_request.client_id, flow_request.client_id) + self.assertEqual(stored_flow_request.flow_id, flow_request.flow_id) + self.assertEqual(stored_flow_request.request_id, flow_request.request_id) + self.assertBetween(stored_flow_request.timestamp, before_write, after_write) self.assertLen(flow_responses, 9) def testMetadataDoesNotGetUpdatedIfPreviousUpdateIsTooRecent(self): @@ -158,13 +172,12 @@ def testMetadataDoesNotGetUpdatedIfPreviousUpdateIsTooRecent(self): now = rdfvalue.RDFDatetime.Now() data_store.REL_DB.WriteClientMetadata(client_id, last_ping=now) - rdf_flow = rdf_flow_objects.Flow( + flow = flows_pb2.Flow( client_id=client_id, flow_id=flow_id, - create_time=rdfvalue.RDFDatetime.Now(), ) - data_store.REL_DB.WriteFlowObject(rdf_flow) - flow_request = rdf_flow_objects.FlowRequest( + data_store.REL_DB.WriteFlowObject(flow) + flow_request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1 ) data_store.REL_DB.WriteFlowRequests([flow_request]) @@ -200,13 +213,12 @@ def testMetadataGetsUpdatedIfPreviousUpdateIsOldEnough(self): ) data_store.REL_DB.WriteClientMetadata(client_id, last_ping=past) - rdf_flow = rdf_flow_objects.Flow( + flow = flows_pb2.Flow( client_id=client_id, flow_id=flow_id, - create_time=rdfvalue.RDFDatetime.Now(), ) - data_store.REL_DB.WriteFlowObject(rdf_flow) - flow_request = rdf_flow_objects.FlowRequest( + data_store.REL_DB.WriteFlowObject(flow) + flow_request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1 ) data_store.REL_DB.WriteFlowRequests([flow_request]) @@ -242,21 +254,25 @@ def testWriteLastPingForNewClients(self): request_id=1, response_id=1, session_id=session_id, - payload=rdfvalue.RDFInteger(1)) + payload=rdfvalue.RDFInteger(1), + ) fs_message = fs_common_pb2.Message( message_type="GrrMessage", source=fs_common_pb2.Address( - client_id=fs_client_id, service_name=FS_SERVICE_NAME)) + client_id=fs_client_id, service_name=FS_SERVICE_NAME + ), + ) fs_message.data.Pack(grr_message.AsPrimitiveProto()) fake_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(123) with mock.patch.object( - events.Events, "PublishEvent", - wraps=events.Events.PublishEvent) as publish_event_fn: + events.Events, "PublishEvent", wraps=events.Events.PublishEvent + ) as publish_event_fn: with mock.patch.object( data_store.REL_DB, "WriteClientMetadata", - wraps=data_store.REL_DB.WriteClientMetadata) as write_metadata_fn: + wraps=data_store.REL_DB.WriteClientMetadata, + ) as write_metadata_fn: with test_lib.FakeTime(fake_time): fs_server.Process(fs_message, None) self.assertEqual(write_metadata_fn.call_count, 1) @@ -286,18 +302,20 @@ def testProcessListingOnlyFleetspeak(self): ppid=1, cmdline=["cmd.exe"], exe=r"c:\windows\cmd.exe", - ctime=1333718907167083) + ctime=1333718907167083, + ) ]) flow_id = flow_test_lib.TestFlowHelper( flow_processes.ListProcesses.__name__, client_mock, client_id=client_id, - creator=self.test_username) + creator=self.test_username, + ) processes = flow_test_lib.GetFlowResults(client_id, flow_id) self.assertLen(processes, 1) - process, = processes + (process,) = processes self.assertEqual(process.ctime, 1333718907167083) self.assertEqual(process.cmdline, ["cmd.exe"]) diff --git a/grr/server/grr_response_server/bin/fleetspeak_server_wrapper.py b/grr/server/grr_response_server/bin/fleetspeak_server_wrapper.py index 33c9290ed7..92a5ca213f 100644 --- a/grr/server/grr_response_server/bin/fleetspeak_server_wrapper.py +++ b/grr/server/grr_response_server/bin/fleetspeak_server_wrapper.py @@ -10,6 +10,7 @@ * `grr_config_updater initialize` has been run. * Fleetspeak has been enabled. """ + import os import subprocess @@ -24,19 +25,22 @@ class Error(Exception): def main(argv): config_dir = package.ResourcePath( - "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server") + "fleetspeak-server-bin", "fleetspeak-server-bin/etc/fleetspeak-server" + ) if not os.path.exists(config_dir): raise Error( f"Configuration directory not found: {config_dir}. " - "Please make sure `grr_config_updater initialize` has been run.") + "Please make sure `grr_config_updater initialize` has been run." + ) fleetspeak_server = package.ResourcePath( - "fleetspeak-server-bin", - "fleetspeak-server-bin/usr/bin/fleetspeak-server") + "fleetspeak-server-bin", "fleetspeak-server-bin/usr/bin/fleetspeak-server" + ) if not os.path.exists(fleetspeak_server): raise Error( f"Fleetspeak server binary not found: {fleetspeak_server}. " "Please make sure that the package `fleetspeak-server-bin` has been " - "installed.") + "installed." + ) command = [ fleetspeak_server, "--logtostderr", diff --git a/grr/server/grr_response_server/bin/grr_server.py b/grr/server/grr_response_server/bin/grr_server.py index a4e4e44088..2c587ed402 100644 --- a/grr/server/grr_response_server/bin/grr_server.py +++ b/grr/server/grr_response_server/bin/grr_server.py @@ -8,7 +8,6 @@ from absl import flags from grr_response_core.config import server as config_server - from grr_response_server import server_startup from grr_response_server.bin import fleetspeak_frontend from grr_response_server.bin import fleetspeak_server_wrapper @@ -21,11 +20,14 @@ "version", default=False, allow_override=True, - help="Print the GRR server version number and exit immediately.") + help="Print the GRR server version number and exit immediately.", +) _COMPONENT = flags.DEFINE_string( - "component", None, - "Component to start: [frontend|admin_ui|worker|grrafana].") + "component", + None, + "Component to start: [frontend|admin_ui|worker|grrafana].", +) def main(argv): @@ -64,8 +66,9 @@ def main(argv): # Raise on invalid component. else: - raise ValueError("No valid component specified. Got: " - "%s." % _COMPONENT.value) + raise ValueError( + "No valid component specified. Got: %s." % _COMPONENT.value + ) if __name__ == "__main__": diff --git a/grr/server/grr_response_server/bin/grrafana.py b/grr/server/grr_response_server/bin/grrafana.py index ca5320ca51..29996caf3c 100644 --- a/grr/server/grr_response_server/bin/grrafana.py +++ b/grr/server/grr_response_server/bin/grrafana.py @@ -29,7 +29,8 @@ "version", default=False, allow_override=True, - help="Print the GRR console version number and exit immediately.") + help="Print the GRR console version number and exit immediately.", +) class _Datapoint(NamedTuple): @@ -79,8 +80,11 @@ class ClientResourceUsageMetric(Metric): """A metric that represents resource usage data for a single client.""" def __init__( - self, name: str, record_values_extract_fn: Callable[ - [List[resource_pb2.ClientResourceUsageRecord]], List[float]] + self, + name: str, + record_values_extract_fn: Callable[ + [List[resource_pb2.ClientResourceUsageRecord]], List[float] + ], ) -> None: super().__init__(name) self._record_values_extract_fn = record_values_extract_fn @@ -97,37 +101,46 @@ def ProcessQuery( end_range_ts = TimeToProtoTimestamp(req_json["range"]["to"]) records_list = fleetspeak_utils.FetchClientResourceUsageRecords( - client_id, start_range_ts, end_range_ts) + client_id, start_range_ts, end_range_ts + ) record_values = self._record_values_extract_fn(records_list) datapoints = [] - for (v, r) in zip(record_values, records_list): + for v, r in zip(record_values, records_list): datapoints.append( _Datapoint( nanos=v, - value=r.server_timestamp.seconds * 1000 + - r.server_timestamp.nanos // 1000000)) + value=r.server_timestamp.seconds * 1000 + + r.server_timestamp.nanos // 1000000, + ) + ) return _TargetWithDatapoints(target=self._name, datapoints=datapoints) AVAILABLE_METRICS_LIST: List[Metric] AVAILABLE_METRICS_LIST = [ - ClientResourceUsageMetric("Mean User CPU Rate", - lambda rl: [r.mean_user_cpu_rate for r in rl]), - ClientResourceUsageMetric("Max User CPU Rate", - lambda rl: [r.max_user_cpu_rate for r in rl]), - ClientResourceUsageMetric("Mean System CPU Rate", - lambda rl: [r.mean_system_cpu_rate for r in rl]), - ClientResourceUsageMetric("Max System CPU Rate", - lambda rl: [r.max_system_cpu_rate for r in rl]), + ClientResourceUsageMetric( + "Mean User CPU Rate", lambda rl: [r.mean_user_cpu_rate for r in rl] + ), + ClientResourceUsageMetric( + "Max User CPU Rate", lambda rl: [r.max_user_cpu_rate for r in rl] + ), + ClientResourceUsageMetric( + "Mean System CPU Rate", lambda rl: [r.mean_system_cpu_rate for r in rl] + ), + ClientResourceUsageMetric( + "Max System CPU Rate", lambda rl: [r.max_system_cpu_rate for r in rl] + ), # Converting MiB to MB ClientResourceUsageMetric( "Mean Resident Memory MB", - lambda rl: [r.mean_resident_memory_mib * 1.049 for r in rl]), + lambda rl: [r.mean_resident_memory_mib * 1.049 for r in rl], + ), ClientResourceUsageMetric( "Max Resident Memory MB", - lambda rl: [r.max_resident_memory_mib * 1.049 for r in rl]), + lambda rl: [r.max_resident_memory_mib * 1.049 for r in rl], + ), ] AVAILABLE_METRICS_BY_NAME = { @@ -148,11 +161,14 @@ def __init__(self) -> None: self._url_map = werkzeug_routing.Map([ werkzeug_routing.Rule("/", endpoint=self._OnRoot, methods=["GET"]), # pytype: disable=wrong-arg-types werkzeug_routing.Rule( - "/search", endpoint=self._OnSearch, methods=["POST"]), # pytype: disable=wrong-arg-types + "/search", endpoint=self._OnSearch, methods=["POST"] + ), # pytype: disable=wrong-arg-types werkzeug_routing.Rule( - "/query", endpoint=self._OnQuery, methods=["POST"]), # pytype: disable=wrong-arg-types + "/query", endpoint=self._OnQuery, methods=["POST"] + ), # pytype: disable=wrong-arg-types werkzeug_routing.Rule( - "/annotations", endpoint=self._OnAnnotations, methods=["POST"]), # pytype: disable=wrong-arg-types + "/annotations", endpoint=self._OnAnnotations, methods=["POST"] + ), # pytype: disable=wrong-arg-types ]) def _DispatchRequest( @@ -184,7 +200,14 @@ def _DispatchRequest( def __call__( self, environ: dict[str, Any], - start_response: Callable[[str, list[tuple[str, Any]], tuple[Type[Exception], Exception, types.TracebackType]], Any] + start_response: Callable[ + [ + str, + list[tuple[str, Any]], + tuple[Type[Exception], Exception, types.TracebackType], + ], + Any, + ], ) -> Iterable[bytes]: request = werkzeug_wrappers.Request(environ) response = self._DispatchRequest(request) @@ -215,7 +238,8 @@ def _OnSearch( """ response = list(AVAILABLE_METRICS_BY_NAME.keys()) return werkzeug_wrappers.Response( - response=json.dumps(response), content_type=JSON_MIME_TYPE) + response=json.dumps(response), content_type=JSON_MIME_TYPE + ) def _OnQuery( self, @@ -240,7 +264,8 @@ def _OnQuery( ] response = [t._asdict() for t in targets_with_datapoints] return werkzeug_wrappers.Response( - response=json.dumps(response), content_type=JSON_MIME_TYPE) + response=json.dumps(response), content_type=JSON_MIME_TYPE + ) def _OnAnnotations( self, @@ -252,7 +277,8 @@ def _OnAnnotations( def TimeToProtoTimestamp(grafana_time: str) -> timestamp_pb2.Timestamp: date = parser.parse(grafana_time) return timestamp_pb2.Timestamp( - seconds=int(date.timestamp()), nanos=date.microsecond * 1000) + seconds=int(date.timestamp()), nanos=date.microsecond * 1000 + ) def main(argv: Any) -> None: @@ -263,12 +289,14 @@ def main(argv: Any) -> None: print(f"GRRafana server {config_server.VERSION['packageversion']}") return - config.CONFIG.AddContext(contexts.GRRAFANA_CONTEXT, - "Context applied when running GRRafana server.") + config.CONFIG.AddContext( + contexts.GRRAFANA_CONTEXT, "Context applied when running GRRafana server." + ) server_startup.Init() fleetspeak_connector.Init() - werkzeug_serving.run_simple(config.CONFIG["GRRafana.bind"], - config.CONFIG["GRRafana.port"], Grrafana()) + werkzeug_serving.run_simple( + config.CONFIG["GRRafana.bind"], config.CONFIG["GRRafana.port"], Grrafana() + ) if __name__ == "__main__": diff --git a/grr/server/grr_response_server/bin/grrafana_test.py b/grr/server/grr_response_server/bin/grrafana_test.py index 8ec13cd0d2..1498caf001 100644 --- a/grr/server/grr_response_server/bin/grrafana_test.py +++ b/grr/server/grr_response_server/bin/grrafana_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Unittest for GRRafana HTTP server.""" + import copy from unittest import mock @@ -9,10 +10,8 @@ from werkzeug import wrappers as werkzeug_wrappers from google.protobuf import timestamp_pb2 - from grr_response_server import fleetspeak_connector from grr_response_server.bin import grrafana - from fleetspeak.src.server.proto.fleetspeak_server import admin_pb2 from fleetspeak.src.server.proto.fleetspeak_server import resource_pb2 @@ -23,46 +22,28 @@ _TEST_CLIENT_RESOURCE_USAGE_RECORD_1 = { "scope": "system", "pid": 2714460, - "process_start_time": { - "seconds": 1597327815, - "nanos": 817468715 - }, - "client_timestamp": { - "seconds": 1597328416, - "nanos": 821525280 - }, - "server_timestamp": { - "seconds": 1597328417, - "nanos": 823124057 - }, + "process_start_time": {"seconds": 1597327815, "nanos": 817468715}, + "client_timestamp": {"seconds": 1597328416, "nanos": 821525280}, + "server_timestamp": {"seconds": 1597328417, "nanos": 823124057}, "mean_user_cpu_rate": 0.31883034110069275, "max_user_cpu_rate": 4.999776840209961, "mean_system_cpu_rate": 0.31883034110069275, "max_system_cpu_rate": 4.999776840209961, "mean_resident_memory_mib": 20, - "max_resident_memory_mib": 20 + "max_resident_memory_mib": 20, } _TEST_CLIENT_RESOURCE_USAGE_RECORD_2 = { "scope": "GRR", "pid": 2714474, - "process_start_time": { - "seconds": 1597327815, - "nanos": 818657389 - }, - "client_timestamp": { - "seconds": 1597328418, - "nanos": 402023428 - }, - "server_timestamp": { - "seconds": 1597328419, - "nanos": 403123025 - }, + "process_start_time": {"seconds": 1597327815, "nanos": 818657389}, + "client_timestamp": {"seconds": 1597328418, "nanos": 402023428}, + "server_timestamp": {"seconds": 1597328419, "nanos": 403123025}, "mean_user_cpu_rate": 0.492735356092453, "max_user_cpu_rate": 4.999615669250488, "mean_system_cpu_rate": 0.07246342301368713, "max_system_cpu_rate": 0.3333326578140259, "mean_resident_memory_mib": 59, - "max_resident_memory_mib": 59 + "max_resident_memory_mib": 59, } _TEST_VALID_RUD_QUERY = { @@ -74,48 +55,36 @@ "range": { "from": _START_RANGE_TIMESTAMP, "to": _END_RANGE_TIMESTAMP, - "raw": { - "from": _START_RANGE_TIMESTAMP, - "to": _END_RANGE_TIMESTAMP - } + "raw": {"from": _START_RANGE_TIMESTAMP, "to": _END_RANGE_TIMESTAMP}, }, "timeInfo": "", "interval": "10m", "intervalMs": 600000, - "targets": [{ - "data": None, - "target": "Max User CPU Rate", - "refId": "A", - "hide": False, - "type": "timeseries" - }, { - "data": None, - "target": "Mean System CPU Rate", - "refId": "A", - "hide": False, - "type": "timeseries" - }], - "maxDataPoints": 800, - "scopedVars": { - "ClientID": { - "text": _TEST_CLIENT_ID_1, - "value": _TEST_CLIENT_ID_1 + "targets": [ + { + "data": None, + "target": "Max User CPU Rate", + "refId": "A", + "hide": False, + "type": "timeseries", }, - "__interval": { - "text": "10m", - "value": "10m" + { + "data": None, + "target": "Mean System CPU Rate", + "refId": "A", + "hide": False, + "type": "timeseries", }, - "__interval_ms": { - "text": "600000", - "value": 600000 - } + ], + "maxDataPoints": 800, + "scopedVars": { + "ClientID": {"text": _TEST_CLIENT_ID_1, "value": _TEST_CLIENT_ID_1}, + "__interval": {"text": "10m", "value": "10m"}, + "__interval_ms": {"text": "600000", "value": 600000}, }, "startTime": 1598782453496, - "rangeRaw": { - "from": _START_RANGE_TIMESTAMP, - "to": _END_RANGE_TIMESTAMP - }, - "adhocFilters": [] + "rangeRaw": {"from": _START_RANGE_TIMESTAMP, "to": _END_RANGE_TIMESTAMP}, + "adhocFilters": [], } _TEST_INVALID_TARGET_QUERY = copy.deepcopy(_TEST_VALID_RUD_QUERY) @@ -139,9 +108,12 @@ def _MockConnReturningRecords(client_ruds): mean_system_cpu_rate=record["mean_system_cpu_rate"], max_system_cpu_rate=record["max_system_cpu_rate"], mean_resident_memory_mib=record["mean_resident_memory_mib"], - max_resident_memory_mib=record["max_resident_memory_mib"])) - conn.outgoing.FetchClientResourceUsageRecords.return_value = admin_pb2.FetchClientResourceUsageRecordsResponse( - records=records) + max_resident_memory_mib=record["max_resident_memory_mib"], + ) + ) + conn.outgoing.FetchClientResourceUsageRecords.return_value = ( + admin_pb2.FetchClientResourceUsageRecordsResponse(records=records) + ) return conn @@ -152,7 +124,8 @@ def setUp(self): super().setUp() self.client = werkzeug_test.Client( application=grrafana.Grrafana(), - response_wrapper=werkzeug_wrappers.Response) + response_wrapper=werkzeug_wrappers.Response, + ) def testRoot(self): response = self.client.get("/") @@ -160,43 +133,52 @@ def testRoot(self): def testSearchMetrics(self): response = self.client.post( - "/search", json={ - "type": "timeseries", - "target": "" - }) + "/search", json={"type": "timeseries", "target": ""} + ) self.assertEqual(200, response.status_code) expected_res = [ - "Mean User CPU Rate", "Max User CPU Rate", "Mean System CPU Rate", - "Max System CPU Rate", "Mean Resident Memory MB", - "Max Resident Memory MB" + "Mean User CPU Rate", + "Max User CPU Rate", + "Mean System CPU Rate", + "Max System CPU Rate", + "Mean Resident Memory MB", + "Max Resident Memory MB", ] self.assertListEqual(response.json, expected_res) def testClientResourceUsageMetricQuery(self): conn = _MockConnReturningRecords([ _TEST_CLIENT_RESOURCE_USAGE_RECORD_1, - _TEST_CLIENT_RESOURCE_USAGE_RECORD_2 + _TEST_CLIENT_RESOURCE_USAGE_RECORD_2, ]) with mock.patch.object(fleetspeak_connector, "CONN", conn): valid_response = self.client.post("/query", json=_TEST_VALID_RUD_QUERY) self.assertEqual(200, valid_response.status_code) - self.assertEqual(valid_response.json, [{ - "target": - "Max User CPU Rate", - "datapoints": [[4.999776840209961, 1597328417823], - [4.999615669250488, 1597328419403]] - }, { - "target": - "Mean System CPU Rate", - "datapoints": [[0.31883034110069275, 1597328417823], - [0.07246342301368713, 1597328419403]] - }]) + self.assertEqual( + valid_response.json, + [ + { + "target": "Max User CPU Rate", + "datapoints": [ + [4.999776840209961, 1597328417823], + [4.999615669250488, 1597328419403], + ], + }, + { + "target": "Mean System CPU Rate", + "datapoints": [ + [0.31883034110069275, 1597328417823], + [0.07246342301368713, 1597328419403], + ], + }, + ], + ) def testQueryInvalidRequest(self): conn = _MockConnReturningRecords([ _TEST_CLIENT_RESOURCE_USAGE_RECORD_1, - _TEST_CLIENT_RESOURCE_USAGE_RECORD_2 + _TEST_CLIENT_RESOURCE_USAGE_RECORD_2, ]) with mock.patch.object(fleetspeak_connector, "CONN", conn): with self.assertRaises(KeyError): @@ -209,10 +191,12 @@ class TimeToProtoTimestampTest(absltest.TestCase): def testTimeToProtoTimestamp(self): self.assertEqual( grrafana.TimeToProtoTimestamp(_START_RANGE_TIMESTAMP), - timestamp_pb2.Timestamp(seconds=1597328417, nanos=(158 * 1000000))) + timestamp_pb2.Timestamp(seconds=1597328417, nanos=(158 * 1000000)), + ) self.assertEqual( grrafana.TimeToProtoTimestamp(_END_RANGE_TIMESTAMP), - timestamp_pb2.Timestamp(seconds=1597770958, nanos=(761 * 1000000))) + timestamp_pb2.Timestamp(seconds=1597770958, nanos=(761 * 1000000)), + ) def main(argv): diff --git a/grr/server/grr_response_server/bin/worker.py b/grr/server/grr_response_server/bin/worker.py index c27212bee8..d45b7fe227 100644 --- a/grr/server/grr_response_server/bin/worker.py +++ b/grr/server/grr_response_server/bin/worker.py @@ -16,7 +16,8 @@ "version", default=False, allow_override=True, - help="Print the GRR worker version number and exit immediately.") + help="Print the GRR worker version number and exit immediately.", +) def main(argv): @@ -27,8 +28,9 @@ def main(argv): print("GRR worker {}".format(config_server.VERSION["packageversion"])) return - config.CONFIG.AddContext(contexts.WORKER_CONTEXT, - "Context applied when running a worker.") + config.CONFIG.AddContext( + contexts.WORKER_CONTEXT, "Context applied when running a worker." + ) # Initialise flows and config_lib server_startup.Init() diff --git a/grr/server/grr_response_server/bin/worker_test.py b/grr/server/grr_response_server/bin/worker_test.py index ed881b7162..5c58f4f427 100644 --- a/grr/server/grr_response_server/bin/worker_test.py +++ b/grr/server/grr_response_server/bin/worker_test.py @@ -7,11 +7,12 @@ from absl import app from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_core.lib.rdfvalues import mig_protodict from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import objects_pb2 from grr_response_server import data_store from grr_response_server import foreman from grr_response_server import worker_lib -from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import action_mocks from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -30,22 +31,24 @@ def handle(l): done.set() data_store.REL_DB.RegisterMessageHandler( - handle, worker_lib.GRRWorker.message_handler_lease_time, limit=1000) + handle, worker_lib.GRRWorker.message_handler_lease_time, limit=1000 + ) startup_info = rdf_client.StartupInfo() startup_info.client_info.client_version = 4321 - - data_store.REL_DB.WriteMessageHandlerRequests( - [ - rdf_objects.MessageHandlerRequest( - client_id=client_id, - handler_name="ClientStartupHandler", - request_id=12345, - request=startup_info, - ) - ] + emb = mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(startup_info) ) + data_store.REL_DB.WriteMessageHandlerRequests([ + objects_pb2.MessageHandlerRequest( + client_id=client_id, + handler_name="ClientStartupHandler", + request_id=12345, + request=emb, + ) + ]) + self.assertTrue(done.wait(10)) result = data_store.REL_DB.ReadClientStartupInfo(client_id) @@ -61,7 +64,8 @@ def testCPULimitForFlows(self): client_id = self.SetupClient(0) client_mock = action_mocks.CPULimitClientMock( - user_cpu_usage=[10], system_cpu_usage=[10], network_usage=[1000]) + user_cpu_usage=[10], system_cpu_usage=[10], network_usage=[1000] + ) flow_test_lib.TestFlowHelper( flow_test_lib.CPULimitFlow.__name__, @@ -69,7 +73,8 @@ def testCPULimitForFlows(self): creator=self.test_username, client_id=client_id, cpu_limit=1000, - network_bytes_limit=10000) + network_bytes_limit=10000, + ) self.assertEqual(client_mock.storage["cpulimit"], [1000, 980, 960]) self.assertEqual(client_mock.storage["networklimit"], [10000, 9000, 8000]) @@ -80,11 +85,14 @@ def testForemanMessageHandler(self): client_id = "C.1100110011001100" data_store.REL_DB.WriteMessageHandlerRequests([ - rdf_objects.MessageHandlerRequest( + objects_pb2.MessageHandlerRequest( client_id=client_id, handler_name="ForemanHandler", request_id=12345, - request=rdf_protodict.DataBlob()) + request=mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(rdf_protodict.DataBlob()) + ), + ) ]) done = threading.Event() @@ -94,7 +102,8 @@ def handle(l): done.set() data_store.REL_DB.RegisterMessageHandler( - handle, worker_lib.GRRWorker.message_handler_lease_time, limit=1000) + handle, worker_lib.GRRWorker.message_handler_lease_time, limit=1000 + ) try: self.assertTrue(done.wait(10)) diff --git a/grr/server/grr_response_server/blob_stores/benchmark.py b/grr/server/grr_response_server/blob_stores/benchmark.py index f5089e30cd..837c210bb2 100644 --- a/grr/server/grr_response_server/blob_stores/benchmark.py +++ b/grr/server/grr_response_server/blob_stores/benchmark.py @@ -77,7 +77,8 @@ def _PrintStats(size, size_b, durations): p90=np.percentile(durations_ms, 90), p95=np.percentile(durations_ms, 95), p99=np.percentile(durations_ms, 99), - )) + ) + ) def _RunBenchmark(bs, size_b, duration_sec, random_fd): diff --git a/grr/server/grr_response_server/blob_stores/db_blob_store.py b/grr/server/grr_response_server/blob_stores/db_blob_store.py index 9ec38b01fd..20a8f5c313 100644 --- a/grr/server/grr_response_server/blob_stores/db_blob_store.py +++ b/grr/server/grr_response_server/blob_stores/db_blob_store.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """REL_DB blobstore implementation.""" + from typing import Dict from typing import Iterable from typing import Optional @@ -25,8 +26,10 @@ def __init__(self, delegate: Optional[blob_store.BlobStore] = None) -> None: # interface). delegate = data_store.REL_DB.delegate # pytype: disable=attribute-error if not isinstance(delegate, blob_store.BlobStore): - raise TypeError(f"Database blobstore delegate of '{type(delegate)}' " - f"type does not implement the blobstore interface") + raise TypeError( + f"Database blobstore delegate of '{type(delegate)}' " + "type does not implement the blobstore interface" + ) self._delegate = delegate diff --git a/grr/server/grr_response_server/blob_stores/db_blob_store_test.py b/grr/server/grr_response_server/blob_stores/db_blob_store_test.py index 4cbe2b3be8..efe819b48b 100644 --- a/grr/server/grr_response_server/blob_stores/db_blob_store_test.py +++ b/grr/server/grr_response_server/blob_stores/db_blob_store_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for REL_DB-based blob store.""" - from absl import app from grr_response_server import blob_store_test_mixin @@ -9,8 +8,9 @@ from grr.test_lib import test_lib -class DbBlobStoreTest(blob_store_test_mixin.BlobStoreTestMixin, - test_lib.GRRBaseTest): +class DbBlobStoreTest( + blob_store_test_mixin.BlobStoreTestMixin, test_lib.GRRBaseTest +): def CreateBlobStore(self): return (db_blob_store.DbBlobStore(), lambda: None) diff --git a/grr/server/grr_response_server/blob_stores/encrypted_blob_store.py b/grr/server/grr_response_server/blob_stores/encrypted_blob_store.py index f57d0124ab..9fd2bd3f9e 100644 --- a/grr/server/grr_response_server/blob_stores/encrypted_blob_store.py +++ b/grr/server/grr_response_server/blob_stores/encrypted_blob_store.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """An module with implementation of the encrypted blobstore.""" + import logging from typing import Iterable from typing import Optional @@ -52,8 +53,12 @@ def WriteBlobs( encrypted_blobs[blob_id] = crypter.Encrypt(blob, bytes(blob_id)) key_names[blob_id] = self._key_name - logging.info("Writing %s encrypted blobs using key '%s' (%s)", len(blobs), - self._key_name, ", ".join(map(str, blobs))) + logging.info( + "Writing %s encrypted blobs using key '%s' (%s)", + len(blobs), + self._key_name, + ", ".join(map(str, blobs)), + ) self._bs.WriteBlobs(encrypted_blobs) self._db.WriteBlobEncryptionKeys(key_names) diff --git a/grr/server/grr_response_server/blob_stores/encrypted_blob_store_test.py b/grr/server/grr_response_server/blob_stores/encrypted_blob_store_test.py index 18f5010838..29e4b0999d 100644 --- a/grr/server/grr_response_server/blob_stores/encrypted_blob_store_test.py +++ b/grr/server/grr_response_server/blob_stores/encrypted_blob_store_test.py @@ -30,7 +30,8 @@ class EncryptedBlobStoreTest( # Test methods are defined in the base mixin class. def CreateBlobStore( - self) -> tuple[blob_store.BlobStore, Optional[Callable[[], None]]]: + self, + ) -> tuple[blob_store.BlobStore, Optional[Callable[[], None]]]: db = mem_db.InMemoryDB() bs = db ks = mem_ks.MemKeystore(["foo"]) @@ -79,7 +80,8 @@ def testReadBlobEncryptedWithoutKeysOutdated(self): del db.blob_keys[blob_id] with self.assertRaises( - encrypted_blob_store.EncryptedBlobWithoutKeysError) as context: + encrypted_blob_store.EncryptedBlobWithoutKeysError + ) as context: bs.ReadBlob(blob_id) self.assertEqual(context.exception.blob_id, blob_id) diff --git a/grr/server/grr_response_server/blob_stores/registry_init.py b/grr/server/grr_response_server/blob_stores/registry_init.py index 95c3ba56b1..c3a40a4af1 100644 --- a/grr/server/grr_response_server/blob_stores/registry_init.py +++ b/grr/server/grr_response_server/blob_stores/registry_init.py @@ -8,4 +8,5 @@ def RegisterBlobStores(): """Registers all BlobStore implementations in blob_store.REGISTRY.""" blob_store.REGISTRY[db_blob_store.DbBlobStore.__name__] = ( - db_blob_store.DbBlobStore) + db_blob_store.DbBlobStore + ) diff --git a/grr/server/grr_response_server/data_store_utils.py b/grr/server/grr_response_server/data_store_utils.py index 3601fa4cdc..e7aa036584 100644 --- a/grr/server/grr_response_server/data_store_utils.py +++ b/grr/server/grr_response_server/data_store_utils.py @@ -10,6 +10,7 @@ from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import mig_client from grr_response_server import data_store +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -37,7 +38,9 @@ def GetFileHashEntry(fd): path_type, components = rdf_objects.ParseCategorizedPath(vfs_path) path_info = data_store.REL_DB.ReadPathInfo(client_id, path_type, components) - return path_info.hash_entry + if path_info is None: + return None + return mig_objects.ToRDFPathInfo(path_info).hash_entry def GetClientKnowledgeBase(client_id): diff --git a/grr/server/grr_response_server/databases/db.py b/grr/server/grr_response_server/databases/db.py index 49bde7fb2e..4fd62de309 100644 --- a/grr/server/grr_response_server/databases/db.py +++ b/grr/server/grr_response_server/databases/db.py @@ -4,9 +4,14 @@ This defines the Database abstraction, which defines the methods used by GRR on a logical relational database model. """ + import abc import collections +import dataclasses +import enum import re +from typing import AbstractSet +from typing import Callable from typing import Collection from typing import Dict from typing import Iterable @@ -22,22 +27,17 @@ from typing import Tuple from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto -from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import search as rdf_search -from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import precondition from grr_response_proto import artifact_pb2 from grr_response_proto import flows_pb2 from grr_response_proto import hunts_pb2 from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 +from grr_response_proto import output_plugin_pb2 from grr_response_proto import user_pb2 from grr_response_server.models import blobs -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner -from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 @@ -140,7 +140,7 @@ def __init__(self, name: str, cause: Optional[Exception] = None): class UnknownClientError(NotFoundError): - """"An exception class representing errors about uninitialized client. + """An exception class representing errors about uninitialized client. Attributes: client_id: An id of the non-existing client that was referenced. @@ -162,7 +162,8 @@ def __init__(self, client_ids, cause=None): self.client_ids = client_ids self.message = "At least one client in '%s' does not exist" % ",".join( - client_ids) + client_ids + ) class UnknownPathError(NotFoundError): @@ -207,9 +208,10 @@ def __init__(self, client_id, path_type, components, cause=None): self.path_type = path_type self.components = components - self.message = ("Listing descendants of path '%s' of type '%s' on client " - "'%s' that is not a directory") - self.message %= ("/".join(self.components), self.path_type, self.client_id) + self.message = ( + "Listing descendants of path '%s' of type '%s' on client " + "'%s' that is not a directory" + ) % ("/".join(self.components), self.path_type, self.client_id) class UnknownRuleError(NotFoundError): @@ -256,8 +258,10 @@ def __init__( super().__init__(binary_id, cause=cause) self.binary_id = binary_id - self.message = ("Signed binary of type %s and path %s was not found" % - (self.binary_id.binary_type, self.binary_id.path)) + self.message = "Signed binary of type %s and path %s was not found" % ( + self.binary_id.binary_type, + self.binary_id.path, + ) class UnknownFlowError(NotFoundError): @@ -268,8 +272,10 @@ def __init__(self, client_id, flow_id, cause=None): self.client_id = client_id self.flow_id = flow_id - self.message = ("Flow with client id '%s' and flow id '%s' does not exist" % - (self.client_id, self.flow_id)) + self.message = ( + "Flow with client id '%s' and flow id '%s' does not exist" + % (self.client_id, self.flow_id) + ) class UnknownScheduledFlowError(NotFoundError): @@ -283,7 +289,8 @@ def __init__(self, client_id, creator, scheduled_flow_id, cause=None): self.scheduled_flow_id = scheduled_flow_id self.message = "ScheduledFlow {}/{}/{} does not exist.".format( - self.client_id, self.creator, self.scheduled_flow_id) + self.client_id, self.creator, self.scheduled_flow_id + ) class UnknownHuntError(NotFoundError): @@ -312,9 +319,10 @@ def __init__(self, hunt_id, state_index): self.hunt_id = hunt_id self.state_index = state_index - self.message = ("Hunt output plugin state for hunt '%s' with " - "index %d does not exist" % - (self.hunt_id, self.state_index)) + self.message = ( + "Hunt output plugin state for hunt '%s' with index %d does not exist" + % (self.hunt_id, self.state_index) + ) class AtLeastOneUnknownFlowError(NotFoundError): @@ -324,8 +332,10 @@ def __init__(self, flow_keys, cause=None): self.flow_keys = flow_keys - self.message = ("At least one flow with client id/flow_id in '%s' " - "does not exist" % (self.flow_keys)) + self.message = ( + "At least one flow with client id/flow_id in '%s' does not exist" + % (self.flow_keys) + ) class UnknownFlowRequestError(NotFoundError): @@ -340,7 +350,8 @@ def __init__(self, client_id, flow_id, request_id, cause=None): self.message = ( "Flow request %d for flow with client id '%s' and flow id '%s' " - "does not exist" % (self.request_id, self.client_id, self.flow_id)) + "does not exist" % (self.request_id, self.client_id, self.flow_id) + ) class AtLeastOneUnknownRequestError(NotFoundError): @@ -350,8 +361,10 @@ def __init__(self, request_keys, cause=None): self.request_keys = request_keys - self.message = ("At least one request with client id/flow_id/request_id in " - "'%s' does not exist" % (self.request_keys)) + self.message = ( + "At least one request with client id/flow_id/request_id in " + "'%s' does not exist" % (self.request_keys) + ) class ParentHuntIsNotRunningError(Error): @@ -367,8 +380,9 @@ def __init__(self, client_id, flow_id, hunt_id, hunt_state): self.message = ( "Parent hunt %s of the flow with client id '%s' and " - "flow id '%s' is not running: %s" % - (self.hunt_id, self.client_id, self.flow_id, self.hunt_state)) + "flow id '%s' is not running: %s" + % (self.hunt_id, self.client_id, self.flow_id, self.hunt_state) + ) class HuntOutputPluginsStatesAreNotInitializedError(Error): @@ -379,9 +393,12 @@ def __init__(self, hunt_obj): self.hunt_obj = hunt_obj - self.message = ("Hunt %r has output plugins but no output plugins states. " - "Make sure it was created with hunt.CreateHunt and not " - "simply written to the database." % self.hunt_obj) + self.message = ( + "Hunt %r has output plugins but no output plugins states. " + "Make sure it was created with hunt.CreateHunt and not " + "simply written to the database." + % self.hunt_obj + ) class ConflictingUpdateFlowArgumentsError(Error): @@ -393,10 +410,11 @@ def __init__(self, client_id, flow_id, param_name): self.flow_id = flow_id self.param_name = param_name - self.message = ("Conflicting parameter when updating flow " - "%s (client %s). Can't call UpdateFlow with " - "flow_obj and %s passed together." % - (flow_id, client_id, param_name)) + self.message = ( + "Conflicting parameter when updating flow " + "%s (client %s). Can't call UpdateFlow with " + "flow_obj and %s passed together." % (flow_id, client_id, param_name) + ) class FlowExistsError(Error): @@ -412,39 +430,49 @@ class StringTooLongError(ValueError): """Validation error raised if a string is too long.""" -# TODO(user): migrate to Python 3 enums as soon as Python 3 is default. -class HuntFlowsCondition(object): +class HuntFlowsCondition(enum.Enum): """Constants to be used with ReadHuntFlows/CountHuntFlows methods.""" - UNSET = 0 - FAILED_FLOWS_ONLY = 1 - SUCCEEDED_FLOWS_ONLY = 2 - COMPLETED_FLOWS_ONLY = 3 - FLOWS_IN_PROGRESS_ONLY = 4 - CRASHED_FLOWS_ONLY = 5 - - @classmethod - def MaxValue(cls): - return cls.CRASHED_FLOWS_ONLY - - -HuntCounters = collections.namedtuple("HuntCounters", [ - "num_clients", - "num_successful_clients", - "num_failed_clients", - "num_clients_with_results", - "num_crashed_clients", - "num_running_clients", - "num_results", - "total_cpu_seconds", - "total_network_bytes_sent", -]) - -FlowStateAndTimestamps = collections.namedtuple("FlowStateAndTimestamps", [ - "flow_state", - "create_time", - "last_update_time", -]) + UNSET = enum.auto() + FAILED_FLOWS_ONLY = enum.auto() + SUCCEEDED_FLOWS_ONLY = enum.auto() + COMPLETED_FLOWS_ONLY = enum.auto() + FLOWS_IN_PROGRESS_ONLY = enum.auto() + CRASHED_FLOWS_ONLY = enum.auto() + + +HuntCounters = collections.namedtuple( + "HuntCounters", + [ + "num_clients", + "num_successful_clients", + "num_failed_clients", + "num_clients_with_results", + "num_crashed_clients", + "num_running_clients", + "num_results", + "total_cpu_seconds", + "total_network_bytes_sent", + ], +) + +FlowStateAndTimestamps = collections.namedtuple( + "FlowStateAndTimestamps", + [ + "flow_state", + "create_time", + "last_update_time", + ], +) + + +@dataclasses.dataclass +class FlowErrorInfo: + """Information about what caused flow to error-out.""" + + message: str + time: rdfvalue.RDFDatetime + backtrace: Optional[str] = None class SearchClientsResult(NamedTuple): @@ -473,38 +501,40 @@ class ClientPath(object): def __init__(self, client_id, path_type, components): precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) _ValidatePathComponents(components) self._repr = (client_id, path_type, tuple(components)) @classmethod def OS(cls, client_id, components): - path_type = rdf_objects.PathInfo.PathType.OS + path_type = objects_pb2.PathInfo.PathType.OS return cls(client_id=client_id, path_type=path_type, components=components) @classmethod def TSK(cls, client_id, components): - path_type = rdf_objects.PathInfo.PathType.TSK + path_type = objects_pb2.PathInfo.PathType.TSK return cls(client_id=client_id, path_type=path_type, components=components) @classmethod def NTFS(cls, client_id, components): - path_type = rdf_objects.PathInfo.PathType.NTFS + path_type = objects_pb2.PathInfo.PathType.NTFS return cls(client_id=client_id, path_type=path_type, components=components) @classmethod def Registry(cls, client_id, components): - path_type = rdf_objects.PathInfo.PathType.REGISTRY + path_type = objects_pb2.PathInfo.PathType.REGISTRY return cls(client_id=client_id, path_type=path_type, components=components) @classmethod def Temp(cls, client_id, components): - path_type = rdf_objects.PathInfo.PathType.TEMP + path_type = objects_pb2.PathInfo.PathType.TEMP return cls(client_id=client_id, path_type=path_type, components=components) @classmethod def FromPathSpec(cls, client_id, path_spec): - path_info = rdf_objects.PathInfo.FromPathSpec(path_spec) + path_info = mig_objects.ToProtoPathInfo( + rdf_objects.PathInfo.FromPathSpec(path_spec) + ) return cls.FromPathInfo(client_id, path_info) @classmethod @@ -512,7 +542,8 @@ def FromPathInfo(cls, client_id, path_info): return cls( client_id=client_id, path_type=path_info.path_type, - components=tuple(path_info.components)) + components=tuple(path_info.components), + ) @property def client_id(self): @@ -555,14 +586,18 @@ def Path(self): def __repr__(self): return "<%s client_id=%r path_type=%r components=%r>" % ( - self.__class__.__name__, self.client_id, self.path_type, - self.components) + self.__class__.__name__, + self.client_id, + self.path_type, + self.components, + ) class Database(metaclass=abc.ABCMeta): """The GRR relational database abstraction.""" - unchanged = "__unchanged__" + UNCHANGED_TYPE = Literal["__unchanged__"] + UNCHANGED: UNCHANGED_TYPE = "__unchanged__" @abc.abstractmethod def Now(self) -> rdfvalue.RDFDatetime: @@ -625,7 +660,6 @@ def DeleteArtifact(self, name: str) -> None: def MultiWriteClientMetadata( self, client_ids: Collection[str], - certificate: Optional[rdf_crypto.RDFX509Cert] = None, first_seen: Optional[rdfvalue.RDFDatetime] = None, last_ping: Optional[rdfvalue.RDFDatetime] = None, last_clock: Optional[rdfvalue.RDFDatetime] = None, @@ -641,8 +675,6 @@ def MultiWriteClientMetadata( Args: client_ids: A collection of GRR client id strings, e.g. ["C.ea3b2b71840d6fa7", "C.ea3b2b71840d6fa8"] - certificate: If set, should be an rdfvalues.crypto.RDFX509 protocol - buffer. Normally only set during initial client record creation. first_seen: An rdfvalue.Datetime, indicating the first time the client contacted the server. last_ping: An rdfvalue.Datetime, indicating the last time the client @@ -659,7 +691,6 @@ def MultiWriteClientMetadata( def WriteClientMetadata( self, client_id: str, - certificate: Optional[rdf_crypto.RDFX509Cert] = None, first_seen: Optional[rdfvalue.RDFDatetime] = None, last_ping: Optional[rdfvalue.RDFDatetime] = None, last_clock: Optional[rdfvalue.RDFDatetime] = None, @@ -674,8 +705,6 @@ def WriteClientMetadata( Args: client_id: A GRR client id string, e.g. "C.ea3b2b71840d6fa7". - certificate: If set, should be an rdfvalues.crypto.RDFX509 protocol - buffer. Normally only set during initial client record creation. first_seen: An rdfvalue.Datetime, indicating the first time the client contacted the server. last_ping: An rdfvalue.Datetime, indicating the last time the client @@ -690,7 +719,6 @@ def WriteClientMetadata( """ self.MultiWriteClientMetadata( client_ids=[client_id], - certificate=certificate, first_seen=first_seen, last_ping=last_ping, last_clock=last_clock, @@ -839,9 +867,11 @@ def ReadClientFullInfo( except KeyError: raise UnknownClientError(client_id) - def ReadAllClientIDs(self, - min_last_ping=None, - batch_size=CLIENT_IDS_BATCH_SIZE): + def ReadAllClientIDs( + self, + min_last_ping=None, + batch_size=CLIENT_IDS_BATCH_SIZE, + ): """Yields lists of client ids for all clients in the database. Args: @@ -855,14 +885,17 @@ def ReadAllClientIDs(self, """ for results in self.ReadClientLastPings( - min_last_ping=min_last_ping, batch_size=batch_size): + min_last_ping=min_last_ping, batch_size=batch_size + ): yield list(results.keys()) @abc.abstractmethod - def ReadClientLastPings(self, - min_last_ping=None, - max_last_ping=None, - batch_size=CLIENT_IDS_BATCH_SIZE): + def ReadClientLastPings( + self, + min_last_ping=None, + max_last_ping=None, + batch_size=CLIENT_IDS_BATCH_SIZE, + ): """Yields dicts of last-ping timestamps for clients in the DB. Args: @@ -1355,10 +1388,10 @@ def IterateAllClientsFullInfo( def ReadPathInfo( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> rdf_objects.PathInfo: + ) -> objects_pb2.PathInfo: """Retrieves a path info record for a given path. The `timestamp` parameter specifies for what moment in time the path @@ -1376,16 +1409,16 @@ def ReadPathInfo( If none is provided, the latest known path information is returned. Returns: - An `rdf_objects.PathInfo` instance. + A PathInfo instance. """ @abc.abstractmethod def ReadPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Collection[Sequence[str]], - ) -> dict[Sequence[str], rdf_objects.PathInfo]: + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: """Retrieves path info records for given paths. Args: @@ -1395,16 +1428,16 @@ def ReadPathInfos( paths to retrieve path information for. Returns: - A dictionary mapping path components to `rdf_objects.PathInfo` instances. + A dictionary mapping path components to PathInfo instances. """ def ListChildPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: """Lists path info records that correspond to children of given path. Args: @@ -1416,20 +1449,21 @@ def ListChildPathInfos( timestamp. Returns: - A list of `rdf_objects.PathInfo` instances sorted by path components. + A list of PathInfo instances sorted by path components. """ return self.ListDescendantPathInfos( - client_id, path_type, components, max_depth=1, timestamp=timestamp) + client_id, path_type, components, max_depth=1, timestamp=timestamp + ) @abc.abstractmethod def ListDescendantPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, max_depth: Optional[int] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: """Lists path info records that correspond to descendants of given path. Args: @@ -1442,14 +1476,14 @@ def ListDescendantPathInfos( unlimited. Returns: - A list of `rdf_objects.PathInfo` instances sorted by path components. + A list of objects_pb2.PathInfo instances sorted by path components. """ @abc.abstractmethod def WritePathInfos( self, client_id: str, - path_infos: Iterable[rdf_objects.PathInfo], + path_infos: Iterable[objects_pb2.PathInfo], ) -> None: """Writes a collection of path_info records for a client. @@ -1465,10 +1499,10 @@ def WritePathInfos( def ReadPathInfosHistories( self, client_id: Text, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Iterable[Sequence[Text]], cutoff: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[Sequence[Text], Sequence[rdf_objects.PathInfo]]: + ) -> Dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: """Reads a collection of hash and stat entries for given paths. Args: @@ -1480,17 +1514,17 @@ def ReadPathInfosHistories( collected. Returns: - A dictionary mapping path components to lists of `rdf_objects.PathInfo` + A dictionary mapping path components to lists of PathInfo ordered by timestamp in ascending order. """ def ReadPathInfoHistory( self, client_id: Text, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[Text], cutoff: Optional[rdfvalue.RDFDatetime] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: """Reads a collection of hash and stat entry for given path. Args: @@ -1502,13 +1536,14 @@ def ReadPathInfoHistory( collected. Returns: - A list of `rdf_objects.PathInfo` ordered by timestamp in ascending order. + A list of PathInfo ordered by timestamp in ascending order. """ histories = self.ReadPathInfosHistories( client_id=client_id, path_type=path_type, components_list=[components], - cutoff=cutoff) + cutoff=cutoff, + ) return histories[components] @@ -1517,7 +1552,7 @@ def ReadLatestPathInfosWithHashBlobReferences( self, client_paths: Collection[ClientPath], max_timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[ClientPath, Optional[rdf_objects.PathInfo]]: + ) -> Dict[ClientPath, Optional[objects_pb2.PathInfo]]: """Returns PathInfos that have corresponding HashBlobReferences. Args: @@ -1593,8 +1628,8 @@ def ReadAPIAuditEntries( username: Optional[Text] = None, router_method_names: Optional[List[Text]] = None, min_timestamp: Optional[rdfvalue.RDFDatetime] = None, - max_timestamp: Optional[rdfvalue.RDFDatetime] = None - ) -> List[rdf_objects.APIAuditEntry]: + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> List[objects_pb2.APIAuditEntry]: """Returns audit entries stored in the database. The event log is sorted according to their timestamp (with the oldest @@ -1607,14 +1642,14 @@ def ReadAPIAuditEntries( max_timestamp: maximum rdfvalue.RDFDateTime (inclusive) Returns: - List of `rdfvalues.objects.APIAuditEntry` instances. + List of `APIAuditEntry` instances. """ @abc.abstractmethod def CountAPIAuditEntriesByUserAndDay( self, min_timestamp: Optional[rdfvalue.RDFDatetime] = None, - max_timestamp: Optional[rdfvalue.RDFDatetime] = None + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, ) -> Dict[Tuple[Text, rdfvalue.RDFDatetime], int]: """Returns audit entry counts grouped by user and calendar day. @@ -1634,45 +1669,56 @@ def CountAPIAuditEntriesByUserAndDay( """ @abc.abstractmethod - def WriteAPIAuditEntry(self, entry): + def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry) -> None: """Writes an audit entry to the database. Args: - entry: An `audit.APIAuditEntry` instance. + entry: An `APIAuditEntry` instance. """ @abc.abstractmethod - def WriteMessageHandlerRequests(self, requests): + def WriteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: """Writes a list of message handler requests to the database. Args: - requests: List of objects.MessageHandlerRequest. + requests: List of MessageHandlerRequest. """ @abc.abstractmethod - def ReadMessageHandlerRequests(self): + def ReadMessageHandlerRequests( + self, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: """Reads all message handler requests from the database. Returns: - A list of objects.MessageHandlerRequest, sorted by timestamp, + A list of MessageHandlerRequest, sorted by timestamp, newest first. """ @abc.abstractmethod - def DeleteMessageHandlerRequests(self, requests): + def DeleteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: """Deletes a list of message handler requests from the database. Args: - requests: List of objects.MessageHandlerRequest. + requests: List of MessageHandlerRequest. """ @abc.abstractmethod - def RegisterMessageHandler(self, handler, lease_time, limit=1000): + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: """Registers a message handler to receive batches of messages. Args: handler: Method, which will be called repeatedly with lists of leased - objects.MessageHandlerRequest. Required. + MessageHandlerRequest. Required. lease_time: rdfvalue.Duration indicating how long the lease should be valid. Required. limit: Limit for the number of leased requests to give one execution of @@ -1680,7 +1726,9 @@ def RegisterMessageHandler(self, handler, lease_time, limit=1000): """ @abc.abstractmethod - def UnregisterMessageHandler(self, timeout=None): + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered message handler. Args: @@ -1924,11 +1972,15 @@ def ReadHashBlobReferences( CLIENT_MESSAGES_TTL = 5 @abc.abstractmethod - def WriteFlowObject(self, flow_obj, allow_update=True): + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + ) -> None: """Writes a flow object to the database. Args: - flow_obj: An rdf_flow_objects.Flow object to write. + flow_obj: A Flow object to write. allow_update: If False, raises AlreadyExistsError if the flow already exists in the database. If True, the flow will be updated. @@ -1938,7 +1990,7 @@ def WriteFlowObject(self, flow_obj, allow_update=True): """ @abc.abstractmethod - def ReadFlowObject(self, client_id, flow_id): + def ReadFlowObject(self, client_id: str, flow_id: str) -> flows_pb2.Flow: """Reads a flow object from the database. Args: @@ -1946,7 +1998,7 @@ def ReadFlowObject(self, client_id, flow_id): flow_id: The id of the flow to read. Returns: - An rdf_flow_objects.Flow object. + A Flow object. """ @abc.abstractmethod @@ -1958,7 +2010,7 @@ def ReadAllFlowObjects( max_create_time: Optional[rdfvalue.RDFDatetime] = None, include_child_flows: bool = True, not_created_by: Optional[Iterable[str]] = None, - ) -> List[rdf_flow_objects.Flow]: + ) -> List[flows_pb2.Flow]: """Returns all flow objects. Args: @@ -1971,10 +2023,12 @@ def ReadAllFlowObjects( not_created_by: exclude flows created by any of the users in this list. Returns: - A list of rdf_flow_objects.Flow objects. + A list of Flow objects. """ - def ReadChildFlowObjects(self, client_id, flow_id): + def ReadChildFlowObjects( + self, client_id: str, flow_id: str + ) -> List[flows_pb2.Flow]: """Reads flow objects that were started by a given flow from the database. Args: @@ -1985,10 +2039,16 @@ def ReadChildFlowObjects(self, client_id, flow_id): A list of rdf_flow_objects.Flow objects. """ return self.ReadAllFlowObjects( - client_id=client_id, parent_flow_id=flow_id, include_child_flows=True) + client_id=client_id, parent_flow_id=flow_id, include_child_flows=True + ) @abc.abstractmethod - def LeaseFlowForProcessing(self, client_id, flow_id, processing_time): + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + ) -> flows_pb2.Flow: """Marks a flow as being processed on this worker and returns it. Args: @@ -2003,11 +2063,11 @@ def LeaseFlowForProcessing(self, client_id, flow_id, processing_time): completed. Returns: - And rdf_flow_objects.Flow object. + And Flow object. """ @abc.abstractmethod - def ReleaseProcessedFlow(self, flow_obj): + def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: """Releases a flow that the worker was processing to the database. This method will check if there are currently more requests ready for @@ -2015,7 +2075,7 @@ def ReleaseProcessedFlow(self, flow_obj): the method will return false. Args: - flow_obj: The rdf_flow_objects.Flow object to return. + flow_obj: The Flow object to return to the database. Returns: A boolean indicating if it was possible to return the flow to the @@ -2024,15 +2084,25 @@ def ReleaseProcessedFlow(self, flow_obj): """ @abc.abstractmethod - def UpdateFlow(self, - client_id, - flow_id, - flow_obj=unchanged, - flow_state=unchanged, - client_crash_info=unchanged, - processing_on=unchanged, - processing_since=unchanged, - processing_deadline=unchanged): + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[flows_pb2.Flow, UNCHANGED_TYPE] = UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, UNCHANGED_TYPE + ] = UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, UNCHANGED_TYPE + ] = UNCHANGED, + processing_on: Union[str, UNCHANGED_TYPE] = UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, UNCHANGED_TYPE] + ] = UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, UNCHANGED_TYPE] + ] = UNCHANGED, + ) -> None: """Updates flow objects in the database. Args: @@ -2048,16 +2118,23 @@ def UpdateFlow(self, """ @abc.abstractmethod - def WriteFlowRequests(self, requests): + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: """Writes a list of flow requests to the database. Args: - requests: List of rdf_flow_objects.FlowRequest objects. + requests: List of FlowRequest objects. """ @abc.abstractmethod - def UpdateIncrementalFlowRequests(self, client_id: str, flow_id: str, - next_response_id_updates: Dict[int, int]): + def UpdateIncrementalFlowRequests( + self, + client_id: str, + flow_id: str, + next_response_id_updates: Mapping[int, int], + ) -> None: """Updates next response ids of given requests. Used to update incremental requests (requests with a callback_state @@ -2071,19 +2148,30 @@ def UpdateIncrementalFlowRequests(self, client_id: str, flow_id: str, """ @abc.abstractmethod - def DeleteFlowRequests(self, requests): + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + ) -> None: """Deletes a list of flow requests from the database. Note: This also deletes all corresponding responses. Args: - requests: List of rdf_flow_objects.FlowRequest objects. + requests: List of FlowRequest objects. """ @abc.abstractmethod def WriteFlowResponses( - self, responses: Iterable[rdf_flow_objects.FlowMessage]) -> None: - """Writes FlowMessages and updates corresponding requests. + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + """Writes Flow responses and updates corresponding requests. This method not only stores the list of responses given in the database but also updates flow status information at the same time. Specifically, it @@ -2093,11 +2181,30 @@ def WriteFlowResponses( for processing, it also writes a FlowProcessingRequest to notify the worker. Args: - responses: List of rdf_flow_objects.FlowMessage rdfvalues to write. + responses: List of FlowResponses, FlowStatuses or FlowIterators values to + write. """ @abc.abstractmethod - def ReadAllFlowRequestsAndResponses(self, client_id, flow_id): + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> Iterable[ + Tuple[ + flows_pb2.FlowRequest, + Dict[ + int, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ] + ]: """Reads all requests and responses for a given flow from the database. Args: @@ -2110,7 +2217,11 @@ def ReadAllFlowRequestsAndResponses(self, client_id, flow_id): """ @abc.abstractmethod - def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id): + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> None: """Deletes all requests and responses for a given flow from the database. Args: @@ -2119,10 +2230,24 @@ def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id): """ @abc.abstractmethod - def ReadFlowRequestsReadyForProcessing(self, - client_id, - flow_id, - next_needed_request=None): + def ReadFlowRequestsReadyForProcessing( + self, + client_id: str, + flow_id: str, + next_needed_request: Optional[int] = None, + ) -> Dict[ + int, + Tuple[ + flows_pb2.FlowRequest, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: """Reads all requests for a flow that can be processed by the worker. There are 2 kinds of requests that are going to be returned by this call: @@ -2149,24 +2274,31 @@ def ReadFlowRequestsReadyForProcessing(self, """ @abc.abstractmethod - def WriteFlowProcessingRequests(self, requests): + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + ) -> None: """Writes a list of flow processing requests to the database. Args: - requests: List of rdf_flows.FlowProcessingRequest. + requests: List of FlowProcessingRequest. """ @abc.abstractmethod - def ReadFlowProcessingRequests(self): + def ReadFlowProcessingRequests( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Reads all flow processing requests from the database. Returns: - A list of rdf_flows.FlowProcessingRequest, sorted by timestamp, + A list of FlowProcessingRequest, sorted by timestamp, newest first. """ @abc.abstractmethod - def AckFlowProcessingRequests(self, requests): + def AckFlowProcessingRequests( + self, requests: Iterable[flows_pb2.FlowProcessingRequest] + ) -> None: """Acknowledges and deletes flow processing requests. Args: @@ -2174,11 +2306,13 @@ def AckFlowProcessingRequests(self, requests): """ @abc.abstractmethod - def DeleteAllFlowProcessingRequests(self): + def DeleteAllFlowProcessingRequests(self) -> None: """Deletes all flow processing requests from the database.""" @abc.abstractmethod - def RegisterFlowProcessingHandler(self, handler): + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: """Registers a handler to receive flow processing messages. Args: @@ -2187,7 +2321,9 @@ def RegisterFlowProcessingHandler(self, handler): """ @abc.abstractmethod - def UnregisterFlowProcessingHandler(self, timeout=None): + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered flow processing handler. Args: @@ -2332,11 +2468,13 @@ def ReadFlowErrors( """ @abc.abstractmethod - def CountFlowErrors(self, - client_id: str, - flow_id: str, - with_tag: Optional[str] = None, - with_type: Optional[str] = None) -> int: + def CountFlowErrors( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: """Counts flow errors of a given flow using given query options. If both with_tag and with_type arguments are provided, they will be applied @@ -2548,15 +2686,17 @@ def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt) -> None: @abc.abstractmethod def UpdateHuntObject( self, - hunt_id, - duration=None, - client_rate=None, - client_limit=None, - hunt_state=None, - hunt_state_reason=None, - hunt_state_comment=None, - start_time=None, - num_clients_at_start_time=None, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[float] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, ): """Updates the hunt object by applying the update function. @@ -2577,7 +2717,9 @@ def UpdateHuntObject( """ @abc.abstractmethod - def ReadHuntOutputPluginsStates(self, hunt_id): + def ReadHuntOutputPluginsStates( + self, hunt_id: str + ) -> List[output_plugin_pb2.OutputPluginState]: """Reads all hunt output plugins states of a given hunt. Args: @@ -2591,7 +2733,11 @@ def ReadHuntOutputPluginsStates(self, hunt_id): """ @abc.abstractmethod - def WriteHuntOutputPluginsStates(self, hunt_id, states): + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + ) -> None: """Writes hunt output plugin states for a given hunt. Args: @@ -2604,7 +2750,15 @@ def WriteHuntOutputPluginsStates(self, hunt_id, states): pass @abc.abstractmethod - def UpdateHuntOutputPluginState(self, hunt_id, state_index, update_fn): + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[ + [jobs_pb2.AttributedDict], + jobs_pb2.AttributedDict, + ], + ) -> jobs_pb2.AttributedDict: """Updates hunt output plugin state for a given output plugin. Args: @@ -2627,7 +2781,7 @@ def UpdateHuntOutputPluginState(self, hunt_id, state_index, update_fn): """ @abc.abstractmethod - def DeleteHuntObject(self, hunt_id): + def DeleteHuntObject(self, hunt_id: str) -> None: """Deletes a hunt object with a given id. Args: @@ -2635,7 +2789,7 @@ def DeleteHuntObject(self, hunt_id): """ @abc.abstractmethod - def ReadHuntObject(self, hunt_id): + def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: """Reads a hunt object from the database. Args: @@ -2651,15 +2805,17 @@ def ReadHuntObject(self, hunt_id): @abc.abstractmethod def ReadHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> List[hunts_pb2.Hunt]: """Reads hunt objects from the database. Args: @@ -2691,15 +2847,17 @@ def ReadHuntObjects( @abc.abstractmethod def ListHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> List[hunts_pb2.HuntMetadata]: """Reads metadata for hunt objects from the database. Args: @@ -2765,14 +2923,16 @@ def CountHuntLogEntries(self, hunt_id: str) -> int: """ @abc.abstractmethod - def ReadHuntResults(self, - hunt_id, - offset, - count, - with_tag=None, - with_type=None, - with_substring=None, - with_timestamp=None): + def ReadHuntResults( + self, + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Iterable[flows_pb2.FlowResult]: """Reads hunt results of a given hunt using given query options. If both with_tag and with_type and/or with_substring arguments are provided, @@ -2800,7 +2960,12 @@ def ReadHuntResults(self, """ @abc.abstractmethod - def CountHuntResults(self, hunt_id, with_tag=None, with_type=None): + def CountHuntResults( + self, + hunt_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: """Counts hunt results of a given hunt using given query options. If both with_tag and with_type arguments are provided, they will be applied @@ -2818,7 +2983,7 @@ def CountHuntResults(self, hunt_id, with_tag=None, with_type=None): """ @abc.abstractmethod - def CountHuntResultsByType(self, hunt_id: str) -> Dict[str, int]: + def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: """Returns counts of items in hunt results grouped by type. Args: @@ -2829,11 +2994,13 @@ def CountHuntResultsByType(self, hunt_id: str) -> Dict[str, int]: """ @abc.abstractmethod - def ReadHuntFlows(self, - hunt_id, - offset, - count, - filter_condition=HuntFlowsCondition.UNSET): + def ReadHuntFlows( + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: HuntFlowsCondition = HuntFlowsCondition.UNSET, + ) -> Sequence[flows_pb2.Flow]: """Reads hunt flows matching given conditins. If more than one condition is specified, all of them have to be fulfilled @@ -2854,7 +3021,11 @@ def ReadHuntFlows(self, """ @abc.abstractmethod - def CountHuntFlows(self, hunt_id, filter_condition=HuntFlowsCondition.UNSET): + def CountHuntFlows( + self, + hunt_id: str, + filter_condition: Optional[HuntFlowsCondition] = HuntFlowsCondition.UNSET, + ) -> int: """Counts hunt flows matching given conditions. If more than one condition is specified, all of them have to be fulfilled @@ -2870,8 +3041,46 @@ def CountHuntFlows(self, hunt_id, filter_condition=HuntFlowsCondition.UNSET): A number of flows matching the specified condition. """ - @abc.abstractmethod - def ReadHuntCounters(self, hunt_id): + def ReadHuntFlowErrors( + self, + hunt_id: str, + offset: int, + count: int, + ) -> Mapping[str, FlowErrorInfo]: + """Returns errors for flows of the given hunt. + + Args: + hunt_id: Identifier of the hunt for which to retrieve errors. + offset: Offset from which we start returning errors. + count: Number of rows + + Returns: + A mapping from client identifiers to information about errors. + """ + results = {} + + for flow_obj in self.ReadHuntFlows( + hunt_id, + offset=offset, + count=count, + filter_condition=HuntFlowsCondition.FAILED_FLOWS_ONLY, + ): + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) + info = FlowErrorInfo( + message=flow_obj.error_message, + time=flow_obj.last_update_time, + ) + if flow_obj.HasField("backtrace"): + info.backtrace = flow_obj.backtrace + + results[flow_obj.client_id] = info + + return results + + def ReadHuntCounters( + self, + hunt_id: str, + ) -> HuntCounters: """Reads hunt counters. Args: @@ -2880,20 +3089,40 @@ def ReadHuntCounters(self, hunt_id): Returns: HuntCounters object. """ + return self.ReadHuntsCounters([hunt_id])[hunt_id] @abc.abstractmethod - def ReadHuntClientResourcesStats(self, hunt_id): + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + ) -> Mapping[str, HuntCounters]: + """Reads hunt counters for several hunt_ids. + + Args: + hunt_ids: The ids of the hunts to read counters for. + + Returns: + A mapping from hunt_ids to HuntCounters objects. + """ + + @abc.abstractmethod + def ReadHuntClientResourcesStats( + self, hunt_id: str + ) -> jobs_pb2.ClientResourcesStats: """Read hunt client resources stats. Args: hunt_id: The id of the hunt to read counters for. Returns: - rdf_stats.ClientResourcesStats object. + ClientResourcesStats object. """ @abc.abstractmethod - def ReadHuntFlowsStatesAndTimestamps(self, hunt_id): + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + ) -> Sequence[FlowStateAndTimestamps]: """Reads hunt flows states and timestamps. Args: @@ -3018,10 +3247,13 @@ def ListScheduledFlows( """Lists all ScheduledFlows for the client and creator.""" @abc.abstractmethod - def StructuredSearchClients(self, expression: rdf_search.SearchExpression, - sort_order: rdf_search.SortOrder, - continuation_token: bytes, - number_of_results: int) -> SearchClientsResult: + def StructuredSearchClients( + self, + expression: rdf_search.SearchExpression, + sort_order: rdf_search.SortOrder, + continuation_token: bytes, + number_of_results: int, + ) -> SearchClientsResult: """Perform a search for clients. Args: @@ -3078,8 +3310,9 @@ def WriteArtifact(self, artifact: artifact_pb2.Artifact) -> None: precondition.AssertType(artifact, artifact_pb2.Artifact) if not artifact.name: raise ValueError("Empty artifact name") - _ValidateStringLength("Artifact names", artifact.name, - MAX_ARTIFACT_NAME_LENGTH) + _ValidateStringLength( + "Artifact names", artifact.name, MAX_ARTIFACT_NAME_LENGTH + ) return self.delegate.WriteArtifact(artifact) @@ -3097,7 +3330,6 @@ def DeleteArtifact(self, name: str) -> None: def MultiWriteClientMetadata( self, client_ids: Collection[str], - certificate: Optional[rdf_crypto.RDFX509Cert] = None, first_seen: Optional[rdfvalue.RDFDatetime] = None, last_ping: Optional[rdfvalue.RDFDatetime] = None, last_clock: Optional[rdfvalue.RDFDatetime] = None, @@ -3106,7 +3338,6 @@ def MultiWriteClientMetadata( fleetspeak_validation_info: Optional[Mapping[str, str]] = None, ) -> None: _ValidateClientIds(client_ids) - precondition.AssertOptionalType(certificate, rdf_crypto.RDFX509Cert) precondition.AssertOptionalType(first_seen, rdfvalue.RDFDatetime) precondition.AssertOptionalType(last_ping, rdfvalue.RDFDatetime) precondition.AssertOptionalType(last_clock, rdfvalue.RDFDatetime) @@ -3118,7 +3349,6 @@ def MultiWriteClientMetadata( return self.delegate.MultiWriteClientMetadata( client_ids=client_ids, - certificate=certificate, first_seen=first_seen, last_ping=last_ping, last_clock=last_clock, @@ -3146,8 +3376,9 @@ def WriteClientSnapshot( snapshot: objects_pb2.ClientSnapshot, ) -> None: precondition.AssertType(snapshot, objects_pb2.ClientSnapshot) - _ValidateStringLength("Platform", snapshot.knowledge_base.os, - _MAX_CLIENT_PLATFORM_LENGTH) + _ValidateStringLength( + "Platform", snapshot.knowledge_base.os, _MAX_CLIENT_PLATFORM_LENGTH + ) return self.delegate.WriteClientSnapshot(snapshot) def MultiReadClientSnapshot( @@ -3164,25 +3395,29 @@ def MultiReadClientFullInfo( ) -> Mapping[str, objects_pb2.ClientFullInfo]: _ValidateClientIds(client_ids) return self.delegate.MultiReadClientFullInfo( - client_ids, min_last_ping=min_last_ping) + client_ids, min_last_ping=min_last_ping + ) - def ReadClientLastPings(self, - min_last_ping=None, - max_last_ping=None, - batch_size=CLIENT_IDS_BATCH_SIZE): + def ReadClientLastPings( + self, + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + max_last_ping: Optional[rdfvalue.RDFDatetime] = None, + batch_size: int = CLIENT_IDS_BATCH_SIZE, + ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: precondition.AssertOptionalType(min_last_ping, rdfvalue.RDFDatetime) precondition.AssertOptionalType(max_last_ping, rdfvalue.RDFDatetime) precondition.AssertType(batch_size, int) if batch_size < 1: raise ValueError( - "batch_size needs to be a positive integer, got {}".format( - batch_size)) + "batch_size needs to be a positive integer, got {}".format(batch_size) + ) return self.delegate.ReadClientLastPings( min_last_ping=min_last_ping, max_last_ping=max_last_ping, - batch_size=batch_size) + batch_size=batch_size, + ) def ReadClientSnapshotHistory( self, @@ -3196,7 +3431,8 @@ def ReadClientSnapshotHistory( self._ValidateTimeRange(timerange) return self.delegate.ReadClientSnapshotHistory( - client_id, timerange=timerange) + client_id, timerange=timerange + ) def WriteClientStartupInfo( self, @@ -3285,7 +3521,8 @@ def ListClientsForKeywords( self._ValidateTimestamp(start_time) result = self.delegate.ListClientsForKeywords( - keywords, start_time=start_time) + keywords, start_time=start_time + ) precondition.AssertDictType(result, str, Collection) for value in result.values(): precondition.AssertIterableType(value, str) @@ -3389,7 +3626,8 @@ def WriteGRRUser( ui_mode=ui_mode, canary_mode=canary_mode, user_type=user_type, - email=email) + email=email, + ) def ReadGRRUser(self, username) -> objects_pb2.GRRUser: _ValidateUsername(username) @@ -3449,41 +3687,44 @@ def ReadApprovalRequests( requestor_username, approval_type, subject_id=subject_id, - include_expired=include_expired) + include_expired=include_expired, + ) def GrantApproval(self, requestor_username, approval_id, grantor_username): _ValidateUsername(requestor_username) _ValidateApprovalId(approval_id) _ValidateUsername(grantor_username) - return self.delegate.GrantApproval(requestor_username, approval_id, - grantor_username) + return self.delegate.GrantApproval( + requestor_username, approval_id, grantor_username + ) def ReadPathInfo( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> rdf_objects.PathInfo: + ) -> objects_pb2.PathInfo: precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) _ValidatePathComponents(components) if timestamp is not None: self._ValidateTimestamp(timestamp) return self.delegate.ReadPathInfo( - client_id, path_type, components, timestamp=timestamp) + client_id, path_type, components, timestamp=timestamp + ) def ReadPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Collection[Sequence[str]], - ) -> dict[Sequence[str], rdf_objects.PathInfo]: + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) precondition.AssertType(components_list, list) for components in components_list: _ValidatePathComponents(components) @@ -3493,28 +3734,29 @@ def ReadPathInfos( def ListChildPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) _ValidatePathComponents(components) precondition.AssertOptionalType(timestamp, rdfvalue.RDFDatetime) return self.delegate.ListChildPathInfos( - client_id, path_type, components, timestamp=timestamp) + client_id, path_type, components, timestamp=timestamp + ) def ListDescendantPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, max_depth: Optional[int] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) _ValidatePathComponents(components) precondition.AssertOptionalType(timestamp, rdfvalue.RDFDatetime) precondition.AssertOptionalType(max_depth, int) @@ -3524,12 +3766,13 @@ def ListDescendantPathInfos( path_type, components, timestamp=timestamp, - max_depth=max_depth) + max_depth=max_depth, + ) def WritePathInfos( self, client_id: str, - path_infos: Iterable[rdf_objects.PathInfo], + path_infos: Iterable[objects_pb2.PathInfo], ) -> None: precondition.ValidateClientId(client_id) _ValidatePathInfos(path_infos) @@ -3560,17 +3803,18 @@ def ReadUserNotifications( _ValidateNotificationState(state) return self.delegate.ReadUserNotifications( - username, state=state, timerange=timerange) + username, state=state, timerange=timerange + ) def ReadPathInfosHistories( self, client_id: Text, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Iterable[Sequence[Text]], cutoff: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[Sequence[Text], Sequence[rdf_objects.PathInfo]]: + ) -> Dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: precondition.ValidateClientId(client_id) - _ValidateEnumType(path_type, rdf_objects.PathInfo.PathType) + _ValidateProtoEnumType(path_type, objects_pb2.PathInfo.PathType) precondition.AssertType(components_list, list) for components in components_list: _ValidatePathComponents(components) @@ -3580,17 +3824,19 @@ def ReadPathInfosHistories( client_id=client_id, path_type=path_type, components_list=components_list, - cutoff=cutoff) + cutoff=cutoff, + ) def ReadLatestPathInfosWithHashBlobReferences( self, client_paths: Collection[ClientPath], max_timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[ClientPath, Optional[rdf_objects.PathInfo]]: + ) -> Dict[ClientPath, Optional[objects_pb2.PathInfo]]: precondition.AssertIterableType(client_paths, ClientPath) precondition.AssertOptionalType(max_timestamp, rdfvalue.RDFDatetime) return self.delegate.ReadLatestPathInfosWithHashBlobReferences( - client_paths, max_timestamp=max_timestamp) + client_paths, max_timestamp=max_timestamp + ) def UpdateUserNotifications( self, @@ -3602,56 +3848,73 @@ def UpdateUserNotifications( _ValidateNotificationState(state) return self.delegate.UpdateUserNotifications( - username, timestamps, state=state) + username, timestamps, state=state + ) def ReadAPIAuditEntries( self, username: Optional[Text] = None, router_method_names: Optional[List[Text]] = None, min_timestamp: Optional[rdfvalue.RDFDatetime] = None, - max_timestamp: Optional[rdfvalue.RDFDatetime] = None - ) -> List[rdf_objects.APIAuditEntry]: + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> List[objects_pb2.APIAuditEntry]: return self.delegate.ReadAPIAuditEntries( username=username, router_method_names=router_method_names, min_timestamp=min_timestamp, - max_timestamp=max_timestamp) + max_timestamp=max_timestamp, + ) def CountAPIAuditEntriesByUserAndDay( self, min_timestamp: Optional[rdfvalue.RDFDatetime] = None, - max_timestamp: Optional[rdfvalue.RDFDatetime] = None + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, ) -> Dict[Tuple[Text, rdfvalue.RDFDatetime], int]: precondition.AssertOptionalType(min_timestamp, rdfvalue.RDFDatetime) precondition.AssertOptionalType(max_timestamp, rdfvalue.RDFDatetime) return self.delegate.CountAPIAuditEntriesByUserAndDay( - min_timestamp=min_timestamp, max_timestamp=max_timestamp) + min_timestamp=min_timestamp, max_timestamp=max_timestamp + ) - def WriteAPIAuditEntry(self, entry): - precondition.AssertType(entry, rdf_objects.APIAuditEntry) + def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry) -> None: + precondition.AssertType(entry, objects_pb2.APIAuditEntry) return self.delegate.WriteAPIAuditEntry(entry) - def WriteMessageHandlerRequests(self, requests): - precondition.AssertIterableType(requests, rdf_objects.MessageHandlerRequest) + def WriteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: + precondition.AssertIterableType(requests, objects_pb2.MessageHandlerRequest) for request in requests: _ValidateMessageHandlerName(request.handler_name) return self.delegate.WriteMessageHandlerRequests(requests) - def DeleteMessageHandlerRequests(self, requests): + def DeleteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: return self.delegate.DeleteMessageHandlerRequests(requests) - def ReadMessageHandlerRequests(self): + def ReadMessageHandlerRequests( + self, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: return self.delegate.ReadMessageHandlerRequests() - def RegisterMessageHandler(self, handler, lease_time, limit=1000): + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: if handler is None: raise ValueError("handler must be provided") _ValidateDuration(lease_time) return self.delegate.RegisterMessageHandler( - handler, lease_time, limit=limit) + handler, lease_time, limit=limit + ) - def UnregisterMessageHandler(self, timeout=None): + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: return self.delegate.UnregisterMessageHandler(timeout=timeout) def WriteCronJob(self, cronjob: flows_pb2.CronJob): @@ -3697,15 +3960,15 @@ def UpdateCronJob( forced_run_requested: Union[bool, Literal[UNCHANGED]] = UNCHANGED, ) -> None: _ValidateCronJobId(cronjob_id) - if current_run_id is not None and current_run_id != Database.unchanged: + if current_run_id is not None and current_run_id != Database.UNCHANGED: _ValidateCronJobRunId(current_run_id) - if last_run_time is not None and last_run_time != Database.unchanged: + if last_run_time is not None and last_run_time != Database.UNCHANGED: precondition.AssertType(last_run_time, rdfvalue.RDFDatetime) - if state is not None and state != Database.unchanged: + if state is not None and state != Database.UNCHANGED: precondition.AssertType(state, jobs_pb2.AttributedDict) if ( forced_run_requested is not None - and forced_run_requested != Database.unchanged + and forced_run_requested != Database.UNCHANGED ): precondition.AssertType(forced_run_requested, bool) @@ -3715,7 +3978,8 @@ def UpdateCronJob( last_run_time=last_run_time, current_run_id=current_run_id, state=state, - forced_run_requested=forced_run_requested) + forced_run_requested=forced_run_requested, + ) def LeaseCronJobs( self, @@ -3727,7 +3991,8 @@ def LeaseCronJobs( _ValidateCronJobId(cronjob_id) _ValidateDuration(lease_time) return self.delegate.LeaseCronJobs( - cronjob_ids=cronjob_ids, lease_time=lease_time) + cronjob_ids=cronjob_ids, lease_time=lease_time + ) def ReturnLeasedCronJobs(self, jobs: Sequence[flows_pb2.CronJob]) -> None: for job in jobs: @@ -3774,12 +4039,16 @@ def ReadHashBlobReferences( precondition.AssertIterableType(hashes, rdf_objects.SHA256HashID) return self.delegate.ReadHashBlobReferences(hashes) - def WriteFlowObject(self, flow_obj, allow_update=True): - precondition.AssertType(flow_obj, rdf_flow_objects.Flow) + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + ) -> None: + precondition.AssertType(flow_obj, flows_pb2.Flow) precondition.AssertType(allow_update, bool) - if flow_obj.HasField("creation_time"): - raise ValueError(f"Creation time set on the flow object: {flow_obj}") + if flow_obj.HasField("create_time"): + raise ValueError(f"Create time set on the flow object: {flow_obj}") return self.delegate.WriteFlowObject(flow_obj, allow_update=allow_update) @@ -3796,15 +4065,16 @@ def ReadAllFlowObjects( max_create_time: Optional[rdfvalue.RDFDatetime] = None, include_child_flows: bool = True, not_created_by: Optional[Iterable[str]] = None, - ) -> List[rdf_flow_objects.Flow]: + ) -> List[flows_pb2.Flow]: if client_id is not None: precondition.ValidateClientId(client_id) precondition.AssertOptionalType(min_create_time, rdfvalue.RDFDatetime) precondition.AssertOptionalType(max_create_time, rdfvalue.RDFDatetime) if parent_flow_id is not None and not include_child_flows: - raise ValueError(f"Parent flow id specified ('{parent_flow_id}') in the " - f"childless mode") + raise ValueError( + f"Parent flow id specified ('{parent_flow_id}') in the childless mode" + ) if not_created_by is not None: precondition.AssertIterableType(not_created_by, str) @@ -3815,50 +4085,73 @@ def ReadAllFlowObjects( min_create_time=min_create_time, max_create_time=max_create_time, include_child_flows=include_child_flows, - not_created_by=not_created_by) + not_created_by=not_created_by, + ) - def ReadChildFlowObjects(self, client_id, flow_id): + def ReadChildFlowObjects( + self, + client_id: str, + flow_id: str, + ) -> List[flows_pb2.Flow]: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) return self.delegate.ReadChildFlowObjects(client_id, flow_id) - def LeaseFlowForProcessing(self, client_id, flow_id, processing_time): + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + ) -> flows_pb2.Flow: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) _ValidateDuration(processing_time) - return self.delegate.LeaseFlowForProcessing(client_id, flow_id, - processing_time) + return self.delegate.LeaseFlowForProcessing( + client_id, flow_id, processing_time + ) - def ReleaseProcessedFlow(self, flow_obj): - precondition.AssertType(flow_obj, rdf_flow_objects.Flow) + def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: + precondition.AssertType(flow_obj, flows_pb2.Flow) return self.delegate.ReleaseProcessedFlow(flow_obj) - def UpdateFlow(self, - client_id, - flow_id, - flow_obj=Database.unchanged, - flow_state=Database.unchanged, - client_crash_info=Database.unchanged, - processing_on=Database.unchanged, - processing_since=Database.unchanged, - processing_deadline=Database.unchanged): + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[ + flows_pb2.Flow, Database.UNCHANGED_TYPE + ] = Database.UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, Database.UNCHANGED_TYPE + ] = Database.UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, Database.UNCHANGED_TYPE + ] = Database.UNCHANGED, + processing_on: Union[str, Database.UNCHANGED_TYPE] = Database.UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, Database.UNCHANGED_TYPE] + ] = Database.UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, Database.UNCHANGED_TYPE] + ] = Database.UNCHANGED, + ) -> None: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) - if flow_obj != Database.unchanged: - precondition.AssertType(flow_obj, rdf_flow_objects.Flow) - - if flow_state != Database.unchanged: - raise ConflictingUpdateFlowArgumentsError(client_id, flow_id, - "flow_state") - - if flow_state != Database.unchanged: - _ValidateEnumType(flow_state, rdf_flow_objects.Flow.FlowState) - if client_crash_info != Database.unchanged: - precondition.AssertType(client_crash_info, rdf_client.ClientCrash) - if processing_since != Database.unchanged: + if flow_obj != Database.UNCHANGED: + precondition.AssertType(flow_obj, flows_pb2.Flow) + + if flow_state != Database.UNCHANGED: + raise ConflictingUpdateFlowArgumentsError( + client_id, flow_id, "flow_state" + ) + if flow_state != Database.UNCHANGED: + _ValidateProtoEnumType(flow_state, flows_pb2.Flow.FlowState) + if client_crash_info != Database.UNCHANGED: + precondition.AssertType(client_crash_info, jobs_pb2.ClientCrash) + if processing_since != Database.UNCHANGED: if processing_since is not None: self._ValidateTimestamp(processing_since) - if processing_deadline != Database.unchanged: + if processing_deadline != Database.UNCHANGED: if processing_deadline is not None: self._ValidateTimestamp(processing_deadline) return self.delegate.UpdateFlow( @@ -3869,70 +4162,143 @@ def UpdateFlow(self, client_crash_info=client_crash_info, processing_on=processing_on, processing_since=processing_since, - processing_deadline=processing_deadline) + processing_deadline=processing_deadline, + ) - def WriteFlowRequests(self, requests): - precondition.AssertIterableType(requests, rdf_flow_objects.FlowRequest) + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: + precondition.AssertIterableType(requests, flows_pb2.FlowRequest) return self.delegate.WriteFlowRequests(requests) - def UpdateIncrementalFlowRequests(self, client_id: str, flow_id: str, - next_response_id_updates: Dict[int, int]): + def UpdateIncrementalFlowRequests( + self, + client_id: str, + flow_id: str, + next_response_id_updates: Mapping[int, int], + ) -> None: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) precondition.AssertDictType(next_response_id_updates, int, int) return self.delegate.UpdateIncrementalFlowRequests( - client_id, flow_id, next_response_id_updates) + client_id, flow_id, next_response_id_updates + ) - def DeleteFlowRequests(self, requests): - precondition.AssertIterableType(requests, rdf_flow_objects.FlowRequest) + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + ) -> None: + precondition.AssertIterableType(requests, flows_pb2.FlowRequest) return self.delegate.DeleteFlowRequests(requests) def WriteFlowResponses( - self, responses: Iterable[rdf_flow_objects.FlowMessage]) -> None: - precondition.AssertIterableType(responses, rdf_flow_objects.FlowMessage) + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + for r in responses: + precondition.AssertType(r.request_id, int) + precondition.AssertType(r.response_id, int) + precondition.AssertType(r.client_id, str) + precondition.AssertType(r.flow_id, str) + return self.delegate.WriteFlowResponses(responses) - def ReadAllFlowRequestsAndResponses(self, client_id, flow_id): + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> Iterable[ + Tuple[ + flows_pb2.FlowRequest, + Dict[ + int, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ] + ]: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) return self.delegate.ReadAllFlowRequestsAndResponses(client_id, flow_id) - def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id): + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> None: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) return self.delegate.DeleteAllFlowRequestsAndResponses(client_id, flow_id) - def ReadFlowRequestsReadyForProcessing(self, - client_id, - flow_id, - next_needed_request=None): + def ReadFlowRequestsReadyForProcessing( + self, + client_id: str, + flow_id: str, + next_needed_request: Optional[int] = None, + ) -> Dict[ + int, + Tuple[ + flows_pb2.FlowRequest, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) if next_needed_request is None: raise ValueError("next_needed_request must be provided.") return self.delegate.ReadFlowRequestsReadyForProcessing( - client_id, flow_id, next_needed_request=next_needed_request) + client_id, flow_id, next_needed_request=next_needed_request + ) - def WriteFlowProcessingRequests(self, requests): - precondition.AssertIterableType(requests, rdf_flows.FlowProcessingRequest) + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + ) -> None: + precondition.AssertIterableType(requests, flows_pb2.FlowProcessingRequest) return self.delegate.WriteFlowProcessingRequests(requests) - def ReadFlowProcessingRequests(self): + def ReadFlowProcessingRequests( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: return self.delegate.ReadFlowProcessingRequests() - def AckFlowProcessingRequests(self, requests): - precondition.AssertIterableType(requests, rdf_flows.FlowProcessingRequest) + def AckFlowProcessingRequests( + self, requests: Iterable[flows_pb2.FlowProcessingRequest] + ) -> None: + precondition.AssertIterableType(requests, flows_pb2.FlowProcessingRequest) return self.delegate.AckFlowProcessingRequests(requests) - def DeleteAllFlowProcessingRequests(self): + def DeleteAllFlowProcessingRequests(self) -> None: return self.delegate.DeleteAllFlowProcessingRequests() - def RegisterFlowProcessingHandler(self, handler): + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: if handler is None: raise ValueError("handler must be provided") return self.delegate.RegisterFlowProcessingHandler(handler) - def UnregisterFlowProcessingHandler(self, timeout=None): + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: return self.delegate.UnregisterFlowProcessingHandler(timeout=timeout) def WriteFlowResults(self, results): @@ -3945,14 +4311,16 @@ def WriteFlowResults(self, results): return self.delegate.WriteFlowResults(results) - def ReadFlowResults(self, - client_id, - flow_id, - offset, - count, - with_tag=None, - with_type=None, - with_substring=None): + def ReadFlowResults( + self, + client_id, + flow_id, + offset, + count, + with_tag=None, + with_type=None, + with_substring=None, + ): precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) precondition.AssertOptionalType(with_tag, Text) @@ -3966,7 +4334,8 @@ def ReadFlowResults(self, count, with_tag=with_tag, with_type=with_type, - with_substring=with_substring) + with_substring=with_substring, + ) def CountFlowResults( self, @@ -3981,7 +4350,8 @@ def CountFlowResults( precondition.AssertOptionalType(with_type, Text) return self.delegate.CountFlowResults( - client_id, flow_id, with_tag=with_tag, with_type=with_type) + client_id, flow_id, with_tag=with_tag, with_type=with_type + ) def CountFlowResultsByType( self, @@ -4031,20 +4401,24 @@ def ReadFlowErrors( offset, count, with_tag=with_tag, - with_type=with_type) + with_type=with_type, + ) - def CountFlowErrors(self, - client_id: str, - flow_id: str, - with_tag: Optional[str] = None, - with_type: Optional[str] = None) -> int: + def CountFlowErrors( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: precondition.ValidateClientId(client_id) precondition.ValidateFlowId(flow_id) precondition.AssertOptionalType(with_tag, Text) precondition.AssertOptionalType(with_type, Text) return self.delegate.CountFlowErrors( - client_id, flow_id, with_tag=with_tag, with_type=with_type) + client_id, flow_id, with_tag=with_tag, with_type=with_type + ) def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: precondition.ValidateClientId(entry.client_id) @@ -4067,7 +4441,8 @@ def ReadFlowLogEntries( precondition.AssertOptionalType(with_substring, str) return self.delegate.ReadFlowLogEntries( - client_id, flow_id, offset, count, with_substring=with_substring) + client_id, flow_id, offset, count, with_substring=with_substring + ) def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: precondition.ValidateClientId(client_id) @@ -4109,12 +4484,8 @@ def ReadFlowOutputPluginLogEntries( ) return self.delegate.ReadFlowOutputPluginLogEntries( - client_id, - flow_id, - output_plugin_id, - offset, - count, - with_type=with_type) + client_id, flow_id, output_plugin_id, offset, count, with_type=with_type + ) def CountFlowOutputPluginLogEntries( self, @@ -4130,14 +4501,12 @@ def CountFlowOutputPluginLogEntries( _ValidateOutputPluginId(output_plugin_id) return self.delegate.CountFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id, with_type=with_type) - - def ReadHuntOutputPluginLogEntries(self, - hunt_id, - output_plugin_id, - offset, - count, - with_type=None): + client_id, flow_id, output_plugin_id, with_type=with_type + ) + + def ReadHuntOutputPluginLogEntries( + self, hunt_id, output_plugin_id, offset, count, with_type=None + ): _ValidateHuntId(hunt_id) _ValidateOutputPluginId(output_plugin_id) if with_type is not None: @@ -4146,12 +4515,12 @@ def ReadHuntOutputPluginLogEntries(self, ) return self.delegate.ReadHuntOutputPluginLogEntries( - hunt_id, output_plugin_id, offset, count, with_type=with_type) + hunt_id, output_plugin_id, offset, count, with_type=with_type + ) - def CountHuntOutputPluginLogEntries(self, - hunt_id, - output_plugin_id, - with_type=None): + def CountHuntOutputPluginLogEntries( + self, hunt_id, output_plugin_id, with_type=None + ): _ValidateHuntId(hunt_id) _ValidateOutputPluginId(output_plugin_id) if with_type is not None: @@ -4160,7 +4529,8 @@ def CountHuntOutputPluginLogEntries(self, ) return self.delegate.CountHuntOutputPluginLogEntries( - hunt_id, output_plugin_id, with_type=with_type) + hunt_id, output_plugin_id, with_type=with_type + ) def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt) -> None: precondition.AssertType(hunt_obj, hunts_pb2.Hunt) @@ -4172,15 +4542,17 @@ def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt) -> None: def UpdateHuntObject( self, - hunt_id, - duration=None, - client_rate=None, - client_limit=None, - hunt_state=None, - hunt_state_reason=None, - hunt_state_comment=None, - start_time=None, - num_clients_at_start_time=None, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[float] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, ): """Updates the hunt object by applying the update function.""" _ValidateHuntId(hunt_id) @@ -4188,9 +4560,10 @@ def UpdateHuntObject( precondition.AssertOptionalType(client_rate, (float, int)) precondition.AssertOptionalType(client_limit, int) if hunt_state is not None: - _ValidateEnumType(hunt_state, rdf_hunt_objects.Hunt.HuntState) + _ValidateProtoEnumType(hunt_state, hunts_pb2.Hunt.HuntState) if hunt_state_reason is not None: - _ValidateEnumType(hunt_state, rdf_hunt_objects.Hunt.HuntStateReason) + _ValidateProtoEnumType(hunt_state, hunts_pb2.Hunt.HuntStateReason) + precondition.AssertOptionalType(hunt_state_comment, str) precondition.AssertOptionalType(start_time, rdfvalue.RDFDatetime) precondition.AssertOptionalType(num_clients_at_start_time, int) @@ -4207,45 +4580,62 @@ def UpdateHuntObject( num_clients_at_start_time=num_clients_at_start_time, ) - def ReadHuntOutputPluginsStates(self, hunt_id): + def ReadHuntOutputPluginsStates( + self, + hunt_id: str, + ) -> List[output_plugin_pb2.OutputPluginState]: _ValidateHuntId(hunt_id) return self.delegate.ReadHuntOutputPluginsStates(hunt_id) - def WriteHuntOutputPluginsStates(self, hunt_id, states): - + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + ) -> None: + """Writes a list of output plugin states to the database.""" if not states: return _ValidateHuntId(hunt_id) - precondition.AssertIterableType(states, rdf_flow_runner.OutputPluginState) + precondition.AssertIterableType(states, output_plugin_pb2.OutputPluginState) self.delegate.WriteHuntOutputPluginsStates(hunt_id, states) - def UpdateHuntOutputPluginState(self, hunt_id, state_index, update_fn): + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[ + [jobs_pb2.AttributedDict], + jobs_pb2.AttributedDict, + ], + ) -> jobs_pb2.AttributedDict: _ValidateHuntId(hunt_id) precondition.AssertType(state_index, int) + return self.delegate.UpdateHuntOutputPluginState( + hunt_id, state_index, update_fn + ) - return self.delegate.UpdateHuntOutputPluginState(hunt_id, state_index, - update_fn) - - def DeleteHuntObject(self, hunt_id): + def DeleteHuntObject(self, hunt_id: str) -> None: _ValidateHuntId(hunt_id) return self.delegate.DeleteHuntObject(hunt_id) - def ReadHuntObject(self, hunt_id): + def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: _ValidateHuntId(hunt_id) return self.delegate.ReadHuntObject(hunt_id) def ReadHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> List[hunts_pb2.Hunt]: precondition.AssertOptionalType(offset, int) precondition.AssertOptionalType(count, int) precondition.AssertOptionalType(with_creator, Text) @@ -4256,7 +4646,8 @@ def ReadHuntObjects( if not_created_by is not None: precondition.AssertIterableType(not_created_by, str) if with_states is not None: - precondition.AssertIterableType(with_states, rdf_structs.EnumNamedValue) + for state in with_states: + _ValidateProtoEnumType(state, hunts_pb2.Hunt.HuntState) return self.delegate.ReadHuntObjects( offset, @@ -4271,14 +4662,16 @@ def ReadHuntObjects( def ListHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, ): precondition.AssertOptionalType(offset, int) precondition.AssertOptionalType(count, int) @@ -4290,7 +4683,8 @@ def ListHuntObjects( if not_created_by is not None: precondition.AssertIterableType(not_created_by, str) if with_states is not None: - precondition.AssertIterableType(with_states, rdf_structs.EnumNamedValue) + for state in with_states: + _ValidateProtoEnumType(state, hunts_pb2.Hunt.HuntState) return self.delegate.ListHuntObjects( offset, @@ -4314,20 +4708,23 @@ def ReadHuntLogEntries( precondition.AssertOptionalType(with_substring, Text) return self.delegate.ReadHuntLogEntries( - hunt_id, offset, count, with_substring=with_substring) + hunt_id, offset, count, with_substring=with_substring + ) def CountHuntLogEntries(self, hunt_id: str) -> int: _ValidateHuntId(hunt_id) return self.delegate.CountHuntLogEntries(hunt_id) - def ReadHuntResults(self, - hunt_id, - offset, - count, - with_tag=None, - with_type=None, - with_substring=None, - with_timestamp=None): + def ReadHuntResults( + self, + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Iterable[flows_pb2.FlowResult]: _ValidateHuntId(hunt_id) precondition.AssertOptionalType(with_tag, Text) precondition.AssertOptionalType(with_type, Text) @@ -4340,44 +4737,67 @@ def ReadHuntResults(self, with_tag=with_tag, with_type=with_type, with_substring=with_substring, - with_timestamp=with_timestamp) + with_timestamp=with_timestamp, + ) def CountHuntResults(self, hunt_id, with_tag=None, with_type=None): _ValidateHuntId(hunt_id) precondition.AssertOptionalType(with_tag, Text) precondition.AssertOptionalType(with_type, Text) return self.delegate.CountHuntResults( - hunt_id, with_tag=with_tag, with_type=with_type) + hunt_id, with_tag=with_tag, with_type=with_type + ) - def CountHuntResultsByType(self, hunt_id): + def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: _ValidateHuntId(hunt_id) return self.delegate.CountHuntResultsByType(hunt_id) - def ReadHuntFlows(self, - hunt_id, - offset, - count, - filter_condition=HuntFlowsCondition.UNSET): + def ReadHuntFlows( + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: HuntFlowsCondition = HuntFlowsCondition.UNSET, + ) -> Sequence[flows_pb2.Flow]: _ValidateHuntId(hunt_id) _ValidateHuntFlowCondition(filter_condition) return self.delegate.ReadHuntFlows( - hunt_id, offset, count, filter_condition=filter_condition) + hunt_id, offset, count, filter_condition=filter_condition + ) - def CountHuntFlows(self, hunt_id, filter_condition=HuntFlowsCondition.UNSET): + def CountHuntFlows( + self, + hunt_id: str, + filter_condition: Optional[HuntFlowsCondition] = HuntFlowsCondition.UNSET, + ) -> int: _ValidateHuntId(hunt_id) _ValidateHuntFlowCondition(filter_condition) return self.delegate.CountHuntFlows( - hunt_id, filter_condition=filter_condition) + hunt_id, filter_condition=filter_condition + ) - def ReadHuntCounters(self, hunt_id): + def ReadHuntCounters(self, hunt_id: str) -> HuntCounters: _ValidateHuntId(hunt_id) return self.delegate.ReadHuntCounters(hunt_id) - def ReadHuntClientResourcesStats(self, hunt_id): + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + ) -> Mapping[str, HuntCounters]: + for hunt_id in hunt_ids: + _ValidateHuntId(hunt_id) + return self.delegate.ReadHuntsCounters(hunt_ids) + + def ReadHuntClientResourcesStats( + self, hunt_id: str + ) -> jobs_pb2.ClientResourcesStats: _ValidateHuntId(hunt_id) return self.delegate.ReadHuntClientResourcesStats(hunt_id) - def ReadHuntFlowsStatesAndTimestamps(self, hunt_id): + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + ) -> Sequence[FlowStateAndTimestamps]: _ValidateHuntId(hunt_id) return self.delegate.ReadHuntFlowsStatesAndTimestamps(hunt_id) @@ -4432,13 +4852,15 @@ def WriteScheduledFlow( precondition.ValidateClientId(scheduled_flow.client_id) return self.delegate.WriteScheduledFlow(scheduled_flow) - def DeleteScheduledFlow(self, client_id: str, creator: str, - scheduled_flow_id: str) -> None: + def DeleteScheduledFlow( + self, client_id: str, creator: str, scheduled_flow_id: str + ) -> None: precondition.ValidateClientId(client_id) _ValidateUsername(creator) _ValidateStringId("scheduled_flow_id", scheduled_flow_id) - return self.delegate.DeleteScheduledFlow(client_id, creator, - scheduled_flow_id) + return self.delegate.DeleteScheduledFlow( + client_id, creator, scheduled_flow_id + ) def ListScheduledFlows( self, @@ -4449,9 +4871,12 @@ def ListScheduledFlows( _ValidateUsername(creator) return self.delegate.ListScheduledFlows(client_id, creator) - def StructuredSearchClients(self, expression: rdf_search.SearchExpression, - continuation_token: bytes, - number_of_results: int) -> SearchClientsResult: + def StructuredSearchClients( + self, + expression: rdf_search.SearchExpression, + continuation_token: bytes, + number_of_results: int, + ) -> SearchClientsResult: return self.delegate.StructuredSearchClient(expression, continuation_token) # pytype: disable=attribute-error def WriteBlobEncryptionKeys( @@ -4560,8 +4985,10 @@ def _ValidateApprovalId(approval_id): def _ValidateApprovalType(approval_type): - if (approval_type == - rdf_objects.ApprovalRequest.ApprovalType.APPROVAL_TYPE_NONE): + if ( + approval_type + == rdf_objects.ApprovalRequest.ApprovalType.APPROVAL_TYPE_NONE + ): raise ValueError("Unexpected approval type: %s" % approval_type) @@ -4569,7 +4996,9 @@ def _ValidateStringLength(name, string, max_length): if len(string) > max_length: raise StringTooLongError( "{} can have at most {} characters, got {}.".format( - name, max_length, len(string))) + name, max_length, len(string) + ) + ) def _ValidateUsername(username): @@ -4584,29 +5013,34 @@ def _ValidateLabel(label): _ValidateStringLength("Labels", label, MAX_LABEL_LENGTH) -def _ValidatePathInfo(path_info): - precondition.AssertType(path_info, rdf_objects.PathInfo) +def _ValidatePathInfo(path_info: objects_pb2.PathInfo) -> None: + precondition.AssertType(path_info, objects_pb2.PathInfo) if not path_info.path_type: - raise ValueError("Expected path_type to be set, got: %s" % - path_info.path_type) + raise ValueError( + "Expected path_type to be set, got: %s" % path_info.path_type + ) -def _ValidatePathInfos(path_infos): +def _ValidatePathInfos(path_infos: Iterable[objects_pb2.PathInfo]) -> None: """Validates a sequence of path infos.""" - precondition.AssertIterableType(path_infos, rdf_objects.PathInfo) + precondition.AssertIterableType(path_infos, objects_pb2.PathInfo) validated = set() for path_info in path_infos: _ValidatePathInfo(path_info) - path_key = (path_info.path_type, path_info.GetPathID()) + path_key = ( + path_info.path_type, + rdf_objects.PathID.FromComponents(path_info.components), + ) if path_key in validated: message = "Conflicting writes for path: '{path}' ({path_type})".format( - path="/".join(path_info.components), path_type=path_info.path_type) + path="/".join(path_info.components), path_type=path_info.path_type + ) raise ValueError(message) if path_info.HasField("hash_entry"): - if path_info.hash_entry.sha256 is None: + if not path_info.hash_entry.sha256: message = "Path with hash entry without SHA256: {}".format(path_info) raise ValueError(message) @@ -4658,13 +5092,13 @@ def _ValidateSHA256HashID(sha256_hash_id): def _ValidateHuntFlowCondition(value): - if value < 0 or value > HuntFlowsCondition.MaxValue(): - raise ValueError("Invalid hunt flow condition: %r" % value) + precondition.AssertType(value, HuntFlowsCondition) def _ValidateMessageHandlerName(name): - _ValidateStringLength("MessageHandler names", name, - MAX_MESSAGE_HANDLER_NAME_LENGTH) + _ValidateStringLength( + "MessageHandler names", name, MAX_MESSAGE_HANDLER_NAME_LENGTH + ) def _ValidateClientActivityBuckets(buckets): diff --git a/grr/server/grr_response_server/databases/db_blob_keys_test_lib.py b/grr/server/grr_response_server/databases/db_blob_keys_test_lib.py index 6d04f58eb4..24c9c4b72c 100644 --- a/grr/server/grr_response_server/databases/db_blob_keys_test_lib.py +++ b/grr/server/grr_response_server/databases/db_blob_keys_test_lib.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with test cases for blob encryption key methods.""" + import os from grr_response_server.databases import db as abstract_db @@ -8,6 +9,7 @@ class DatabaseTestBlobKeysMixin: """A mixin class for testing blob encryption keys database methods.""" + db: abstract_db.Database # This is test-only module, we don't need docstrings for test methods. @@ -50,11 +52,14 @@ def testReadBlobEncryptionKeysMultiple(self): }) results = self.db.ReadBlobEncryptionKeys([blob_id_1, blob_id_2, blob_id_3]) - self.assertEqual(results, { - blob_id_1: "foo", - blob_id_2: "bar", - blob_id_3: "quux", - }) + self.assertEqual( + results, + { + blob_id_1: "foo", + blob_id_2: "bar", + blob_id_3: "quux", + }, + ) def testReadBlobEncryptionKeysOverridden(self): blob_id = blobs.BlobID(os.urandom(32)) diff --git a/grr/server/grr_response_server/databases/db_blob_references_test.py b/grr/server/grr_response_server/databases/db_blob_references_test.py index 1ec22ec3e6..babf33c0d7 100644 --- a/grr/server/grr_response_server/databases/db_blob_references_test.py +++ b/grr/server/grr_response_server/databases/db_blob_references_test.py @@ -45,10 +45,13 @@ def testCorrectlyHandlesRequestWithOneExistingAndOneMissingHash(self): missing_hash_id = rdf_objects.SHA256HashID(b"00000000" * 4) results = self.db.ReadHashBlobReferences([missing_hash_id, hash_id]) - self.assertEqual(results, { - hash_id: [blob_ref], - missing_hash_id: None, - }) + self.assertEqual( + results, + { + hash_id: [blob_ref], + missing_hash_id: None, + }, + ) def testMultipleHashBlobReferencesCanBeWrittenAndReadBack(self): blob_ref_1 = objects_pb2.BlobReference( diff --git a/grr/server/grr_response_server/databases/db_clients_test.py b/grr/server/grr_response_server/databases/db_clients_test.py index cc9bf67eb1..c70bdb81ba 100644 --- a/grr/server/grr_response_server/databases/db_clients_test.py +++ b/grr/server/grr_response_server/databases/db_clients_test.py @@ -3,10 +3,9 @@ from unittest import mock from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto -from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import mig_client from grr_response_core.lib.util import collection +from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 from grr_response_proto import knowledge_base_pb2 from grr_response_proto import objects_pb2 @@ -14,50 +13,9 @@ from grr_response_server.databases import db from grr_response_server.databases import db_test_utils from grr_response_server.models import clients -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 -CERT = rdf_crypto.RDFX509Cert(b"""-----BEGIN CERTIFICATE----- -MIIF7zCCA9egAwIBAgIBATANBgkqhkiG9w0BAQUFADA+MQswCQYDVQQGEwJVUzEM -MAoGA1UECBMDQ0FMMQswCQYDVQQHEwJTRjEUMBIGA1UEAxMLR1JSIFRlc3QgQ0Ew -HhcNMTEwNTI3MTIxNTExWhcNMTIwNTI2MTIxNTExWjBCMQswCQYDVQQGEwJVUzEM -MAoGA1UECBMDQ0FMMQswCQYDVQQHEwJTRjEYMBYGA1UEAxMPR1JSIFRlc3QgU2Vy -dmVyMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAwUXBNzWSoEr88dGQ -qZWSwgJ+n/A/QQyNn/ZM57XsqI6IMO6plFmA+DZv2FkTTdniNPmhuL9mjWYA5yg4 -KYMbz5igOiBoF9RBeIm2/v2Sg65VFoyCgJNgl3V34mpoDCHBYTi2A/OfoKeSQISb -UfMHsYhPHdGfhjk8dEuMo7MxjrtfAO3Y4QtjTiE07eNdoRQkFtzF0m9oSaytJ95c -BAe1eQ/2zcvxPvnF5yavR4fwKQtk8o1hc21XVG0JvqJ7da79C27cQQP3E/6EYzpN -pkh9n4berPBHV/oxlB2np4zKgXCQ4zDdiw1uEUY9+iFmVEuvzO2e5NJcfnu74sGb -oX+2a2/ph65sMZ2/NF8lRgetvIrtYUl15yypXmH3VobBYvpfGpab1rLt0J1HoVUh -V5Nsrdav0n8EQ+hln/sHz+G5rNe4ZSJbZ8w8b1TOwTENdzOYKAQH/NN9IrsbXNgE -8RHSHfPwibWnhfKS/fy7GO8qah/u2HPQ5S33gao409zbwS6c4sn0nAQhr5H6pHVD -iMLcBPFQ+w6zIk28hOv3GMa5XQtm8ONb/QhOLTbtB+ZCHKCw3bXASVDt7EwvnM/b -cSYS58wKmUQhH3unizXyihLhxC8ck/KMTkGnuGBC0Pz2d6YgcdL4BxAK6udSjSQQ -DB8sWYKJJrmlCnaN2E1eBbPV5PMCAwEAAaOB8zCB8DAJBgNVHRMEAjAAMBEGCWCG -SAGG+EIBAQQEAwIGQDArBglghkgBhvhCAQ0EHhYcVGlueUNBIEdlbmVyYXRlZCBD -ZXJ0aWZpY2F0ZTAdBgNVHQ4EFgQUywgOS64OISRSFNqpMpF83qXKDPIwbgYDVR0j -BGcwZYAUO4+Xefeqvq3W6/eaPxaNv8IHpcuhQqRAMD4xCzAJBgNVBAYTAlVTMQww -CgYDVQQIEwNDQUwxCzAJBgNVBAcTAlNGMRQwEgYDVQQDEwtHUlIgVGVzdCBDQYIJ -AIayxnA7Bp+3MAkGA1UdEgQCMAAwCQYDVR0RBAIwADANBgkqhkiG9w0BAQUFAAOC -AgEAY6z2VZdS83i6N88hVk3Y8qt0xNhP10+tfgsI7auPq2n3PsDNOLPvp2OcUcLI -csMQ/3GTI84uRm0GFnLMAc+A8BQZ14+3kPRju5jWe3KMfP1Ohz5Hm36Uf47tFhgV -VYnyIPwwCE1QPOgbnFt5jR+d3pjhx9TvjfeFKmavxMpxnDD2KWgGZfuE1UqC0DXm -rkimG2Q+dHUFBOMBUKzaklZsr7v4hlc+7XY1n5vRhiuczS9m5mVB05Cg4mrJFcVs -AUsxSuwgMhJqxuNaFw8qMmdkX7ujo5HAtwJqIi91Sdj8xNRqDysd1OagqL3Mx172 -wTJu7ZIAURpw52AXxn3PpK5NS3NSvL/PE6SnpHCtfkxaHl/80W2oq7MjSaHbQt2g -8vYuwLEKYVhgEBzEK0p5AqDyabAn49bw9hfT10NElJ/tYEPCKZZwrARBHnpCxLeC -jJVIIMzPOczWnTDw92ls3l6+l075MOzXGo94GNlxt0/HLCQktl9cuF1APmRkiGUe -EaQA1dggxMyZGyZpYmEbrWCiEjKqfIXXnpyw5pxL5Rvoe4kYrQBvbJ1aaWJ87Pcz -gXJvjIkzp4x/MMAgdBOqJm5tJ4nhCHTbXWuIbYymPLn7hqXhyrDZwqnH7kQKPF2/ -z5KjO8gWio6YOhsDwrketcBcIANMDYws2+TzrLs9ttuHNS0= ------END CERTIFICATE-----""") - - -def _DaysSinceEpoch(days): - return rdfvalue.RDFDatetime( - rdfvalue.Duration.From(days, rdfvalue.DAYS).microseconds) - def _FlattenDicts(dicts): """Merges an iterable of dicts into one dict.""" @@ -132,7 +90,6 @@ def testClientMetadataInitialWrite(self): # Typical initial non-FS write d.WriteClientMetadata( client_id_2, - certificate=CERT, first_seen=rdfvalue.RDFDatetime(100000000), ) @@ -144,7 +101,6 @@ def testClientMetadataInitialWrite(self): m2 = res[client_id_2] self.assertIsInstance(m2, objects_pb2.ClientMetadata) - self.assertEqual(m2.certificate, CERT.SerializeToBytes()) self.assertEqual(m2.first_seen, int(rdfvalue.RDFDatetime(100000000))) def testClientMetadataDefaultValues(self): @@ -170,7 +126,6 @@ def testClientMetadataSkipFields(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( client_id, - certificate=CERT, first_seen=rdfvalue.RDFDatetime(100000000), last_clock=rdfvalue.RDFDatetime(100000001), last_foreman=rdfvalue.RDFDatetime(100000002), @@ -183,7 +138,6 @@ def testClientMetadataSkipFields(self): # Skip fields self.db.WriteClientMetadata( client_id, - certificate=None, first_seen=None, last_clock=None, last_foreman=None, @@ -193,7 +147,6 @@ def testClientMetadataSkipFields(self): ) md = self.db.ReadClientMetadata(client_id) - self.assertEqual(md.certificate, CERT.SerializeToBytes()) self.assertEqual(md.first_seen, int(rdfvalue.RDFDatetime(100000000))) self.assertEqual(md.clock, int(rdfvalue.RDFDatetime(100000001))) self.assertEqual(md.last_foreman_time, int(rdfvalue.RDFDatetime(100000002))) @@ -213,7 +166,6 @@ def testClientMetadataSubsecond(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( client_id, - certificate=CERT, first_seen=rdfvalue.RDFDatetime(100000001), last_clock=rdfvalue.RDFDatetime(100000011), last_foreman=rdfvalue.RDFDatetime(100000021), @@ -301,8 +253,9 @@ def testReadAllClientIDsSome(self): client_ids = list(self.db.ReadAllClientIDs()) self.assertLen(client_ids, 1) - self.assertCountEqual(client_ids[0], - [client_a_id, client_b_id, client_c_id]) + self.assertCountEqual( + client_ids[0], [client_a_id, client_b_id, client_c_id] + ) def testReadAllClientIDsNotEvenlyDivisibleByBatchSize(self): client_a_id = db_test_utils.InitializeClient(self.db) @@ -312,7 +265,8 @@ def testReadAllClientIDsNotEvenlyDivisibleByBatchSize(self): client_ids = list(self.db.ReadAllClientIDs(batch_size=2)) self.assertEqual([len(batch) for batch in client_ids], [2, 1]) self.assertCountEqual( - collection.Flatten(client_ids), [client_a_id, client_b_id, client_c_id]) + collection.Flatten(client_ids), [client_a_id, client_b_id, client_c_id] + ) def testReadAllClientIDsEvenlyDivisibleByBatchSize(self): client_a_id = db_test_utils.InitializeClient(self.db) @@ -324,76 +278,113 @@ def testReadAllClientIDsEvenlyDivisibleByBatchSize(self): self.assertEqual([len(batch) for batch in client_ids], [2, 2]) self.assertCountEqual( collection.Flatten(client_ids), - [client_a_id, client_b_id, client_c_id, client_d_id]) + [client_a_id, client_b_id, client_c_id, client_d_id], + ) def testReadAllClientIDsFilterLastPing(self): self.db.WriteClientMetadata("C.0000000000000001") self.db.WriteClientMetadata( "C.0000000000000002", - last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2)) + last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2), + ) self.db.WriteClientMetadata( "C.0000000000000003", - last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3)) + last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), + ) self.db.WriteClientMetadata( "C.0000000000000004", - last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4)) + last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), + ) client_ids = self.db.ReadAllClientIDs( - min_last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3)) + min_last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3) + ) self.assertCountEqual( collection.Flatten(client_ids), - ["C.0000000000000003", "C.0000000000000004"]) + ["C.0000000000000003", "C.0000000000000004"], + ) def testReadClientLastPings_ResultsDivisibleByBatchSize(self): client_ids = self._WriteClientLastPingData() - (client_id5, client_id6, client_id7, client_id8, client_id9, - client_id10) = client_ids[4:] + ( + client_id5, + client_id6, + client_id7, + client_id8, + client_id9, + client_id10, + ) = client_ids[4:] results = list( self.db.ReadClientLastPings( min_last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), - batch_size=3)) + batch_size=3, + ) + ) self.assertEqual([len(batch) for batch in results], [3, 3]) self.assertEqual( - _FlattenDicts(results), { + _FlattenDicts(results), + { client_id5: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), client_id6: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), client_id7: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), client_id8: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), client_id9: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), client_id10: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), - }) + }, + ) def testReadClientLastPings_ResultsNotDivisibleByBatchSize(self): client_ids = self._WriteClientLastPingData() - (client_id5, client_id6, client_id7, client_id8, client_id9, - client_id10) = client_ids[4:] + ( + client_id5, + client_id6, + client_id7, + client_id8, + client_id9, + client_id10, + ) = client_ids[4:] results = list( self.db.ReadClientLastPings( min_last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), - batch_size=4)) + batch_size=4, + ) + ) self.assertEqual([len(batch) for batch in results], [4, 2]) self.assertEqual( - _FlattenDicts(results), { + _FlattenDicts(results), + { client_id5: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), client_id6: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3), client_id7: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), client_id8: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), client_id9: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), client_id10: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), - }) + }, + ) def testReadClientLastPings_NoFilter(self): client_ids = self._WriteClientLastPingData() - (client_id1, client_id2, client_id3, client_id4, client_id5, client_id6, - client_id7, client_id8, client_id9, client_id10) = client_ids + ( + client_id1, + client_id2, + client_id3, + client_id4, + client_id5, + client_id6, + client_id7, + client_id8, + client_id9, + client_id10, + ) = client_ids self.assertEqual( - list(self.db.ReadClientLastPings()), [{ + list(self.db.ReadClientLastPings()), + [{ client_id1: None, client_id2: None, client_id3: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2), @@ -404,7 +395,8 @@ def testReadClientLastPings_NoFilter(self): client_id8: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4), client_id9: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), client_id10: rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5), - }]) + }], + ) def testReadClientLastPings_AllFilters(self): client_ids = self._WriteClientLastPingData() @@ -472,28 +464,42 @@ def testReadClientLastPings_MaxPingFilter(self): def _WriteClientLastPingData(self): """Writes test data for ReadClientLastPings() tests.""" client_ids = tuple("C.00000000000000%02d" % i for i in range(1, 11)) - (client_id1, client_id2, client_id3, client_id4, client_id5, client_id6, - client_id7, client_id8, client_id9, client_id10) = client_ids + ( + client_id1, + client_id2, + client_id3, + client_id4, + client_id5, + client_id6, + client_id7, + client_id8, + client_id9, + client_id10, + ) = client_ids self.db.WriteClientMetadata(client_id1) self.db.WriteClientMetadata(client_id2) self.db.WriteClientMetadata( - client_id3, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2)) + client_id3, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2) + ) self.db.WriteClientMetadata( client_id4, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2) ) self.db.WriteClientMetadata( - client_id5, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3)) + client_id5, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3) + ) self.db.WriteClientMetadata( client_id6, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3) ) self.db.WriteClientMetadata( - client_id7, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4)) + client_id7, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4) + ) self.db.WriteClientMetadata( client_id8, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(4) ) self.db.WriteClientMetadata( - client_id9, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5)) + client_id9, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5) + ) self.db.WriteClientMetadata( client_id10, last_ping=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(5) ) @@ -698,11 +704,13 @@ def testClientSummary(self): self.assertIsInstance(res[client_id_2], objects_pb2.ClientSnapshot) self.assertIsNotNone(res[client_id_1].timestamp) self.assertIsNotNone(res[client_id_2].timestamp) - self.assertEqual(res[client_id_1].knowledge_base.fqdn, - "test1234.examples.com") + self.assertEqual( + res[client_id_1].knowledge_base.fqdn, "test1234.examples.com" + ) self.assertEqual(res[client_id_1].kernel, "12.4") - self.assertEqual(res[client_id_2].knowledge_base.fqdn, - "test1235.examples.com") + self.assertEqual( + res[client_id_2].knowledge_base.fqdn, "test1235.examples.com" + ) self.assertFalse(res[client_id_3]) def testMultiReadClientSnapshotInfoWithEmptyList(self): @@ -724,14 +732,28 @@ def testClientKeywords(self): client_id_3 = db_test_utils.InitializeClient(self.db) # Typical keywords are usernames and prefixes of hostnames. - d.AddClientKeywords(client_id_1, [ - "joe", "machine.test.example1.com", "machine.test.example1", - "machine.test", "machine", "🚀" - ]) - d.AddClientKeywords(client_id_2, [ - "fred", "machine.test.example2.com", "machine.test.example2", - "machine.test", "machine", "🚀🚀" - ]) + d.AddClientKeywords( + client_id_1, + [ + "joe", + "machine.test.example1.com", + "machine.test.example1", + "machine.test", + "machine", + "🚀", + ], + ) + d.AddClientKeywords( + client_id_2, + [ + "fred", + "machine.test.example2.com", + "machine.test.example2", + "machine.test", + "machine", + "🚀🚀", + ], + ) d.AddClientKeywords(client_id_3, ["foo", "bar", "baz"]) res = d.ListClientsForKeywords(["fred", "machine", "missing"]) @@ -742,9 +764,11 @@ def testClientKeywords(self): for kw, client_id in [("🚀", client_id_1), ("🚀🚀", client_id_2)]: res = d.ListClientsForKeywords([kw]) self.assertEqual( - res[kw], [client_id], - "Expected [%s] when reading keyword %s, got %s" % - (client_id, kw, res[kw])) + res[kw], + [client_id], + "Expected [%s] when reading keyword %s, got %s" + % (client_id, kw, res[kw]), + ) def testClientKeywordsTimeRanges(self): d = self.db @@ -754,8 +778,9 @@ def testClientKeywordsTimeRanges(self): change_time = rdfvalue.RDFDatetime.Now() d.AddClientKeywords(client_id, ["hostname2"]) - res = d.ListClientsForKeywords(["hostname1", "hostname2"], - start_time=change_time) + res = d.ListClientsForKeywords( + ["hostname1", "hostname2"], start_time=change_time + ) self.assertEqual(res["hostname1"], []) self.assertEqual(res["hostname2"], [client_id]) @@ -763,12 +788,19 @@ def testRemoveClientKeyword(self): d = self.db client_id = db_test_utils.InitializeClient(self.db) temporary_kw = "investigation42" - d.AddClientKeywords(client_id, [ - "joe", "machine.test.example.com", "machine.test.example", - "machine.test", temporary_kw - ]) + d.AddClientKeywords( + client_id, + [ + "joe", + "machine.test.example.com", + "machine.test.example", + "machine.test", + temporary_kw, + ], + ) self.assertEqual( - d.ListClientsForKeywords([temporary_kw])[temporary_kw], [client_id]) + d.ListClientsForKeywords([temporary_kw])[temporary_kw], [client_id] + ) d.RemoveClientKeyword(client_id, temporary_kw) self.assertEqual(d.ListClientsForKeywords([temporary_kw])[temporary_kw], []) self.assertEqual(d.ListClientsForKeywords(["joe"])["joe"], [client_id]) @@ -1208,8 +1240,9 @@ def testCrashHistory(self): hist = d.ReadClientCrashInfoHistory(client_id) self.assertLen(hist, 3) - self.assertEqual([ci.crash_message for ci in hist], - ["Crash #3", "Crash #2", "Crash #1"]) + self.assertEqual( + [ci.crash_message for ci in hist], ["Crash #3", "Crash #2", "Crash #1"] + ) self.assertGreater(hist[0].timestamp, hist[1].timestamp) self.assertGreater(hist[1].timestamp, hist[2].timestamp) @@ -1242,7 +1275,6 @@ def testReadClientFullInfoReturnsCorrectResult(self): kernel="12.3", ) d.WriteClientSnapshot(cl) - d.WriteClientMetadata(client_id, certificate=CERT) si = jobs_pb2.StartupInfo(boot_time=1) d.WriteClientStartupInfo(client_id, si) d.AddClientLabels(client_id, "test_owner", ["test_label"]) @@ -1260,7 +1292,6 @@ def testReadClientFullInfoReturnsCorrectResult(self): full_info.last_snapshot.knowledge_base.fqdn, "test1234.examples.com" ) - self.assertEqual(full_info.metadata.certificate, CERT.SerializeToBytes()) self.assertEqual(full_info.last_startup_info.boot_time, 1) self.assertLen(full_info.labels, 1) @@ -1280,7 +1311,8 @@ def testReadClientFullInfoTimestamps(self): first_seen=first_seen_time, last_clock=last_clock_time, last_ping=last_ping_time, - last_foreman=last_foreman_time) + last_foreman=last_foreman_time, + ) pre_time = self.db.Now() @@ -1313,8 +1345,9 @@ def _SetupFullInfoClients(self): self.db.WriteGRRUser("test_owner") for i in range(10): - client_id = db_test_utils.InitializeClient(self.db, - "C.000000005000000%d" % i) + client_id = db_test_utils.InitializeClient( + self.db, "C.000000005000000%d" % i + ) cl = objects_pb2.ClientSnapshot( client_id=client_id, @@ -1324,12 +1357,13 @@ def _SetupFullInfoClients(self): kernel="12.3.%d" % i, ) self.db.WriteClientSnapshot(cl) - self.db.WriteClientMetadata(client_id, certificate=CERT) si = jobs_pb2.StartupInfo(boot_time=i) self.db.WriteClientStartupInfo(client_id, si) self.db.AddClientLabels( - client_id, "test_owner", - ["test_label-a-%d" % i, "test_label-b-%d" % i]) + client_id, + "test_owner", + ["test_label-a-%d" % i, "test_label-b-%d" % i], + ) def _VerifySnapshots(self, snapshots): snapshots = sorted(snapshots, key=lambda s: s.client_id) @@ -1341,16 +1375,21 @@ def _VerifySnapshots(self, snapshots): def _VerifyFullInfos(self, c_infos): c_infos = sorted(c_infos, key=lambda c: c.last_snapshot.client_id) for i, full_info in enumerate(c_infos): - self.assertEqual(full_info.last_snapshot.client_id, - "C.000000005000000%d" % i) - self.assertEqual(full_info.metadata.certificate, CERT) + self.assertEqual( + full_info.last_snapshot.client_id, "C.000000005000000%d" % i + ) self.assertEqual(full_info.last_startup_info.boot_time, i) - self.assertCountEqual(full_info.labels, [ - rdf_objects.ClientLabel( - owner="test_owner", name="test_label-a-%d" % i), - rdf_objects.ClientLabel( - owner="test_owner", name="test_label-b-%d" % i) - ]) + self.assertCountEqual( + full_info.labels, + [ + rdf_objects.ClientLabel( + owner="test_owner", name="test_label-a-%d" % i + ), + rdf_objects.ClientLabel( + owner="test_owner", name="test_label-b-%d" % i + ), + ], + ) def testIterateAllClientsFullInfo(self): self._SetupFullInfoClients() @@ -1370,7 +1409,7 @@ def _SetupLastPingClients(self, now): self.db.WriteClientSnapshot( objects_pb2.ClientSnapshot(client_id=client_id) ) - ping = (time_past if i % 2 == 0 else now) + ping = time_past if i % 2 == 0 else now self.db.WriteClientMetadata(client_id, last_ping=ping) client_ids_to_ping[client_id] = ping @@ -1388,7 +1427,8 @@ def testMultiReadClientsFullInfoFiltersClientsByLastPingTime(self): cid for cid, ping in client_ids_to_ping.items() if ping == base_time ] full_infos = d.MultiReadClientFullInfo( - list(client_ids_to_ping.keys()), min_last_ping=cutoff_time) + list(client_ids_to_ping.keys()), min_last_ping=cutoff_time + ) self.assertCountEqual(expected_client_ids, full_infos) def testMultiReadClientsFullInfoWithEmptyList(self): @@ -1406,7 +1446,8 @@ def testMultiReadClientsFullInfoSkipsMissingClients(self): missing_client_id = "C.00413187fefa1dcf" full_infos = d.MultiReadClientFullInfo( - [present_client_id, missing_client_id]) + [present_client_id, missing_client_id] + ) self.assertEqual(list(full_infos.keys()), [present_client_id]) def testMultiReadClientsFullInfoNoSnapshot(self): @@ -1489,22 +1530,26 @@ def _AddClientKeyedData(self, client_id): # A flow. flow_id = flow.RandomFlowId() self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id)) + flows_pb2.Flow(client_id=client_id, flow_id=flow_id) + ) # A flow request. self.db.WriteFlowRequests([ - rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) ]) # A flow response. self.db.WriteFlowResponses([ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=1) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=1 + ) ]) # A flow processing request. self.db.WriteFlowProcessingRequests( - [rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)]) + [flows_pb2.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)] + ) return flow_id @@ -1585,13 +1630,22 @@ def testDeleteClientWithAssociatedMetadata(self): def testDeleteClientWithPaths(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_0 = rdf_objects.PathInfo.OS(components=("foo", "bar", "baz")) + path_info_0 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) path_info_0.stat_entry.st_size = 42 - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar", "quux")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "quux"), + ) path_info_1.hash_entry.sha256 = b"quux" - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "norf", "thud")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "norf", "thud"), + ) path_info_2.stat_entry.st_size = 1337 path_info_2.hash_entry.sha256 = b"norf" @@ -1605,7 +1659,8 @@ def testDeleteClientWithPaths(self): def testFleetspeakValidationInfoIsInitiallyUnset(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( - client_id, first_seen=rdfvalue.RDFDatetime(100000000)) + client_id, first_seen=rdfvalue.RDFDatetime(100000000) + ) res = self.db.MultiReadClientMetadata([client_id]) self.assertLen(res, 1) @@ -1616,10 +1671,8 @@ def testWritesFleetspeakValidationInfo(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( - client_id, fleetspeak_validation_info={ - "foo": "bar", - "12": "34" - }) + client_id, fleetspeak_validation_info={"foo": "bar", "12": "34"} + ) res = self.db.MultiReadClientMetadata([client_id]) self.assertLen(res, 1) @@ -1633,15 +1686,11 @@ def testOverwritesFleetspeakValidationInfo(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( - client_id, fleetspeak_validation_info={ - "foo": "bar", - "12": "34" - }) + client_id, fleetspeak_validation_info={"foo": "bar", "12": "34"} + ) self.db.WriteClientMetadata( - client_id, fleetspeak_validation_info={ - "foo": "bar", - "new": "1234" - }) + client_id, fleetspeak_validation_info={"foo": "bar", "new": "1234"} + ) res = self.db.MultiReadClientMetadata([client_id]) self.assertLen(res, 1) @@ -1655,7 +1704,8 @@ def testRemovesFleetspeakValidationInfoWhenValidationInfoIsEmpty(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( - client_id, fleetspeak_validation_info={"foo": "bar"}) + client_id, fleetspeak_validation_info={"foo": "bar"} + ) self.db.WriteClientMetadata(client_id, fleetspeak_validation_info={}) res = self.db.MultiReadClientMetadata([client_id]) @@ -1667,7 +1717,8 @@ def testKeepsFleetspeakValidationInfoWhenValidationInfoIsNotPresent(self): client_id = "C.fc413187fefa1dcf" self.db.WriteClientMetadata( - client_id, fleetspeak_validation_info={"foo": "bar"}) + client_id, fleetspeak_validation_info={"foo": "bar"} + ) self.db.WriteClientMetadata(client_id) res = self.db.MultiReadClientMetadata([client_id]) diff --git a/grr/server/grr_response_server/databases/db_cronjob_test.py b/grr/server/grr_response_server/databases/db_cronjob_test.py index 923885e70f..7203a67c68 100644 --- a/grr/server/grr_response_server/databases/db_cronjob_test.py +++ b/grr/server/grr_response_server/databases/db_cronjob_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Mixin tests for storing cronjob objects in the relational db.""" - from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import mig_protodict from grr_response_core.lib.util import random @@ -263,8 +262,9 @@ def testCronJobReturning(self): ) self.assertTrue(leased) - with test_lib.FakeTime(current_time + - rdfvalue.Duration.From(1, rdfvalue.MINUTES)): + with test_lib.FakeTime( + current_time + rdfvalue.Duration.From(1, rdfvalue.MINUTES) + ): self.db.ReturnLeasedCronJobs([leased[0]]) returned_job = self.db.ReadCronJob(leased[0].cron_job_id) @@ -279,13 +279,15 @@ def testCronJobReturningMultiple(self): current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( - lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES)) + lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) + ) self.assertLen(leased, 3) current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10001) with test_lib.FakeTime(current_time): unleased_jobs = self.db.LeaseCronJobs( - lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES)) + lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) + ) self.assertEmpty(unleased_jobs) self.db.ReturnLeasedCronJobs(leased) @@ -375,7 +377,8 @@ def testCronJobRunExpiry(self): self.db.WriteCronJob(job) fake_time = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 7, rdfvalue.DAYS) + 7, rdfvalue.DAYS + ) with test_lib.FakeTime(fake_time): run = flows_pb2.CronJobRun( cron_job_id=job_id, @@ -385,7 +388,8 @@ def testCronJobRunExpiry(self): self.db.WriteCronJobRun(run) fake_time_one_day_later = fake_time + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) with test_lib.FakeTime(fake_time_one_day_later): run = flows_pb2.CronJobRun( cron_job_id=job_id, @@ -395,7 +399,8 @@ def testCronJobRunExpiry(self): self.db.WriteCronJobRun(run) fake_time_two_days_later = fake_time + rdfvalue.Duration.From( - 2, rdfvalue.DAYS) + 2, rdfvalue.DAYS + ) with test_lib.FakeTime(fake_time_two_days_later): run = flows_pb2.CronJobRun( cron_job_id=job_id, diff --git a/grr/server/grr_response_server/databases/db_events_test.py b/grr/server/grr_response_server/databases/db_events_test.py index 635fcaf4a2..02732e2205 100644 --- a/grr/server/grr_response_server/databases/db_events_test.py +++ b/grr/server/grr_response_server/databases/db_events_test.py @@ -2,10 +2,7 @@ from typing import Optional from grr_response_core.lib import rdfvalue -from grr_response_server.rdfvalues import objects as rdf_objects - - -APIAuditEntry = rdf_objects.APIAuditEntry +from grr_response_proto import objects_pb2 def _Date(date: str) -> rdfvalue.RDFDatetime: @@ -19,27 +16,30 @@ def _MakeEntry( http_request_path: str = "/test", router_method_name: str = "TestHandler", username: str = "user", - response_code: APIAuditEntry.Code = APIAuditEntry.Code.OK, + response_code: objects_pb2.APIAuditEntry.Code = objects_pb2.APIAuditEntry.Code.OK, timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> APIAuditEntry: + ) -> objects_pb2.APIAuditEntry: self.db.WriteGRRUser(username) - return APIAuditEntry( + res = objects_pb2.APIAuditEntry( http_request_path=http_request_path, router_method_name=router_method_name, username=username, response_code=response_code, - timestamp=timestamp, ) + if timestamp is not None: + res.timestamp = timestamp.AsMicrosecondsSinceEpoch() + return res - def _WriteEntry(self, **kwargs) -> rdf_objects.APIAuditEntry: + def _WriteEntry(self, **kwargs) -> objects_pb2.APIAuditEntry: entry = self._MakeEntry(**kwargs) self.db.WriteAPIAuditEntry(entry) return entry def testWriteDoesNotMutate(self): entry = self._MakeEntry() - copy = entry.Copy() + copy = objects_pb2.APIAuditEntry() + copy.CopyFrom(entry) self.db.WriteAPIAuditEntry(entry) self.assertEqual(entry, copy) @@ -50,7 +50,7 @@ def testWriteAuditEntry(self): self.assertLen(entries, 1) # We should not compare timestamps. - entries[0].timestamp = None + entries[0].ClearField("timestamp") self.assertCountEqual(entries, [entry]) def testWriteEntriesWithMicrosecondDifference(self): @@ -68,23 +68,25 @@ def testWriteEntriesWithMicrosecondDifference(self): def testReadEntries(self): entry1 = self._WriteEntry() - entry2 = self._WriteEntry(response_code=APIAuditEntry.Code.ERROR) + entry2 = self._WriteEntry( + response_code=objects_pb2.APIAuditEntry.Code.ERROR + ) entries = self.db.ReadAPIAuditEntries() self.assertLen(entries, 2) # We should not compare timestamps. - entries[0].timestamp = None - entries[1].timestamp = None + entries[0].ClearField("timestamp") + entries[1].ClearField("timestamp") self.assertCountEqual(entries, [entry1, entry2]) def testReadEntriesOrder(self): status_codes = [ - APIAuditEntry.Code.OK, - APIAuditEntry.Code.ERROR, - APIAuditEntry.Code.FORBIDDEN, - APIAuditEntry.Code.NOT_FOUND, - APIAuditEntry.Code.NOT_IMPLEMENTED, + objects_pb2.APIAuditEntry.Code.OK, + objects_pb2.APIAuditEntry.Code.ERROR, + objects_pb2.APIAuditEntry.Code.FORBIDDEN, + objects_pb2.APIAuditEntry.Code.NOT_FOUND, + objects_pb2.APIAuditEntry.Code.NOT_IMPLEMENTED, ] for status_code in status_codes: @@ -104,7 +106,7 @@ def testReadEntriesFilterUsername(self): self.assertLen(entries, 1) # We should not compare timestamps. - entries[0].timestamp = None + entries[0].ClearField("timestamp") self.assertCountEqual(entries, [entry]) def testReadEntriesFilterRouterMethodName(self): @@ -117,13 +119,13 @@ def testReadEntriesFilterRouterMethodName(self): self.assertCountEqual(router_method_names, ["foo", "bar"]) def testReadEntriesFilterTimestamp(self): - self._WriteEntry(response_code=APIAuditEntry.Code.OK) + self._WriteEntry(response_code=objects_pb2.APIAuditEntry.Code.OK) ok_timestamp = self.db.Now() - self._WriteEntry(response_code=APIAuditEntry.Code.ERROR) + self._WriteEntry(response_code=objects_pb2.APIAuditEntry.Code.ERROR) error_timestamp = self.db.Now() - self._WriteEntry(response_code=APIAuditEntry.Code.NOT_FOUND) + self._WriteEntry(response_code=objects_pb2.APIAuditEntry.Code.NOT_FOUND) not_found_timestamp = self.db.Now() entries = self.db.ReadAPIAuditEntries(min_timestamp=not_found_timestamp) @@ -133,17 +135,30 @@ def testReadEntriesFilterTimestamp(self): self.assertLen(entries, 3) entries = self.db.ReadAPIAuditEntries(min_timestamp=ok_timestamp) - self.assertEqual([e.response_code for e in entries], - [APIAuditEntry.Code.ERROR, APIAuditEntry.Code.NOT_FOUND]) + self.assertEqual( + [e.response_code for e in entries], + [ + objects_pb2.APIAuditEntry.Code.ERROR, + objects_pb2.APIAuditEntry.Code.NOT_FOUND, + ], + ) entries = self.db.ReadAPIAuditEntries(max_timestamp=error_timestamp) - self.assertEqual([e.response_code for e in entries], - [APIAuditEntry.Code.OK, APIAuditEntry.Code.ERROR]) + self.assertEqual( + [e.response_code for e in entries], + [ + objects_pb2.APIAuditEntry.Code.OK, + objects_pb2.APIAuditEntry.Code.ERROR, + ], + ) entries = self.db.ReadAPIAuditEntries( - min_timestamp=ok_timestamp, max_timestamp=error_timestamp) - self.assertEqual([e.response_code for e in entries], - [APIAuditEntry.Code.ERROR]) + min_timestamp=ok_timestamp, max_timestamp=error_timestamp + ) + self.assertEqual( + [e.response_code for e in entries], + [objects_pb2.APIAuditEntry.Code.ERROR], + ) def testCountEntries(self): day = _Date("2019-02-02") @@ -151,17 +166,17 @@ def testCountEntries(self): self._WriteEntry(username="user1", timestamp=_Date("2019-02-02 00:00")) self._WriteEntry(username="user2", timestamp=_Date("2019-02-02 00:00")) - self.assertEqual({ - ("user1", day): 1, - ("user2", day): 1 - }, self.db.CountAPIAuditEntriesByUserAndDay()) + self.assertEqual( + {("user1", day): 1, ("user2", day): 1}, + self.db.CountAPIAuditEntriesByUserAndDay(), + ) self._WriteEntry(username="user1", timestamp=_Date("2019-02-02 23:59:59")) - self.assertEqual({ - ("user1", day): 2, - ("user2", day): 1 - }, self.db.CountAPIAuditEntriesByUserAndDay()) + self.assertEqual( + {("user1", day): 2, ("user2", day): 1}, + self.db.CountAPIAuditEntriesByUserAndDay(), + ) def testCountEntriesFilteredByTimestamp(self): self._WriteEntry(username="user", timestamp=_Date("2019-02-01")) @@ -173,13 +188,16 @@ def testCountEntriesFilteredByTimestamp(self): counts = self.db.CountAPIAuditEntriesByUserAndDay( min_timestamp=_Date("2019-02-02"), - max_timestamp=_Date("2019-02-03 23:59:59")) + max_timestamp=_Date("2019-02-03 23:59:59"), + ) self.assertEqual( { ("user1", _Date("2019-02-02")): 2, ("user2", _Date("2019-02-02")): 1, ("user1", _Date("2019-02-03")): 1, - }, counts) + }, + counts, + ) def testDeleteUsersRetainsApiAuditEntries(self): self._WriteEntry(username="foo") @@ -192,9 +210,9 @@ def testDeleteUsersRetainsApiAuditEntries(self): def testWriteAndReadWithCommitTimestamp(self): entry = self._MakeEntry(username="foo") - before = self.db.Now() + before = self.db.Now().AsMicrosecondsSinceEpoch() self.db.WriteAPIAuditEntry(entry) - after = self.db.Now() + after = self.db.Now().AsMicrosecondsSinceEpoch() entries = self.db.ReadAPIAuditEntries(username="foo") self.assertLen(entries, 1) diff --git a/grr/server/grr_response_server/databases/db_flows_test.py b/grr/server/grr_response_server/databases/db_flows_test.py index 58ca02e0ed..13d7150cd9 100644 --- a/grr/server/grr_response_server/databases/db_flows_test.py +++ b/grr/server/grr_response_server/databases/db_flows_test.py @@ -5,7 +5,7 @@ import random import threading import time -from typing import Optional, Sequence +from typing import List, Optional, Sequence, Tuple, Union from unittest import mock from google.protobuf import any_pb2 @@ -13,8 +13,8 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_server import flow @@ -22,7 +22,7 @@ from grr_response_server.databases import db_test_utils from grr_response_server.flows import file from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import test_lib @@ -33,86 +33,102 @@ class DatabaseTestFlowMixin(object): """ def testFlowWritingUnknownClient(self): - flow_id = u"1234ABCD" - client_id = u"C.1234567890123456" + flow_id = "1234ABCD" + client_id = "C.1234567890123456" - rdf_flow = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) with self.assertRaises(db.UnknownClientError): - self.db.WriteFlowObject(rdf_flow) + self.db.WriteFlowObject(flow_obj) def testFlowWriting(self): - flow_id = u"1234ABCD" - client_id = u"C.1234567890123456" + flow_id = "1234ABCD" + client_id = "C.1234567890123456" self.db.WriteClientMetadata(client_id) - rdf_flow = rdf_flow_objects.Flow( + flow_obj = flows_pb2.Flow( client_id=client_id, flow_id=flow_id, long_flow_id=f"{client_id}/{flow_id}", - next_request_to_process=4) - self.db.WriteFlowObject(rdf_flow) + next_request_to_process=4, + creator="foo", + flow_class_name="bar", + ) + flow_obj.cpu_time_used.user_cpu_time = 123 + flow_obj.cpu_time_used.system_cpu_time = 456 + flow_obj.network_bytes_sent = 789 + self.db.WriteFlowObject(flow_obj) read_flow = self.db.ReadFlowObject(client_id, flow_id) # Last update and creation times have changed, everything else should be # equal. - read_flow.create_time = None - read_flow.last_update_time = None + read_flow.ClearField("create_time") + read_flow.ClearField("last_update_time") - self.assertEqual(read_flow, rdf_flow) + self.assertEqual(read_flow, flow_obj) # Invalid flow id or client id raises. with self.assertRaises(db.UnknownFlowError): - self.db.ReadFlowObject(client_id, u"1234AAAA") + self.db.ReadFlowObject(client_id, "1234AAAA") with self.assertRaises(db.UnknownFlowError): - self.db.ReadFlowObject(u"C.1234567890000000", flow_id) + self.db.ReadFlowObject("C.1234567890000000", flow_id) def testFlowOverwrite(self): - flow_id = u"1234ABCD" - client_id = u"C.1234567890123456" + flow_id = "1234ABCD" + client_id = "C.1234567890123456" self.db.WriteClientMetadata(client_id) - rdf_flow = rdf_flow_objects.Flow( - client_id=client_id, flow_id=flow_id, next_request_to_process=4) - self.db.WriteFlowObject(rdf_flow) + flow_obj = flows_pb2.Flow( + client_id=client_id, + flow_id=flow_id, + long_flow_id=f"{client_id}/{flow_id}", + next_request_to_process=4, + creator="foo", + flow_class_name="bar", + ) + flow_obj.cpu_time_used.user_cpu_time = 123 + flow_obj.cpu_time_used.system_cpu_time = 456 + flow_obj.network_bytes_sent = 789 + self.db.WriteFlowObject(flow_obj) read_flow = self.db.ReadFlowObject(client_id, flow_id) # Last update and creation times have changed, everything else should be # equal. - read_flow.create_time = None - read_flow.last_update_time = None + read_flow.ClearField("create_time") + read_flow.ClearField("last_update_time") - self.assertEqual(read_flow, rdf_flow) + self.assertEqual(read_flow, flow_obj) # Now change the flow object. - rdf_flow.next_request_to_process = 5 + flow_obj.next_request_to_process = 5 - self.db.WriteFlowObject(rdf_flow) + self.db.WriteFlowObject(flow_obj) read_flow_after_update = self.db.ReadFlowObject(client_id, flow_id) self.assertEqual(read_flow_after_update.next_request_to_process, 5) def testFlowOverwriteFailsWithAllowUpdateFalse(self): - flow_id = u"1234ABCD" - client_id = u"C.1234567890123456" + flow_id = "1234ABCD" + client_id = "C.1234567890123456" self.db.WriteClientMetadata(client_id) - rdf_flow = rdf_flow_objects.Flow( - client_id=client_id, flow_id=flow_id, next_request_to_process=4) - self.db.WriteFlowObject(rdf_flow, allow_update=False) + flow_obj = flows_pb2.Flow( + client_id=client_id, flow_id=flow_id, next_request_to_process=4 + ) + self.db.WriteFlowObject(flow_obj, allow_update=False) # Now change the flow object. - rdf_flow.next_request_to_process = 5 + flow_obj.next_request_to_process = 5 with self.assertRaises(db.FlowExistsError) as context: - self.db.WriteFlowObject(rdf_flow, allow_update=False) + self.db.WriteFlowObject(flow_obj, allow_update=False) self.assertEqual(context.exception.client_id, client_id) self.assertEqual(context.exception.flow_id, flow_id) @@ -126,12 +142,12 @@ def testFlowTimestamp(self): self.db.WriteClientMetadata(client_id) - before_timestamp = self.db.Now() + before_timestamp = self.db.Now().AsMicrosecondsSinceEpoch() - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) self.db.WriteFlowObject(flow_obj) - after_timestamp = self.db.Now() + after_timestamp = self.db.Now().AsMicrosecondsSinceEpoch() flow_obj = self.db.ReadFlowObject(client_id=client_id, flow_id=flow_id) self.assertBetween(flow_obj.create_time, before_timestamp, after_timestamp) @@ -142,13 +158,12 @@ def testFlowTimestampWithMissingCreationTime(self): self.db.WriteClientMetadata(client_id) - before_timestamp = self.db.Now() + before_timestamp = self.db.Now().AsMicrosecondsSinceEpoch() - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) - flow_obj.create_time = None + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) self.db.WriteFlowObject(flow_obj) - after_timestamp = self.db.Now() + after_timestamp = self.db.Now().AsMicrosecondsSinceEpoch() flow_obj = self.db.ReadFlowObject(client_id=client_id, flow_id=flow_id) self.assertBetween(flow_obj.create_time, before_timestamp, after_timestamp) @@ -159,11 +174,11 @@ def testFlowNameWithMissingNameInProtobuf(self): self.db.WriteClientMetadata(client_id) - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) flow_obj.flow_class_name = "Quux" self.db.WriteFlowObject(flow_obj) - flow_obj.flow_class_name = None + flow_obj.ClearField("flow_class_name") self.db.UpdateFlow(client_id=client_id, flow_id=flow_id, flow_obj=flow_obj) flow_obj = self.db.ReadFlowObject(client_id=client_id, flow_id=flow_id) @@ -175,7 +190,7 @@ def testFlowKeyMetadataUnchangable(self): self.db.WriteClientMetadata(client_id) - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) flow_obj.long_flow_id = f"{client_id}/{flow_id}" self.db.WriteFlowObject(flow_obj) @@ -193,7 +208,7 @@ def testFlowParentMetadataUnchangable(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = "0F00B430" - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) parent_flow_id_1 = db_test_utils.InitializeFlow(self.db, client_id) parent_hunt_id_1 = db_test_utils.InitializeHunt(self.db) @@ -219,7 +234,7 @@ def testFlowNameUnchangable(self): self.db.WriteClientMetadata(client_id) - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) flow_obj.flow_class_name = "Quux" self.db.WriteFlowObject(flow_obj) @@ -235,7 +250,7 @@ def testFlowCreatorUnchangable(self): self.db.WriteClientMetadata(client_id) - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) flow_obj.creator = "norf" self.db.WriteFlowObject(flow_obj) @@ -251,11 +266,11 @@ def testFlowCreatorUnsetInProtobuf(self): self.db.WriteClientMetadata(client_id) - flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + flow_obj = flows_pb2.Flow(client_id=client_id, flow_id=flow_id) flow_obj.creator = "norf" self.db.WriteFlowObject(flow_obj) - flow_obj.creator = None + flow_obj.ClearField("creator") self.db.UpdateFlow(client_id=client_id, flow_id=flow_id, flow_obj=flow_obj) flow_obj = self.db.ReadFlowObject(client_id=client_id, flow_id=flow_id) @@ -270,30 +285,37 @@ def testReadAllFlowObjects(self): # Write a flow and a child flow for client 1. flow1 = rdf_flow_objects.Flow(client_id=client_id_1, flow_id="000A0001") - self.db.WriteFlowObject(flow1) + proto_flow = mig_flow_objects.ToProtoFlow(flow1) + self.db.WriteFlowObject(proto_flow) flow2 = rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="000A0002", parent_flow_id="000A0001") - self.db.WriteFlowObject(flow2) + client_id=client_id_1, flow_id="000A0002", parent_flow_id="000A0001" + ) + proto_flow = mig_flow_objects.ToProtoFlow(flow2) + self.db.WriteFlowObject(proto_flow) # Same flow id for client 2. flow3 = rdf_flow_objects.Flow(client_id=client_id_2, flow_id="000A0001") - self.db.WriteFlowObject(flow3) + proto_flow = mig_flow_objects.ToProtoFlow(flow3) + self.db.WriteFlowObject(proto_flow) flows = self.db.ReadAllFlowObjects() - self.assertCountEqual([f.flow_id for f in flows], - ["000A0001", "000A0002", "000A0001"]) + self.assertCountEqual( + [f.flow_id for f in flows], ["000A0001", "000A0002", "000A0001"] + ) def testReadAllFlowObjectsWithMinCreateTime(self): client_id_1 = "C.1111111111111111" self.db.WriteClientMetadata(client_id_1) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001A")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001A") + ) timestamp = self.db.Now() self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001B")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001B") + ) flows = self.db.ReadAllFlowObjects(min_create_time=timestamp) self.assertEqual([f.flow_id for f in flows], ["0000001B"]) @@ -303,12 +325,14 @@ def testReadAllFlowObjectsWithMaxCreateTime(self): self.db.WriteClientMetadata(client_id_1) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001A")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001A") + ) timestamp = self.db.Now() self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001B")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001B") + ) flows = self.db.ReadAllFlowObjects(max_create_time=timestamp) self.assertEqual([f.flow_id for f in flows], ["0000001A"]) @@ -320,9 +344,11 @@ def testReadAllFlowObjectsWithClientID(self): self.db.WriteClientMetadata(client_id_2) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001A")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001A") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_2, flow_id="0000001B")) + flows_pb2.Flow(client_id=client_id_2, flow_id="0000001B") + ) flows = self.db.ReadAllFlowObjects(client_id=client_id_1) self.assertEqual([f.flow_id for f in flows], ["0000001A"]) @@ -333,29 +359,32 @@ def testReadAllFlowObjectsWitParentFlowID(self): parent_flow = rdf_flow_objects.Flow() parent_flow.client_id = client_id parent_flow.flow_id = "AAAAAAAA" - self.db.WriteFlowObject(parent_flow) + proto_flow = mig_flow_objects.ToProtoFlow(parent_flow) + self.db.WriteFlowObject(proto_flow) child_flow_1 = rdf_flow_objects.Flow() child_flow_1.client_id = client_id child_flow_1.flow_id = "CCCC1111" child_flow_1.parent_flow_id = "AAAAAAAA" - self.db.WriteFlowObject(child_flow_1) + proto_flow = mig_flow_objects.ToProtoFlow(child_flow_1) + self.db.WriteFlowObject(proto_flow) child_flow_2 = rdf_flow_objects.Flow() child_flow_2.client_id = client_id child_flow_2.flow_id = "CCCC2222" child_flow_2.parent_flow_id = "AAAAAAAA" - self.db.WriteFlowObject(child_flow_2) + proto_flow = mig_flow_objects.ToProtoFlow(child_flow_2) + self.db.WriteFlowObject(proto_flow) not_child_flow = rdf_flow_objects.Flow() not_child_flow.client_id = client_id not_child_flow.flow_id = "FFFFFFFF" - self.db.WriteFlowObject(not_child_flow) + proto_flow = mig_flow_objects.ToProtoFlow(not_child_flow) + self.db.WriteFlowObject(proto_flow) result = self.db.ReadAllFlowObjects( - client_id=client_id, - parent_flow_id="AAAAAAAA", - include_child_flows=True) + client_id=client_id, parent_flow_id="AAAAAAAA", include_child_flows=True + ) result_flow_ids = set(_.flow_id for _ in result) self.assertIn("CCCC1111", result_flow_ids) @@ -369,25 +398,28 @@ def testReadAllFlowObjectsWithParentFlowIDWithoutChildren(self): parent_flow = rdf_flow_objects.Flow() parent_flow.client_id = client_id parent_flow.flow_id = "AAAAAAAA" - self.db.WriteFlowObject(parent_flow) + proto_flow = mig_flow_objects.ToProtoFlow(parent_flow) + self.db.WriteFlowObject(proto_flow) with self.assertRaises(ValueError): self.db.ReadAllFlowObjects( client_id=client_id, parent_flow_id="AAAAAAAAA", - include_child_flows=False) + include_child_flows=False, + ) def testReadAllFlowObjectsWithoutChildren(self): client_id_1 = "C.1111111111111111" self.db.WriteClientMetadata(client_id_1) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000001A")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000001A") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, - flow_id="0000001B", - parent_flow_id="0000001A")) + flows_pb2.Flow( + client_id=client_id_1, flow_id="0000001B", parent_flow_id="0000001A" + ) + ) flows = self.db.ReadAllFlowObjects(include_child_flows=False) self.assertEqual([f.flow_id for f in flows], ["0000001A"]) @@ -397,14 +429,14 @@ def testReadAllFlowObjectsWithNotCreatedBy(self): self.db.WriteClientMetadata(client_id_1) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="000A0001", creator="foo")) + flows_pb2.Flow(client_id=client_id_1, flow_id="000A0001", creator="foo") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="000A0002", creator="bar")) + flows_pb2.Flow(client_id=client_id_1, flow_id="000A0002", creator="bar") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="000A0003", creator="baz")) + flows_pb2.Flow(client_id=client_id_1, flow_id="000A0003", creator="baz") + ) flows = self.db.ReadAllFlowObjects(not_created_by=frozenset(["baz", "foo"])) self.assertCountEqual([f.flow_id for f in flows], ["000A0002"]) @@ -418,49 +450,53 @@ def testReadAllFlowObjectsWithAllConditions(self): min_timestamp = self.db.Now() self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="0000000A", creator="bar")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000000A", creator="bar") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, flow_id="0000000F", creator="foo")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000000F", creator="foo") + ) max_timestamp = self.db.Now() self.db.WriteFlowObject( - rdf_flow_objects.Flow( - client_id=client_id_1, - flow_id="0000000B", - parent_flow_id="0000000A")) + flows_pb2.Flow( + client_id=client_id_1, flow_id="0000000B", parent_flow_id="0000000A" + ) + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000000C")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000000C") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_1, flow_id="0000000D")) + flows_pb2.Flow(client_id=client_id_1, flow_id="0000000D") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow(client_id=client_id_2, flow_id="0000000E")) + flows_pb2.Flow(client_id=client_id_2, flow_id="0000000E") + ) flows = self.db.ReadAllFlowObjects( client_id=client_id_1, min_create_time=min_timestamp, max_create_time=max_timestamp, include_child_flows=False, - not_created_by=frozenset(["baz", "foo"])) + not_created_by=frozenset(["baz", "foo"]), + ) self.assertEqual([f.flow_id for f in flows], ["0000000A"]) def testUpdateUnknownFlow(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - crash_info = rdf_client.ClientCrash(crash_message="oh no") + crash_info = jobs_pb2.ClientCrash(crash_message="oh no") with self.assertRaises(db.UnknownFlowError): self.db.UpdateFlow( - u"C.1234567890AAAAAA", flow_id, client_crash_info=crash_info) + "C.1234567890AAAAAA", flow_id, client_crash_info=crash_info + ) def testFlowUpdateChangesAllFields(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) flow_obj = self.db.ReadFlowObject(client_id, flow_id) - flow_obj.cpu_time_used.user_cpu_time = 0.5 flow_obj.cpu_time_used.system_cpu_time = 1.5 flow_obj.num_replies_sent = 10 @@ -471,9 +507,8 @@ def testFlowUpdateChangesAllFields(self): read_flow = self.db.ReadFlowObject(client_id, flow_id) # Last update times will differ. - read_flow.last_update_time = None - flow_obj.last_update_time = None - + read_flow.ClearField("last_update_time") + flow_obj.ClearField("last_update_time") self.assertEqual(read_flow, flow_obj) def testFlowStateUpdate(self): @@ -482,20 +517,19 @@ def testFlowStateUpdate(self): # Check that just updating flow_state works fine. self.db.UpdateFlow( - client_id, flow_id, flow_state=rdf_flow_objects.Flow.FlowState.CRASHED) + client_id, flow_id, flow_state=flows_pb2.Flow.FlowState.CRASHED + ) read_flow = self.db.ReadFlowObject(client_id, flow_id) - self.assertEqual(read_flow.flow_state, - rdf_flow_objects.Flow.FlowState.CRASHED) + self.assertEqual(read_flow.flow_state, flows_pb2.Flow.FlowState.CRASHED) # TODO(user): remove an option to update the flow by updating flow_obj. # It makes the DB API unnecessary complicated. # Check that changing flow_state through flow_obj works too. - read_flow.flow_state = rdf_flow_objects.Flow.FlowState.RUNNING + read_flow.flow_state = flows_pb2.Flow.FlowState.RUNNING self.db.UpdateFlow(client_id, flow_id, flow_obj=read_flow) read_flow_2 = self.db.ReadFlowObject(client_id, flow_id) - self.assertEqual(read_flow_2.flow_state, - rdf_flow_objects.Flow.FlowState.RUNNING) + self.assertEqual(read_flow_2.flow_state, flows_pb2.Flow.FlowState.RUNNING) def testUpdatingFlowObjAndFlowStateInSingleUpdateRaises(self): client_id = db_test_utils.InitializeClient(self.db) @@ -507,13 +541,14 @@ def testUpdatingFlowObjAndFlowStateInSingleUpdateRaises(self): client_id, flow_id, flow_obj=read_flow, - flow_state=rdf_flow_objects.Flow.FlowState.CRASHED) + flow_state=flows_pb2.Flow.FlowState.CRASHED, + ) def testCrashInfoUpdate(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - crash_info = rdf_client.ClientCrash(crash_message="oh no") + crash_info = jobs_pb2.ClientCrash(crash_message="oh no") self.db.UpdateFlow(client_id, flow_id, client_crash_info=crash_info) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertEqual(read_flow.client_crash_info, crash_info) @@ -529,7 +564,8 @@ def testProcessingInformationUpdate(self): flow_id, processing_on="Worker1", processing_since=now, - processing_deadline=deadline) + processing_deadline=deadline, + ) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertEqual(read_flow.processing_on, "Worker1") self.assertEqual(read_flow.processing_since, now) @@ -538,51 +574,74 @@ def testProcessingInformationUpdate(self): # None can be used to clear some fields. self.db.UpdateFlow(client_id, flow_id, processing_on=None) read_flow = self.db.ReadFlowObject(client_id, flow_id) - self.assertEqual(read_flow.processing_on, "") + self.assertFalse(read_flow.HasField("processing_on")) self.db.UpdateFlow(client_id, flow_id, processing_since=None) read_flow = self.db.ReadFlowObject(client_id, flow_id) - self.assertEqual(read_flow.processing_since, None) + self.assertFalse(read_flow.HasField("processing_since")) self.db.UpdateFlow(client_id, flow_id, processing_deadline=None) read_flow = self.db.ReadFlowObject(client_id, flow_id) - self.assertEqual(read_flow.processing_deadline, None) + self.assertFalse(read_flow.HasField("processing_deadline")) + + def testUpdateFlowUpdateTime(self): + client_id = db_test_utils.InitializeClient(self.db) + flow_id = db_test_utils.InitializeFlow(self.db, client_id) + + pre_update_time = self.db.Now().AsMicrosecondsSinceEpoch() + + self.db.UpdateFlow( + client_id, + flow_id, + flow_state=flows_pb2.Flow.FlowState.FINISHED, + ) + + post_update_time = self.db.Now().AsMicrosecondsSinceEpoch() + + flow_obj = self.db.ReadFlowObject(client_id, flow_id) + self.assertBetween( + flow_obj.last_update_time, pre_update_time, post_update_time + ) def testRequestWriting(self): - client_id_1 = u"C.1234567890123456" - client_id_2 = u"C.1234567890123457" - flow_id_1 = u"1234ABCD" - flow_id_2 = u"ABCD1234" + client_id_1 = "C.1234567890123456" + client_id_2 = "C.1234567890123457" + flow_id_1 = "1234ABCD" + flow_id_2 = "ABCD1234" with self.assertRaises(db.AtLeastOneUnknownFlowError): - self.db.WriteFlowRequests([ - rdf_flow_objects.FlowRequest( - client_id=client_id_1, flow_id=flow_id_1) - ]) + self.db.WriteFlowRequests( + [flows_pb2.FlowRequest(client_id=client_id_1, flow_id=flow_id_1)] + ) for client_id in [client_id_1, client_id_2]: self.db.WriteClientMetadata(client_id) requests = [] for flow_id in [flow_id_1, flow_id_2]: for client_id in [client_id_1, client_id_2]: - rdf_flow = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) - self.db.WriteFlowObject(rdf_flow) + self.db.WriteFlowObject( + flows_pb2.Flow(client_id=client_id, flow_id=flow_id) + ) for i in range(1, 4): requests.append( - rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=i)) + flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=i + ) + ) self.db.WriteFlowRequests(requests) for flow_id in [flow_id_1, flow_id_2]: for client_id in [client_id_1, client_id_2]: read = self.db.ReadAllFlowRequestsAndResponses( - client_id=client_id, flow_id=flow_id) + client_id=client_id, flow_id=flow_id + ) self.assertLen(read, 3) - self.assertEqual([req.request_id for (req, _) in read], - list(range(1, 4))) + self.assertEqual( + [req.request_id for (req, _) in read], list(range(1, 4)) + ) for _, responses in read: self.assertEqual(responses, {}) @@ -592,7 +651,8 @@ def _WriteRequestForProcessing(self, client_id, flow_id, request_id): self.addCleanup(self.db.UnregisterFlowProcessingHandler) marked_flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=3) + self.db, client_id, next_request_to_process=3 + ) # We write 2 requests, one after another: # First request is the request provided by the user. Second is @@ -605,16 +665,18 @@ def _WriteRequestForProcessing(self, client_id, flow_id, request_id): # for the user-supplied request. Effectively, callback's invocation for # the marked requests acts as a checkpoint: after it we can make # assertions. - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( flow_id=flow_id, client_id=client_id, request_id=request_id, - needs_processing=True) - marked_request = rdf_flow_objects.FlowRequest( + needs_processing=True, + ) + marked_request = flows_pb2.FlowRequest( flow_id=marked_flow_id, client_id=client_id, request_id=3, - needs_processing=True) + needs_processing=True, + ) self.db.WriteFlowRequests([request, marked_request]) @@ -640,13 +702,15 @@ def _WriteRequestForProcessing(self, client_id, flow_id, request_id): time.sleep(0.1) if rdfvalue.RDFDatetime.Now() - cur_time > rdfvalue.Duration.From( - 10, rdfvalue.SECONDS): + 10, rdfvalue.SECONDS + ): self.fail("Flow request was not processed in time.") def testRequestWritingHighIDDoesntTriggerFlowProcessing(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=3) + self.db, client_id, next_request_to_process=3 + ) requests_triggered = self._WriteRequestForProcessing(client_id, flow_id, 4) # Not the expected request. @@ -655,7 +719,8 @@ def testRequestWritingHighIDDoesntTriggerFlowProcessing(self): def testRequestWritingLowIDDoesntTriggerFlowProcessing(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=3) + self.db, client_id, next_request_to_process=3 + ) requests_triggered = self._WriteRequestForProcessing(client_id, flow_id, 2) # Not the expected request. @@ -664,7 +729,8 @@ def testRequestWritingLowIDDoesntTriggerFlowProcessing(self): def testRequestWritingExpectedIDTriggersFlowProcessing(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=3) + self.db, client_id, next_request_to_process=3 + ) requests_triggered = self._WriteRequestForProcessing(client_id, flow_id, 3) # This one is. @@ -673,19 +739,21 @@ def testRequestWritingExpectedIDTriggersFlowProcessing(self): def testFlowRequestsWithStartTimeAreCorrectlyDelayed(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=3) + self.db, client_id, next_request_to_process=3 + ) req_func = mock.Mock() self.db.RegisterFlowProcessingHandler(req_func) self.addCleanup(self.db.UnregisterFlowProcessingHandler) cur_time = rdfvalue.RDFDatetime.Now() - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( flow_id=flow_id, client_id=client_id, request_id=3, - start_time=cur_time + rdfvalue.Duration.From(2, rdfvalue.SECONDS), - needs_processing=True) + start_time=int(cur_time + rdfvalue.Duration.From(2, rdfvalue.SECONDS)), + needs_processing=True, + ) self.db.WriteFlowRequests([request]) self.assertEqual(req_func.call_count, 0) @@ -693,11 +761,14 @@ def testFlowRequestsWithStartTimeAreCorrectlyDelayed(self): while req_func.call_count == 0: time.sleep(0.1) if rdfvalue.RDFDatetime.Now() - cur_time > rdfvalue.Duration.From( - 10, rdfvalue.SECONDS): + 10, rdfvalue.SECONDS + ): self.fail("Flow request was not processed in time.") - self.assertGreaterEqual(rdfvalue.RDFDatetime.Now() - cur_time, - rdfvalue.Duration.From(2, rdfvalue.SECONDS)) + self.assertGreaterEqual( + rdfvalue.RDFDatetime.Now() - cur_time, + rdfvalue.Duration.From(2, rdfvalue.SECONDS), + ) def testDeleteFlowRequests(self): client_id = db_test_utils.InitializeClient(self.db) @@ -707,21 +778,27 @@ def testDeleteFlowRequests(self): responses = [] for request_id in range(1, 4): requests.append( - rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=request_id)) + flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=request_id + ) + ) responses.append( - rdf_flow_objects.FlowResponse( + flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=1)) + response_id=1, + ) + ) self.db.WriteFlowRequests(requests) self.db.WriteFlowResponses(responses) request_list = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) - self.assertCountEqual([req.request_id for req, _ in request_list], - [req.request_id for req in requests]) + self.assertCountEqual( + [req.request_id for req, _ in request_list], + [req.request_id for req in requests], + ) random.shuffle(requests) @@ -729,18 +806,21 @@ def testDeleteFlowRequests(self): request = requests.pop() self.db.DeleteFlowRequests([request]) request_list = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) - self.assertCountEqual([req.request_id for req, _ in request_list], - [req.request_id for req in requests]) + self.assertCountEqual( + [req.request_id for req, _ in request_list], + [req.request_id for req in requests], + ) def testResponsesForUnknownFlow(self): - client_id = u"C.1234567890123456" - flow_id = u"1234ABCD" + client_id = "C.1234567890123456" + flow_id = "1234ABCD" # This will not raise but also not write anything. with test_lib.SuppressLogs(): self.db.WriteFlowResponses([ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=1) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=1 + ) ]) read = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) self.assertEqual(read, []) @@ -749,18 +829,20 @@ def testResponsesForUnknownRequest(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) self.db.WriteFlowRequests([request]) # Write two responses at a time, one request exists, the other doesn't. with test_lib.SuppressLogs(): self.db.WriteFlowResponses([ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, - response_id=1), - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=2, response_id=1) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=1 + ), + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=2, response_id=1 + ), ]) # We should have one response in the db. @@ -773,20 +855,20 @@ def testWriteResponsesConcurrent(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest() + request = flows_pb2.FlowRequest() request.client_id = client_id request.flow_id = flow_id request.request_id = 1 request.next_state = "FOO" self.db.WriteFlowRequests([request]) - response = rdf_flow_objects.FlowResponse() + response = flows_pb2.FlowResponse() response.client_id = client_id response.flow_id = flow_id response.request_id = 1 response.response_id = 1 - status = rdf_flow_objects.FlowStatus() + status = flows_pb2.FlowStatus() status.client_id = client_id status.flow_id = flow_id status.request_id = 1 @@ -821,19 +903,21 @@ def testStatusForUnknownRequest(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) self.db.WriteFlowRequests([request]) # Write two status responses at a time, one for the request that exists, one # for a request that doesn't. with test_lib.SuppressLogs(): self.db.WriteFlowResponses([ - rdf_flow_objects.FlowStatus( - client_id=client_id, flow_id=flow_id, request_id=1, - response_id=1), - rdf_flow_objects.FlowStatus( - client_id=client_id, flow_id=flow_id, request_id=2, response_id=1) + flows_pb2.FlowStatus( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=1 + ), + flows_pb2.FlowStatus( + client_id=client_id, flow_id=flow_id, request_id=2, response_id=1 + ), ]) # We should have one response in the db. @@ -848,16 +932,21 @@ def testResponseWriting(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1, - needs_processing=False) + needs_processing=False, + ) + + before_write = self.db.Now() self.db.WriteFlowRequests([request]) + after_write = self.db.Now() responses = [ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=i) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=i + ) for i in range(3) ] @@ -867,7 +956,11 @@ def testResponseWriting(self): self.assertLen(all_requests, 1) read_request, read_responses = all_requests[0] - self.assertEqual(read_request, request) + self.assertEqual(read_request.client_id, request.client_id) + self.assertEqual(read_request.flow_id, request.flow_id) + self.assertEqual(read_request.request_id, request.request_id) + self.assertEqual(read_request.needs_processing, request.needs_processing) + self.assertBetween(read_request.timestamp, before_write, after_write) self.assertEqual(list(read_responses), [0, 1, 2]) for response_id, response in read_responses.items(): @@ -877,18 +970,24 @@ def testResponseWritingForDuplicateResponses(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1, - needs_processing=False) + needs_processing=False, + ) + + before_write = self.db.Now() self.db.WriteFlowRequests([request]) + after_write = self.db.Now() responses = [ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=0), - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=0) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=0 + ), + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=0 + ), ] self.db.WriteFlowResponses(responses) @@ -897,7 +996,11 @@ def testResponseWritingForDuplicateResponses(self): self.assertLen(all_requests, 1) read_request, read_responses = all_requests[0] - self.assertEqual(read_request, request) + self.assertEqual(read_request.client_id, request.client_id) + self.assertEqual(read_request.flow_id, request.flow_id) + self.assertEqual(read_request.request_id, request.request_id) + self.assertEqual(read_request.needs_processing, request.needs_processing) + self.assertBetween(read_request.timestamp, before_write, after_write) self.assertEqual(list(read_responses), [0]) for response_id, response in read_responses.items(): @@ -911,19 +1014,23 @@ def testCompletingMultipleRequests(self): responses = [] for i in range(5): requests.append( - rdf_flow_objects.FlowRequest( + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=i, - needs_processing=False)) + needs_processing=False, + ) + ) responses.append( - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=i, - response_id=1)) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=i, response_id=1 + ) + ) responses.append( - rdf_flow_objects.FlowStatus( - client_id=client_id, flow_id=flow_id, request_id=i, - response_id=2)) + flows_pb2.FlowStatus( + client_id=client_id, flow_id=flow_id, request_id=i, response_id=2 + ) + ) self.db.WriteFlowRequests(requests) @@ -931,7 +1038,8 @@ def testCompletingMultipleRequests(self): self.db.WriteFlowResponses(responses) read = self.db.ReadAllFlowRequestsAndResponses( - client_id=client_id, flow_id=flow_id) + client_id=client_id, flow_id=flow_id + ) self.assertEqual(len(read), 5) for req, _ in read: self.assertTrue(req.needs_processing) @@ -940,25 +1048,31 @@ def testStatusMessagesCanBeWrittenAndRead(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1, - needs_processing=False) + needs_processing=False, + ) self.db.WriteFlowRequests([request]) responses = [ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=i) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=i + ) for i in range(3) ] # Also store an Iterator, why not. responses.append( - rdf_flow_objects.FlowIterator( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=3)) + flows_pb2.FlowIterator( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=3 + ) + ) responses.append( - rdf_flow_objects.FlowStatus( - client_id=client_id, flow_id=flow_id, request_id=1, response_id=4)) + flows_pb2.FlowStatus( + client_id=client_id, flow_id=flow_id, request_id=1, response_id=4 + ) + ) self.db.WriteFlowResponses(responses) all_requests = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) @@ -967,9 +1081,9 @@ def testStatusMessagesCanBeWrittenAndRead(self): _, read_responses = all_requests[0] self.assertEqual(list(read_responses), [0, 1, 2, 3, 4]) for i in range(3): - self.assertIsInstance(read_responses[i], rdf_flow_objects.FlowResponse) - self.assertIsInstance(read_responses[3], rdf_flow_objects.FlowIterator) - self.assertIsInstance(read_responses[4], rdf_flow_objects.FlowStatus) + self.assertIsInstance(read_responses[i], flows_pb2.FlowResponse) + self.assertIsInstance(read_responses[3], flows_pb2.FlowIterator) + self.assertIsInstance(read_responses[4], flows_pb2.FlowStatus) def _ReadRequest(self, client_id, flow_id, request_id): all_requests = self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) @@ -977,41 +1091,53 @@ def _ReadRequest(self, client_id, flow_id, request_id): if request.request_id == request_id: return request - def _Responses(self, client_id, flow_id, request_id, num_responses): + def _Responses( + self, client_id, flow_id, request_id, num_responses + ) -> list[flows_pb2.FlowResponse]: return [ - rdf_flow_objects.FlowResponse( + flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=i) for i in range(1, num_responses + 1) + response_id=i, + ) + for i in range(1, num_responses + 1) ] - def _ResponsesAndStatus(self, client_id, flow_id, request_id, num_responses): + def _ResponsesAndStatus( + self, client_id, flow_id, request_id, num_responses + ) -> list[Union[flows_pb2.FlowResponse, flows_pb2.FlowStatus]]: return self._Responses(client_id, flow_id, request_id, num_responses) + [ - rdf_flow_objects.FlowStatus( + flows_pb2.FlowStatus( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=num_responses + 1) + response_id=num_responses + 1, + ) ] - def _WriteRequestAndCompleteResponses(self, client_id, flow_id, request_id, - num_responses): - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=request_id) + def _WriteRequestAndCompleteResponses( + self, client_id, flow_id, request_id, num_responses + ): + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=request_id + ) self.db.WriteFlowRequests([request]) return self._WriteCompleteResponses( client_id=client_id, flow_id=flow_id, request_id=request_id, - num_responses=num_responses) + num_responses=num_responses, + ) - def _WriteCompleteResponses(self, client_id, flow_id, request_id, - num_responses): + def _WriteCompleteResponses( + self, client_id, flow_id, request_id, num_responses + ): # Write responses and a status in random order. - responses = self._ResponsesAndStatus(client_id, flow_id, request_id, - num_responses) + responses = self._ResponsesAndStatus( + client_id, flow_id, request_id, num_responses + ) random.shuffle(responses) for response in responses: @@ -1034,9 +1160,11 @@ def testResponsesForEarlierRequestDontTriggerFlowProcessing(self): # Write a flow that is waiting for request #2. client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) requests_triggered = self._WriteRequestAndCompleteResponses( - client_id, flow_id, request_id=1, num_responses=3) + client_id, flow_id, request_id=1, num_responses=3 + ) # No flow processing request generated for request 1 (we are waiting # for #2). @@ -1045,84 +1173,96 @@ def testResponsesForEarlierRequestDontTriggerFlowProcessing(self): def testResponsesForEarlierIncrementalRequestDontTriggerFlowProcessing(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) request_id = 1 - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, next_state="Next", - callback_state="Callback") + callback_state="Callback", + ) self.db.WriteFlowRequests([request]) responses = self._Responses(client_id, flow_id, request_id, 1) self.db.WriteFlowResponses(responses) requests_to_process = self.db.ReadFlowProcessingRequests() - self.assertLen(requests_to_process, 0) + self.assertEmpty(requests_to_process) def testResponsesForLaterRequestDontTriggerFlowProcessing(self): # Write a flow that is waiting for request #2. client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) requests_triggered = self._WriteRequestAndCompleteResponses( - client_id, flow_id, request_id=3, num_responses=7) + client_id, flow_id, request_id=3, num_responses=7 + ) # No flow processing request generated for request 3 (we are waiting # for #2). self.assertEqual(requests_triggered, 0) def testResponsesForLaterIncrementalRequestDoNotTriggerIncrementalProcessing( - self): + self, + ): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) request_id = 3 - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, next_state="Next", - callback_state="Callback") + callback_state="Callback", + ) self.db.WriteFlowRequests([request]) responses = self._Responses(client_id, flow_id, request_id, 1) self.db.WriteFlowResponses(responses) requests_to_process = self.db.ReadFlowProcessingRequests() - self.assertLen(requests_to_process, 0) + self.assertEmpty(requests_to_process) def testResponsesForExpectedRequestTriggerFlowProcessing(self): # Write a flow that is waiting for request #2. client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) requests_triggered = self._WriteRequestAndCompleteResponses( - client_id, flow_id, request_id=2, num_responses=5) + client_id, flow_id, request_id=2, num_responses=5 + ) # This one generates a request. self.assertEqual(requests_triggered, 1) def testResponsesForExpectedIncrementalRequestTriggerIncrementalProcessing( - self): + self, + ): request_id = 2 client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=request_id) + self.db, client_id, next_request_to_process=request_id + ) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, next_state="Next", - callback_state="Callback") + callback_state="Callback", + ) self.db.WriteFlowRequests([request]) requests_to_process = self.db.ReadFlowProcessingRequests() - self.assertLen(requests_to_process, 0) + self.assertEmpty(requests_to_process) responses = self._Responses(client_id, flow_id, request_id, 1) self.db.WriteFlowResponses(responses) @@ -1130,28 +1270,98 @@ def testResponsesForExpectedIncrementalRequestTriggerIncrementalProcessing( requests_to_process = self.db.ReadFlowProcessingRequests() self.assertLen(requests_to_process, 1) + def testCompletingRequestsWithResponsesTriggersDelayedProcessingCorrectly( + self, + ): + # Pretend that the flow currently processes request #1. + client_id = db_test_utils.InitializeClient(self.db) + flow_id = db_test_utils.InitializeFlow( + self.db, client_id, next_request_to_process=2 + ) + + start_time = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( + 6, rdfvalue.SECONDS + ) + + request = flows_pb2.FlowRequest( + client_id=client_id, + flow_id=flow_id, + request_id=2, + nr_responses_expected=2, + start_time=int(start_time), + next_state="Foo", + next_response_id=1, + needs_processing=False, + ) + + responses = [] + responses.append( + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=2, response_id=0 + ) + ) + responses.append( + flows_pb2.FlowStatus( + client_id=client_id, + flow_id=flow_id, + request_id=2, + response_id=2, + status=rdf_flow_objects.FlowStatus.Status.OK, + ) + ) + + request_queue = queue.Queue() + + def Callback(request: flows_pb2.FlowProcessingRequest): + self.db.AckFlowProcessingRequests([request]) + request_queue.put(request) + + self.db.RegisterFlowProcessingHandler(Callback) + self.addCleanup(self.db.UnregisterFlowProcessingHandler) + + self.db.WriteFlowRequests([request]) + self.db.WriteFlowResponses(responses) + + # The request #2 shouldn't be processed within 3 seconds... + try: + request_queue.get(True, timeout=3) + self.fail("Notification arrived too quickly") + except queue.Empty: + pass + + # ...but should be procssed within the next 10 seconds. + try: + request_queue.get(True, timeout=10) + except queue.Empty: + self.fail("Notification didn't arrive when it was expected to.") + def testRewritingResponsesForRequestDoesNotTriggerAdditionalProcessing(self): # Write a flow that is waiting for request #2. client_id = db_test_utils.InitializeClient(self.db) marked_client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) marked_flow_id = db_test_utils.InitializeFlow( - self.db, marked_client_id, next_request_to_process=2) + self.db, marked_client_id, next_request_to_process=2 + ) - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=2) + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=2 + ) self.db.WriteFlowRequests([request]) - marked_request = rdf_flow_objects.FlowRequest( - client_id=marked_client_id, flow_id=marked_flow_id, request_id=2) + marked_request = flows_pb2.FlowRequest( + client_id=marked_client_id, flow_id=marked_flow_id, request_id=2 + ) self.db.WriteFlowRequests([marked_request]) # Generate responses together with a status message. responses = self._ResponsesAndStatus(client_id, flow_id, 2, 4) - marked_responses = self._ResponsesAndStatus(marked_client_id, - marked_flow_id, 2, 4) + marked_responses = self._ResponsesAndStatus( + marked_client_id, marked_flow_id, 2, 4 + ) req_func = mock.Mock() self.db.RegisterFlowProcessingHandler(req_func) @@ -1165,7 +1375,8 @@ def testRewritingResponsesForRequestDoesNotTriggerAdditionalProcessing(self): break time.sleep(0.1) if rdfvalue.RDFDatetime.Now() - cur_time > rdfvalue.Duration.From( - 10, rdfvalue.SECONDS): + 10, rdfvalue.SECONDS + ): self.fail("Flow request was not processed in time.") req_func.reset_mock() @@ -1204,21 +1415,24 @@ def testRewritingResponsesForRequestDoesNotTriggerAdditionalProcessing(self): time.sleep(0.1) if rdfvalue.RDFDatetime.Now() - cur_time > rdfvalue.Duration.From( - 10, rdfvalue.SECONDS): + 10, rdfvalue.SECONDS + ): self.fail("Flow request was not processed in time.") def testRewritingResponsesForIncrementalRequestsTriggersMoreProcessing(self): request_id = 2 client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=request_id) + self.db, client_id, next_request_to_process=request_id + ) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, next_state="Next", - callback_state="Callback") + callback_state="Callback", + ) self.db.WriteFlowRequests([request]) responses = self._Responses(client_id, flow_id, request_id, 1) @@ -1229,7 +1443,7 @@ def testRewritingResponsesForIncrementalRequestsTriggersMoreProcessing(self): self.db.DeleteAllFlowProcessingRequests() requests_to_process = self.db.ReadFlowProcessingRequests() - self.assertLen(requests_to_process, 0) + self.assertEmpty(requests_to_process) # Writing same responses second time triggers a processing requests. self.db.WriteFlowResponses(responses) @@ -1239,24 +1453,28 @@ def testRewritingResponsesForIncrementalRequestsTriggersMoreProcessing(self): def testLeaseFlowForProcessingRaisesIfParentHuntIsStoppedOrCompleted(self): hunt_id = db_test_utils.InitializeHunt(self.db) self.db.UpdateHuntObject( - hunt_id, hunt_state=rdf_hunt_objects.Hunt.HuntState.STOPPED) + hunt_id, hunt_state=hunts_pb2.Hunt.HuntState.STOPPED + ) client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, parent_hunt_id=hunt_id) + self.db, client_id, parent_hunt_id=hunt_id + ) processing_time = rdfvalue.Duration.From(60, rdfvalue.SECONDS) with self.assertRaises(db.ParentHuntIsNotRunningError): self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) self.db.UpdateHuntObject( - hunt_id, hunt_state=rdf_hunt_objects.Hunt.HuntState.COMPLETED) + hunt_id, hunt_state=hunts_pb2.Hunt.HuntState.COMPLETED + ) with self.assertRaises(db.ParentHuntIsNotRunningError): self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) self.db.UpdateHuntObject( - hunt_id, hunt_state=rdf_hunt_objects.Hunt.HuntState.STARTED) + hunt_id, hunt_state=hunts_pb2.Hunt.HuntState.STARTED + ) # Should work again. self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) @@ -1267,13 +1485,14 @@ def testLeaseFlowForProcessingThatIsAlreadyBeingProcessed(self): processing_time = rdfvalue.Duration.From(60, rdfvalue.SECONDS) flow_for_processing = self.db.LeaseFlowForProcessing( - client_id, flow_id, processing_time) + client_id, flow_id, processing_time + ) # Already marked as being processed. with self.assertRaises(ValueError): self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) - self.db.ReleaseProcessedFlow(flow_for_processing) + self.assertTrue(self.db.ReleaseProcessedFlow(flow_for_processing)) # Should work again. self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) @@ -1291,8 +1510,9 @@ def testLeaseFlowForProcessingAfterProcessingTimeExpiration(self): with self.assertRaises(ValueError): self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) - after_deadline = now + processing_time + rdfvalue.Duration.From( - 1, rdfvalue.SECONDS) + after_deadline = ( + now + processing_time + rdfvalue.Duration.From(1, rdfvalue.SECONDS) + ) with test_lib.FakeTime(after_deadline): # Should work again. self.db.LeaseFlowForProcessing(client_id, flow_id, processing_time) @@ -1304,10 +1524,12 @@ def testLeaseFlowForProcessingUpdatesHuntCounters(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, flow_id=hunt_id, parent_hunt_id=hunt_id) + self.db, client_id, flow_id=hunt_id, parent_hunt_id=hunt_id + ) flow_for_processing = self.db.LeaseFlowForProcessing( - client_id, flow_id, processing_time) + client_id, flow_id, processing_time + ) flow_for_processing.num_replies_sent = 10 client_summary_result = flows_pb2.FlowResult( @@ -1335,23 +1557,29 @@ def testLeaseFlowForProcessingUpdatesFlowObjects(self): with test_lib.FakeTime(now): flow_for_processing = self.db.LeaseFlowForProcessing( - client_id, flow_id, processing_time) + client_id, flow_id, processing_time + ) self.assertEqual(flow_for_processing.processing_on, utils.ProcessIdString()) # Using assertGreaterEqual and assertLess, as processing_since might come # from the commit timestamp and thus not be influenced by test_lib.FakeTime. - self.assertGreaterEqual(flow_for_processing.processing_since, now) - self.assertLess(flow_for_processing.processing_since, - now + rdfvalue.Duration("5s")) - self.assertEqual(flow_for_processing.processing_deadline, - processing_deadline) + self.assertGreaterEqual(flow_for_processing.processing_since, int(now)) + self.assertLess( + flow_for_processing.processing_since, int(now + rdfvalue.Duration("5s")) + ) + self.assertEqual( + flow_for_processing.processing_deadline, int(processing_deadline) + ) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertEqual(read_flow.processing_on, flow_for_processing.processing_on) - self.assertEqual(flow_for_processing.processing_since, - flow_for_processing.processing_since) - self.assertEqual(read_flow.processing_deadline, - flow_for_processing.processing_deadline) + self.assertEqual( + read_flow.processing_since, flow_for_processing.processing_since + ) + self.assertEqual( + read_flow.processing_deadline, + flow_for_processing.processing_deadline, + ) self.assertEqual(read_flow.num_replies_sent, 0) flow_for_processing.next_request_to_process = 5 @@ -1360,37 +1588,41 @@ def testLeaseFlowForProcessingUpdatesFlowObjects(self): self.assertTrue(self.db.ReleaseProcessedFlow(flow_for_processing)) # Check that returning the flow doesn't change the flow object. self.assertEqual(read_flow.processing_on, flow_for_processing.processing_on) - self.assertEqual(flow_for_processing.processing_since, - flow_for_processing.processing_since) - self.assertEqual(read_flow.processing_deadline, - flow_for_processing.processing_deadline) + self.assertEqual( + read_flow.processing_since, flow_for_processing.processing_since + ) + self.assertEqual( + read_flow.processing_deadline, + flow_for_processing.processing_deadline, + ) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertFalse(read_flow.processing_on) - self.assertIsNone(read_flow.processing_since) - self.assertIsNone(read_flow.processing_deadline) + self.assertFalse(read_flow.HasField("processing_since")) + self.assertFalse(read_flow.HasField("processing_deadline")) self.assertEqual(read_flow.next_request_to_process, 5) self.assertEqual(read_flow.num_replies_sent, 10) def testFlowLastUpdateTime(self): processing_time = rdfvalue.Duration.From(60, rdfvalue.SECONDS) - t0 = self.db.Now() + t0 = self.db.Now().AsMicrosecondsSinceEpoch() client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - t1 = self.db.Now() + t1 = self.db.Now().AsMicrosecondsSinceEpoch() read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertBetween(read_flow.last_update_time, t0, t1) flow_for_processing = self.db.LeaseFlowForProcessing( - client_id, flow_id, processing_time) + client_id, flow_id, processing_time + ) self.assertBetween(flow_for_processing.last_update_time, t0, t1) - t2 = self.db.Now() + t2 = self.db.Now().AsMicrosecondsSinceEpoch() self.db.ReleaseProcessedFlow(flow_for_processing) - t3 = self.db.Now() + t3 = self.db.Now().AsMicrosecondsSinceEpoch() read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertBetween(read_flow.last_update_time, t2, t3) @@ -1398,44 +1630,49 @@ def testFlowLastUpdateTime(self): def testReleaseProcessedFlow(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) processing_time = rdfvalue.Duration.From(60, rdfvalue.SECONDS) - processed_flow = self.db.LeaseFlowForProcessing(client_id, flow_id, - processing_time) + processed_flow = self.db.LeaseFlowForProcessing( + client_id, flow_id, processing_time + ) # Let's say we processed one request on this flow. processed_flow.next_request_to_process = 2 # There are some requests ready for processing but not #2. self.db.WriteFlowRequests([ - rdf_flow_objects.FlowRequest( + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1, - needs_processing=True), - rdf_flow_objects.FlowRequest( + needs_processing=True, + ), + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=4, - needs_processing=True) + needs_processing=True, + ), ]) - self.assertTrue(self.db.ReleaseProcessedFlow(processed_flow)) - processed_flow = self.db.LeaseFlowForProcessing(client_id, flow_id, - processing_time) + processed_flow = self.db.LeaseFlowForProcessing( + client_id, flow_id, processing_time + ) # And another one. processed_flow.next_request_to_process = 3 # But in the meantime, request 3 is ready for processing. self.db.WriteFlowRequests([ - rdf_flow_objects.FlowRequest( + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=3, - needs_processing=True) + needs_processing=True, + ) ]) self.assertFalse(self.db.ReleaseProcessedFlow(processed_flow)) @@ -1458,97 +1695,104 @@ def testReleaseProcessedFlowWithRequestScheduledInFuture(self): # There is a request ready for processing but only in the future. # It shouldn't be returned as ready for processing. self.db.WriteFlowRequests([ - rdf_flow_objects.FlowRequest( + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=2, - start_time=rdfvalue.RDFDatetime.Now() - + rdfvalue.Duration.From(60, rdfvalue.SECONDS), + start_time=int( + rdfvalue.RDFDatetime.Now() + + rdfvalue.Duration.From(60, rdfvalue.SECONDS) + ), needs_processing=True, ), ]) - self.assertTrue(self.db.ReleaseProcessedFlow(processed_flow)) def testReleaseProcessedFlowWithProcessedFlowRequest(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1 flow_request.needs_processing = False self.db.WriteFlowRequests([flow_request]) - flow_obj = rdf_flow_objects.Flow() + flow_obj = flows_pb2.Flow() flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.next_request_to_process = 1 self.assertTrue(self.db.ReleaseProcessedFlow(flow_obj)) def testReadChildFlows(self): - client_id = u"C.1234567890123456" + client_id = "C.1234567890123456" self.db.WriteClientMetadata(client_id) self.db.WriteFlowObject( - rdf_flow_objects.Flow(flow_id=u"00000001", client_id=client_id)) + flows_pb2.Flow(flow_id="00000001", client_id=client_id) + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=u"00000002", - client_id=client_id, - parent_flow_id=u"00000001")) + flows_pb2.Flow( + flow_id="00000002", client_id=client_id, parent_flow_id="00000001" + ) + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=u"00000003", - client_id=client_id, - parent_flow_id=u"00000002")) + flows_pb2.Flow( + flow_id="00000003", client_id=client_id, parent_flow_id="00000002" + ) + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=u"00000004", - client_id=client_id, - parent_flow_id=u"00000001")) + flows_pb2.Flow( + flow_id="00000004", client_id=client_id, parent_flow_id="00000001" + ) + ) # This one is completely unrelated (different client id). self.db.WriteClientMetadata("C.1234567890123457") self.db.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=u"00000001", client_id=u"C.1234567890123457")) + flows_pb2.Flow(flow_id="00000001", client_id="C.1234567890123457") + ) self.db.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=u"00000002", - client_id=u"C.1234567890123457", - parent_flow_id=u"00000001")) + flows_pb2.Flow( + flow_id="00000002", + client_id="C.1234567890123457", + parent_flow_id="00000001", + ) + ) - children = self.db.ReadChildFlowObjects(client_id, u"00000001") + children = self.db.ReadChildFlowObjects(client_id, "00000001") self.assertLen(children, 2) for c in children: - self.assertEqual(c.parent_flow_id, u"00000001") + self.assertEqual(c.parent_flow_id, "00000001") - children = self.db.ReadChildFlowObjects(client_id, u"00000002") + children = self.db.ReadChildFlowObjects(client_id, "00000002") self.assertLen(children, 1) - self.assertEqual(children[0].parent_flow_id, u"00000002") - self.assertEqual(children[0].flow_id, u"00000003") + self.assertEqual(children[0].parent_flow_id, "00000002") + self.assertEqual(children[0].flow_id, "00000003") - children = self.db.ReadChildFlowObjects(client_id, u"00000003") + children = self.db.ReadChildFlowObjects(client_id, "00000003") self.assertEmpty(children) def _WriteRequestAndResponses(self, client_id, flow_id): - rdf_flow = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) - self.db.WriteFlowObject(rdf_flow) + self.db.WriteFlowObject( + flows_pb2.Flow(client_id=client_id, flow_id=flow_id) + ) for request_id in range(1, 4): - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=request_id) + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=request_id + ) self.db.WriteFlowRequests([request]) for response_id in range(1, 3): - response = rdf_flow_objects.FlowResponse( + response = flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=response_id) + response_id=response_id, + ) self.db.WriteFlowResponses([response]) def _CheckRequestsAndResponsesAreThere(self, client_id, flow_id): @@ -1558,10 +1802,10 @@ def _CheckRequestsAndResponsesAreThere(self, client_id, flow_id): self.assertLen(responses, 2) def testDeleteAllFlowRequestsAndResponses(self): - client_id1 = u"C.1234567890123456" - client_id2 = u"C.1234567890123457" - flow_id1 = u"1234ABCD" - flow_id2 = u"1234ABCE" + client_id1 = "C.1234567890123456" + client_id2 = "C.1234567890123457" + flow_id1 = "1234ABCD" + flow_id2 = "1234ABCE" self.db.WriteClientMetadata(client_id1) self.db.WriteClientMetadata(client_id2) @@ -1586,35 +1830,42 @@ def testDeleteAllFlowRequestsAndResponses(self): self.assertEqual(all_requests, []) def testReadFlowRequestsReadyForProcessing(self): - client_id = u"C.1234567890000000" - flow_id = u"12344321" + client_id = "C.1234567890000000" + flow_id = "12344321" requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing( - client_id, flow_id, next_needed_request=1) + client_id, flow_id, next_needed_request=1 + ) self.assertEqual(requests_for_processing, {}) client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) for request_id in [1, 3, 4, 5, 7]: - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, - needs_processing=True) + needs_processing=True, + ) self.db.WriteFlowRequests([request]) # Request 4 has some responses. responses = [ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=4, response_id=i) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=4, response_id=i + ) for i in range(3) ] + before_write_ts = self.db.Now() self.db.WriteFlowResponses(responses) + after_write_ts = self.db.Now() requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing( - client_id, flow_id, next_needed_request=3) + client_id, flow_id, next_needed_request=3 + ) # We expect three requests here. Req #1 is old and should not be there, req # #7 can't be processed since we are missing #6 in between. That leaves @@ -1626,36 +1877,47 @@ def testReadFlowRequestsReadyForProcessing(self): request, _ = requests_for_processing[request_id] self.assertEqual(request_id, request.request_id) - self.assertEqual(requests_for_processing[4][1], responses) + for res, exp in zip(requests_for_processing[4][1], responses): + self.assertEqual(res.client_id, exp.client_id) + self.assertEqual(res.flow_id, exp.flow_id) + self.assertEqual(res.request_id, exp.request_id) + self.assertEqual(res.response_id, exp.response_id) + self.assertBetween(res.timestamp, before_write_ts, after_write_ts) def testReadFlowRequestsReadyForProcessingHandlesIncrementalResponses(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) for request_id in [1, 3, 4, 5, 7]: - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, callback_state="Callback" if request_id != 5 else None, - next_state="Next") + next_state="Next", + ) self.db.WriteFlowRequests([request]) # Request 4 has some responses. responses = [] for i in range(3): responses.append( - rdf_flow_objects.FlowResponse( + flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=4, - response_id=i + 1)) + response_id=i + 1, + ) + ) + before_write_ts = self.db.Now() self.db.WriteFlowResponses(responses) + after_write_ts = self.db.Now() # An incremental request is always returned as ready for processing. requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing( - client_id, flow_id, next_needed_request=3) + client_id, flow_id, next_needed_request=3 + ) # We expect three requests here. Req #1 is old and should not be there, req # #5 can't be processed since it's not incremental and is not done. That @@ -1665,34 +1927,47 @@ def testReadFlowRequestsReadyForProcessingHandlesIncrementalResponses(self): # Requests for processing contains pairs of (request, responses) as values. self.assertEqual(requests_for_processing[3][1], []) - self.assertEqual(requests_for_processing[4][1], responses) self.assertEqual(requests_for_processing[7][1], []) + for req, exp in zip(requests_for_processing[4][1], responses): + self.assertEqual(req.client_id, exp.client_id) + self.assertEqual(req.flow_id, exp.flow_id) + self.assertEqual(req.request_id, exp.request_id) + self.assertEqual(req.response_id, exp.response_id) + + self.assertBetween(req.timestamp, before_write_ts, after_write_ts) + def testReadFlowRequestsReadyForProcessingWithUnorderedIncrementalResponses( - self): + self, + ): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=3, next_response_id=4, callback_state="Callback", - next_state="Next") + next_state="Next", + ) self.db.WriteFlowRequests([request]) # Request 4 has some responses. responses = [ - rdf_flow_objects.FlowResponse( - client_id=client_id, flow_id=flow_id, request_id=3, response_id=i) + flows_pb2.FlowResponse( + client_id=client_id, flow_id=flow_id, request_id=3, response_id=i + ) for i in [1, 2, 5, 7] ] + before_write_ts = self.db.Now() self.db.WriteFlowResponses(responses) + after_write_ts = self.db.Now() # An incremental request is always returned as ready for processing. requests_for_processing = self.db.ReadFlowRequestsReadyForProcessing( - client_id, flow_id, next_needed_request=3) + client_id, flow_id, next_needed_request=3 + ) # requests_for_processing is a dict that's expected to have a single element # with key 3. @@ -1702,7 +1977,12 @@ def testReadFlowRequestsReadyForProcessingWithUnorderedIncrementalResponses( _, fetched_responses = requests_for_processing[3] # Only responses with response_id >= than request.next_response_id should # be returned. - self.assertListEqual(fetched_responses, [responses[2], responses[3]]) + for res, exp in zip(fetched_responses, [responses[2], responses[3]]): + self.assertEqual(res.client_id, exp.client_id) + self.assertEqual(res.flow_id, exp.flow_id) + self.assertEqual(res.request_id, exp.request_id) + self.assertEqual(res.response_id, exp.response_id) + self.assertBetween(res.timestamp, before_write_ts, after_write_ts) def testUpdateIncrementalFlowRequests(self): client_id = db_test_utils.InitializeClient(self.db) @@ -1711,12 +1991,14 @@ def testUpdateIncrementalFlowRequests(self): requests = [] for request_id in range(10): requests.append( - rdf_flow_objects.FlowRequest( + flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=request_id, next_state="Next", - callback_state="Callback")) + callback_state="Callback", + ) + ) self.db.WriteFlowRequests(requests) update_map = dict((i, i * 2) for i in range(10)) @@ -1736,7 +2018,7 @@ def testFlowProcessingRequestsQueue(self): request_queue = queue.Queue() - def Callback(request): + def Callback(request: flows_pb2.FlowProcessingRequest): self.db.AckFlowProcessingRequests([request]) request_queue.put(request) @@ -1746,20 +2028,33 @@ def Callback(request): requests = [] for flow_id in flow_ids: requests.append( - rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)) + flows_pb2.FlowProcessingRequest(client_id=client_id, flow_id=flow_id) + ) + pre_creation_time = self.db.Now() self.db.WriteFlowProcessingRequests(requests) + post_creation_time = self.db.Now() got = [] while len(got) < 5: try: l = request_queue.get(True, timeout=6) + got.append(l) except queue.Empty: - self.fail("Timed out waiting for messages, expected 5, got %d" % - len(got)) - got.append(l) + self.fail( + "Timed out waiting for messages, expected 5, got %d" % len(got) + ) - self.assertCountEqual(requests, got) + self.assertCountEqual( + [r.client_id for r in requests], [g.client_id for g in got] + ) + self.assertCountEqual( + [r.flow_id for r in requests], [g.flow_id for g in got] + ) + for g in got: + self.assertBetween( + int(g.creation_time), pre_creation_time, post_creation_time + ) def testFlowProcessingRequestsQueueWithDelay(self): client_id = db_test_utils.InitializeClient(self.db) @@ -1771,7 +2066,7 @@ def testFlowProcessingRequestsQueueWithDelay(self): request_queue = queue.Queue() - def Callback(request): + def Callback(request: flows_pb2.FlowProcessingRequest): self.db.AckFlowProcessingRequests([request]) request_queue.put(request) @@ -1780,30 +2075,44 @@ def Callback(request): now = rdfvalue.RDFDatetime.Now() delivery_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch( - now.AsSecondsSinceEpoch() + 0.5) + now.AsSecondsSinceEpoch() + 0.5 + ) requests = [] for flow_id in flow_ids: requests.append( - rdf_flows.FlowProcessingRequest( - client_id=client_id, flow_id=flow_id, - delivery_time=delivery_time)) + flows_pb2.FlowProcessingRequest( + client_id=client_id, + flow_id=flow_id, + delivery_time=int(delivery_time), + ) + ) + pre_creation_time = self.db.Now() self.db.WriteFlowProcessingRequests(requests) + post_creation_time = self.db.Now() got = [] while len(got) < 5: try: l = request_queue.get(True, timeout=6) + got.append(l) except queue.Empty: - self.fail("Timed out waiting for messages, expected 5, got %d" % - len(got)) - got.append(l) + self.fail( + "Timed out waiting for messages, expected 5, got %d" % len(got) + ) self.assertGreater(rdfvalue.RDFDatetime.Now(), l.delivery_time) - self.assertCountEqual(requests, got) + self.assertCountEqual( + [r.client_id for r in requests], [g.client_id for g in got] + ) + self.assertCountEqual( + [r.flow_id for r in requests], [g.flow_id for g in got] + ) + for g in got: + self.assertBetween(g.creation_time, pre_creation_time, post_creation_time) leftover = self.db.ReadFlowProcessingRequests() - self.assertEqual(leftover, []) + self.assertEmpty(leftover) def testFlowRequestsStartTimeIsRespectedWhenResponsesAreWritten(self): client_id = db_test_utils.InitializeClient(self.db) @@ -1811,7 +2120,7 @@ def testFlowRequestsStartTimeIsRespectedWhenResponsesAreWritten(self): request_queue = queue.Queue() - def Callback(request): + def Callback(request: flows_pb2.FlowProcessingRequest): self.db.AckFlowProcessingRequests([request]) request_queue.put(request) @@ -1822,25 +2131,28 @@ def Callback(request): delivery_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch( now.AsSecondsSinceEpoch() + 5 ) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=1, - start_time=delivery_time, + start_time=int(delivery_time), nr_responses_expected=1, next_state="Foo", needs_processing=True, ) self.db.WriteFlowRequests([request]) - response = rdf_flow_objects.FlowResponse( + payload_any = any_pb2.Any() + payload_any.Pack(flows_pb2.FlowRequest()) + + response = flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=request.request_id, response_id=0, # For the purpose of the test, the payload can be arbitrary, # using rdf_flow_objects.FlowRequest as a sample struct. - payload=rdf_flow_objects.FlowRequest(), + payload=payload_any, ) self.db.WriteFlowResponses([response]) @@ -1866,7 +2178,7 @@ def testFlowProcessingRequestIsAlwaysWrittenIfStartTimeIsSpecified(self): now = rdfvalue.RDFDatetime.Now() delivery_time = now + rdfvalue.Duration("5s") - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, # Note that the request_id is different from the next_request_id @@ -1877,7 +2189,7 @@ def testFlowProcessingRequestIsAlwaysWrittenIfStartTimeIsSpecified(self): # next_request_id might be equivalent to the FlowRequest's request_id # and the state will get processed. request_id=42, - start_time=delivery_time, + start_time=int(delivery_time), nr_responses_expected=1, next_state="Foo", needs_processing=True, @@ -1894,7 +2206,7 @@ def testFPRNotWrittenIfStartTimeNotSpecifiedAndIdDoesNotMatch(self): fprs = self.db.ReadFlowProcessingRequests() self.assertEmpty(fprs) - request = rdf_flow_objects.FlowRequest( + request = flows_pb2.FlowRequest( client_id=client_id, flow_id=flow_id, request_id=42, @@ -1924,9 +2236,12 @@ def testAcknowledgingFlowProcessingRequestsWorks(self): requests = [] for flow_id in flow_ids: requests.append( - rdf_flows.FlowProcessingRequest( - client_id=client_id, flow_id=flow_id, - delivery_time=delivery_time)) + flows_pb2.FlowProcessingRequest( + client_id=client_id, + flow_id=flow_id, + delivery_time=int(delivery_time), + ) + ) self.db.WriteFlowProcessingRequests(requests) @@ -1941,12 +2256,14 @@ def testAcknowledgingFlowProcessingRequestsWorks(self): self.db.AckFlowProcessingRequests(stored_requests[1:3]) stored_requests = self.db.ReadFlowProcessingRequests() self.assertLen(stored_requests, 3) - self.assertCountEqual([r.flow_id for r in stored_requests], - [flow_ids[0], flow_ids[3], flow_ids[4]]) + self.assertCountEqual( + [r.flow_id for r in stored_requests], + [flow_ids[0], flow_ids[3], flow_ids[4]], + ) # Make sure DeleteAllFlowProcessingRequests removes all requests. self.db.DeleteAllFlowProcessingRequests() - self.assertEqual(self.db.ReadFlowProcessingRequests(), []) + self.assertEmpty(self.db.ReadFlowProcessingRequests()) self.db.UnregisterFlowProcessingHandler() @@ -2078,7 +2395,8 @@ def testWritesAndReadsMultipleFlowResultsOfSingleType(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id)) + self._SampleResults(client_id, flow_id) + ) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) self.assertLen(results, len(sample_results)) @@ -2096,7 +2414,8 @@ def testWritesAndReadsMultipleFlowResultsWithDifferentTimestamps(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) self.assertLen(results, len(sample_results)) @@ -2155,37 +2474,44 @@ def testReadFlowResultsCorrectlyAppliesOffsetAndCountFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) for l in range(1, 11): for i in range(10): results = self.db.ReadFlowResults(client_id, flow_id, i, l) - expected = sample_results[i:i + l] + expected = sample_results[i : i + l] result_payloads = [x.payload for x in results] expected_payloads = [x.payload for x in expected] self.assertEqual( - result_payloads, expected_payloads, - "Results differ from expected (from %d, size %d): %s vs %s" % - (i, l, result_payloads, expected_payloads)) + result_payloads, + expected_payloads, + "Results differ from expected (from %d, size %d): %s vs %s" + % (i, l, result_payloads, expected_payloads), + ) def testReadFlowResultsCorrectlyAppliesWithTagFilter(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_tag="blah") + client_id, flow_id, 0, 100, with_tag="blah" + ) self.assertFalse(results) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_tag="tag") + client_id, flow_id, 0, 100, with_tag="tag" + ) self.assertFalse(results) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_tag="tag_1") + client_id, flow_id, 0, 100, with_tag="tag_1" + ) self.assertEqual([i.payload for i in results], [sample_results[1].payload]) def testReadFlowResultsCorrectlyAppliesWithTypeFilter(self): @@ -2224,11 +2550,13 @@ def testReadFlowResultsCorrectlyAppliesWithTypeFilter(self): flow_id, 0, 100, - with_type=rdf_client.ClientInformation.__name__) + with_type=rdf_client.ClientInformation.__name__, + ) self.assertFalse(results) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_type=rdf_client.ClientSummary.__name__) + client_id, flow_id, 0, 100, with_type=rdf_client.ClientSummary.__name__ + ) self.assertCountEqual( [i.payload for i in results], [i.payload for i in sample_results_summary], @@ -2239,21 +2567,25 @@ def testReadFlowResultsCorrectlyAppliesWithSubstringFilter(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_substring="blah") + client_id, flow_id, 0, 100, with_substring="blah" + ) self.assertFalse(results) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_substring="manufacturer") + client_id, flow_id, 0, 100, with_substring="manufacturer" + ) self.assertEqual( [i.payload for i in results], [i.payload for i in sample_results], ) results = self.db.ReadFlowResults( - client_id, flow_id, 0, 100, with_substring="manufacturer_1") + client_id, flow_id, 0, 100, with_substring="manufacturer_1" + ) self.assertEqual([i.payload for i in results], [sample_results[1].payload]) def testReadFlowResultsCorrectlyAppliesVariousCombinationsOfFilters(self): @@ -2261,14 +2593,15 @@ def testReadFlowResultsCorrectlyAppliesVariousCombinationsOfFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) tags = {None: list(sample_results), "tag_1": [sample_results[1]]} substrings = { None: list(sample_results), "manufacturer": list(sample_results), - "manufacturer_1": [sample_results[1]] + "manufacturer_1": [sample_results[1]], } types = { @@ -2280,7 +2613,8 @@ def testReadFlowResultsCorrectlyAppliesVariousCombinationsOfFilters(self): for substring_value, substring_expected in substrings.items(): for type_value, type_expected in types.items(): expected = [ - r for r in tag_expected + r + for r in tag_expected if r in substring_expected and r in type_expected ] results = self.db.ReadFlowResults( @@ -2290,20 +2624,24 @@ def testReadFlowResultsCorrectlyAppliesVariousCombinationsOfFilters(self): 100, with_tag=tag_value, with_type=type_value, - with_substring=substring_value) + with_substring=substring_value, + ) self.assertCountEqual( - [i.payload for i in expected], [i.payload for i in results], + [i.payload for i in expected], + [i.payload for i in results], "Result items do not match for " - "(tag=%s, type=%s, substring=%s): %s vs %s" % - (tag_value, type_value, substring_value, expected, results)) + "(tag=%s, type=%s, substring=%s): %s vs %s" + % (tag_value, type_value, substring_value, expected, results), + ) def testReadFlowResultsReturnsPayloadWithMissingTypeAsSpecialValue(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) type_name = jobs_pb2.ClientSummary.__name__ try: @@ -2327,7 +2665,8 @@ def testCountFlowResultsReturnsCorrectResultsCount(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_results = self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) num_results = self.db.CountFlowResults(client_id, flow_id) self.assertEqual(num_results, len(sample_results)) @@ -2337,7 +2676,8 @@ def testCountFlowResultsCorrectlyAppliesWithTagFilter(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) num_results = self.db.CountFlowResults(client_id, flow_id, with_tag="blah") self.assertEqual(num_results, 0) @@ -2365,15 +2705,18 @@ def testCountFlowResultsCorrectlyAppliesWithTypeFilter(self): self._WriteFlowResults(sample_results=[client_crash_result] * 10) num_results = self.db.CountFlowResults( - client_id, flow_id, with_type=rdf_client.ClientInformation.__name__) + client_id, flow_id, with_type=rdf_client.ClientInformation.__name__ + ) self.assertEqual(num_results, 0) num_results = self.db.CountFlowResults( - client_id, flow_id, with_type=rdf_client.ClientSummary.__name__) + client_id, flow_id, with_type=rdf_client.ClientSummary.__name__ + ) self.assertEqual(num_results, 10) num_results = self.db.CountFlowResults( - client_id, flow_id, with_type=rdf_client.ClientCrash.__name__) + client_id, flow_id, with_type=rdf_client.ClientCrash.__name__ + ) self.assertEqual(num_results, 10) def testCountFlowResultsCorrectlyAppliesWithTagAndWithTypeFilters(self): @@ -2381,7 +2724,8 @@ def testCountFlowResultsCorrectlyAppliesWithTagAndWithTypeFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) self._WriteFlowResults( - self._SampleResults(client_id, flow_id), multiple_timestamps=True) + self._SampleResults(client_id, flow_id), multiple_timestamps=True + ) num_results = self.db.CountFlowResults( client_id, @@ -2415,10 +2759,13 @@ def testCountFlowResultsByTypeReturnsCorrectNumbers(self): ) counts_by_type = self.db.CountFlowResultsByType(client_id, flow_id) - self.assertEqual(counts_by_type, { - "ClientSummary": 3, - "ClientCrash": 5, - }) + self.assertEqual( + counts_by_type, + { + "ClientSummary": 3, + "ClientCrash": 5, + }, + ) def _CreateErrors(self, client_id, flow_id, hunt_id=None): sample_errors = [] @@ -2480,7 +2827,8 @@ def testWritesAndReadsMultipleFlowErrorsOfSingleType(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id)) + self._CreateErrors(client_id, flow_id) + ) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100) self.assertLen(errors, len(sample_errors)) @@ -2498,7 +2846,8 @@ def testWritesAndReadsMultipleFlowErrorsWithDifferentTimestamps(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100) self.assertLen(errors, len(sample_errors)) @@ -2527,7 +2876,8 @@ def SampleClientSummaryError(i): return error sample_errors = self._WriteFlowErrors( - sample_errors=[SampleClientSummaryError(i) for i in range(10)]) + sample_errors=[SampleClientSummaryError(i) for i in range(10)] + ) def SampleClientCrashError(i): error = flows_pb2.FlowError(client_id=client_id, flow_id=flow_id) @@ -2543,7 +2893,9 @@ def SampleClientCrashError(i): sample_errors.extend( self._WriteFlowErrors( - sample_errors=[SampleClientCrashError(i) for i in range(10)])) + sample_errors=[SampleClientCrashError(i) for i in range(10)] + ) + ) def SampleClientInformationError(i): error = flows_pb2.FlowError(client_id=client_id, flow_id=flow_id) @@ -2552,7 +2904,9 @@ def SampleClientInformationError(i): sample_errors.extend( self._WriteFlowErrors( - sample_errors=[SampleClientInformationError(i) for i in range(10)])) + sample_errors=[SampleClientInformationError(i) for i in range(10)] + ) + ) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100) self.assertLen(errors, len(sample_errors)) @@ -2567,26 +2921,30 @@ def testReadFlowErrorsCorrectlyAppliesOffsetAndCountFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) for l in range(1, 11): for i in range(10): errors = self.db.ReadFlowErrors(client_id, flow_id, i, l) - expected = sample_errors[i:i + l] + expected = sample_errors[i : i + l] error_payloads = [x.payload for x in errors] expected_payloads = [x.payload for x in expected] self.assertEqual( - error_payloads, expected_payloads, - "Errors differ from expected (from %d, size %d): %s vs %s" % - (i, l, error_payloads, expected_payloads)) + error_payloads, + expected_payloads, + "Errors differ from expected (from %d, size %d): %s vs %s" + % (i, l, error_payloads, expected_payloads), + ) def testReadFlowErrorsCorrectlyAppliesWithTagFilter(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100, with_tag="blah") self.assertFalse(errors) @@ -2595,7 +2953,8 @@ def testReadFlowErrorsCorrectlyAppliesWithTagFilter(self): self.assertFalse(errors) errors = self.db.ReadFlowErrors( - client_id, flow_id, 0, 100, with_tag="tag_1") + client_id, flow_id, 0, 100, with_tag="tag_1" + ) self.assertEqual([i.payload for i in errors], [sample_errors[1].payload]) def testReadFlowErrorsCorrectlyAppliesWithTypeFilter(self): @@ -2615,7 +2974,8 @@ def SampleClientSummaryError(i): return error sample_errors = self._WriteFlowErrors( - sample_errors=[SampleClientSummaryError(i) for i in range(10)]) + sample_errors=[SampleClientSummaryError(i) for i in range(10)] + ) def SampleClientCrashError(i): error = flows_pb2.FlowError(client_id=client_id, flow_id=flow_id) @@ -2631,18 +2991,22 @@ def SampleClientCrashError(i): sample_errors.extend( self._WriteFlowErrors( - sample_errors=[SampleClientCrashError(i) for i in range(10)])) + sample_errors=[SampleClientCrashError(i) for i in range(10)] + ) + ) errors = self.db.ReadFlowErrors( client_id, flow_id, 0, 100, - with_type=rdf_client.ClientInformation.__name__) + with_type=rdf_client.ClientInformation.__name__, + ) self.assertFalse(errors) errors = self.db.ReadFlowErrors( - client_id, flow_id, 0, 100, with_type=rdf_client.ClientSummary.__name__) + client_id, flow_id, 0, 100, with_type=rdf_client.ClientSummary.__name__ + ) self.assertCountEqual( [i.payload for i in errors], [i.payload for i in sample_errors[:10]], @@ -2653,7 +3017,8 @@ def testReadFlowErrorsCorrectlyAppliesVariousCombinationsOfFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) tags = {None: list(sample_errors), "tag_1": [sample_errors[1]]} @@ -2666,25 +3031,23 @@ def testReadFlowErrorsCorrectlyAppliesVariousCombinationsOfFilters(self): for type_value, type_expected in types.items(): expected = [r for r in tag_expected if r in type_expected] errors = self.db.ReadFlowErrors( - client_id, - flow_id, - 0, - 100, - with_tag=tag_value, - with_type=type_value) + client_id, flow_id, 0, 100, with_tag=tag_value, with_type=type_value + ) - self.assertCountEqual([i.payload for i in expected], - [i.payload for i in errors], - "Error items do not match for " - "(tag=%s, type=%s): %s vs %s" % - (tag_value, type_value, expected, errors)) + self.assertCountEqual( + [i.payload for i in expected], + [i.payload for i in errors], + "Error items do not match for (tag=%s, type=%s): %s vs %s" + % (tag_value, type_value, expected, errors), + ) def testReadFlowErrorsReturnsPayloadWithMissingTypeAsSpecialValue(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) type_name = rdf_client.ClientSummary.__name__ try: @@ -2708,7 +3071,8 @@ def testCountFlowErrorsReturnsCorrectErrorsCount(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) sample_errors = self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) num_errors = self.db.CountFlowErrors(client_id, flow_id) self.assertEqual(num_errors, len(sample_errors)) @@ -2718,7 +3082,8 @@ def testCountFlowErrorsCorrectlyAppliesWithTagFilter(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) num_errors = self.db.CountFlowErrors(client_id, flow_id, with_tag="blah") self.assertEqual(num_errors, 0) @@ -2746,15 +3111,18 @@ def testCountFlowErrorsCorrectlyAppliesWithTypeFilter(self): self._WriteFlowErrors(sample_errors=[client_crash_error] * 10) num_errors = self.db.CountFlowErrors( - client_id, flow_id, with_type=rdf_client.ClientInformation.__name__) + client_id, flow_id, with_type=rdf_client.ClientInformation.__name__ + ) self.assertEqual(num_errors, 0) num_errors = self.db.CountFlowErrors( - client_id, flow_id, with_type=rdf_client.ClientSummary.__name__) + client_id, flow_id, with_type=rdf_client.ClientSummary.__name__ + ) self.assertEqual(num_errors, 10) num_errors = self.db.CountFlowErrors( - client_id, flow_id, with_type=rdf_client.ClientCrash.__name__) + client_id, flow_id, with_type=rdf_client.ClientCrash.__name__ + ) self.assertEqual(num_errors, 10) def testCountFlowErrorsCorrectlyAppliesWithTagAndWithTypeFilters(self): @@ -2762,13 +3130,15 @@ def testCountFlowErrorsCorrectlyAppliesWithTagAndWithTypeFilters(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) self._WriteFlowErrors( - self._CreateErrors(client_id, flow_id), multiple_timestamps=True) + self._CreateErrors(client_id, flow_id), multiple_timestamps=True + ) num_errors = self.db.CountFlowErrors( client_id, flow_id, with_tag="tag_1", - with_type=rdf_client.ClientSummary.__name__) + with_type=rdf_client.ClientSummary.__name__, + ) self.assertEqual(num_errors, 1) def testCountFlowErrorsByTypeReturnsCorrectNumbers(self): @@ -2795,10 +3165,13 @@ def testCountFlowErrorsByTypeReturnsCorrectNumbers(self): ) counts_by_type = self.db.CountFlowErrorsByType(client_id, flow_id) - self.assertEqual(counts_by_type, { - "ClientSummary": 3, - "ClientCrash": 5, - }) + self.assertEqual( + counts_by_type, + { + "ClientSummary": 3, + "ClientCrash": 5, + }, + ) def testWritesAndReadsSingleFlowLogEntry(self): client_id = db_test_utils.InitializeClient(self.db) @@ -2842,7 +3215,7 @@ def testReadFlowLogEntriesCorrectlyAppliesOffsetAndCountFilters(self): for i in range(10): for size in range(1, 10): entries = self.db.ReadFlowLogEntries(client_id, flow_id, i, size) - self.assertEqual([e.message for e in entries], messages[i:i + size]) + self.assertEqual([e.message for e in entries], messages[i : i + size]) def testReadFlowLogEntriesCorrectlyAppliesWithSubstringFilter(self): client_id = db_test_utils.InitializeClient(self.db) @@ -2850,15 +3223,18 @@ def testReadFlowLogEntriesCorrectlyAppliesWithSubstringFilter(self): messages = self._WriteFlowLogEntries(client_id, flow_id) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 0, 100, with_substring="foobar") + client_id, flow_id, 0, 100, with_substring="foobar" + ) self.assertFalse(entries) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 0, 100, with_substring="blah") + client_id, flow_id, 0, 100, with_substring="blah" + ) self.assertEqual([e.message for e in entries], messages) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 0, 100, with_substring="blah_1") + client_id, flow_id, 0, 100, with_substring="blah_1" + ) self.assertEqual([e.message for e in entries], [messages[1]]) def testReadFlowLogEntriesCorrectlyAppliesVariousCombinationsOfFilters(self): @@ -2867,15 +3243,18 @@ def testReadFlowLogEntriesCorrectlyAppliesVariousCombinationsOfFilters(self): messages = self._WriteFlowLogEntries(client_id, flow_id) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 0, 100, with_substring="foobar") + client_id, flow_id, 0, 100, with_substring="foobar" + ) self.assertFalse(entries) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 1, 2, with_substring="blah") + client_id, flow_id, 1, 2, with_substring="blah" + ) self.assertEqual([e.message for e in entries], [messages[1], messages[2]]) entries = self.db.ReadFlowLogEntries( - client_id, flow_id, 0, 1, with_substring="blah_1") + client_id, flow_id, 0, 1, with_substring="blah_1" + ) self.assertEqual([e.message for e in entries], [messages[1]]) def testCountFlowLogEntriesReturnsCorrectFlowLogEntriesCount(self): @@ -2900,8 +3279,9 @@ def testFlowLogsAndErrorsForUnknownFlowsRaise(self): self.assertEqual(context.exception.client_id, client_id) self.assertEqual(context.exception.flow_id, flow_id) - def _WriteFlowOutputPluginLogEntries(self, client_id, flow_id, - output_plugin_id): + def _WriteFlowOutputPluginLogEntries( + self, client_id, flow_id, output_plugin_id + ): entries = [] for i in range(10): message = "blah_🚀_%d" % i @@ -2946,13 +3326,16 @@ def testFlowOutputPluginLogEntriesCanBeWrittenAndThenRead(self): output_plugin_id = "1" written_entries = self._WriteFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id) + client_id, flow_id, output_plugin_id + ) read_entries = self.db.ReadFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id, 0, 100) + client_id, flow_id, output_plugin_id, 0, 100 + ) self.assertLen(written_entries, len(read_entries)) - self.assertEqual([e.message for e in written_entries], - [e.message for e in read_entries]) + self.assertEqual( + [e.message for e in written_entries], [e.message for e in read_entries] + ) def testFlowOutputPluginLogEntryWith1MbMessageCanBeWrittenAndThenRead(self): client_id = db_test_utils.InitializeClient(self.db) @@ -2969,7 +3352,8 @@ def testFlowOutputPluginLogEntryWith1MbMessageCanBeWrittenAndThenRead(self): self.db.WriteFlowOutputPluginLogEntry(entry) read_entries = self.db.ReadFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id, 0, 100) + client_id, flow_id, output_plugin_id, 0, 100 + ) self.assertLen(read_entries, 1) self.assertEqual(read_entries[0].message, entry.message) @@ -3006,29 +3390,34 @@ def testReadFlowOutputPluginLogEntriesCorrectlyAppliesOffsetCounter(self): flow_id = db_test_utils.InitializeFlow(self.db, client_id) output_plugin_id = "1" - entries = self._WriteFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id) + entries = self._WriteFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id + ) for l in range(1, 11): for i in range(10): results = self.db.ReadFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id, i, l) - expected = entries[i:i + l] + client_id, flow_id, output_plugin_id, i, l + ) + expected = entries[i : i + l] result_messages = [x.message for x in results] expected_messages = [x.message for x in expected] self.assertEqual( - result_messages, expected_messages, - "Results differ from expected (from %d, size %d): %s vs %s" % - (i, l, result_messages, expected_messages)) + result_messages, + expected_messages, + "Results differ from expected (from %d, size %d): %s vs %s" + % (i, l, result_messages, expected_messages), + ) def testReadFlowOutputPluginLogEntriesAppliesOffsetCounterWithType(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) output_plugin_id = "1" - entries = self._WriteFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id) + entries = self._WriteFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id + ) for l in range(1, 11): for i in range(10): @@ -3037,34 +3426,46 @@ def testReadFlowOutputPluginLogEntriesAppliesOffsetCounterWithType(self): flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ERROR, ]: results = self.db.ReadFlowOutputPluginLogEntries( - client_id, flow_id, output_plugin_id, i, l, with_type=with_type) - expected = [e for e in entries if e.log_entry_type == with_type - ][i:i + l] + client_id, flow_id, output_plugin_id, i, l, with_type=with_type + ) + expected = [e for e in entries if e.log_entry_type == with_type][ + i : i + l + ] result_messages = [x.message for x in results] expected_messages = [x.message for x in expected] self.assertEqual( - result_messages, expected_messages, - "Results differ from expected (from %d, size %d): %s vs %s" % - (i, l, result_messages, expected_messages)) + result_messages, + expected_messages, + "Results differ from expected (from %d, size %d): %s vs %s" + % (i, l, result_messages, expected_messages), + ) def testFlowOutputPluginLogEntriesCanBeCountedPerPlugin(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow(self.db, client_id) output_plugin_id_1 = "1" - self._WriteFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id_1) + self._WriteFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id_1 + ) output_plugin_id_2 = "2" - self._WriteFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id_2) + self._WriteFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id_2 + ) self.assertEqual( - self.db.CountFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id_1), 10) + self.db.CountFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id_1 + ), + 10, + ) self.assertEqual( - self.db.CountFlowOutputPluginLogEntries(client_id, flow_id, - output_plugin_id_2), 10) + self.db.CountFlowOutputPluginLogEntries( + client_id, flow_id, output_plugin_id_2 + ), + 10, + ) def testCountFlowOutputPluginLogEntriesRespectsWithTypeFilter(self): client_id = db_test_utils.InitializeClient(self.db) @@ -3175,7 +3576,8 @@ def testWriteScheduledFlowUpdatesExistingEntry(self): client_id=client_id, creator=username, scheduled_flow_id=sf.scheduled_flow_id, - error="foobar") + error="foobar", + ) results = self.db.ListScheduledFlows(client_id, username) self.assertLen(results, 1) @@ -3222,10 +3624,12 @@ def testListScheduledFlowsFiltersCorrectly(self): self.assertEqual(results[0].creator, username2) self.assertEmpty( - self.db.ListScheduledFlows("C.1234123412341234", username1)) + self.db.ListScheduledFlows("C.1234123412341234", username1) + ) self.assertEmpty(self.db.ListScheduledFlows(client_id1, "nonexistent")) self.assertEmpty( - self.db.ListScheduledFlows("C.1234123412341234", "nonexistent")) + self.db.ListScheduledFlows("C.1234123412341234", "nonexistent") + ) self.assertEmpty(self.db.ListScheduledFlows(client_id3, username1)) self.assertEmpty(self.db.ListScheduledFlows(client_id1, username3)) @@ -3296,7 +3700,8 @@ def testDeleteScheduledFlowRaisesForUnknownScheduledFlow(self): username = db_test_utils.InitializeUser(self.db) self._SetupScheduledFlow( - scheduled_flow_id="1", client_id=client_id, creator=username) + scheduled_flow_id="1", client_id=client_id, creator=username + ) with self.assertRaises(db.UnknownScheduledFlowError) as e: self.db.DeleteScheduledFlow(client_id, username, "2") @@ -3339,35 +3744,47 @@ class DatabaseLargeTestFlowMixin(object): # TODO(hanuszczak): Remove code duplication in the three methods below shared # with the `DatabaseTestFlowMixin` class. - def _Responses(self, client_id, flow_id, request_id, num_responses): + def _Responses( + self, client_id, flow_id, request_id, num_responses + ) -> List[flows_pb2.FlowResponse]: # TODO(hanuszczak): Fix this lint properly. # pylint: disable=g-complex-comprehension return [ - rdf_flow_objects.FlowResponse( + flows_pb2.FlowResponse( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=i) for i in range(1, num_responses + 1) + response_id=i, + ) + for i in range(1, num_responses + 1) ] # pylint: enable=g-complex-comprehension - def _ResponsesAndStatus(self, client_id, flow_id, request_id, num_responses): + def _ResponsesAndStatus( + self, client_id, flow_id, request_id, num_responses + ) -> List[Union[flows_pb2.FlowResponse, flows_pb2.FlowStatus]]: return self._Responses(client_id, flow_id, request_id, num_responses) + [ - rdf_flow_objects.FlowStatus( + flows_pb2.FlowStatus( client_id=client_id, flow_id=flow_id, request_id=request_id, - response_id=num_responses + 1) + response_id=num_responses + 1, + ) ] - def _WriteResponses(self, num): + def _WriteResponses(self, num) -> Tuple[ + flows_pb2.FlowRequest, + List[Union[flows_pb2.FlowResponse, flows_pb2.FlowStatus]], + ]: client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) - request = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=2) + request = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=2 + ) self.db.WriteFlowRequests([request]) # Generate responses together with a status message. @@ -3381,81 +3798,132 @@ def _WriteResponses(self, num): def test40001RequestsCanBeWrittenAndRead(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) requests = [ - rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=i) + flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=i + ) for i in range(40001) ] self.db.WriteFlowRequests(requests) self.assertLen( - self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id), 40001) + self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id), 40001 + ) def test40001ResponsesCanBeWrittenAndRead(self): + before_write = self.db.Now() request, responses = self._WriteResponses(40001) + after_write = self.db.Now() - expected_request = rdf_flow_objects.FlowRequest( + expected_request = flows_pb2.FlowRequest( client_id=request.client_id, flow_id=request.flow_id, request_id=request.request_id, needs_processing=True, - nr_responses_expected=40002) + nr_responses_expected=40002, + ) rrp = self.db.ReadFlowRequestsReadyForProcessing( request.client_id, request.flow_id, - next_needed_request=request.request_id) + next_needed_request=request.request_id, + ) self.assertLen(rrp, 1) fetched_request, fetched_responses = rrp[request.request_id] - self.assertEqual(fetched_request, expected_request) + + self.assertIsInstance(fetched_request, flows_pb2.FlowRequest) + self.assertEqual(fetched_request.client_id, expected_request.client_id) + self.assertEqual(fetched_request.flow_id, expected_request.flow_id) + self.assertEqual(fetched_request.request_id, expected_request.request_id) + self.assertEqual( + fetched_request.needs_processing, expected_request.needs_processing + ) + self.assertEqual( + fetched_request.nr_responses_expected, + expected_request.nr_responses_expected, + ) + self.assertBetween(fetched_request.timestamp, before_write, after_write) + + for r in fetched_responses: + # `responses` does not have the timestamp as it is only available after + # reading, not writing, so we compare it manually and remove it from the + # proto. + self.assertBetween(r.timestamp, before_write, after_write) + r.ClearField("timestamp") self.assertEqual(fetched_responses, responses) - arrp = self.db.ReadAllFlowRequestsAndResponses(request.client_id, - request.flow_id) + arrp = self.db.ReadAllFlowRequestsAndResponses( + request.client_id, request.flow_id + ) self.assertLen(arrp, 1) fetched_request, fetched_responses = arrp[0] - self.assertEqual(fetched_request, expected_request) - self.assertEqual([r for _, r in sorted(fetched_responses.items())], - responses) + self.assertIsInstance(fetched_request, flows_pb2.FlowRequest) + self.assertEqual(fetched_request.client_id, expected_request.client_id) + self.assertEqual(fetched_request.flow_id, expected_request.flow_id) + self.assertEqual(fetched_request.request_id, expected_request.request_id) + self.assertEqual( + fetched_request.needs_processing, expected_request.needs_processing + ) + self.assertEqual( + fetched_request.nr_responses_expected, + expected_request.nr_responses_expected, + ) + self.assertBetween(fetched_request.timestamp, before_write, after_write) + for r in fetched_responses.values(): + # `responses` does not have the timestamp as it is only available after + # reading, not writing, so we compare it manually and remove it from the + # proto. + self.assertBetween(r.timestamp, before_write, after_write) + r.ClearField("timestamp") + self.assertEqual( + [r for _, r in sorted(fetched_responses.items())], responses + ) def testDeleteAllFlowRequestsAndResponsesHandles11000Responses(self): request, _ = self._WriteResponses(11000) - self.db.DeleteAllFlowRequestsAndResponses(request.client_id, - request.flow_id) - arrp = self.db.ReadAllFlowRequestsAndResponses(request.client_id, - request.flow_id) + self.db.DeleteAllFlowRequestsAndResponses( + request.client_id, request.flow_id + ) + arrp = self.db.ReadAllFlowRequestsAndResponses( + request.client_id, request.flow_id + ) self.assertEmpty(arrp) def testDeleteFlowRequestsHandles11000Responses(self): request, _ = self._WriteResponses(11000) - self.db.DeleteFlowRequests([request]) - arrp = self.db.ReadAllFlowRequestsAndResponses(request.client_id, - request.flow_id) + arrp = self.db.ReadAllFlowRequestsAndResponses( + request.client_id, request.flow_id + ) self.assertEmpty(arrp) def testDeleteFlowRequestsHandles11000Requests(self): client_id = db_test_utils.InitializeClient(self.db) flow_id = db_test_utils.InitializeFlow( - self.db, client_id, next_request_to_process=2) + self.db, client_id, next_request_to_process=2 + ) requests = [ - rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=i) + flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=i + ) for i in range(2, 11002) ] self.db.WriteFlowRequests(requests) self.assertLen( - self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id), 11000) + self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id), 11000 + ) self.db.DeleteFlowRequests(requests) self.assertEmpty( - self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id)) + self.db.ReadAllFlowRequestsAndResponses(client_id, flow_id) + ) def testWritesAndCounts40001FlowResults(self): client_id = db_test_utils.InitializeClient(self.db) diff --git a/grr/server/grr_response_server/databases/db_foreman_rules_test.py b/grr/server/grr_response_server/databases/db_foreman_rules_test.py index f14faf982a..04e1fdd45b 100644 --- a/grr/server/grr_response_server/databases/db_foreman_rules_test.py +++ b/grr/server/grr_response_server/databases/db_foreman_rules_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Mixin tests for storing Foreman rules in the relational db.""" + from typing import Optional from grr_response_core.lib import rdfvalue diff --git a/grr/server/grr_response_server/databases/db_hunts_test.py b/grr/server/grr_response_server/databases/db_hunts_test.py index ecb8168b2f..13b049fc1e 100644 --- a/grr/server/grr_response_server/databases/db_hunts_test.py +++ b/grr/server/grr_response_server/databases/db_hunts_test.py @@ -3,26 +3,29 @@ import collections import random -from typing import List +from typing import List, Optional +from google.protobuf import any_pb2 from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats from grr_response_core.lib.rdfvalues import stats as rdf_stats +from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto import output_plugin_pb2 from grr_response_server import flow from grr_response_server.databases import db from grr_response_server.databases import db_test_utils from grr_response_server.output_plugins import email_plugin from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects -from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin class DatabaseTestHuntMixin(object): @@ -31,93 +34,117 @@ class DatabaseTestHuntMixin(object): This mixin adds methods to test the handling of hunts. """ - def _SetupHuntClientAndFlow(self, - client_id=None, - hunt_id=None, - flow_id=None, - **additional_flow_args): + def _SetupHuntClientAndFlow( + self, + hunt_id: str = None, + client_id: Optional[str] = None, + flow_id: Optional[str] = None, + flow_state: Optional[rdf_structs.EnumNamedValue] = None, + parent_flow_id: Optional[str] = None, + ): client_id = db_test_utils.InitializeClient(self.db, client_id=client_id) # Top-level hunt-induced flows should have hunt's id. flow_id = flow_id or hunt_id - self.db.WriteClientMetadata(client_id) - rdf_flow = rdf_flow_objects.Flow( - client_id=client_id, + flow_id = db_test_utils.InitializeFlow( + self.db, + client_id, flow_id=flow_id, + flow_state=flow_state, parent_hunt_id=hunt_id, - **additional_flow_args) - self.db.WriteFlowObject(rdf_flow) + parent_flow_id=parent_flow_id, + ) return client_id, flow_id def testWritingAndReadingHuntObjectWorks(self): then = rdfvalue.RDFDatetime.Now() - self.db.WriteGRRUser("Foo") - hunt_obj = rdf_hunt_objects.Hunt(creator="Foo", description="Lorem ipsum.") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + db_test_utils.InitializeUser(self.db, "Foo") + hunt_id = db_test_utils.InitializeHunt( + self.db, + creator="Foo", + description="Lorem ipsum.", + ) - read_hunt_obj = self.db.ReadHuntObject(hunt_obj.hunt_id) + read_hunt_obj = self.db.ReadHuntObject(hunt_id) self.assertEqual(read_hunt_obj.creator, "Foo") self.assertEqual(read_hunt_obj.description, "Lorem ipsum.") - self.assertGreater(read_hunt_obj.create_time, then) - self.assertGreater(read_hunt_obj.last_update_time, then) + self.assertGreater( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + read_hunt_obj.create_time + ), + then, + ) + self.assertGreater( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + read_hunt_obj.last_update_time + ), + then, + ) def testWritingHuntObjectIntegralClientRate(self): - creator = db_test_utils.InitializeUser(self.db) + creator = db_test_utils.InitializeUser(self.db, "user") - hunt_obj = rdf_hunt_objects.Hunt() - hunt_obj.creator = creator - hunt_obj.description = "Lorem ipsum." - hunt_obj.client_rate = 42 - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt( + self.db, + creator="user", + description="Lorem ipsum.", + client_rate=42, + ) - hunt_obj = self.db.ReadHuntObject(hunt_obj.hunt_id) + hunt_obj = self.db.ReadHuntObject(hunt_id) self.assertEqual(hunt_obj.creator, creator) self.assertEqual(hunt_obj.description, "Lorem ipsum.") self.assertAlmostEqual(hunt_obj.client_rate, 42, places=5) def testWritingHuntObjectFractionalClientRate(self): - creator = db_test_utils.InitializeUser(self.db) + creator = db_test_utils.InitializeUser(self.db, "user") - hunt_obj = rdf_hunt_objects.Hunt() - hunt_obj.creator = creator - hunt_obj.description = "Lorem ipsum." - hunt_obj.client_rate = 3.14 - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt( + self.db, + creator="user", + description="Lorem ipsum.", + client_rate=3.14, + ) - hunt_obj = self.db.ReadHuntObject(hunt_obj.hunt_id) + hunt_obj = self.db.ReadHuntObject(hunt_id) self.assertEqual(hunt_obj.creator, creator) self.assertEqual(hunt_obj.description, "Lorem ipsum.") self.assertAlmostEqual(hunt_obj.client_rate, 3.14, places=5) def testHuntObjectCannotBeOverwritten(self): - self.db.WriteGRRUser("user") + db_test_utils.InitializeUser(self.db, "user") hunt_id = "ABCDEF42" - hunt_obj_v1 = rdf_hunt_objects.Hunt( - hunt_id=hunt_id, description="foo", creator="user") - hunt_obj_v2 = rdf_hunt_objects.Hunt( - hunt_id=hunt_id, description="bar", creator="user") + hunt_obj_v1 = hunts_pb2.Hunt( + hunt_id=hunt_id, + hunt_state=hunts_pb2.Hunt.HuntState.PAUSED, + description="foo", + creator="user", + ) + hunt_obj_v2 = hunts_pb2.Hunt( + hunt_id=hunt_id, + hunt_state=hunts_pb2.Hunt.HuntState.PAUSED, + description="bar", + creator="user", + ) hunt_obj_v2.hunt_id = hunt_obj_v1.hunt_id - hunt_obj_v1 = mig_hunt_objects.ToProtoHunt(hunt_obj_v1) self.db.WriteHuntObject(hunt_obj_v1) - hunt_obj_v2 = mig_hunt_objects.ToProtoHunt(hunt_obj_v2) with self.assertRaises(db.DuplicatedHuntError) as context: self.db.WriteHuntObject(hunt_obj_v2) self.assertEqual(context.exception.hunt_id, hunt_id) def testHuntObjectCannotBeWrittenInNonPausedState(self): - self.db.WriteGRRUser("user") - hunt_object = rdf_hunt_objects.Hunt( - hunt_state=rdf_hunt_objects.Hunt.HuntState.STARTED, creator="user") + db_test_utils.InitializeUser(self.db, "user") + hunt_object = hunts_pb2.Hunt( + hunt_id=rdf_hunt_objects.RandomHuntId(), + hunt_state=rdf_hunt_objects.Hunt.HuntState.STARTED, + creator="user", + ) - hunt_object = mig_hunt_objects.ToProtoHunt(hunt_object) with self.assertRaises(ValueError): self.db.WriteHuntObject(hunt_object) @@ -129,69 +156,82 @@ def testUpdateHuntObjectRaisesIfHuntDoesNotExist(self): with self.assertRaises(db.UnknownHuntError): self.db.UpdateHuntObject( rdf_hunt_objects.RandomHuntId(), - hunt_state=rdf_hunt_objects.Hunt.HuntState.STARTED) + hunt_state=hunts_pb2.Hunt.HuntState.STARTED, + ) def testUpdateHuntObjectCorrectlyUpdatesHuntObject(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) self.db.UpdateHuntObject( - hunt_obj.hunt_id, + hunt_id, duration=rdfvalue.Duration.From(1, rdfvalue.WEEKS), client_rate=33, client_limit=48, - hunt_state=rdf_hunt_objects.Hunt.HuntState.STOPPED, + hunt_state=hunts_pb2.Hunt.HuntState.STOPPED, hunt_state_comment="foo", start_time=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(43), - num_clients_at_start_time=44) + num_clients_at_start_time=44, + ) - updated_hunt_obj = self.db.ReadHuntObject(hunt_obj.hunt_id) - self.assertEqual(updated_hunt_obj.duration, - rdfvalue.Duration.From(1, rdfvalue.WEEKS)) + updated_hunt_obj = self.db.ReadHuntObject(hunt_id) + self.assertEqual( + rdfvalue.Duration.From(updated_hunt_obj.duration, rdfvalue.SECONDS), + rdfvalue.Duration.From(1, rdfvalue.WEEKS), + ) self.assertEqual(updated_hunt_obj.client_rate, 33) self.assertEqual(updated_hunt_obj.client_limit, 48) - self.assertEqual(updated_hunt_obj.hunt_state, - rdf_hunt_objects.Hunt.HuntState.STOPPED) + self.assertEqual( + updated_hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.STOPPED + ) self.assertEqual(updated_hunt_obj.hunt_state_comment, "foo") - self.assertEqual(updated_hunt_obj.init_start_time, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(43)) - self.assertEqual(updated_hunt_obj.last_start_time, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(43)) + self.assertEqual( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + updated_hunt_obj.init_start_time + ), + rdfvalue.RDFDatetime.FromSecondsSinceEpoch(43), + ) + self.assertEqual( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + updated_hunt_obj.last_start_time + ), + rdfvalue.RDFDatetime.FromSecondsSinceEpoch(43), + ) self.assertEqual(updated_hunt_obj.num_clients_at_start_time, 44) def testUpdateHuntObjectCorrectlyUpdatesInitAndLastStartTime(self): - self.db.WriteGRRUser("user") - hunt_object = rdf_hunt_objects.Hunt( - description="Lorem ipsum.", creator="user") - hunt_object = mig_hunt_objects.ToProtoHunt(hunt_object) - self.db.WriteHuntObject(hunt_object) + hunt_id = db_test_utils.InitializeHunt(self.db) timestamp_1 = rdfvalue.RDFDatetime.Now() - self.db.UpdateHuntObject(hunt_object.hunt_id, start_time=timestamp_1) + self.db.UpdateHuntObject(hunt_id, start_time=timestamp_1) timestamp_2 = rdfvalue.RDFDatetime.Now() - self.db.UpdateHuntObject(hunt_object.hunt_id, start_time=timestamp_2) + self.db.UpdateHuntObject(hunt_id, start_time=timestamp_2) - updated_hunt_object = self.db.ReadHuntObject(hunt_object.hunt_id) - self.assertEqual(updated_hunt_object.init_start_time, timestamp_1) - self.assertEqual(updated_hunt_object.last_start_time, timestamp_2) + updated_hunt_object = self.db.ReadHuntObject(hunt_id) + self.assertEqual( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + updated_hunt_object.init_start_time + ), + timestamp_1, + ) + self.assertEqual( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + updated_hunt_object.last_start_time + ), + timestamp_2, + ) def testDeletingHuntObjectWorks(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) # This shouldn't raise. - self.db.ReadHuntObject(hunt_obj.hunt_id) + self.db.ReadHuntObject(hunt_id) - self.db.DeleteHuntObject(hunt_obj.hunt_id) + self.db.DeleteHuntObject(hunt_id) # The hunt is deleted: this should raise now. with self.assertRaises(db.UnknownHuntError): - self.db.ReadHuntObject(hunt_obj.hunt_id) + self.db.ReadHuntObject(hunt_id) def testDeleteHuntObjectWithApprovalRequest(self): creator = db_test_utils.InitializeUser(self.db) @@ -220,9 +260,8 @@ def testDeleteHuntObjectWithApprovalRequest(self): def testReadHuntObjectsReturnsEmptyListWhenNoHunts(self): self.assertEqual(self.db.ReadHuntObjects(offset=0, count=db.MAX_COUNT), []) - def _CreateMultipleHunts(self): - self.db.WriteGRRUser("user-a") - self.db.WriteGRRUser("user-b") + def _CreateMultipleHunts(self) -> List[hunts_pb2.Hunt]: + self._CreateMultipleUsers(["user-a", "user-b"]) result = [] for i in range(10): @@ -230,35 +269,41 @@ def _CreateMultipleHunts(self): creator = "user-a" else: creator = "user-b" - hunt_obj = rdf_hunt_objects.Hunt( - description="foo_%d" % i, creator=creator) - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - result.append(self.db.ReadHuntObject(hunt_obj.hunt_id)) + hunt_id = db_test_utils.InitializeHunt( + self.db, + creator=creator, + description="foo_%d" % i, + ) + result.append(self.db.ReadHuntObject(hunt_id)) return result def _CreateMultipleUsers(self, users: List[str]): for user in users: - self.db.WriteGRRUser(user) + db_test_utils.InitializeUser(self.db, user) - def _CreateMultipleHuntsForUser(self, user: str, count: int): + def _CreateMultipleHuntsForUser( + self, + user: str, + count: int, + ) -> List[hunts_pb2.Hunt]: result = [] for i in range(count): - hunt_obj = rdf_hunt_objects.Hunt(description="foo_%d" % i, creator=user) - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - result.append(self.db.ReadHuntObject(hunt_obj.hunt_id)) + hunt_id = db_test_utils.InitializeHunt( + self.db, + creator=user, + description="foo_%d" % i, + ) + result.append(self.db.ReadHuntObject(hunt_id)) return result def _CreateHuntWithState( - self, creator: str, state: rdf_hunt_objects.Hunt.HuntState - ) -> rdf_hunt_objects.Hunt: - hunt_obj = rdf_hunt_objects.Hunt(creator=creator) - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - self.db.UpdateHuntObject(hunt_obj.hunt_id, hunt_state=state) - return self.db.ReadHuntObject(hunt_obj.hunt_id) + self, creator: str, state: hunts_pb2.Hunt.HuntState + ) -> hunts_pb2.Hunt: + hunt_id = db_test_utils.InitializeHunt(self.db, creator=creator) + + self.db.UpdateHuntObject(hunt_id, hunt_state=state) + return self.db.ReadHuntObject(hunt_id) def testReadHuntObjectsWithoutFiltersReadsAllHunts(self): expected = self._CreateMultipleHunts() @@ -416,52 +461,66 @@ def testReadHuntObjectsCreatedAfterFilterIsAppliedCorrectly(self): got = self.db.ReadHuntObjects( 0, db.MAX_COUNT, - created_after=all_hunts[0].create_time - - rdfvalue.Duration.From(1, rdfvalue.SECONDS)) + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[0].create_time + ) + - rdfvalue.Duration.From(1, rdfvalue.SECONDS), + ) self.assertListEqual(got, list(reversed(all_hunts))) got = self.db.ReadHuntObjects( - 0, db.MAX_COUNT, created_after=all_hunts[2].create_time) + 0, + db.MAX_COUNT, + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[2].create_time + ), + ) self.assertListEqual(got, list(reversed(all_hunts[3:]))) got = self.db.ReadHuntObjects( - 0, db.MAX_COUNT, created_after=all_hunts[-1].create_time) + 0, + db.MAX_COUNT, + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[-1].create_time + ), + ) self.assertEmpty(got) def testReadHuntObjectsWithDescriptionMatchFilterIsAppliedCorrectly(self): all_hunts = self._CreateMultipleHunts() got = self.db.ReadHuntObjects( - 0, db.MAX_COUNT, with_description_match="foo_") + 0, db.MAX_COUNT, with_description_match="foo_" + ) self.assertListEqual(got, list(reversed(all_hunts))) got = self.db.ReadHuntObjects( - 0, db.MAX_COUNT, with_description_match="blah") + 0, db.MAX_COUNT, with_description_match="blah" + ) self.assertEmpty(got) got = self.db.ReadHuntObjects( - 0, db.MAX_COUNT, with_description_match="foo_3") + 0, db.MAX_COUNT, with_description_match="foo_3" + ) self.assertListEqual(got, [all_hunts[3]]) def testReadHuntObjectsWithStatesFilterIsAppliedCorrectly(self): creator = "testuser" - self.db.WriteGRRUser(creator) + db_test_utils.InitializeUser(self.db, creator) paused_hunt = self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.PAUSED + creator, hunts_pb2.Hunt.HuntState.PAUSED ) - self._CreateHuntWithState(creator, rdf_hunt_objects.Hunt.HuntState.STARTED) + self._CreateHuntWithState(creator, hunts_pb2.Hunt.HuntState.STARTED) stopped_hunt = self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.STOPPED - ) - self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.COMPLETED + creator, hunts_pb2.Hunt.HuntState.STOPPED ) + self._CreateHuntWithState(creator, hunts_pb2.Hunt.HuntState.COMPLETED) got = self.db.ReadHuntObjects( 0, db.MAX_COUNT, with_states=[ - rdf_hunt_objects.Hunt.HuntState.PAUSED, + hunts_pb2.Hunt.HuntState.PAUSED, ], ) self.assertListEqual(got, [paused_hunt]) @@ -470,8 +529,8 @@ def testReadHuntObjectsWithStatesFilterIsAppliedCorrectly(self): 0, db.MAX_COUNT, with_states=[ - rdf_hunt_objects.Hunt.HuntState.PAUSED, - rdf_hunt_objects.Hunt.HuntState.STOPPED, + hunts_pb2.Hunt.HuntState.PAUSED, + hunts_pb2.Hunt.HuntState.STOPPED, ], ) self.assertCountEqual(got, [paused_hunt, stopped_hunt]) @@ -489,26 +548,30 @@ def testReadHuntObjectsCombinationsOfFiltersAreAppliedCorrectly(self): self.db.ReadHuntObjects, conditions=dict( with_creator="user-a", - created_after=expected[2].create_time, - with_description_match="foo_4"), - error_desc="ReadHuntObjects") + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + expected[2].create_time + ), + with_description_match="foo_4", + ), + error_desc="ReadHuntObjects", + ) def testListHuntObjectsReturnsEmptyListWhenNoHunts(self): self.assertEqual(self.db.ListHuntObjects(offset=0, count=db.MAX_COUNT), []) def testListHuntObjectsWithoutFiltersReadsAllHunts(self): - expected = [ - rdf_hunt_objects.HuntMetadata.FromHunt(h) - for h in self._CreateMultipleHunts() - ] + hunts = self._CreateMultipleHunts() + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] + expected = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + expected = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in expected] got = self.db.ListHuntObjects(0, db.MAX_COUNT) self.assertListEqual(got, list(reversed(expected))) def testListHuntObjectsWithCreatorFilterIsAppliedCorrectly(self): - all_hunts = [ - rdf_hunt_objects.HuntMetadata.FromHunt(h) - for h in self._CreateMultipleHunts() - ] + all_hunts = self._CreateMultipleHunts() + all_hunts = [mig_hunt_objects.ToRDFHunt(h) for h in all_hunts] + all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in all_hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects(0, db.MAX_COUNT, with_creator="user-a") self.assertListEqual(got, list(reversed(all_hunts[:5]))) @@ -521,7 +584,9 @@ def testListHuntObjectsWithCreatedByFilterIsAppliedCorrectly(self): hunts = self._CreateMultipleHuntsForUser( "user-a", 5 ) + self._CreateMultipleHuntsForUser("user-b", 5) + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects(0, db.MAX_COUNT, created_by=frozenset([])) self.assertListEqual(got, []) @@ -548,7 +613,9 @@ def testListHuntObjectsWithCreatorAndCreatedByFilterIsAppliedCorrectly( hunts = self._CreateMultipleHuntsForUser( "user-a", 5 ) + self._CreateMultipleHuntsForUser("user-b", 5) + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects( 0, @@ -587,7 +654,9 @@ def testListHuntObjectsWithNotCreatedByFilterIsAppliedCorrectly(self): hunts = self._CreateMultipleHuntsForUser( "user-a", 5 ) + self._CreateMultipleHuntsForUser("user-b", 5) + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects(0, db.MAX_COUNT, not_created_by=frozenset([])) self.assertListEqual(got, list(reversed(all_hunts))) @@ -614,7 +683,10 @@ def testListHuntObjectsWithCreatorAndNotCreatedByFilterIsAppliedCorrectly( hunts = self._CreateMultipleHuntsForUser( "user-a", 5 ) + self._CreateMultipleHuntsForUser("user-b", 5) + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] + all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects( 0, @@ -657,31 +729,44 @@ def testListHuntObjectsWithCreatorAndNotCreatedByFilterIsAppliedCorrectly( self.assertListEqual(got, list(reversed(all_hunts[5:]))) def testListHuntObjectsCreatedAfterFilterIsAppliedCorrectly(self): - all_hunts = [ - rdf_hunt_objects.HuntMetadata.FromHunt(h) - for h in self._CreateMultipleHunts() - ] + hunts = self._CreateMultipleHunts() + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] + all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects( 0, db.MAX_COUNT, - created_after=all_hunts[0].create_time - - rdfvalue.Duration.From(1, rdfvalue.SECONDS)) + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[0].create_time + ) + - rdfvalue.Duration.From(1, rdfvalue.SECONDS), + ) self.assertListEqual(got, list(reversed(all_hunts))) got = self.db.ListHuntObjects( - 0, db.MAX_COUNT, created_after=all_hunts[2].create_time) + 0, + db.MAX_COUNT, + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[2].create_time + ), + ) self.assertListEqual(got, list(reversed(all_hunts[3:]))) got = self.db.ListHuntObjects( - 0, db.MAX_COUNT, created_after=all_hunts[-1].create_time) + 0, + db.MAX_COUNT, + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + all_hunts[-1].create_time + ), + ) self.assertEmpty(got) def testListHuntObjectsWithDescriptionMatchFilterIsAppliedCorrectly(self): - all_hunts = [ - rdf_hunt_objects.HuntMetadata.FromHunt(h) - for h in self._CreateMultipleHunts() - ] + hunts = self._CreateMultipleHunts() + hunts = [mig_hunt_objects.ToRDFHunt(h) for h in hunts] + all_hunts = [rdf_hunt_objects.HuntMetadata.FromHunt(h) for h in hunts] + all_hunts = [mig_hunt_objects.ToProtoHuntMetadata(h) for h in all_hunts] got = self.db.ListHuntObjects( 0, db.MAX_COUNT, with_description_match="foo_" @@ -700,27 +785,31 @@ def testListHuntObjectsWithDescriptionMatchFilterIsAppliedCorrectly(self): def testListHuntObjectsWithStatesFilterIsAppliedCorrectly(self): creator = "testuser" - self.db.WriteGRRUser(creator) - paused_hunt_metadata = rdf_hunt_objects.HuntMetadata.FromHunt( - self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.PAUSED - ) + db_test_utils.InitializeUser(self.db, creator) + paused_hunt = self._CreateHuntWithState( + creator, hunts_pb2.Hunt.HuntState.PAUSED + ) + paused_hunt = mig_hunt_objects.ToRDFHunt(paused_hunt) + paused_hunt_metadata = rdf_hunt_objects.HuntMetadata.FromHunt(paused_hunt) + paused_hunt_metadata = mig_hunt_objects.ToProtoHuntMetadata( + paused_hunt_metadata ) self._CreateHuntWithState(creator, rdf_hunt_objects.Hunt.HuntState.STARTED) - stopped_hunt_metadata = rdf_hunt_objects.HuntMetadata.FromHunt( - self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.STOPPED - ) + stopped_hunt = self._CreateHuntWithState( + creator, hunts_pb2.Hunt.HuntState.STOPPED ) - self._CreateHuntWithState( - creator, rdf_hunt_objects.Hunt.HuntState.COMPLETED + stopped_hunt = mig_hunt_objects.ToRDFHunt(stopped_hunt) + stopped_hunt_metadata = rdf_hunt_objects.HuntMetadata.FromHunt(stopped_hunt) + stopped_hunt_metadata = mig_hunt_objects.ToProtoHuntMetadata( + stopped_hunt_metadata ) + self._CreateHuntWithState(creator, hunts_pb2.Hunt.HuntState.COMPLETED) got = self.db.ListHuntObjects( 0, db.MAX_COUNT, with_states=[ - rdf_hunt_objects.Hunt.HuntState.PAUSED, + hunts_pb2.Hunt.HuntState.PAUSED, ], ) self.assertCountEqual(got, [paused_hunt_metadata]) @@ -729,8 +818,8 @@ def testListHuntObjectsWithStatesFilterIsAppliedCorrectly(self): 0, db.MAX_COUNT, with_states=[ - rdf_hunt_objects.Hunt.HuntState.PAUSED, - rdf_hunt_objects.Hunt.HuntState.STOPPED, + hunts_pb2.Hunt.HuntState.PAUSED, + hunts_pb2.Hunt.HuntState.STOPPED, ], ) self.assertCountEqual(got, [paused_hunt_metadata, stopped_hunt_metadata]) @@ -744,11 +833,14 @@ def testListHuntObjectsWithStatesFilterIsAppliedCorrectly(self): def testListHuntObjectsCombinationsOfFiltersAreAppliedCorrectly(self): expected = self._CreateMultipleHunts() + self.DoFilterCombinationsAndOffsetCountTest( self.db.ListHuntObjects, conditions=dict( with_creator="user-a", - created_after=expected[2].create_time, + created_after=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + expected[2].create_time + ), with_description_match="foo_4", created_by=frozenset(["user-a"]), not_created_by=frozenset(["user-b"]), @@ -758,47 +850,79 @@ def testListHuntObjectsCombinationsOfFiltersAreAppliedCorrectly(self): ) def testWritingAndReadingHuntOutputPluginsStatesWorks(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - plugin_descriptor = rdf_output_plugin.OutputPluginDescriptor( - plugin_name=email_plugin.EmailOutputPlugin.__name__, - args=email_plugin.EmailOutputPluginArgs(emails_limit=42)) - state_1 = rdf_flow_runner.OutputPluginState( - plugin_descriptor=plugin_descriptor, plugin_state={}) + email_args = output_plugin_pb2.EmailOutputPluginArgs(emails_limit=42) + email_args_any = any_pb2.Any() + email_args_any.Pack(email_args) + plugin_descriptor = output_plugin_pb2.OutputPluginDescriptor( + plugin_name=email_plugin.EmailOutputPlugin.__name__, args=email_args_any + ) + plugin_state = jobs_pb2.AttributedDict( + dat=[ + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="a_foo1"), + v=jobs_pb2.DataBlob(string="a_bar1"), + ), + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="a_foo2"), + v=jobs_pb2.DataBlob(string="a_bar2"), + ), + ] + ) + state_1 = output_plugin_pb2.OutputPluginState( + plugin_descriptor=plugin_descriptor, + plugin_state=plugin_state, + ) - plugin_descriptor = rdf_output_plugin.OutputPluginDescriptor( + email_args_2 = output_plugin_pb2.EmailOutputPluginArgs(emails_limit=43) + email_args_any_2 = any_pb2.Any() + email_args_any_2.Pack(email_args_2) + plugin_descriptor_2 = output_plugin_pb2.OutputPluginDescriptor( plugin_name=email_plugin.EmailOutputPlugin.__name__, - args=email_plugin.EmailOutputPluginArgs(emails_limit=43)) - state_2 = rdf_flow_runner.OutputPluginState( - plugin_descriptor=plugin_descriptor, plugin_state={}) + args=email_args_any_2, + ) + plugin_state_2 = jobs_pb2.AttributedDict( + dat=[ + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="b_foo1"), + v=jobs_pb2.DataBlob(string="b_bar1"), + ), + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="b_foo2"), + v=jobs_pb2.DataBlob(string="b_bar2"), + ), + ] + ) + state_2 = output_plugin_pb2.OutputPluginState( + plugin_descriptor=plugin_descriptor_2, + plugin_state=plugin_state_2, + ) written_states = [state_1, state_2] - self.db.WriteHuntOutputPluginsStates(hunt_obj.hunt_id, written_states) + self.db.WriteHuntOutputPluginsStates(hunt_id, written_states) - read_states = self.db.ReadHuntOutputPluginsStates(hunt_obj.hunt_id) + read_states = self.db.ReadHuntOutputPluginsStates(hunt_id) self.assertEqual(read_states, written_states) def testReadingHuntOutputPluginsReturnsThemInOrderOfWriting(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) states = [] for i in range(100): states.append( - rdf_flow_runner.OutputPluginState( - plugin_descriptor=rdf_output_plugin.OutputPluginDescriptor( - plugin_name="DummyHuntOutputPlugin_%d" % i), - plugin_state={})) + output_plugin_pb2.OutputPluginState( + plugin_descriptor=output_plugin_pb2.OutputPluginDescriptor( + plugin_name="DummyHuntOutputPlugin_%d" % i + ), + plugin_state=jobs_pb2.AttributedDict(), + ) + ) random.shuffle(states) - self.db.WriteHuntOutputPluginsStates(hunt_obj.hunt_id, states) + self.db.WriteHuntOutputPluginsStates(hunt_id, states) - read_states = self.db.ReadHuntOutputPluginsStates(hunt_obj.hunt_id) + read_states = self.db.ReadHuntOutputPluginsStates(hunt_id) self.assertEqual(read_states, states) def testWritingHuntOutputStatesForZeroPlugins(self): @@ -807,21 +931,21 @@ def testWritingHuntOutputStatesForZeroPlugins(self): self.db.WriteHuntOutputPluginsStates(rdf_hunt_objects.RandomHuntId(), []) def testWritingHuntOutputStatesForUnknownHuntRaises(self): - state = rdf_flow_runner.OutputPluginState( - plugin_descriptor=rdf_output_plugin.OutputPluginDescriptor( - plugin_name="DummyHuntOutputPlugin1"), - plugin_state={}) + state = output_plugin_pb2.OutputPluginState( + plugin_descriptor=output_plugin_pb2.OutputPluginDescriptor( + plugin_name="DummyHuntOutputPlugin1" + ), + plugin_state=jobs_pb2.AttributedDict(), + ) with self.assertRaises(db.UnknownHuntError): - self.db.WriteHuntOutputPluginsStates(rdf_hunt_objects.RandomHuntId(), - [state]) + self.db.WriteHuntOutputPluginsStates( + rdf_hunt_objects.RandomHuntId(), [state] + ) def testReadingHuntOutputPluginsWithoutStates(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - res = self.db.ReadHuntOutputPluginsStates(hunt_obj.hunt_id) + hunt_id = db_test_utils.InitializeHunt(self.db) + res = self.db.ReadHuntOutputPluginsStates(hunt_id) self.assertEqual(res, []) def testReadingHuntOutputStatesForUnknownHuntRaises(self): @@ -830,76 +954,109 @@ def testReadingHuntOutputStatesForUnknownHuntRaises(self): def testUpdatingHuntOutputStateForUnknownHuntRaises(self): with self.assertRaises(db.UnknownHuntError): - self.db.UpdateHuntOutputPluginState(rdf_hunt_objects.RandomHuntId(), - 0, lambda x: x) + self.db.UpdateHuntOutputPluginState( + rdf_hunt_objects.RandomHuntId(), 0, lambda x: x + ) def testUpdatingHuntOutputStateWorksCorrectly(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - state_1 = rdf_flow_runner.OutputPluginState( - plugin_descriptor=rdf_output_plugin.OutputPluginDescriptor( - plugin_name="DummyHuntOutputPlugin1"), - plugin_state={}) + state_1 = output_plugin_pb2.OutputPluginState( + plugin_descriptor=output_plugin_pb2.OutputPluginDescriptor( + plugin_name="DummyHuntOutputPlugin1" + ), + plugin_state=jobs_pb2.AttributedDict(), + ) - state_2 = rdf_flow_runner.OutputPluginState( - plugin_descriptor=rdf_output_plugin.OutputPluginDescriptor( - plugin_name="DummyHuntOutputPlugin2"), - plugin_state={}) + state_2 = output_plugin_pb2.OutputPluginState( + plugin_descriptor=output_plugin_pb2.OutputPluginDescriptor( + plugin_name="DummyHuntOutputPlugin2" + ), + plugin_state=jobs_pb2.AttributedDict(), + ) - self.db.WriteHuntOutputPluginsStates(hunt_obj.hunt_id, [state_1, state_2]) + self.db.WriteHuntOutputPluginsStates(hunt_id, [state_1, state_2]) - def Update(s): - s["foo"] = "bar" + def Update(s: jobs_pb2.AttributedDict) -> jobs_pb2.AttributedDict: + el = s.dat.add() + el.k.CopyFrom(jobs_pb2.DataBlob(string="foo")) + el.v.CopyFrom(jobs_pb2.DataBlob(string="bar")) return s - self.db.UpdateHuntOutputPluginState(hunt_obj.hunt_id, 0, Update) + self.db.UpdateHuntOutputPluginState(hunt_id, 0, Update) - states = self.db.ReadHuntOutputPluginsStates(hunt_obj.hunt_id) - self.assertEqual(states[0].plugin_state, {"foo": "bar"}) - self.assertEqual(states[1].plugin_state, {}) + states = self.db.ReadHuntOutputPluginsStates(hunt_id) + self.assertEqual( + states[0].plugin_state, + jobs_pb2.AttributedDict( + dat=[ + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="foo"), + v=jobs_pb2.DataBlob(string="bar"), + ), + ] + ), + ) + self.assertEmpty(states[1].plugin_state.dat) - self.db.UpdateHuntOutputPluginState(hunt_obj.hunt_id, 1, Update) + self.db.UpdateHuntOutputPluginState(hunt_id, 1, Update) - states = self.db.ReadHuntOutputPluginsStates(hunt_obj.hunt_id) - self.assertEqual(states[0].plugin_state, {"foo": "bar"}) - self.assertEqual(states[1].plugin_state, {"foo": "bar"}) + states = self.db.ReadHuntOutputPluginsStates(hunt_id) + self.assertEqual( + states[0].plugin_state, + jobs_pb2.AttributedDict( + dat=[ + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="foo"), + v=jobs_pb2.DataBlob(string="bar"), + ), + ] + ), + ) + + self.assertEqual( + states[1].plugin_state, + jobs_pb2.AttributedDict( + dat=[ + jobs_pb2.KeyValue( + k=jobs_pb2.DataBlob(string="foo"), + v=jobs_pb2.DataBlob(string="bar"), + ), + ] + ), + ) def testReadHuntLogEntriesReturnsEntryFromSingleHuntFlow(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) client_id, flow_id = self._SetupHuntClientAndFlow( - client_id="C.12345678901234aa", hunt_id=hunt_obj.hunt_id) + client_id="C.12345678901234aa", hunt_id=hunt_id + ) self.db.WriteFlowLogEntry( flows_pb2.FlowLogEntry( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah", ) ) - hunt_log_entries = self.db.ReadHuntLogEntries(hunt_obj.hunt_id, 0, 10) + hunt_log_entries = self.db.ReadHuntLogEntries(hunt_id, 0, 10) self.assertLen(hunt_log_entries, 1) self.assertIsInstance(hunt_log_entries[0], flows_pb2.FlowLogEntry) - self.assertEqual(hunt_log_entries[0].hunt_id, hunt_obj.hunt_id) + self.assertEqual(hunt_log_entries[0].hunt_id, hunt_id) self.assertEqual(hunt_log_entries[0].client_id, client_id) self.assertEqual(hunt_log_entries[0].flow_id, flow_id) self.assertEqual(hunt_log_entries[0].message, "blah") - def _WriteNestedAndNonNestedLogEntries(self, hunt_obj): - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + def _WriteNestedAndNonNestedLogEntries(self, hunt_id: str): + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) # Top-level hunt-induced flows should have the hunt's ID. self.db.WriteFlowLogEntry( flows_pb2.FlowLogEntry( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah_a", ) ) @@ -907,7 +1064,7 @@ def _WriteNestedAndNonNestedLogEntries(self, hunt_obj): flows_pb2.FlowLogEntry( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah_b", ) ) @@ -916,13 +1073,14 @@ def _WriteNestedAndNonNestedLogEntries(self, hunt_obj): _, nested_flow_id = self._SetupHuntClientAndFlow( client_id=client_id, parent_flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - flow_id=flow.RandomFlowId()) + hunt_id=hunt_id, + flow_id=flow.RandomFlowId(), + ) self.db.WriteFlowLogEntry( flows_pb2.FlowLogEntry( client_id=client_id, flow_id=nested_flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah_a_%d" % i, ) ) @@ -930,117 +1088,114 @@ def _WriteNestedAndNonNestedLogEntries(self, hunt_obj): flows_pb2.FlowLogEntry( client_id=client_id, flow_id=nested_flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah_b_%d" % i, ) ) def testReadHuntLogEntriesIgnoresNestedFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self._WriteNestedAndNonNestedLogEntries(hunt_obj) + self._WriteNestedAndNonNestedLogEntries(hunt_id) - hunt_log_entries = self.db.ReadHuntLogEntries(hunt_obj.hunt_id, 0, 10) + hunt_log_entries = self.db.ReadHuntLogEntries(hunt_id, 0, 10) self.assertLen(hunt_log_entries, 2) self.assertEqual(hunt_log_entries[0].message, "blah_a") self.assertEqual(hunt_log_entries[1].message, "blah_b") def testCountHuntLogEntriesIgnoresNestedFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self._WriteNestedAndNonNestedLogEntries(hunt_obj) + self._WriteNestedAndNonNestedLogEntries(hunt_id) - num_hunt_log_entries = self.db.CountHuntLogEntries(hunt_obj.hunt_id) + num_hunt_log_entries = self.db.CountHuntLogEntries(hunt_id) self.assertEqual(num_hunt_log_entries, 2) - def _WriteHuntLogEntries(self, msg="blah"): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + def _WriteHuntLogEntries(self, msg="blah") -> str: + hunt_id = db_test_utils.InitializeHunt(self.db) for i in range(10): client_id, flow_id = self._SetupHuntClientAndFlow( - client_id="C.12345678901234a%d" % i, hunt_id=hunt_obj.hunt_id) + client_id="C.12345678901234a%d" % i, hunt_id=hunt_id + ) self.db.WriteFlowLogEntry( flows_pb2.FlowLogEntry( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="%s%d" % (msg, i), ) ) - return hunt_obj + return hunt_id def testReadHuntLogEntriesReturnsEntryFromMultipleHuntFlows(self): - hunt_obj = self._WriteHuntLogEntries() + hunt_id = self._WriteHuntLogEntries() - hunt_log_entries = self.db.ReadHuntLogEntries(hunt_obj.hunt_id, 0, 100) + hunt_log_entries = self.db.ReadHuntLogEntries(hunt_id, 0, 100) self.assertLen(hunt_log_entries, 10) # Make sure messages are returned in timestamps-ascending order. for i, e in enumerate(hunt_log_entries): self.assertEqual(e.message, "blah%d" % i) def testReadHuntLogEntriesCorrectlyAppliesOffsetAndCountFilters(self): - hunt_obj = self._WriteHuntLogEntries() + hunt_id = self._WriteHuntLogEntries() for i in range(10): - hunt_log_entries = self.db.ReadHuntLogEntries(hunt_obj.hunt_id, i, 1) + hunt_log_entries = self.db.ReadHuntLogEntries(hunt_id, i, 1) self.assertLen(hunt_log_entries, 1) self.assertEqual(hunt_log_entries[0].message, "blah%d" % i) def testReadHuntLogEntriesCorrectlyAppliesWithSubstringFilter(self): - hunt_obj = self._WriteHuntLogEntries() + hunt_id = self._WriteHuntLogEntries() hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 100, with_substring="foo") + hunt_id, 0, 100, with_substring="foo" + ) self.assertEmpty(hunt_log_entries) hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 100, with_substring="blah") + hunt_id, 0, 100, with_substring="blah" + ) self.assertLen(hunt_log_entries, 10) # Make sure messages are returned in timestamps-ascending order. for i, e in enumerate(hunt_log_entries): self.assertEqual(e.message, "blah%d" % i) hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 100, with_substring="blah1") + hunt_id, 0, 100, with_substring="blah1" + ) self.assertLen(hunt_log_entries, 1) self.assertEqual(hunt_log_entries[0].message, "blah1") def testReadHuntLogEntriesSubstringFilterIsCorrectlyEscaped(self): - hunt_obj = self._WriteHuntLogEntries("ABC%1") + hunt_id = self._WriteHuntLogEntries("ABC%1") hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 100, with_substring="BC%1") + hunt_id, 0, 100, with_substring="BC%1" + ) self.assertLen(hunt_log_entries, 10) hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 100, with_substring="B%1") - self.assertLen(hunt_log_entries, 0) + hunt_id, 0, 100, with_substring="B%1" + ) + self.assertEmpty(hunt_log_entries) def testReadHuntLogEntriesCorrectlyAppliesCombinationOfFilters(self): - hunt_obj = self._WriteHuntLogEntries() + hunt_id = self._WriteHuntLogEntries() hunt_log_entries = self.db.ReadHuntLogEntries( - hunt_obj.hunt_id, 0, 1, with_substring="blah") + hunt_id, 0, 1, with_substring="blah" + ) self.assertLen(hunt_log_entries, 1) self.assertEqual(hunt_log_entries[0].message, "blah0") def testCountHuntLogEntriesReturnsCorrectHuntLogEntriesCount(self): - hunt_obj = self._WriteHuntLogEntries() + hunt_id = self._WriteHuntLogEntries() - num_entries = self.db.CountHuntLogEntries(hunt_obj.hunt_id) + num_entries = self.db.CountHuntLogEntries(hunt_id) self.assertEqual(num_entries, 10) - def _WriteHuntResults(self, sample_results=None): - for r in sample_results: - self.db.WriteFlowResults([mig_flow_objects.ToProtoFlowResult(r)]) + def _WriteHuntResults(self, sample_results: List[flows_pb2.FlowResult]): + self.db.WriteFlowResults(sample_results) # Update num_replies_sent for all flows referenced in sample_results: # in case the DB implementation relies on this data when @@ -1054,355 +1209,374 @@ def _WriteHuntResults(self, sample_results=None): f_obj.num_replies_sent += delta self.db.UpdateFlow(client_id, flow_id, flow_obj=f_obj) - def _SampleSingleTypeHuntResults(self, - client_id=None, - flow_id=None, - hunt_id=None, - serial_number=None, - count=10): + def _SampleSingleTypeHuntResults( + self, + client_id=None, + flow_id=None, + hunt_id=None, + serial_number=None, + count=10, + ) -> List[flows_pb2.FlowResult]: self.assertIsNotNone(client_id) self.assertIsNotNone(flow_id) self.assertIsNotNone(hunt_id) res = [] for i in range(count): + payload_any = any_pb2.Any() + payload_any.Pack( + jobs_pb2.ClientSummary( + client_id=client_id, + system_manufacturer="manufacturer_%d" % i, + serial_number=serial_number, + install_date=int( + rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10 + i) + ), + ), + ) res.append( - rdf_flow_objects.FlowResult( + flows_pb2.FlowResult( client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, tag="tag_%d" % i, - payload=rdf_client.ClientSummary( - client_id=client_id, - system_manufacturer="manufacturer_%d" % i, - serial_number=serial_number, - install_date=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10 + - i)))) + payload=payload_any, + ), + ) return res - def _SampleTwoTypeHuntResults(self, - client_id=None, - flow_id=None, - hunt_id=None, - serial_number=None, - count_per_type=5, - timestamp_start=10): + def _SampleTwoTypeHuntResults( + self, + client_id=None, + flow_id=None, + hunt_id=None, + serial_number=None, + count_per_type=5, + timestamp_start=10, + ) -> List[flows_pb2.FlowResult]: self.assertIsNotNone(client_id) self.assertIsNotNone(flow_id) self.assertIsNotNone(hunt_id) res = [] for i in range(count_per_type): + paylad_any = any_pb2.Any() + paylad_any.Pack( + jobs_pb2.ClientSummary( + client_id=client_id, + system_manufacturer="manufacturer_%d" % i, + serial_number=serial_number, + install_date=int( + rdfvalue.RDFDatetime.FromSecondsSinceEpoch( + timestamp_start + i + ) + ), + ), + ) res.append( - rdf_flow_objects.FlowResult( + flows_pb2.FlowResult( client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, tag="tag_%d" % i, - payload=rdf_client.ClientSummary( - client_id=client_id, - system_manufacturer="manufacturer_%d" % i, - serial_number=serial_number, - install_date=rdfvalue.RDFDatetime.FromSecondsSinceEpoch( - timestamp_start + i)))) + payload=paylad_any, + ) + ) for i in range(count_per_type): + paylad_any = any_pb2.Any() + paylad_any.Pack( + jobs_pb2.ClientCrash( + client_id=client_id, + timestamp=int( + rdfvalue.RDFDatetime.FromSecondsSinceEpoch( + timestamp_start + i + ) + ), + ), + ) res.append( - rdf_flow_objects.FlowResult( + flows_pb2.FlowResult( client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, tag="tag_%d" % i, - payload=rdf_client.ClientCrash( - client_id=client_id, - timestamp=rdfvalue.RDFDatetime.FromSecondsSinceEpoch( - timestamp_start + i)))) + payload=paylad_any, + ) + ) return res def testReadHuntResultsReadsSingleResultOfSingleType(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id, count=1) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, count=1 + ) self._WriteHuntResults(sample_results) - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 10) + results = self.db.ReadHuntResults(hunt_id, 0, 10) self.assertLen(results, 1) - self.assertEqual(results[0].hunt_id, hunt_obj.hunt_id) + self.assertEqual(results[0].hunt_id, hunt_id) self.assertEqual(results[0].payload, sample_results[0].payload) def testReadHuntResultsReadsMultipleResultOfSingleType(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, - flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count=10) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, count=10 + ) self._WriteHuntResults(sample_results) - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 1000) + results = self.db.ReadHuntResults(hunt_id, 0, 1000) self.assertLen(results, 10) for i in range(10): - self.assertEqual(results[i].hunt_id, hunt_obj.hunt_id) + self.assertEqual(results[i].hunt_id, hunt_id) self.assertEqual(results[i].payload, sample_results[i].payload) def testReadHuntResultsReadsMultipleResultOfMultipleTypes(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id_1, flow_id_1 = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id_1, flow_id_1 = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results_1 = self._SampleTwoTypeHuntResults( - client_id=client_id_1, flow_id=flow_id_1, hunt_id=hunt_obj.hunt_id) + client_id=client_id_1, flow_id=flow_id_1, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results_1) - client_id_2, flow_id_2 = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id_2, flow_id_2 = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results_2 = self._SampleTwoTypeHuntResults( - client_id=client_id_2, flow_id=flow_id_2, hunt_id=hunt_obj.hunt_id) + client_id=client_id_2, flow_id=flow_id_2, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results_2) sample_results = sample_results_1 + sample_results_2 - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 1000) + results = self.db.ReadHuntResults(hunt_id, 0, 1000) self.assertLen(results, len(sample_results)) - self.assertListEqual([i.payload for i in results], - [i.payload for i in sample_results]) + self.assertListEqual( + [i.payload for i in results], [i.payload for i in sample_results] + ) def testReadHuntResultsCorrectlyAppliedOffsetAndCountFilters(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) sample_results = [] - for i in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + for _ in range(10): + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleSingleTypeHuntResults( - client_id=client_id, - flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count=1) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, count=1 + ) sample_results.extend(results) self._WriteHuntResults(results) for l in range(1, 11): for i in range(10): - results = self.db.ReadHuntResults(hunt_obj.hunt_id, i, l) - expected = sample_results[i:i + l] + results = self.db.ReadHuntResults(hunt_id, i, l) + expected = sample_results[i : i + l] result_payloads = [x.payload for x in results] expected_payloads = [x.payload for x in expected] self.assertEqual( - result_payloads, expected_payloads, - "Results differ from expected (from %d, size %d): %s vs %s" % - (i, l, result_payloads, expected_payloads)) + result_payloads, + expected_payloads, + "Results differ from expected (from %d, size %d): %s vs %s" + % (i, l, result_payloads, expected_payloads), + ) def testReadHuntResultsCorrectlyAppliesWithTagFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 100, with_tag="blah") + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_tag="blah") self.assertFalse(results) - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 100, with_tag="tag") + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_tag="tag") self.assertFalse(results) - results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_tag="tag_1") - self.assertEqual([i.payload for i in results], - [i.payload for i in sample_results if i.tag == "tag_1"]) + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_tag="tag_1") + self.assertEqual( + [i.payload for i in results], + [i.payload for i in sample_results if i.tag == "tag_1"], + ) def testReadHuntResultsCorrectlyAppliesWithTypeFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) sample_results = [] - for i in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + for _ in range(10): + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleTwoTypeHuntResults( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count_per_type=1) + hunt_id=hunt_id, + count_per_type=1, + ) sample_results.extend(results) self._WriteHuntResults(results) results = self.db.ReadHuntResults( - hunt_obj.hunt_id, - 0, - 100, - with_type=rdf_client.ClientInformation.__name__) + hunt_id, 0, 100, with_type=jobs_pb2.ClientInformation.__name__ + ) self.assertFalse(results) results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_type=rdf_client.ClientSummary.__name__) + hunt_id, 0, 100, with_type=jobs_pb2.ClientSummary.__name__ + ) self.assertCountEqual( [i.payload for i in results], [ i.payload for i in sample_results - if isinstance(i.payload, rdf_client.ClientSummary) + if i.payload.Is(jobs_pb2.ClientSummary.DESCRIPTOR) ], ) def testReadHuntResultsCorrectlyAppliesWithSubstringFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) - results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_substring="blah") + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_substring="blah") self.assertEmpty(results) results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_substring="manufacturer") + hunt_id, 0, 100, with_substring="manufacturer" + ) self.assertEqual( [i.payload for i in results], [i.payload for i in sample_results], ) results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_substring="manufacturer_1") + hunt_id, 0, 100, with_substring="manufacturer_1" + ) self.assertEqual([i.payload for i in results], [sample_results[1].payload]) def testReadHuntResultsSubstringFilterIsCorrectlyEscaped(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - serial_number="ABC%123") + hunt_id=hunt_id, + serial_number="ABC%123", + ) self._WriteHuntResults(sample_results) - results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_substring="ABC%123") + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_substring="ABC%123") self.assertLen(results, 10) - results = self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 100, with_substring="AB%23") - self.assertLen(results, 0) + results = self.db.ReadHuntResults(hunt_id, 0, 100, with_substring="AB%23") + self.assertEmpty(results) def testReadHuntResultsCorrectlyAppliesVariousCombinationsOfFilters(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) sample_results = [] for _ in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleTwoTypeHuntResults( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count_per_type=5) + hunt_id=hunt_id, + count_per_type=5, + ) sample_results.extend(results) self._WriteHuntResults(results) + # TODO: Clean up this test. tags = { None: list(sample_results), - "tag_1": [s for s in sample_results if s.tag == "tag_1"] + "tag_1": [s for s in sample_results if s.tag == "tag_1"], } substrings = { - None: - list(sample_results), - "manufacturer": [ - s for s in sample_results - if "manufacturer" in getattr(s.payload, "system_manufacturer", "") - ], - "manufacturer_1": [ - s for s in sample_results - if "manufacturer_1" in getattr(s.payload, "system_manufacturer", "") - ] + None: list(sample_results), } + manufacturer = [] + manufacturer_1 = [] + for s in sample_results: + if s.payload.Is(jobs_pb2.ClientSummary.DESCRIPTOR): + payload = jobs_pb2.ClientSummary() + elif s.payload.Is(jobs_pb2.ClientCrash.DESCRIPTOR): + payload = jobs_pb2.ClientCrash() + else: + continue + s.payload.Unpack(payload) + if "manufacturer" in getattr(payload, "system_manufacturer", ""): + manufacturer.append(s) + if "manufacturer_1" in getattr(payload, "system_manufacturer", ""): + manufacturer_1.append(s) + + substrings["manufacturer"] = manufacturer + substrings["manufacturer_1"] = manufacturer_1 + types = { - None: - list(sample_results), - rdf_client.ClientSummary.__name__: [ - s for s in sample_results - if isinstance(s.payload, rdf_client.ClientSummary) - ] + None: list(sample_results), + jobs_pb2.ClientSummary.__name__: [ + s + for s in sample_results + if s.payload.Is(jobs_pb2.ClientSummary.DESCRIPTOR) + ], } for tag_value, tag_expected in tags.items(): for substring_value, substring_expected in substrings.items(): for type_value, type_expected in types.items(): expected = [ - e for e in tag_expected + e + for e in tag_expected if e in substring_expected and e in type_expected ] results = self.db.ReadHuntResults( - hunt_obj.hunt_id, + hunt_id, 0, 100, with_tag=tag_value, with_type=type_value, - with_substring=substring_value) + with_substring=substring_value, + ) self.assertCountEqual( - [i.payload for i in expected], [i.payload for i in results], + [i.payload for i in expected], + [i.payload for i in results], "Result items do not match for " - "(tag=%s, type=%s, substring=%s): %s vs %s" % - (tag_value, type_value, substring_value, expected, results)) + "(tag=%s, type=%s, substring=%s): %s vs %s" + % (tag_value, type_value, substring_value, expected, results), + ) def testReadHuntResultsReturnsPayloadWithMissingTypeAsSpecialValue(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) type_name = rdf_client.ClientSummary.__name__ - try: - cls = rdfvalue.RDFValue.classes.pop(type_name) + cls = rdfvalue.RDFValue.classes.pop(type_name) - results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 100) - finally: - rdfvalue.RDFValue.classes[type_name] = cls + results = self.db.ReadHuntResults(hunt_id, 0, 100) + rdfvalue.RDFValue.classes[type_name] = cls self.assertLen(sample_results, len(results)) for r in results: - self.assertIsInstance(r.payload, - rdf_objects.SerializedValueOfUnrecognizedType) - self.assertEqual(r.payload.type_name, type_name) + self.assertTrue( + r.payload.Is(objects_pb2.SerializedValueOfUnrecognizedType.DESCRIPTOR) + ) + payload = objects_pb2.SerializedValueOfUnrecognizedType() + r.payload.Unpack(payload) + self.assertEqual(payload.type_name, type_name) def testReadHuntResultsIgnoresChildFlowsResults(self): client_id = db_test_utils.InitializeClient(self.db) @@ -1437,7 +1611,6 @@ def testReadHuntResultsIgnoresChildFlowsResults(self): self.db.WriteFlowResults([child_flow_result]) results = self.db.ReadHuntResults(hunt_id, offset=0, count=1024) - results = list(map(mig_flow_objects.ToProtoFlowResult, results)) self.assertLen(results, 1) @@ -1446,241 +1619,228 @@ def testReadHuntResultsIgnoresChildFlowsResults(self): self.assertEqual(result.fqdn, "hunt.example.com") def testCountHuntResultsReturnsCorrectResultsCount(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) - num_results = self.db.CountHuntResults(hunt_obj.hunt_id) + num_results = self.db.CountHuntResults(hunt_id) self.assertLen(sample_results, num_results) def testCountHuntResultsCorrectlyAppliesWithTagFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, flow_id=flow_id, hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) - num_results = self.db.CountHuntResults(hunt_obj.hunt_id, with_tag="blah") + num_results = self.db.CountHuntResults(hunt_id, with_tag="blah") self.assertEqual(num_results, 0) - num_results = self.db.CountHuntResults(hunt_obj.hunt_id, with_tag="tag_1") + num_results = self.db.CountHuntResults(hunt_id, with_tag="tag_1") self.assertEqual(num_results, 1) def testCountHuntResultsCorrectlyAppliesWithTypeFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) sample_results = [] for _ in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleTwoTypeHuntResults( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count_per_type=1) + hunt_id=hunt_id, + count_per_type=1, + ) sample_results.extend(results) self._WriteHuntResults(results) num_results = self.db.CountHuntResults( - hunt_obj.hunt_id, with_type=rdf_client.ClientInformation.__name__) + hunt_id, with_type=rdf_client.ClientInformation.__name__ + ) self.assertEqual(num_results, 0) num_results = self.db.CountHuntResults( - hunt_obj.hunt_id, with_type=rdf_client.ClientSummary.__name__) + hunt_id, with_type=rdf_client.ClientSummary.__name__ + ) self.assertEqual(num_results, 10) num_results = self.db.CountHuntResults( - hunt_obj.hunt_id, with_type=rdf_client.ClientCrash.__name__) + hunt_id, with_type=rdf_client.ClientCrash.__name__ + ) self.assertEqual(num_results, 10) def testCountHuntResultsCorrectlyAppliesWithTagAndWithTypeFilters(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) sample_results = [] for _ in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleTwoTypeHuntResults( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count_per_type=5) + hunt_id=hunt_id, + count_per_type=5, + ) sample_results.extend(results) self._WriteHuntResults(results) num_results = self.db.CountHuntResults( - hunt_obj.hunt_id, - with_tag="tag_1", - with_type=rdf_client.ClientSummary.__name__) + hunt_id, with_tag="tag_1", with_type=rdf_client.ClientSummary.__name__ + ) self.assertEqual(num_results, 10) def testCountHuntResultsCorrectlyAppliesWithTimestampFilter(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - sample_results = [] for _ in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, - flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count=10) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, count=10 + ) self._WriteHuntResults(sample_results[:5]) self._WriteHuntResults(sample_results[5:]) - hunt_results = self.db.ReadHuntResults(hunt_obj.hunt_id, 0, 10) + hunt_results = self.db.ReadHuntResults(hunt_id, 0, 10) for hr in hunt_results: - self.assertEqual([hr], - self.db.ReadHuntResults( - hunt_obj.hunt_id, 0, 10, - with_timestamp=hr.timestamp)) + self.assertEqual( + [hr], + self.db.ReadHuntResults( + hunt_id, + 0, + 10, + with_timestamp=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + hr.timestamp + ), + ), + ) def testCountHuntResultsByTypeGroupsResultsCorrectly(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) results = self._SampleTwoTypeHuntResults( - client_id=client_id, - flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, - count_per_type=5) + client_id=client_id, flow_id=flow_id, hunt_id=hunt_id, count_per_type=5 + ) self._WriteHuntResults(results) - counts = self.db.CountHuntResultsByType(hunt_obj.hunt_id) + counts = self.db.CountHuntResultsByType(hunt_id) for key in counts: self.assertIsInstance(key, str) - self.assertEqual(counts, { - rdf_client.ClientSummary.__name__: 5, - rdf_client.ClientCrash.__name__: 5 - }) + self.assertEqual( + counts, + { + rdf_client.ClientSummary.__name__: 5, + rdf_client.ClientCrash.__name__: 5, + }, + ) def testReadHuntFlowsReturnsEmptyListWhenNoFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self.assertEmpty(self.db.ReadHuntFlows(hunt_obj.hunt_id, 0, 10)) + self.assertEmpty(self.db.ReadHuntFlows(hunt_id, 0, 10)) def testReadHuntFlowsReturnsAllHuntFlowsWhenNoFilterCondition(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - _, flow_id_1 = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) - _, flow_id_2 = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + _, flow_id_1 = self._SetupHuntClientAndFlow(hunt_id=hunt_id) + _, flow_id_2 = self._SetupHuntClientAndFlow(hunt_id=hunt_id) - flows = self.db.ReadHuntFlows(hunt_obj.hunt_id, 0, 10) + flows = self.db.ReadHuntFlows(hunt_id, 0, 10) self.assertCountEqual([f.flow_id for f in flows], [flow_id_1, flow_id_2]) - def _BuildFilterConditionExpectations(self, hunt_obj): + def _BuildFilterConditionExpectations(self, hunt_id): _, running_flow_id = self._SetupHuntClientAndFlow( - flow_state=rdf_flow_objects.Flow.FlowState.RUNNING, - hunt_id=hunt_obj.hunt_id) + flow_state=rdf_flow_objects.Flow.FlowState.RUNNING, hunt_id=hunt_id + ) _, succeeded_flow_id = self._SetupHuntClientAndFlow( - flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, - hunt_id=hunt_obj.hunt_id) + flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, hunt_id=hunt_id + ) _, failed_flow_id = self._SetupHuntClientAndFlow( - flow_state=rdf_flow_objects.Flow.FlowState.ERROR, - hunt_id=hunt_obj.hunt_id) + flow_state=rdf_flow_objects.Flow.FlowState.ERROR, hunt_id=hunt_id + ) _, crashed_flow_id = self._SetupHuntClientAndFlow( - flow_state=rdf_flow_objects.Flow.FlowState.CRASHED, - hunt_id=hunt_obj.hunt_id) + flow_state=rdf_flow_objects.Flow.FlowState.CRASHED, hunt_id=hunt_id + ) client_id, flow_with_results_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + hunt_id=hunt_id + ) sample_results = self._SampleSingleTypeHuntResults( - client_id=client_id, - flow_id=flow_with_results_id, - hunt_id=hunt_obj.hunt_id) + client_id=client_id, flow_id=flow_with_results_id, hunt_id=hunt_id + ) self._WriteHuntResults(sample_results) return { db.HuntFlowsCondition.UNSET: [ - running_flow_id, succeeded_flow_id, failed_flow_id, crashed_flow_id, - flow_with_results_id + running_flow_id, + succeeded_flow_id, + failed_flow_id, + crashed_flow_id, + flow_with_results_id, ], db.HuntFlowsCondition.FAILED_FLOWS_ONLY: [failed_flow_id], db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY: [succeeded_flow_id], db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY: [ - failed_flow_id, succeeded_flow_id + failed_flow_id, + succeeded_flow_id, ], db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY: [running_flow_id], db.HuntFlowsCondition.CRASHED_FLOWS_ONLY: [crashed_flow_id], } def testReadHuntFlowsAppliesFilterConditionCorrectly(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - expectations = self._BuildFilterConditionExpectations(hunt_obj) + expectations = self._BuildFilterConditionExpectations(hunt_id) for filter_condition, expected in expectations.items(): results = self.db.ReadHuntFlows( - hunt_obj.hunt_id, 0, 10, filter_condition=filter_condition) + hunt_id, 0, 10, filter_condition=filter_condition + ) results_ids = [r.flow_id for r in results] self.assertCountEqual( - results_ids, expected, "Result items do not match for " - "(filter_condition=%d): %s vs %s" % - (filter_condition, expected, results_ids)) + results_ids, + expected, + "Result items do not match for (filter_condition=%s): %s vs %s" + % (filter_condition, expected, results_ids), + ) def testReadHuntFlowsCorrectlyAppliesOffsetAndCountFilters(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - expectations = self._BuildFilterConditionExpectations(hunt_obj) + expectations = self._BuildFilterConditionExpectations(hunt_id) for filter_condition, _ in expectations.items(): full_results = self.db.ReadHuntFlows( - hunt_obj.hunt_id, 0, 1024, filter_condition=filter_condition) + hunt_id, 0, 1024, filter_condition=filter_condition + ) full_results_ids = [r.flow_id for r in full_results] for index in range(0, 2): for count in range(1, 3): results = self.db.ReadHuntFlows( - hunt_obj.hunt_id, index, count, filter_condition=filter_condition) + hunt_id, index, count, filter_condition=filter_condition + ) results_ids = [r.flow_id for r in results] - expected_ids = full_results_ids[index:index + count] + expected_ids = full_results_ids[index : index + count] self.assertCountEqual( - results_ids, expected_ids, "Result items do not match for " - "(filter_condition=%d, index=%d, count=%d): %s vs %s" % - (filter_condition, index, count, expected_ids, results_ids)) + results_ids, + expected_ids, + "Result items do not match for " + "(filter_condition=%s, index=%d, count=%d): %s vs %s" + % (filter_condition, index, count, expected_ids, results_ids), + ) def testReadHuntFlowsIgnoresSubflows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - hunt_id = hunt_obj.hunt_id + hunt_id = db_test_utils.InitializeHunt(self.db) client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING) + hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING + ) # Whatever state the subflow is in, it should be ignored. self._SetupHuntClientAndFlow( @@ -1688,37 +1848,37 @@ def testReadHuntFlowsIgnoresSubflows(self): hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.ERROR) + flow_state=rdf_flow_objects.Flow.FlowState.ERROR, + ) self._SetupHuntClientAndFlow( client_id=client_id, hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.FINISHED) + flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + ) self._SetupHuntClientAndFlow( client_id=client_id, hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.RUNNING) + flow_state=rdf_flow_objects.Flow.FlowState.RUNNING, + ) for state, expected_results in [ (db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY, 0), (db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY, 0), - (db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY, 1) + (db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY, 1), ]: results = self.db.ReadHuntFlows(hunt_id, 0, 10, filter_condition=state) self.assertLen(results, expected_results) def testCountHuntFlowsIgnoresSubflows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - hunt_id = hunt_obj.hunt_id + hunt_id = db_test_utils.InitializeHunt(self.db) client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING) + hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING + ) # Whatever state the subflow is in, it should be ignored. self._SetupHuntClientAndFlow( @@ -1726,62 +1886,178 @@ def testCountHuntFlowsIgnoresSubflows(self): hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.ERROR) + flow_state=rdf_flow_objects.Flow.FlowState.ERROR, + ) self._SetupHuntClientAndFlow( client_id=client_id, hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.FINISHED) + flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + ) self._SetupHuntClientAndFlow( client_id=client_id, hunt_id=hunt_id, flow_id=flow.RandomFlowId(), parent_flow_id=flow_id, - flow_state=rdf_flow_objects.Flow.FlowState.RUNNING) + flow_state=rdf_flow_objects.Flow.FlowState.RUNNING, + ) self.assertEqual(self.db.CountHuntFlows(hunt_id), 1) def testCountHuntFlowsReturnsEmptyListWhenNoFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self.assertEqual(self.db.CountHuntFlows(hunt_obj.hunt_id), 0) + self.assertEqual(self.db.CountHuntFlows(hunt_id), 0) def testCountHuntFlowsReturnsAllHuntFlowsWhenNoFilterCondition(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) - self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + self._SetupHuntClientAndFlow(hunt_id=hunt_id) + self._SetupHuntClientAndFlow(hunt_id=hunt_id) - self.assertEqual(self.db.CountHuntFlows(hunt_obj.hunt_id), 2) + self.assertEqual(self.db.CountHuntFlows(hunt_id), 2) def testCountHuntFlowsAppliesFilterConditionCorrectly(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - expectations = self._BuildFilterConditionExpectations(hunt_obj) + expectations = self._BuildFilterConditionExpectations(hunt_id) for filter_condition, expected in expectations.items(): result = self.db.CountHuntFlows( - hunt_obj.hunt_id, filter_condition=filter_condition) + hunt_id, filter_condition=filter_condition + ) self.assertLen( - expected, result, "Result count does not match for " - "(filter_condition=%d): %d vs %d" % - (filter_condition, len(expected), result)) + expected, + result, + "Result count does not match for (filter_condition=%s): %d vs %d" + % (filter_condition, len(expected), result), + ) + + def testReadHuntFlowErrors(self): + hunt_id = db_test_utils.InitializeHunt(self.db) + + client_id_1 = db_test_utils.InitializeClient(self.db) + client_id_2 = db_test_utils.InitializeClient(self.db) + + flow_id_1 = db_test_utils.InitializeFlow( + self.db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + self.db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + pre_update_1_time = self.db.Now() + + flow_obj_1 = self.db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + flow_obj_1.backtrace = "File 'foo.py', line 1, in 'foo'" + self.db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + pre_update_2_time = self.db.Now() + + flow_obj_2 = self.db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + self.db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + results = self.db.ReadHuntFlowErrors(hunt_id, offset=0, count=1024) + self.assertLen(results, 2) + + self.assertEqual(results[client_id_1].message, "ERROR_1") + self.assertGreater(results[client_id_1].time, pre_update_1_time) + self.assertIsNotNone(results[client_id_1].backtrace) + + self.assertEqual(results[client_id_2].message, "ERROR_2") + self.assertGreater(results[client_id_2].time, pre_update_2_time) + self.assertIsNone(results[client_id_2].backtrace) + + def testReadHuntFlowErrorsIgnoreSubflows(self): + hunt_id = db_test_utils.InitializeHunt(self.db) + + client_id = db_test_utils.InitializeClient(self.db) + + flow_id = db_test_utils.InitializeFlow( + self.db, + client_id=client_id, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + child_flow_id = db_test_utils.InitializeFlow( + self.db, + client_id=client_id, + parent_flow_id=flow_id, + parent_hunt_id=hunt_id, + ) + + flow_obj = self.db.ReadFlowObject(client_id, flow_id) + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj.error_message = "ERROR" + self.db.UpdateFlow(client_id, flow_id, flow_obj) + + child_flow_obj = self.db.ReadFlowObject(client_id, child_flow_id) + child_flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR + child_flow_obj.error_message = "CHILD_ERROR" + self.db.UpdateFlow(client_id, child_flow_id, child_flow_obj) + + results = self.db.ReadHuntFlowErrors(hunt_id, offset=0, count=1024) + self.assertLen(results, 1) + self.assertEqual(results[client_id].message, "ERROR") + + def testReadHuntFlowErrorsOffsetAndCount(self): + hunt_id = db_test_utils.InitializeHunt(self.db) + + client_id_1 = db_test_utils.InitializeClient(self.db) + client_id_2 = db_test_utils.InitializeClient(self.db) + client_id_3 = db_test_utils.InitializeClient(self.db) + + flow_id_1 = db_test_utils.InitializeFlow( + self.db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + self.db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_3 = db_test_utils.InitializeFlow( + self.db, + client_id=client_id_3, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + flow_obj_1 = self.db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + self.db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + flow_obj_2 = self.db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + self.db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + flow_obj_3 = self.db.ReadFlowObject(client_id_3, flow_id_3) + flow_obj_3.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_3.error_message = "ERROR_3" + self.db.UpdateFlow(client_id_3, flow_id_3, flow_obj_3) + + results = self.db.ReadHuntFlowErrors(hunt_id, offset=1, count=1) + self.assertLen(results, 1) + self.assertEqual(results[client_id_2].message, "ERROR_2") def testReadHuntCountersForNewHunt(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - hunt_counters = self.db.ReadHuntCounters(hunt_obj.hunt_id) + hunt_id = db_test_utils.InitializeHunt(self.db) + hunt_counters = self.db.ReadHuntCounters(hunt_id) self.assertEqual(hunt_counters.num_clients, 0) self.assertEqual(hunt_counters.num_successful_clients, 0) self.assertEqual(hunt_counters.num_failed_clients, 0) @@ -1792,28 +2068,80 @@ def testReadHuntCountersForNewHunt(self): self.assertEqual(hunt_counters.total_cpu_seconds, 0) self.assertEqual(hunt_counters.total_network_bytes_sent, 0) + def testReadHuntsCountersForEmptyList(self): + hunts_counters = self.db.ReadHuntsCounters([]) + self.assertEmpty(hunts_counters) + + def testReadHuntsCountersForSeveralHunts(self): + hunt_id_1 = db_test_utils.InitializeHunt(self.db) + hunt_id_2 = db_test_utils.InitializeHunt(self.db) + + hunts_counters = self.db.ReadHuntsCounters([hunt_id_1, hunt_id_2]) + self.assertLen(hunts_counters, 2) + self.assertIsInstance(hunts_counters[hunt_id_1], db.HuntCounters) + self.assertIsInstance(hunts_counters[hunt_id_2], db.HuntCounters) + + def testReadHuntsCountersForASubsetOfCreatedHunts(self): + db_test_utils.InitializeUser(self.db, "user") + + hunt_id_1 = db_test_utils.InitializeHunt(self.db, creator="user") + _ = db_test_utils.InitializeHunt(self.db, creator="user") + hunt_id_3 = db_test_utils.InitializeHunt(self.db, creator="user") + + hunts_counters = self.db.ReadHuntsCounters([hunt_id_1, hunt_id_3]) + + self.assertLen(hunts_counters, 2) + self.assertIsInstance(hunts_counters[hunt_id_1], db.HuntCounters) + self.assertIsInstance(hunts_counters[hunt_id_3], db.HuntCounters) + + def testReadHuntsCountersReturnsSameResultAsReadHuntCounters(self): + db_test_utils.InitializeUser(self.db, "user") + + hunt_id_1 = db_test_utils.InitializeHunt(self.db, creator="user") + self._BuildFilterConditionExpectations(hunt_id_1) + + hunt_id_2 = db_test_utils.InitializeHunt(self.db, creator="user") + self._BuildFilterConditionExpectations(hunt_id_2) + + hunts_counters = self.db.ReadHuntsCounters([hunt_id_1, hunt_id_2]) + self.assertLen(hunts_counters, 2) + self.assertEqual( + hunts_counters[hunt_id_1], + self.db.ReadHuntCounters(hunt_id_1), + ) + self.assertEqual( + hunts_counters[hunt_id_2], + self.db.ReadHuntCounters(hunt_id_2), + ) + def testReadHuntCountersCorrectlyAggregatesResultsAmongDifferentFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - - expectations = self._BuildFilterConditionExpectations(hunt_obj) - - hunt_counters = self.db.ReadHuntCounters(hunt_obj.hunt_id) - self.assertLen(expectations[db.HuntFlowsCondition.UNSET], - hunt_counters.num_clients) - self.assertLen(expectations[db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY], - hunt_counters.num_successful_clients) - self.assertLen(expectations[db.HuntFlowsCondition.FAILED_FLOWS_ONLY], - hunt_counters.num_failed_clients) - self.assertLen(expectations[db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY], - hunt_counters.num_running_clients) + hunt_id = db_test_utils.InitializeHunt(self.db) + + expectations = self._BuildFilterConditionExpectations(hunt_id) + + hunt_counters = self.db.ReadHuntCounters(hunt_id) + self.assertLen( + expectations[db.HuntFlowsCondition.UNSET], hunt_counters.num_clients + ) + self.assertLen( + expectations[db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY], + hunt_counters.num_successful_clients, + ) + self.assertLen( + expectations[db.HuntFlowsCondition.FAILED_FLOWS_ONLY], + hunt_counters.num_failed_clients, + ) + self.assertLen( + expectations[db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY], + hunt_counters.num_running_clients, + ) # _BuildFilterConditionExpectations writes 10 sample results for one client. self.assertEqual(hunt_counters.num_clients_with_results, 1) - self.assertLen(expectations[db.HuntFlowsCondition.CRASHED_FLOWS_ONLY], - hunt_counters.num_crashed_clients) + self.assertLen( + expectations[db.HuntFlowsCondition.CRASHED_FLOWS_ONLY], + hunt_counters.num_crashed_clients, + ) # _BuildFilterConditionExpectations writes 10 sample results. self.assertEqual(hunt_counters.num_results, 10) @@ -1823,28 +2151,37 @@ def testReadHuntCountersCorrectlyAggregatesResultsAmongDifferentFlows(self): # Check that after adding a flow with resource metrics, total counters # get updated. - self._SetupHuntClientAndFlow( - flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + client_id = db_test_utils.InitializeClient(self.db) + db_test_utils.InitializeFlow( + self.db, + client_id, + flow_id=hunt_id, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=4.5, system_cpu_time=10), + user_cpu_time=4.5, system_cpu_time=10 + ), network_bytes_sent=42, - hunt_id=hunt_obj.hunt_id) - hunt_counters = self.db.ReadHuntCounters(hunt_obj.hunt_id) + ) + + hunt_counters = self.db.ReadHuntCounters(hunt_id) self.assertAlmostEqual(hunt_counters.total_cpu_seconds, 14.5) self.assertEqual(hunt_counters.total_network_bytes_sent, 42) def testReadHuntClientResourcesStatsIgnoresSubflows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow( + client_id = db_test_utils.InitializeClient(self.db) + flow_id = db_test_utils.InitializeFlow( + self.db, + client_id, + flow_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=100, system_cpu_time=200), + user_cpu_time=100, system_cpu_time=200 + ), network_bytes_sent=300, - hunt_id=hunt_obj.hunt_id) + ) # Create a subflow that used some resources too. This resource usage is # already accounted for in the parent flow so the overall hunt resource @@ -1853,13 +2190,16 @@ def testReadHuntClientResourcesStatsIgnoresSubflows(self): client_id=client_id, flow_id="12345678", parent_flow_id=flow_id, - parent_hunt_id=hunt_obj.hunt_id, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=10, system_cpu_time=20), - network_bytes_sent=30) + user_cpu_time=10, system_cpu_time=20 + ), + network_bytes_sent=30, + ) + sub_flow = mig_flow_objects.ToProtoFlow(sub_flow) self.db.WriteFlowObject(sub_flow) - usage_stats = self.db.ReadHuntClientResourcesStats(hunt_obj.hunt_id) + usage_stats = self.db.ReadHuntClientResourcesStats(hunt_id) network_bins = usage_stats.network_bytes_sent_stats.histogram.bins user_cpu_bins = usage_stats.user_cpu_stats.histogram.bins system_cpu_bins = usage_stats.system_cpu_stats.histogram.bins @@ -1873,130 +2213,263 @@ def testReadHuntClientResourcesStatsIgnoresSubflows(self): self.assertLen(usage_stats.worst_performers, 1) def testReadHuntClientResourcesStatsCorrectlyAggregatesData(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) flow_data = [] - expected_user_cpu_histogram = rdf_stats.StatsHistogram.FromBins( - rdf_stats.ClientResourcesStats.CPU_STATS_BINS) - expected_system_cpu_histogram = rdf_stats.StatsHistogram.FromBins( - rdf_stats.ClientResourcesStats.CPU_STATS_BINS) - expected_network_histogram = rdf_stats.StatsHistogram.FromBins( - rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS) + for i in range(10): user_cpu_time = 4.5 + i system_cpu_time = 10 + i * 2 network_bytes_sent = 42 + i * 3 - client_id, flow_id = self._SetupHuntClientAndFlow( + client_id = db_test_utils.InitializeClient(self.db) + flow_id = db_test_utils.InitializeFlow( + self.db, + client_id, + flow_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=user_cpu_time, system_cpu_time=system_cpu_time), + user_cpu_time=user_cpu_time, system_cpu_time=system_cpu_time + ), network_bytes_sent=network_bytes_sent, - hunt_id=hunt_obj.hunt_id) - - expected_user_cpu_histogram.RegisterValue(user_cpu_time) - expected_system_cpu_histogram.RegisterValue(system_cpu_time) - expected_network_histogram.RegisterValue(network_bytes_sent) + ) - flow_data.append((client_id, flow_id, (user_cpu_time, system_cpu_time, - network_bytes_sent))) + flow_data.append(( + client_id, + flow_id, + (user_cpu_time, system_cpu_time, network_bytes_sent), + )) + + usage_stats = self.db.ReadHuntClientResourcesStats(hunt_id) + + expected_cpu_bins = rdf_stats.ClientResourcesStats.CPU_STATS_BINS + expected_network_bins = rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS + expected_user_cpu_histogram = jobs_pb2.StatsHistogram( + bins=[ + jobs_pb2.StatsHistogramBin(num=num, range_max_value=max_range) + for num, max_range in zip( + 12 * [0] + [1, 1, 1, 1, 1, 1, 4, 0], expected_cpu_bins + ) + ] + ) - usage_stats = self.db.ReadHuntClientResourcesStats(hunt_obj.hunt_id) + expected_system_cpu_histogram = jobs_pb2.StatsHistogram( + bins=[ + jobs_pb2.StatsHistogramBin(num=num, range_max_value=max_range) + for num, max_range in zip(18 * [0] + [3, 7], expected_cpu_bins) + ] + ) + expected_network_histogram = jobs_pb2.StatsHistogram( + bins=[ + jobs_pb2.StatsHistogramBin(num=num, range_max_value=max_range) + for num, max_range in zip( + [0, 0, 8, 2] + 14 * [0], expected_network_bins + ) + ] + ) self.assertEqual(usage_stats.user_cpu_stats.num, 10) - self.assertAlmostEqual(usage_stats.user_cpu_stats.mean, 9) + self.assertAlmostEqual(usage_stats.user_cpu_stats.sum, 90) self.assertAlmostEqual(usage_stats.user_cpu_stats.stddev, 2.87228, 5) - self.assertLen(usage_stats.user_cpu_stats.histogram.bins, - len(expected_user_cpu_histogram.bins)) - for b, model_b in zip(usage_stats.user_cpu_stats.histogram.bins, - expected_user_cpu_histogram.bins): - self.assertAlmostEqual(b.range_max_value, model_b.range_max_value) - self.assertEqual(b.num, model_b.num) + self.assertLen( + usage_stats.user_cpu_stats.histogram.bins, + len(expected_user_cpu_histogram.bins), + ) + for bin_index, (b, exp_b) in enumerate( + zip( + usage_stats.user_cpu_stats.histogram.bins, + expected_user_cpu_histogram.bins, + ) + ): + self.assertAlmostEqual( + b.range_max_value, exp_b.range_max_value, msg=f"bin index {bin_index}" + ) + self.assertEqual(b.num, exp_b.num, msg=f"bin index {bin_index}") self.assertEqual(usage_stats.system_cpu_stats.num, 10) - self.assertAlmostEqual(usage_stats.system_cpu_stats.mean, 19) + self.assertAlmostEqual(usage_stats.system_cpu_stats.sum, 190) self.assertAlmostEqual(usage_stats.system_cpu_stats.stddev, 5.74456, 5) - self.assertLen(usage_stats.system_cpu_stats.histogram.bins, - len(expected_system_cpu_histogram.bins)) - for b, model_b in zip(usage_stats.system_cpu_stats.histogram.bins, - expected_system_cpu_histogram.bins): - self.assertAlmostEqual(b.range_max_value, model_b.range_max_value) - self.assertEqual(b.num, model_b.num) + self.assertLen( + usage_stats.system_cpu_stats.histogram.bins, + len(expected_system_cpu_histogram.bins), + ) + for bin_index, (b, exp_b) in enumerate( + zip( + usage_stats.system_cpu_stats.histogram.bins, + expected_system_cpu_histogram.bins, + ) + ): + self.assertAlmostEqual( + b.range_max_value, exp_b.range_max_value, msg=f"bin index {bin_index}" + ) + self.assertEqual(b.num, exp_b.num, msg=f"bin index {bin_index}") self.assertEqual(usage_stats.network_bytes_sent_stats.num, 10) - self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.mean, 55.5) - self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.stddev, 8.6168, - 4) - self.assertLen(usage_stats.network_bytes_sent_stats.histogram.bins, - len(expected_network_histogram.bins)) - for b, model_b in zip(usage_stats.network_bytes_sent_stats.histogram.bins, - expected_network_histogram.bins): - self.assertAlmostEqual(b.range_max_value, model_b.range_max_value) - self.assertEqual(b.num, model_b.num) + self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.sum, 555) + self.assertAlmostEqual( + usage_stats.network_bytes_sent_stats.stddev, 8.6168, 4 + ) + self.assertLen( + usage_stats.network_bytes_sent_stats.histogram.bins, + len(expected_network_histogram.bins), + ) + for bin_index, (b, model_b) in enumerate( + zip( + usage_stats.network_bytes_sent_stats.histogram.bins, + expected_network_histogram.bins, + ) + ): + self.assertAlmostEqual( + b.range_max_value, + model_b.range_max_value, + msg=f"bin index {bin_index}", + ) + self.assertEqual(b.num, model_b.num, msg=f"bin index {bin_index}") self.assertLen(usage_stats.worst_performers, 10) - for worst_performer, flow_d in zip(usage_stats.worst_performers, - reversed(flow_data)): - client_id, flow_id, (user_cpu_time, system_cpu_time, - network_bytes_sent) = flow_d - self.assertEqual(worst_performer.client_id.Basename(), client_id) - self.assertAlmostEqual(worst_performer.cpu_usage.user_cpu_time, - user_cpu_time) - self.assertAlmostEqual(worst_performer.cpu_usage.system_cpu_time, - system_cpu_time) + for worst_performer, flow_d in zip( + usage_stats.worst_performers, reversed(flow_data) + ): + ( + client_id, + flow_id, + (user_cpu_time, system_cpu_time, network_bytes_sent), + ) = flow_d + self.assertEqual( + rdf_client.ClientURN.FromHumanReadable( + worst_performer.client_id + ).Basename(), + client_id, + ) + self.assertAlmostEqual( + worst_performer.cpu_usage.user_cpu_time, user_cpu_time + ) + self.assertAlmostEqual( + worst_performer.cpu_usage.system_cpu_time, system_cpu_time + ) self.assertEqual(worst_performer.network_bytes_sent, network_bytes_sent) - self.assertEqual(worst_performer.session_id.Path(), - "/%s/%s" % (client_id, flow_id)) + self.assertEqual( + rdfvalue.SessionID.FromHumanReadable( + worst_performer.session_id + ).Path(), + "/%s/%s" % (client_id, flow_id), + ) def testReadHuntClientResourcesStatsCorrectlyAggregatesVeryLargeNumbers(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - self._SetupHuntClientAndFlow( + client_id_1 = db_test_utils.InitializeClient(self.db) + db_test_utils.InitializeFlow( + self.db, + client_id_1, + flow_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=3810072130, system_cpu_time=3810072130), + user_cpu_time=3810072130, system_cpu_time=3810072130 + ), network_bytes_sent=3810072130, - hunt_id=hunt_obj.hunt_id) - self._SetupHuntClientAndFlow( + ) + + client_id_2 = db_test_utils.InitializeClient(self.db) + db_test_utils.InitializeFlow( + self.db, + client_id_2, + flow_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + parent_hunt_id=hunt_id, cpu_time_used=rdf_client_stats.CpuSeconds( - user_cpu_time=2143939532, system_cpu_time=2143939532), + user_cpu_time=2143939532, system_cpu_time=2143939532 + ), network_bytes_sent=2143939532, - hunt_id=hunt_obj.hunt_id) + ) - usage_stats = self.db.ReadHuntClientResourcesStats(hunt_obj.hunt_id) + usage_stats = self.db.ReadHuntClientResourcesStats(hunt_id) self.assertEqual(usage_stats.user_cpu_stats.num, 2) - self.assertAlmostEqual(usage_stats.user_cpu_stats.mean, 2977005831, 5) + self.assertAlmostEqual(usage_stats.user_cpu_stats.sum, 5954011662, 5) self.assertAlmostEqual(usage_stats.user_cpu_stats.stddev, 833066299, 5) - self.assertAlmostEqual(usage_stats.system_cpu_stats.mean, 2977005831, 5) + self.assertAlmostEqual(usage_stats.system_cpu_stats.sum, 5954011662, 5) self.assertAlmostEqual(usage_stats.system_cpu_stats.stddev, 833066299, 5) - self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.mean, - 2977005831, 5) - self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.stddev, - 833066299, 5) + self.assertAlmostEqual( + usage_stats.network_bytes_sent_stats.sum, 5954011662, 5 + ) + self.assertAlmostEqual( + usage_stats.network_bytes_sent_stats.stddev, 833066299, 5 + ) + self.assertLen(usage_stats.worst_performers, 2) + self.assertEqual( + rdf_client.ClientURN.FromHumanReadable( + usage_stats.worst_performers[0].client_id + ).Path(), + f"/{client_id_1}", + ) + self.assertAlmostEqual( + usage_stats.worst_performers[0].cpu_usage.user_cpu_time, 3810072130.0 + ) + self.assertAlmostEqual( + usage_stats.worst_performers[0].cpu_usage.system_cpu_time, 3810072130.0 + ) + self.assertEqual( + usage_stats.worst_performers[0].network_bytes_sent, 3810072130 + ) + self.assertEqual( + rdfvalue.SessionID.FromHumanReadable( + usage_stats.worst_performers[0].session_id + ), + f"/{client_id_1}/{hunt_id}", + ) + self.assertEqual( + rdf_client.ClientURN.FromHumanReadable( + usage_stats.worst_performers[1].client_id + ).Path(), + f"/{client_id_2}", + ) + self.assertAlmostEqual( + usage_stats.worst_performers[1].cpu_usage.user_cpu_time, 2143939532.0 + ) + self.assertAlmostEqual( + usage_stats.worst_performers[1].cpu_usage.system_cpu_time, 2143939532.0 + ) + self.assertEqual( + usage_stats.worst_performers[1].network_bytes_sent, 2143939532 + ) + self.assertEqual( + rdfvalue.SessionID.FromHumanReadable( + usage_stats.worst_performers[1].session_id + ), + f"/{client_id_2}/{hunt_id}", + ) + + def testReadHuntClientResourcesStatsFiltersDirectFlowIdToMatchTheHuntID(self): + client_id = db_test_utils.InitializeClient(self.db) + hunt_id = db_test_utils.InitializeHunt(self.db) + + # The `flow_id` is randomly initialized, so it will not match the `hunt_id`. + db_test_utils.InitializeFlow( + self.db, + client_id, + parent_hunt_id=hunt_id, + ) + + usage_stats = self.db.ReadHuntClientResourcesStats(hunt_id) + self.assertEqual(usage_stats.user_cpu_stats.num, 0) + self.assertEqual(usage_stats.system_cpu_stats.num, 0) + self.assertEqual(usage_stats.network_bytes_sent_stats.num, 0) + self.assertEmpty(usage_stats.worst_performers) def testReadHuntFlowsStatesAndTimestampsWorksCorrectlyForMultipleFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) expected = [] for i in range(10): - client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) if i % 2 == 0: - flow_state = rdf_flow_objects.Flow.FlowState.RUNNING + flow_state = flows_pb2.Flow.FlowState.RUNNING else: - flow_state = rdf_flow_objects.Flow.FlowState.FINISHED + flow_state = flows_pb2.Flow.FlowState.FINISHED self.db.UpdateFlow(client_id, flow_id, flow_state=flow_state) flow_obj = self.db.ReadFlowObject(client_id, flow_id) @@ -2004,25 +2477,25 @@ def testReadHuntFlowsStatesAndTimestampsWorksCorrectlyForMultipleFlows(self): db.FlowStateAndTimestamps( flow_state=flow_obj.flow_state, create_time=flow_obj.create_time, - last_update_time=flow_obj.last_update_time)) + last_update_time=flow_obj.last_update_time, + ) + ) - state_and_times = self.db.ReadHuntFlowsStatesAndTimestamps(hunt_obj.hunt_id) + state_and_times = self.db.ReadHuntFlowsStatesAndTimestamps(hunt_id) self.assertCountEqual(state_and_times, expected) def testReadHuntFlowsStatesAndTimestampsIgnoresNestedFlows(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) - client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_obj.hunt_id) + client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) self._SetupHuntClientAndFlow( - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, client_id=client_id, flow_id=flow.RandomFlowId(), - parent_flow_id=flow_id) + parent_flow_id=flow_id, + ) - state_and_times = self.db.ReadHuntFlowsStatesAndTimestamps(hunt_obj.hunt_id) + state_and_times = self.db.ReadHuntFlowsStatesAndTimestamps(hunt_id) self.assertLen(state_and_times, 1) flow_obj = self.db.ReadFlowObject(client_id, flow_id) @@ -2031,48 +2504,47 @@ def testReadHuntFlowsStatesAndTimestampsIgnoresNestedFlows(self): db.FlowStateAndTimestamps( flow_state=flow_obj.flow_state, create_time=flow_obj.create_time, - last_update_time=flow_obj.last_update_time)) + last_update_time=flow_obj.last_update_time, + ), + ) def testReadHuntOutputPluginLogEntriesReturnsEntryFromSingleHuntFlow(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + hunt_id = db_test_utils.InitializeHunt(self.db) output_plugin_id = "1" client_id, flow_id = self._SetupHuntClientAndFlow( - client_id="C.12345678901234aa", hunt_id=hunt_obj.hunt_id) + client_id="C.12345678901234aa", hunt_id=hunt_id + ) self.db.WriteFlowOutputPluginLogEntry( flows_pb2.FlowOutputPluginLogEntry( client_id=client_id, flow_id=flow_id, output_plugin_id=output_plugin_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, message="blah", ) ) hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, output_plugin_id, 0, 10) + hunt_id, output_plugin_id, 0, 10 + ) self.assertLen(hunt_op_log_entries, 1) self.assertIsInstance( hunt_op_log_entries[0], flows_pb2.FlowOutputPluginLogEntry ) - self.assertEqual(hunt_op_log_entries[0].hunt_id, hunt_obj.hunt_id) + self.assertEqual(hunt_op_log_entries[0].hunt_id, hunt_id) self.assertEqual(hunt_op_log_entries[0].client_id, client_id) self.assertEqual(hunt_op_log_entries[0].flow_id, flow_id) self.assertEqual(hunt_op_log_entries[0].message, "blah") - def _WriteHuntOutputPluginLogEntries(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) + def _WriteHuntOutputPluginLogEntries(self) -> str: + hunt_id = db_test_utils.InitializeHunt(self.db) output_plugin_id = "1" for i in range(10): client_id, flow_id = self._SetupHuntClientAndFlow( - client_id="C.12345678901234a%d" % i, hunt_id=hunt_obj.hunt_id) + client_id="C.12345678901234a%d" % i, hunt_id=hunt_id + ) enum = flows_pb2.FlowOutputPluginLogEntry.LogEntryType if i % 3 == 0: log_entry_type = enum.ERROR @@ -2082,40 +2554,43 @@ def _WriteHuntOutputPluginLogEntries(self): flows_pb2.FlowOutputPluginLogEntry( client_id=client_id, flow_id=flow_id, - hunt_id=hunt_obj.hunt_id, + hunt_id=hunt_id, output_plugin_id=output_plugin_id, log_entry_type=log_entry_type, message="blah%d" % i, ) ) - return hunt_obj + return hunt_id def testReadHuntOutputPluginLogEntriesReturnsEntryFromMultipleHuntFlows(self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + hunt_id = self._WriteHuntOutputPluginLogEntries() hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, "1", 0, 100) + hunt_id, "1", 0, 100 + ) self.assertLen(hunt_op_log_entries, 10) # Make sure messages are returned in timestamps-ascending order. for i, e in enumerate(hunt_op_log_entries): self.assertEqual(e.message, "blah%d" % i) def testReadHuntOutputPluginLogEntriesCorrectlyAppliesOffsetAndCountFilters( - self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + self, + ): + hunt_id = self._WriteHuntOutputPluginLogEntries() for i in range(10): hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, "1", i, 1) + hunt_id, "1", i, 1 + ) self.assertLen(hunt_op_log_entries, 1) self.assertEqual(hunt_op_log_entries[0].message, "blah%d" % i) def testReadHuntOutputPluginLogEntriesCorrectlyAppliesWithTypeFilter(self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + hunt_id = self._WriteHuntOutputPluginLogEntries() hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", 0, 100, @@ -2124,7 +2599,7 @@ def testReadHuntOutputPluginLogEntriesCorrectlyAppliesWithTypeFilter(self): self.assertEmpty(hunt_op_log_entries) hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", 0, 100, @@ -2133,7 +2608,7 @@ def testReadHuntOutputPluginLogEntriesCorrectlyAppliesWithTypeFilter(self): self.assertLen(hunt_op_log_entries, 4) hunt_op_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", 0, 100, @@ -2142,11 +2617,12 @@ def testReadHuntOutputPluginLogEntriesCorrectlyAppliesWithTypeFilter(self): self.assertLen(hunt_op_log_entries, 6) def testReadHuntOutputPluginLogEntriesCorrectlyAppliesCombinationOfFilters( - self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + self, + ): + hunt_id = self._WriteHuntOutputPluginLogEntries() hunt_log_entries = self.db.ReadHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", 0, 1, @@ -2156,91 +2632,87 @@ def testReadHuntOutputPluginLogEntriesCorrectlyAppliesCombinationOfFilters( self.assertEqual(hunt_log_entries[0].message, "blah1") def testCountHuntOutputPluginLogEntriesReturnsCorrectCount(self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + hunt_id = self._WriteHuntOutputPluginLogEntries() - num_entries = self.db.CountHuntOutputPluginLogEntries(hunt_obj.hunt_id, "1") + num_entries = self.db.CountHuntOutputPluginLogEntries(hunt_id, "1") self.assertEqual(num_entries, 10) def testCountHuntOutputPluginLogEntriesRespectsWithTypeFilter(self): - hunt_obj = self._WriteHuntOutputPluginLogEntries() + hunt_id = self._WriteHuntOutputPluginLogEntries() num_entries = self.db.CountHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", with_type=flows_pb2.FlowOutputPluginLogEntry.LogEntryType.LOG, ) self.assertEqual(num_entries, 6) num_entries = self.db.CountHuntOutputPluginLogEntries( - hunt_obj.hunt_id, + hunt_id, "1", with_type=flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ERROR, ) self.assertEqual(num_entries, 4) def testFlowStateUpdateUsingUpdateFlow(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - hunt_id = hunt_obj.hunt_id + hunt_id = db_test_utils.InitializeHunt(self.db) client_id, flow_id = self._SetupHuntClientAndFlow( - hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING) + hunt_id=hunt_id, flow_state=rdf_flow_objects.Flow.FlowState.RUNNING + ) results = self.db.ReadHuntFlows( hunt_id, 0, 10, - filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY) + filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY, + ) self.assertLen(results, 1) results = self.db.ReadHuntFlows( hunt_id, 0, 10, - filter_condition=db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY) + filter_condition=db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY, + ) self.assertEmpty(results) - rdf_flow = self.db.ReadFlowObject(client_id, flow_id) - rdf_flow.flow_state = rdf_flow_objects.Flow.FlowState.FINISHED - self.db.UpdateFlow(client_id, rdf_flow.flow_id, flow_obj=rdf_flow) + proto_flow = self.db.ReadFlowObject(client_id, flow_id) + proto_flow.flow_state = flows_pb2.Flow.FlowState.FINISHED + self.db.UpdateFlow(client_id, proto_flow.flow_id, flow_obj=proto_flow) results = self.db.ReadHuntFlows( hunt_id, 0, 10, - filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY) + filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY, + ) self.assertEmpty(results) results = self.db.ReadHuntFlows( hunt_id, 0, 10, - filter_condition=db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY) + filter_condition=db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY, + ) self.assertLen(results, 1) def testFlowStateUpdateUsingReleaseProcessedFlow(self): - self.db.WriteGRRUser("user") - hunt_obj = rdf_hunt_objects.Hunt(description="foo", creator="user") - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) - self.db.WriteHuntObject(hunt_obj) - hunt_id = hunt_obj.hunt_id + hunt_id = db_test_utils.InitializeHunt(self.db) client_id, flow_id = self._SetupHuntClientAndFlow(hunt_id=hunt_id) flow_obj = self.db.LeaseFlowForProcessing( - client_id, flow_id, rdfvalue.Duration.From(1, rdfvalue.MINUTES)) - self.assertEqual(flow_obj.flow_state, rdf_flow_objects.Flow.FlowState.UNSET) + client_id, flow_id, rdfvalue.Duration.From(1, rdfvalue.MINUTES) + ) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.UNSET) - flow_obj.flow_state = rdf_flow_objects.Flow.FlowState.ERROR + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR self.db.ReleaseProcessedFlow(flow_obj) results = self.db.ReadHuntFlows( - hunt_id, - 0, - 10, - filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY) + hunt_id, 0, 10, filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY + ) self.assertLen(results, 1) diff --git a/grr/server/grr_response_server/databases/db_message_handler_test.py b/grr/server/grr_response_server/databases/db_message_handler_test.py index c5d2100c5d..68d321051d 100644 --- a/grr/server/grr_response_server/databases/db_message_handler_test.py +++ b/grr/server/grr_response_server/databases/db_message_handler_test.py @@ -5,8 +5,9 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils - -from grr_response_server.rdfvalues import objects as rdf_objects +from grr_response_core.lib.rdfvalues import mig_protodict +from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import objects_pb2 class DatabaseTestHandlerMixin(object): @@ -17,44 +18,56 @@ class DatabaseTestHandlerMixin(object): def testMessageHandlerRequests(self): - requests = [ - rdf_objects.MessageHandlerRequest( - client_id="C.1000000000000000", - handler_name="Testhandler", - request_id=i * 100, - request=rdfvalue.RDFInteger(i)) for i in range(5) - ] + requests = [] + for i in range(5): + emb = mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) + ) + requests.append( + objects_pb2.MessageHandlerRequest( + client_id="C.1000000000000000", + handler_name="Testhandler", + request_id=i * 100, + request=emb, + ) + ) self.db.WriteMessageHandlerRequests(requests) read = self.db.ReadMessageHandlerRequests() for r in read: self.assertTrue(r.timestamp) - r.timestamp = None + r.ClearField("timestamp") - self.assertEqual(sorted(read, key=lambda req: req.request_id), requests) + self.assertCountEqual(read, requests) self.db.DeleteMessageHandlerRequests(requests[:2]) self.db.DeleteMessageHandlerRequests(requests[4:5]) read = self.db.ReadMessageHandlerRequests() self.assertLen(read, 2) - read = sorted(read, key=lambda req: req.request_id) for r in read: - r.timestamp = None + r.ClearField("timestamp") - self.assertEqual(requests[2:4], read) + self.assertCountEqual(requests[2:4], read) self.db.DeleteMessageHandlerRequests(read) def testMessageHandlerRequestLeasing(self): - requests = [ - rdf_objects.MessageHandlerRequest( - client_id="C.1000000000000000", - handler_name="Testhandler", - request_id=i * 100, - request=rdfvalue.RDFInteger(i)) for i in range(10) - ] + requests = [] + for i in range(10): + emb = mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) + ) + requests.append( + objects_pb2.MessageHandlerRequest( + client_id="C.1000000000000000", + handler_name="Testhandler", + request_id=i * 100, + request=emb, + ) + ) + lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) leased = queue.Queue() @@ -67,16 +80,17 @@ def testMessageHandlerRequestLeasing(self): try: l = leased.get(True, timeout=6) except queue.Empty: - self.fail("Timed out waiting for messages, expected 10, got %d" % - len(got)) + self.fail( + "Timed out waiting for messages, expected 10, got %d" % len(got) + ) self.assertLessEqual(len(l), 5) for m in l: self.assertEqual(m.leased_by, utils.ProcessIdString()) self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now()) self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now()) - m.leased_by = None - m.leased_until = None - m.timestamp = None + m.ClearField("leased_by") + m.ClearField("leased_until") + m.ClearField("timestamp") got += l self.db.DeleteMessageHandlerRequests(got) diff --git a/grr/server/grr_response_server/databases/db_paths_test.py b/grr/server/grr_response_server/databases/db_paths_test.py index c2a93d65de..532b6282f8 100644 --- a/grr/server/grr_response_server/databases/db_paths_test.py +++ b/grr/server/grr_response_server/databases/db_paths_test.py @@ -4,11 +4,12 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import paths as rdf_paths +from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_test_utils +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -22,21 +23,36 @@ def testWritePathInfosValidatesClientId(self): path = ["usr", "local"] with self.assertRaises(ValueError): - self.db.WritePathInfos("", [rdf_objects.PathInfo.OS(components=path)]) + self.db.WritePathInfos( + "", + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=path + ) + ], + ) def testWritePathInfosValidatesPathType(self): path = ["usr", "local"] client_id = db_test_utils.InitializeClient(self.db) with self.assertRaises(ValueError): - self.db.WritePathInfos(client_id, [rdf_objects.PathInfo(components=path)]) + self.db.WritePathInfos(client_id, [objects_pb2.PathInfo(components=path)]) def testWritePathInfosValidatesClient(self): client_id = "C.0123456789012345" with self.assertRaises(db.UnknownClientError) as context: self.db.WritePathInfos( - client_id, [rdf_objects.PathInfo.OS(components=[], directory=True)]) + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=[], + directory=True, + ) + ], + ) self.assertEqual(context.exception.client_id, client_id) @@ -52,10 +68,21 @@ def testWritePathInfosValidateConflictingWrites(self): client_id = db_test_utils.InitializeClient(self.db) with self.assertRaises(ValueError): - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"], directory=False), - rdf_objects.PathInfo.OS(components=["foo", "bar"], directory=True), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + directory=False, + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + directory=True, + ), + ], + ) def testWritePathInfosEmpty(self): client_id = db_test_utils.InitializeClient(self.db) @@ -66,15 +93,23 @@ def testWritePathInfosMetadata(self): self.db.WritePathInfos( client_id, - [rdf_objects.PathInfo.TSK(components=["foo", "bar"], directory=True)]) + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, + components=["foo", "bar"], + directory=True, + ) + ], + ) - results = self.db.ReadPathInfos(client_id, - rdf_objects.PathInfo.PathType.TSK, - [("foo", "bar")]) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.TSK, [("foo", "bar")] + ) result_path_info = results[("foo", "bar")] - self.assertEqual(result_path_info.path_type, - rdf_objects.PathInfo.PathType.TSK) + self.assertEqual( + result_path_info.path_type, objects_pb2.PathInfo.PathType.TSK + ) self.assertEqual(result_path_info.components, ["foo", "bar"]) self.assertEqual(result_path_info.directory, True) @@ -85,26 +120,36 @@ def testWritePathInfosMetadataTimestampUpdate(self): timestamp_0 = now() - self.db.WritePathInfos(client_id, - [rdf_objects.PathInfo.OS(components=["foo"])]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) + ], + ) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(result.components, ["foo"]) self.assertGreater(result.timestamp, timestamp_0) self.assertLess(result.timestamp, now()) - self.assertEqual(result.last_stat_entry_timestamp, None) - self.assertEqual(result.last_hash_entry_timestamp, None) + self.assertFalse(result.HasField("last_stat_entry_timestamp")) + self.assertFalse(result.HasField("last_hash_entry_timestamp")) timestamp_1 = now() - stat_entry = rdf_client_fs.StatEntry(st_mode=42) - self.db.WritePathInfos( - client_id, - [rdf_objects.PathInfo.OS(components=["foo"], stat_entry=stat_entry)]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + ) + path_info.stat_entry.st_mode = 42 + self.db.WritePathInfos(client_id, [path_info]) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(result.components, ["foo"]) self.assertEqual(result.stat_entry.st_mode, 42) self.assertGreater(result.timestamp, timestamp_1) @@ -114,13 +159,16 @@ def testWritePathInfosMetadataTimestampUpdate(self): timestamp_2 = now() - hash_entry = rdf_crypto.Hash(sha256=b"foo") - self.db.WritePathInfos( - client_id, - [rdf_objects.PathInfo.OS(components=["foo"], hash_entry=hash_entry)]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + ) + path_info.hash_entry.sha256 = b"foo" + self.db.WritePathInfos(client_id, [path_info]) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(result.components, ["foo"]) self.assertEqual(result.hash_entry.sha256, b"foo") self.assertGreater(result.timestamp, timestamp_2) @@ -132,10 +180,18 @@ def testWritePathInfosMetadataTimestampUpdate(self): self.db.WritePathInfos( client_id, - [rdf_objects.PathInfo.OS(components=["foo"], directory=True)]) + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + directory=True, + ) + ], + ) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(result.components, ["foo"]) self.assertEqual(result.stat_entry.st_mode, 42) self.assertEqual(result.hash_entry.sha256, b"foo") @@ -149,13 +205,16 @@ def testWritePathInfosMetadataTimestampUpdate(self): timestamp_4 = now() - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info.stat_entry.st_mode = 108 path_info.hash_entry.sha256 = b"norf" self.db.WritePathInfos(client_id, [path_info]) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(result.components, ["foo"]) self.assertEqual(result.stat_entry.st_mode, 108) self.assertEqual(result.hash_entry.sha256, b"norf") @@ -177,14 +236,18 @@ def testWritePathInfosStatEntry(self): stat_entry.st_atime = 4815162342 path_info = rdf_objects.PathInfo.FromStatEntry(stat_entry) - self.db.WritePathInfos(client_id, [path_info]) + proto_path_info = mig_objects.ToProtoPathInfo(path_info) + self.db.WritePathInfos(client_id, [proto_path_info]) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [ - (), - ("foo",), - ("foo", "bar"), - ]) + results = self.db.ReadPathInfos( + client_id, + objects_pb2.PathInfo.PathType.OS, + [ + (), + ("foo",), + ("foo", "bar"), + ], + ) root_path_info = results[()] self.assertFalse(root_path_info.HasField("stat_entry")) @@ -202,19 +265,23 @@ def testWritePathInfosStatEntry(self): def testWritePathInfosHashEntry(self): client_id = db_test_utils.InitializeClient(self.db) - hash_entry = rdf_crypto.Hash() + hash_entry = jobs_pb2.Hash() hash_entry.sha256 = hashlib.sha256(b"foo").digest() hash_entry.md5 = hashlib.md5(b"foo").digest() hash_entry.num_bytes = len(b"foo") - path_info = rdf_objects.PathInfo.OS( - components=["foo", "bar", "baz"], hash_entry=hash_entry) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + ) + path_info.hash_entry.CopyFrom(hash_entry) self.db.WritePathInfos(client_id, [path_info]) result = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz")) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) self.assertEqual(result.components, ["foo", "bar", "baz"]) self.assertTrue(result.HasField("hash_entry")) @@ -226,12 +293,15 @@ def testWritePathInfosHashEntry(self): def testWritePathInfosValidatesHashEntry(self): client_id = db_test_utils.InitializeClient(self.db) - hash_entry = rdf_crypto.Hash() + hash_entry = jobs_pb2.Hash() hash_entry.md5 = hashlib.md5(b"foo").digest() hash_entry.sha1 = hashlib.sha1(b"bar").digest() - path_info = rdf_objects.PathInfo.OS( - components=("foo", "bar", "baz"), hash_entry=hash_entry) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) + path_info.hash_entry.CopyFrom(hash_entry) with self.assertRaises(ValueError): self.db.WritePathInfos(client_id, [path_info]) @@ -257,14 +327,17 @@ def MD5(data: bytes) -> bytes: for name, content in files.items(): content = name.encode("utf-8") - hash_entry = rdf_crypto.Hash() + hash_entry = jobs_pb2.Hash() hash_entry.sha256 = SHA256(content) hash_entry.md5 = MD5(content) hash_entry.num_bytes = len(content) - path_infos.append( - rdf_objects.PathInfo.OS( - components=["foo", "bar", "baz", name], hash_entry=hash_entry)) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", name], + ) + path_info.hash_entry.CopyFrom(hash_entry) + path_infos.append(path_info) self.db.WritePathInfos(client_id, path_infos) @@ -273,8 +346,9 @@ def MD5(data: bytes) -> bytes: result = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz", name)) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz", name), + ) self.assertEqual(result.components, ["foo", "bar", "baz", name]) self.assertTrue(result.HasField("hash_entry")) @@ -286,19 +360,22 @@ def MD5(data: bytes) -> bytes: def testWritePathInfosHashAndStatEntry(self): client_id = db_test_utils.InitializeClient(self.db) - stat_entry = rdf_client_fs.StatEntry(st_mode=1337) - hash_entry = rdf_crypto.Hash(sha256=hashlib.sha256(b"foo").digest()) + stat_entry = jobs_pb2.StatEntry(st_mode=1337) + hash_entry = jobs_pb2.Hash(sha256=hashlib.sha256(b"foo").digest()) - path_info = rdf_objects.PathInfo.OS( + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "bar", "baz"], - stat_entry=stat_entry, - hash_entry=hash_entry) + ) + path_info.stat_entry.CopyFrom(stat_entry) + path_info.hash_entry.CopyFrom(hash_entry) self.db.WritePathInfos(client_id, [path_info]) result = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz")) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) self.assertEqual(result.components, ["foo", "bar", "baz"]) self.assertTrue(result.HasField("stat_entry")) @@ -309,22 +386,29 @@ def testWritePathInfosHashAndStatEntry(self): def testWritePathInfoHashAndStatEntrySeparateWrites(self): client_id = db_test_utils.InitializeClient(self.db) - stat_entry = rdf_client_fs.StatEntry(st_mode=1337) - stat_entry_path_info = rdf_objects.PathInfo.OS( - components=["foo"], stat_entry=stat_entry) + stat_entry = jobs_pb2.StatEntry(st_mode=1337) + stat_entry_path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + ) + stat_entry_path_info.stat_entry.CopyFrom(stat_entry) stat_entry_timestamp = self.db.Now() self.db.WritePathInfos(client_id, [stat_entry_path_info]) - hash_entry = rdf_crypto.Hash(sha256=hashlib.sha256(b"foo").digest()) - hash_entry_path_info = rdf_objects.PathInfo.OS( - components=["foo"], hash_entry=hash_entry) + hash_entry = jobs_pb2.Hash(sha256=hashlib.sha256(b"foo").digest()) + hash_entry_path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + ) + hash_entry_path_info.hash_entry.CopyFrom(hash_entry) hash_entry_timestamp = self.db.Now() self.db.WritePathInfos(client_id, [hash_entry_path_info]) result = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) now = self.db.Now() @@ -341,16 +425,25 @@ def testWritePathInfoHashAndStatEntrySeparateWrites(self): def testWritePathInfosExpansion(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [ - ("foo",), - ("foo", "bar"), - ("foo", "bar", "baz"), - ]) + results = self.db.ReadPathInfos( + client_id, + objects_pb2.PathInfo.PathType.OS, + [ + ("foo",), + ("foo", "bar"), + ("foo", "bar", "baz"), + ], + ) self.assertLen(results, 3) @@ -369,38 +462,62 @@ def testWritePathInfosExpansion(self): def testWritePathInfosTypeSeparated(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo"], directory=True), - rdf_objects.PathInfo.TSK(components=["foo"], directory=False), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + directory=True, + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, + components=["foo"], + directory=False, + ), + ], + ) - os_results = self.db.ReadPathInfos(client_id, - rdf_objects.PathInfo.PathType.OS, - [("foo",)]) + os_results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [("foo",)] + ) self.assertLen(os_results, 1) self.assertTrue(os_results[("foo",)].directory) - tsk_results = self.db.ReadPathInfos(client_id, - rdf_objects.PathInfo.PathType.TSK, - [("foo",)]) + tsk_results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.TSK, [("foo",)] + ) self.assertLen(tsk_results, 1) self.assertFalse(tsk_results[("foo",)].directory) def testWritePathInfosUpdates(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS( - components=["foo", "bar", "baz"], directory=False), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + directory=False, + ), + ], + ) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS( - components=["foo", "bar", "baz"], directory=True), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + directory=True, + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [("foo", "bar", "baz")]) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [("foo", "bar", "baz")] + ) result_path_info = results[("foo", "bar", "baz")] self.assertTrue(result_path_info.directory) @@ -408,15 +525,29 @@ def testWritePathInfosUpdates(self): def testWritePathInfosUpdatesAncestors(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo"], directory=False), - ]) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], + directory=False, + ), + ], + ) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [("foo",)]) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [("foo",)] + ) self.assertLen(results, 1) self.assertTrue(results[("foo",)].directory) @@ -424,12 +555,19 @@ def testWritePathInfosUpdatesAncestors(self): def testWritePathInfosWritesAncestorsWithTimestamps(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["ancestor", "bar"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["ancestor", "bar"], + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [("ancestor",)]) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [("ancestor",)] + ) self.assertLen(results, 1) self.assertIsNotNone(results[("ancestor",)].timestamp) @@ -437,15 +575,28 @@ def testWritePathInfosWritesAncestorsWithTimestamps(self): def testWritePathInfosDuplicatedData(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - ]) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + ], + ) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [("foo", "bar")]) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [("foo", "bar")] + ) self.assertLen(results, 1) result_path_info = results[("foo", "bar")] @@ -455,7 +606,9 @@ def testWritePathInfosDuplicatedData(self): def testWritePathInfosStoresCopy(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo", "bar"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "bar"] + ) path_info.stat_entry.st_size = 1337 path_info.hash_entry.sha256 = b"foo" @@ -471,39 +624,51 @@ def testWritePathInfosStoresCopy(self): result_1 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo", "bar"), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertEqual(result_1.stat_entry.st_size, 1337) self.assertEqual(result_1.hash_entry.sha256, b"foo") result_2 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo", "bar"), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertEqual(result_2.stat_entry.st_size, 42) self.assertEqual(result_2.hash_entry.sha256, b"bar") def testReadPathInfosEmptyComponentsList(self): client_id = db_test_utils.InitializeClient(self.db) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - []) + results = self.db.ReadPathInfos( + client_id, objects_pb2.PathInfo.PathType.OS, [] + ) self.assertEqual(results, {}) def testReadPathInfosNonExistent(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + ], + ) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [ - ("foo", "bar"), - ("foo", "baz"), - ("quux", "norf"), - ]) + results = self.db.ReadPathInfos( + client_id, + objects_pb2.PathInfo.PathType.OS, + [ + ("foo", "bar"), + ("foo", "baz"), + ("quux", "norf"), + ], + ) self.assertLen(results, 3) self.assertIsNotNone(results[("foo", "bar")]) self.assertIsNone(results[("foo", "baz")]) @@ -515,9 +680,10 @@ def testReadPathInfoValidatesTimestamp(self): with self.assertRaises(TypeError): self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.REGISTRY, + objects_pb2.PathInfo.PathType.REGISTRY, components=("foo", "bar", "baz"), - timestamp=rdfvalue.Duration.From(10, rdfvalue.SECONDS)) + timestamp=rdfvalue.Duration.From(10, rdfvalue.SECONDS), + ) def testReadPathInfoNonExistent(self): client_id = db_test_utils.InitializeClient(self.db) @@ -525,125 +691,148 @@ def testReadPathInfoNonExistent(self): with self.assertRaises(db.UnknownPathError) as ctx: self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz")) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) self.assertEqual(ctx.exception.client_id, client_id) - self.assertEqual(ctx.exception.path_type, rdf_objects.PathInfo.PathType.OS) + self.assertEqual(ctx.exception.path_type, objects_pb2.PathInfo.PathType.OS) self.assertEqual(ctx.exception.components, ("foo", "bar", "baz")) def testReadPathInfoTimestampStatEntry(self): client_id = db_test_utils.InitializeClient(self.db) pathspec = rdf_paths.PathSpec( - path="foo/bar/baz", pathtype=rdf_paths.PathSpec.PathType.OS) + path="foo/bar/baz", pathtype=rdf_paths.PathSpec.PathType.OS + ) stat_entry = rdf_client_fs.StatEntry(pathspec=pathspec, st_size=42) - self.db.WritePathInfos(client_id, - [rdf_objects.PathInfo.FromStatEntry(stat_entry)]) + path_info = rdf_objects.PathInfo.FromStatEntry(stat_entry) + self.db.WritePathInfos(client_id, [mig_objects.ToProtoPathInfo(path_info)]) timestamp_1 = self.db.Now() stat_entry = rdf_client_fs.StatEntry(pathspec=pathspec, st_size=101) - self.db.WritePathInfos(client_id, - [rdf_objects.PathInfo.FromStatEntry(stat_entry)]) + path_info = rdf_objects.PathInfo.FromStatEntry(stat_entry) + self.db.WritePathInfos(client_id, [mig_objects.ToProtoPathInfo(path_info)]) timestamp_2 = self.db.Now() stat_entry = rdf_client_fs.StatEntry(pathspec=pathspec, st_size=1337) - self.db.WritePathInfos(client_id, - [rdf_objects.PathInfo.FromStatEntry(stat_entry)]) + path_info = rdf_objects.PathInfo.FromStatEntry(stat_entry) + self.db.WritePathInfos(client_id, [mig_objects.ToProtoPathInfo(path_info)]) timestamp_3 = self.db.Now() path_info_last = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz")) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) self.assertEqual(path_info_last.stat_entry.st_size, 1337) self.assertEqual(path_info_last.components, ["foo", "bar", "baz"]) path_info_1 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo", "bar", "baz"), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertEqual(path_info_1.stat_entry.st_size, 42) self.assertEqual(path_info_last.components, ["foo", "bar", "baz"]) path_info_2 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo", "bar", "baz"), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertEqual(path_info_2.stat_entry.st_size, 101) self.assertEqual(path_info_last.components, ["foo", "bar", "baz"]) path_info_3 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo", "bar", "baz"), - timestamp=timestamp_3) + timestamp=timestamp_3, + ) self.assertEqual(path_info_3.stat_entry.st_size, 1337) self.assertEqual(path_info_last.components, ["foo", "bar", "baz"]) def testReadPathInfoTimestampHashEntry(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) - path_info.hash_entry = rdf_crypto.Hash(sha256=b"bar") + path_info.hash_entry.sha256 = b"bar" self.db.WritePathInfos(client_id, [path_info]) bar_timestamp = self.db.Now() - path_info.hash_entry = rdf_crypto.Hash(sha256=b"baz") + path_info.hash_entry.sha256 = b"baz" self.db.WritePathInfos(client_id, [path_info]) baz_timestamp = self.db.Now() - path_info.hash_entry = rdf_crypto.Hash(sha256=b"quux") + path_info.hash_entry.sha256 = b"quux" self.db.WritePathInfos(client_id, [path_info]) quux_timestamp = self.db.Now() bar_path_info = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=bar_timestamp) + timestamp=bar_timestamp, + ) self.assertEqual(bar_path_info.hash_entry.sha256, b"bar") baz_path_info = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=baz_timestamp) + timestamp=baz_timestamp, + ) self.assertEqual(baz_path_info.hash_entry.sha256, b"baz") quux_path_info = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=quux_timestamp) + timestamp=quux_timestamp, + ) self.assertEqual(quux_path_info.hash_entry.sha256, b"quux") def testReadPathInfosMany(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=["foo", "bar"]) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "bar"] + ) path_info_1.stat_entry.st_mode = 42 path_info_1.hash_entry.md5 = b"foo" path_info_1.hash_entry.sha256 = b"bar" - path_info_2 = rdf_objects.PathInfo.OS(components=["baz", "quux", "norf"]) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["baz", "quux", "norf"], + ) path_info_2.hash_entry.sha256 = b"bazquuxnorf" - path_info_3 = rdf_objects.PathInfo.OS(components=["blargh"], directory=True) + path_info_3 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["blargh"], + directory=True, + ) path_info_3.stat_entry.st_size = 1337 self.db.WritePathInfos(client_id, [path_info_1, path_info_2, path_info_3]) - results = self.db.ReadPathInfos(client_id, rdf_objects.PathInfo.PathType.OS, - [ - ("foo", "bar"), - ("baz", "quux", "norf"), - ("blargh",), - ]) + results = self.db.ReadPathInfos( + client_id, + objects_pb2.PathInfo.PathType.OS, + [ + ("foo", "bar"), + ("baz", "quux", "norf"), + ("blargh",), + ], + ) result_path_info_1 = results[("foo", "bar")] self.assertEqual(result_path_info_1.components, ["foo", "bar"]) self.assertEqual(result_path_info_1.stat_entry.st_mode, 42) @@ -662,80 +851,92 @@ def testReadPathInfosMany(self): def testReadPathInfoTimestampStatAndHashEntry(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) - path_info.stat_entry = rdf_client_fs.StatEntry(st_mode=42) - path_info.hash_entry = None + path_info.stat_entry.st_mode = 42 + path_info.ClearField("hash_entry") self.db.WritePathInfos(client_id, [path_info]) timestamp_1 = self.db.Now() - path_info.stat_entry = None - path_info.hash_entry = rdf_crypto.Hash(sha256=b"quux") + path_info.ClearField("stat_entry") + path_info.hash_entry.sha256 = b"quux" self.db.WritePathInfos(client_id, [path_info]) timestamp_2 = self.db.Now() - path_info.stat_entry = rdf_client_fs.StatEntry(st_mode=1337) - path_info.hash_entry = None + path_info.stat_entry.st_mode = 1337 + path_info.ClearField("hash_entry") self.db.WritePathInfos(client_id, [path_info]) timestamp_3 = self.db.Now() - path_info.stat_entry = rdf_client_fs.StatEntry(st_mode=4815162342) - path_info.hash_entry = rdf_crypto.Hash(sha256=b"norf") + path_info.stat_entry.st_mode = 4815162342 + path_info.hash_entry.sha256 = b"norf" self.db.WritePathInfos(client_id, [path_info]) timestamp_4 = self.db.Now() path_info_1 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertEqual(path_info_1.stat_entry.st_mode, 42) self.assertFalse(path_info_1.HasField("hash_entry")) path_info_2 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertEqual(path_info_2.stat_entry.st_mode, 42) self.assertEqual(path_info_2.hash_entry.sha256, b"quux") path_info_3 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_3) + timestamp=timestamp_3, + ) self.assertEqual(path_info_3.stat_entry.st_mode, 1337) self.assertEqual(path_info_3.hash_entry.sha256, b"quux") path_info_4 = self.db.ReadPathInfo( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_4) + timestamp=timestamp_4, + ) self.assertEqual(path_info_4.stat_entry.st_mode, 4815162342) self.assertEqual(path_info_4.hash_entry.sha256, b"norf") def testReadPathInfoOlder(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info.stat_entry.st_mode = 42 path_info.hash_entry.sha256 = b"foo" self.db.WritePathInfos(client_id, [path_info]) - path_info = rdf_objects.PathInfo.OS(components=["bar"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["bar"] + ) path_info.stat_entry.st_mode = 1337 path_info.hash_entry.sha256 = b"bar" self.db.WritePathInfos(client_id, [path_info]) path_info = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEqual(path_info.stat_entry.st_mode, 42) self.assertEqual(path_info.hash_entry.sha256, b"foo") path_info = self.db.ReadPathInfo( - client_id, rdf_objects.PathInfo.PathType.OS, components=("bar",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("bar",) + ) self.assertEqual(path_info.stat_entry.st_mode, 1337) self.assertEqual(path_info.hash_entry.sha256, b"bar") @@ -743,7 +944,8 @@ def testListDescendantPathInfosAlwaysSucceedsOnRoot(self): client_id = db_test_utils.InitializeClient(self.db) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=()) + client_id, objects_pb2.PathInfo.PathType.OS, components=() + ) self.assertEmpty(results) @@ -761,87 +963,135 @@ def testListDescendantPathInfosNonexistentDirectory(self): with self.assertRaises(db.UnknownPathError): self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) def testListDescendantPathInfosNotDirectory(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=("foo",), directory=False) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo",), + directory=False, + ) self.db.WritePathInfos(client_id, [path_info]) with self.assertRaises(db.NotDirectoryPathError): self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) def testListDescendantPathInfosEmptyResult(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=("foo",), directory=True) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo",), + directory=True, + ) self.db.WritePathInfos(client_id, [path_info]) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertEmpty(results) def testListDescendantPathInfosSingleResult(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + ], + ) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertLen(results, 1) - self.assertEqual(results[0].components, ("foo", "bar")) + self.assertEqual(results[0].components, ["foo", "bar"]) def testListDescendantPathInfosSingle(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz", "quux"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", "quux"], + ), + ], + ) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertLen(results, 3) - self.assertEqual(results[0].components, ("foo", "bar")) - self.assertEqual(results[1].components, ("foo", "bar", "baz")) - self.assertEqual(results[2].components, ("foo", "bar", "baz", "quux")) + self.assertEqual(results[0].components, ["foo", "bar"]) + self.assertEqual(results[1].components, ["foo", "bar", "baz"]) + self.assertEqual(results[2].components, ["foo", "bar", "baz", "quux"]) def testListDescendantPathInfosBranching(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar", "quux"]), - rdf_objects.PathInfo.OS(components=["foo", "baz"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "quux"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "baz"], + ), + ], + ) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) self.assertLen(results, 3) - self.assertEqual(results[0].components, ("foo", "bar")) - self.assertEqual(results[1].components, ("foo", "bar", "quux")) - self.assertEqual(results[2].components, ("foo", "baz")) + self.assertEqual(results[0].components, ["foo", "bar"]) + self.assertEqual(results[1].components, ["foo", "bar", "quux"]) + self.assertEqual(results[2].components, ["foo", "baz"]) def testListDescendantPathInfosLimited(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz", "quux"]), - rdf_objects.PathInfo.OS(components=["foo", "bar", "blargh"]), - rdf_objects.PathInfo.OS(components=["foo", "norf", "thud", "plugh"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", "quux"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "blargh"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "norf", "thud", "plugh"], + ), + ], + ) results = self.db.ListDescendantPathInfos( client_id, - rdf_objects.PathInfo.PathType.OS, + objects_pb2.PathInfo.PathType.OS, components=("foo",), - max_depth=2) + max_depth=2, + ) components = [tuple(path_info.components) for path_info in results] @@ -855,93 +1105,154 @@ def testListDescendantPathInfosLimited(self): def testListDescendantPathInfosTypeSeparated(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["usr", "bin", "javac"]), - rdf_objects.PathInfo.TSK(components=["usr", "bin", "gdb"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["usr", "bin", "javac"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, + components=["usr", "bin", "gdb"], + ), + ], + ) os_results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("usr", "bin")) + client_id, objects_pb2.PathInfo.PathType.OS, components=("usr", "bin") + ) self.assertLen(os_results, 1) - self.assertEqual(os_results[0].components, ("usr", "bin", "javac")) + self.assertEqual(os_results[0].components, ["usr", "bin", "javac"]) tsk_results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.TSK, components=("usr", "bin")) + client_id, objects_pb2.PathInfo.PathType.TSK, components=("usr", "bin") + ) self.assertLen(tsk_results, 1) - self.assertEqual(tsk_results[0].components, ("usr", "bin", "gdb")) + self.assertEqual(tsk_results[0].components, ["usr", "bin", "gdb"]) def testListDescendantPathInfosAll(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - rdf_objects.PathInfo.OS(components=["baz", "quux"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["baz", "quux"], + ), + ], + ) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=()) + client_id, objects_pb2.PathInfo.PathType.OS, components=() + ) - self.assertEqual(results[0].components, ("baz",)) - self.assertEqual(results[1].components, ("baz", "quux")) - self.assertEqual(results[2].components, ("foo",)) - self.assertEqual(results[3].components, ("foo", "bar")) + self.assertEqual( + results[0].components, + [ + "baz", + ], + ) + self.assertEqual(results[1].components, ["baz", "quux"]) + self.assertEqual( + results[2].components, + [ + "foo", + ], + ) + self.assertEqual(results[3].components, ["foo", "bar"]) def testListDescendantPathInfosLimitedDirectory(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=["foo", "bar", "baz"]) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + ) path_info_1.stat_entry.st_mode = 108 - path_info_2 = rdf_objects.PathInfo.OS(components=["foo", "bar"]) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "bar"] + ) path_info_2.stat_entry.st_mode = 1337 - path_info_3 = rdf_objects.PathInfo.OS(components=["foo", "norf", "quux"]) + path_info_3 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "norf", "quux"], + ) path_info_3.stat_entry.st_mode = 707 self.db.WritePathInfos(client_id, [path_info_1, path_info_2, path_info_3]) results = self.db.ListDescendantPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=(), max_depth=2) + client_id, objects_pb2.PathInfo.PathType.OS, components=(), max_depth=2 + ) self.assertLen(results, 3) - self.assertEqual(results[0].components, ("foo",)) - self.assertEqual(results[1].components, ("foo", "bar")) - self.assertEqual(results[2].components, ("foo", "norf")) + self.assertEqual( + results[0].components, + [ + "foo", + ], + ) + self.assertEqual(results[1].components, ["foo", "bar"]) + self.assertEqual(results[2].components, ["foo", "norf"]) self.assertEqual(results[1].stat_entry.st_mode, 1337) def testListDescendantPathInfosDepthZero(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=("foo",)) - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "bar")) - path_info_3 = rdf_objects.PathInfo.OS(components=("baz",)) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) + path_info_3 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("baz",) + ) self.db.WritePathInfos(client_id, [path_info_1, path_info_2, path_info_3]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - max_depth=0) + max_depth=0, + ) self.assertEmpty(results) def testListDescendantPathInfosTimestampNow(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo", "bar", "baz"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + ) path_info.stat_entry.st_size = 1337 self.db.WritePathInfos(client_id, [path_info]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=self.db.Now()) + timestamp=self.db.Now(), + ) self.assertLen(results, 3) - self.assertEqual(results[0].components, ("foo",)) - self.assertEqual(results[1].components, ("foo", "bar")) - self.assertEqual(results[2].components, ("foo", "bar", "baz")) + self.assertEqual( + results[0].components, + [ + "foo", + ], + ) + self.assertEqual(results[1].components, ["foo", "bar"]) + self.assertEqual(results[2].components, ["foo", "bar", "baz"]) self.assertEqual(results[2].stat_entry.st_size, 1337) def testListDescendantPathInfosTimestampMultiple(self): @@ -949,69 +1260,84 @@ def testListDescendantPathInfosTimestampMultiple(self): timestamp_0 = self.db.Now() - path_info_1 = rdf_objects.PathInfo.OS(components=["foo", "bar", "baz"]) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz"], + ) path_info_1.stat_entry.st_size = 1 self.db.WritePathInfos(client_id, [path_info_1]) timestamp_1 = self.db.Now() - path_info_2 = rdf_objects.PathInfo.OS(components=["foo", "quux", "norf"]) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "quux", "norf"], + ) path_info_2.stat_entry.st_size = 2 self.db.WritePathInfos(client_id, [path_info_2]) timestamp_2 = self.db.Now() - path_info_3 = rdf_objects.PathInfo.OS(components=["foo", "quux", "thud"]) + path_info_3 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "quux", "thud"], + ) path_info_3.stat_entry.st_size = 3 self.db.WritePathInfos(client_id, [path_info_3]) timestamp_3 = self.db.Now() results_0 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_0) + timestamp=timestamp_0, + ) self.assertEmpty(results_0) results_1 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertLen(results_1, 3) - self.assertEqual(results_1[0].components, ("foo",)) - self.assertEqual(results_1[1].components, ("foo", "bar")) - self.assertEqual(results_1[2].components, ("foo", "bar", "baz")) + self.assertEqual(results_1[0].components, ["foo"]) + self.assertEqual(results_1[1].components, ["foo", "bar"]) + self.assertEqual(results_1[2].components, ["foo", "bar", "baz"]) results_2 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertLen(results_2, 5) - self.assertEqual(results_2[0].components, ("foo",)) - self.assertEqual(results_2[1].components, ("foo", "bar")) - self.assertEqual(results_2[2].components, ("foo", "bar", "baz")) - self.assertEqual(results_2[3].components, ("foo", "quux")) - self.assertEqual(results_2[4].components, ("foo", "quux", "norf")) + self.assertEqual(results_2[0].components, ["foo"]) + self.assertEqual(results_2[1].components, ["foo", "bar"]) + self.assertEqual(results_2[2].components, ["foo", "bar", "baz"]) + self.assertEqual(results_2[3].components, ["foo", "quux"]) + self.assertEqual(results_2[4].components, ["foo", "quux", "norf"]) results_3 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_3) + timestamp=timestamp_3, + ) self.assertLen(results_3, 6) - self.assertEqual(results_3[0].components, ("foo",)) - self.assertEqual(results_3[1].components, ("foo", "bar")) - self.assertEqual(results_3[2].components, ("foo", "bar", "baz")) - self.assertEqual(results_3[3].components, ("foo", "quux")) - self.assertEqual(results_3[4].components, ("foo", "quux", "norf")) - self.assertEqual(results_3[5].components, ("foo", "quux", "thud")) + self.assertEqual(results_3[0].components, ["foo"]) + self.assertEqual(results_3[1].components, ["foo", "bar"]) + self.assertEqual(results_3[2].components, ["foo", "bar", "baz"]) + self.assertEqual(results_3[3].components, ["foo", "quux"]) + self.assertEqual(results_3[4].components, ["foo", "quux", "norf"]) + self.assertEqual(results_3[5].components, ["foo", "quux", "thud"]) def testListDescendantPathInfosTimestampStatValue(self): client_id = db_test_utils.InitializeClient(self.db) timestamp_0 = self.db.Now() - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info.stat_entry.st_size = 1337 self.db.WritePathInfos(client_id, [path_info]) @@ -1023,27 +1349,30 @@ def testListDescendantPathInfosTimestampStatValue(self): results_0 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_0) + timestamp=timestamp_0, + ) self.assertEmpty(results_0) results_1 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("foo", "bar")) + self.assertEqual(results_1[0].components, ["foo", "bar"]) self.assertEqual(results_1[0].stat_entry.st_size, 1337) results_2 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertLen(results_2, 1) - self.assertEqual(results_2[0].components, ("foo", "bar")) + self.assertEqual(results_2[0].components, ["foo", "bar"]) self.assertEqual(results_2[0].stat_entry.st_size, 42) def testListDescendantPathInfosTimestampStatValue_ListVersion(self): @@ -1051,7 +1380,9 @@ def testListDescendantPathInfosTimestampStatValue_ListVersion(self): timestamp_0 = self.db.Now() - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info.stat_entry.st_size = 1337 self.db.WritePathInfos(client_id, [path_info]) @@ -1063,36 +1394,30 @@ def testListDescendantPathInfosTimestampStatValue_ListVersion(self): results_0 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=[ - "foo", - ], + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], timestamp=timestamp_0, ) self.assertEmpty(results_0) results_1 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=[ - "foo", - ], + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], timestamp=timestamp_1, ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("foo", "bar")) + self.assertEqual(results_1[0].components, ["foo", "bar"]) self.assertEqual(results_1[0].stat_entry.st_size, 1337) results_2 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=[ - "foo", - ], + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo"], timestamp=timestamp_2, ) self.assertLen(results_2, 1) - self.assertEqual(results_2[0].components, ("foo", "bar")) + self.assertEqual(results_2[0].components, ["foo", "bar"]) self.assertEqual(results_2[0].stat_entry.st_size, 42) def testListDescendantPathInfosTimestampHashValue(self): @@ -1100,7 +1425,9 @@ def testListDescendantPathInfosTimestampHashValue(self): timestamp_0 = self.db.Now() - path_info = rdf_objects.PathInfo.OS(components=("foo",)) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info.hash_entry.md5 = b"quux" path_info.hash_entry.sha256 = b"thud" @@ -1116,274 +1443,418 @@ def testListDescendantPathInfosTimestampHashValue(self): results_0 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_0) + timestamp=timestamp_0, + ) self.assertEmpty(results_0) results_1 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("foo",)) self.assertEqual(results_1[0].hash_entry.md5, b"quux") self.assertEqual(results_1[0].hash_entry.sha256, b"thud") results_2 = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=(), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertLen(results_2, 1) - self.assertEqual(results_2[0].components, ("foo",)) + self.assertEqual(results_2[0].components, ["foo"]) self.assertEqual(results_2[0].hash_entry.md5, b"norf") self.assertEqual(results_2[0].hash_entry.sha256, b"blargh") def testListDescendantPathInfosWildcards(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=("foo", "quux")), - rdf_objects.PathInfo.OS(components=("bar", "norf")), - rdf_objects.PathInfo.OS(components=("___", "thud")), - rdf_objects.PathInfo.OS(components=("%%%", "ztesch")), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "quux"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("bar", "norf"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("___", "thud"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%%", "ztesch"), + ), + ], + ) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("___",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("___",), + ) self.assertLen(results, 1) - self.assertEqual(results[0].components, ("___", "thud")) + self.assertEqual(results[0].components, ["___", "thud"]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("%%%",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%%",), + ) self.assertLen(results, 1) - self.assertEqual(results[0].components, ("%%%", "ztesch")) + self.assertEqual(results[0].components, ["%%%", "ztesch"]) def testListDescendantPathInfosManyWildcards(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=("%", "%%", "%%%")), - rdf_objects.PathInfo.OS(components=("%", "%%%", "%")), - rdf_objects.PathInfo.OS(components=("%%", "%", "%%%")), - rdf_objects.PathInfo.OS(components=("%%", "%%%", "%")), - rdf_objects.PathInfo.OS(components=("%%%", "%%", "%%")), - rdf_objects.PathInfo.OS(components=("__", "%%", "__")), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%", "%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%%", "%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%", "%", "%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%", "%%%", "%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%%", "%%", "%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("__", "%%", "__"), + ), + ], + ) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("%",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%",), + ) self.assertLen(results, 4) - self.assertEqual(results[0].components, ("%", "%%")) - self.assertEqual(results[1].components, ("%", "%%", "%%%")) - self.assertEqual(results[2].components, ("%", "%%%")) - self.assertEqual(results[3].components, ("%", "%%%", "%")) + self.assertEqual(results[0].components, ["%", "%%"]) + self.assertEqual(results[1].components, ["%", "%%", "%%%"]) + self.assertEqual(results[2].components, ["%", "%%%"]) + self.assertEqual(results[3].components, ["%", "%%%", "%"]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("%%",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%",), + ) self.assertLen(results, 4) - self.assertEqual(results[0].components, ("%%", "%")) - self.assertEqual(results[1].components, ("%%", "%", "%%%")) - self.assertEqual(results[2].components, ("%%", "%%%")) - self.assertEqual(results[3].components, ("%%", "%%%", "%")) + self.assertEqual(results[0].components, ["%%", "%"]) + self.assertEqual(results[1].components, ["%%", "%", "%%%"]) + self.assertEqual(results[2].components, ["%%", "%%%"]) + self.assertEqual(results[3].components, ["%%", "%%%", "%"]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("__",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("__",), + ) self.assertLen(results, 2) - self.assertEqual(results[0].components, ("__", "%%")) - self.assertEqual(results[1].components, ("__", "%%", "__")) + self.assertEqual(results[0].components, ["__", "%%"]) + self.assertEqual(results[1].components, ["__", "%%", "__"]) def testListDescendantPathInfosWildcardsWithMaxDepth(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=("%", "%%foo", "%%%bar", "%%%%")), - rdf_objects.PathInfo.OS(components=("%", "%%foo", "%%%baz", "%%%%")), - rdf_objects.PathInfo.OS(components=("%", "%%quux", "%%%norf", "%%%%")), - rdf_objects.PathInfo.OS(components=("%", "%%quux", "%%%thud", "%%%%")), - rdf_objects.PathInfo.OS(components=("%%", "%%bar", "%%%quux")), - rdf_objects.PathInfo.OS(components=("%%", "%%baz", "%%%norf")), - rdf_objects.PathInfo.OS(components=("__", "__bar__", "__quux__")), - rdf_objects.PathInfo.OS(components=("__", "__baz__", "__norf__")), - rdf_objects.PathInfo.OS(components=("blargh",)), - rdf_objects.PathInfo.OS(components=("ztesch",)), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%foo", "%%%bar", "%%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%foo", "%%%baz", "%%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%quux", "%%%norf", "%%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%", "%%quux", "%%%thud", "%%%%"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%", "%%bar", "%%%quux"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("%%", "%%baz", "%%%norf"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("__", "__bar__", "__quux__"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("__", "__baz__", "__norf__"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("blargh",), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("ztesch",), + ), + ], + ) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("%",), - max_depth=2) + max_depth=2, + ) self.assertLen(results, 6) - self.assertEqual(results[0].components, ("%", "%%foo")) - self.assertEqual(results[1].components, ("%", "%%foo", "%%%bar")) - self.assertEqual(results[2].components, ("%", "%%foo", "%%%baz")) - self.assertEqual(results[3].components, ("%", "%%quux")) - self.assertEqual(results[4].components, ("%", "%%quux", "%%%norf")) - self.assertEqual(results[5].components, ("%", "%%quux", "%%%thud")) + self.assertEqual(results[0].components, ["%", "%%foo"]) + self.assertEqual(results[1].components, ["%", "%%foo", "%%%bar"]) + self.assertEqual(results[2].components, ["%", "%%foo", "%%%baz"]) + self.assertEqual(results[3].components, ["%", "%%quux"]) + self.assertEqual(results[4].components, ["%", "%%quux", "%%%norf"]) + self.assertEqual(results[5].components, ["%", "%%quux", "%%%thud"]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("%%",), - max_depth=1) + max_depth=1, + ) self.assertLen(results, 2) - self.assertEqual(results[0].components, ("%%", "%%bar")) - self.assertEqual(results[1].components, ("%%", "%%baz")) + self.assertEqual(results[0].components, ["%%", "%%bar"]) + self.assertEqual(results[1].components, ["%%", "%%baz"]) results = self.db.ListDescendantPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("__",), - max_depth=1) + max_depth=1, + ) self.assertLen(results, 2) - self.assertEqual(results[0].components, ("__", "__bar__")) - self.assertEqual(results[1].components, ("__", "__baz__")) + self.assertEqual(results[0].components, ["__", "__bar__"]) + self.assertEqual(results[1].components, ["__", "__baz__"]) def testListChildPathInfosRoot(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar"]), - rdf_objects.PathInfo.OS(components=["foo", "baz"]), - rdf_objects.PathInfo.OS(components=["quux", "norf"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "baz"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["quux", "norf"], + ), + ], + ) results = self.db.ListChildPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=()) + client_id, objects_pb2.PathInfo.PathType.OS, components=() + ) - self.assertEqual(results[0].components, ("foo",)) + self.assertEqual(results[0].components, ["foo"]) self.assertTrue(results[0].directory) - self.assertEqual(results[1].components, ("quux",)) + self.assertEqual(results[1].components, ["quux"]) self.assertTrue(results[1].directory) def testListChildPathInfosRootDeeper(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=("foo", "bar", "baz")), - rdf_objects.PathInfo.OS(components=("foo", "bar", "quux")), - rdf_objects.PathInfo.OS(components=("foo", "bar", "norf", "thud")), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "quux"), + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "norf", "thud"), + ), + ], + ) results = self.db.ListChildPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=()) + client_id, objects_pb2.PathInfo.PathType.OS, components=() + ) self.assertLen(results, 1) - self.assertEqual(results[0].components, ("foo",)) + self.assertEqual( + results[0].components, + [ + "foo", + ], + ) self.assertTrue(results[0].directory) def testListChildPathInfosDetails(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo", "bar"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "bar"] + ) path_info.stat_entry.st_size = 42 self.db.WritePathInfos(client_id, [path_info]) - path_info = rdf_objects.PathInfo.OS(components=["foo", "baz"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo", "baz"] + ) path_info.hash_entry.md5 = b"quux" path_info.hash_entry.sha256 = b"norf" self.db.WritePathInfos(client_id, [path_info]) results = self.db.ListChildPathInfos( - client_id, rdf_objects.PathInfo.PathType.OS, components=("foo",)) - self.assertEqual(results[0].components, ("foo", "bar")) + client_id, objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) + self.assertEqual(results[0].components, ["foo", "bar"]) self.assertEqual(results[0].stat_entry.st_size, 42) - self.assertEqual(results[1].components, ("foo", "baz")) + self.assertEqual(results[1].components, ["foo", "baz"]) self.assertEqual(results[1].hash_entry.md5, b"quux") self.assertEqual(results[1].hash_entry.sha256, b"norf") def testListChildPathInfosDeepSorted(self): client_id = db_test_utils.InitializeClient(self.db) - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz", "quux"]), - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz", "norf"]), - rdf_objects.PathInfo.OS(components=["foo", "bar", "baz", "thud"]), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", "quux"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", "norf"], + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=["foo", "bar", "baz", "thud"], + ), + ], + ) results = self.db.ListChildPathInfos( client_id, - rdf_objects.PathInfo.PathType.OS, - components=("foo", "bar", "baz")) - self.assertEqual(results[0].components, ("foo", "bar", "baz", "norf")) - self.assertEqual(results[1].components, ("foo", "bar", "baz", "quux")) - self.assertEqual(results[2].components, ("foo", "bar", "baz", "thud")) + objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) + self.assertEqual(results[0].components, ["foo", "bar", "baz", "norf"]) + self.assertEqual(results[1].components, ["foo", "bar", "baz", "quux"]) + self.assertEqual(results[2].components, ["foo", "bar", "baz", "thud"]) def testListChildPathInfosTimestamp(self): client_id = db_test_utils.InitializeClient(self.db) timestamp_0 = self.db.Now() - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_1.stat_entry.st_size = 1 self.db.WritePathInfos(client_id, [path_info_1]) timestamp_1 = self.db.Now() - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "baz")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "baz") + ) path_info_2.stat_entry.st_size = 2 self.db.WritePathInfos(client_id, [path_info_2]) timestamp_2 = self.db.Now() results_0 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_0) + timestamp=timestamp_0, + ) self.assertEmpty(results_0) results_1 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("foo", "bar")) + self.assertEqual(results_1[0].components, ["foo", "bar"]) self.assertEqual(results_1[0].stat_entry.st_size, 1) results_2 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertLen(results_2, 2) - self.assertEqual(results_2[0].components, ("foo", "bar")) + self.assertEqual(results_2[0].components, ["foo", "bar"]) self.assertEqual(results_2[0].stat_entry.st_size, 1) - self.assertEqual(results_2[1].components, ("foo", "baz")) + self.assertEqual(results_2[1].components, ["foo", "baz"]) self.assertEqual(results_2[1].stat_entry.st_size, 2) def testListChildPathInfosTimestampStatAndHashValue(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar", "baz")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) path_info.stat_entry.st_size = 42 path_info.hash_entry.sha256 = b"quux" self.db.WritePathInfos(client_id, [path_info]) timestamp_1 = self.db.Now() - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar", "baz")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) path_info.stat_entry.st_size = 108 path_info.hash_entry.sha256 = b"norf" self.db.WritePathInfos(client_id, [path_info]) timestamp_2 = self.db.Now() - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar", "baz")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("foo", "bar", "baz"), + ) path_info.stat_entry.st_size = 1337 path_info.hash_entry.sha256 = b"thud" self.db.WritePathInfos(client_id, [path_info]) @@ -1391,107 +1862,129 @@ def testListChildPathInfosTimestampStatAndHashValue(self): results_1 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar"), - timestamp=timestamp_1) + timestamp=timestamp_1, + ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("foo", "bar", "baz")) + self.assertEqual(results_1[0].components, ["foo", "bar", "baz"]) self.assertEqual(results_1[0].stat_entry.st_size, 42) self.assertEqual(results_1[0].hash_entry.sha256, b"quux") results_2 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar"), - timestamp=timestamp_2) + timestamp=timestamp_2, + ) self.assertLen(results_2, 1) - self.assertEqual(results_2[0].components, ("foo", "bar", "baz")) + self.assertEqual(results_2[0].components, ["foo", "bar", "baz"]) self.assertEqual(results_2[0].stat_entry.st_size, 108) self.assertEqual(results_2[0].hash_entry.sha256, b"norf") results_3 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar"), - timestamp=timestamp_3) + timestamp=timestamp_3, + ) self.assertLen(results_3, 1) - self.assertEqual(results_3[0].components, ("foo", "bar", "baz")) + self.assertEqual(results_3[0].components, ["foo", "bar", "baz"]) self.assertEqual(results_3[0].stat_entry.st_size, 1337) self.assertEqual(results_3[0].hash_entry.sha256, b"thud") def testListChildPathInfosBackslashes(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=("\\", "\\\\", "\\\\\\")) - path_info_2 = rdf_objects.PathInfo.OS(components=("\\", "\\\\\\", "\\\\")) - path_info_3 = rdf_objects.PathInfo.OS(components=("\\", "foo\\bar", "baz")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "\\\\", "\\\\\\"), + ) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "\\\\\\", "\\\\"), + ) + path_info_3 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "foo\\bar", "baz"), + ) self.db.WritePathInfos(client_id, [path_info_1, path_info_2, path_info_3]) results_0 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("\\",)) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\",), + ) self.assertLen(results_0, 3) - self.assertEqual(results_0[0].components, ("\\", "\\\\")) - self.assertEqual(results_0[1].components, ("\\", "\\\\\\")) - self.assertEqual(results_0[2].components, ("\\", "foo\\bar")) + self.assertEqual(results_0[0].components, ["\\", "\\\\"]) + self.assertEqual(results_0[1].components, ["\\", "\\\\\\"]) + self.assertEqual(results_0[2].components, ["\\", "foo\\bar"]) results_1 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("\\", "\\\\")) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "\\\\"), + ) self.assertLen(results_1, 1) - self.assertEqual(results_1[0].components, ("\\", "\\\\", "\\\\\\")) + self.assertEqual(results_1[0].components, ["\\", "\\\\", "\\\\\\"]) results_2 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("\\", "\\\\\\")) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "\\\\\\"), + ) self.assertLen(results_2, 1) - self.assertEqual(results_2[0].components, ("\\", "\\\\\\", "\\\\")) + self.assertEqual(results_2[0].components, ["\\", "\\\\\\", "\\\\"]) results_3 = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=("\\", "foo\\bar")) + path_type=objects_pb2.PathInfo.PathType.OS, + components=("\\", "foo\\bar"), + ) self.assertLen(results_3, 1) - self.assertEqual(results_3[0].components, ("\\", "foo\\bar", "baz")) + self.assertEqual(results_3[0].components, ["\\", "foo\\bar", "baz"]) def testListChildPathInfosTSKRootVolume(self): client_id = db_test_utils.InitializeClient(self.db) volume = "\\\\?\\Volume{2d4fbbd3-0000-0000-0000-100000000000}" - path_info = rdf_objects.PathInfo.TSK(components=(volume, "foobar.txt")) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, + components=(volume, "foobar.txt"), + ) path_info.stat_entry.st_size = 42 self.db.WritePathInfos(client_id, [path_info]) results = self.db.ListChildPathInfos( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.TSK, - components=(volume,)) + path_type=objects_pb2.PathInfo.PathType.TSK, + components=(volume,), + ) self.assertLen(results, 1) - self.assertEqual(results[0].components, (volume, "foobar.txt")) + self.assertEqual(results[0].components, [volume, "foobar.txt"]) self.assertEqual(results[0].stat_entry.st_size, 42) def testReadPathInfosHistoriesEmpty(self): client_id = db_test_utils.InitializeClient(self.db) - result = self.db.ReadPathInfosHistories(client_id, - rdf_objects.PathInfo.PathType.OS, - []) + result = self.db.ReadPathInfosHistories( + client_id, objects_pb2.PathInfo.PathType.OS, [] + ) self.assertEqual(result, {}) def testReadPathInfosHistoriesDoesNotRaiseOnUnknownClient(self): - results = self.db.ReadPathInfosHistories("C.FFFF111122223333", - rdf_objects.PathInfo.PathType.OS, - [("foo",)]) + results = self.db.ReadPathInfosHistories( + "C.FFFF111122223333", objects_pb2.PathInfo.PathType.OS, [("foo",)] + ) self.assertEqual(results[("foo",)], []) def testReadPathInfosHistoriesWithSingleFileWithSingleHistoryItem(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info.stat_entry.st_size = 42 path_info.hash_entry.sha256 = b"quux" @@ -1501,7 +1994,8 @@ def testReadPathInfosHistoriesWithSingleFileWithSingleHistoryItem(self): now = self.db.Now() path_infos = self.db.ReadPathInfosHistories( - client_id, rdf_objects.PathInfo.PathType.OS, [("foo",)]) + client_id, objects_pb2.PathInfo.PathType.OS, [("foo",)] + ) self.assertLen(path_infos, 1) pi = path_infos[("foo",)] @@ -1513,10 +2007,14 @@ def testReadPathInfosHistoriesWithSingleFileWithSingleHistoryItem(self): def testReadPathInfosHistoriesWithTwoFilesWithSingleHistoryItemEach(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=["foo"]) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info_1.stat_entry.st_mode = 1337 - path_info_2 = rdf_objects.PathInfo.OS(components=["bar"]) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["bar"] + ) path_info_2.hash_entry.sha256 = b"quux" then = self.db.Now() @@ -1524,28 +2022,36 @@ def testReadPathInfosHistoriesWithTwoFilesWithSingleHistoryItemEach(self): now = self.db.Now() path_infos = self.db.ReadPathInfosHistories( - client_id, rdf_objects.PathInfo.PathType.OS, [("foo",), ("bar",)]) + client_id, objects_pb2.PathInfo.PathType.OS, [("foo",), ("bar",)] + ) self.assertLen(path_infos, 2) pi = path_infos[("bar",)] self.assertLen(pi, 1) - self.assertEqual(pi[0].components, ("bar",)) + self.assertEqual(pi[0].components, ["bar"]) self.assertEqual(pi[0].hash_entry.sha256, b"quux") self.assertBetween(pi[0].timestamp, then, now) pi = path_infos[("foo",)] self.assertLen(pi, 1) - self.assertEqual(pi[0].components, ("foo",)) + self.assertEqual( + pi[0].components, + ["foo"], + ) self.assertEqual(pi[0].stat_entry.st_mode, 1337) self.assertBetween(pi[0].timestamp, then, now) def testReadPathInfosHistoriesWithTwoFilesWithTwoHistoryItems(self): client_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=["foo"]) - path_info_2 = rdf_objects.PathInfo.OS(components=["bar"]) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["bar"] + ) timestamp_1 = self.db.Now() @@ -1570,35 +2076,38 @@ def testReadPathInfosHistoriesWithTwoFilesWithTwoHistoryItems(self): timestamp_5 = self.db.Now() path_infos = self.db.ReadPathInfosHistories( - client_id, rdf_objects.PathInfo.PathType.OS, [("foo",), ("bar",)]) + client_id, objects_pb2.PathInfo.PathType.OS, [("foo",), ("bar",)] + ) self.assertLen(path_infos, 2) pi = path_infos[("bar",)] self.assertLen(pi, 2) - self.assertEqual(pi[0].components, ("bar",)) + self.assertEqual(pi[0].components, ["bar"]) self.assertEqual(pi[0].stat_entry.st_mode, 109) self.assertBetween(pi[0].timestamp, timestamp_3, timestamp_4) - self.assertEqual(pi[1].components, ("bar",)) + self.assertEqual(pi[1].components, ["bar"]) self.assertEqual(pi[1].stat_entry.st_mode, 110) self.assertBetween(pi[1].timestamp, timestamp_4, timestamp_5) pi = path_infos[("foo",)] self.assertLen(pi, 2) - self.assertEqual(pi[0].components, ("foo",)) + self.assertEqual(pi[0].components, ["foo"]) self.assertEqual(pi[0].stat_entry.st_mode, 1337) self.assertBetween(pi[0].timestamp, timestamp_1, timestamp_2) - self.assertEqual(pi[1].components, ("foo",)) + self.assertEqual(pi[1].components, ["foo"]) self.assertEqual(pi[1].stat_entry.st_mode, 1338) self.assertBetween(pi[1].timestamp, timestamp_2, timestamp_3) def testReadPathInfoHistoryTimestamp(self): client_id = db_test_utils.InitializeClient(self.db) - path_info = rdf_objects.PathInfo.OS(components=["foo"]) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=["foo"] + ) path_info.stat_entry.st_size = 0 self.db.WritePathInfos(client_id, [path_info]) @@ -1616,9 +2125,10 @@ def testReadPathInfoHistoryTimestamp(self): path_infos = self.db.ReadPathInfoHistory( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",), - cutoff=cutoff) + cutoff=cutoff, + ) self.assertLen(path_infos, 3) self.assertEqual(path_infos[0].stat_entry.st_size, 0) @@ -1650,35 +2160,46 @@ def testReadLatestPathInfosReturnsNothingForNonExistingPaths(self): path_2 = db.ClientPath.TSK(client_b_id, components=("foo", "baz")) results = self.db.ReadLatestPathInfosWithHashBlobReferences( - [path_1, path_2]) + [path_1, path_2] + ) self.assertEqual(results, {path_1: None, path_2: None}) def testReadLatestPathInfosReturnsNothingWhenNoFilesCollected(self): client_a_id = db_test_utils.InitializeClient(self.db) client_b_id = db_test_utils.InitializeClient(self.db) - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) self.db.WritePathInfos(client_a_id, [path_info_1]) - path_info_2 = rdf_objects.PathInfo.TSK(components=("foo", "baz")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, components=("foo", "baz") + ) self.db.WritePathInfos(client_b_id, [path_info_2]) path_1 = db.ClientPath.OS(client_a_id, components=("foo", "bar")) path_2 = db.ClientPath.TSK(client_b_id, components=("foo", "baz")) results = self.db.ReadLatestPathInfosWithHashBlobReferences( - [path_1, path_2]) + [path_1, path_2] + ) self.assertEqual(results, {path_1: None, path_2: None}) def testReadLatestPathInfosFindsTwoCollectedFilesWhenTheyAreTheOnlyEntries( - self): + self, + ): client_a_id = db_test_utils.InitializeClient(self.db) client_b_id = db_test_utils.InitializeClient(self.db) hash_id_1, hash_id_2 = self._WriteBlobReferences() - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_1.hash_entry.sha256 = hash_id_1.AsBytes() self.db.WritePathInfos(client_a_id, [path_info_1]) - path_info_2 = rdf_objects.PathInfo.TSK(components=("foo", "baz")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.TSK, components=("foo", "baz") + ) path_info_2.hash_entry.sha256 = hash_id_2.AsBytes() self.db.WritePathInfos(client_b_id, [path_info_2]) @@ -1686,7 +2207,8 @@ def testReadLatestPathInfosFindsTwoCollectedFilesWhenTheyAreTheOnlyEntries( path_2 = db.ClientPath.TSK(client_b_id, components=("foo", "baz")) results = self.db.ReadLatestPathInfosWithHashBlobReferences( - [path_1, path_2]) + [path_1, path_2] + ) self.assertCountEqual(results.keys(), [path_1, path_2]) self.assertEqual(results[path_1].hash_entry, path_info_1.hash_entry) self.assertEqual(results[path_2].hash_entry, path_info_2.hash_entry) @@ -1694,15 +2216,20 @@ def testReadLatestPathInfosFindsTwoCollectedFilesWhenTheyAreTheOnlyEntries( self.assertTrue(results[path_2].timestamp) def testReadLatestPathInfosCorrectlyFindsCollectedFileWithNonLatestEntry( - self): + self, + ): client_id = db_test_utils.InitializeClient(self.db) hash_id, _ = self._WriteBlobReferences() - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_1.hash_entry.sha256 = hash_id.AsBytes() self.db.WritePathInfos(client_id, [path_info_1]) - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) self.db.WritePathInfos(client_id, [path_info_2]) path = db.ClientPath.OS(client_id, components=("foo", "bar")) @@ -1716,11 +2243,15 @@ def testReadLatestPathInfosCorrectlyFindsLatestOfTwoCollectedFiles(self): client_id = db_test_utils.InitializeClient(self.db) hash_id_1, hash_id_2 = self._WriteBlobReferences() - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_1.hash_entry.sha256 = hash_id_1.AsBytes() self.db.WritePathInfos(client_id, [path_info_1]) - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_2.hash_entry.sha256 = hash_id_2.AsBytes() self.db.WritePathInfos(client_id, [path_info_2]) @@ -1731,23 +2262,29 @@ def testReadLatestPathInfosCorrectlyFindsLatestOfTwoCollectedFiles(self): self.assertTrue(results[path].timestamp) def testReadLatestPathInfosCorrectlyFindsLatestCollectedFileBeforeTimestamp( - self): + self, + ): client_id = db_test_utils.InitializeClient(self.db) hash_id_1, hash_id_2 = self._WriteBlobReferences() - path_info_1 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_1.hash_entry.sha256 = hash_id_1.AsBytes() self.db.WritePathInfos(client_id, [path_info_1]) time_checkpoint = self.db.Now() - path_info_2 = rdf_objects.PathInfo.OS(components=("foo", "bar")) + path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) path_info_2.hash_entry.sha256 = hash_id_2.AsBytes() self.db.WritePathInfos(client_id, [path_info_2]) path = db.ClientPath.OS(client_id, components=("foo", "bar")) results = self.db.ReadLatestPathInfosWithHashBlobReferences( - [path], max_timestamp=time_checkpoint) + [path], max_timestamp=time_checkpoint + ) self.assertCountEqual(results.keys(), [path]) self.assertEqual(results[path].hash_entry, path_info_1.hash_entry) self.assertTrue(results[path].timestamp) @@ -1766,20 +2303,28 @@ def testReadLatestPathInfosMaxTimestampMultiplePaths(self): rdf_objects.SHA256HashID(after_hash_id): [blob_ref], }) - before_path_info_1 = rdf_objects.PathInfo.OS(components=("foo",)) + before_path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) before_path_info_1.hash_entry.sha256 = before_hash_id - before_path_info_2 = rdf_objects.PathInfo.OS(components=("bar",)) + before_path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("bar",) + ) before_path_info_2.hash_entry.sha256 = before_hash_id self.db.WritePathInfos(client_id, [before_path_info_1, before_path_info_2]) timestamp = self.db.Now() - after_path_info_1 = rdf_objects.PathInfo.OS(components=("foo",)) + after_path_info_1 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo",) + ) after_path_info_1.hash_entry.sha256 = after_hash_id - after_path_info_2 = rdf_objects.PathInfo.OS(components=("bar",)) + after_path_info_2 = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("bar",) + ) after_path_info_2.hash_entry.sha256 = after_hash_id self.db.WritePathInfos(client_id, [after_path_info_1, after_path_info_2]) @@ -1788,19 +2333,23 @@ def testReadLatestPathInfosMaxTimestampMultiplePaths(self): client_path_2 = db.ClientPath.OS(client_id, ("bar",)) results = self.db.ReadLatestPathInfosWithHashBlobReferences( - [client_path_1, client_path_2], max_timestamp=timestamp) + [client_path_1, client_path_2], max_timestamp=timestamp + ) self.assertLen(results, 2) self.assertEqual(results[client_path_1].hash_entry.sha256, before_hash_id) self.assertEqual(results[client_path_2].hash_entry.sha256, before_hash_id) def testReadLatestPathInfosIncludesStatEntryIfThereIsOneWithSameTimestamp( - self): + self, + ): client_id = db_test_utils.InitializeClient(self.db) hash_id, _ = self._WriteBlobReferences() - path_info = rdf_objects.PathInfo.OS(components=("foo", "bar")) - path_info.stat_entry = rdf_client_fs.StatEntry(st_mode=42) + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, components=("foo", "bar") + ) + path_info.stat_entry.st_mode = 42 path_info.hash_entry.sha256 = hash_id.AsBytes() self.db.WritePathInfos(client_id, [path_info]) @@ -1815,25 +2364,36 @@ def testReadLatestPathInfosIncludesStatEntryIfThereIsOneWithSameTimestamp( def testWriteLongPathInfosWithCommonPrefix(self): client_id = db_test_utils.InitializeClient(self.db) - prefix = ("foobarbaz",) * 303 - quux_components = prefix + ("quux",) - norf_components = prefix + ("norf",) + prefix = ["foobarbaz"] * 303 + quux_components = prefix + ["quux"] + norf_components = prefix + ["norf"] - self.db.WritePathInfos(client_id, [ - rdf_objects.PathInfo.OS(components=quux_components), - rdf_objects.PathInfo.OS(components=norf_components), - ]) + self.db.WritePathInfos( + client_id, + [ + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=quux_components, + ), + objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.OS, + components=norf_components, + ), + ], + ) quux_path_info = self.db.ReadPathInfo( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=quux_components) + path_type=objects_pb2.PathInfo.PathType.OS, + components=quux_components, + ) self.assertEqual(quux_path_info.components, quux_components) norf_path_info = self.db.ReadPathInfo( client_id=client_id, - path_type=rdf_objects.PathInfo.PathType.OS, - components=norf_components) + path_type=objects_pb2.PathInfo.PathType.OS, + components=norf_components, + ) self.assertEqual(norf_path_info.components, norf_components) diff --git a/grr/server/grr_response_server/databases/db_signed_binaries_test.py b/grr/server/grr_response_server/databases/db_signed_binaries_test.py index 2b0d2cbaa5..24eb17a9e6 100644 --- a/grr/server/grr_response_server/databases/db_signed_binaries_test.py +++ b/grr/server/grr_response_server/databases/db_signed_binaries_test.py @@ -46,7 +46,8 @@ def testUpdateSignedBinaryReferences(self): self.assertEqual(stored_references1, _test_references1) self.db.WriteSignedBinaryReferences(_test_id1, _test_references2) stored_references2, timestamp2 = self.db.ReadSignedBinaryReferences( - _test_id1) + _test_id1 + ) self.assertEqual(stored_references2, _test_references2) self.assertGreater(timestamp2, timestamp1) diff --git a/grr/server/grr_response_server/databases/db_test_mixin.py b/grr/server/grr_response_server/databases/db_test_mixin.py index 26673ac8ca..61ca3f5ff3 100644 --- a/grr/server/grr_response_server/databases/db_test_mixin.py +++ b/grr/server/grr_response_server/databases/db_test_mixin.py @@ -77,8 +77,9 @@ def setUp(self): # Set up database before calling super.setUp(), in case any other mixin # depends on db during its setup. db_obj, cleanup = self.CreateDatabase() - patcher = mock.patch.object(data_store, "REL_DB", - db.DatabaseValidationWrapper(db_obj)) + patcher = mock.patch.object( + data_store, "REL_DB", db.DatabaseValidationWrapper(db_obj) + ) patcher.start() self.addCleanup(patcher.stop) diff --git a/grr/server/grr_response_server/databases/db_test_utils.py b/grr/server/grr_response_server/databases/db_test_utils.py index 8d3e5fa767..0f2f52a5fe 100644 --- a/grr/server/grr_response_server/databases/db_test_utils.py +++ b/grr/server/grr_response_server/databases/db_test_utils.py @@ -7,11 +7,12 @@ from typing import Any, Callable, Dict, Iterable, Optional, Text from grr_response_core.lib import rdfvalue +from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_server.databases import db as abstract_db from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects -from grr_response_server.rdfvalues import mig_hunt_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 @@ -257,9 +258,13 @@ def InitializeFlow( db: abstract_db.Database, client_id: str, flow_id: Optional[str] = None, + flow_state: Optional[rdf_structs.EnumNamedValue] = None, parent_flow_id: Optional[str] = None, parent_hunt_id: Optional[str] = None, - **kwargs, + next_request_to_process: Optional[int] = None, + cpu_time_used: Optional[float] = None, + network_bytes_sent: Optional[int] = None, + creator: Optional[str] = None, ) -> str: """Initializes a test flow. @@ -268,9 +273,13 @@ def InitializeFlow( client_id: A client id of the client to run the flow on. flow_id: A specific flow id to use for initialized flow. If none is provided a randomly generated one is used. + flow_state: A flow state (optional). parent_flow_id: Identifier of the parent flow (optional). parent_hunt_id: Identifier of the parent hunt (optional). - **kwargs: Parameters to initialize the flow object with. + next_request_to_process: The next request to process (optional). + cpu_time_used: The used CPU time (optional). + network_bytes_sent: The number of bytes sent (optional). + creator: The creator of the flow (optional). Returns: A flow id of the initialized flow. @@ -279,9 +288,10 @@ def InitializeFlow( random_digit = lambda: random.choice(string.hexdigits).upper() flow_id = "".join(random_digit() for _ in range(16)) - flow_obj = rdf_flow_objects.Flow(**kwargs) - flow_obj.client_id = client_id - flow_obj.flow_id = flow_id + flow_obj = rdf_flow_objects.Flow(client_id=client_id, flow_id=flow_id) + + if flow_state is not None: + flow_obj.flow_state = flow_state if parent_flow_id is not None: flow_obj.parent_flow_id = parent_flow_id @@ -289,7 +299,19 @@ def InitializeFlow( if parent_hunt_id is not None: flow_obj.parent_hunt_id = parent_hunt_id - db.WriteFlowObject(flow_obj) + if cpu_time_used is not None: + flow_obj.cpu_time_used = cpu_time_used + + if network_bytes_sent is not None: + flow_obj.network_bytes_sent = network_bytes_sent + + if next_request_to_process is not None: + flow_obj.next_request_to_process = next_request_to_process + + if creator is not None: + flow_obj.creator = creator + + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) return flow_id @@ -298,6 +320,8 @@ def InitializeHunt( db: abstract_db.Database, hunt_id: Optional[str] = None, creator: Optional[str] = None, + description: Optional[str] = None, + client_rate: Optional[float] = None, ) -> str: """Initializes a test user. @@ -307,6 +331,8 @@ def InitializeHunt( a randomly generated one is used. creator: A username of the hunt creator. If none is provided a randomly generated one is used (and initialized). + description: A hunt description. + client_rate: The client rate Returns: A hunt id of the initialized hunt. @@ -317,10 +343,15 @@ def InitializeHunt( if creator is None: creator = InitializeUser(db) - hunt_obj = rdf_hunt_objects.Hunt() - hunt_obj.hunt_id = hunt_id - hunt_obj.creator = creator - hunt_obj = mig_hunt_objects.ToProtoHunt(hunt_obj) + hunt_obj = hunts_pb2.Hunt( + hunt_id=hunt_id, + hunt_state=hunts_pb2.Hunt.HuntState.PAUSED, + creator=creator, + ) + if description is not None: + hunt_obj.description = description + if client_rate is not None: + hunt_obj.client_rate = client_rate db.WriteHuntObject(hunt_obj) return hunt_id diff --git a/grr/server/grr_response_server/databases/db_test_utils_test.py b/grr/server/grr_response_server/databases/db_test_utils_test.py index 69f22f5c72..03862a460a 100644 --- a/grr/server/grr_response_server/databases/db_test_utils_test.py +++ b/grr/server/grr_response_server/databases/db_test_utils_test.py @@ -10,15 +10,17 @@ from grr_response_server.rdfvalues import mig_objects -class TestOffsetAndCountTest(db_test_utils.QueryTestHelpersMixin, - absltest.TestCase): +class TestOffsetAndCountTest( + db_test_utils.QueryTestHelpersMixin, absltest.TestCase +): def testDoesNotRaiseWhenWorksAsExpected(self): items = range(10) self.DoOffsetAndCountTest( lambda: items, - lambda offset, count: items[offset:offset + count], - error_desc="foo") + lambda offset, count: items[offset : offset + count], + error_desc="foo", + ) def testRaisesWhenDoesNotWorkAsExpected(self): items = range(10) @@ -28,18 +30,20 @@ def FetchRangeFn(offset, count): if offset > 5: return [] else: - return items[offset:offset + count] + return items[offset : offset + count] with self.assertRaisesRegex( AssertionError, re.escape( - "Results differ from expected (offset 6, count 1, foo): [] vs [6]") + "Results differ from expected (offset 6, count 1, foo): [] vs [6]" + ), ): self.DoOffsetAndCountTest(lambda: items, FetchRangeFn, error_desc="foo") -class TestFilterCombinations(db_test_utils.QueryTestHelpersMixin, - absltest.TestCase): +class TestFilterCombinations( + db_test_utils.QueryTestHelpersMixin, absltest.TestCase +): def testDoesNotRaiseWhenWorkingAsExpected(self): @@ -62,7 +66,8 @@ def FetchFn(bigger_than_3_only=None, less_than_7_only=None, even_only=None): self.DoFilterCombinationsTest( FetchFn, dict(bigger_than_3_only=True, less_than_7_only=True, even_only=True), - error_desc="foo") + error_desc="foo", + ) def testRaisesWhenDoesNotWorkAsExpected(self): @@ -91,23 +96,29 @@ def FetchFn(bigger_than_3_only=None, less_than_7_only=None, even_only=None): re.escape( "Results differ from expected " "({'bigger_than_3_only': True, 'less_than_7_only': True}, foo): " - "[5, 6] vs [4, 5, 6]")): + "[5, 6] vs [4, 5, 6]" + ), + ): self.DoFilterCombinationsTest( FetchFn, dict(bigger_than_3_only=True, less_than_7_only=True, even_only=True), - error_desc="foo") + error_desc="foo", + ) class TestFilterCombinationsAndOffsetCountTest( - db_test_utils.QueryTestHelpersMixin, absltest.TestCase): + db_test_utils.QueryTestHelpersMixin, absltest.TestCase +): def testDoesNotRaiseWhenWorksAsExpected(self): - def FetchFn(offset, - count, - bigger_than_3_only=None, - less_than_7_only=None, - even_only=None): + def FetchFn( + offset, + count, + bigger_than_3_only=None, + less_than_7_only=None, + even_only=None, + ): result = [] for i in range(10): if bigger_than_3_only and i <= 3: @@ -121,20 +132,23 @@ def FetchFn(offset, result.append(i) - return result[offset:offset + count] + return result[offset : offset + count] self.DoFilterCombinationsAndOffsetCountTest( FetchFn, dict(bigger_than_3_only=True, less_than_7_only=True, even_only=True), - error_desc="foo") + error_desc="foo", + ) def testRaisesWhenDoesNotWorkAsExpected(self): - def FetchFn(offset, - count, - bigger_than_3_only=None, - less_than_7_only=None, - even_only=None): + def FetchFn( + offset, + count, + bigger_than_3_only=None, + less_than_7_only=None, + even_only=None, + ): del offset # Unused. result = [] @@ -156,13 +170,17 @@ def FetchFn(offset, with self.assertRaisesRegex( AssertionError, - re.escape("Results differ from expected " - "(offset 1, count 1, {'bigger_than_3_only': True}, foo): " - "[4] vs [5]")): + re.escape( + "Results differ from expected " + "(offset 1, count 1, {'bigger_than_3_only': True}, foo): " + "[4] vs [5]" + ), + ): self.DoFilterCombinationsAndOffsetCountTest( FetchFn, dict(bigger_than_3_only=True, less_than_7_only=True, even_only=True), - error_desc="foo") + error_desc="foo", + ) class InitializeClientTest(absltest.TestCase): diff --git a/grr/server/grr_response_server/databases/db_users_test.py b/grr/server/grr_response_server/databases/db_users_test.py index 76bbaf16a8..670ea8a89f 100644 --- a/grr/server/grr_response_server/databases/db_users_test.py +++ b/grr/server/grr_response_server/databases/db_users_test.py @@ -107,7 +107,8 @@ def testReadingMultipleGRRUsersEntriesWorks(self): username="f🧙oo", ui_mode="ADVANCED", canary_mode=True, - user_type=rdf_objects.GRRUser.UserType.USER_TYPE_ADMIN) + user_type=rdf_objects.GRRUser.UserType.USER_TYPE_ADMIN, + ) proto_u = mig_objects.ToProtoGRRUser(u_foo) d.WriteGRRUser( u_foo.username, @@ -258,20 +259,26 @@ def testReadWriteApprovalRequestsWithFilledInUsersEmailsAndGrants(self): self.db.GrantApproval( approval_id=approval_id, requestor_username="requestor", - grantor_username="user_foo") + grantor_username="user_foo", + ) self.db.GrantApproval( approval_id=approval_id, requestor_username="requestor", - grantor_username="user_bar") + grantor_username="user_bar", + ) read_request = d.ReadApprovalRequest("requestor", approval_id) - self.assertCountEqual(approval_request.notified_users, - read_request.notified_users) - self.assertCountEqual(approval_request.email_cc_addresses, - read_request.email_cc_addresses) - self.assertCountEqual([g.grantor_username for g in read_request.grants], - ["user_foo", "user_bar"]) + self.assertCountEqual( + approval_request.notified_users, read_request.notified_users + ) + self.assertCountEqual( + approval_request.email_cc_addresses, read_request.email_cc_addresses + ) + self.assertCountEqual( + [g.grantor_username for g in read_request.grants], + ["user_foo", "user_bar"], + ) def testGrantApprovalAddsNewGrantor(self): d = self.db @@ -322,8 +329,9 @@ def testGrantApprovalAddsMultipleGrantorsWithSameName(self): read_request = d.ReadApprovalRequest("requestor", approval_id) self.assertLen(read_request.grants, 3) - self.assertEqual([g.grantor_username for g in read_request.grants], - ["grantor"] * 3) + self.assertEqual( + [g.grantor_username for g in read_request.grants], ["grantor"] * 3 + ) def testReadApprovalRequeststReturnsNothingWhenNoApprovals(self): d = self.db @@ -359,7 +367,9 @@ def testReadApprovalRequestsReturnsSingleApproval(self): approvals = list( d.ReadApprovalRequests( "requestor", - rdf_objects.ApprovalRequest.ApprovalType.APPROVAL_TYPE_CLIENT)) + rdf_objects.ApprovalRequest.ApprovalType.APPROVAL_TYPE_CLIENT, + ) + ) self.assertLen(approvals, 1) self.assertEqual(approvals[0].approval_id, approval_id) @@ -380,7 +390,8 @@ def testReadApprovalRequestsReturnsMultipleApprovals(self): d.WriteGRRUser("requestor") expiration_time = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) approval_ids = set() for _ in range(10): @@ -450,7 +461,8 @@ def testWriteApprovalRequestSubject(self): with self.subTest(case="Read many with subject", type=approval_type): requests = self.db.ReadApprovalRequests( - "requestor", approval_type, subject_id=subject_id) + "requestor", approval_type, subject_id=subject_id + ) self.assertLen(requests, 1) self.assertEqual(requests[0].subject_id, subject_id) @@ -478,11 +490,13 @@ def testReadApprovalRequestsIncludesGrantsIntoSingleApproval(self): self.db.GrantApproval( approval_id=approval_id, requestor_username="requestor", - grantor_username="grantor1") + grantor_username="grantor1", + ) self.db.GrantApproval( approval_id=approval_id, requestor_username="requestor", - grantor_username="grantor2") + grantor_username="grantor2", + ) approvals = list( d.ReadApprovalRequests( @@ -494,8 +508,10 @@ def testReadApprovalRequestsIncludesGrantsIntoSingleApproval(self): self.assertLen(approvals, 1) self.assertEqual(approvals[0].approval_id, approval_id) - self.assertCountEqual([g.grantor_username for g in approvals[0].grants], - ["grantor1", "grantor2"]) + self.assertCountEqual( + [g.grantor_username for g in approvals[0].grants], + ["grantor1", "grantor2"], + ) def testReadApprovalRequestsIncludesGrantsIntoMultipleResults(self): d = self.db @@ -551,9 +567,11 @@ def testReadApprovalRequestsFiltersOutExpiredApprovals(self): d.WriteGRRUser("requestor") time_future = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) time_past = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) non_expired_approval_ids = set() for i in range(10): @@ -582,7 +600,8 @@ def testReadApprovalRequestsFiltersOutExpiredApprovals(self): self.assertLen(approvals, 5) self.assertEqual( - set(a.approval_id for a in approvals), non_expired_approval_ids) + set(a.approval_id for a in approvals), non_expired_approval_ids + ) def testReadApprovalRequestsKeepsExpiredApprovalsWhenAsked(self): d = self.db @@ -590,9 +609,11 @@ def testReadApprovalRequestsKeepsExpiredApprovalsWhenAsked(self): d.WriteGRRUser("requestor") time_future = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) time_past = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) approval_ids = set() for i in range(10): @@ -680,7 +701,8 @@ def testReadApprovalRequestsForSubjectReturnsManyNonExpiredApproval(self): d.WriteGRRUser("requestor") expiration_time = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) approval_ids = set() for _ in range(10): @@ -727,11 +749,13 @@ def testReadApprovalRequestsForSubjectIncludesGrantsIntoSingleResult(self): self.db.GrantApproval( requestor_username="requestor", approval_id=approval_id, - grantor_username="grantor1") + grantor_username="grantor1", + ) self.db.GrantApproval( requestor_username="requestor", approval_id=approval_id, - grantor_username="grantor2") + grantor_username="grantor2", + ) approvals = list( d.ReadApprovalRequests( @@ -744,8 +768,10 @@ def testReadApprovalRequestsForSubjectIncludesGrantsIntoSingleResult(self): self.assertLen(approvals, 1) self.assertEqual(approvals[0].approval_id, approval_id) - self.assertCountEqual([g.grantor_username for g in approvals[0].grants], - ["grantor1", "grantor2"]) + self.assertCountEqual( + [g.grantor_username for g in approvals[0].grants], + ["grantor1", "grantor2"], + ) def testReadApprovalRequestsForSubjectIncludesGrantsIntoMultipleResults(self): client_id = db_test_utils.InitializeClient(self.db) @@ -803,9 +829,11 @@ def testReadApprovalRequestsForSubjectFiltersOutExpiredApprovals(self): d.WriteGRRUser("requestor") time_future = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) time_past = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) non_expired_approval_ids = set() for i in range(10): @@ -833,7 +861,8 @@ def testReadApprovalRequestsForSubjectFiltersOutExpiredApprovals(self): self.assertLen(approvals, 5) self.assertEqual( - set(a.approval_id for a in approvals), non_expired_approval_ids) + set(a.approval_id for a in approvals), non_expired_approval_ids + ) def testReadApprovalRequestsForSubjectKeepsExpiredApprovalsWhenAsked(self): client_id = db_test_utils.InitializeClient(self.db) @@ -842,9 +871,11 @@ def testReadApprovalRequestsForSubjectKeepsExpiredApprovalsWhenAsked(self): d.WriteGRRUser("requestor") time_future = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) time_past = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) + 1, rdfvalue.DAYS + ) approval_ids = set() for i in range(10): @@ -947,21 +978,28 @@ def testMultipleNotificationsCanBeWrittenAndRead(self): self.assertLen(notifications, 3) self.assertEqual(notifications[0].username, username) - self.assertEqual(notifications[0].notification_type, - NotificationType.TYPE_CLIENT_INTERROGATED) + self.assertEqual( + notifications[0].notification_type, + NotificationType.TYPE_CLIENT_INTERROGATED, + ) self.assertEqual(notifications[0].state, NotificationState.STATE_PENDING) self.assertEqual(notifications[0].message, "Lorem ipsum.") self.assertEqual(notifications[1].username, username) - self.assertEqual(notifications[1].notification_type, - NotificationType.TYPE_CLIENT_APPROVAL_REQUESTED) - self.assertEqual(notifications[1].state, - NotificationState.STATE_NOT_PENDING) + self.assertEqual( + notifications[1].notification_type, + NotificationType.TYPE_CLIENT_APPROVAL_REQUESTED, + ) + self.assertEqual( + notifications[1].state, NotificationState.STATE_NOT_PENDING + ) self.assertEqual(notifications[1].message, "Dolor sit amet.") self.assertEqual(notifications[2].username, username) - self.assertEqual(notifications[2].notification_type, - NotificationType.TYPE_FLOW_RUN_FAILED) + self.assertEqual( + notifications[2].notification_type, + NotificationType.TYPE_FLOW_RUN_FAILED, + ) self.assertEqual(notifications[2].state, NotificationState.STATE_PENDING) self.assertEqual(notifications[2].message, "Consectetur adipiscing elit.") @@ -1105,11 +1143,13 @@ def testReadUserNotificationsWithStateFilter(self): self._SetupUserNotificationTimerangeTest() ns = d.ReadUserNotifications( - username, state=rdf_objects.UserNotification.State.STATE_NOT_PENDING) + username, state=rdf_objects.UserNotification.State.STATE_NOT_PENDING + ) self.assertEmpty(ns) ns = d.ReadUserNotifications( - username, state=rdf_objects.UserNotification.State.STATE_PENDING) + username, state=rdf_objects.UserNotification.State.STATE_PENDING + ) self.assertLen(ns, 2) def testReadUserNotificationsWithStateAndTimerange(self): @@ -1121,13 +1161,15 @@ def testReadUserNotificationsWithStateAndTimerange(self): ns = d.ReadUserNotifications( username, timerange=(ts[0], ts[1]), - state=rdf_objects.UserNotification.State.STATE_NOT_PENDING) + state=rdf_objects.UserNotification.State.STATE_NOT_PENDING, + ) self.assertEmpty(ns) ns = d.ReadUserNotifications( username, timerange=(ts[0], ts[1]), - state=rdf_objects.UserNotification.State.STATE_PENDING) + state=rdf_objects.UserNotification.State.STATE_PENDING, + ) self.assertLen(ns, 1) self.assertEqual(ns[0].message, "n0") @@ -1254,11 +1296,13 @@ def testDeleteUserDeletesApprovalGrantsForGrantor(self): self.db.GrantApproval( requestor_username="requestor", approval_id=approval_id, - grantor_username="grantor") + grantor_username="grantor", + ) self.db.GrantApproval( requestor_username="requestor", approval_id=approval_id, - grantor_username="grantor2") + grantor_username="grantor2", + ) d.DeleteGRRUser("grantor") result = d.ReadApprovalRequest("requestor", approval_id) @@ -1270,7 +1314,8 @@ def testDeleteUserDeletesNotifications(self): username = "test" self._SetupUserNotificationTimerangeTest(username) self.assertNotEmpty( - d.ReadUserNotifications(username, timerange=(None, None))) + d.ReadUserNotifications(username, timerange=(None, None)) + ) d.DeleteGRRUser(username) self.assertEmpty(d.ReadUserNotifications(username, timerange=(None, None))) @@ -1312,4 +1357,5 @@ def testCountGRRUsersMultiple(self): self.assertEqual(self.db.CountGRRUsers(), 3) + # This file is a test library and thus does not require a __main__ block. diff --git a/grr/server/grr_response_server/databases/db_utils.py b/grr/server/grr_response_server/databases/db_utils.py index 84fcdc9984..c1981a0c25 100644 --- a/grr/server/grr_response_server/databases/db_utils.py +++ b/grr/server/grr_response_server/databases/db_utils.py @@ -1,9 +1,9 @@ #!/usr/bin/env python """Utility functions/decorators for DB implementations.""" + import functools import logging import time - from typing import Generic from typing import List from typing import Sequence @@ -14,7 +14,6 @@ from google.protobuf import any_pb2 from google.protobuf import wrappers_pb2 from grr_response_core.lib import rdfvalue -from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import precondition from grr_response_core.stats import metrics @@ -28,9 +27,11 @@ DB_REQUEST_LATENCY = metrics.Event( "db_request_latency", fields=[("call", str)], - bins=[0.05 * 1.2**x for x in range(30)]) # 50ms to ~10 secs + bins=[0.05 * 1.2**x for x in range(30)], +) # 50ms to ~10 secs DB_REQUEST_ERRORS = metrics.Counter( - "db_request_errors", fields=[("call", str), ("type", str)]) + "db_request_errors", fields=[("call", str), ("type", str)] +) class Error(Exception): @@ -57,8 +58,8 @@ class InvalidTypeURLError(Error): pass -def CallLoggedAndAccounted(f): - """Decorator to log and account for a DB call.""" +def CallLogged(f): + """Decorator used to add automatic logging of the database call.""" @functools.wraps(f) def Decorator(*args, **kwargs): @@ -67,19 +68,37 @@ def Decorator(*args, **kwargs): result = f(*args, **kwargs) latency = time.time() - start_time - DB_REQUEST_LATENCY.RecordEvent(latency, fields=[f.__name__]) logging.debug("DB request %s SUCCESS (%.3fs)", f.__name__, latency) return result except db.Error as e: - DB_REQUEST_ERRORS.Increment(fields=[f.__name__, "grr"]) - logging.debug("DB request %s GRR ERROR: %s", f.__name__, - utils.SmartUnicode(e)) + logging.debug("DB request %s GRR ERROR: %s", f.__name__, e) raise except Exception as e: + logging.debug("DB request %s INTERNAL DB ERROR : %s", f.__name__, e) + raise + + return Decorator + + +def CallAccounted(f): + """Decorator used to add automatic metric accounting of the database call.""" + + @functools.wraps(f) + def Decorator(*args, **kwargs): + try: + start_time = time.time() + result = f(*args, **kwargs) + latency = time.time() - start_time + + DB_REQUEST_LATENCY.RecordEvent(latency, fields=[f.__name__]) + + return result + except db.Error: + DB_REQUEST_ERRORS.Increment(fields=[f.__name__, "grr"]) + raise + except Exception: DB_REQUEST_ERRORS.Increment(fields=[f.__name__, "db"]) - logging.debug("DB request %s INTERNAL DB ERROR : %s", f.__name__, - utils.SmartUnicode(e)) raise return Decorator @@ -202,8 +221,9 @@ def MicrosToSeconds(ms): return ms / 1e6 -def ParseAndUnpackAny(payload_type: str, - payload_bytes: bytes) -> rdf_structs.RDFProtoStruct: +def ParseAndUnpackAny( + payload_type: str, payload_bytes: bytes +) -> rdf_structs.RDFProtoStruct: """Parses a google.protobuf.Any payload and unpack it into RDFProtoStruct. Args: @@ -227,7 +247,8 @@ def ParseAndUnpackAny(payload_type: str, if payload_type not in rdfvalue.RDFValue.classes: return rdf_objects.SerializedValueOfUnrecognizedType( - type_name=payload_type, value=payload_value_bytes) + type_name=payload_type, value=payload_value_bytes + ) rdf_class = rdfvalue.RDFValue.classes[payload_type] return rdf_class.FromSerializedBytes(payload_value_bytes) @@ -397,7 +418,9 @@ def RDFTypeNameToTypeURL(rdf_type_name: str) -> str: return f"type.googleapis.com/grr.{rdf_type_name}" -_BYTES_VALUE_TYPE_URL = f"type.googleapis.com/{wrappers_pb2.BytesValue.DESCRIPTOR.full_name}" +_BYTES_VALUE_TYPE_URL = ( + f"type.googleapis.com/{wrappers_pb2.BytesValue.DESCRIPTOR.full_name}" +) _RDF_TYPE_NAME_BY_WRAPPER_TYPE_NAME = { diff --git a/grr/server/grr_response_server/databases/db_utils_test.py b/grr/server/grr_response_server/databases/db_utils_test.py index 174fbd476d..5b6373eb90 100644 --- a/grr/server/grr_response_server/databases/db_utils_test.py +++ b/grr/server/grr_response_server/databases/db_utils_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import logging -from unittest import mock +from typing import Tuple from absl import app from absl.testing import absltest @@ -90,9 +90,12 @@ def testSingleOperationLowerThanLimit(self): batch_planner = db_utils.BatchPlanner(10) batch_planner.PlanOperation("a", 9) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 9)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 9)], + ], + ) def testMultipleOperationsLowerThanLimitInTotal(self): batch_planner = db_utils.BatchPlanner(10) @@ -100,28 +103,37 @@ def testMultipleOperationsLowerThanLimitInTotal(self): batch_planner.PlanOperation("b", 3) batch_planner.PlanOperation("c", 3) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 3), ("b", 0, 3), ("c", 0, 3)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 3), ("b", 0, 3), ("c", 0, 3)], + ], + ) def testSingleOperationBiggerThanLimit(self): batch_planner = db_utils.BatchPlanner(10) batch_planner.PlanOperation("a", 12) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 10)], - [("a", 10, 2)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 10)], + [("a", 10, 2)], + ], + ) def testSingleOperationMoreThanTwiceBiggerThanLimit(self): batch_planner = db_utils.BatchPlanner(10) batch_planner.PlanOperation("a", 22) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 10)], - [("a", 10, 10)], - [("a", 20, 2)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 10)], + [("a", 10, 10)], + [("a", 20, 2)], + ], + ) def testMultipleOperationsBiggerThanLimitInTotal(self): batch_planner = db_utils.BatchPlanner(10) @@ -130,10 +142,13 @@ def testMultipleOperationsBiggerThanLimitInTotal(self): batch_planner.PlanOperation("c", 3) batch_planner.PlanOperation("d", 3) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 3), ("b", 0, 3), ("c", 0, 3), ("d", 0, 1)], - [("d", 1, 2)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 3), ("b", 0, 3), ("c", 0, 3), ("d", 0, 1)], + [("d", 1, 2)], + ], + ) def testMultipleOperationsTwiceBiggerThanLimitInTotal(self): batch_planner = db_utils.BatchPlanner(10) @@ -145,57 +160,70 @@ def testMultipleOperationsTwiceBiggerThanLimitInTotal(self): batch_planner.PlanOperation("f", 3) batch_planner.PlanOperation("g", 3) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 3), ("b", 0, 3), ("c", 0, 3), ("d", 0, 1)], - [("d", 1, 2), ("e", 0, 3), ("f", 0, 3), ("g", 0, 2)], - [("g", 2, 1)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 3), ("b", 0, 3), ("c", 0, 3), ("d", 0, 1)], + [("d", 1, 2), ("e", 0, 3), ("f", 0, 3), ("g", 0, 2)], + [("g", 2, 1)], + ], + ) def testMultipleOperationsEachBiggerThanLimit(self): batch_planner = db_utils.BatchPlanner(10) batch_planner.PlanOperation("a", 12) batch_planner.PlanOperation("b", 12) - self.assertEqual(batch_planner.batches, [ - [("a", 0, 10)], - [("a", 10, 2), ("b", 0, 8)], - [("b", 8, 4)], - ]) + self.assertEqual( + batch_planner.batches, + [ + [("a", 0, 10)], + [("a", 10, 2), ("b", 0, 8)], + [("b", 8, 4)], + ], + ) -class CallLoggedAndAccountedTest(stats_test_lib.StatsTestMixin, - absltest.TestCase): +class CallAccountedTest(stats_test_lib.StatsTestMixin, absltest.TestCase): - @db_utils.CallLoggedAndAccounted + @db_utils.CallAccounted def SampleCall(self): return 42 - @db_utils.CallLoggedAndAccounted + @db_utils.CallAccounted def SampleCallWithGRRError(self): raise db.UnknownGRRUserError("Unknown") - @db_utils.CallLoggedAndAccounted + @db_utils.CallAccounted def SampleCallWithDBError(self): raise RuntimeError("some") def testReturnValueIsPropagated(self): self.assertEqual(self.SampleCall(), 42) - def _ExpectIncrements(self, fn, latency_count_increment, - grr_errors_count_increment, db_errors_count_increment): + def _ExpectIncrements( + self, + fn, + latency_count_increment, + grr_errors_count_increment, + db_errors_count_increment, + ): with self.assertStatsCounterDelta( latency_count_increment, db_utils.DB_REQUEST_LATENCY, - fields=[fn.__name__]): + fields=[fn.__name__], + ): with self.assertStatsCounterDelta( grr_errors_count_increment, db_utils.DB_REQUEST_ERRORS, - fields=[fn.__name__, "grr"]): + fields=[fn.__name__, "grr"], + ): with self.assertStatsCounterDelta( db_errors_count_increment, db_utils.DB_REQUEST_ERRORS, - fields=[fn.__name__, "db"]): + fields=[fn.__name__, "db"], + ): try: fn() except Exception: # pylint: disable=broad-except @@ -210,34 +238,80 @@ def testCallRaisingLogicalErrorIsCorretlyAccounted(self): def testCallRaisingRuntimeDBErrorIsCorretlyAccounted(self): self._ExpectIncrements(self.SampleCallWithDBError, 0, 0, 1) - @mock.patch.object(logging, "debug") - def testSuccessfulCallIsCorretlyLogged(self, debug_mock): - self.SampleCall() - self.assertTrue(debug_mock.called) - got = debug_mock.call_args_list[0][0] - self.assertIn("SUCCESS", got[0]) - self.assertEqual(got[1], "SampleCall") +class CallLoggedTest(absltest.TestCase): + + def setUp(self): + super().setUp() + + logger = logging.getLogger() + + class Handler(logging.Handler): + + def __init__(self): + super().__init__() + self.logs: list[logging.LogRecord] = [] + + def emit(self, record: logging.LogRecord): + self.logs.append(record) + + # We create our own log handler that stores all records in a simple list. + self.handler = Handler() + logger.addHandler(self.handler) + self.addCleanup(lambda: logger.removeHandler(self.handler)) + + # We adjust log level to `DEBUG` to catch all logs. + old_log_level = logger.level + logger.setLevel(logging.DEBUG) + self.addCleanup(lambda: logger.setLevel(old_log_level)) + + def testArgsAndResultPropagated(self): + @db_utils.CallLogged + def SampleCall(arg: int, kwarg: int = 0) -> Tuple[int, int]: + return (arg, kwarg) + + self.assertEqual(SampleCall(42, 1337), (42, 1337)) + + def testCallSuccessLogged(self): + @db_utils.CallLogged + def SampleCall(): + return 42 + + SampleCall() + + self.assertLen(self.handler.logs, 1) + + message = self.handler.logs[0].getMessage() + self.assertIn("SUCCESS", message) + self.assertIn("SampleCall", message) + + def testCallRaisedDBErrorLogged(self): + @db_utils.CallLogged + def SampleCallWithDBError(): + raise db.UnknownGRRUserError("Unknown") - @mock.patch.object(logging, "debug") - def testCallRaisingLogicalErrorIsCorretlyLogged(self, debug_mock): with self.assertRaises(db.UnknownGRRUserError): - self.SampleCallWithGRRError() + SampleCallWithDBError() + + self.assertLen(self.handler.logs, 1) - self.assertTrue(debug_mock.called) - got = debug_mock.call_args_list[0][0] - self.assertIn("GRR ERROR", got[0]) - self.assertEqual(got[1], "SampleCallWithGRRError") + message = self.handler.logs[0].getMessage() + self.assertIn("GRR ERROR", message) + self.assertIn("SampleCallWithDBError", message) + + def testCallRaisedGenericErrorLogged(self): + @db_utils.CallLogged + def SampleCallWithGenericError(): + raise RuntimeError() - @mock.patch.object(logging, "debug") - def testCallRaisingRuntimeDBErrorIsCorretlyLogged(self, debug_mock): with self.assertRaises(RuntimeError): - self.SampleCallWithDBError() + SampleCallWithGenericError() + + self.assertLen(self.handler.logs, 1) - self.assertTrue(debug_mock.called) - got = debug_mock.call_args_list[0][0] - self.assertIn("INTERNAL DB ERROR", got[0]) - self.assertEqual(got[1], "SampleCallWithDBError") + message = self.handler.logs[0].getMessage() + self.assertIn("INTERNAL DB ERROR", message) + self.assertIn("SampleCallWithGenericError", message) class IdToIntConversionTest(absltest.TestCase): @@ -248,7 +322,8 @@ def testFlowIdToInt(self): self.assertEqual(db_utils.FlowIDToInt("FFFFFFFF"), 0xFFFFFFFF) self.assertEqual(db_utils.FlowIDToInt("0000000100000000"), 0x100000000) self.assertEqual( - db_utils.FlowIDToInt("FFFFFFFFFFFFFFFF"), 0xFFFFFFFFFFFFFFFF) + db_utils.FlowIDToInt("FFFFFFFFFFFFFFFF"), 0xFFFFFFFFFFFFFFFF + ) def testIntToFlowId(self): self.assertEqual(db_utils.IntToFlowID(1), "00000001") @@ -256,7 +331,8 @@ def testIntToFlowId(self): self.assertEqual(db_utils.IntToFlowID(0xFFFFFFFF), "FFFFFFFF") self.assertEqual(db_utils.IntToFlowID(0x100000000), "0000000100000000") self.assertEqual( - db_utils.IntToFlowID(0xFFFFFFFFFFFFFFFF), "FFFFFFFFFFFFFFFF") + db_utils.IntToFlowID(0xFFFFFFFFFFFFFFFF), "FFFFFFFFFFFFFFFF" + ) def testHuntIdToInt(self): self.assertEqual(db_utils.HuntIDToInt("00000001"), 1) @@ -264,7 +340,8 @@ def testHuntIdToInt(self): self.assertEqual(db_utils.HuntIDToInt("FFFFFFFF"), 0xFFFFFFFF) self.assertEqual(db_utils.HuntIDToInt("0000000100000000"), 0x100000000) self.assertEqual( - db_utils.HuntIDToInt("FFFFFFFFFFFFFFFF"), 0xFFFFFFFFFFFFFFFF) + db_utils.HuntIDToInt("FFFFFFFFFFFFFFFF"), 0xFFFFFFFFFFFFFFFF + ) def testIntToHuntId(self): self.assertEqual(db_utils.IntToHuntID(1), "00000001") @@ -272,7 +349,8 @@ def testIntToHuntId(self): self.assertEqual(db_utils.IntToHuntID(0xFFFFFFFF), "FFFFFFFF") self.assertEqual(db_utils.IntToHuntID(0x100000000), "0000000100000000") self.assertEqual( - db_utils.IntToHuntID(0xFFFFFFFFFFFFFFFF), "FFFFFFFFFFFFFFFF") + db_utils.IntToHuntID(0xFFFFFFFFFFFFFFFF), "FFFFFFFFFFFFFFFF" + ) class ParseAndUnpackAnyTest(absltest.TestCase): diff --git a/grr/server/grr_response_server/databases/mem_artifacts_test.py b/grr/server/grr_response_server/databases/mem_artifacts_test.py index 1b353466c3..5424462a41 100644 --- a/grr/server/grr_response_server/databases/mem_artifacts_test.py +++ b/grr/server/grr_response_server/databases/mem_artifacts_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBArtifactsTest(db_artifacts_test.DatabaseTestArtifactsMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBArtifactsTest( + db_artifacts_test.DatabaseTestArtifactsMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_blob_keys.py b/grr/server/grr_response_server/databases/mem_blob_keys.py index de02dc37b4..8f5a33fd12 100644 --- a/grr/server/grr_response_server/databases/mem_blob_keys.py +++ b/grr/server/grr_response_server/databases/mem_blob_keys.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The in-memory database methods for blob encryption keys.""" + from typing import Collection from typing import Dict from typing import Optional @@ -9,6 +10,7 @@ class InMemoryDBBlobKeysMixin: """A mixin proving blob encryption key methods for in-memory database.""" + blob_keys: Dict[blobs.BlobID, str] def WriteBlobEncryptionKeys( diff --git a/grr/server/grr_response_server/databases/mem_blob_references_test.py b/grr/server/grr_response_server/databases/mem_blob_references_test.py index fc5bdded2b..8fd875d3ae 100644 --- a/grr/server/grr_response_server/databases/mem_blob_references_test.py +++ b/grr/server/grr_response_server/databases/mem_blob_references_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -10,7 +9,9 @@ class MemoryDBBlobReferencesTest( db_blob_references_test.DatabaseTestBlobReferencesMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_blobs.py b/grr/server/grr_response_server/databases/mem_blobs.py index 262dbe2ffe..112fe2da83 100644 --- a/grr/server/grr_response_server/databases/mem_blobs.py +++ b/grr/server/grr_response_server/databases/mem_blobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """DB mixin for blobs-related methods.""" + from typing import Collection, Mapping, Optional from grr_response_core.lib import utils diff --git a/grr/server/grr_response_server/databases/mem_blobs_test.py b/grr/server/grr_response_server/databases/mem_blobs_test.py index 14249d05ce..3bd9e9597e 100644 --- a/grr/server/grr_response_server/databases/mem_blobs_test.py +++ b/grr/server/grr_response_server/databases/mem_blobs_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for the memory-based blob store.""" - from absl import app from absl.testing import absltest @@ -11,8 +10,11 @@ from grr.test_lib import test_lib -class MemoryDBBlobStoreTest(blob_store_test_mixin.BlobStoreTestMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBBlobStoreTest( + blob_store_test_mixin.BlobStoreTestMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): def CreateBlobStore(self): db, db_cleanup_fn = self.CreateDatabase() diff --git a/grr/server/grr_response_server/databases/mem_clients.py b/grr/server/grr_response_server/databases/mem_clients.py index 2ee59da0fa..96081103ee 100644 --- a/grr/server/grr_response_server/databases/mem_clients.py +++ b/grr/server/grr_response_server/databases/mem_clients.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """The in memory database methods for client handling.""" -from typing import Collection, Mapping, Optional, Sequence, Tuple, TypedDict + +from typing import Collection, Iterator, Mapping, Optional, Sequence, Tuple, TypedDict from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils @@ -42,7 +43,6 @@ class InMemoryDBClientMixin(object): def MultiWriteClientMetadata( self, client_ids: Collection[str], - certificate: Optional[rdf_crypto.RDFX509Cert] = None, first_seen: Optional[rdfvalue.RDFDatetime] = None, last_ping: Optional[rdfvalue.RDFDatetime] = None, last_clock: Optional[rdfvalue.RDFDatetime] = None, @@ -52,9 +52,6 @@ def MultiWriteClientMetadata( ) -> None: """Writes metadata about the clients.""" md = {} - if certificate is not None: - md["certificate"] = certificate - if first_seen is not None: md["first_seen"] = first_seen @@ -211,10 +208,12 @@ def MultiReadClientFullInfo( return res @utils.Synchronized - def ReadClientLastPings(self, - min_last_ping=None, - max_last_ping=None, - batch_size=db.CLIENT_IDS_BATCH_SIZE): + def ReadClientLastPings( + self, + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + max_last_ping: Optional[rdfvalue.RDFDatetime] = None, + batch_size: int = db.CLIENT_IDS_BATCH_SIZE, + ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: """Yields dicts of last-ping timestamps for clients in the DB.""" last_pings = {} for client_id, metadata in self.metadatas.items(): @@ -501,10 +500,13 @@ def DeleteClient( for kw in self.keywords: self.keywords[kw].pop(client_id, None) - def StructuredSearchClients(self, expression: rdf_search.SearchExpression, - sort_order: rdf_search.SortOrder, - continuation_token: bytes, - number_of_results: int) -> db.SearchClientsResult: + def StructuredSearchClients( + self, + expression: rdf_search.SearchExpression, + sort_order: rdf_search.SortOrder, + continuation_token: bytes, + number_of_results: int, + ) -> db.SearchClientsResult: # Unused arguments del self, expression, sort_order, continuation_token, number_of_results raise NotImplementedError diff --git a/grr/server/grr_response_server/databases/mem_clients_test.py b/grr/server/grr_response_server/databases/mem_clients_test.py index 8e2f3dc05a..bb0779d233 100644 --- a/grr/server/grr_response_server/databases/mem_clients_test.py +++ b/grr/server/grr_response_server/databases/mem_clients_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBClientsTest(db_clients_test.DatabaseTestClientsMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBClientsTest( + db_clients_test.DatabaseTestClientsMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_cronjob_test.py b/grr/server/grr_response_server/databases/mem_cronjob_test.py index 540d0fea38..f04fbbbc7e 100644 --- a/grr/server/grr_response_server/databases/mem_cronjob_test.py +++ b/grr/server/grr_response_server/databases/mem_cronjob_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBCronJobTest(db_cronjob_test.DatabaseTestCronJobMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBCronJobTest( + db_cronjob_test.DatabaseTestCronJobMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_cronjobs.py b/grr/server/grr_response_server/databases/mem_cronjobs.py index 6f368e23b2..9f725bdc5e 100644 --- a/grr/server/grr_response_server/databases/mem_cronjobs.py +++ b/grr/server/grr_response_server/databases/mem_cronjobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The in memory database methods for cron job handling.""" + from typing import Optional, Sequence, Tuple from grr_response_core.lib import rdfvalue @@ -56,30 +57,32 @@ def ReadCronJobs( return res @utils.Synchronized - def UpdateCronJob(self, - cronjob_id, - last_run_status=db.Database.unchanged, - last_run_time=db.Database.unchanged, - current_run_id=db.Database.unchanged, - state=db.Database.unchanged, - forced_run_requested=db.Database.unchanged): + def UpdateCronJob( + self, + cronjob_id, + last_run_status=db.Database.UNCHANGED, + last_run_time=db.Database.UNCHANGED, + current_run_id=db.Database.UNCHANGED, + state=db.Database.UNCHANGED, + forced_run_requested=db.Database.UNCHANGED, + ): """Updates run information for an existing cron job.""" job = self.cronjobs.get(cronjob_id) if job is None: raise db.UnknownCronJobError(f"Cron job {cronjob_id} not known.") - if last_run_status != db.Database.unchanged: + if last_run_status != db.Database.UNCHANGED: job.last_run_status = last_run_status - if last_run_time != db.Database.unchanged: + if last_run_time != db.Database.UNCHANGED: job.last_run_time = last_run_time.AsMicrosecondsSinceEpoch() - if current_run_id != db.Database.unchanged: + if current_run_id != db.Database.UNCHANGED: if current_run_id is None: job.ClearField("current_run_id") else: job.current_run_id = current_run_id - if state != db.Database.unchanged: + if state != db.Database.UNCHANGED: job.state.CopyFrom(state) - if forced_run_requested != db.Database.unchanged: + if forced_run_requested != db.Database.UNCHANGED: job.forced_run_requested = forced_run_requested @utils.Synchronized diff --git a/grr/server/grr_response_server/databases/mem_events.py b/grr/server/grr_response_server/databases/mem_events.py index 54319b23cc..f81d2e121f 100644 --- a/grr/server/grr_response_server/databases/mem_events.py +++ b/grr/server/grr_response_server/databases/mem_events.py @@ -2,36 +2,49 @@ """The in memory database methods for event handling.""" import collections +from typing import Optional from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_server.rdfvalues import objects as rdf_objects +from grr_response_proto import objects_pb2 class InMemoryDBEventMixin(object): """InMemoryDB mixin for event handling.""" + api_audit_entries: list[objects_pb2.APIAuditEntry] + @utils.Synchronized - def ReadAPIAuditEntries(self, - username=None, - router_method_names=None, - min_timestamp=None, - max_timestamp=None): + def ReadAPIAuditEntries( + self, + username: Optional[str] = None, + router_method_names: Optional[list[str]] = None, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> list[objects_pb2.APIAuditEntry]: """Returns audit entries stored in the database.""" results = [] for entry in self.api_audit_entries: - if username is not None and entry.username != username: + if username and entry.username != username: continue - if (router_method_names and - entry.router_method_name not in router_method_names): + if ( + router_method_names + and entry.router_method_name not in router_method_names + ): continue - if min_timestamp is not None and entry.timestamp < min_timestamp: + if ( + min_timestamp is not None + and entry.timestamp < min_timestamp.AsMicrosecondsSinceEpoch() + ): continue - if max_timestamp is not None and entry.timestamp > max_timestamp: + if ( + max_timestamp is not None + and entry.timestamp > max_timestamp.AsMicrosecondsSinceEpoch() + ): continue results.append(entry) @@ -39,28 +52,38 @@ def ReadAPIAuditEntries(self, return sorted(results, key=lambda entry: entry.timestamp) @utils.Synchronized - def CountAPIAuditEntriesByUserAndDay(self, - min_timestamp=None, - max_timestamp=None): + def CountAPIAuditEntriesByUserAndDay( + self, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> dict[tuple[str, rdfvalue.RDFDatetime], int]: """Returns audit entry counts grouped by user and calendar day.""" results = collections.Counter() for entry in self.api_audit_entries: - if min_timestamp is not None and entry.timestamp < min_timestamp: + if ( + min_timestamp is not None + and entry.timestamp < min_timestamp.AsMicrosecondsSinceEpoch() + ): continue - if max_timestamp is not None and entry.timestamp > max_timestamp: + if ( + max_timestamp is not None + and entry.timestamp > max_timestamp.AsMicrosecondsSinceEpoch() + ): continue # Truncate DateTime by removing the time-part to allow grouping by date. - day = rdfvalue.RDFDatetime.FromDate(entry.timestamp.AsDatetime().date()) + rdf_dt = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(entry.timestamp) + day = rdfvalue.RDFDatetime.FromDate(rdf_dt.AsDatetime().date()) results[(entry.username, day)] += 1 return dict(results) @utils.Synchronized - def WriteAPIAuditEntry(self, entry: rdf_objects.APIAuditEntry): + def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry) -> None: """Writes an audit entry to the database.""" - copy = entry.Copy() - if copy.timestamp is None: - copy.timestamp = rdfvalue.RDFDatetime.Now() + copy = objects_pb2.APIAuditEntry() + copy.CopyFrom(entry) + if not copy.HasField("timestamp"): + copy.timestamp = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() self.api_audit_entries.append(copy) diff --git a/grr/server/grr_response_server/databases/mem_events_test.py b/grr/server/grr_response_server/databases/mem_events_test.py index 10274804ae..c179a5d04f 100644 --- a/grr/server/grr_response_server/databases/mem_events_test.py +++ b/grr/server/grr_response_server/databases/mem_events_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBEventsTest(db_events_test.DatabaseTestEventsMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBEventsTest( + db_events_test.DatabaseTestEventsMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_flows.py b/grr/server/grr_response_server/databases/mem_flows.py index e0ded02c6e..f80d52af77 100644 --- a/grr/server/grr_response_server/databases/mem_flows.py +++ b/grr/server/grr_response_server/databases/mem_flows.py @@ -6,6 +6,8 @@ import sys import threading import time +from typing import Callable +from typing import Collection from typing import Dict from typing import Iterable from typing import List @@ -16,16 +18,17 @@ from typing import Text from typing import Tuple from typing import TypeVar +from typing import Union from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_flow_objects T = TypeVar("T") @@ -33,6 +36,8 @@ ClientID = NewType("ClientID", str) FlowID = NewType("FlowID", str) Username = NewType("Username", str) +HandlerName = NewType("HandlerName", str) +RequestID = NewType("RequestID", int) class Error(Exception): @@ -46,42 +51,62 @@ class TimeOutWhileWaitingForFlowsToBeProcessedError(Error): class InMemoryDBFlowMixin(object): """InMemoryDB mixin for flow handling.""" + flows: Dict[Tuple[ClientID, FlowID], flows_pb2.Flow] flow_results: Dict[Tuple[str, str], List[flows_pb2.FlowResult]] flow_errors: Dict[Tuple[str, str], List[flows_pb2.FlowError]] flow_log_entries: Dict[Tuple[str, str], List[flows_pb2.FlowLogEntry]] flow_output_plugin_log_entries: Dict[ Tuple[str, str], List[flows_pb2.FlowOutputPluginLogEntry] ] + flow_responses: Dict[Tuple[str, str], List[flows_pb2.FlowResponse]] + flow_requests: Dict[Tuple[str, str], List[flows_pb2.FlowRequest]] scheduled_flows: dict[ tuple[ClientID, Username, FlowID], flows_pb2.ScheduledFlow ] + message_handler_requests: Dict[ + HandlerName, Dict[RequestID, objects_pb2.MessageHandlerRequest] + ] + message_handler_leases: Dict[ + HandlerName, Dict[RequestID, int] # lease expiration time in us + ] + @utils.Synchronized - def WriteMessageHandlerRequests(self, requests): + def WriteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: """Writes a list of message handler requests to the database.""" - now = rdfvalue.RDFDatetime.Now() + now = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() for r in requests: flow_dict = self.message_handler_requests.setdefault(r.handler_name, {}) - cloned_request = r.Copy() + cloned_request = objects_pb2.MessageHandlerRequest() + cloned_request.CopyFrom(r) cloned_request.timestamp = now flow_dict[cloned_request.request_id] = cloned_request @utils.Synchronized - def ReadMessageHandlerRequests(self): + def ReadMessageHandlerRequests( + self, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: """Reads all message handler requests from the database.""" res = [] leases = self.message_handler_leases for requests in self.message_handler_requests.values(): for r in requests.values(): - res.append(r.Copy()) + res.append(r) existing_lease = leases.get(r.handler_name, {}).get(r.request_id, None) - res[-1].leased_until = existing_lease + if existing_lease is not None: + res[-1].leased_until = existing_lease + else: + res[-1].ClearField("leased_until") return sorted(res, key=lambda r: r.timestamp, reverse=True) @utils.Synchronized - def DeleteMessageHandlerRequests(self, requests): + def DeleteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: """Deletes a list of message handler requests from the database.""" for r in requests: @@ -92,7 +117,12 @@ def DeleteMessageHandlerRequests(self, requests): if r.request_id in flow_dict: del flow_dict[r.request_id] - def RegisterMessageHandler(self, handler, lease_time, limit=1000): + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: """Leases a number of message handler requests up to the indicated limit.""" self.UnregisterMessageHandler() @@ -100,11 +130,14 @@ def RegisterMessageHandler(self, handler, lease_time, limit=1000): self.handler_thread = threading.Thread( name="message_handler", target=self._MessageHandlerLoop, - args=(handler, lease_time, limit)) + args=(handler, lease_time, limit), + ) self.handler_thread.daemon = True self.handler_thread.start() - def UnregisterMessageHandler(self, timeout=None): + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered message handler.""" if self.handler_thread: self.handler_stop = True @@ -113,7 +146,13 @@ def UnregisterMessageHandler(self, timeout=None): raise RuntimeError("Message handler thread did not join in time.") self.handler_thread = None - def _MessageHandlerLoop(self, handler, lease_time, limit): + def _MessageHandlerLoop( + self, + handler: Callable[[Iterable[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Loop to handle outstanding requests.""" while not self.handler_stop: try: msgs = self._LeaseMessageHandlerRequests(lease_time, limit) @@ -125,21 +164,28 @@ def _MessageHandlerLoop(self, handler, lease_time, limit): logging.exception("_LeaseMessageHandlerRequests raised %s.", e) @utils.Synchronized - def _LeaseMessageHandlerRequests(self, lease_time, limit): + def _LeaseMessageHandlerRequests( + self, + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: """Read and lease some outstanding message handler requests.""" leased_requests = [] now = rdfvalue.RDFDatetime.Now() - zero = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) + now_us = now.AsMicrosecondsSinceEpoch() expiration_time = now + lease_time + expiration_time_us = expiration_time.AsMicrosecondsSinceEpoch() leases = self.message_handler_leases for requests in self.message_handler_requests.values(): for r in requests.values(): - existing_lease = leases.get(r.handler_name, {}).get(r.request_id, zero) - if existing_lease < now: - leases.setdefault(r.handler_name, {})[r.request_id] = expiration_time - r.leased_until = expiration_time + existing_lease = leases.get(r.handler_name, {}).get(r.request_id, 0) + if existing_lease < now_us: + leases.setdefault(r.handler_name, {})[ + r.request_id + ] = expiration_time_us + r.leased_until = expiration_time_us r.leased_by = utils.ProcessIdString() leased_requests.append(r) if len(leased_requests) >= limit: @@ -148,7 +194,11 @@ def _LeaseMessageHandlerRequests(self, lease_time, limit): return leased_requests @utils.Synchronized - def WriteFlowObject(self, flow_obj, allow_update=True): + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + ) -> None: """Writes a flow object to the database.""" if flow_obj.client_id not in self.metadatas: raise db.UnknownClientError(flow_obj.client_id) @@ -158,19 +208,20 @@ def WriteFlowObject(self, flow_obj, allow_update=True): if not allow_update and key in self.flows: raise db.FlowExistsError(flow_obj.client_id, flow_obj.flow_id) - now = rdfvalue.RDFDatetime.Now() + now = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() - clone = flow_obj.Copy() + clone = flows_pb2.Flow() + clone.CopyFrom(flow_obj) clone.last_update_time = now clone.create_time = now self.flows[key] = clone @utils.Synchronized - def ReadFlowObject(self, client_id, flow_id): + def ReadFlowObject(self, client_id: str, flow_id: str) -> flows_pb2.Flow: """Reads a flow object from the database.""" try: - return self.flows[(client_id, flow_id)].Copy() + return self.flows[(client_id, flow_id)] except KeyError: raise db.UnknownFlowError(client_id, flow_id) @@ -183,127 +234,193 @@ def ReadAllFlowObjects( max_create_time: Optional[rdfvalue.RDFDatetime] = None, include_child_flows: bool = True, not_created_by: Optional[Iterable[str]] = None, - ) -> List[rdf_flow_objects.Flow]: + ) -> List[flows_pb2.Flow]: """Returns all flow objects.""" res = [] for flow in self.flows.values(): - if ((client_id is None or flow.client_id == client_id) and - (parent_flow_id is None or flow.parent_flow_id == parent_flow_id) and - (min_create_time is None or flow.create_time >= min_create_time) and - (max_create_time is None or flow.create_time <= max_create_time) and - (include_child_flows or not flow.parent_flow_id) and - (not_created_by is None or flow.creator not in not_created_by)): - res.append(flow.Copy()) + if client_id is not None and client_id != flow.client_id: + continue + if parent_flow_id is not None and parent_flow_id != flow.parent_flow_id: + continue + if ( + min_create_time is not None + and flow.create_time < min_create_time.AsMicrosecondsSinceEpoch() + ): + continue + if ( + max_create_time is not None + and flow.create_time > max_create_time.AsMicrosecondsSinceEpoch() + ): + continue + if not include_child_flows and flow.parent_flow_id: + continue + if not_created_by is not None and flow.creator in not_created_by: + continue + res.append(flow) return res @utils.Synchronized - def LeaseFlowForProcessing(self, client_id, flow_id, processing_time): + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + ) -> flows_pb2.Flow: """Marks a flow as being processed on this worker and returns it.""" - rdf_flow = self.ReadFlowObject(client_id, flow_id) - if rdf_flow.parent_hunt_id: - rdf_hunt = self.ReadHuntObject(rdf_flow.parent_hunt_id) + flow = self.ReadFlowObject(client_id, flow_id) + if flow.parent_hunt_id: + hunt_obj = self.ReadHuntObject(flow.parent_hunt_id) if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing( - rdf_hunt.hunt_state): - raise db.ParentHuntIsNotRunningError(client_id, flow_id, - rdf_hunt.hunt_id, - rdf_hunt.hunt_state) + hunt_obj.hunt_state + ): + raise db.ParentHuntIsNotRunningError( + client_id, flow_id, hunt_obj.hunt_id, hunt_obj.hunt_state + ) now = rdfvalue.RDFDatetime.Now() - if rdf_flow.processing_on and rdf_flow.processing_deadline > now: - raise ValueError("Flow %s on client %s is already being processed." % - (flow_id, client_id)) + if flow.processing_on and flow.processing_deadline > int(now): + raise ValueError( + "Flow %s on client %s is already being processed." + % (flow_id, client_id) + ) processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() - self.UpdateFlow( - client_id, - flow_id, - processing_on=process_id_string, - processing_since=now, - processing_deadline=processing_deadline) - rdf_flow.processing_on = process_id_string - rdf_flow.processing_since = now - rdf_flow.processing_deadline = processing_deadline - return rdf_flow - - @utils.Synchronized - def UpdateFlow(self, - client_id, - flow_id, - flow_obj=db.Database.unchanged, - flow_state=db.Database.unchanged, - client_crash_info=db.Database.unchanged, - processing_on=db.Database.unchanged, - processing_since=db.Database.unchanged, - processing_deadline=db.Database.unchanged): - """Updates flow objects in the database.""" + # We avoid calling `UpdateFlow` here because it will update the + # `last_update_time` field. Other DB implementations avoid this change, + # so we want to preserve the same behavior here. + flow_clone = flows_pb2.Flow() + flow_clone.CopyFrom(flow) + flow_clone.processing_on = process_id_string + flow_clone.processing_since = int(now) + flow_clone.processing_deadline = int(processing_deadline) + self.flows[(client_id, flow_id)] = flow_clone + + flow.processing_on = process_id_string + flow.processing_since = int(now) + flow.processing_deadline = int(processing_deadline) + return flow + + @utils.Synchronized + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[ + flows_pb2.Flow, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + processing_on: Optional[ + Union[str, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + ) -> None: + """Updates flow objects in the database.""" try: flow = self.flows[(client_id, flow_id)] except KeyError: raise db.UnknownFlowError(client_id, flow_id) - if flow_obj != db.Database.unchanged: - new_flow = flow_obj.Copy() + if isinstance(flow_obj, flows_pb2.Flow): + new_flow = flows_pb2.Flow() + new_flow.CopyFrom(flow_obj) # Some fields cannot be updated. new_flow.client_id = flow.client_id new_flow.flow_id = flow.flow_id - new_flow.long_flow_id = flow.long_flow_id - new_flow.parent_flow_id = flow.parent_flow_id - new_flow.parent_hunt_id = flow.parent_hunt_id - new_flow.flow_class_name = flow.flow_class_name - new_flow.creator = flow.creator + if flow.long_flow_id: + new_flow.long_flow_id = flow.long_flow_id + if flow.parent_flow_id: + new_flow.parent_flow_id = flow.parent_flow_id + if flow.parent_hunt_id: + new_flow.parent_hunt_id = flow.parent_hunt_id + if flow.flow_class_name: + new_flow.flow_class_name = flow.flow_class_name + if flow.creator: + new_flow.creator = flow.creator - self.flows[(client_id, flow_id)] = new_flow flow = new_flow - - if flow_state != db.Database.unchanged: + if isinstance(flow_state, flows_pb2.Flow.FlowState.ValueType): flow.flow_state = flow_state - if client_crash_info != db.Database.unchanged: - flow.client_crash_info = client_crash_info - if processing_on != db.Database.unchanged: + if isinstance(client_crash_info, jobs_pb2.ClientCrash): + flow.client_crash_info.CopyFrom(client_crash_info) + if ( + isinstance(processing_on, str) + and processing_on is not db.Database.UNCHANGED + ): flow.processing_on = processing_on - if processing_since != db.Database.unchanged: - flow.processing_since = processing_since - if processing_deadline != db.Database.unchanged: - flow.processing_deadline = processing_deadline - flow.last_update_time = rdfvalue.RDFDatetime.Now() + elif processing_on is None: + flow.ClearField("processing_on") + if isinstance(processing_since, rdfvalue.RDFDatetime): + flow.processing_since = int(processing_since) + elif processing_since is None: + flow.ClearField("processing_since") + if isinstance(processing_deadline, rdfvalue.RDFDatetime): + flow.processing_deadline = int(processing_deadline) + elif processing_deadline is None: + flow.ClearField("processing_deadline") + flow.last_update_time = int(rdfvalue.RDFDatetime.Now()) + + self.flows[(client_id, flow_id)] = flow @utils.Synchronized - def WriteFlowRequests(self, requests): + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: """Writes a list of flow requests to the database.""" flow_processing_requests = [] for request in requests: if (request.client_id, request.flow_id) not in self.flows: - raise db.AtLeastOneUnknownFlowError([(request.client_id, - request.flow_id)]) + raise db.AtLeastOneUnknownFlowError( + [(request.client_id, request.flow_id)] + ) for request in requests: key = (request.client_id, request.flow_id) request_dict = self.flow_requests.setdefault(key, {}) - request_dict[request.request_id] = request.Copy() - request_dict[request.request_id].timestamp = rdfvalue.RDFDatetime.Now() + clone = flows_pb2.FlowRequest() + clone.CopyFrom(request) + request_dict[request.request_id] = clone + request_dict[request.request_id].timestamp = int( + rdfvalue.RDFDatetime.Now() + ) if request.needs_processing: flow = self.flows[(request.client_id, request.flow_id)] + flow = mig_flow_objects.ToRDFFlow(flow) if ( flow.next_request_to_process == request.request_id - or request.start_time is not None + or request.start_time ): - flow_processing_requests.append( - rdf_flows.FlowProcessingRequest( - client_id=request.client_id, - flow_id=request.flow_id, - delivery_time=request.start_time)) + processing_request = flows_pb2.FlowProcessingRequest( + client_id=request.client_id, flow_id=request.flow_id + ) + if request.start_time: + processing_request.delivery_time = request.start_time + flow_processing_requests.append(processing_request) if flow_processing_requests: self.WriteFlowProcessingRequests(flow_processing_requests) @utils.Synchronized def UpdateIncrementalFlowRequests( - self, client_id: str, flow_id: str, - next_response_id_updates: Dict[int, int]) -> None: + self, + client_id: str, + flow_id: str, + next_response_id_updates: Mapping[int, int], + ) -> None: """Updates incremental flow requests.""" if (client_id, flow_id) not in self.flows: raise db.UnknownFlowError(client_id, flow_id) @@ -311,10 +428,13 @@ def UpdateIncrementalFlowRequests( request_dict = self.flow_requests[(client_id, flow_id)] for request_id, next_response_id in next_response_id_updates.items(): request_dict[request_id].next_response_id = next_response_id - request_dict[request_id].timestamp = rdfvalue.RDFDatetime.Now() + request_dict[request_id].timestamp = int(rdfvalue.RDFDatetime.Now()) @utils.Synchronized - def DeleteFlowRequests(self, requests): + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + ) -> None: """Deletes a list of flow requests from the database.""" for request in requests: if (request.client_id, request.flow_id) not in self.flows: @@ -325,9 +445,10 @@ def DeleteFlowRequests(self, requests): request_dict = self.flow_requests.get(key, {}) try: del request_dict[request.request_id] - except KeyError: - raise db.UnknownFlowRequestError(request.client_id, request.flow_id, - request.request_id) + except KeyError as e: + raise db.UnknownFlowRequestError( + request.client_id, request.flow_id, request.request_id + ) from e response_dict = self.flow_responses.get(key, {}) try: @@ -336,8 +457,17 @@ def DeleteFlowRequests(self, requests): pass @utils.Synchronized - def WriteFlowResponses(self, responses): - """Writes FlowMessages and updates corresponding requests.""" + def WriteFlowResponses( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + """Writes Flow responses and updates corresponding requests.""" status_available = {} requests_updated = set() task_ids_by_request = {} @@ -345,26 +475,44 @@ def WriteFlowResponses(self, responses): for response in responses: flow_key = (response.client_id, response.flow_id) if flow_key not in self.flows: - logging.error("Received response for unknown flow %s, %s.", - response.client_id, response.flow_id) + logging.error( + "Received response for unknown flow %s, %s.", + response.client_id, + response.flow_id, + ) continue request_dict = self.flow_requests.get(flow_key, {}) if response.request_id not in request_dict: - logging.error("Received response for unknown request %s, %s, %d.", - response.client_id, response.flow_id, response.request_id) + logging.error( + "Received response for unknown request %s, %s, %d.", + response.client_id, + response.flow_id, + response.request_id, + ) continue response_dict = self.flow_responses.setdefault(flow_key, {}) - clone = response.Copy() - clone.timestamp = rdfvalue.RDFDatetime.Now() - - response_dict.setdefault(response.request_id, - {})[response.response_id] = clone - - if isinstance(response, rdf_flow_objects.FlowStatus): - status_available[(response.client_id, response.flow_id, - response.request_id, response.response_id)] = response + clone = flows_pb2.FlowResponse() + if isinstance(response, flows_pb2.FlowIterator): + clone = flows_pb2.FlowIterator() + elif isinstance(response, flows_pb2.FlowStatus): + clone = flows_pb2.FlowStatus() + + clone.CopyFrom(response) + clone.timestamp = int(rdfvalue.RDFDatetime.Now()) + + response_dict.setdefault(response.request_id, {})[ + response.response_id + ] = clone + + if isinstance(response, flows_pb2.FlowStatus): + status_available[( + response.client_id, + response.flow_id, + response.request_id, + response.response_id, + )] = response request_key = (response.client_id, response.flow_id, response.request_id) requests_updated.add(request_key) @@ -372,7 +520,6 @@ def WriteFlowResponses(self, responses): task_ids_by_request[request_key] = response.task_id except AttributeError: pass - # Every time we get a status we store how many responses are expected. for status in status_available.values(): request_dict = self.flow_requests[(status.client_id, status.flow_id)] @@ -384,6 +531,7 @@ def WriteFlowResponses(self, responses): for client_id, flow_id, request_id in requests_updated: flow_key = (client_id, flow_id) flow = self.flows[flow_key] + flow = mig_flow_objects.ToRDFFlow(flow) request_dict = self.flow_requests[flow_key] request = request_dict[request_id] @@ -397,24 +545,46 @@ def WriteFlowResponses(self, responses): if flow.next_request_to_process == request_id: added_for_processing = True - needs_processing.append( - rdf_flows.FlowProcessingRequest( - client_id=client_id, flow_id=flow_id)) + flow_processing_request = flows_pb2.FlowProcessingRequest( + client_id=client_id, + flow_id=flow_id, + ) + if request.start_time: + flow_processing_request.delivery_time = request.start_time + needs_processing.append(flow_processing_request) + + if ( + request.callback_state + and flow.next_request_to_process == request_id + and not added_for_processing + ): - if (request.callback_state and - flow.next_request_to_process == request_id and - not added_for_processing): needs_processing.append( - rdf_flows.FlowProcessingRequest( - client_id=client_id, flow_id=flow_id)) - + flows_pb2.FlowProcessingRequest( + client_id=client_id, flow_id=flow_id + ) + ) if needs_processing: self.WriteFlowProcessingRequests(needs_processing) - return needs_processing - @utils.Synchronized - def ReadAllFlowRequestsAndResponses(self, client_id, flow_id): + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> Iterable[ + Tuple[ + flows_pb2.FlowRequest, + Dict[ + int, + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ] + ]: """Reads all requests and responses for a given flow from the database.""" flow_key = (client_id, flow_id) try: @@ -432,7 +602,11 @@ def ReadAllFlowRequestsAndResponses(self, client_id, flow_id): return res @utils.Synchronized - def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id): + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> None: """Deletes all requests and responses for a given flow from the database.""" flow_key = (client_id, flow_id) try: @@ -451,10 +625,24 @@ def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id): pass @utils.Synchronized - def ReadFlowRequestsReadyForProcessing(self, - client_id, - flow_id, - next_needed_request=None): + def ReadFlowRequestsReadyForProcessing( + self, + client_id: str, + flow_id: str, + next_needed_request: Optional[int] = None, + ) -> Dict[ + int, + Tuple[ + flows_pb2.FlowRequest, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: """Reads all requests for a flow that can be processed by the worker.""" request_dict = self.flow_requests.get((client_id, flow_id), {}) response_dict = self.flow_responses.get((client_id, flow_id), {}) @@ -475,16 +663,20 @@ def ReadFlowRequestsReadyForProcessing(self, responses = sorted( response_dict.get(request_id, {}).values(), - key=lambda response: response.response_id) + key=lambda response: response.response_id, + ) + # Serialize/deserialize responses to better simulate the # real DB behavior (where serialization/deserialization is almost # guaranteed to be done). # TODO(user): change mem-db implementation to do # serialization/deserialization everywhere in a generic way. - responses = [ - r.__class__.FromSerializedBytes(r.SerializeToBytes()) - for r in responses - ] + reserialized_responses = [] + for r in responses: + response = r.__class__() + response.ParseFromString(r.SerializeToString()) + reserialized_responses.append(response) + responses = reserialized_responses res[request_id] = (request, responses) next_needed_request += 1 @@ -499,6 +691,7 @@ def ReadFlowRequestsReadyForProcessing(self, continue responses = response_dict.get(request_id, {}).values() + responses = [ r for r in responses if r.response_id >= request.next_response_id ] @@ -509,16 +702,18 @@ def ReadFlowRequestsReadyForProcessing(self, # guaranteed to be done). # TODO(user): change mem-db implementation to do # serialization/deserialization everywhere in a generic way. - responses = [ - r.__class__.FromSerializedBytes(r.SerializeToBytes()) - for r in responses - ] + reserialized_responses = [] + for r in responses: + response = r.__class__() + response.ParseFromString(r.SerializeToString()) + reserialized_responses.append(response) + responses = reserialized_responses res[request_id] = (request, responses) return res @utils.Synchronized - def ReleaseProcessedFlow(self, flow_obj): + def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: """Releases a flow that the worker was processing to the database.""" key = (flow_obj.client_id, flow_obj.flow_id) next_id_to_process = flow_obj.next_request_to_process @@ -528,7 +723,7 @@ def ReleaseProcessedFlow(self, flow_obj): and request_dict[next_id_to_process].needs_processing ): start_time = request_dict[next_id_to_process].start_time - if start_time is None or start_time < rdfvalue.RDFDatetime.Now(): + if not start_time or start_time < int(rdfvalue.RDFDatetime.Now()): return False self.UpdateFlow( @@ -537,17 +732,30 @@ def ReleaseProcessedFlow(self, flow_obj): flow_obj=flow_obj, processing_on=None, processing_since=None, - processing_deadline=None) + processing_deadline=None, + ) return True - def _InlineProcessingOK(self, requests): + def _InlineProcessingOK( + self, requests: Sequence[flows_pb2.FlowProcessingRequest] + ) -> bool: + """Returns whether inline processing is OK for a list of requests.""" for r in requests: - if r.delivery_time is not None: + if r.delivery_time: + return False + + # If the corresponding flow is already being processed, inline processing + # won't work. + flow = self.flows[r.client_id, r.flow_id] + if flow.HasField("processing_since"): return False return True @utils.Synchronized - def WriteFlowProcessingRequests(self, requests): + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + ) -> None: """Writes a list of flow processing requests to the database.""" # If we don't have a handler thread running, we might be able to process the # requests inline. If we are not, we start the handler thread for real and @@ -555,26 +763,31 @@ def WriteFlowProcessingRequests(self, requests): if not self.flow_handler_thread and self.flow_handler_target: if self._InlineProcessingOK(requests): for r in requests: + r.creation_time = int(rdfvalue.RDFDatetime.Now()) self.flow_handler_target(r) return else: self._RegisterFlowProcessingHandler(self.flow_handler_target) self.flow_handler_target = None - now = rdfvalue.RDFDatetime.Now() for r in requests: - cloned_request = r.Copy() - cloned_request.timestamp = now + cloned_request = flows_pb2.FlowProcessingRequest() + cloned_request.CopyFrom(r) key = (r.client_id, r.flow_id) + cloned_request.creation_time = int(rdfvalue.RDFDatetime.Now()) self.flow_processing_requests[key] = cloned_request @utils.Synchronized - def ReadFlowProcessingRequests(self): + def ReadFlowProcessingRequests( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Reads all flow processing requests from the database.""" return list(self.flow_processing_requests.values()) @utils.Synchronized - def AckFlowProcessingRequests(self, requests): + def AckFlowProcessingRequests( + self, requests: Iterable[flows_pb2.FlowProcessingRequest] + ) -> None: """Deletes a list of flow processing requests from the database.""" for r in requests: key = (r.client_id, r.flow_id) @@ -582,10 +795,12 @@ def AckFlowProcessingRequests(self, requests): del self.flow_processing_requests[key] @utils.Synchronized - def DeleteAllFlowProcessingRequests(self): + def DeleteAllFlowProcessingRequests(self) -> None: self.flow_processing_requests = {} - def RegisterFlowProcessingHandler(self, handler): + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: """Registers a message handler to receive flow processing messages.""" self.UnregisterFlowProcessingHandler() @@ -596,20 +811,26 @@ def RegisterFlowProcessingHandler(self, handler): for request in self._GetFlowRequestsReadyForProcessing(): handler(request) with self.lock: - self.flow_processing_requests.pop((request.client_id, request.flow_id), - None) + self.flow_processing_requests.pop( + (request.client_id, request.flow_id), None + ) - def _RegisterFlowProcessingHandler(self, handler): + def _RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: """Registers a handler to receive flow processing messages.""" self.flow_handler_stop = False self.flow_handler_thread = threading.Thread( name="flow_processing_handler", target=self._HandleFlowProcessingRequestLoop, - args=(handler,)) + args=(handler,), + ) self.flow_handler_thread.daemon = True self.flow_handler_thread.start() - def UnregisterFlowProcessingHandler(self, timeout=None): + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered flow processing handler.""" self.flow_handler_target = None @@ -621,11 +842,13 @@ def UnregisterFlowProcessingHandler(self, timeout=None): self.flow_handler_thread = None @utils.Synchronized - def _GetFlowRequestsReadyForProcessing(self): + def _GetFlowRequestsReadyForProcessing( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: now = rdfvalue.RDFDatetime.Now() todo = [] for r in list(self.flow_processing_requests.values()): - if r.delivery_time is None or r.delivery_time <= now: + if not r.delivery_time or r.delivery_time <= now: todo.append(r) return todo @@ -649,16 +872,18 @@ def WaitUntilNoFlowsToProcess(self, timeout=None): # If the thread is dead, or there are no requests # to be processed/being processed, we stop waiting # and return from the function. - if (not t.is_alive() or - (not self._GetFlowRequestsReadyForProcessing() and - not self.flow_handler_num_being_processed)): + if not t.is_alive() or ( + not self._GetFlowRequestsReadyForProcessing() + and not self.flow_handler_num_being_processed + ): return time.sleep(0.2) if timeout and time.time() - start_time > timeout: raise TimeOutWhileWaitingForFlowsToBeProcessedError( - "Flow processing didn't finish in time.") + "Flow processing didn't finish in time." + ) def _HandleFlowProcessingRequestLoop(self, handler): """Handler thread for the FlowProcessingRequest queue.""" @@ -667,8 +892,9 @@ def _HandleFlowProcessingRequestLoop(self, handler): todo = self._GetFlowRequestsReadyForProcessing() for request in todo: self.flow_handler_num_being_processed += 1 - del self.flow_processing_requests[(request.client_id, - request.flow_id)] + del self.flow_processing_requests[ + (request.client_id, request.flow_id) + ] for request in todo: handler(request) @@ -746,7 +972,7 @@ def _ReadFlowResultsOrErrors( if encoded_substring in i.payload.SerializeToString() ] - return results[offset:offset + count] + return results[offset : offset + count] def ReadFlowResults( self, @@ -767,7 +993,8 @@ def ReadFlowResults( count, with_tag=with_tag, with_type=with_type, - with_substring=with_substring) + with_substring=with_substring, + ) @utils.Synchronized def CountFlowResults( @@ -785,7 +1012,9 @@ def CountFlowResults( 0, sys.maxsize, with_tag=with_tag, - with_type=with_type)) + with_type=with_type, + ) + ) @utils.Synchronized def CountFlowResultsByType( @@ -828,7 +1057,8 @@ def ReadFlowErrors( offset, count, with_tag=with_tag, - with_type=with_type) + with_type=with_type, + ) @utils.Synchronized def CountFlowErrors( @@ -846,7 +1076,9 @@ def CountFlowErrors( 0, sys.maxsize, with_tag=with_tag, - with_type=with_type)) + with_type=with_type, + ) + ) @utils.Synchronized def CountFlowErrorsByType( @@ -886,12 +1118,13 @@ def ReadFlowLogEntries( """Reads flow log entries of a given flow using given query options.""" entries = sorted( self.flow_log_entries.get((client_id, flow_id), []), - key=lambda e: e.timestamp) + key=lambda e: e.timestamp, + ) if with_substring is not None: entries = [i for i in entries if with_substring in i.message] - return entries[offset:offset + count] + return entries[offset : offset + count] @utils.Synchronized def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: @@ -930,14 +1163,15 @@ def ReadFlowOutputPluginLogEntries( """Reads flow output plugin log entries.""" entries = sorted( self.flow_output_plugin_log_entries.get((client_id, flow_id), []), - key=lambda e: e.timestamp) + key=lambda e: e.timestamp, + ) entries = [e for e in entries if e.output_plugin_id == output_plugin_id] if with_type is not None: entries = [e for e in entries if e.log_entry_type == with_type] - return entries[offset:offset + count] + return entries[offset : offset + count] @utils.Synchronized def CountFlowOutputPluginLogEntries( @@ -958,7 +1192,9 @@ def CountFlowOutputPluginLogEntries( output_plugin_id, 0, sys.maxsize, - with_type=with_type)) + with_type=with_type, + ) + ) @utils.Synchronized def WriteScheduledFlow( @@ -999,7 +1235,8 @@ def DeleteScheduledFlow( raise db.UnknownScheduledFlowError( client_id=client_id, creator=creator, - scheduled_flow_id=scheduled_flow_id) + scheduled_flow_id=scheduled_flow_id, + ) @utils.Synchronized def ListScheduledFlows( diff --git a/grr/server/grr_response_server/databases/mem_flows_large_test.py b/grr/server/grr_response_server/databases/mem_flows_large_test.py index d77deb4019..b1c6141162 100644 --- a/grr/server/grr_response_server/databases/mem_flows_large_test.py +++ b/grr/server/grr_response_server/databases/mem_flows_large_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBFlowTest(db_flows_test.DatabaseLargeTestFlowMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBFlowTest( + db_flows_test.DatabaseLargeTestFlowMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_flows_test.py b/grr/server/grr_response_server/databases/mem_flows_test.py index 44dc1910ed..1d8f5ae9d2 100644 --- a/grr/server/grr_response_server/databases/mem_flows_test.py +++ b/grr/server/grr_response_server/databases/mem_flows_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBFlowTest(db_flows_test.DatabaseTestFlowMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBFlowTest( + db_flows_test.DatabaseTestFlowMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_foreman_rules.py b/grr/server/grr_response_server/databases/mem_foreman_rules.py index 73fd8e5f0f..8c213e95c2 100644 --- a/grr/server/grr_response_server/databases/mem_foreman_rules.py +++ b/grr/server/grr_response_server/databases/mem_foreman_rules.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The in memory database methods for foreman rule handling.""" + from typing import Sequence from grr_response_core.lib import rdfvalue diff --git a/grr/server/grr_response_server/databases/mem_foreman_rules_test.py b/grr/server/grr_response_server/databases/mem_foreman_rules_test.py index 182653f4e5..8aaf9fa429 100644 --- a/grr/server/grr_response_server/databases/mem_foreman_rules_test.py +++ b/grr/server/grr_response_server/databases/mem_foreman_rules_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -10,7 +9,9 @@ class MemoryDBForemanRulesTest( db_foreman_rules_test.DatabaseTestForemanRulesMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_hunts.py b/grr/server/grr_response_server/databases/mem_hunts.py index dccb20cea9..827458b5fc 100644 --- a/grr/server/grr_response_server/databases/mem_hunts.py +++ b/grr/server/grr_response_server/databases/mem_hunts.py @@ -1,36 +1,104 @@ #!/usr/bin/env python """The in memory database methods for hunt handling.""" +from collections.abc import Callable +import math import sys -from typing import Dict, Optional -from typing import Sequence +from typing import AbstractSet, Collection, Dict, Iterable, List, Mapping, Optional, Sequence from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats +from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import stats as rdf_stats from grr_response_proto import flows_pb2 from grr_response_proto import hunts_pb2 +from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 +from grr_response_proto import output_plugin_pb2 from grr_response_server.databases import db -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_hunt_objects +def UpdateHistogram(histogram: rdf_stats.StatsHistogram, value: float): + """Puts a given value into an appropriate bin.""" + for b in histogram.bins: + if b.range_max_value > value: + b.num += 1 + return + + if histogram.bins: + histogram.bins[-1].num += 1 + + +def UpdateStats(running_stats: rdf_stats.RunningStats, values: Iterable[float]): + """Updates running stats with the given values.""" + sum_sq = 0 + for value in values: + running_stats.num += 1 + running_stats.sum += value + sum_sq += value**2 + mean = running_stats.sum / running_stats.num if running_stats.num > 0 else 0 + running_stats.stddev = math.sqrt(sum_sq / running_stats.num - mean**2) + UpdateHistogram(running_stats.histogram, value) + + +def InitializeClientResourcesStats( + client_resources: Sequence[jobs_pb2.ClientResources], +) -> jobs_pb2.ClientResourcesStats: + """Initialized ClientResourcesStats with resources consumed by a single client.""" + + stats = jobs_pb2.ClientResourcesStats() + stats.user_cpu_stats.histogram.bins.extend([ + jobs_pb2.StatsHistogramBin(range_max_value=b) + for b in rdf_stats.ClientResourcesStats.CPU_STATS_BINS + ]) + stats.system_cpu_stats.histogram.bins.extend([ + jobs_pb2.StatsHistogramBin(range_max_value=b) + for b in rdf_stats.ClientResourcesStats.CPU_STATS_BINS + ]) + stats.network_bytes_sent_stats.histogram.bins.extend([ + jobs_pb2.StatsHistogramBin(range_max_value=b) + for b in rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS + ]) + UpdateStats( + stats.user_cpu_stats, + [r.cpu_usage.user_cpu_time for r in client_resources], + ) + UpdateStats( + stats.system_cpu_stats, + [r.cpu_usage.system_cpu_time for r in client_resources], + ) + UpdateStats( + stats.network_bytes_sent_stats, + [r.network_bytes_sent for r in client_resources], + ) + + client_resources.sort( + key=lambda s: s.cpu_usage.user_cpu_time + s.cpu_usage.system_cpu_time, + reverse=True, + ) + stats.worst_performers.extend( + client_resources[: rdf_stats.ClientResourcesStats.NUM_WORST_PERFORMERS] + ) + + return stats + + class InMemoryDBHuntMixin(object): """Hunts-related DB methods implementation.""" hunts: Dict[str, hunts_pb2.Hunt] + flows: Dict[str, flows_pb2.Flow] - def _GetHuntFlows(self, hunt_id): + def _GetHuntFlows(self, hunt_id: str) -> List[flows_pb2.Flow]: hunt_flows = [ f for f in self.flows.values() if f.parent_hunt_id == hunt_id and f.flow_id == hunt_id ] + hunt_flows = [mig_flow_objects.ToRDFFlow(f) for f in hunt_flows] return sorted(hunt_flows, key=lambda f: f.client_id) @utils.Synchronized @@ -46,7 +114,13 @@ def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt): self.hunts[hunt_obj.hunt_id] = clone @utils.Synchronized - def UpdateHuntObject(self, hunt_id, start_time=None, **kwargs): + def UpdateHuntObject( + self, + hunt_id: str, + start_time: Optional[rdfvalue.RDFDatetime] = None, + duration: Optional[rdfvalue.Duration] = None, + **kwargs, + ): """Updates the hunt object by applying the update function.""" hunt_obj = self.ReadHuntObject(hunt_id) @@ -63,58 +137,80 @@ def UpdateHuntObject(self, hunt_id, start_time=None, **kwargs): else: setattr(hunt_obj, k, v) + if duration is not None: + hunt_obj.duration = duration.ToInt(rdfvalue.SECONDS) + if start_time is not None: - hunt_obj.init_start_time = hunt_obj.init_start_time or start_time - hunt_obj.last_start_time = start_time + hunt_obj.init_start_time = hunt_obj.init_start_time or int(start_time) + hunt_obj.last_start_time = int(start_time) - hunt_obj.last_update_time = rdfvalue.RDFDatetime.Now() - self.hunts[hunt_obj.hunt_id] = mig_hunt_objects.ToProtoHunt(hunt_obj) + hunt_obj.last_update_time = int(rdfvalue.RDFDatetime.Now()) + self.hunts[hunt_obj.hunt_id] = hunt_obj @utils.Synchronized - def ReadHuntOutputPluginsStates(self, hunt_id): + def ReadHuntOutputPluginsStates( + self, + hunt_id: str, + ) -> List[output_plugin_pb2.OutputPluginState]: + """Reads hunt output plugin states for a given hunt.""" if hunt_id not in self.hunts: raise db.UnknownHuntError(hunt_id) serialized_states = self.hunt_output_plugins_states.get(hunt_id, []) - return [ - rdf_flow_runner.OutputPluginState.FromSerializedBytes(s) - for s in serialized_states - ] + result = [] + for s in serialized_states: + output_plugin_state = output_plugin_pb2.OutputPluginState() + output_plugin_state.ParseFromString(s) + result.append(output_plugin_state) + + return result @utils.Synchronized - def WriteHuntOutputPluginsStates(self, hunt_id, states): + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + ) -> None: if hunt_id not in self.hunts: raise db.UnknownHuntError(hunt_id) self.hunt_output_plugins_states[hunt_id] = [ - s.SerializeToBytes() for s in states + s.SerializeToString() for s in states ] @utils.Synchronized - def UpdateHuntOutputPluginState(self, hunt_id, state_index, update_fn): + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[ + [jobs_pb2.AttributedDict], + jobs_pb2.AttributedDict, + ], + ) -> jobs_pb2.AttributedDict: """Updates hunt output plugin state for a given output plugin.""" - if hunt_id not in self.hunts: raise db.UnknownHuntError(hunt_id) + state = output_plugin_pb2.OutputPluginState() try: - state = rdf_flow_runner.OutputPluginState.FromSerializedBytes( + state.ParseFromString( self.hunt_output_plugins_states[hunt_id][state_index] ) - except KeyError: - raise db.UnknownHuntOutputPluginError(hunt_id, state_index) + except KeyError as ex: + raise db.UnknownHuntOutputPluginError(hunt_id, state_index) from ex - state.plugin_state = update_fn(state.plugin_state) + modified_plugin_state = update_fn(state.plugin_state) + state.plugin_state.CopyFrom(modified_plugin_state) self.hunt_output_plugins_states[hunt_id][ state_index - ] = state.SerializeToBytes() - + ] = state.SerializeToString() return state.plugin_state @utils.Synchronized - def DeleteHuntObject(self, hunt_id): + def DeleteHuntObject(self, hunt_id: str) -> None: """Deletes a hunt object with a given id.""" try: del self.hunts[hunt_id] @@ -135,25 +231,30 @@ def DeleteHuntObject(self, hunt_id): del approvals[approval_id] @utils.Synchronized - def ReadHuntObject(self, hunt_id: str) -> rdf_hunt_objects.Hunt: + def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: """Reads a hunt object from the database.""" + hunt = hunts_pb2.Hunt() try: - return self._DeepCopy(mig_hunt_objects.ToRDFHunt(self.hunts[hunt_id])) - except KeyError: - raise db.UnknownHuntError(hunt_id) + hunt_instance = self.hunts[hunt_id] + except KeyError as ex: + raise db.UnknownHuntError(hunt_id) from ex + hunt.CopyFrom(hunt_instance) + return hunt @utils.Synchronized def ReadHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> List[hunts_pb2.Hunt]: """Reads metadata for hunt objects from the database.""" filter_fns = [] if with_creator is not None: @@ -163,33 +264,36 @@ def ReadHuntObjects( if not_created_by is not None: filter_fns.append(lambda h: h.creator not in not_created_by) if created_after is not None: - filter_fns.append(lambda h: h.create_time > created_after) + filter_fns.append(lambda h: h.create_time > int(created_after)) if with_description_match is not None: filter_fns.append(lambda h: with_description_match in h.description) if with_states is not None: filter_fns.append(lambda h: h.hunt_state in with_states) filter_fn = lambda h: all(f(h) for f in filter_fns) - result = [ - self._DeepCopy(mig_hunt_objects.ToRDFHunt(h)) - for h in self.hunts.values() - if filter_fn(h) - ] + result = [] + for h in self.hunts.values(): + if filter_fn(h): + hunt_obj = hunts_pb2.Hunt() + hunt_obj.CopyFrom(h) + result.append(hunt_obj) result.sort(key=lambda h: h.create_time, reverse=True) return result[offset : offset + (count or db.MAX_COUNT)] @utils.Synchronized def ListHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> List[hunts_pb2.HuntMetadata]: """Reads all hunt objects from the database.""" filter_fns = [] if with_creator is not None: @@ -199,7 +303,7 @@ def ListHuntObjects( if not_created_by is not None: filter_fns.append(lambda h: h.creator not in not_created_by) if created_after is not None: - filter_fns.append(lambda h: h.create_time > created_after) + filter_fns.append(lambda h: h.create_time > int(created_after)) if with_description_match is not None: filter_fns.append(lambda h: with_description_match in h.description) if with_states is not None: @@ -208,10 +312,12 @@ def ListHuntObjects( result = [] for h in self.hunts.values(): - h = mig_hunt_objects.ToRDFHunt(h) if not filter_fn(h): continue - result.append(rdf_hunt_objects.HuntMetadata.FromHunt(h)) + h = mig_hunt_objects.ToRDFHunt(h) + hunt_metadata = rdf_hunt_objects.HuntMetadata.FromHunt(h) + hunt_metadata = mig_hunt_objects.ToProtoHuntMetadata(hunt_metadata) + result.append(hunt_metadata) result.sort(key=lambda h: h.create_time, reverse=True) return result[offset : offset + (count or db.MAX_COUNT)] @@ -256,18 +362,18 @@ def CountHuntLogEntries(self, hunt_id: str) -> int: @utils.Synchronized def ReadHuntResults( self, - hunt_id, - offset, - count, - with_tag=None, - with_type=None, - with_substring=None, - with_timestamp=None, - ): + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Iterable[flows_pb2.FlowResult]: """Reads hunt results of a given hunt using given query options.""" all_results = [] for flow_obj in self._GetHuntFlows(hunt_id): - for proto_entry in self.ReadFlowResults( + for entry in self.ReadFlowResults( flow_obj.client_id, flow_obj.flow_id, 0, @@ -276,9 +382,8 @@ def ReadHuntResults( with_type=with_type, with_substring=with_substring, ): - entry = mig_flow_objects.ToRDFFlowResult(proto_entry) all_results.append( - rdf_flow_objects.FlowResult( + flows_pb2.FlowResult( hunt_id=hunt_id, client_id=flow_obj.client_id, flow_id=flow_obj.flow_id, @@ -295,7 +400,12 @@ def ReadHuntResults( return all_results[offset : offset + count] @utils.Synchronized - def CountHuntResults(self, hunt_id, with_tag=None, with_type=None): + def CountHuntResults( + self, + hunt_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: """Counts hunt results of a given hunt using given query options.""" return len( self.ReadHuntResults( @@ -304,9 +414,10 @@ def CountHuntResults(self, hunt_id, with_tag=None, with_type=None): ) @utils.Synchronized - def CountHuntResultsByType(self, hunt_id): + def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: result = {} for hr in self.ReadHuntResults(hunt_id, 0, sys.maxsize): + hr = mig_flow_objects.ToRDFFlowResult(hr) key = hr.payload.__class__.__name__ result[key] = result.setdefault(key, 0) + 1 @@ -314,8 +425,12 @@ def CountHuntResultsByType(self, hunt_id): @utils.Synchronized def ReadHuntFlows( - self, hunt_id, offset, count, filter_condition=db.HuntFlowsCondition.UNSET - ): + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: db.HuntFlowsCondition = db.HuntFlowsCondition.UNSET, + ) -> Sequence[flows_pb2.Flow]: """Reads hunt flows matching given conditins.""" if filter_condition == db.HuntFlowsCondition.UNSET: filter_fn = lambda _: True @@ -341,12 +456,18 @@ def ReadHuntFlows( if filter_fn(flow_obj) ] results.sort(key=lambda f: f.last_update_time) - return results[offset : offset + count] + results = results[offset : offset + count] + results = [mig_flow_objects.ToProtoFlow(f) for f in results] + return results @utils.Synchronized def CountHuntFlows( - self, hunt_id, filter_condition=db.HuntFlowsCondition.UNSET - ): + self, + hunt_id: str, + filter_condition: Optional[ + db.HuntFlowsCondition + ] = db.HuntFlowsCondition.UNSET, + ) -> int: """Counts hunt flows matching given conditions.""" return len( @@ -356,81 +477,98 @@ def CountHuntFlows( ) @utils.Synchronized - def ReadHuntCounters(self, hunt_id): - """Reads hunt counters.""" - num_clients = self.CountHuntFlows(hunt_id) - num_successful_clients = self.CountHuntFlows( - hunt_id, filter_condition=db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY - ) - num_failed_clients = self.CountHuntFlows( - hunt_id, filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY - ) - num_clients_with_results = len( - set( - r[0].client_id - for r in self.flow_results.values() - if r and r[0].hunt_id == hunt_id - ) - ) - num_crashed_clients = self.CountHuntFlows( - hunt_id, filter_condition=db.HuntFlowsCondition.CRASHED_FLOWS_ONLY - ) - num_running_clients = self.CountHuntFlows( - hunt_id, filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY - ) - num_results = self.CountHuntResults(hunt_id) + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + ) -> Mapping[str, db.HuntCounters]: + """Reads hunt counters for several hunt ids.""" + hunt_counters = {} + for hunt_id in hunt_ids: + num_clients = self.CountHuntFlows(hunt_id) + num_successful_clients = self.CountHuntFlows( + hunt_id, filter_condition=db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY + ) + num_failed_clients = self.CountHuntFlows( + hunt_id, filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY + ) + num_clients_with_results = len( + set( + r[0].client_id + for r in self.flow_results.values() + if r and r[0].hunt_id == hunt_id + ) + ) + num_crashed_clients = self.CountHuntFlows( + hunt_id, filter_condition=db.HuntFlowsCondition.CRASHED_FLOWS_ONLY + ) + num_running_clients = self.CountHuntFlows( + hunt_id, filter_condition=db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY + ) + num_results = self.CountHuntResults(hunt_id) - total_cpu_seconds = 0 - total_network_bytes_sent = 0 - for f in self.ReadHuntFlows(hunt_id, 0, sys.maxsize): - total_cpu_seconds += ( - f.cpu_time_used.user_cpu_time + f.cpu_time_used.system_cpu_time + total_cpu_seconds = 0 + total_network_bytes_sent = 0 + for f in self.ReadHuntFlows(hunt_id, 0, sys.maxsize): + total_cpu_seconds += ( + f.cpu_time_used.user_cpu_time + f.cpu_time_used.system_cpu_time + ) + total_network_bytes_sent += f.network_bytes_sent + + hunt_counters[hunt_id] = db.HuntCounters( + num_clients=num_clients, + num_successful_clients=num_successful_clients, + num_failed_clients=num_failed_clients, + num_clients_with_results=num_clients_with_results, + num_crashed_clients=num_crashed_clients, + num_running_clients=num_running_clients, + num_results=num_results, + total_cpu_seconds=total_cpu_seconds, + total_network_bytes_sent=total_network_bytes_sent, ) - total_network_bytes_sent += f.network_bytes_sent - - return db.HuntCounters( - num_clients=num_clients, - num_successful_clients=num_successful_clients, - num_failed_clients=num_failed_clients, - num_clients_with_results=num_clients_with_results, - num_crashed_clients=num_crashed_clients, - num_running_clients=num_running_clients, - num_results=num_results, - total_cpu_seconds=total_cpu_seconds, - total_network_bytes_sent=total_network_bytes_sent, - ) + return hunt_counters @utils.Synchronized - def ReadHuntClientResourcesStats(self, hunt_id): + def ReadHuntClientResourcesStats( + self, + hunt_id: str, + ) -> rdf_stats.ClientResourcesStats: """Read/calculate hunt client resources stats.""" - result = rdf_stats.ClientResourcesStats() + client_resources = [] + for f in self._GetHuntFlows(hunt_id): - cr = rdf_client_stats.ClientResources( - session_id="%s/%s" % (f.client_id, f.flow_id), - client_id=f.client_id, + f = mig_flow_objects.ToProtoFlow(f) + cr = jobs_pb2.ClientResources( + session_id=str(rdfvalue.RDFURN(f.client_id).Add(f.flow_id)), + client_id=str(rdf_client.ClientURN.FromHumanReadable(f.client_id)), cpu_usage=f.cpu_time_used, network_bytes_sent=f.network_bytes_sent, ) - result.RegisterResources(cr) + client_resources.append(cr) - # TODO(user): remove this hack when compatibility with AFF4 is not - # important. - return rdf_stats.ClientResourcesStats.FromSerializedBytes( - result.SerializeToBytes() - ) + result = InitializeClientResourcesStats(client_resources) + + return result @utils.Synchronized - def ReadHuntFlowsStatesAndTimestamps(self, hunt_id): + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + ) -> Sequence[db.FlowStateAndTimestamps]: """Reads hunt flows states and timestamps.""" result = [] for f in self._GetHuntFlows(hunt_id): + f = mig_flow_objects.ToProtoFlow(f) result.append( db.FlowStateAndTimestamps( flow_state=f.flow_state, - create_time=f.create_time, - last_update_time=f.last_update_time, + create_time=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + f.create_time + ), + last_update_time=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + f.last_update_time + ), ) ) diff --git a/grr/server/grr_response_server/databases/mem_hunts_test.py b/grr/server/grr_response_server/databases/mem_hunts_test.py index b41cdac4de..5fede4f156 100644 --- a/grr/server/grr_response_server/databases/mem_hunts_test.py +++ b/grr/server/grr_response_server/databases/mem_hunts_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -9,9 +8,12 @@ from grr.test_lib import test_lib -class MemoryDBHuntTest(db_hunts_test.DatabaseTestHuntMixin, - db_test_utils.QueryTestHelpersMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBHuntTest( + db_hunts_test.DatabaseTestHuntMixin, + db_test_utils.QueryTestHelpersMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_message_handler_test.py b/grr/server/grr_response_server/databases/mem_message_handler_test.py index 2cdbd9a728..5ea50e89fc 100644 --- a/grr/server/grr_response_server/databases/mem_message_handler_test.py +++ b/grr/server/grr_response_server/databases/mem_message_handler_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBHandlerTest(db_message_handler_test.DatabaseTestHandlerMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBHandlerTest( + db_message_handler_test.DatabaseTestHandlerMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_paths.py b/grr/server/grr_response_server/databases/mem_paths.py index fb2ec8f5e0..307d03e2db 100644 --- a/grr/server/grr_response_server/databases/mem_paths.py +++ b/grr/server/grr_response_server/databases/mem_paths.py @@ -13,13 +13,11 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import mig_client_fs -from grr_response_core.lib.rdfvalues import mig_crypto from grr_response_core.lib.util import collection from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_server.databases import db -from grr_response_server.rdfvalues import mig_objects +from grr_response_server.models import paths from grr_response_server.rdfvalues import objects as rdf_objects @@ -163,7 +161,7 @@ class InMemoryDBPathMixin(object): # Maps (client_id, path_type, components) to a path record. path_records: dict[ - Tuple[str, "rdf_objects.PathInfo.PathType", Tuple[str, ...]], _PathRecord + Tuple[str, "objects_pb2.PathInfo.PathType", Tuple[str, ...]], _PathRecord ] # Maps client_id to client metadata. @@ -178,16 +176,14 @@ class InMemoryDBPathMixin(object): def ReadPathInfo( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> rdf_objects.PathInfo: + ) -> objects_pb2.PathInfo: """Retrieves a path info record for a given path.""" try: path_record = self.path_records[(client_id, path_type, tuple(components))] - return mig_objects.ToRDFPathInfo( - path_record.GetPathInfo(timestamp=timestamp) - ) + return path_record.GetPathInfo(timestamp=timestamp) except KeyError: raise db.UnknownPathError( client_id=client_id, path_type=path_type, components=components @@ -197,9 +193,9 @@ def ReadPathInfo( def ReadPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Collection[Sequence[str]], - ) -> dict[Sequence[str], Optional[rdf_objects.PathInfo]]: + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: """Retrieves path info records for given paths.""" result = {} @@ -208,11 +204,9 @@ def ReadPathInfos( path_record = self.path_records[ (client_id, path_type, tuple(components)) ] - result[components] = mig_objects.ToRDFPathInfo( - path_record.GetPathInfo() - ) + result[tuple(components)] = path_record.GetPathInfo() except KeyError: - result[components] = None + result[tuple(components)] = None return result @@ -220,11 +214,11 @@ def ReadPathInfos( def ListDescendantPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, max_depth: Optional[int] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: """Lists path info records that correspond to children of given path.""" result = [] root_dir_exists = False @@ -247,8 +241,10 @@ def ListDescendantPathInfos( continue if not collection.StartsWith(other_components, components): continue - if (max_depth is not None and - len(other_components) - len(components) > max_depth): + if ( + max_depth is not None + and len(other_components) - len(components) > max_depth + ): continue result.append(path_info) @@ -257,10 +253,7 @@ def ListDescendantPathInfos( raise db.UnknownPathError(client_id, path_type, components) if timestamp is None: - return [ - mig_objects.ToRDFPathInfo(info) - for info in sorted(result, key=lambda _: tuple(_.components)) - ] + return sorted(result, key=lambda _: tuple(_.components)) # We need to filter implicit path infos if specific timestamp is given. @@ -280,9 +273,9 @@ def Add(self, path_info, idx=0): components = path_info.components if idx == len(components): self.path_info = path_info - self.explicit |= ( - path_info.HasField("stat_entry") or - path_info.HasField("hash_entry")) + self.explicit |= path_info.HasField( + "stat_entry" + ) or path_info.HasField("hash_entry") else: child = self.children.setdefault(components[idx], TrieNode()) child.Add(path_info, idx=idx + 1) @@ -301,7 +294,7 @@ def Collect(self, path_infos): explicit_path_infos = [] trie.Collect(explicit_path_infos) - return [mig_objects.ToRDFPathInfo(info) for info in explicit_path_infos] + return explicit_path_infos def _GetPathRecord( self, client_id: str, path_info: objects_pb2.PathInfo @@ -324,27 +317,25 @@ def _WritePathInfo( def WritePathInfos( self, client_id: str, - path_infos: Iterable[rdf_objects.PathInfo], + path_infos: Iterable[objects_pb2.PathInfo], ) -> None: """Writes a collection of path_info records for a client.""" if client_id not in self.metadatas: raise db.UnknownClientError(client_id) for path_info in path_infos: - self._WritePathInfo(client_id, mig_objects.ToProtoPathInfo(path_info)) - for ancestor_path_info in path_info.GetAncestors(): - self._WritePathInfo( - client_id, mig_objects.ToProtoPathInfo(ancestor_path_info) - ) + self._WritePathInfo(client_id, path_info) + for ancestor_path_info in paths.GetAncestorPathInfos(path_info): + self._WritePathInfo(client_id, ancestor_path_info) @utils.Synchronized def ReadPathInfosHistories( self, client_id: Text, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Iterable[Sequence[Text]], cutoff: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[Sequence[Text], Sequence[rdf_objects.PathInfo]]: + ) -> dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: """Reads a collection of hash and stat entries for given paths.""" results = {} @@ -385,9 +376,7 @@ def ReadPathInfosHistories( if cutoff is not None and timestamp > cutoff_micros: continue - results[components].append( - mig_objects.ToRDFPathInfo(entries_by_ts[timestamp]) - ) + results[components].append(entries_by_ts[timestamp]) return results @@ -396,7 +385,7 @@ def ReadLatestPathInfosWithHashBlobReferences( self, client_paths: Collection[db.ClientPath], max_timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[db.ClientPath, Optional[rdf_objects.PathInfo]]: + ) -> Dict[db.ClientPath, Optional[objects_pb2.PathInfo]]: """Returns PathInfos that have corresponding HashBlobReferences.""" results = {} @@ -423,23 +412,20 @@ def ReadLatestPathInfosWithHashBlobReferences( ): continue - rdf_hash_entry = mig_crypto.ToRDFHash(hash_entry) - # TODO: Use protos below when changing signature and - # `blob_refs_by_hashes` is migrated to protos. hash_id = rdf_objects.SHA256HashID.FromSerializedBytes( - rdf_hash_entry.sha256.AsBytes() + hash_entry.sha256 ) if hash_id not in self.blob_refs_by_hashes: continue - pi = rdf_objects.PathInfo( + pi = objects_pb2.PathInfo( path_type=cp.path_type, components=tuple(cp.components), timestamp=ts, - hash_entry=rdf_hash_entry, ) + pi.hash_entry.CopyFrom(hash_entry) try: - pi.stat_entry = mig_client_fs.ToRDFStatEntry(stat_entries_by_ts[ts]) + pi.stat_entry.CopyFrom(stat_entries_by_ts[ts]) except KeyError: pass diff --git a/grr/server/grr_response_server/databases/mem_paths_test.py b/grr/server/grr_response_server/databases/mem_paths_test.py index 0fe50f8f0f..f91a0a49d4 100644 --- a/grr/server/grr_response_server/databases/mem_paths_test.py +++ b/grr/server/grr_response_server/databases/mem_paths_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBPathsTest(db_paths_test.DatabaseTestPathsMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBPathsTest( + db_paths_test.DatabaseTestPathsMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_signed_binaries.py b/grr/server/grr_response_server/databases/mem_signed_binaries.py index c9798cc1d9..dfde07837a 100644 --- a/grr/server/grr_response_server/databases/mem_signed_binaries.py +++ b/grr/server/grr_response_server/databases/mem_signed_binaries.py @@ -17,7 +17,7 @@ def _SignedBinaryKeyFromID( def _SignedBinaryIDFromKey( - binary_key: Tuple[int, Text] + binary_key: Tuple[int, Text], ) -> objects_pb2.SignedBinaryID: """Converts a tuple representing a signed binary to a SignedBinaryID.""" return objects_pb2.SignedBinaryID( diff --git a/grr/server/grr_response_server/databases/mem_signed_binaries_test.py b/grr/server/grr_response_server/databases/mem_signed_binaries_test.py index f2d0a16610..f2900f155e 100644 --- a/grr/server/grr_response_server/databases/mem_signed_binaries_test.py +++ b/grr/server/grr_response_server/databases/mem_signed_binaries_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -10,7 +9,9 @@ class MemoryDBSignedBinariesTest( db_signed_binaries_test.DatabaseTestSignedBinariesMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_time_test.py b/grr/server/grr_response_server/databases/mem_time_test.py index 456f9afea5..357e86434e 100644 --- a/grr/server/grr_response_server/databases/mem_time_test.py +++ b/grr/server/grr_response_server/databases/mem_time_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBArtifactsTest(db_time_test.DatabaseTimeTestMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBArtifactsTest( + db_time_test.DatabaseTimeTestMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_users.py b/grr/server/grr_response_server/databases/mem_users.py index 0971a2544f..de642f6316 100644 --- a/grr/server/grr_response_server/databases/mem_users.py +++ b/grr/server/grr_response_server/databases/mem_users.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The in memory database methods for GRR users and approval handling.""" + import os from typing import Optional, Sequence, Tuple @@ -131,9 +132,10 @@ def ReadApprovalRequest( res = objects_pb2.ApprovalRequest() res.CopyFrom(self.approvals_by_username[requestor_username][approval_id]) return res - except KeyError: - raise db.UnknownApprovalRequestError("Can't find approval with id: %s" % - approval_id) + except KeyError as e: + raise db.UnknownApprovalRequestError( + "Can't find approval with id: %s" % approval_id + ) from e @utils.Synchronized def ReadApprovalRequests( diff --git a/grr/server/grr_response_server/databases/mem_users_test.py b/grr/server/grr_response_server/databases/mem_users_test.py index 1cd71bc0ef..b61ea6e6c0 100644 --- a/grr/server/grr_response_server/databases/mem_users_test.py +++ b/grr/server/grr_response_server/databases/mem_users_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MemoryDBUsersTest(db_users_test.DatabaseTestUsersMixin, - mem_test_base.MemoryDBTestBase, absltest.TestCase): +class MemoryDBUsersTest( + db_users_test.DatabaseTestUsersMixin, + mem_test_base.MemoryDBTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mem_yara_test.py b/grr/server/grr_response_server/databases/mem_yara_test.py index fce5471b8c..83c419e4b5 100644 --- a/grr/server/grr_response_server/databases/mem_yara_test.py +++ b/grr/server/grr_response_server/databases/mem_yara_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_server.databases import db_yara_test_lib diff --git a/grr/server/grr_response_server/databases/mysql.py b/grr/server/grr_response_server/databases/mysql.py index a545db2dba..15e5ae8574 100644 --- a/grr/server/grr_response_server/databases/mysql.py +++ b/grr/server/grr_response_server/databases/mysql.py @@ -63,7 +63,9 @@ COLLATION = "utf8mb4_unicode_ci" CREATE_DATABASE_QUERY = ( "CREATE DATABASE `{}` CHARACTER SET {} COLLATE {}".format( - "{}", CHARACTER_SET, COLLATION)) # Keep first placeholder for later. + "{}", CHARACTER_SET, COLLATION + ) # Keep first placeholder for later. +) def _IsRetryable(error): @@ -120,15 +122,21 @@ def _CheckCollation(cursor): cur_collation_connection = _ReadVariable("collation_connection", cursor) if cur_collation_connection != COLLATION: - logging.warning("Require MySQL collation_connection of %s, got %s.", - COLLATION, cur_collation_connection) + logging.warning( + "Require MySQL collation_connection of %s, got %s.", + COLLATION, + cur_collation_connection, + ) cur_collation_database = _ReadVariable("collation_database", cursor) if cur_collation_database != COLLATION: logging.warning( "Require MySQL collation_database of %s, got %s." - " To create your database, use: %s", COLLATION, cur_collation_database, - CREATE_DATABASE_QUERY) + " To create your database, use: %s", + COLLATION, + cur_collation_database, + CREATE_DATABASE_QUERY, + ) def _SetEncoding(cursor): @@ -159,7 +167,9 @@ def _CheckConnectionEncoding(cursor): if cur_character_set != CHARACTER_SET: raise EncodingEnforcementError( "Require MySQL character_set_connection of {}, got {}.".format( - CHARACTER_SET, cur_character_set)) + CHARACTER_SET, cur_character_set + ) + ) def _CheckDatabaseEncoding(cursor): @@ -168,9 +178,10 @@ def _CheckDatabaseEncoding(cursor): if cur_character_set != CHARACTER_SET: raise EncodingEnforcementError( "Require MySQL character_set_database of {}, got {}." - " To create your database, use: {}".format(CHARACTER_SET, - cur_character_set, - CREATE_DATABASE_QUERY)) + " To create your database, use: {}".format( + CHARACTER_SET, cur_character_set, CREATE_DATABASE_QUERY + ) + ) def _SetPacketSizeForFollowingConnections(cursor): @@ -180,16 +191,20 @@ def _SetPacketSizeForFollowingConnections(cursor): if cur_packet_size < MAX_PACKET_SIZE: logging.warning( "MySQL max_allowed_packet of %d is required, got %d. Overwriting.", - MAX_PACKET_SIZE, cur_packet_size) + MAX_PACKET_SIZE, + cur_packet_size, + ) try: _SetGlobalVariable("max_allowed_packet", MAX_PACKET_SIZE, cursor) except MySQLdb.OperationalError as e: logging.error(e) - msg = ("Failed to override max_allowed_packet setting. " - "max_allowed_packet must be < %d. Please update MySQL " - "configuration or grant GRR sufficient privileges to " - "override global variables." % MAX_PACKET_SIZE) + msg = ( + "Failed to override max_allowed_packet setting. " + "max_allowed_packet must be < %d. Please update MySQL " + "configuration or grant GRR sufficient privileges to " + "override global variables." % MAX_PACKET_SIZE + ) logging.error(msg) raise MaxAllowedPacketSettingTooLowError(msg) @@ -201,7 +216,9 @@ def _CheckPacketSize(cursor): raise Error( "MySQL max_allowed_packet of {0} is required, got {1}. " "Please set max_allowed_packet={0} in your MySQL config.".format( - MAX_PACKET_SIZE, cur_packet_size)) + MAX_PACKET_SIZE, cur_packet_size + ) + ) def _CheckLogFileSize(cursor): @@ -220,8 +237,11 @@ def _CheckLogFileSize(cursor): max_blob_size_mib = max_blob_size / 2**20 logging.warning( "MySQL innodb_log_file_size of %d is required, got %d. " - "Storing Blobs bigger than %.4f MiB will fail.", required_size, - innodb_log_file_size, max_blob_size_mib) + "Storing Blobs bigger than %.4f MiB will fail.", + required_size, + innodb_log_file_size, + max_blob_size_mib, + ) def _IsMariaDB(cursor): @@ -254,8 +274,10 @@ def _SetMariaDBMode(cursor): # encounters duplicate keys. This flag disables this behavior for # consistency. if _IsMariaDB(cursor): - cursor.execute("SET @@OLD_MODE = CONCAT(@@OLD_MODE, " - "',NO_DUP_KEY_WARNINGS_WITH_IGNORE')") + cursor.execute( + "SET @@OLD_MODE = CONCAT(@@OLD_MODE, " + "',NO_DUP_KEY_WARNINGS_WITH_IGNORE')" + ) def _CheckForSSL(cursor): @@ -270,14 +292,16 @@ def _CheckForSSL(cursor): raise RuntimeError("Unable to establish SSL connection to MySQL.") -def _SetupDatabase(host=None, - port=None, - user=None, - password=None, - database=None, - client_key_path=None, - client_cert_path=None, - ca_cert_path=None): +def _SetupDatabase( + host=None, + port=None, + user=None, + password=None, + database=None, + client_key_path=None, + client_cert_path=None, + ca_cert_path=None, +): """Connect to the given MySQL host and create a utf8mb4_unicode_ci database. Args: @@ -301,7 +325,9 @@ def _SetupDatabase(host=None, database=None, client_key_path=client_key_path, client_cert_path=client_cert_path, - ca_cert_path=ca_cert_path)) as conn: + ca_cert_path=ca_cert_path, + ) + ) as conn: with contextlib.closing(conn.cursor()) as cursor: try: cursor.execute(CREATE_DATABASE_QUERY.format(database)) @@ -322,20 +348,24 @@ def _MigrationConnect(): database=database, client_key_path=client_key_path, client_cert_path=client_cert_path, - ca_cert_path=ca_cert_path) + ca_cert_path=ca_cert_path, + ) - mysql_migration.ProcessMigrations(_MigrationConnect, - config.CONFIG["Mysql.migrations_dir"]) + mysql_migration.ProcessMigrations( + _MigrationConnect, config.CONFIG["Mysql.migrations_dir"] + ) -def _GetConnectionArgs(host=None, - port=None, - user=None, - password=None, - database=None, - client_key_path=None, - client_cert_path=None, - ca_cert_path=None): +def _GetConnectionArgs( + host=None, + port=None, + user=None, + password=None, + database=None, + client_key_path=None, + client_cert_path=None, + ca_cert_path=None, +): """Builds connection arguments for MySQLdb.Connect function.""" connection_args = dict( autocommit=False, @@ -368,14 +398,16 @@ def _GetConnectionArgs(host=None, return connection_args -def _Connect(host=None, - port=None, - user=None, - password=None, - database=None, - client_key_path=None, - client_cert_path=None, - ca_cert_path=None): +def _Connect( + host=None, + port=None, + user=None, + password=None, + database=None, + client_key_path=None, + client_cert_path=None, + ca_cert_path=None, +): """Connect to MySQL and check if server fulfills requirements.""" connection_args = _GetConnectionArgs( host=host, @@ -385,7 +417,8 @@ def _Connect(host=None, database=database, client_key_path=client_key_path, client_cert_path=client_cert_path, - ca_cert_path=ca_cert_path) + ca_cert_path=ca_cert_path, + ) conn = MySQLdb.Connect(**connection_args) with contextlib.closing(conn.cursor()) as cursor: diff --git a/grr/server/grr_response_server/databases/mysql_artifacts_test.py b/grr/server/grr_response_server/databases/mysql_artifacts_test.py index eb570d2133..ed083a7623 100644 --- a/grr/server/grr_response_server/databases/mysql_artifacts_test.py +++ b/grr/server/grr_response_server/databases/mysql_artifacts_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlArtifactsTest(db_artifacts_test.DatabaseTestArtifactsMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlArtifactsTest( + db_artifacts_test.DatabaseTestArtifactsMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_blob_keys.py b/grr/server/grr_response_server/databases/mysql_blob_keys.py index 660cd1e044..355c18595d 100644 --- a/grr/server/grr_response_server/databases/mysql_blob_keys.py +++ b/grr/server/grr_response_server/databases/mysql_blob_keys.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with MySQL implementation of blobstore encryption keys methods.""" + from __future__ import annotations from typing import Collection diff --git a/grr/server/grr_response_server/databases/mysql_blob_references_test.py b/grr/server/grr_response_server/databases/mysql_blob_references_test.py index e1010315e4..884c69623a 100644 --- a/grr/server/grr_response_server/databases/mysql_blob_references_test.py +++ b/grr/server/grr_response_server/databases/mysql_blob_references_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -10,7 +9,9 @@ class MysqlBlobReferencesTest( db_blob_references_test.DatabaseTestBlobReferencesMixin, - mysql_test.MysqlTestBase, absltest.TestCase): + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_blobs.py b/grr/server/grr_response_server/databases/mysql_blobs.py index 43ca625734..2aef836145 100644 --- a/grr/server/grr_response_server/databases/mysql_blobs.py +++ b/grr/server/grr_response_server/databases/mysql_blobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The MySQL database methods for blobs handling.""" + from typing import Collection, Mapping, Optional import MySQLdb.cursors @@ -34,14 +35,18 @@ def _Insert(cursor, table, values): column_names = list(sorted(values[0])) for value_dict in values: if set(column_names) != set(value_dict): - raise ValueError("Given value dictionaries must have identical keys. " - "Expecting columns {!r}, but got value {!r}".format( - column_names, value_dict)) + raise ValueError( + "Given value dictionaries must have identical keys. " + "Expecting columns {!r}, but got value {!r}".format( + column_names, value_dict + ) + ) query = "INSERT IGNORE INTO %s {cols} VALUES {vals}" % table query = query.format( cols=mysql_utils.Columns(column_names), - vals=mysql_utils.Placeholders(num=len(column_names), values=len(values))) + vals=mysql_utils.Placeholders(num=len(column_names), values=len(values)), + ) values_list = [] for values_dict in values: @@ -59,7 +64,7 @@ def _BlobToChunks(blob_id, blob): chunks.append({ "blob_id": blob_id, "chunk_index": i, - "blob_chunk": blob[chunk_begin:chunk_begin + BLOB_CHUNK_SIZE] + "blob_chunk": blob[chunk_begin : chunk_begin + BLOB_CHUNK_SIZE], }) return chunks @@ -71,8 +76,10 @@ def _PartitionChunks(chunks): for chunk in chunks: cursize = len(chunk["blob_chunk"]) - if (cursize + partition_size > BLOB_CHUNK_SIZE or - len(partitions[-1]) >= CHUNKS_PER_INSERT): + if ( + cursize + partition_size > BLOB_CHUNK_SIZE + or len(partitions[-1]) >= CHUNKS_PER_INSERT + ): partitions.append([]) partition_size = 0 partitions[-1].append(chunk) @@ -102,12 +109,13 @@ def ReadBlobs(self, blob_ids, cursor=None): if not blob_ids: return {} - query = ("SELECT blob_id, blob_chunk " - "FROM blobs " - "FORCE INDEX (PRIMARY) " - "WHERE blob_id IN {} " - "ORDER BY blob_id, chunk_index ASC").format( - mysql_utils.Placeholders(len(blob_ids))) + query = ( + "SELECT blob_id, blob_chunk " + "FROM blobs " + "FORCE INDEX (PRIMARY) " + "WHERE blob_id IN {} " + "ORDER BY blob_id, chunk_index ASC" + ).format(mysql_utils.Placeholders(len(blob_ids))) cursor.execute(query, [bytes(blob_id) for blob_id in blob_ids]) results = {blob_id: None for blob_id in blob_ids} for blob_id_bytes, blob in cursor.fetchall(): @@ -125,11 +133,12 @@ def CheckBlobsExist(self, blob_ids, cursor=None): return {} exists = {blob_id: False for blob_id in blob_ids} - query = ("SELECT blob_id " - "FROM blobs " - "FORCE INDEX (PRIMARY) " - "WHERE blob_id IN {}".format( - mysql_utils.Placeholders(len(blob_ids)))) + query = ( + "SELECT blob_id " + "FROM blobs " + "FORCE INDEX (PRIMARY) " + "WHERE blob_id IN {}".format(mysql_utils.Placeholders(len(blob_ids))) + ) cursor.execute(query, [bytes(blob_id) for blob_id in blob_ids]) for (blob_id_bytes,) in cursor.fetchall(): exists[blobs.BlobID(blob_id_bytes)] = True @@ -167,8 +176,10 @@ def ReadHashBlobReferences( if not hashes: return {} - query = ("SELECT hash_id, blob_references FROM hash_blob_references WHERE " - "hash_id IN {}").format(mysql_utils.Placeholders(len(hashes))) + query = ( + "SELECT hash_id, blob_references FROM hash_blob_references WHERE " + "hash_id IN {}" + ).format(mysql_utils.Placeholders(len(hashes))) cursor.execute(query, [hash_id.AsBytes() for hash_id in hashes]) results = {hash_id: None for hash_id in hashes} for hash_id, blob_references in cursor.fetchall(): diff --git a/grr/server/grr_response_server/databases/mysql_blobs_test.py b/grr/server/grr_response_server/databases/mysql_blobs_test.py index 224fb6035f..759cbce4cd 100644 --- a/grr/server/grr_response_server/databases/mysql_blobs_test.py +++ b/grr/server/grr_response_server/databases/mysql_blobs_test.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for the MySQL-based blob store.""" - from absl import app from absl.testing import absltest @@ -11,8 +10,11 @@ from grr.test_lib import test_lib -class MySQLBlobStoreTest(blob_store_test_mixin.BlobStoreTestMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MySQLBlobStoreTest( + blob_store_test_mixin.BlobStoreTestMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): def CreateBlobStore(self): db, db_cleanup_fn = self.CreateDatabase() diff --git a/grr/server/grr_response_server/databases/mysql_clients.py b/grr/server/grr_response_server/databases/mysql_clients.py index 78e1317644..09f214ead6 100644 --- a/grr/server/grr_response_server/databases/mysql_clients.py +++ b/grr/server/grr_response_server/databases/mysql_clients.py @@ -1,7 +1,8 @@ #!/usr/bin/env python """The MySQL database methods for client handling.""" + import itertools -from typing import Collection, Mapping, Optional, Sequence, Tuple +from typing import Collection, Iterator, Mapping, Optional, Sequence, Tuple import MySQLdb from MySQLdb.constants import ER as mysql_error_constants @@ -9,7 +10,6 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import search as rdf_search from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 @@ -28,7 +28,6 @@ class MySQLDBClientMixin(object): def MultiWriteClientMetadata( self, client_ids: Collection[str], - certificate: Optional[rdf_crypto.RDFX509Cert] = None, first_seen: Optional[rdfvalue.RDFDatetime] = None, last_ping: Optional[rdfvalue.RDFDatetime] = None, last_clock: Optional[rdfvalue.RDFDatetime] = None, @@ -49,10 +48,6 @@ def MultiWriteClientMetadata( for i, client_id in enumerate(client_ids): values[f"client_id{i}"] = db_utils.ClientIDToInt(client_id) - if certificate is not None: - column_names.append("certificate") - common_placeholders.append("%(certificate)s") - values["certificate"] = certificate.SerializeToBytes() if first_seen is not None: column_names.append("first_seen") common_placeholders.append("FROM_UNIXTIME(%(first_seen)s)") @@ -182,10 +177,12 @@ def WriteClientSnapshot( insert_history_query = ( "INSERT INTO client_snapshot_history(client_id, timestamp, " - "client_snapshot) VALUES (%s, @now, %s)") + "client_snapshot) VALUES (%s, @now, %s)" + ) insert_startup_query = ( "INSERT INTO client_startup_history(client_id, timestamp, " - "startup_info) VALUES(%s, @now, %s)") + "startup_info) VALUES(%s, @now, %s)" + ) client_info = { "last_platform": snapshot.knowledge_base.os, @@ -198,7 +195,9 @@ def WriteClientSnapshot( update_query = ( "UPDATE clients SET {} WHERE client_id = %(client_id)s".format( - ", ".join(update_clauses))) + ", ".join(update_clauses) + ) + ) int_client_id = db_utils.ClientIDToInt(snapshot.client_id) client_info["client_id"] = int_client_id @@ -247,7 +246,8 @@ def MultiReadClientSnapshot( "AND s.client_id = c.client_id " "AND h.timestamp = c.last_snapshot_timestamp " "AND s.timestamp = c.last_startup_timestamp " - "AND c.client_id IN ({})").format(", ".join(["%s"] * len(client_ids))) + "AND c.client_id IN ({})" + ).format(", ".join(["%s"] * len(client_ids))) ret = {cid: None for cid in client_ids} cursor.execute(query, int_ids) @@ -288,13 +288,15 @@ def ReadClientSnapshotHistory( client_id_int = db_utils.ClientIDToInt(client_id) - query = ("SELECT sn.client_snapshot, st.startup_info, " - " UNIX_TIMESTAMP(sn.timestamp) FROM " - "client_snapshot_history AS sn, " - "client_startup_history AS st WHERE " - "sn.client_id = st.client_id AND " - "sn.timestamp = st.timestamp AND " - "sn.client_id=%s ") + query = ( + "SELECT sn.client_snapshot, st.startup_info, " + " UNIX_TIMESTAMP(sn.timestamp) FROM " + "client_snapshot_history AS sn, " + "client_startup_history AS st WHERE " + "sn.client_id = st.client_id AND " + "sn.timestamp = st.timestamp AND " + "sn.client_id=%s " + ) args = [client_id_int] if timerange: @@ -345,14 +347,18 @@ def WriteClientStartupInfo( (client_id, timestamp, startup_info) VALUES (%(client_id)s, @now, %(startup_info)s) - """, params) + """, + params, + ) cursor.execute( """ UPDATE clients SET last_startup_timestamp = @now WHERE client_id = %(client_id)s - """, params) + """, + params, + ) except MySQLdb.IntegrityError as e: raise db.UnknownClientError(client_id, cause=e) @@ -585,10 +591,12 @@ def MultiReadClientFullInfo( cursor.execute(query, values) return dict(self._ResponseToClientsFullInfo(cursor.fetchall())) - def ReadClientLastPings(self, - min_last_ping=None, - max_last_ping=None, - batch_size=db.CLIENT_IDS_BATCH_SIZE): + def ReadClientLastPings( + self, + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + max_last_ping: Optional[rdfvalue.RDFDatetime] = None, + batch_size: int = db.CLIENT_IDS_BATCH_SIZE, + ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: """Yields dicts of last-ping timestamps for clients in the DB.""" last_client_id = db_utils.IntToClientID(0) @@ -605,12 +613,14 @@ def ReadClientLastPings(self, break @mysql_utils.WithTransaction(readonly=True) - def _ReadClientLastPings(self, - last_client_id, - count, - min_last_ping=None, - max_last_ping=None, - cursor=None): + def _ReadClientLastPings( + self, + last_client_id: str, + count: int, + min_last_ping: rdfvalue.RDFDatetime = None, + max_last_ping: rdfvalue.RDFDatetime = None, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> Tuple[str, Mapping[str, Optional[rdfvalue.RDFDatetime]]]: """Yields dicts of last-ping timestamps for clients in the DB.""" where_filters = ["client_id > %s"] query_values = [db_utils.ClientIDToInt(last_client_id)] @@ -619,7 +629,8 @@ def _ReadClientLastPings(self, query_values.append(mysql_utils.RDFDatetimeToTimestamp(min_last_ping)) if max_last_ping is not None: where_filters.append( - "(last_ping IS NULL OR last_ping <= FROM_UNIXTIME(%s))") + "(last_ping IS NULL OR last_ping <= FROM_UNIXTIME(%s))" + ) query_values.append(mysql_utils.RDFDatetimeToTimestamp(max_last_ping)) query = """ @@ -661,9 +672,7 @@ def MultiAddClientKeywords( INSERT INTO client_keywords (client_id, keyword_hash, keyword) VALUES {} ON DUPLICATE KEY UPDATE timestamp = NOW(6) - """.format( - ", ".join(["(%s, %s, %s)"] * len(args)) - ) + """.format(", ".join(["(%s, %s, %s)"] * len(args))) try: cursor.execute(query, list(itertools.chain.from_iterable(args))) except MySQLdb.IntegrityError as error: @@ -680,8 +689,8 @@ def RemoveClientKeyword( cursor.execute( "DELETE FROM client_keywords " "WHERE client_id = %s AND keyword_hash = %s", - [db_utils.ClientIDToInt(client_id), - mysql_utils.Hash(keyword)]) + [db_utils.ClientIDToInt(client_id), mysql_utils.Hash(keyword)], + ) @mysql_utils.WithTransaction(readonly=True) def ListClientsForKeywords( @@ -755,10 +764,11 @@ def MultiReadClientLabels( """Reads the user labels for a list of clients.""" int_ids = [db_utils.ClientIDToInt(cid) for cid in client_ids] - query = ("SELECT client_id, owner_username, label " - "FROM client_labels " - "WHERE client_id IN ({})").format(", ".join(["%s"] * - len(client_ids))) + query = ( + "SELECT client_id, owner_username, label " + "FROM client_labels " + "WHERE client_id IN ({})" + ).format(", ".join(["%s"] * len(client_ids))) ret = {client_id: [] for client_id in client_ids} cursor.execute(query, int_ids) @@ -782,13 +792,18 @@ def RemoveClientLabels( ) -> None: """Removes a list of user labels from a given client.""" - query = ("DELETE FROM client_labels " - "WHERE client_id = %s AND owner_username_hash = %s " - "AND label IN ({})").format(", ".join(["%s"] * len(labels))) - args = itertools.chain([ - db_utils.ClientIDToInt(client_id), - mysql_utils.Hash(owner), - ], labels) + query = ( + "DELETE FROM client_labels " + "WHERE client_id = %s AND owner_username_hash = %s " + "AND label IN ({})" + ).format(", ".join(["%s"] * len(labels))) + args = itertools.chain( + [ + db_utils.ClientIDToInt(client_id), + mysql_utils.Hash(owner), + ], + labels, + ) cursor.execute(query, args) @mysql_utils.WithTransaction(readonly=True) @@ -823,14 +838,18 @@ def WriteClientCrashInfo( """ INSERT INTO client_crash_history (client_id, timestamp, crash_info) VALUES (%(client_id)s, @now, %(crash_info)s) - """, params) + """, + params, + ) cursor.execute( """ UPDATE clients SET last_crash_timestamp = @now WHERE client_id = %(client_id)s - """, params) + """, + params, + ) except MySQLdb.IntegrityError as e: raise db.UnknownClientError(client_id, cause=e) @@ -847,7 +866,9 @@ def ReadClientCrashInfo( "FROM clients, client_crash_history WHERE " "clients.client_id = client_crash_history.client_id AND " "clients.last_crash_timestamp = client_crash_history.timestamp AND " - "clients.client_id = %s", [db_utils.ClientIDToInt(client_id)]) + "clients.client_id = %s", + [db_utils.ClientIDToInt(client_id)], + ) row = cursor.fetchone() if not row: return None @@ -869,7 +890,9 @@ def ReadClientCrashInfoHistory( "SELECT UNIX_TIMESTAMP(timestamp), crash_info " "FROM client_crash_history WHERE " "client_crash_history.client_id = %s " - "ORDER BY timestamp DESC", [db_utils.ClientIDToInt(client_id)]) + "ORDER BY timestamp DESC", + [db_utils.ClientIDToInt(client_id)], + ) ret = [] for timestamp, crash_info in cursor.fetchall(): ci = jobs_pb2.ClientCrash() @@ -885,8 +908,10 @@ def DeleteClient( cursor: Optional[MySQLdb.cursors.Cursor] = None, ) -> None: """Deletes a client with all associated metadata.""" - cursor.execute("SELECT COUNT(*) FROM clients WHERE client_id = %s", - [db_utils.ClientIDToInt(client_id)]) + cursor.execute( + "SELECT COUNT(*) FROM clients WHERE client_id = %s", + [db_utils.ClientIDToInt(client_id)], + ) if cursor.fetchone()[0] == 0: raise db.UnknownClientError(client_id) @@ -898,15 +923,22 @@ def DeleteClient( last_crash_timestamp = NULL, last_snapshot_timestamp = NULL, last_startup_timestamp = NULL - WHERE client_id = %s""", [db_utils.ClientIDToInt(client_id)]) + WHERE client_id = %s""", + [db_utils.ClientIDToInt(client_id)], + ) - cursor.execute("DELETE FROM clients WHERE client_id = %s", - [db_utils.ClientIDToInt(client_id)]) + cursor.execute( + "DELETE FROM clients WHERE client_id = %s", + [db_utils.ClientIDToInt(client_id)], + ) - def StructuredSearchClients(self, expression: rdf_search.SearchExpression, - sort_order: rdf_search.SortOrder, - continuation_token: bytes, - number_of_results: int) -> db.SearchClientsResult: + def StructuredSearchClients( + self, + expression: rdf_search.SearchExpression, + sort_order: rdf_search.SortOrder, + continuation_token: bytes, + number_of_results: int, + ) -> db.SearchClientsResult: # Unused arguments del self, expression, sort_order, continuation_token, number_of_results raise NotImplementedError diff --git a/grr/server/grr_response_server/databases/mysql_clients_test.py b/grr/server/grr_response_server/databases/mysql_clients_test.py index add066710b..c6a2ee9c8b 100644 --- a/grr/server/grr_response_server/databases/mysql_clients_test.py +++ b/grr/server/grr_response_server/databases/mysql_clients_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlClientsTest(db_clients_test.DatabaseTestClientsMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlClientsTest( + db_clients_test.DatabaseTestClientsMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): # TODO: Enforce foreign key constraint on the `users` table. def testMultiAddClientLabelsUnknownUser(self): diff --git a/grr/server/grr_response_server/databases/mysql_cronjob_test.py b/grr/server/grr_response_server/databases/mysql_cronjob_test.py index 0bdaecc2e3..25eb4d4562 100644 --- a/grr/server/grr_response_server/databases/mysql_cronjob_test.py +++ b/grr/server/grr_response_server/databases/mysql_cronjob_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlCronJobTest(db_cronjob_test.DatabaseTestCronJobMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlCronJobTest( + db_cronjob_test.DatabaseTestCronJobMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_cronjobs.py b/grr/server/grr_response_server/databases/mysql_cronjobs.py index 3f5976fa37..e95219f361 100644 --- a/grr/server/grr_response_server/databases/mysql_cronjobs.py +++ b/grr/server/grr_response_server/databases/mysql_cronjobs.py @@ -25,11 +25,13 @@ def WriteCronJob( cursor: Optional[MySQLdb.cursors.Cursor] = None, ) -> None: """Writes a cronjob to the database.""" - query = ("INSERT INTO cron_jobs " - "(job_id, job, create_time, enabled) " - "VALUES (%s, %s, FROM_UNIXTIME(%s), %s) " - "ON DUPLICATE KEY UPDATE " - "enabled=VALUES(enabled)") + query = ( + "INSERT INTO cron_jobs " + "(job_id, job, create_time, enabled) " + "VALUES (%s, %s, FROM_UNIXTIME(%s), %s) " + "ON DUPLICATE KEY UPDATE " + "enabled=VALUES(enabled)" + ) create_time_str = mysql_utils.RDFDatetimeToTimestamp( rdfvalue.RDFDatetime().FromMicrosecondsSinceEpoch(cronjob.created_at) @@ -173,30 +175,32 @@ def DeleteCronJob( cursor.execute(query, args) @mysql_utils.WithTransaction() - def UpdateCronJob(self, - cronjob_id, - last_run_status=db.Database.unchanged, - last_run_time=db.Database.unchanged, - current_run_id=db.Database.unchanged, - state=db.Database.unchanged, - forced_run_requested=db.Database.unchanged, - cursor=None): + def UpdateCronJob( + self, + cronjob_id, + last_run_status=db.Database.UNCHANGED, + last_run_time=db.Database.UNCHANGED, + current_run_id=db.Database.UNCHANGED, + state=db.Database.UNCHANGED, + forced_run_requested=db.Database.UNCHANGED, + cursor=None, + ): """Updates run information for an existing cron job.""" updates = [] args = [] - if last_run_status != db.Database.unchanged: + if last_run_status != db.Database.UNCHANGED: updates.append("last_run_status=%s") args.append(int(last_run_status)) - if last_run_time != db.Database.unchanged: + if last_run_time != db.Database.UNCHANGED: updates.append("last_run_time=FROM_UNIXTIME(%s)") args.append(mysql_utils.RDFDatetimeToTimestamp(last_run_time)) - if current_run_id != db.Database.unchanged: + if current_run_id != db.Database.UNCHANGED: updates.append("current_run_id=%s") args.append(db_utils.CronJobRunIDToInt(current_run_id)) - if state != db.Database.unchanged: + if state != db.Database.UNCHANGED: updates.append("state=%s") args.append(state.SerializeToString()) - if forced_run_requested != db.Database.unchanged: + if forced_run_requested != db.Database.UNCHANGED: updates.append("forced_run_requested=%s") args.append(forced_run_requested) @@ -322,7 +326,8 @@ def WriteCronJobRun( ) except MySQLdb.IntegrityError as e: raise db.UnknownCronJobError( - "CronJob with id %s not found." % run_object.cron_job_id, cause=e) + "CronJob with id %s not found." % run_object.cron_job_id, cause=e + ) def _CronJobRunFromRow( self, row: Tuple[bytes, float] diff --git a/grr/server/grr_response_server/databases/mysql_events.py b/grr/server/grr_response_server/databases/mysql_events.py index 49ca074aaa..2200b4c868 100644 --- a/grr/server/grr_response_server/databases/mysql_events.py +++ b/grr/server/grr_response_server/databases/mysql_events.py @@ -1,14 +1,21 @@ #!/usr/bin/env python """The MySQL database methods for event handling.""" +from typing import Optional + +import MySQLdb + from grr_response_core.lib import rdfvalue +from grr_response_proto import objects_pb2 from grr_response_server.databases import mysql_utils -from grr_response_server.rdfvalues import objects as rdf_objects -def _AuditEntryFromRow(details, timestamp): - entry = rdf_objects.APIAuditEntry.FromSerializedBytes(details) - entry.timestamp = mysql_utils.TimestampToRDFDatetime(timestamp) +def _AuditEntryFromRow( + details: bytes, timestamp: float +) -> objects_pb2.APIAuditEntry: + entry = objects_pb2.APIAuditEntry() + entry.ParseFromString(details) + entry.timestamp = mysql_utils.TimestampToMicrosecondsSinceEpoch(timestamp) return entry @@ -16,12 +23,14 @@ class MySQLDBEventMixin(object): """MySQLDB mixin for event handling.""" @mysql_utils.WithTransaction(readonly=True) - def ReadAPIAuditEntries(self, - username=None, - router_method_names=None, - min_timestamp=None, - max_timestamp=None, - cursor=None): + def ReadAPIAuditEntries( + self, + username: Optional[str] = None, + router_method_names: Optional[list[str]] = None, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> list[objects_pb2.APIAuditEntry]: """Returns audit entries stored in the database.""" query = """SELECT details, UNIX_TIMESTAMP(timestamp) @@ -30,6 +39,7 @@ def ReadAPIAuditEntries(self, {WHERE_PLACEHOLDER} ORDER BY timestamp ASC """ + assert cursor is not None conditions = [] values = [] @@ -65,11 +75,14 @@ def ReadAPIAuditEntries(self, ] @mysql_utils.WithTransaction() - def CountAPIAuditEntriesByUserAndDay(self, - min_timestamp=None, - max_timestamp=None, - cursor=None): + def CountAPIAuditEntriesByUserAndDay( + self, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> dict[tuple[str, rdfvalue.RDFDatetime], int]: """Returns audit entry counts grouped by user and calendar day.""" + assert cursor is not None query = """ -- Timestamps are timezone-agnostic whereas dates are not. Hence, we are @@ -102,22 +115,30 @@ def CountAPIAuditEntriesByUserAndDay(self, query = query.replace("{WHERE_PLACEHOLDER}", where) cursor.execute(query, values) - return {(username, rdfvalue.RDFDatetime.FromDate(day)): count - for (username, day, count) in cursor.fetchall()} + return { + (username, rdfvalue.RDFDatetime.FromDate(day)): count + for (username, day, count) in cursor.fetchall() + } @mysql_utils.WithTransaction() - def WriteAPIAuditEntry(self, entry: rdf_objects.APIAuditEntry, cursor=None): + def WriteAPIAuditEntry( + self, + entry: objects_pb2.APIAuditEntry, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> None: """Writes an audit entry to the database.""" - if entry.timestamp is None: - datetime = rdfvalue.RDFDatetime.Now() + assert cursor is not None + + if not entry.HasField("timestamp"): + datetime = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() else: datetime = entry.timestamp args = { "username": entry.username, "router_method_name": entry.router_method_name, - "details": entry.SerializeToBytes(), - "timestamp": mysql_utils.RDFDatetimeToTimestamp(datetime), + "details": entry.SerializeToString(), + "timestamp": mysql_utils.MicrosecondsSinceEpochToTimestamp(datetime), } query = """ INSERT INTO api_audit_entry (username, router_method_name, details, diff --git a/grr/server/grr_response_server/databases/mysql_events_test.py b/grr/server/grr_response_server/databases/mysql_events_test.py index 41c615303c..a2d1521813 100644 --- a/grr/server/grr_response_server/databases/mysql_events_test.py +++ b/grr/server/grr_response_server/databases/mysql_events_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlEventsTest(db_events_test.DatabaseTestEventsMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlEventsTest( + db_events_test.DatabaseTestEventsMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_flows.py b/grr/server/grr_response_server/databases/mysql_flows.py index 261bc18a38..7f25942a1f 100644 --- a/grr/server/grr_response_server/databases/mysql_flows.py +++ b/grr/server/grr_response_server/databases/mysql_flows.py @@ -4,6 +4,9 @@ import logging import threading import time +from typing import AbstractSet +from typing import Callable +from typing import Collection from typing import Dict from typing import Iterable from typing import List @@ -11,8 +14,10 @@ from typing import Optional from typing import Sequence from typing import Text +from typing import Tuple from typing import Type from typing import TypeVar +from typing import Union import MySQLdb from MySQLdb import cursors @@ -21,18 +26,17 @@ from google.protobuf import any_pb2 from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.util import collection from grr_response_core.lib.util import random from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils from grr_response_server.databases import mysql_utils from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects -from grr_response_server.rdfvalues import objects as rdf_objects T = TypeVar("T") @@ -42,42 +46,62 @@ class MySQLDBFlowMixin(object): """MySQLDB mixin for flow handling.""" @mysql_utils.WithTransaction() - def WriteMessageHandlerRequests(self, requests, cursor=None): + def WriteMessageHandlerRequests( + self, + requests: Iterable[objects_pb2.MessageHandlerRequest], + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> None: """Writes a list of message handler requests to the database.""" - query = ("INSERT IGNORE INTO message_handler_requests " - "(handlername, request_id, request) VALUES ") + query = ( + "INSERT IGNORE INTO message_handler_requests " + "(handlername, request_id, request) VALUES " + ) value_templates = [] args = [] for r in requests: - args.extend([r.handler_name, r.request_id, r.SerializeToBytes()]) + args.extend([r.handler_name, r.request_id, r.SerializeToString()]) value_templates.append("(%s, %s, %s)") query += ",".join(value_templates) cursor.execute(query, args) @mysql_utils.WithTransaction(readonly=True) - def ReadMessageHandlerRequests(self, cursor=None): + def ReadMessageHandlerRequests( + self, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: """Reads all message handler requests from the database.""" - query = ("SELECT UNIX_TIMESTAMP(timestamp), request," - " UNIX_TIMESTAMP(leased_until), leased_by " - "FROM message_handler_requests " - "ORDER BY timestamp DESC") + query = ( + "SELECT UNIX_TIMESTAMP(timestamp), request," + " UNIX_TIMESTAMP(leased_until), leased_by " + "FROM message_handler_requests " + "ORDER BY timestamp DESC" + ) cursor.execute(query) res = [] for timestamp, request, leased_until, leased_by in cursor.fetchall(): - req = rdf_objects.MessageHandlerRequest.FromSerializedBytes(request) - req.timestamp = mysql_utils.TimestampToRDFDatetime(timestamp) - req.leased_by = leased_by - req.leased_until = mysql_utils.TimestampToRDFDatetime(leased_until) + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(request) + req.timestamp = mysql_utils.TimestampToMicrosecondsSinceEpoch(timestamp) + if leased_by is not None: + req.leased_by = leased_by + if leased_until is not None: + req.leased_until = mysql_utils.TimestampToMicrosecondsSinceEpoch( + leased_until + ) res.append(req) return res @mysql_utils.WithTransaction() - def DeleteMessageHandlerRequests(self, requests, cursor=None): + def DeleteMessageHandlerRequests( + self, + requests: Iterable[objects_pb2.MessageHandlerRequest], + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> None: """Deletes a list of message handler requests from the database.""" query = "DELETE FROM message_handler_requests WHERE request_id IN ({})" @@ -85,7 +109,12 @@ def DeleteMessageHandlerRequests(self, requests, cursor=None): query = query.format(",".join(["%s"] * len(request_ids))) cursor.execute(query, request_ids) - def RegisterMessageHandler(self, handler, lease_time, limit=1000): + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: """Leases a number of message handler requests up to the indicated limit.""" self.UnregisterMessageHandler() @@ -94,11 +123,14 @@ def RegisterMessageHandler(self, handler, lease_time, limit=1000): self.handler_thread = threading.Thread( name="message_handler", target=self._MessageHandlerLoop, - args=(handler, lease_time, limit)) + args=(handler, lease_time, limit), + ) self.handler_thread.daemon = True self.handler_thread.start() - def UnregisterMessageHandler(self, timeout=None): + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered message handler.""" if self.handler_thread: self.handler_stop = True @@ -109,7 +141,13 @@ def UnregisterMessageHandler(self, timeout=None): _MESSAGE_HANDLER_POLL_TIME_SECS = 5 - def _MessageHandlerLoop(self, handler, lease_time, limit): + def _MessageHandlerLoop( + self, + handler: Callable[[Iterable[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Loop to handle outstanding requests.""" while not self.handler_stop: try: msgs = self._LeaseMessageHandlerRequests(lease_time, limit) @@ -121,7 +159,12 @@ def _MessageHandlerLoop(self, handler, lease_time, limit): logging.exception("_LeaseMessageHandlerRequests raised %s.", e) @mysql_utils.WithTransaction() - def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None): + def _LeaseMessageHandlerRequests( + self, + lease_time: rdfvalue.Duration, + limit: int = 1000, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> Iterable[objects_pb2.MessageHandlerRequest]: """Leases a number of message handler requests up to the indicated limit.""" now = rdfvalue.RDFDatetime.Now() @@ -130,10 +173,12 @@ def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None): expiry = now + lease_time expiry_str = mysql_utils.RDFDatetimeToTimestamp(expiry) - query = ("UPDATE message_handler_requests " - "SET leased_until=FROM_UNIXTIME(%s), leased_by=%s " - "WHERE leased_until IS NULL OR leased_until < FROM_UNIXTIME(%s) " - "LIMIT %s") + query = ( + "UPDATE message_handler_requests " + "SET leased_until=FROM_UNIXTIME(%s), leased_by=%s " + "WHERE leased_until IS NULL OR leased_until < FROM_UNIXTIME(%s) " + "LIMIT %s" + ) id_str = utils.ProcessIdString() args = (expiry_str, id_str, now_str, limit) @@ -146,19 +191,25 @@ def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None): "SELECT UNIX_TIMESTAMP(timestamp), request " "FROM message_handler_requests " "WHERE leased_by=%s AND leased_until=FROM_UNIXTIME(%s) LIMIT %s", - (id_str, expiry_str, updated)) + (id_str, expiry_str, updated), + ) res = [] for timestamp, request in cursor.fetchall(): - req = rdf_objects.MessageHandlerRequest.FromSerializedBytes(request) - req.timestamp = mysql_utils.TimestampToRDFDatetime(timestamp) - req.leased_until = expiry + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(request) + req.timestamp = mysql_utils.TimestampToMicrosecondsSinceEpoch(timestamp) + req.leased_until = expiry.AsMicrosecondsSinceEpoch() req.leased_by = id_str res.append(req) - return res @mysql_utils.WithTransaction() - def WriteFlowObject(self, flow_obj, allow_update=True, cursor=None): + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> None: """Writes a flow object to the database.""" query = """ @@ -182,9 +233,11 @@ def WriteFlowObject(self, flow_obj, allow_update=True, cursor=None): last_update=VALUES(last_update)""" user_cpu_time_used_micros = db_utils.SecondsToMicros( - flow_obj.cpu_time_used.user_cpu_time) + flow_obj.cpu_time_used.user_cpu_time + ) system_cpu_time_used_micros = db_utils.SecondsToMicros( - flow_obj.cpu_time_used.system_cpu_time) + flow_obj.cpu_time_used.system_cpu_time + ) args = { "client_id": db_utils.ClientIDToInt(flow_obj.client_id), @@ -192,7 +245,7 @@ def WriteFlowObject(self, flow_obj, allow_update=True, cursor=None): "long_flow_id": flow_obj.long_flow_id, "name": flow_obj.flow_class_name, "creator": flow_obj.creator, - "flow": flow_obj.SerializeToBytes(), + "flow": flow_obj.SerializeToString(), "flow_state": int(flow_obj.flow_state), "next_request_to_process": flow_obj.next_request_to_process, "network_bytes_sent": flow_obj.network_bytes_sent, @@ -219,7 +272,7 @@ def WriteFlowObject(self, flow_obj, allow_update=True, cursor=None): else: raise db.UnknownClientError(flow_obj.client_id, cause=e) - def _FlowObjectFromRow(self, row): + def _FlowObjectFromRow(self, row) -> flows_pb2.Flow: """Generates a flow object from a database row.""" datetime = mysql_utils.TimestampToRDFDatetime cpu_time = db_utils.MicrosToSeconds @@ -235,7 +288,8 @@ def _FlowObjectFromRow(self, row): timestamp, last_update_timestamp) = row # pyformat: enable - flow_obj = rdf_flow_objects.Flow.FromSerializedBytes(flow) + flow_obj = flows_pb2.Flow() + flow_obj.ParseFromString(flow) # We treat column values as the source of truth, not the proto. flow_obj.client_id = db_utils.IntToClientID(client_id) @@ -246,35 +300,50 @@ def _FlowObjectFromRow(self, row): flow_obj.parent_flow_id = db_utils.IntToFlowID(parent_flow_id) if parent_hunt_id is not None: flow_obj.parent_hunt_id = db_utils.IntToHuntID(parent_hunt_id) + if name is not None: flow_obj.flow_class_name = name if creator is not None: flow_obj.creator = creator if flow_state not in [None, rdf_flow_objects.Flow.FlowState.UNSET]: flow_obj.flow_state = flow_state - if client_crash_info is not None: - deserialize = rdf_client.ClientCrash.FromSerializedBytes - flow_obj.client_crash_info = deserialize(client_crash_info) if next_request_to_process: flow_obj.next_request_to_process = next_request_to_process - if processing_deadline is not None: - flow_obj.processing_deadline = datetime(processing_deadline) + + # In case the create time is not stored in the serialized flow (which might + # be the case), we fallback to the timestamp information stored in the + # column. + if not flow_obj.HasField("create_time"): + flow_obj.create_time = datetime(timestamp).AsMicrosecondsSinceEpoch() + flow_obj.last_update_time = datetime( + last_update_timestamp + ).AsMicrosecondsSinceEpoch() + + if client_crash_info is not None: + flow_obj.client_crash_info.ParseFromString(client_crash_info) + + flow_obj.ClearField("processing_on") if processing_on is not None: flow_obj.processing_on = processing_on + + flow_obj.ClearField("processing_since") if processing_since is not None: - flow_obj.processing_since = datetime(processing_since) + flow_obj.processing_since = datetime( + processing_since + ).AsMicrosecondsSinceEpoch() + + flow_obj.ClearField("processing_deadline") + if processing_deadline is not None: + flow_obj.processing_deadline = datetime( + processing_deadline + ).AsMicrosecondsSinceEpoch() + flow_obj.cpu_time_used.user_cpu_time = cpu_time(user_cpu_time) flow_obj.cpu_time_used.system_cpu_time = cpu_time(system_cpu_time) flow_obj.network_bytes_sent = network_bytes_sent + if num_replies_sent: flow_obj.num_replies_sent = num_replies_sent - flow_obj.last_update_time = datetime(last_update_timestamp) - - # In case the create time is not stored in the serialized flow (which might - # be the case), we fallback to the timestamp information stored in the - # column. - if flow_obj.create_time is None: - flow_obj.create_time = datetime(timestamp) return flow_obj @@ -302,18 +371,25 @@ def _FlowObjectFromRow(self, row): )) @mysql_utils.WithTransaction(readonly=True) - def ReadFlowObject(self, client_id, flow_id, cursor=None): + def ReadFlowObject( + self, + client_id: str, + flow_id: str, + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> flows_pb2.Flow: """Reads a flow object from the database.""" - query = (f"SELECT {self.FLOW_DB_FIELDS} " - f"FROM flows WHERE client_id=%s AND flow_id=%s") + query = ( + f"SELECT {self.FLOW_DB_FIELDS} " + "FROM flows WHERE client_id=%s AND flow_id=%s" + ) cursor.execute( query, - [db_utils.ClientIDToInt(client_id), - db_utils.FlowIDToInt(flow_id)]) + [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)], + ) result = cursor.fetchall() if not result: raise db.UnknownFlowError(client_id, flow_id) - row, = result + (row,) = result return self._FlowObjectFromRow(row) @mysql_utils.WithTransaction(readonly=True) @@ -325,8 +401,8 @@ def ReadAllFlowObjects( max_create_time: Optional[rdfvalue.RDFDatetime] = None, include_child_flows: bool = True, not_created_by: Optional[Iterable[str]] = None, - cursor=None, - ) -> List[rdf_flow_objects.Flow]: + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> List[flows_pb2.Flow]: """Returns all flow objects.""" conditions = [] args = [] @@ -365,48 +441,58 @@ def ReadAllFlowObjects( return [self._FlowObjectFromRow(row) for row in cursor.fetchall()] @mysql_utils.WithTransaction() - def LeaseFlowForProcessing(self, - client_id, - flow_id, - processing_time, - cursor=None): + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + cursor: Optional[cursors.Cursor] = None, + ) -> flows_pb2.Flow: """Marks a flow as being processed on this worker and returns it.""" - query = (f"SELECT {self.FLOW_DB_FIELDS} " - f"FROM flows WHERE client_id=%s AND flow_id=%s") + query = ( + f"SELECT {self.FLOW_DB_FIELDS} " + "FROM flows WHERE client_id=%s AND flow_id=%s" + ) cursor.execute( query, - [db_utils.ClientIDToInt(client_id), - db_utils.FlowIDToInt(flow_id)]) + [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)], + ) response = cursor.fetchall() if not response: raise db.UnknownFlowError(client_id, flow_id) - row, = response - rdf_flow = self._FlowObjectFromRow(row) + (row,) = response + flow = self._FlowObjectFromRow(row) now = rdfvalue.RDFDatetime.Now() - if rdf_flow.processing_on and rdf_flow.processing_deadline > now: - raise ValueError("Flow %s on client %s is already being processed." % - (flow_id, client_id)) + if flow.processing_on and flow.processing_deadline > int(now): + raise ValueError( + "Flow %s on client %s is already being processed." + % (flow_id, client_id) + ) - if rdf_flow.parent_hunt_id is not None: + if flow.parent_hunt_id is not None: query = "SELECT hunt_state FROM hunts WHERE hunt_id=%s" - args = [db_utils.HuntIDToInt(rdf_flow.parent_hunt_id)] + args = [db_utils.HuntIDToInt(flow.parent_hunt_id)] rows_found = cursor.execute(query, args) if rows_found == 1: - hunt_state, = cursor.fetchone() - if (hunt_state is not None and - not rdf_hunt_objects.IsHuntSuitableForFlowProcessing(hunt_state)): - raise db.ParentHuntIsNotRunningError(client_id, flow_id, - rdf_flow.parent_hunt_id, - hunt_state) - - update_query = ("UPDATE flows SET " - "processing_on=%s, " - "processing_since=FROM_UNIXTIME(%s), " - "processing_deadline=FROM_UNIXTIME(%s) " - "WHERE client_id=%s and flow_id=%s") + (hunt_state,) = cursor.fetchone() + if ( + hunt_state is not None + and not rdf_hunt_objects.IsHuntSuitableForFlowProcessing(hunt_state) + ): + raise db.ParentHuntIsNotRunningError( + client_id, flow_id, flow.parent_hunt_id, hunt_state + ) + + update_query = ( + "UPDATE flows SET " + "processing_on=%s, " + "processing_since=FROM_UNIXTIME(%s), " + "processing_deadline=FROM_UNIXTIME(%s) " + "WHERE client_id=%s and flow_id=%s" + ) processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() @@ -415,59 +501,84 @@ def LeaseFlowForProcessing(self, mysql_utils.RDFDatetimeToTimestamp(now), mysql_utils.RDFDatetimeToTimestamp(processing_deadline), db_utils.ClientIDToInt(client_id), - db_utils.FlowIDToInt(flow_id) + db_utils.FlowIDToInt(flow_id), ] cursor.execute(update_query, args) # This needs to happen after we are sure that the write has succeeded. - rdf_flow.processing_on = process_id_string - rdf_flow.processing_since = now - rdf_flow.processing_deadline = processing_deadline - return rdf_flow + flow.processing_on = process_id_string + flow.processing_since = int(now) + flow.processing_deadline = int(processing_deadline) + return flow @mysql_utils.WithTransaction() - def UpdateFlow(self, - client_id, - flow_id, - flow_obj=db.Database.unchanged, - flow_state=db.Database.unchanged, - client_crash_info=db.Database.unchanged, - processing_on=db.Database.unchanged, - processing_since=db.Database.unchanged, - processing_deadline=db.Database.unchanged, - cursor=None): + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[ + flows_pb2.Flow, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + processing_on: Optional[ + Union[str, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Updates flow objects in the database.""" updates = [] args = [] - if flow_obj != db.Database.unchanged: + if isinstance(flow_obj, flows_pb2.Flow): updates.append("flow=%s") - args.append(flow_obj.SerializeToBytes()) + args.append(flow_obj.SerializeToString()) updates.append("flow_state=%s") args.append(int(flow_obj.flow_state)) updates.append("user_cpu_time_used_micros=%s") args.append( - db_utils.SecondsToMicros(flow_obj.cpu_time_used.user_cpu_time)) + db_utils.SecondsToMicros(flow_obj.cpu_time_used.user_cpu_time) + ) updates.append("system_cpu_time_used_micros=%s") args.append( - db_utils.SecondsToMicros(flow_obj.cpu_time_used.system_cpu_time)) + db_utils.SecondsToMicros(flow_obj.cpu_time_used.system_cpu_time) + ) updates.append("network_bytes_sent=%s") args.append(flow_obj.network_bytes_sent) updates.append("num_replies_sent=%s") args.append(flow_obj.num_replies_sent) - if flow_state != db.Database.unchanged: + if isinstance(flow_state, flows_pb2.Flow.FlowState.ValueType): updates.append("flow_state=%s") args.append(int(flow_state)) - if client_crash_info != db.Database.unchanged: + if isinstance(client_crash_info, jobs_pb2.ClientCrash): updates.append("client_crash_info=%s") - args.append(client_crash_info.SerializeToBytes()) - if processing_on != db.Database.unchanged: + args.append(client_crash_info.SerializeToString()) + if ( + isinstance(processing_on, str) + and processing_on is not db.Database.UNCHANGED + ) or processing_on is None: updates.append("processing_on=%s") args.append(processing_on) - if processing_since != db.Database.unchanged: + if ( + isinstance(processing_since, rdfvalue.RDFDatetime) + or processing_since is None + ): updates.append("processing_since=FROM_UNIXTIME(%s)") args.append(mysql_utils.RDFDatetimeToTimestamp(processing_since)) - if processing_deadline != db.Database.unchanged: + if ( + isinstance(processing_deadline, rdfvalue.RDFDatetime) + or processing_deadline is None + ): updates.append("processing_deadline=FROM_UNIXTIME(%s)") args.append(mysql_utils.RDFDatetimeToTimestamp(processing_deadline)) @@ -484,7 +595,11 @@ def UpdateFlow(self, if updated == 0: raise db.UnknownFlowError(client_id, flow_id) - def _WriteFlowProcessingRequests(self, requests, cursor): + def _WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + cursor: Optional[cursors.Cursor], + ) -> None: """Returns a (query, args) tuple that inserts the given requests.""" templates = [] args = [] @@ -492,19 +607,27 @@ def _WriteFlowProcessingRequests(self, requests, cursor): templates.append("(%s, %s, %s, FROM_UNIXTIME(%s))") args.append(db_utils.ClientIDToInt(req.client_id)) args.append(db_utils.FlowIDToInt(req.flow_id)) - args.append(req.SerializeToBytes()) + args.append(req.SerializeToString()) if req.delivery_time: - args.append(mysql_utils.RDFDatetimeToTimestamp(req.delivery_time)) + args.append( + mysql_utils.MicrosecondsSinceEpochToTimestamp(req.delivery_time) + ) else: args.append(None) - query = ("INSERT INTO flow_processing_requests " - "(client_id, flow_id, request, delivery_time) VALUES ") + query = ( + "INSERT INTO flow_processing_requests " + "(client_id, flow_id, request, delivery_time) VALUES " + ) query += ", ".join(templates) cursor.execute(query, args) @mysql_utils.WithTransaction() - def WriteFlowRequests(self, requests, cursor=None): + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Writes a list of flow requests to the database.""" args = [] templates = [] @@ -516,8 +639,10 @@ def WriteFlowRequests(self, requests, cursor=None): needs_processing.setdefault((r.client_id, r.flow_id), []).append(r) start_time = None - if r.start_time is not None: - start_time = r.start_time.AsDatetime() + if r.start_time: + start_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + r.start_time + ).AsDatetime() flow_keys.append((r.client_id, r.flow_id)) templates.append("(%s, %s, %s, %s, %s, %s, %s, %s)") @@ -529,7 +654,7 @@ def WriteFlowRequests(self, requests, cursor=None): r.callback_state, r.next_response_id, start_time, - r.SerializeToBytes(), + r.SerializeToString(), ]) if needs_processing: @@ -541,8 +666,9 @@ def WriteFlowRequests(self, requests, cursor=None): nr_args.append(db_utils.ClientIDToInt(client_id)) nr_args.append(db_utils.FlowIDToInt(flow_id)) - nr_query = ("SELECT client_id, flow_id, next_request_to_process " - "FROM flows WHERE ") + nr_query = ( + "SELECT client_id, flow_id, next_request_to_process FROM flows WHERE " + ) nr_query += " OR ".join(nr_conditions) cursor.execute(nr_query, nr_args) @@ -553,15 +679,14 @@ def WriteFlowRequests(self, requests, cursor=None): flow_id = db_utils.IntToFlowID(flow_id_int) candidate_requests = needs_processing.get((client_id, flow_id), []) for r in candidate_requests: - if ( - next_request_to_process == r.request_id - or r.start_time is not None - ): - flow_processing_requests.append( - rdf_flows.FlowProcessingRequest( - client_id=client_id, - flow_id=flow_id, - delivery_time=r.start_time)) + if next_request_to_process == r.request_id or r.start_time: + flow_processing_request = flows_pb2.FlowProcessingRequest( + client_id=client_id, + flow_id=flow_id, + ) + if r.start_time: + flow_processing_request.delivery_time = r.start_time + flow_processing_requests.append(flow_processing_request) if flow_processing_requests: self._WriteFlowProcessingRequests(flow_processing_requests, cursor) @@ -578,12 +703,24 @@ def WriteFlowRequests(self, requests, cursor=None): except MySQLdb.IntegrityError as e: raise db.AtLeastOneUnknownFlowError(flow_keys, cause=e) - def _WriteResponses(self, responses, cursor): + def _WriteResponses( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + cursor: Optional[cursors.Cursor], + ) -> None: """Builds the writes to store the given responses in the db.""" - query = ("INSERT IGNORE INTO flow_responses " - "(client_id, flow_id, request_id, response_id, " - "response, status, iterator, timestamp) VALUES ") + query = ( + "INSERT IGNORE INTO flow_responses " + "(client_id, flow_id, request_id, response_id, " + "response, status, iterator, timestamp) VALUES " + ) templates = [] args = [] @@ -596,18 +733,18 @@ def _WriteResponses(self, responses, cursor): args.append(flow_id_int) args.append(r.request_id) args.append(r.response_id) - if isinstance(r, rdf_flow_objects.FlowResponse): - args.append(r.SerializeToBytes()) + if isinstance(r, flows_pb2.FlowResponse): + args.append(r.SerializeToString()) args.append("") args.append("") - elif isinstance(r, rdf_flow_objects.FlowStatus): + elif isinstance(r, flows_pb2.FlowStatus): args.append("") - args.append(r.SerializeToBytes()) + args.append(r.SerializeToString()) args.append("") - elif isinstance(r, rdf_flow_objects.FlowIterator): + elif isinstance(r, flows_pb2.FlowIterator): args.append("") args.append("") - args.append(r.SerializeToBytes()) + args.append(r.SerializeToString()) else: # This can't really happen due to db api type checking. raise ValueError("Got unexpected response type: %s %s" % (type(r), r)) @@ -625,7 +762,17 @@ def _WriteResponses(self, responses, cursor): logging.warning("Response for unknown request: %s", responses[0]) @mysql_utils.WithTransaction() - def _WriteFlowResponsesAndExpectedUpdates(self, responses, cursor=None): + def _WriteFlowResponsesAndExpectedUpdates( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ] + ], + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> None: """Writes a flow responses and updates flow requests expected counts.""" self._WriteResponses(responses, cursor) @@ -642,7 +789,7 @@ def _WriteFlowResponsesAndExpectedUpdates(self, responses, cursor=None): for r in responses: # If the response is a FlowStatus, we have to update the FlowRequest with # the number of expected messages. - if isinstance(r, rdf_flow_objects.FlowStatus): + if isinstance(r, flows_pb2.FlowStatus): args = { "client_id": db_utils.ClientIDToInt(r.client_id), "flow_id": db_utils.FlowIDToInt(r.flow_id), @@ -651,7 +798,11 @@ def _WriteFlowResponsesAndExpectedUpdates(self, responses, cursor=None): } cursor.execute(query, args) - def _ReadFlowResponseCounts(self, request_keys, cursor=None): + def _ReadFlowResponseCounts( + self, + request_keys: AbstractSet[Tuple[str, str, str]], + cursor: Optional[cursors.Cursor] = None, + ) -> Mapping[Tuple[str, str, str], int]: """Reads counts of responses for the given requests.""" query = """ @@ -686,12 +837,19 @@ def _ReadFlowResponseCounts(self, request_keys, cursor=None): cursor.execute(query, args) response_counts = {} for client_id_int, flow_id_int, request_id, count in cursor.fetchall(): - request_key = (db_utils.IntToClientID(client_id_int), - db_utils.IntToFlowID(flow_id_int), request_id) + request_key = ( + db_utils.IntToClientID(client_id_int), + db_utils.IntToFlowID(flow_id_int), + request_id, + ) response_counts[request_key] = count return response_counts - def _ReadAndLockNextRequestsToProcess(self, flow_keys, cursor): + def _ReadAndLockNextRequestsToProcess( + self, + flow_keys: AbstractSet[Tuple[str, str]], + cursor: Optional[cursors.Cursor] = None, + ) -> Mapping[Tuple[str, str], str]: """Reads and locks the next_request_to_process for a number of flows.""" query = """ @@ -710,14 +868,25 @@ def _ReadAndLockNextRequestsToProcess(self, flow_keys, cursor): cursor.execute(query, args) next_requests = {} - for client_id_int, flow_id_int, next_request in cursor.fetchall(): - flow_key = (db_utils.IntToClientID(client_id_int), - db_utils.IntToFlowID(flow_id_int)) + for ( + client_id_int, + flow_id_int, + next_request, + ) in cursor.fetchall(): + flow_key = ( + db_utils.IntToClientID(client_id_int), + db_utils.IntToFlowID(flow_id_int), + ) next_requests[flow_key] = next_request + return next_requests - def _ReadLockAndUpdateAffectedRequests(self, request_keys, response_counts, - cursor): + def _ReadLockAndUpdateAffectedRequests( + self, + request_keys: AbstractSet[Tuple[str, str, str]], + response_counts: Mapping[Tuple[str, str, str], int], + cursor: Optional[cursors.Cursor] = None, + ) -> Mapping[Tuple[str, str, str], rdf_flow_objects.FlowRequest]: """Reads, locks, and updates completed requests.""" condition_template = """ @@ -741,7 +910,8 @@ def _ReadLockAndUpdateAffectedRequests(self, request_keys, response_counts, if request_key in response_counts: conditions.append(condition_template) callback_agnostic_conditions.append( - callback_agnostic_condition_template) + callback_agnostic_condition_template + ) args.append(db_utils.ClientIDToInt(client_id)) args.append(db_utils.FlowIDToInt(flow_id)) args.append(request_id) @@ -760,8 +930,11 @@ def _ReadLockAndUpdateAffectedRequests(self, request_keys, response_counts, query = query.format(conditions=" OR ".join(conditions)) cursor.execute(query, args) for client_id_int, flow_id_int, request_id, request in cursor.fetchall(): - request_key = (db_utils.IntToClientID(client_id_int), - db_utils.IntToFlowID(flow_id_int), request_id) + request_key = ( + db_utils.IntToClientID(client_id_int), + db_utils.IntToFlowID(flow_id_int), + request_id, + ) r = rdf_flow_objects.FlowRequest.FromSerializedBytes(request) affected_requests[request_key] = r @@ -776,11 +949,22 @@ def _ReadLockAndUpdateAffectedRequests(self, request_keys, response_counts, return affected_requests @mysql_utils.WithTransaction() - def _UpdateRequestsAndScheduleFPRs(self, responses, cursor=None): + def _UpdateRequestsAndScheduleFPRs( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + cursor: Optional[cursors.Cursor] = None, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Updates requests and writes FlowProcessingRequests if needed.""" request_keys = set( - (r.client_id, r.flow_id, r.request_id) for r in responses) + (r.client_id, r.flow_id, r.request_id) for r in responses + ) flow_keys = set((r.client_id, r.flow_id) for r in responses) response_counts = self._ReadFlowResponseCounts(request_keys, cursor) @@ -788,7 +972,8 @@ def _UpdateRequestsAndScheduleFPRs(self, responses, cursor=None): next_requests = self._ReadAndLockNextRequestsToProcess(flow_keys, cursor) affected_requests = self._ReadLockAndUpdateAffectedRequests( - request_keys, response_counts, cursor) + request_keys, response_counts, cursor + ) if not affected_requests: return [] @@ -797,20 +982,33 @@ def _UpdateRequestsAndScheduleFPRs(self, responses, cursor=None): for request_key, r in affected_requests.items(): client_id, flow_id, request_id = request_key if next_requests[(client_id, flow_id)] == request_id: - fprs_to_write.append( - rdf_flows.FlowProcessingRequest( - client_id=r.client_id, - flow_id=r.flow_id, - delivery_time=r.start_time)) + flow_processing_request = flows_pb2.FlowProcessingRequest( + client_id=r.client_id, + flow_id=r.flow_id, + ) + if r.start_time is not None: + flow_processing_request.delivery_time = int(r.start_time) + + fprs_to_write.append(flow_processing_request) if fprs_to_write: self._WriteFlowProcessingRequests(fprs_to_write, cursor) return affected_requests - @db_utils.CallLoggedAndAccounted - def WriteFlowResponses(self, responses): - """Writes FlowMessages and updates corresponding requests.""" + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowResponses( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + """Writes FlowResponse/FlowStatus/FlowIterator and updates corresponding requests.""" if not responses: return @@ -824,24 +1022,32 @@ def UpdateIncrementalFlowRequests( self, client_id: str, flow_id: str, - next_response_id_updates: Dict[int, int], - cursor: Optional[cursors.Cursor] = None) -> None: + next_response_id_updates: Mapping[int, int], + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Updates next response ids of given requests.""" if not next_response_id_updates: return for request_id, next_response_id in next_response_id_updates.items(): - query = ("UPDATE flow_requests SET next_response_id=%s WHERE " - "client_id=%s AND flow_id=%s AND request_id=%s") + query = ( + "UPDATE flow_requests SET next_response_id=%s WHERE " + "client_id=%s AND flow_id=%s AND request_id=%s" + ) args = [ next_response_id, db_utils.ClientIDToInt(client_id), - db_utils.FlowIDToInt(flow_id), request_id + db_utils.FlowIDToInt(flow_id), + request_id, ] cursor.execute(query, args) @mysql_utils.WithTransaction() - def DeleteFlowRequests(self, requests, cursor=None): + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Deletes a list of flow requests from the database.""" if not requests: return @@ -873,41 +1079,74 @@ def DeleteFlowRequests(self, requests, cursor=None): cursor.execute(request_query, args) @mysql_utils.WithTransaction(readonly=True) - def ReadAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> Iterable[ + Tuple[ + flows_pb2.FlowRequest, + Dict[ + int, + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ] + ]: """Reads all requests and responses for a given flow from the database.""" - query = ("SELECT request, needs_processing, responses_expected, " - "callback_state, next_response_id, UNIX_TIMESTAMP(timestamp) " - "FROM flow_requests WHERE client_id=%s AND flow_id=%s") + query = ( + "SELECT request, needs_processing, responses_expected, " + "callback_state, next_response_id, UNIX_TIMESTAMP(timestamp) " + "FROM flow_requests WHERE client_id=%s AND flow_id=%s" + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] cursor.execute(query, args) requests = [] - for (req, needs_processing, resp_expected, callback_state, next_response_id, - ts) in cursor.fetchall(): - request = rdf_flow_objects.FlowRequest.FromSerializedBytes(req) + for ( + req, + needs_processing, + resp_expected, + callback_state, + next_response_id, + ts, + ) in cursor.fetchall(): + request = flows_pb2.FlowRequest() + request.ParseFromString(req) request.needs_processing = needs_processing - request.nr_responses_expected = resp_expected + if resp_expected is not None: + request.nr_responses_expected = resp_expected request.callback_state = callback_state request.next_response_id = next_response_id - request.timestamp = mysql_utils.TimestampToRDFDatetime(ts) + request.timestamp = int(mysql_utils.TimestampToRDFDatetime(ts)) requests.append(request) - query = ("SELECT response, status, iterator, UNIX_TIMESTAMP(timestamp) " - "FROM flow_responses WHERE client_id=%s AND flow_id=%s") + query = ( + "SELECT response, status, iterator, UNIX_TIMESTAMP(timestamp) " + "FROM flow_responses WHERE client_id=%s AND flow_id=%s" + ) cursor.execute(query, args) responses = {} for res, status, iterator, ts in cursor.fetchall(): if status: - response = rdf_flow_objects.FlowStatus.FromSerializedBytes(status) + response = flows_pb2.FlowStatus() + response.ParseFromString(status) elif iterator: - response = rdf_flow_objects.FlowIterator.FromSerializedBytes(iterator) + response = flows_pb2.FlowIterator() + response.ParseFromString(iterator) else: - response = rdf_flow_objects.FlowResponse.FromSerializedBytes(res) - response.timestamp = mysql_utils.TimestampToRDFDatetime(ts) - responses.setdefault(response.request_id, - {})[response.response_id] = response + response = flows_pb2.FlowResponse() + response.ParseFromString(res) + response.timestamp = int(mysql_utils.TimestampToRDFDatetime(ts)) + responses.setdefault(response.request_id, {})[ + response.response_id + ] = response ret = [] for req in sorted(requests, key=lambda r: r.request_id): @@ -915,7 +1154,12 @@ def ReadAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): return ret @mysql_utils.WithTransaction() - def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Deletes all requests and responses for a given flow from the database.""" args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] res_query = "DELETE FROM flow_responses WHERE client_id=%s AND flow_id=%s" @@ -924,45 +1168,74 @@ def DeleteAllFlowRequestsAndResponses(self, client_id, flow_id, cursor=None): cursor.execute(req_query, args) @mysql_utils.WithTransaction(readonly=True) - def ReadFlowRequestsReadyForProcessing(self, - client_id, - flow_id, - next_needed_request, - cursor=None): + def ReadFlowRequestsReadyForProcessing( + self, + client_id: str, + flow_id: str, + next_needed_request: Optional[int] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> Dict[ + int, + Tuple[ + flows_pb2.FlowRequest, + Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: """Reads all requests for a flow that can be processed by the worker.""" - query = ("SELECT request, needs_processing, responses_expected, " - "callback_state, next_response_id, " - "UNIX_TIMESTAMP(timestamp) " - "FROM flow_requests " - "WHERE client_id=%s AND flow_id=%s") + query = ( + "SELECT request, needs_processing, responses_expected, " + "callback_state, next_response_id, " + "UNIX_TIMESTAMP(timestamp) " + "FROM flow_requests " + "WHERE client_id=%s AND flow_id=%s" + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] cursor.execute(query, args) requests = {} - for (req, needs_processing, responses_expected, callback_state, - next_response_id, ts) in cursor.fetchall(): - request = rdf_flow_objects.FlowRequest.FromSerializedBytes(req) + for ( + req, + needs_processing, + responses_expected, + callback_state, + next_response_id, + ts, + ) in cursor.fetchall(): + request = flows_pb2.FlowRequest() + request.ParseFromString(req) request.needs_processing = needs_processing - request.nr_responses_expected = responses_expected + if responses_expected is not None: + request.nr_responses_expected = responses_expected request.callback_state = callback_state request.next_response_id = next_response_id - request.timestamp = mysql_utils.TimestampToRDFDatetime(ts) + request.timestamp = int(mysql_utils.TimestampToRDFDatetime(ts)) requests[request.request_id] = request - query = ("SELECT response, status, iterator, UNIX_TIMESTAMP(timestamp) " - "FROM flow_responses " - "WHERE client_id=%s AND flow_id=%s") + query = ( + "SELECT response, status, iterator, UNIX_TIMESTAMP(timestamp) " + "FROM flow_responses " + "WHERE client_id=%s AND flow_id=%s" + ) cursor.execute(query, args) responses = {} for res, status, iterator, ts in cursor.fetchall(): if status: - response = rdf_flow_objects.FlowStatus.FromSerializedBytes(status) + response = flows_pb2.FlowStatus() + response.ParseFromString(status) elif iterator: - response = rdf_flow_objects.FlowIterator.FromSerializedBytes(iterator) + response = flows_pb2.FlowIterator() + response.ParseFromString(iterator) else: - response = rdf_flow_objects.FlowResponse.FromSerializedBytes(res) - response.timestamp = mysql_utils.TimestampToRDFDatetime(ts) + response = flows_pb2.FlowResponse() + response.ParseFromString(res) + response.timestamp = int(mysql_utils.TimestampToRDFDatetime(ts)) responses.setdefault(response.request_id, []).append(response) res = {} @@ -975,7 +1248,8 @@ def ReadFlowRequestsReadyForProcessing(self, break sorted_responses = sorted( - responses.get(next_needed_request, []), key=lambda r: r.response_id) + responses.get(next_needed_request, []), key=lambda r: r.response_id + ) res[req.request_id] = (req, sorted_responses) next_needed_request += 1 @@ -997,7 +1271,11 @@ def ReadFlowRequestsReadyForProcessing(self, return res @mysql_utils.WithTransaction() - def ReleaseProcessedFlow(self, flow_obj, cursor=None): + def ReleaseProcessedFlow( + self, + flow_obj: flows_pb2.Flow, + cursor: Optional[cursors.Cursor] = None, + ) -> bool: """Releases a flow that the worker was processing to the database.""" update_query = """ @@ -1033,56 +1311,65 @@ def ReleaseProcessedFlow(self, flow_obj, cursor=None): needs_processing.needs_processing = FALSE OR needs_processing.needs_processing IS NULL) """ + clone = flows_pb2.Flow() + clone.CopyFrom(flow_obj) + clone.ClearField("processing_on") + clone.ClearField("processing_since") + clone.ClearField("processing_deadline") - clone = flow_obj.Copy() - clone.processing_on = None - clone.processing_since = None - clone.processing_deadline = None args = { - "client_id": - db_utils.ClientIDToInt(flow_obj.client_id), - "flow": - clone.SerializeToBytes(), - "flow_id": - db_utils.FlowIDToInt(flow_obj.flow_id), - "flow_state": - int(clone.flow_state), - "network_bytes_sent": - flow_obj.network_bytes_sent, - "next_request_to_process": - flow_obj.next_request_to_process, - "num_replies_sent": - flow_obj.num_replies_sent, - "system_cpu_time_used_micros": - db_utils.SecondsToMicros(flow_obj.cpu_time_used.system_cpu_time), - "user_cpu_time_used_micros": - db_utils.SecondsToMicros(flow_obj.cpu_time_used.user_cpu_time), + "client_id": db_utils.ClientIDToInt(flow_obj.client_id), + "flow": clone.SerializeToString(), + "flow_id": db_utils.FlowIDToInt(flow_obj.flow_id), + "flow_state": int(clone.flow_state), + "network_bytes_sent": flow_obj.network_bytes_sent, + "next_request_to_process": flow_obj.next_request_to_process, + "num_replies_sent": flow_obj.num_replies_sent, + "system_cpu_time_used_micros": db_utils.SecondsToMicros( + flow_obj.cpu_time_used.system_cpu_time + ), + "user_cpu_time_used_micros": db_utils.SecondsToMicros( + flow_obj.cpu_time_used.user_cpu_time + ), } rows_updated = cursor.execute(update_query, args) return rows_updated == 1 @mysql_utils.WithTransaction() - def WriteFlowProcessingRequests(self, requests, cursor=None): + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Writes a list of flow processing requests to the database.""" self._WriteFlowProcessingRequests(requests, cursor) @mysql_utils.WithTransaction(readonly=True) - def ReadFlowProcessingRequests(self, cursor=None): + def ReadFlowProcessingRequests( + self, + cursor: Optional[cursors.Cursor] = None, + ) -> Sequence[rdf_flows.FlowProcessingRequest]: """Reads all flow processing requests from the database.""" - query = ("SELECT request, UNIX_TIMESTAMP(timestamp) " - "FROM flow_processing_requests") + query = ( + "SELECT request, UNIX_TIMESTAMP(timestamp) " + "FROM flow_processing_requests" + ) cursor.execute(query) res = [] for serialized_request, ts in cursor.fetchall(): - req = rdf_flows.FlowProcessingRequest.FromSerializedBytes( - serialized_request) - req.timestamp = mysql_utils.TimestampToRDFDatetime(ts) + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(serialized_request) + req.creation_time = int(mysql_utils.TimestampToRDFDatetime(ts)) res.append(req) return res @mysql_utils.WithTransaction() - def AckFlowProcessingRequests(self, requests, cursor=None): + def AckFlowProcessingRequests( + self, + requests: Iterable[flows_pb2.FlowProcessingRequest], + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Deletes a list of flow processing requests from the database.""" if not requests: return @@ -1093,22 +1380,29 @@ def AckFlowProcessingRequests(self, requests, cursor=None): args = [] for r in requests: conditions.append( - "(client_id=%s AND flow_id=%s AND timestamp=FROM_UNIXTIME(%s))") + "(client_id=%s AND flow_id=%s AND timestamp=FROM_UNIXTIME(%s))" + ) args.append(db_utils.ClientIDToInt(r.client_id)) args.append(db_utils.FlowIDToInt(r.flow_id)) - args.append(mysql_utils.RDFDatetimeToTimestamp(r.timestamp)) + args.append( + mysql_utils.MicrosecondsSinceEpochToTimestamp(r.creation_time) + ) query += " OR ".join(conditions) cursor.execute(query, args) @mysql_utils.WithTransaction() - def DeleteAllFlowProcessingRequests(self, cursor=None): + def DeleteAllFlowProcessingRequests( + self, cursor: Optional[cursors.Cursor] = None + ) -> None: """Deletes all flow processing requests from the database.""" query = "DELETE FROM flow_processing_requests WHERE true" cursor.execute(query) @mysql_utils.WithTransaction() - def _LeaseFlowProcessingRequests(self, limit, cursor=None): + def _LeaseFlowProcessingRequests( + self, limit: int, cursor=None + ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Leases a number of flow processing requests.""" now = rdfvalue.RDFDatetime.Now() expiry = now + rdfvalue.Duration.From(10, rdfvalue.MINUTES) @@ -1158,17 +1452,20 @@ def _LeaseFlowProcessingRequests(self, limit, cursor=None): res = [] for timestamp, request in cursor.fetchall(): - req = rdf_flows.FlowProcessingRequest.FromSerializedBytes(request) - req.timestamp = mysql_utils.TimestampToRDFDatetime(timestamp) - req.leased_until = expiry - req.leased_by = id_str + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(request) + req.creation_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + timestamp + ) res.append(req) return res _FLOW_REQUEST_POLL_TIME_SECS = 3 - def _FlowProcessingRequestHandlerLoop(self, handler): + def _FlowProcessingRequestHandlerLoop( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: """The main loop for the flow processing request queue.""" self.flow_processing_request_handler_pool.Start() @@ -1183,7 +1480,8 @@ def _FlowProcessingRequestHandlerLoop(self, handler): if msgs: for m in msgs: self.flow_processing_request_handler_pool.AddTask( - target=handler, args=(m,)) + target=handler, args=(m,) + ) else: time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) @@ -1193,7 +1491,9 @@ def _FlowProcessingRequestHandlerLoop(self, handler): self.flow_processing_request_handler_pool.Stop() - def RegisterFlowProcessingHandler(self, handler): + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: """Registers a handler to receive flow processing messages.""" self.UnregisterFlowProcessingHandler() @@ -1202,11 +1502,14 @@ def RegisterFlowProcessingHandler(self, handler): self.flow_processing_request_handler_thread = threading.Thread( name="flow_processing_request_handler", target=self._FlowProcessingRequestHandlerLoop, - args=(handler,)) + args=(handler,), + ) self.flow_processing_request_handler_thread.daemon = True self.flow_processing_request_handler_thread.start() - def UnregisterFlowProcessingHandler(self, timeout=None): + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: """Unregisters any registered flow processing handler.""" if self.flow_processing_request_handler_thread: self.flow_processing_request_handler_stop = True @@ -1224,9 +1527,11 @@ def _WriteFlowResultsOrErrors( ): """Writes flow results/errors for a given flow.""" - query = (f"INSERT INTO {table_name} " - "(client_id, flow_id, hunt_id, timestamp, payload, type, tag) " - "VALUES ") + query = ( + f"INSERT INTO {table_name} " + "(client_id, flow_id, hunt_id, timestamp, payload, type, tag) " + "VALUES " + ) templates = [] args = [] @@ -1239,7 +1544,8 @@ def _WriteFlowResultsOrErrors( else: args.append(0) args.append( - mysql_utils.RDFDatetimeToTimestamp(rdfvalue.RDFDatetime.Now())) + mysql_utils.RDFDatetimeToTimestamp(rdfvalue.RDFDatetime.Now()) + ) args.append(r.payload.value) args.append(db_utils.TypeURLToRDFTypeName(r.payload.type_url)) args.append(r.tag) @@ -1250,7 +1556,8 @@ def _WriteFlowResultsOrErrors( cursor.execute(query, args) except MySQLdb.IntegrityError as e: raise db.AtLeastOneUnknownFlowError( - [(r.client_id, r.flow_id) for r in results], cause=e) + [(r.client_id, r.flow_id) for r in results], cause=e + ) def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: """Writes flow results for a given flow.""" @@ -1370,10 +1677,12 @@ def _CountFlowResultsOrErrors( cursor: Optional[cursors.Cursor] = None, ) -> int: """Counts flow results/errors of a given flow using given query options.""" - query = ("SELECT COUNT(*) " - f"FROM {table_name} " - f"FORCE INDEX ({table_name}_by_client_id_flow_id_timestamp) " - "WHERE client_id = %s AND flow_id = %s ") + query = ( + "SELECT COUNT(*) " + f"FROM {table_name} " + f"FORCE INDEX ({table_name}_by_client_id_flow_id_timestamp) " + "WHERE client_id = %s AND flow_id = %s " + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] if with_tag is not None: @@ -1400,7 +1709,8 @@ def CountFlowResults( client_id, flow_id, with_tag=with_tag, - with_type=with_type) + with_type=with_type, + ) @mysql_utils.WithTransaction(readonly=True) def _CountFlowResultsOrErrorsByType( @@ -1411,10 +1721,12 @@ def _CountFlowResultsOrErrorsByType( cursor: Optional[cursors.Cursor] = None, ) -> Mapping[str, int]: """Returns counts of flow results/errors grouped by result type.""" - query = (f"SELECT type, COUNT(*) FROM {table_name} " - f"FORCE INDEX ({table_name}_by_client_id_flow_id_timestamp) " - "WHERE client_id = %s AND flow_id = %s " - "GROUP BY type") + query = ( + f"SELECT type, COUNT(*) FROM {table_name} " + f"FORCE INDEX ({table_name}_by_client_id_flow_id_timestamp) " + "WHERE client_id = %s AND flow_id = %s " + "GROUP BY type" + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] cursor.execute(query, args) @@ -1425,8 +1737,9 @@ def CountFlowResultsByType( self, client_id: str, flow_id: str ) -> Mapping[str, int]: """Returns counts of flow results grouped by result type.""" - return self._CountFlowResultsOrErrorsByType("flow_results", client_id, - flow_id) + return self._CountFlowResultsOrErrorsByType( + "flow_results", client_id, flow_id + ) def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: """Writes flow errors for a given flow.""" @@ -1478,7 +1791,8 @@ def CountFlowErrors( client_id, flow_id, with_tag=with_tag, - with_type=with_type) + with_type=with_type, + ) def CountFlowErrorsByType( self, client_id: str, flow_id: str @@ -1488,8 +1802,9 @@ def CountFlowErrorsByType( # concept. Error is a kind of a negative result. Given the structural # similarity, we can share large chunks of implementation between # errors and results DB code. - return self._CountFlowResultsOrErrorsByType("flow_errors", client_id, - flow_id) + return self._CountFlowResultsOrErrorsByType( + "flow_errors", client_id, flow_id + ) @mysql_utils.WithTransaction() def WriteFlowLogEntry( @@ -1506,7 +1821,7 @@ def WriteFlowLogEntry( args = { "client_id": db_utils.ClientIDToInt(entry.client_id), "flow_id": db_utils.FlowIDToInt(entry.flow_id), - "message": entry.message + "message": entry.message, } if entry.hunt_id: @@ -1531,10 +1846,12 @@ def ReadFlowLogEntries( ) -> Sequence[flows_pb2.FlowLogEntry]: """Reads flow log entries of a given flow using given query options.""" - query = ("SELECT message, UNIX_TIMESTAMP(timestamp) " - "FROM flow_log_entries " - "FORCE INDEX (flow_log_entries_by_flow) " - "WHERE client_id = %s AND flow_id = %s ") + query = ( + "SELECT message, UNIX_TIMESTAMP(timestamp) " + "FROM flow_log_entries " + "FORCE INDEX (flow_log_entries_by_flow) " + "WHERE client_id = %s AND flow_id = %s " + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] if with_substring is not None: @@ -1569,10 +1886,12 @@ def CountFlowLogEntries( ) -> int: """Returns number of flow log entries of a given flow.""" - query = ("SELECT COUNT(*) " - "FROM flow_log_entries " - "FORCE INDEX (flow_log_entries_by_flow) " - "WHERE client_id = %s AND flow_id = %s ") + query = ( + "SELECT COUNT(*) " + "FROM flow_log_entries " + "FORCE INDEX (flow_log_entries_by_flow) " + "WHERE client_id = %s AND flow_id = %s " + ) args = [db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id)] cursor.execute(query, args) @@ -1593,16 +1912,13 @@ def WriteFlowOutputPluginLogEntry( %(type)s, %(message)s) """ args = { - "client_id": - db_utils.ClientIDToInt(entry.client_id), - "flow_id": - db_utils.FlowIDToInt(entry.flow_id), - "output_plugin_id": - db_utils.OutputPluginIDToInt(entry.output_plugin_id), - "type": - int(entry.log_entry_type), - "message": - entry.message, + "client_id": db_utils.ClientIDToInt(entry.client_id), + "flow_id": db_utils.FlowIDToInt(entry.flow_id), + "output_plugin_id": db_utils.OutputPluginIDToInt( + entry.output_plugin_id + ), + "type": int(entry.log_entry_type), + "message": entry.message, } if entry.hunt_id: @@ -1629,14 +1945,16 @@ def ReadFlowOutputPluginLogEntries( cursor: Optional[cursors.Cursor] = None, ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: """Reads flow output plugin log entries.""" - query = ("SELECT log_entry_type, message, UNIX_TIMESTAMP(timestamp) " - "FROM flow_output_plugin_log_entries " - "FORCE INDEX (flow_output_plugin_log_entries_by_flow) " - "WHERE client_id = %s AND flow_id = %s AND output_plugin_id = %s ") + query = ( + "SELECT log_entry_type, message, UNIX_TIMESTAMP(timestamp) " + "FROM flow_output_plugin_log_entries " + "FORCE INDEX (flow_output_plugin_log_entries_by_flow) " + "WHERE client_id = %s AND flow_id = %s AND output_plugin_id = %s " + ) args = [ db_utils.ClientIDToInt(client_id), db_utils.FlowIDToInt(flow_id), - db_utils.OutputPluginIDToInt(output_plugin_id) + db_utils.OutputPluginIDToInt(output_plugin_id), ] if with_type is not None: @@ -1667,20 +1985,27 @@ def ReadFlowOutputPluginLogEntries( return ret @mysql_utils.WithTransaction(readonly=True) - def CountFlowOutputPluginLogEntries(self, - client_id, - flow_id, - output_plugin_id, - with_type=None, - cursor=None): + def CountFlowOutputPluginLogEntries( + self, + client_id: str, + flow_id: str, + output_plugin_id: str, + with_type: Optional[ + flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ValueType + ] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> int: """Returns number of flow output plugin log entries of a given flow.""" - query = ("SELECT COUNT(*) " - "FROM flow_output_plugin_log_entries " - "FORCE INDEX (flow_output_plugin_log_entries_by_flow) " - "WHERE client_id = %s AND flow_id = %s AND output_plugin_id = %s ") + query = ( + "SELECT COUNT(*) " + "FROM flow_output_plugin_log_entries " + "FORCE INDEX (flow_output_plugin_log_entries_by_flow) " + "WHERE client_id = %s AND flow_id = %s AND output_plugin_id = %s " + ) args = [ db_utils.ClientIDToInt(client_id), - db_utils.FlowIDToInt(flow_id), output_plugin_id + db_utils.FlowIDToInt(flow_id), + output_plugin_id, ] if with_type is not None: @@ -1718,8 +2043,7 @@ def WriteScheduledFlow( query = """ REPLACE INTO scheduled_flows {cols} VALUES {vals} - """.format( - cols=mysql_utils.Columns(args), vals=vals) + """.format(cols=mysql_utils.Columns(args), vals=vals) try: cursor.execute(query, args) @@ -1748,17 +2072,20 @@ def DeleteScheduledFlow( client_id = %s AND creator_username_hash = %s AND scheduled_flow_id = %s - """, [ - db_utils.ClientIDToInt(client_id), - mysql_utils.Hash(creator), - db_utils.FlowIDToInt(scheduled_flow_id), - ]) + """, + [ + db_utils.ClientIDToInt(client_id), + mysql_utils.Hash(creator), + db_utils.FlowIDToInt(scheduled_flow_id), + ], + ) if cursor.rowcount == 0: raise db.UnknownScheduledFlowError( client_id=client_id, creator=creator, - scheduled_flow_id=scheduled_flow_id) + scheduled_flow_id=scheduled_flow_id, + ) @mysql_utils.WithTransaction(readonly=True) def ListScheduledFlows( diff --git a/grr/server/grr_response_server/databases/mysql_flows_large_test.py b/grr/server/grr_response_server/databases/mysql_flows_large_test.py index 5b693596e0..2af5e0ba1a 100644 --- a/grr/server/grr_response_server/databases/mysql_flows_large_test.py +++ b/grr/server/grr_response_server/databases/mysql_flows_large_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlFlowTest(db_flows_test.DatabaseLargeTestFlowMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlFlowTest( + db_flows_test.DatabaseLargeTestFlowMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_flows_test.py b/grr/server/grr_response_server/databases/mysql_flows_test.py index d22f7dc3ab..1057be7159 100644 --- a/grr/server/grr_response_server/databases/mysql_flows_test.py +++ b/grr/server/grr_response_server/databases/mysql_flows_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlFlowTest(db_flows_test.DatabaseTestFlowMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlFlowTest( + db_flows_test.DatabaseTestFlowMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_foreman_rules.py b/grr/server/grr_response_server/databases/mysql_foreman_rules.py index 00ecace2df..fbd60b9087 100644 --- a/grr/server/grr_response_server/databases/mysql_foreman_rules.py +++ b/grr/server/grr_response_server/databases/mysql_foreman_rules.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """The MySQL database methods for foreman rule handling.""" + from typing import Optional, Sequence import MySQLdb diff --git a/grr/server/grr_response_server/databases/mysql_foreman_rules_test.py b/grr/server/grr_response_server/databases/mysql_foreman_rules_test.py index 395dc745bc..f85568dc54 100644 --- a/grr/server/grr_response_server/databases/mysql_foreman_rules_test.py +++ b/grr/server/grr_response_server/databases/mysql_foreman_rules_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlForemanRulesTest(db_foreman_rules_test.DatabaseTestForemanRulesMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlForemanRulesTest( + db_foreman_rules_test.DatabaseTestForemanRulesMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_hunts.py b/grr/server/grr_response_server/databases/mysql_hunts.py index 957de282de..024a7496b7 100644 --- a/grr/server/grr_response_server/databases/mysql_hunts.py +++ b/grr/server/grr_response_server/databases/mysql_hunts.py @@ -1,26 +1,32 @@ #!/usr/bin/env python """The MySQL database methods for flow handling.""" +from collections.abc import Callable +from typing import AbstractSet +from typing import Collection +from typing import Iterable +from typing import List +from typing import Mapping from typing import Optional from typing import Sequence +from typing import Tuple import MySQLdb from MySQLdb import cursors +from google.protobuf import any_pb2 from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import stats as rdf_stats from grr_response_proto import flows_pb2 from grr_response_proto import hunts_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto import output_plugin_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils from grr_response_server.databases import mysql_utils from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner -from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects -from grr_response_server.rdfvalues import objects as rdf_objects -from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin _HUNT_COLUMNS_SELECT = ", ".join(( "UNIX_TIMESTAMP(create_timestamp)", @@ -41,6 +47,7 @@ _HUNT_OUTPUT_PLUGINS_STATES_COLUMNS = ( "plugin_name", "plugin_args", + "plugin_args_any", "plugin_state", ) @@ -83,16 +90,18 @@ def WriteHuntObject( @mysql_utils.WithTransaction() def UpdateHuntObject( self, - hunt_id, - duration=None, - client_rate=None, - client_limit=None, - hunt_state=None, - hunt_state_reason=None, - hunt_state_comment=None, - start_time=None, - num_clients_at_start_time=None, - cursor=None, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[float] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, + cursor: Optional[cursors.Cursor] = None, ): """Updates the hunt object by applying the update function.""" vals = [] @@ -149,7 +158,11 @@ def UpdateHuntObject( raise db.UnknownHuntError(hunt_id) @mysql_utils.WithTransaction() - def DeleteHuntObject(self, hunt_id, cursor=None): + def DeleteHuntObject( + self, + hunt_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> None: """Deletes a given hunt object.""" query = "DELETE FROM hunts WHERE hunt_id = %s" hunt_id_int = db_utils.HuntIDToInt(hunt_id) @@ -169,7 +182,7 @@ def DeleteHuntObject(self, hunt_id, cursor=None): """ args = { "approval_type": int( - rdf_objects.ApprovalRequest.ApprovalType.APPROVAL_TYPE_HUNT + objects_pb2.ApprovalRequest.ApprovalType.APPROVAL_TYPE_HUNT ), "hunt_id": hunt_id, } @@ -192,12 +205,18 @@ def _HuntObjectFromRow(self, row): description, body, ) = row - hunt_obj = rdf_hunt_objects.Hunt.FromSerializedBytes(body) - hunt_obj.duration = rdfvalue.Duration.From(duration_micros, - rdfvalue.MICROSECONDS) - hunt_obj.create_time = mysql_utils.TimestampToRDFDatetime(create_time) - hunt_obj.last_update_time = mysql_utils.TimestampToRDFDatetime( - last_update_time) + hunt_obj = hunts_pb2.Hunt() + hunt_obj.ParseFromString(body) + hunt_obj.duration = rdfvalue.DurationSeconds.From( + duration_micros, rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS) + + hunt_obj.create_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + create_time + ) + hunt_obj.last_update_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + last_update_time + ) # Checks below are needed for hunts that were written to the database before # respective fields became part of F1 schema. @@ -217,12 +236,14 @@ def _HuntObjectFromRow(self, row): hunt_obj.hunt_state_comment = hunt_state_comment if init_start_time is not None: - hunt_obj.init_start_time = mysql_utils.TimestampToRDFDatetime( - init_start_time) + hunt_obj.init_start_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + init_start_time + ) if last_start_time is not None: - hunt_obj.last_start_time = mysql_utils.TimestampToRDFDatetime( - last_start_time) + hunt_obj.last_start_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + last_start_time + ) if num_clients_at_start_time is not None: hunt_obj.num_clients_at_start_time = num_clients_at_start_time @@ -233,31 +254,36 @@ def _HuntObjectFromRow(self, row): return hunt_obj @mysql_utils.WithTransaction(readonly=True) - def ReadHuntObject(self, hunt_id, cursor=None): + def ReadHuntObject( + self, hunt_id: str, cursor: Optional[cursors.Cursor] = None + ) -> hunts_pb2.Hunt: """Reads a hunt object from the database.""" - query = ("SELECT {columns} " - "FROM hunts WHERE hunt_id = %s".format( - columns=_HUNT_COLUMNS_SELECT)) + query = "SELECT {columns} FROM hunts WHERE hunt_id = %s".format( + columns=_HUNT_COLUMNS_SELECT + ) nr_results = cursor.execute(query, [db_utils.HuntIDToInt(hunt_id)]) if nr_results == 0: raise db.UnknownHuntError(hunt_id) - return self._HuntObjectFromRow(cursor.fetchone()) + hunt = self._HuntObjectFromRow(cursor.fetchone()) + return hunt @mysql_utils.WithTransaction(readonly=True) def ReadHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - cursor=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> List[hunts_pb2.Hunt]: """Reads multiple hunt objects from the database.""" query = "SELECT {columns} FROM hunts ".format(columns=_HUNT_COLUMNS_SELECT) args = [] @@ -283,7 +309,6 @@ def ReadHuntObjects( # implementation does not know how to convert a `frozenset` to a string. # The cursor implementation knows how to convert lists and ordinary sets. args.append(list(not_created_by)) - if created_after is not None: components.append("create_timestamp > FROM_UNIXTIME(%s) ") args.append(mysql_utils.RDFDatetimeToTimestamp(created_after)) @@ -311,16 +336,18 @@ def ReadHuntObjects( @mysql_utils.WithTransaction(readonly=True) def ListHuntObjects( self, - offset, - count, - with_creator=None, - created_after=None, - with_description_match=None, - created_by=None, - not_created_by=None, - with_states=None, - cursor=None, - ): + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> List[hunts_pb2.HuntMetadata]: """Reads metadata for hunt objects from the database.""" query = """ SELECT @@ -385,62 +412,113 @@ def ListHuntObjects( cursor.execute(query, args) result = [] for row in cursor.fetchall(): - (hunt_id, create_timestamp, last_update_timestamp, creator, - duration_micros, client_rate, client_limit, hunt_state, - hunt_state_comment, init_start_time, last_start_time, description) = row - result.append( - rdf_hunt_objects.HuntMetadata( - hunt_id=db_utils.IntToHuntID(hunt_id), - description=description or None, - create_time=mysql_utils.TimestampToRDFDatetime(create_timestamp), - creator=creator, - duration=rdfvalue.Duration.From(duration_micros, - rdfvalue.MICROSECONDS), - client_rate=client_rate, - client_limit=client_limit, - hunt_state=hunt_state, - hunt_state_comment=hunt_state_comment or None, - last_update_time=mysql_utils.TimestampToRDFDatetime( - last_update_timestamp), - init_start_time=mysql_utils.TimestampToRDFDatetime( - init_start_time), - last_start_time=mysql_utils.TimestampToRDFDatetime( - last_start_time))) + ( + hunt_id, + create_timestamp, + last_update_timestamp, + creator, + duration_micros, + client_rate, + client_limit, + hunt_state, + hunt_state_comment, + init_start_time, + last_start_time, + description, + ) = row + + hunt_metadata = hunts_pb2.HuntMetadata( + hunt_id=db_utils.IntToHuntID(hunt_id), + create_time=int(mysql_utils.TimestampToRDFDatetime(create_timestamp)), + creator=creator, + duration=rdfvalue.Duration.From( + duration_micros, rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS), + client_rate=client_rate, + client_limit=client_limit, + hunt_state=hunt_state, + ) + + if description: + hunt_metadata.description = description + + if hunt_state_comment: + hunt_metadata.hunt_state_comment = hunt_state_comment + + if last_update_timestamp := mysql_utils.TimestampToRDFDatetime( + last_update_timestamp + ): + hunt_metadata.last_update_time = int(last_update_timestamp) + + if init_start_time := mysql_utils.TimestampToRDFDatetime(init_start_time): + hunt_metadata.init_start_time = int(init_start_time) + + if last_start_time := mysql_utils.TimestampToRDFDatetime(last_start_time): + hunt_metadata.last_start_time = int(last_start_time) + + result.append(hunt_metadata) return result - def _HuntOutputPluginStateFromRow(self, row): + def _HuntOutputPluginStateFromRow( + self, row: Tuple[str, bytes, bytes] + ) -> output_plugin_pb2.OutputPluginState: """Builds OutputPluginState object from a DB row.""" - plugin_name, plugin_args_bytes, plugin_state_bytes = row - - plugin_descriptor = rdf_output_plugin.OutputPluginDescriptor( - plugin_name=plugin_name) - if plugin_args_bytes is not None: - plugin_args_cls = plugin_descriptor.GetPluginArgsClass() - # If plugin_args_cls is None, we have no clue what class plugin args - # should be and therefore no way to deserialize it. This can happen if - # a plugin got renamed or removed, for example. In this case we - # still want to get plugin's definition and state back and not fail hard, - # so that all other plugins can be read. - if plugin_args_cls is not None: - plugin_descriptor.args = plugin_args_cls.FromSerializedBytes( - plugin_args_bytes) - else: # Avoid missing information even if we cannot unpack it. - plugin_descriptor.args = plugin_args_bytes - - plugin_state = rdf_protodict.AttributedDict.FromSerializedBytes( - plugin_state_bytes) - return rdf_flow_runner.OutputPluginState( - plugin_descriptor=plugin_descriptor, plugin_state=plugin_state) + ( + plugin_name, + plugin_args_bytes, + plugin_args_any_bytes, + plugin_state_bytes, + ) = row + + if plugin_args_any_bytes is not None: + plugin_args_any = any_pb2.Any() + plugin_args_any.ParseFromString(plugin_args_any_bytes) + elif plugin_args_bytes is not None: + # TODO: The db migration added a new column but didn't + # backfill the data, so a fallback to parse the old format is implemented + # here. Remove this fallback mechanism after the new format has been + # adopted and old data is not needed anymore. + if plugin_name in rdfvalue.RDFValue.classes: + plugin_args_any = any_pb2.Any( + type_url=db_utils.RDFTypeNameToTypeURL(plugin_name), + value=plugin_args_bytes, + ) + else: + unrecognized = objects_pb2.SerializedValueOfUnrecognizedType( + type_name=plugin_name, value=plugin_args_bytes + ) + plugin_args_any = any_pb2.Any() + plugin_args_any.Pack(unrecognized) + else: + plugin_args_any = None + + plugin_descriptor = output_plugin_pb2.OutputPluginDescriptor( + plugin_name=plugin_name, args=plugin_args_any + ) + + plugin_state = jobs_pb2.AttributedDict() + plugin_state.ParseFromString(plugin_state_bytes) + + return output_plugin_pb2.OutputPluginState( + plugin_descriptor=plugin_descriptor, + plugin_state=plugin_state, + ) @mysql_utils.WithTransaction(readonly=True) - def ReadHuntOutputPluginsStates(self, hunt_id, cursor=None): + def ReadHuntOutputPluginsStates( + self, + hunt_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> List[output_plugin_pb2.OutputPluginState]: """Reads all hunt output plugins states of a given hunt.""" columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS) - query = ("SELECT {columns} FROM hunt_output_plugins_states " - "WHERE hunt_id = %s".format(columns=columns)) + query = ( + "SELECT {columns} FROM hunt_output_plugins_states " + "WHERE hunt_id = %s".format(columns=columns) + ) rows_returned = cursor.execute(query, [db_utils.HuntIDToInt(hunt_id)]) if rows_returned > 0: states = [] @@ -456,27 +534,38 @@ def ReadHuntOutputPluginsStates(self, hunt_id, cursor=None): return [] @mysql_utils.WithTransaction() - def WriteHuntOutputPluginsStates(self, hunt_id, states, cursor=None): + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + cursor: Optional[cursors.Cursor] = None, + ): """Writes hunt output plugin states for a given hunt.""" columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS) placeholders = mysql_utils.Placeholders( - 2 + len(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS)) + 2 + len(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS) + ) hunt_id_int = db_utils.HuntIDToInt(hunt_id) for index, state in enumerate(states): - query = ("INSERT INTO hunt_output_plugins_states " - "(hunt_id, plugin_id, {columns}) " - "VALUES {placeholders}".format( - columns=columns, placeholders=placeholders)) + query = ( + "INSERT INTO hunt_output_plugins_states " + "(hunt_id, plugin_id, {columns}) " + "VALUES {placeholders}".format( + columns=columns, placeholders=placeholders + ) + ) args = [hunt_id_int, index, state.plugin_descriptor.plugin_name] if state.plugin_descriptor.HasField("args"): - args.append(state.plugin_descriptor.args.SerializeToBytes()) + args.append(state.plugin_descriptor.args.value) + args.append(state.plugin_descriptor.args.SerializeToString()) else: args.append(None) + args.append(None) - args.append(state.plugin_state.SerializeToBytes()) + args.append(state.plugin_state.SerializeToString()) try: cursor.execute(query, args) @@ -484,11 +573,16 @@ def WriteHuntOutputPluginsStates(self, hunt_id, states, cursor=None): raise db.UnknownHuntError(hunt_id=hunt_id, cause=e) @mysql_utils.WithTransaction() - def UpdateHuntOutputPluginState(self, - hunt_id, - state_index, - update_fn, - cursor=None): + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[ + [jobs_pb2.AttributedDict], + jobs_pb2.AttributedDict, + ], + cursor: Optional[cursors.Cursor] = None, + ) -> jobs_pb2.AttributedDict: """Updates hunt output plugin state for a given output plugin.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) @@ -499,8 +593,10 @@ def UpdateHuntOutputPluginState(self, raise db.UnknownHuntError(hunt_id) columns = ", ".join(_HUNT_OUTPUT_PLUGINS_STATES_COLUMNS) - query = ("SELECT {columns} FROM hunt_output_plugins_states " - "WHERE hunt_id = %s AND plugin_id = %s".format(columns=columns)) + query = ( + "SELECT {columns} FROM hunt_output_plugins_states " + "WHERE hunt_id = %s AND plugin_id = %s".format(columns=columns) + ) rows_returned = cursor.execute(query, [hunt_id_int, state_index]) if rows_returned == 0: raise db.UnknownHuntOutputPluginStateError(hunt_id, state_index) @@ -508,10 +604,12 @@ def UpdateHuntOutputPluginState(self, state = self._HuntOutputPluginStateFromRow(cursor.fetchone()) modified_plugin_state = update_fn(state.plugin_state) - query = ("UPDATE hunt_output_plugins_states " - "SET plugin_state = %s " - "WHERE hunt_id = %s AND plugin_id = %s") - args = [modified_plugin_state.SerializeToBytes(), hunt_id_int, state_index] + query = ( + "UPDATE hunt_output_plugins_states " + "SET plugin_state = %s " + "WHERE hunt_id = %s AND plugin_id = %s" + ) + args = [modified_plugin_state.SerializeToString(), hunt_id_int, state_index] cursor.execute(query, args) return state @@ -527,10 +625,12 @@ def ReadHuntLogEntries( """Reads hunt log entries of a given hunt using given query options.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) - query = ("SELECT client_id, flow_id, message, UNIX_TIMESTAMP(timestamp) " - "FROM flow_log_entries " - "FORCE INDEX(flow_log_entries_by_hunt) " - "WHERE hunt_id = %s AND flow_id = hunt_id ") + query = ( + "SELECT client_id, flow_id, message, UNIX_TIMESTAMP(timestamp) " + "FROM flow_log_entries " + "FORCE INDEX(flow_log_entries_by_hunt) " + "WHERE hunt_id = %s AND flow_id = hunt_id " + ) args = [hunt_id_int] @@ -568,27 +668,31 @@ def CountHuntLogEntries( """Returns number of hunt log entries of a given hunt.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) - query = ("SELECT COUNT(*) FROM flow_log_entries " - "FORCE INDEX(flow_log_entries_by_hunt) " - "WHERE hunt_id = %s AND flow_id = hunt_id") + query = ( + "SELECT COUNT(*) FROM flow_log_entries " + "FORCE INDEX(flow_log_entries_by_hunt) " + "WHERE hunt_id = %s AND flow_id = hunt_id" + ) cursor.execute(query, [hunt_id_int]) return cursor.fetchone()[0] @mysql_utils.WithTransaction(readonly=True) - def ReadHuntResults(self, - hunt_id, - offset, - count, - with_tag=None, - with_type=None, - with_substring=None, - with_timestamp=None, - cursor=None): + def ReadHuntResults( + self, + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> Iterable[flows_pb2.FlowResult]: """Reads hunt results of a given hunt using given query options.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) query = """ - SELECT client_id, flow_id, hunt_id, payload, type, + SELECT client_id, flow_id, payload, type, UNIX_TIMESTAMP(timestamp), tag FROM flow_results FORCE INDEX(flow_results_hunt_id_flow_id_timestamp) @@ -623,25 +727,31 @@ def ReadHuntResults(self, for ( client_id_int, flow_id_int, - hunt_id_int, serialized_payload, payload_type, timestamp, tag, ) in cursor.fetchall(): + if payload_type in rdfvalue.RDFValue.classes: - payload = rdfvalue.RDFValue.classes[payload_type].FromSerializedBytes( - serialized_payload) + payload = any_pb2.Any( + type_url=db_utils.RDFTypeNameToTypeURL(payload_type), + value=serialized_payload, + ) else: - payload = rdf_objects.SerializedValueOfUnrecognizedType( - type_name=payload_type, value=serialized_payload) + unrecognized = objects_pb2.SerializedValueOfUnrecognizedType( + type_name=payload_type, value=serialized_payload + ) + payload = any_pb2.Any() + payload.Pack(unrecognized) - result = rdf_flow_objects.FlowResult( + result = flows_pb2.FlowResult( client_id=db_utils.IntToClientID(client_id_int), flow_id=db_utils.IntToFlowID(flow_id_int), hunt_id=hunt_id, payload=payload, - timestamp=mysql_utils.TimestampToRDFDatetime(timestamp)) + timestamp=mysql_utils.TimestampToMicrosecondsSinceEpoch(timestamp), + ) if tag is not None: result.tag = tag @@ -650,11 +760,13 @@ def ReadHuntResults(self, return ret @mysql_utils.WithTransaction(readonly=True) - def CountHuntResults(self, - hunt_id, - with_tag=None, - with_type=None, - cursor=None): + def CountHuntResults( + self, + hunt_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + cursor: Optional[cursors.Cursor] = None, + ) -> int: """Counts hunt results of a given hunt using given query options.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) @@ -674,12 +786,18 @@ def CountHuntResults(self, return cursor.fetchone()[0] @mysql_utils.WithTransaction(readonly=True) - def CountHuntResultsByType(self, hunt_id, cursor=None): + def CountHuntResultsByType( + self, + hunt_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> Mapping[str, int]: """Counts number of hunts results per type.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) - query = ("SELECT type, COUNT(*) FROM flow_results " - "WHERE hunt_id = %s GROUP BY type") + query = ( + "SELECT type, COUNT(*) FROM flow_results " + "WHERE hunt_id = %s GROUP BY type" + ) cursor.execute(query, [hunt_id_int]) return dict(cursor.fetchall()) @@ -689,62 +807,85 @@ def _HuntFlowCondition(self, condition): if condition == db.HuntFlowsCondition.UNSET: return "", [] elif condition == db.HuntFlowsCondition.FAILED_FLOWS_ONLY: - return ("AND flow_state = %s ", - [int(rdf_flow_objects.Flow.FlowState.ERROR)]) + return ( + "AND flow_state = %s ", + [int(rdf_flow_objects.Flow.FlowState.ERROR)], + ) elif condition == db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY: - return ("AND flow_state = %s ", - [int(rdf_flow_objects.Flow.FlowState.FINISHED)]) + return ( + "AND flow_state = %s ", + [int(rdf_flow_objects.Flow.FlowState.FINISHED)], + ) elif condition == db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY: - return ("AND (flow_state = %s OR flow_state = %s) ", [ - int(rdf_flow_objects.Flow.FlowState.FINISHED), - int(rdf_flow_objects.Flow.FlowState.ERROR) - ]) + return ( + "AND (flow_state = %s OR flow_state = %s) ", + [ + int(rdf_flow_objects.Flow.FlowState.FINISHED), + int(rdf_flow_objects.Flow.FlowState.ERROR), + ], + ) elif condition == db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY: - return ("AND flow_state = %s ", - [int(rdf_flow_objects.Flow.FlowState.RUNNING)]) + return ( + "AND flow_state = %s ", + [int(rdf_flow_objects.Flow.FlowState.RUNNING)], + ) elif condition == db.HuntFlowsCondition.CRASHED_FLOWS_ONLY: - return ("AND flow_state = %s ", - [int(rdf_flow_objects.Flow.FlowState.CRASHED)]) + return ( + "AND flow_state = %s ", + [int(rdf_flow_objects.Flow.FlowState.CRASHED)], + ) else: raise ValueError("Invalid condition value: %r" % condition) @mysql_utils.WithTransaction(readonly=True) - def ReadHuntFlows(self, - hunt_id, - offset, - count, - filter_condition=db.HuntFlowsCondition.UNSET, - cursor=None): + def ReadHuntFlows( + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: db.HuntFlowsCondition = db.HuntFlowsCondition.UNSET, + cursor: Optional[cursors.Cursor] = None, + ) -> Sequence[flows_pb2.Flow]: """Reads hunt flows matching given conditins.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) - query = ("SELECT {columns} FROM flows " - "FORCE INDEX(flows_by_hunt) " - "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL " - "{filter_condition} " - "ORDER BY last_update ASC " - "LIMIT %s OFFSET %s") + query = ( + "SELECT {columns} FROM flows " + "FORCE INDEX(flows_by_hunt) " + "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL " + "{filter_condition} " + "ORDER BY last_update ASC " + "LIMIT %s OFFSET %s" + ) filter_query, extra_args = self._HuntFlowCondition(filter_condition) query = query.format( - columns=self.FLOW_DB_FIELDS, filter_condition=filter_query) + columns=self.FLOW_DB_FIELDS, filter_condition=filter_query + ) args = [hunt_id_int] + extra_args + [count, offset] cursor.execute(query, args) - return [self._FlowObjectFromRow(row) for row in cursor.fetchall()] + flows = [self._FlowObjectFromRow(row) for row in cursor.fetchall()] + return flows @mysql_utils.WithTransaction(readonly=True) - def CountHuntFlows(self, - hunt_id, - filter_condition=db.HuntFlowsCondition.UNSET, - cursor=None): + def CountHuntFlows( + self, + hunt_id: str, + filter_condition: Optional[ + db.HuntFlowsCondition + ] = db.HuntFlowsCondition.UNSET, + cursor=None, + ) -> int: """Counts hunt flows matching given conditions.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) - query = ("SELECT COUNT(*) FROM flows " - "FORCE INDEX(flows_by_hunt) " - "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL " - "{filter_condition}") + query = ( + "SELECT COUNT(*) FROM flows " + "FORCE INDEX(flows_by_hunt) " + "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL " + "{filter_condition}" + ) filter_query, extra_args = self._HuntFlowCondition(filter_condition) args = [hunt_id_int] + extra_args @@ -753,68 +894,99 @@ def CountHuntFlows(self, return cursor.fetchone()[0] @mysql_utils.WithTransaction(readonly=True) - def ReadHuntCounters(self, hunt_id, cursor=None): - """Reads hunt counters.""" - hunt_id_int = db_utils.HuntIDToInt(hunt_id) - - query = ("SELECT flow_state, COUNT(*) " - "FROM flows " - "FORCE INDEX(flows_by_hunt) " - "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL " - "GROUP BY flow_state") + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + cursor: Optional[cursors.Cursor] = None, + ) -> Mapping[str, db.HuntCounters]: + """Reads hunt counters for several hunt ids.""" + if not hunt_ids: + return {} - cursor.execute(query, [hunt_id_int]) - counts_by_state = dict(cursor.fetchall()) - - num_successful_clients = counts_by_state.get( - int(rdf_flow_objects.Flow.FlowState.FINISHED), 0) - num_failed_clients = counts_by_state.get( - int(rdf_flow_objects.Flow.FlowState.ERROR), 0) - num_crashed_clients = counts_by_state.get( - int(rdf_flow_objects.Flow.FlowState.CRASHED), 0) - num_running_clients = counts_by_state.get( - int(rdf_flow_objects.Flow.FlowState.RUNNING), 0) - num_clients = sum(counts_by_state.values()) + hunt_ids_ints = [db_utils.HuntIDToInt(hunt_id) for hunt_id in hunt_ids] query = """ - SELECT * FROM - - ( - SELECT COUNT(client_id) - FROM flows - FORCE INDEX(flows_by_hunt) - WHERE parent_hunt_id = %s AND parent_flow_id IS NULL AND - num_replies_sent > 0) counters, + SELECT parent_hunt_id, flow_state, COUNT(*) + FROM flows + FORCE INDEX(flows_by_hunt) + WHERE parent_hunt_id IN %(hunt_ids)s + AND parent_flow_id IS NULL + GROUP BY parent_hunt_id, flow_state + """ + cursor.execute(query, {"hunt_ids": tuple(hunt_ids_ints)}) + + counts_by_state_per_hunt = dict.fromkeys(hunt_ids_ints, {}) + for hunt_id, state, count in cursor.fetchall(): + counts_by_state_per_hunt[hunt_id][state] = count + + hunt_counters = dict.fromkeys( + hunt_ids, + db.HuntCounters( + num_clients=0, + num_successful_clients=0, + num_failed_clients=0, + num_clients_with_results=0, + num_crashed_clients=0, + num_running_clients=0, + num_results=0, + total_cpu_seconds=0, + total_network_bytes_sent=0, + ), + ) - ( - SELECT SUM(user_cpu_time_used_micros + system_cpu_time_used_micros), - SUM(network_bytes_sent), - SUM(num_replies_sent) - FROM flows - FORCE INDEX(flows_by_hunt) - WHERE parent_hunt_id = %s AND parent_flow_id IS NULL) resources + query = """ + SELECT + parent_hunt_id, + SUM(user_cpu_time_used_micros + system_cpu_time_used_micros), + SUM(network_bytes_sent), + SUM(num_replies_sent), + COUNT(IF(num_replies_sent > 0, client_id, NULL)) + FROM flows + FORCE INDEX(flows_by_hunt) + WHERE parent_hunt_id IN %(hunt_ids)s + AND parent_flow_id IS NULL + GROUP BY parent_hunt_id """ + cursor.execute(query, {"hunt_ids": tuple(hunt_ids_ints)}) - cursor.execute(query, [hunt_id_int, hunt_id_int]) - ( - num_clients_with_results, + for ( + hunt_id, total_cpu_seconds, total_network_bytes_sent, num_results, - ) = cursor.fetchone() - - return db.HuntCounters( - num_clients=num_clients, - num_successful_clients=num_successful_clients, - num_failed_clients=num_failed_clients, - num_clients_with_results=num_clients_with_results, - num_crashed_clients=num_crashed_clients, - num_running_clients=num_running_clients, - num_results=int(num_results or 0), - total_cpu_seconds=db_utils.MicrosToSeconds(int(total_cpu_seconds or 0)), - total_network_bytes_sent=int(total_network_bytes_sent or 0)) - - def _BinsToQuery(self, bins, column_name): + num_clients_with_results, + ) in cursor.fetchall(): + counts_by_state = counts_by_state_per_hunt[hunt_id] + num_successful_clients = counts_by_state.get( + int(rdf_flow_objects.Flow.FlowState.FINISHED), 0 + ) + num_failed_clients = counts_by_state.get( + int(rdf_flow_objects.Flow.FlowState.ERROR), 0 + ) + num_crashed_clients = counts_by_state.get( + int(rdf_flow_objects.Flow.FlowState.CRASHED), 0 + ) + num_running_clients = counts_by_state.get( + int(rdf_flow_objects.Flow.FlowState.RUNNING), 0 + ) + num_clients = sum(counts_by_state_per_hunt[hunt_id].values()) + + hunt_counters[db_utils.IntToHuntID(hunt_id)] = db.HuntCounters( + num_clients=num_clients, + num_successful_clients=num_successful_clients, + num_failed_clients=num_failed_clients, + num_clients_with_results=num_clients_with_results, + num_crashed_clients=num_crashed_clients, + num_running_clients=num_running_clients, + num_results=int(num_results or 0), + total_cpu_seconds=db_utils.MicrosToSeconds( + int(total_cpu_seconds or 0) + ), + total_network_bytes_sent=int(total_network_bytes_sent or 0), + ) + return hunt_counters + + def _BinsToQuery(self, bins: Sequence[int], column_name: str) -> str: """Builds an SQL query part to fetch counts corresponding to given bins.""" result = [] # With the current StatsHistogram implementation the last bin simply @@ -833,7 +1005,9 @@ def _BinsToQuery(self, bins, column_name): return ", ".join(result) @mysql_utils.WithTransaction(readonly=True) - def ReadHuntClientResourcesStats(self, hunt_id, cursor=None): + def ReadHuntClientResourcesStats( + self, hunt_id: str, cursor: Optional[cursors.Cursor] = None + ) -> jobs_pb2.ClientResourcesStats: """Read/calculate hunt client resources stats.""" hunt_id_int = db_utils.HuntIDToInt(hunt_id) @@ -852,35 +1026,46 @@ def ReadHuntClientResourcesStats(self, hunt_id, cursor=None): int(1000000 * b) for b in rdf_stats.ClientResourcesStats.CPU_STATS_BINS ] - query += self._BinsToQuery(scaled_bins, "(user_cpu_time_used_micros)") - query += "," - query += self._BinsToQuery(scaled_bins, "(system_cpu_time_used_micros)") - query += "," - query += self._BinsToQuery( - rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS, "network_bytes_sent") + query += ", ".join([ + self._BinsToQuery(scaled_bins, "(user_cpu_time_used_micros)"), + self._BinsToQuery(scaled_bins, "(system_cpu_time_used_micros)"), + self._BinsToQuery( + rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS, + "network_bytes_sent", + ), + ]) query += " FROM flows " query += "FORCE INDEX(flows_by_hunt) " - query += "WHERE parent_hunt_id = %s AND parent_flow_id IS NULL" + query += "WHERE parent_hunt_id = %s " + query += "AND parent_flow_id IS NULL " + query += "AND flow_id = %s" - cursor.execute(query, [hunt_id_int]) + cursor.execute(query, [hunt_id_int, hunt_id_int]) response = cursor.fetchone() - (count, user_sum, user_stddev, system_sum, system_stddev, network_sum, - network_stddev) = response[:7] - - stats = rdf_stats.ClientResourcesStats( - user_cpu_stats=rdf_stats.RunningStats( + ( + count, + user_sum, + user_stddev, + system_sum, + system_stddev, + network_sum, + network_stddev, + ) = response[:7] + + stats = jobs_pb2.ClientResourcesStats( + user_cpu_stats=jobs_pb2.RunningStats( num=count, sum=db_utils.MicrosToSeconds(int(user_sum or 0)), stddev=int(user_stddev or 0) / 1e6, ), - system_cpu_stats=rdf_stats.RunningStats( + system_cpu_stats=jobs_pb2.RunningStats( num=count, sum=db_utils.MicrosToSeconds(int(system_sum or 0)), stddev=int(system_stddev or 0) / 1e6, ), - network_bytes_sent_stats=rdf_stats.RunningStats( + network_bytes_sent_stats=jobs_pb2.RunningStats( num=count, sum=float(network_sum or 0), stddev=float(network_stddev or 0), @@ -888,25 +1073,34 @@ def ReadHuntClientResourcesStats(self, hunt_id, cursor=None): ) offset = 7 - stats.user_cpu_stats.histogram = rdf_stats.StatsHistogram() + user_cpu_histogram = jobs_pb2.StatsHistogram() for b_num, b_max_value in zip( - response[offset:], rdf_stats.ClientResourcesStats.CPU_STATS_BINS): - stats.user_cpu_stats.histogram.bins.append( - rdf_stats.StatsHistogramBin(range_max_value=b_max_value, num=b_num)) + response[offset:], rdf_stats.ClientResourcesStats.CPU_STATS_BINS + ): + user_cpu_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.user_cpu_stats.histogram.CopyFrom(user_cpu_histogram) offset += len(rdf_stats.ClientResourcesStats.CPU_STATS_BINS) - stats.system_cpu_stats.histogram = rdf_stats.StatsHistogram() + system_cpu_histogram = jobs_pb2.StatsHistogram() for b_num, b_max_value in zip( - response[offset:], rdf_stats.ClientResourcesStats.CPU_STATS_BINS): - stats.system_cpu_stats.histogram.bins.append( - rdf_stats.StatsHistogramBin(range_max_value=b_max_value, num=b_num)) + response[offset:], rdf_stats.ClientResourcesStats.CPU_STATS_BINS + ): + system_cpu_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.system_cpu_stats.histogram.CopyFrom(system_cpu_histogram) offset += len(rdf_stats.ClientResourcesStats.CPU_STATS_BINS) - stats.network_bytes_sent_stats.histogram = rdf_stats.StatsHistogram() + network_bytes_histogram = jobs_pb2.StatsHistogram() for b_num, b_max_value in zip( - response[offset:], rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS): - stats.network_bytes_sent_stats.histogram.bins.append( - rdf_stats.StatsHistogramBin(range_max_value=b_max_value, num=b_num)) + response[offset:], rdf_stats.ClientResourcesStats.NETWORK_STATS_BINS + ): + network_bytes_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.network_bytes_sent_stats.histogram.CopyFrom(network_bytes_histogram) query = """ SELECT @@ -914,7 +1108,7 @@ def ReadHuntClientResourcesStats(self, hunt_id, cursor=None): system_cpu_time_used_micros, network_bytes_sent FROM flows FORCE INDEX(flows_by_hunt) - WHERE parent_hunt_id = %s AND parent_flow_id IS NULL AND + WHERE parent_hunt_id = %s AND parent_flow_id IS NULL AND flow_id = %s AND (user_cpu_time_used_micros > 0 OR system_cpu_time_used_micros > 0 OR network_bytes_sent > 0) @@ -922,25 +1116,31 @@ def ReadHuntClientResourcesStats(self, hunt_id, cursor=None): LIMIT 10 """ - cursor.execute(query, [hunt_id_int]) + cursor.execute(query, [hunt_id_int, hunt_id_int]) for cid, fid, ucpu, scpu, nbs in cursor.fetchall(): client_id = db_utils.IntToClientID(cid) flow_id = db_utils.IntToFlowID(fid) stats.worst_performers.append( - rdf_client_stats.ClientResources( - client_id=client_id, - session_id=rdfvalue.RDFURN(client_id).Add(flow_id), - cpu_usage=rdf_client_stats.CpuSeconds( + jobs_pb2.ClientResources( + client_id=str(rdf_client.ClientURN.FromHumanReadable(client_id)), + session_id=str(rdfvalue.RDFURN(client_id).Add(flow_id)), + cpu_usage=jobs_pb2.CpuSeconds( user_cpu_time=db_utils.MicrosToSeconds(ucpu), system_cpu_time=db_utils.MicrosToSeconds(scpu), ), - network_bytes_sent=nbs)) + network_bytes_sent=nbs, + ) + ) return stats @mysql_utils.WithTransaction(readonly=True) - def ReadHuntFlowsStatesAndTimestamps(self, hunt_id, cursor=None): + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + cursor: Optional[cursors.Cursor] = None, + ) -> Sequence[db.FlowStateAndTimestamps]: """Reads hunt flows states and timestamps.""" query = """ @@ -959,7 +1159,9 @@ def ReadHuntFlowsStatesAndTimestamps(self, hunt_id, cursor=None): db.FlowStateAndTimestamps( flow_state=rdf_flow_objects.Flow.FlowState.FromInt(fs), create_time=mysql_utils.TimestampToRDFDatetime(ct), - last_update_time=mysql_utils.TimestampToRDFDatetime(lup))) + last_update_time=mysql_utils.TimestampToRDFDatetime(lup), + ) + ) return result @@ -976,14 +1178,16 @@ def ReadHuntOutputPluginLogEntries( cursor: Optional[cursors.Cursor] = None, ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: """Reads hunt output plugin log entries.""" - query = ("SELECT client_id, flow_id, log_entry_type, message, " - "UNIX_TIMESTAMP(timestamp) " - "FROM flow_output_plugin_log_entries " - "FORCE INDEX (flow_output_plugin_log_entries_by_hunt) " - "WHERE hunt_id = %s AND output_plugin_id = %s ") + query = ( + "SELECT client_id, flow_id, log_entry_type, message, " + "UNIX_TIMESTAMP(timestamp) " + "FROM flow_output_plugin_log_entries " + "FORCE INDEX (flow_output_plugin_log_entries_by_hunt) " + "WHERE hunt_id = %s AND output_plugin_id = %s " + ) args = [ db_utils.HuntIDToInt(hunt_id), - db_utils.OutputPluginIDToInt(output_plugin_id) + db_utils.OutputPluginIDToInt(output_plugin_id), ] if with_type is not None: @@ -997,8 +1201,13 @@ def ReadHuntOutputPluginLogEntries( cursor.execute(query, args) ret = [] - for (client_id_int, flow_id_int, log_entry_type, message, - timestamp) in cursor.fetchall(): + for ( + client_id_int, + flow_id_int, + log_entry_type, + message, + timestamp, + ) in cursor.fetchall(): ret.append( flows_pb2.FlowOutputPluginLogEntry( hunt_id=hunt_id, @@ -1026,13 +1235,15 @@ def CountHuntOutputPluginLogEntries( cursor: Optional[cursors.Cursor] = None, ): """Counts hunt output plugin log entries.""" - query = ("SELECT COUNT(*) " - "FROM flow_output_plugin_log_entries " - "FORCE INDEX (flow_output_plugin_log_entries_by_hunt) " - "WHERE hunt_id = %s AND output_plugin_id = %s ") + query = ( + "SELECT COUNT(*) " + "FROM flow_output_plugin_log_entries " + "FORCE INDEX (flow_output_plugin_log_entries_by_hunt) " + "WHERE hunt_id = %s AND output_plugin_id = %s " + ) args = [ db_utils.HuntIDToInt(hunt_id), - db_utils.OutputPluginIDToInt(output_plugin_id) + db_utils.OutputPluginIDToInt(output_plugin_id), ] if with_type is not None: diff --git a/grr/server/grr_response_server/databases/mysql_hunts_test.py b/grr/server/grr_response_server/databases/mysql_hunts_test.py index 2b118d656e..aaeed4520c 100644 --- a/grr/server/grr_response_server/databases/mysql_hunts_test.py +++ b/grr/server/grr_response_server/databases/mysql_hunts_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -9,9 +8,12 @@ from grr.test_lib import test_lib -class MysqlHuntTest(db_hunts_test.DatabaseTestHuntMixin, - db_test_utils.QueryTestHelpersMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlHuntTest( + db_hunts_test.DatabaseTestHuntMixin, + db_test_utils.QueryTestHelpersMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_message_handler_test.py b/grr/server/grr_response_server/databases/mysql_message_handler_test.py index 5f6a7af3af..fd5cae427e 100644 --- a/grr/server/grr_response_server/databases/mysql_message_handler_test.py +++ b/grr/server/grr_response_server/databases/mysql_message_handler_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlHandlerTest(db_message_handler_test.DatabaseTestHandlerMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlHandlerTest( + db_message_handler_test.DatabaseTestHandlerMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_migration.py b/grr/server/grr_response_server/databases/mysql_migration.py index 02331de3c6..efc22bf0c7 100644 --- a/grr/server/grr_response_server/databases/mysql_migration.py +++ b/grr/server/grr_response_server/databases/mysql_migration.py @@ -25,21 +25,24 @@ def _MigrationFilenameToInt(fname: Text) -> int: return int(base) -def ListMigrationsToProcess(migrations_root: Text, - current_migration_number: Optional[int] - ) -> Sequence[Text]: +def ListMigrationsToProcess( + migrations_root: Text, current_migration_number: Optional[int] +) -> Sequence[Text]: """Lists filenames of migrations with numbers bigger than a given one.""" migrations = [] for m in os.listdir(migrations_root): - if (current_migration_number is None or - _MigrationFilenameToInt(m) > current_migration_number): + if ( + current_migration_number is None + or _MigrationFilenameToInt(m) > current_migration_number + ): migrations.append(m) return sorted(migrations, key=_MigrationFilenameToInt) -def ProcessMigrations(open_conn_fn: Callable[[], Connection], - migrations_root: Text) -> None: +def ProcessMigrations( + open_conn_fn: Callable[[], Connection], migrations_root: Text +) -> None: """Processes migrations from a given folder. This function uses LOCK TABLE MySQL command on _migrations @@ -74,8 +77,9 @@ def ProcessMigrations(open_conn_fn: Callable[[], Connection], current_migration = GetLatestMigrationNumber(cursor) to_process = ListMigrationsToProcess(migrations_root, current_migration) - logging.info("Will execute following DB migrations: %s", - ", ".join(to_process)) + logging.info( + "Will execute following DB migrations: %s", ", ".join(to_process) + ) for fname in to_process: start_time = time.time() @@ -86,13 +90,16 @@ def ProcessMigrations(open_conn_fn: Callable[[], Connection], with contextlib.closing(conn.cursor()) as cursor: cursor.execute(sql) - logging.info("Migration %s is done. Took %.2fs", fname, - time.time() - start_time) + logging.info( + "Migration %s is done. Took %.2fs", fname, time.time() - start_time + ) # Update _migrations table with the latest migration. with contextlib.closing(conn.cursor()) as cursor: - cursor.execute("INSERT INTO _migrations (migration_id) VALUES (%s)", - [_MigrationFilenameToInt(fname)]) + cursor.execute( + "INSERT INTO _migrations (migration_id) VALUES (%s)", + [_MigrationFilenameToInt(fname)], + ) finally: with contextlib.closing(conn.cursor()) as cursor: cursor.execute('SELECT RELEASE_LOCK("grr_migration")') @@ -100,10 +107,12 @@ def ProcessMigrations(open_conn_fn: Callable[[], Connection], def DumpCurrentSchema(cursor: Cursor) -> Text: """Dumps current database schema.""" - cursor.execute("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " - "WHERE table_schema = (SELECT DATABASE())") + cursor.execute( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE table_schema = (SELECT DATABASE())" + ) defs = [] - for table, in sorted(cursor.fetchall()): + for (table,) in sorted(cursor.fetchall()): cursor.execute("SHOW CREATE TABLE `{}`".format(table)) rows = cursor.fetchall() defs.append(rows[0][1]) diff --git a/grr/server/grr_response_server/databases/mysql_migration_test.py b/grr/server/grr_response_server/databases/mysql_migration_test.py index 6a05a47d1c..5c0e415bd5 100644 --- a/grr/server/grr_response_server/databases/mysql_migration_test.py +++ b/grr/server/grr_response_server/databases/mysql_migration_test.py @@ -32,24 +32,28 @@ def testReturnsAllMigrationsWhenCurrentNumberIsNone(self): fnames = ["0000.sql", "0001.sql", "0002.sql"] self._CreateMigrationFiles(fnames) self.assertListEqual( - mysql_migration.ListMigrationsToProcess(self.temp_dir, None), fnames) + mysql_migration.ListMigrationsToProcess(self.temp_dir, None), fnames + ) def testReturnsOnlyMigrationsWithNumbersBiggerThanCurrentMigrationIndex(self): fnames = ["0000.sql", "0001.sql", "0002.sql", "0003.sql"] self._CreateMigrationFiles(fnames) self.assertListEqual( mysql_migration.ListMigrationsToProcess(self.temp_dir, 1), - ["0002.sql", "0003.sql"]) + ["0002.sql", "0003.sql"], + ) def testDoesNotAssumeLexicalSortingOrder(self): fnames = ["7.sql", "8.sql", "9.sql", "10.sql"] self._CreateMigrationFiles(fnames) self.assertListEqual( - mysql_migration.ListMigrationsToProcess(self.temp_dir, None), fnames) + mysql_migration.ListMigrationsToProcess(self.temp_dir, None), fnames + ) -class MySQLMigrationTest(mysql_test.MySQLDatabaseProviderMixin, - absltest.TestCase): +class MySQLMigrationTest( + mysql_test.MySQLDatabaseProviderMixin, absltest.TestCase +): def _GetLatestMigrationNumber(self, conn): with contextlib.closing(conn.cursor()) as cursor: @@ -57,14 +61,17 @@ def _GetLatestMigrationNumber(self, conn): def testMigrationsTableIsCorrectlyUpdates(self): all_migrations = mysql_migration.ListMigrationsToProcess( - config.CONFIG["Mysql.migrations_dir"], None) + config.CONFIG["Mysql.migrations_dir"], None + ) self.assertEqual( self._conn._RunInTransaction(self._GetLatestMigrationNumber), - len(all_migrations) - 1) + len(all_migrations) - 1, + ) def _DumpSchema(self, conn): with contextlib.closing(conn.cursor()) as cursor: return mysql_migration.DumpCurrentSchema(cursor) + if __name__ == "__main__": app.run(test_lib.main) diff --git a/grr/server/grr_response_server/databases/mysql_migrations/0024.sql b/grr/server/grr_response_server/databases/mysql_migrations/0024.sql new file mode 100644 index 0000000000..f55a7b50e0 --- /dev/null +++ b/grr/server/grr_response_server/databases/mysql_migrations/0024.sql @@ -0,0 +1,2 @@ +ALTER TABLE hunt_output_plugins_states +ADD COLUMN plugin_args_any MEDIUMBLOB DEFAULT NULL AFTER plugin_args; diff --git a/grr/server/grr_response_server/databases/mysql_paths.py b/grr/server/grr_response_server/databases/mysql_paths.py index 218a6be9e3..b2d88dd35c 100644 --- a/grr/server/grr_response_server/databases/mysql_paths.py +++ b/grr/server/grr_response_server/databases/mysql_paths.py @@ -16,7 +16,6 @@ from grr_response_server.databases import db_utils from grr_response_server.databases import mysql_utils from grr_response_server.models import paths -from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -27,21 +26,22 @@ class MySQLDBPathMixin(object): def ReadPathInfo( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, cursor: Optional[MySQLdb.cursors.Cursor] = None, - ) -> rdf_objects.PathInfo: + ) -> objects_pb2.PathInfo: """Retrieves a path info record for a given path.""" assert cursor is not None if timestamp is None: path_infos = self.ReadPathInfos(client_id, path_type, [components]) - path_info = path_infos[components] + path_info = path_infos[tuple(components)] if path_info is None: raise db.UnknownPathError( - client_id=client_id, path_type=path_type, components=components) + client_id=client_id, path_type=path_type, components=components + ) return path_info # If/when support for MySQL 5.x is dropped, this query can be cleaned up @@ -111,7 +111,8 @@ def ReadPathInfo( row = cursor.fetchone() if row is None: raise db.UnknownPathError( - client_id=client_id, path_type=path_type, components=components) + client_id=client_id, path_type=path_type, components=components + ) # pyformat: disable (directory, timestamp, @@ -119,7 +120,7 @@ def ReadPathInfo( hash_entry_bytes, last_hash_entry_timestamp) = row # pyformat: enable - proto_path_info = objects_pb2.PathInfo( + path_info = objects_pb2.PathInfo( path_type=objects_pb2.PathInfo.PathType.Name(path_type), components=components, directory=directory, @@ -127,31 +128,27 @@ def ReadPathInfo( datetime = mysql_utils.TimestampToMicrosecondsSinceEpoch if timestamp is not None: - proto_path_info.timestamp = datetime(timestamp) + path_info.timestamp = datetime(timestamp) if last_stat_entry_timestamp is not None: - proto_path_info.last_stat_entry_timestamp = datetime( - last_stat_entry_timestamp - ) + path_info.last_stat_entry_timestamp = datetime(last_stat_entry_timestamp) if last_hash_entry_timestamp is not None: - proto_path_info.last_hash_entry_timestamp = datetime( - last_hash_entry_timestamp - ) + path_info.last_hash_entry_timestamp = datetime(last_hash_entry_timestamp) if stat_entry_bytes is not None: - proto_path_info.stat_entry.ParseFromString(stat_entry_bytes) + path_info.stat_entry.ParseFromString(stat_entry_bytes) if hash_entry_bytes is not None: - proto_path_info.hash_entry.ParseFromString(hash_entry_bytes) + path_info.hash_entry.ParseFromString(hash_entry_bytes) - return mig_objects.ToRDFPathInfo(proto_path_info) + return path_info @mysql_utils.WithTransaction(readonly=True) def ReadPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components_list: Collection[Sequence[str]], cursor: Optional[MySQLdb.cursors.Cursor] = None, - ) -> dict[Sequence[str], Optional[rdf_objects.PathInfo]]: + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: """Retrieves path info records for given paths.""" assert cursor is not None @@ -160,7 +157,7 @@ def ReadPathInfos( path_ids = list(map(rdf_objects.PathID.FromComponents, components_list)) - path_infos = {components: None for components in components_list} + path_infos = {tuple(components): None for components in components_list} query = """ SELECT p.path, p.directory, UNIX_TIMESTAMP(p.timestamp), @@ -203,7 +200,7 @@ def ReadPathInfos( # pyformat: enable components = mysql_utils.PathToComponents(path) - proto_path_info = objects_pb2.PathInfo( + path_info = objects_pb2.PathInfo( path_type=objects_pb2.PathInfo.PathType.Name(path_type), components=components, directory=directory, @@ -211,22 +208,22 @@ def ReadPathInfos( datetime = mysql_utils.TimestampToMicrosecondsSinceEpoch if timestamp is not None: - proto_path_info.timestamp = datetime(timestamp) + path_info.timestamp = datetime(timestamp) if last_stat_entry_timestamp is not None: - proto_path_info.last_stat_entry_timestamp = datetime( + path_info.last_stat_entry_timestamp = datetime( last_stat_entry_timestamp ) if last_hash_entry_timestamp is not None: - proto_path_info.last_hash_entry_timestamp = datetime( + path_info.last_hash_entry_timestamp = datetime( last_hash_entry_timestamp ) if stat_entry_bytes is not None: - proto_path_info.stat_entry.ParseFromString(stat_entry_bytes) + path_info.stat_entry.ParseFromString(stat_entry_bytes) if hash_entry_bytes is not None: - proto_path_info.hash_entry.ParseFromString(hash_entry_bytes) + path_info.hash_entry.ParseFromString(hash_entry_bytes) - path_infos[components] = mig_objects.ToRDFPathInfo(proto_path_info) + path_infos[tuple(components)] = path_info return path_infos @@ -234,21 +231,18 @@ def ReadPathInfos( def WritePathInfos( self, client_id: str, - path_infos: Sequence[rdf_objects.PathInfo], + path_infos: Sequence[objects_pb2.PathInfo], cursor: Optional[MySQLdb.cursors.Cursor] = None, ) -> None: """Writes a collection of path_info records for a client.""" assert cursor is not None now = mysql_utils.RDFDatetimeToTimestamp(rdfvalue.RDFDatetime.Now()) - proto_path_infos = [ - mig_objects.ToProtoPathInfo(info) for info in path_infos - ] int_client_id = db_utils.ClientIDToInt(client_id) # Since we need to validate client id even if there are no paths given, we # cannot rely on foreign key constraints and have to special-case this. - if not proto_path_infos: + if not path_infos: query = "SELECT client_id FROM clients WHERE client_id = %(client_id)s" cursor.execute(query, {"client_id": int_client_id}) if not cursor.fetchall(): @@ -263,41 +257,39 @@ def WritePathInfos( hash_entry_keys = [] hash_entry_values = [] - for proto_path_info in proto_path_infos: - path = mysql_utils.ComponentsToPath(proto_path_info.components) + for path_info in path_infos: + path = mysql_utils.ComponentsToPath(path_info.components) key = ( int_client_id, - int(proto_path_info.path_type), - rdf_objects.PathID.FromComponents( - proto_path_info.components - ).AsBytes(), + int(path_info.path_type), + rdf_objects.PathID.FromComponents(path_info.components).AsBytes(), ) details = ( now, path, - bool(proto_path_info.directory), - len(proto_path_info.components), + bool(path_info.directory), + len(path_info.components), ) path_info_values.append(key + details) - if proto_path_info.HasField("stat_entry"): + if path_info.HasField("stat_entry"): stat_entry_keys.extend(key) - details = (now, proto_path_info.stat_entry.SerializeToString()) + details = (now, path_info.stat_entry.SerializeToString()) stat_entry_values.append(key + details) - if proto_path_info.HasField("hash_entry"): + if path_info.HasField("hash_entry"): hash_entry_keys.extend(key) details = ( now, - proto_path_info.hash_entry.SerializeToString(), - proto_path_info.hash_entry.sha256, + path_info.hash_entry.SerializeToString(), + path_info.hash_entry.sha256, ) hash_entry_values.append(key + details) # TODO(hanuszczak): Implement a trie in order to avoid inserting # duplicated records. - for parent_path_info in paths.GetAncestorPathInfos(proto_path_info): + for parent_path_info in paths.GetAncestorPathInfos(path_info): path = mysql_utils.ComponentsToPath(parent_path_info.components) parent_key = ( int_client_id, @@ -361,15 +353,15 @@ def WritePathInfos( def ListDescendantPathInfos( self, client_id: str, - path_type: "rdf_objects.PathInfo.PathType", + path_type: objects_pb2.PathInfo.PathType, components: Sequence[str], timestamp: Optional[rdfvalue.RDFDatetime] = None, max_depth: Optional[int] = None, cursor: Optional[MySQLdb.cursors.Cursor] = None, - ) -> Sequence[rdf_objects.PathInfo]: + ) -> Sequence[objects_pb2.PathInfo]: """Lists path info records that correspond to descendants of given path.""" assert cursor is not None - proto_path_infos = [] + path_infos = [] query = "" @@ -471,7 +463,7 @@ def ListDescendantPathInfos( path_components = mysql_utils.PathToComponents(path) - proto_path_info = objects_pb2.PathInfo( + path_info = objects_pb2.PathInfo( path_type=objects_pb2.PathInfo.PathType.Name(path_type), components=path_components, directory=directory, @@ -479,72 +471,70 @@ def ListDescendantPathInfos( datetime = mysql_utils.TimestampToMicrosecondsSinceEpoch if timestamp is not None: - proto_path_info.timestamp = datetime(timestamp) + path_info.timestamp = datetime(timestamp) if last_stat_entry_timestamp is not None: - proto_path_info.last_stat_entry_timestamp = datetime( + path_info.last_stat_entry_timestamp = datetime( last_stat_entry_timestamp ) if last_hash_entry_timestamp is not None: - proto_path_info.last_hash_entry_timestamp = datetime( + path_info.last_hash_entry_timestamp = datetime( last_hash_entry_timestamp ) if stat_entry_bytes is not None: - proto_path_info.stat_entry.ParseFromString(stat_entry_bytes) + path_info.stat_entry.ParseFromString(stat_entry_bytes) if hash_entry_bytes is not None: - proto_path_info.hash_entry.ParseFromString(hash_entry_bytes) + path_info.hash_entry.ParseFromString(hash_entry_bytes) - proto_path_infos.append(proto_path_info) + path_infos.append(path_info) - proto_path_infos.sort(key=lambda _: tuple(_.components)) + path_infos.sort(key=lambda _: tuple(_.components)) # The first entry should be always the base directory itself unless it is a # root directory that was never collected. - if not proto_path_infos and components: + if not path_infos and components: raise db.UnknownPathError(client_id, path_type, components) - if proto_path_infos and not proto_path_infos[0].directory: + if path_infos and not path_infos[0].directory: raise db.NotDirectoryPathError(client_id, path_type, components) - proto_path_infos = proto_path_infos[1:] + path_infos = path_infos[1:] # For specific timestamp, we return information only about explicit paths # (paths that have associated stat or hash entry or have an ancestor that is # explicit). if not only_explicit: - return [mig_objects.ToRDFPathInfo(pi) for pi in proto_path_infos] + return path_infos explicit_path_infos = [] has_explicit_ancestor = set() # This list is sorted according to the keys component, so by traversing it # in the reverse order we make sure that we process deeper paths first. - for proto_path_info in reversed(proto_path_infos): - path_components = tuple(proto_path_info.components) + for path_info in reversed(path_infos): + path_components = tuple(path_info.components) if ( - proto_path_info.HasField("stat_entry") - or proto_path_info.HasField("hash_entry") + path_info.HasField("stat_entry") + or path_info.HasField("hash_entry") or path_components in has_explicit_ancestor ): - explicit_path_infos.append(proto_path_info) + explicit_path_infos.append(path_info) has_explicit_ancestor.add(path_components[:-1]) # Since we collected explicit paths in reverse order, we need to reverse it # again to conform to the interface. - return list( - reversed([mig_objects.ToRDFPathInfo(pi) for pi in explicit_path_infos]) - ) + return list(reversed(explicit_path_infos)) @mysql_utils.WithTransaction(readonly=True) def ReadPathInfosHistories( self, client_id: Text, - path_type: rdf_objects.PathInfo.PathType, + path_type: objects_pb2.PathInfo.PathType, components_list: Iterable[Sequence[Text]], cutoff: Optional[rdfvalue.RDFDatetime] = None, - cursor: Optional[MySQLdb.cursors.Cursor] = None - ) -> Dict[Sequence[Text], Sequence[rdf_objects.PathInfo]]: + cursor: Optional[MySQLdb.cursors.Cursor] = None, + ) -> Dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: """Reads a collection of hash and stat entries for given paths.""" assert cursor is not None @@ -552,12 +542,12 @@ def ReadPathInfosHistories( if not components_list: return {} - path_infos = {components: [] for components in components_list} + path_infos = {tuple(components): [] for components in components_list} - path_id_components = {} + path_id_components: dict[rdf_objects.PathID, tuple[str, ...]] = {} for components in components_list: path_id = rdf_objects.PathID.FromComponents(components) - path_id_components[path_id] = components + path_id_components[path_id] = tuple(components) params = { "client_id": db_utils.ClientIDToInt(client_id), @@ -613,7 +603,8 @@ def ReadPathInfosHistories( """.format( stat_entry_timestamp_condition=stat_entry_timestamp_condition, hash_entry_timestamp_condition=hash_entry_timestamp_condition, - path_id_placeholders=path_id_placeholders) + path_id_placeholders=path_id_placeholders, + ) cursor.execute(query, params) for row in cursor.fetchall(): @@ -624,29 +615,29 @@ def ReadPathInfosHistories( path_id_bytes = stat_entry_path_id_bytes or hash_entry_path_id_bytes path_id = rdf_objects.PathID.FromSerializedBytes(path_id_bytes) - components = path_id_components[path_id] + components: tuple[str, ...] = tuple(path_id_components[path_id]) timestamp = stat_entry_timestamp or hash_entry_timestamp - proto_path_info = objects_pb2.PathInfo( + path_info = objects_pb2.PathInfo( path_type=objects_pb2.PathInfo.PathType.Name(path_type), components=components, ) if timestamp is not None: - proto_path_info.timestamp = ( - mysql_utils.TimestampToMicrosecondsSinceEpoch(timestamp) + path_info.timestamp = mysql_utils.TimestampToMicrosecondsSinceEpoch( + timestamp ) if stat_entry_bytes is not None: - proto_path_info.stat_entry.ParseFromString(stat_entry_bytes) + path_info.stat_entry.ParseFromString(stat_entry_bytes) if hash_entry_bytes is not None: - proto_path_info.hash_entry.ParseFromString(hash_entry_bytes) + path_info.hash_entry.ParseFromString(hash_entry_bytes) - path_infos[components].append(mig_objects.ToRDFPathInfo(proto_path_info)) + path_infos[components].append(path_info) - for components in components_list: - path_infos[components].sort(key=lambda path_info: path_info.timestamp) + for comps in components_list: + path_infos[tuple(comps)].sort(key=lambda path_info: path_info.timestamp) return path_infos @@ -656,7 +647,7 @@ def ReadLatestPathInfosWithHashBlobReferences( client_paths: Collection[db.ClientPath], max_timestamp: Optional[rdfvalue.RDFDatetime] = None, cursor: Optional[MySQLdb.cursors.Cursor] = None, - ) -> Dict[db.ClientPath, Optional[rdf_objects.PathInfo]]: + ) -> Dict[db.ClientPath, Optional[objects_pb2.PathInfo]]: """Returns PathInfos that have corresponding HashBlobReferences.""" assert cursor is not None path_infos = {client_path: None for client_path in client_paths} @@ -716,22 +707,23 @@ def ReadLatestPathInfosWithHashBlobReferences( client_path = db.ClientPath( client_id=db_utils.IntToClientID(client_id), path_type=path_type, - components=path_id_components[path_id]) + components=path_id_components[path_id], + ) - proto_path_info = objects_pb2.PathInfo( + path_info = objects_pb2.PathInfo( path_type=objects_pb2.PathInfo.PathType.Name(path_type), components=components, ) datetime = mysql_utils.TimestampToMicrosecondsSinceEpoch if timestamp is not None: - proto_path_info.timestamp = datetime(timestamp) + path_info.timestamp = datetime(timestamp) if stat_entry_bytes is not None: - proto_path_info.stat_entry.ParseFromString(stat_entry_bytes) + path_info.stat_entry.ParseFromString(stat_entry_bytes) if hash_entry_bytes is not None: - proto_path_info.hash_entry.ParseFromString(hash_entry_bytes) + path_info.hash_entry.ParseFromString(hash_entry_bytes) - path_infos[client_path] = mig_objects.ToRDFPathInfo(proto_path_info) + path_infos[client_path] = path_info return path_infos diff --git a/grr/server/grr_response_server/databases/mysql_paths_test.py b/grr/server/grr_response_server/databases/mysql_paths_test.py index 34f0e4ecef..9959f6c67e 100644 --- a/grr/server/grr_response_server/databases/mysql_paths_test.py +++ b/grr/server/grr_response_server/databases/mysql_paths_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlPathsTest(db_paths_test.DatabaseTestPathsMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlPathsTest( + db_paths_test.DatabaseTestPathsMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_pool_test.py b/grr/server/grr_response_server/databases/mysql_pool_test.py index b7bc2c6a5a..db84c8dd74 100644 --- a/grr/server/grr_response_server/databases/mysql_pool_test.py +++ b/grr/server/grr_response_server/databases/mysql_pool_test.py @@ -66,8 +66,12 @@ def operational_error(*args, **kwargs): bad_cursor_mock = mock.MagicMock() for m in [ - 'callproc', 'execute', 'executemany', 'fetchone', 'fetchmany', - 'fetchall' + 'callproc', + 'execute', + 'executemany', + 'fetchone', + 'fetchmany', + 'fetchall', ]: getattr(bad_cursor_mock, m).side_effect = operational_error @@ -80,10 +84,12 @@ def gen_bad(): pool = mysql_pool.Pool(gen_bad, max_size=5) for op in [ - lambda c: c.callproc('my_proc'), lambda c: c. - execute('SELECT foo FROM bar'), lambda c: c.executemany( - 'INSERT INTO foo(bar) VALUES %s', ['A', 'B']), lambda c: c.fetchone( - ), lambda c: c.fetchmany(size=5), lambda c: c.fetchall() + lambda c: c.callproc('my_proc'), + lambda c: c.execute('SELECT foo FROM bar'), + lambda c: c.executemany('INSERT INTO foo(bar) VALUES %s', ['A', 'B']), + lambda c: c.fetchone(), + lambda c: c.fetchmany(size=5), + lambda c: c.fetchall(), ]: # If we can fail 10 times, then failed connections aren't consuming # pool capacity. @@ -101,8 +107,12 @@ def testGoodConnection(self): good_cursor_mock = mock.MagicMock() for m in [ - 'callproc', 'execute', 'executemany', 'fetchone', 'fetchmany', - 'fetchall' + 'callproc', + 'execute', + 'executemany', + 'fetchone', + 'fetchmany', + 'fetchall', ]: getattr(good_cursor_mock, m).return_value = m @@ -117,11 +127,15 @@ def gen_good(): for m, op in [ ('callproc', lambda c: c.callproc('my_proc')), ('execute', lambda c: c.execute('SELECT foo FROM bar')), - ('executemany', - lambda c: c.executemany('INSERT INTO foo(bar) VALUES %s', ['A', 'B'])), + ( + 'executemany', + lambda c: c.executemany( + 'INSERT INTO foo(bar) VALUES %s', ['A', 'B'] + ), + ), ('fetchone', lambda c: c.fetchone()), ('fetchmany', lambda c: c.fetchmany(size=5)), - ('fetchall', lambda c: c.fetchall()) + ('fetchall', lambda c: c.fetchall()), ]: # If we can fail 10 times, then idling a connection doesn't consume pool # capacity. diff --git a/grr/server/grr_response_server/databases/mysql_signed_binaries.py b/grr/server/grr_response_server/databases/mysql_signed_binaries.py index 673a1c2597..211291d99e 100644 --- a/grr/server/grr_response_server/databases/mysql_signed_binaries.py +++ b/grr/server/grr_response_server/databases/mysql_signed_binaries.py @@ -37,8 +37,8 @@ def WriteSignedBinaryReferences( ON DUPLICATE KEY UPDATE blob_references = VALUES(blob_references) """.format( - cols=mysql_utils.Columns(args), - vals=mysql_utils.NamedPlaceholders(args)) + cols=mysql_utils.Columns(args), vals=mysql_utils.NamedPlaceholders(args) + ) cursor.execute(query, args) @mysql_utils.WithTransaction(readonly=True) @@ -66,8 +66,9 @@ def ReadSignedBinaryReferences( raw_references, timestamp = row # TODO(hanuszczak): pytype does not understand overloads, so we have to cast # to a non-optional object. - datetime = cast(rdfvalue.RDFDatetime, - mysql_utils.TimestampToRDFDatetime(timestamp)) + datetime = cast( + rdfvalue.RDFDatetime, mysql_utils.TimestampToRDFDatetime(timestamp) + ) references = objects_pb2.BlobReferences() references.ParseFromString(raw_references) @@ -82,7 +83,8 @@ def ReadIDsForAllSignedBinaries( assert cursor is not None cursor.execute( - "SELECT binary_type, binary_path FROM signed_binary_references") + "SELECT binary_type, binary_path FROM signed_binary_references" + ) return [ objects_pb2.SignedBinaryID(binary_type=binary_type, path=binary_path) for binary_type, binary_path in cursor.fetchall() diff --git a/grr/server/grr_response_server/databases/mysql_signed_binaries_test.py b/grr/server/grr_response_server/databases/mysql_signed_binaries_test.py index 4360339041..68cf33d3c4 100644 --- a/grr/server/grr_response_server/databases/mysql_signed_binaries_test.py +++ b/grr/server/grr_response_server/databases/mysql_signed_binaries_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -10,7 +9,9 @@ class MysqlSignedBinariesTest( db_signed_binaries_test.DatabaseTestSignedBinariesMixin, - mysql_test.MysqlTestBase, absltest.TestCase): + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_test.py b/grr/server/grr_response_server/databases/mysql_test.py index 656e1134ac..b23ec23b44 100644 --- a/grr/server/grr_response_server/databases/mysql_test.py +++ b/grr/server/grr_response_server/databases/mysql_test.py @@ -142,9 +142,12 @@ class MysqlTestBase(MySQLDatabaseProviderMixin): pass -class TestMysqlDB(stats_test_lib.StatsTestMixin, - db_test_mixin.DatabaseTestMixin, MysqlTestBase, - absltest.TestCase): +class TestMysqlDB( + stats_test_lib.StatsTestMixin, + db_test_mixin.DatabaseTestMixin, + MysqlTestBase, + absltest.TestCase, +): """Test the mysql.MysqlDB class. Most of the tests in this suite are general blackbox tests of the db.Database @@ -156,27 +159,40 @@ def testIsRetryable(self): self.assertFalse( mysql._IsRetryable( MySQLdb.OperationalError( - 1416, "Cannot get geometry object from data..."))) + 1416, "Cannot get geometry object from data..." + ) + ) + ) self.assertTrue( mysql._IsRetryable( MySQLdb.OperationalError( - 1205, "Lock wait timeout exceeded; try restarting..."))) + 1205, "Lock wait timeout exceeded; try restarting..." + ) + ) + ) self.assertTrue( mysql._IsRetryable( MySQLdb.OperationalError( 1213, - "Deadlock found when trying to get lock; try restarting..."))) + "Deadlock found when trying to get lock; try restarting...", + ) + ) + ) self.assertTrue( mysql._IsRetryable( MySQLdb.OperationalError( - 1637, "Too many active concurrent transactions"))) + 1637, "Too many active concurrent transactions" + ) + ) + ) def AddUser(self, connection, user, password): cursor = connection.cursor() cursor.execute( "INSERT INTO grr_users (username, username_hash, password) " "VALUES (%s, %s, %s)", - (user, mysql_utils.Hash(user), password.encode("utf-8"))) + (user, mysql_utils.Hash(user), password.encode("utf-8")), + ) cursor.close() def ListUsers(self, connection): @@ -222,28 +238,36 @@ def Transaction1(connection): counts[0] += 1 cursor = connection.cursor() cursor.execute( - "SELECT password FROM grr_users WHERE username = 'user1' FOR UPDATE") + "SELECT password FROM grr_users WHERE username = 'user1' FOR UPDATE" + ) t1_halfway.set() self.assertTrue(t2_halfway.wait(5)) - cursor.execute("UPDATE grr_users SET password = 'pw2-updated' " - "WHERE username = 'user2'") + cursor.execute( + "UPDATE grr_users SET password = 'pw2-updated' " + "WHERE username = 'user2'" + ) cursor.close() def Transaction2(connection): counts[1] += 1 cursor = connection.cursor() cursor.execute( - "SELECT password FROM grr_users WHERE username = 'user2' FOR UPDATE") + "SELECT password FROM grr_users WHERE username = 'user2' FOR UPDATE" + ) t2_halfway.set() self.assertTrue(t1_halfway.wait(5)) - cursor.execute("UPDATE grr_users SET password = 'pw1-updated' " - "WHERE username = 'user1'") + cursor.execute( + "UPDATE grr_users SET password = 'pw1-updated' " + "WHERE username = 'user1'" + ) cursor.close() thread_1 = threading.Thread( - target=lambda: self.db.delegate._RunInTransaction(Transaction1)) + target=lambda: self.db.delegate._RunInTransaction(Transaction1) + ) thread_2 = threading.Thread( - target=lambda: self.db.delegate._RunInTransaction(Transaction2)) + target=lambda: self.db.delegate._RunInTransaction(Transaction2) + ) thread_1.start() thread_2.start() @@ -253,8 +277,9 @@ def Transaction2(connection): # Both transaction should have succeeded users = self.db.delegate._RunInTransaction(self.ListUsers, readonly=True) - self.assertEqual(users, - (("user1", b"pw1-updated"), ("user2", b"pw2-updated"))) + self.assertEqual( + users, (("user1", b"pw1-updated"), ("user2", b"pw2-updated")) + ) # At least one should have been retried. self.assertGreater(sum(counts), 2) @@ -262,7 +287,8 @@ def Transaction2(connection): def testSuccessfulCallsAreCorrectlyAccounted(self): with self.assertStatsCounterDelta( - 1, db_utils.DB_REQUEST_LATENCY, fields=["ReadGRRUsers"]): + 1, db_utils.DB_REQUEST_LATENCY, fields=["ReadGRRUsers"] + ): self.db.ReadGRRUsers() def testMaxAllowedPacketSettingIsOverriddenWhenTooLow(self): @@ -301,7 +327,8 @@ def SetMaxAllowedPacket(conn): with mock.patch.object( mysql, "_SetGlobalVariable", - side_effect=MySQLdb.OperationalError("SUPER privileges required")): + side_effect=MySQLdb.OperationalError("SUPER privileges required"), + ): with self.assertRaises(mysql.MaxAllowedPacketSettingTooLowError): self.__class__._Connect() @@ -320,8 +347,9 @@ def RaiseServerGoneError(connection): connection.close = mock.Mock(wraps=real_close_fn) connections.append(connection) - raise MySQLdb.OperationalError(mysql_conn_errors.SERVER_GONE_ERROR, - expected_error_msg) + raise MySQLdb.OperationalError( + mysql_conn_errors.SERVER_GONE_ERROR, expected_error_msg + ) with mock.patch.object(self.db.delegate, "_max_pool_size", 6): with self.assertRaises(MySQLdb.OperationalError) as context: @@ -375,8 +403,9 @@ def RaisePermanentError(connection): connection.close = mock.Mock(wraps=real_close_fn) connections.append(connection) - raise MySQLdb.OperationalError(mysql_conn_errors.NOT_IMPLEMENTED, - expected_error_msg) + raise MySQLdb.OperationalError( + mysql_conn_errors.NOT_IMPLEMENTED, expected_error_msg + ) with self.assertRaises(MySQLdb.OperationalError) as context: self.db.delegate._RunInTransaction(RaisePermanentError) diff --git a/grr/server/grr_response_server/databases/mysql_time_test.py b/grr/server/grr_response_server/databases/mysql_time_test.py index fd9cd4d650..08b4989602 100644 --- a/grr/server/grr_response_server/databases/mysql_time_test.py +++ b/grr/server/grr_response_server/databases/mysql_time_test.py @@ -7,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlClientsTest(db_time_test.DatabaseTimeTestMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlClientsTest( + db_time_test.DatabaseTimeTestMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_users.py b/grr/server/grr_response_server/databases/mysql_users.py index 56a50fb705..fc75cc8ebc 100644 --- a/grr/server/grr_response_server/databases/mysql_users.py +++ b/grr/server/grr_response_server/databases/mysql_users.py @@ -370,7 +370,8 @@ def WriteUserNotification( } query = "INSERT INTO user_notification {columns} VALUES {values}".format( columns=mysql_utils.Columns(args), - values=mysql_utils.NamedPlaceholders(args)) + values=mysql_utils.NamedPlaceholders(args), + ) try: cursor.execute(query, args) except MySQLdb.IntegrityError: diff --git a/grr/server/grr_response_server/databases/mysql_users_test.py b/grr/server/grr_response_server/databases/mysql_users_test.py index 7c9147a37b..ea63994893 100644 --- a/grr/server/grr_response_server/databases/mysql_users_test.py +++ b/grr/server/grr_response_server/databases/mysql_users_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl import app from absl.testing import absltest @@ -8,8 +7,11 @@ from grr.test_lib import test_lib -class MysqlUsersTest(db_users_test.DatabaseTestUsersMixin, - mysql_test.MysqlTestBase, absltest.TestCase): +class MysqlUsersTest( + db_users_test.DatabaseTestUsersMixin, + mysql_test.MysqlTestBase, + absltest.TestCase, +): pass diff --git a/grr/server/grr_response_server/databases/mysql_utils.py b/grr/server/grr_response_server/databases/mysql_utils.py index 92aa4ceb7e..e4e36ae733 100644 --- a/grr/server/grr_response_server/databases/mysql_utils.py +++ b/grr/server/grr_response_server/databases/mysql_utils.py @@ -4,7 +4,6 @@ import contextlib import functools import hashlib -import inspect from typing import Iterable from typing import Optional from typing import Sequence @@ -94,6 +93,7 @@ def Columns(iterable: Iterable[Text]) -> Text: Args: iterable: The iterable of strings to be used as column names. + Returns: A string containing a tuple of sorted comma-separated column names. """ columns = sorted(iterable) @@ -153,7 +153,7 @@ def ComponentsToPath(components: Sequence[Text]) -> Text: return "" -def PathToComponents(path: Text) -> Sequence[Text]: +def PathToComponents(path: Text) -> tuple[str, ...]: """Converts a canonical path representation to a list of components. Args: @@ -201,33 +201,6 @@ def __init__(self, readonly=False): def __call__(self, func): readonly = self.readonly - takes_args = inspect.getfullargspec(func).args - takes_connection = "connection" in takes_args - takes_cursor = "cursor" in takes_args - - if takes_connection == takes_cursor: - raise TypeError( - "@mysql_utils.WithTransaction requires a function to take exactly " - "one of 'connection', 'cursor', got: %s" % str(takes_args)) - - if takes_connection: - - @functools.wraps(func) - def Decorated(self, *args, **kw): - """A function decorated by WithTransaction to receive a connection.""" - connection = kw.get("connection", None) - if connection: - return func(self, *args, **kw) - - def Closure(connection): - new_kw = kw.copy() - new_kw["connection"] = connection - return func(self, *args, **new_kw) - - return self._RunInTransaction(Closure, readonly) - - return Decorated - @functools.wraps(func) def Decorated(self, *args, **kw): # pylint: disable=function-redefined """A function decorated by WithTransaction to receive a cursor.""" @@ -243,7 +216,7 @@ def Closure(connection): return self._RunInTransaction(Closure, readonly) - return db_utils.CallLoggedAndAccounted(Decorated) + return db_utils.CallLogged(db_utils.CallAccounted(Decorated)) class RetryableError(db_module.Error): diff --git a/grr/server/grr_response_server/databases/mysql_utils_test.py b/grr/server/grr_response_server/databases/mysql_utils_test.py index 50a1eefac2..6cc02a241a 100644 --- a/grr/server/grr_response_server/databases/mysql_utils_test.py +++ b/grr/server/grr_response_server/databases/mysql_utils_test.py @@ -27,7 +27,8 @@ def testZeroValues(self): def testManyValues(self): self.assertEqual( - mysql_utils.Placeholders(3, 2), "(%s, %s, %s), (%s, %s, %s)") + mysql_utils.Placeholders(3, 2), "(%s, %s, %s), (%s, %s, %s)" + ) class NamedPlaceholdersTest(absltest.TestCase): @@ -41,20 +42,20 @@ def testOne(self): def testMany(self): self.assertEqual( mysql_utils.NamedPlaceholders(["bar", "baz", "foo"]), - "(%(bar)s, %(baz)s, %(foo)s)") + "(%(bar)s, %(baz)s, %(foo)s)", + ) def testDictUsesKeys(self): self.assertIn( - mysql_utils.NamedPlaceholders({ - "bar": 42, - "baz": 42, - "foo": 42 - }), ["(%(bar)s, %(baz)s, %(foo)s)"]) + mysql_utils.NamedPlaceholders({"bar": 42, "baz": 42, "foo": 42}), + ["(%(bar)s, %(baz)s, %(foo)s)"], + ) def testSortsNames(self): self.assertEqual( mysql_utils.NamedPlaceholders(["bar", "foo", "baz"]), - "(%(bar)s, %(baz)s, %(foo)s)") + "(%(bar)s, %(baz)s, %(foo)s)", + ) class ColumnsTest(absltest.TestCase): @@ -67,19 +68,19 @@ def testOne(self): def testMany(self): self.assertEqual( - mysql_utils.Columns(["bar", "baz", "foo"]), "(`bar`, `baz`, `foo`)") + mysql_utils.Columns(["bar", "baz", "foo"]), "(`bar`, `baz`, `foo`)" + ) def testDictUsesKeys(self): self.assertIn( - mysql_utils.Columns({ - "bar": 42, - "baz": 42, - "foo": 42 - }), ["(`bar`, `baz`, `foo`)"]) + mysql_utils.Columns({"bar": 42, "baz": 42, "foo": 42}), + ["(`bar`, `baz`, `foo`)"], + ) def testSortsNames(self): self.assertEqual( - mysql_utils.Columns(["bar", "foo", "baz"]), "(`bar`, `baz`, `foo`)") + mysql_utils.Columns(["bar", "foo", "baz"]), "(`bar`, `baz`, `foo`)" + ) def testSortsRawNamesWithoutEscape(self): self.assertGreater("`", "_") diff --git a/grr/server/grr_response_server/databases/mysql_yara_test.py b/grr/server/grr_response_server/databases/mysql_yara_test.py index 196f0c2857..8c9330e5b9 100644 --- a/grr/server/grr_response_server/databases/mysql_yara_test.py +++ b/grr/server/grr_response_server/databases/mysql_yara_test.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - from absl.testing import absltest from grr_response_server.databases import db_yara_test_lib diff --git a/grr/server/grr_response_server/export.py b/grr/server/grr_response_server/export.py index b8e68382d7..b3c5bea174 100644 --- a/grr/server/grr_response_server/export.py +++ b/grr/server/grr_response_server/export.py @@ -44,8 +44,6 @@ def GetMetadata(client_id, client_full_info): metadata.hostname = kb.fqdn metadata.os = kb.os - # TODO: Remove this once the field is gone. - metadata.uname = f"{kb.os}-{os_release}-{os_version}" metadata.os_release = os_release metadata.os_version = os_version metadata.usernames = ",".join(user.username for user in kb.users) diff --git a/grr/server/grr_response_server/export_converters/base.py b/grr/server/grr_response_server/export_converters/base.py index 934b4f04df..10edae5f76 100644 --- a/grr/server/grr_response_server/export_converters/base.py +++ b/grr/server/grr_response_server/export_converters/base.py @@ -30,7 +30,7 @@ class ExportOptions(rdf_structs.RDFProtoStruct): protobuf = export_pb2.ExportOptions -class ExportConverter(): +class ExportConverter: """Base ExportConverter class. ExportConverters are used to convert RDFValues to export-friendly RDFValues. diff --git a/grr/server/grr_response_server/export_converters/buffer_reference.py b/grr/server/grr_response_server/export_converters/buffer_reference.py index e73f7ec808..278f16ab17 100644 --- a/grr/server/grr_response_server/export_converters/buffer_reference.py +++ b/grr/server/grr_response_server/export_converters/buffer_reference.py @@ -24,11 +24,14 @@ class BufferReferenceToExportedMatchConverter(base.ExportConverter): input_rdf_type = rdf_client.BufferReference def Convert( - self, metadata: base.ExportedMetadata, - buffer_reference: rdf_client.BufferReference) -> Iterator[ExportedMatch]: + self, + metadata: base.ExportedMetadata, + buffer_reference: rdf_client.BufferReference, + ) -> Iterator[ExportedMatch]: yield ExportedMatch( metadata=metadata, offset=buffer_reference.offset, length=buffer_reference.length, data=buffer_reference.data, - urn=buffer_reference.pathspec.AFF4Path(metadata.client_urn)) + urn=buffer_reference.pathspec.AFF4Path(metadata.client_urn), + ) diff --git a/grr/server/grr_response_server/export_converters/buffer_reference_test.py b/grr/server/grr_response_server/export_converters/buffer_reference_test.py index d042c177db..6e167b9c52 100644 --- a/grr/server/grr_response_server/export_converters/buffer_reference_test.py +++ b/grr/server/grr_response_server/export_converters/buffer_reference_test.py @@ -8,12 +8,14 @@ from grr.test_lib import test_lib -class BufferReferenceToExportedMatchConverterTest(export_test_lib.ExportTestBase - ): +class BufferReferenceToExportedMatchConverterTest( + export_test_lib.ExportTestBase +): def testBasicConversion(self): pathspec = rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ) buffer = rdf_client.BufferReference( length=123, offset=456, diff --git a/grr/server/grr_response_server/export_converters/client_summary.py b/grr/server/grr_response_server/export_converters/client_summary.py index b43ebd5bc3..fae5a766f0 100644 --- a/grr/server/grr_response_server/export_converters/client_summary.py +++ b/grr/server/grr_response_server/export_converters/client_summary.py @@ -18,14 +18,16 @@ class ExportedClient(rdf_structs.RDFProtoStruct): class ClientSummaryToExportedNetworkInterfaceConverter( - network.InterfaceToExportedNetworkInterfaceConverter): + network.InterfaceToExportedNetworkInterfaceConverter +): """Converts a ClientSummary to ExportedNetworkInterfaces.""" input_rdf_type = rdf_client.ClientSummary def Convert( - self, metadata: base.ExportedMetadata, - client_summary: rdf_client.ClientSummary + self, + metadata: base.ExportedMetadata, + client_summary: rdf_client.ClientSummary, ) -> Iterator[network.ExportedNetworkInterface]: """Converts a ClientSummary into ExportedNetworkInterfaces. @@ -49,8 +51,10 @@ class ClientSummaryToExportedClientConverter(base.ExportConverter): input_rdf_type = rdf_client.ClientSummary def Convert( - self, metadata: base.ExportedMetadata, - unused_client_summary: rdf_client.ClientSummary) -> List[ExportedClient]: + self, + metadata: base.ExportedMetadata, + unused_client_summary: rdf_client.ClientSummary, + ) -> List[ExportedClient]: """Returns an ExportedClient using the ExportedMetadata. Args: diff --git a/grr/server/grr_response_server/export_converters/client_summary_test.py b/grr/server/grr_response_server/export_converters/client_summary_test.py index 5d8dbcd312..bf3805b148 100644 --- a/grr/server/grr_response_server/export_converters/client_summary_test.py +++ b/grr/server/grr_response_server/export_converters/client_summary_test.py @@ -15,34 +15,44 @@ class ClientSummaryToExportedNetworkInterfaceConverterTest( - export_test_lib.ExportTestBase): + export_test_lib.ExportTestBase +): def testClientSummaryToExportedNetworkInterfaceConverter(self): mac_address_bytes = b"123456" mac_address = text.Hexify(mac_address_bytes) - summary = rdf_client.ClientSummary(interfaces=[ - rdf_client_network.Interface( - mac_address=mac_address_bytes, - ifname="eth0", - addresses=[ - rdf_client_network.NetworkAddress( - address_type=rdf_client_network.NetworkAddress.Family.INET, - packed_bytes=socket.inet_pton(socket.AF_INET, "127.0.0.1"), - ), - rdf_client_network.NetworkAddress( - address_type=rdf_client_network.NetworkAddress.Family.INET, - packed_bytes=socket.inet_pton(socket.AF_INET, "10.0.0.1"), - ), - rdf_client_network.NetworkAddress( - address_type=rdf_client_network.NetworkAddress.Family.INET6, - packed_bytes=socket.inet_pton(socket.AF_INET6, - "2001:720:1500:1::a100"), - ) - ]) - ]) - - converter = client_summary.ClientSummaryToExportedNetworkInterfaceConverter( + summary = rdf_client.ClientSummary( + interfaces=[ + rdf_client_network.Interface( + mac_address=mac_address_bytes, + ifname="eth0", + addresses=[ + rdf_client_network.NetworkAddress( + address_type=rdf_client_network.NetworkAddress.Family.INET, + packed_bytes=socket.inet_pton( + socket.AF_INET, "127.0.0.1" + ), + ), + rdf_client_network.NetworkAddress( + address_type=rdf_client_network.NetworkAddress.Family.INET, + packed_bytes=socket.inet_pton( + socket.AF_INET, "10.0.0.1" + ), + ), + rdf_client_network.NetworkAddress( + address_type=rdf_client_network.NetworkAddress.Family.INET6, + packed_bytes=socket.inet_pton( + socket.AF_INET6, "2001:720:1500:1::a100" + ), + ), + ], + ) + ] + ) + + converter = ( + client_summary.ClientSummaryToExportedNetworkInterfaceConverter() ) results = list(converter.Convert(self.metadata, summary)) self.assertLen(results, 1) @@ -52,8 +62,9 @@ def testClientSummaryToExportedNetworkInterfaceConverter(self): self.assertEqual(results[0].ip6_addresses, "2001:720:1500:1::a100") -class ClientSummaryToExportedClientConverterTest(export_test_lib.ExportTestBase - ): +class ClientSummaryToExportedClientConverterTest( + export_test_lib.ExportTestBase +): def testClientSummaryToExportedClientConverter(self): summary = rdf_client.ClientSummary() diff --git a/grr/server/grr_response_server/export_converters/cron_tab_file.py b/grr/server/grr_response_server/export_converters/cron_tab_file.py index 17cffb622e..fec82446ba 100644 --- a/grr/server/grr_response_server/export_converters/cron_tab_file.py +++ b/grr/server/grr_response_server/export_converters/cron_tab_file.py @@ -20,8 +20,9 @@ class CronTabFileConverter(base.ExportConverter): input_rdf_type = rdf_cronjobs.CronTabFile def Convert( - self, metadata: base.ExportedMetadata, - cron_tab_file: rdf_cronjobs.CronTabFile + self, + metadata: base.ExportedMetadata, + cron_tab_file: rdf_cronjobs.CronTabFile, ) -> Iterator[ExportedCronTabEntry]: for j in cron_tab_file.jobs: yield ExportedCronTabEntry( @@ -33,4 +34,5 @@ def Convert( month=j.month, dayofweek=j.dayofweek, command=j.command, - comment=j.comment) + comment=j.comment, + ) diff --git a/grr/server/grr_response_server/export_converters/cron_tab_file_test.py b/grr/server/grr_response_server/export_converters/cron_tab_file_test.py index 63b9dbe4c7..0726cdaa7d 100644 --- a/grr/server/grr_response_server/export_converters/cron_tab_file_test.py +++ b/grr/server/grr_response_server/export_converters/cron_tab_file_test.py @@ -22,7 +22,8 @@ def testExportsFileWithTwoEntries(self): month="4", dayofweek="1", command="bash", - comment="foo"), + comment="foo", + ), rdf_cronjobs.CronTabEntry( minute="aa", hour="bb", @@ -30,12 +31,15 @@ def testExportsFileWithTwoEntries(self): month="dd", dayofweek="ee", command="ps", - comment="some"), - ]) + comment="some", + ), + ], + ) converter = cron_tab_file.CronTabFileConverter() converted = list( - converter.Convert(base.ExportedMetadata(self.metadata), sample)) + converter.Convert(base.ExportedMetadata(self.metadata), sample) + ) self.assertLen(converted, 2) self.assertIsInstance(converted[0], cron_tab_file.ExportedCronTabEntry) diff --git a/grr/server/grr_response_server/export_converters/data_agnostic.py b/grr/server/grr_response_server/export_converters/data_agnostic.py index f6f3649e89..eaa7048ce2 100644 --- a/grr/server/grr_response_server/export_converters/data_agnostic.py +++ b/grr/server/grr_response_server/export_converters/data_agnostic.py @@ -44,33 +44,46 @@ def Flatten(self, metadata, value_to_flatten): # Metadata is always the first field of exported data. descriptors.append( rdf_structs.ProtoEmbedded( - name="metadata", field_number=1, nested=base.ExportedMetadata)) + name="metadata", field_number=1, nested=base.ExportedMetadata + ) + ) for number, desc in sorted(value.type_infos_by_field_number.items()): # Name 'metadata' is reserved to store ExportedMetadata value. if desc.name == "metadata": - logging.debug("Ignoring 'metadata' field in %s.", - value.__class__.__name__) + logging.debug( + "Ignoring 'metadata' field in %s.", value.__class__.__name__ + ) continue # Copy descriptors for primivie values as-is, just make sure their # field number is correct. - if isinstance(desc, (rdf_structs.ProtoBinary, rdf_structs.ProtoString, - rdf_structs.ProtoUnsignedInteger, - rdf_structs.ProtoRDFValue, rdf_structs.ProtoEnum)): + if isinstance( + desc, + ( + rdf_structs.ProtoBinary, + rdf_structs.ProtoString, + rdf_structs.ProtoUnsignedInteger, + rdf_structs.ProtoRDFValue, + rdf_structs.ProtoEnum, + ), + ): # Incrementing field number by 1, as 1 is always occuppied by metadata. descriptors.append(desc.Copy(field_number=number + 1)) - if (isinstance(desc, rdf_structs.ProtoEnum) and - not isinstance(desc, rdf_structs.ProtoBoolean)): + if isinstance(desc, rdf_structs.ProtoEnum) and not isinstance( + desc, rdf_structs.ProtoBoolean + ): # Attach the enum container to the class for easy reference: enums[desc.enum_name] = desc.enum_container # Create the class as late as possible. This will modify a # metaclass registry, we need to make sure there are no problems. output_class = type( - self.ExportedClassNameForValue(value), (AutoExportedProtoStruct,), - dict(Flatten=Flatten)) + self.ExportedClassNameForValue(value), + (AutoExportedProtoStruct,), + dict(Flatten=Flatten), + ) for descriptor in descriptors: output_class.AddDescriptor(descriptor) diff --git a/grr/server/grr_response_server/export_converters/data_agnostic_test.py b/grr/server/grr_response_server/export_converters/data_agnostic_test.py index d0aaa508e8..5412c35c11 100644 --- a/grr/server/grr_response_server/export_converters/data_agnostic_test.py +++ b/grr/server/grr_response_server/export_converters/data_agnostic_test.py @@ -14,9 +14,12 @@ class DataAgnosticExportConverterTest(export_test_lib.ExportTestBase): """Tests for DataAgnosticExportConverter.""" def ConvertOriginalValue(self, original_value): - converted_values = list(data_agnostic.DataAgnosticExportConverter().Convert( - base.ExportedMetadata(source_urn=rdfvalue.RDFURN("aff4:/foo")), - original_value)) + converted_values = list( + data_agnostic.DataAgnosticExportConverter().Convert( + base.ExportedMetadata(source_urn=rdfvalue.RDFURN("aff4:/foo")), + original_value, + ) + ) self.assertLen(converted_values, 1) return converted_values[0] @@ -25,30 +28,52 @@ def testAddsMetadataAndIgnoresRepeatedAndMessagesFields(self): converted_value = self.ConvertOriginalValue(original_value) # No 'metadata' field in the original value. - self.assertCountEqual([t.name for t in original_value.type_infos], [ - "string_value", "int_value", "bool_value", "repeated_string_value", - "message_value", "enum_value", "another_enum_value", "urn_value", - "datetime_value" - ]) + self.assertCountEqual( + [t.name for t in original_value.type_infos], + [ + "string_value", + "int_value", + "bool_value", + "repeated_string_value", + "message_value", + "enum_value", + "another_enum_value", + "urn_value", + "datetime_value", + ], + ) # But there's one in the converted value. - self.assertCountEqual([t.name for t in converted_value.type_infos], [ - "metadata", "string_value", "int_value", "bool_value", "enum_value", - "another_enum_value", "urn_value", "datetime_value" - ]) + self.assertCountEqual( + [t.name for t in converted_value.type_infos], + [ + "metadata", + "string_value", + "int_value", + "bool_value", + "enum_value", + "another_enum_value", + "urn_value", + "datetime_value", + ], + ) # Metadata value is correctly initialized from user-supplied metadata. - self.assertEqual(converted_value.metadata.source_urn, - rdfvalue.RDFURN("aff4:/foo")) + self.assertEqual( + converted_value.metadata.source_urn, rdfvalue.RDFURN("aff4:/foo") + ) def testIgnoresPredefinedMetadataField(self): original_value = export_test_lib.DataAgnosticConverterTestValueWithMetadata( - metadata=42, value="value") + metadata=42, value="value" + ) converted_value = self.ConvertOriginalValue(original_value) - self.assertCountEqual([t.name for t in converted_value.type_infos], - ["metadata", "value"]) - self.assertEqual(converted_value.metadata.source_urn, - rdfvalue.RDFURN("aff4:/foo")) + self.assertCountEqual( + [t.name for t in converted_value.type_infos], ["metadata", "value"] + ) + self.assertEqual( + converted_value.metadata.source_urn, rdfvalue.RDFURN("aff4:/foo") + ) self.assertEqual(converted_value.value, "value") def testProcessesPrimitiveTypesCorrectly(self): @@ -56,45 +81,55 @@ def testProcessesPrimitiveTypesCorrectly(self): string_value="string value", int_value=42, bool_value=True, - enum_value=export_test_lib.DataAgnosticConverterTestValue.EnumOption - .OPTION_2, + enum_value=export_test_lib.DataAgnosticConverterTestValue.EnumOption.OPTION_2, urn_value=rdfvalue.RDFURN("aff4:/bar"), - datetime_value=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42)) + datetime_value=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + ) converted_value = self.ConvertOriginalValue(original_value) - self.assertEqual(converted_value.string_value.__class__, - original_value.string_value.__class__) + self.assertEqual( + converted_value.string_value.__class__, + original_value.string_value.__class__, + ) self.assertEqual(converted_value.string_value, "string value") - self.assertEqual(converted_value.int_value.__class__, - original_value.int_value.__class__) + self.assertEqual( + converted_value.int_value.__class__, original_value.int_value.__class__ + ) self.assertEqual(converted_value.int_value, 42) - self.assertEqual(converted_value.bool_value.__class__, - original_value.bool_value.__class__) + self.assertEqual( + converted_value.bool_value.__class__, + original_value.bool_value.__class__, + ) self.assertEqual(converted_value.bool_value, True) - self.assertEqual(converted_value.enum_value.__class__, - original_value.enum_value.__class__) - self.assertEqual(converted_value.enum_value, - converted_value.EnumOption.OPTION_2) + self.assertEqual( + converted_value.enum_value.__class__, + original_value.enum_value.__class__, + ) + self.assertEqual( + converted_value.enum_value, converted_value.EnumOption.OPTION_2 + ) self.assertIsInstance(converted_value.urn_value, rdfvalue.RDFURN) self.assertEqual(converted_value.urn_value, rdfvalue.RDFURN("aff4:/bar")) self.assertIsInstance(converted_value.datetime_value, rdfvalue.RDFDatetime) - self.assertEqual(converted_value.datetime_value, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42)) + self.assertEqual( + converted_value.datetime_value, + rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + ) def testConvertedValuesCanBeSerializedAndDeserialized(self): original_value = export_test_lib.DataAgnosticConverterTestValue( string_value="string value", int_value=42, bool_value=True, - enum_value=export_test_lib.DataAgnosticConverterTestValue.EnumOption - .OPTION_2, + enum_value=export_test_lib.DataAgnosticConverterTestValue.EnumOption.OPTION_2, urn_value=rdfvalue.RDFURN("aff4:/bar"), - datetime_value=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42)) + datetime_value=rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + ) converted_value = self.ConvertOriginalValue(original_value) serialized = converted_value.SerializeToBytes() diff --git a/grr/server/grr_response_server/export_converters/execute_response.py b/grr/server/grr_response_server/export_converters/execute_response.py index 52f06e9997..96f5379bbf 100644 --- a/grr/server/grr_response_server/export_converters/execute_response.py +++ b/grr/server/grr_response_server/export_converters/execute_response.py @@ -20,8 +20,9 @@ class ExecuteResponseConverter(base.ExportConverter): input_rdf_type = rdf_client_action.ExecuteResponse def Convert( - self, metadata: base.ExportedMetadata, - r: rdf_client_action.ExecuteResponse + self, + metadata: base.ExportedMetadata, + r: rdf_client_action.ExecuteResponse, ) -> Iterator[ExportedExecuteResponse]: yield ExportedExecuteResponse( metadata=metadata, @@ -32,4 +33,5 @@ def Convert( stderr=r.stderr, # ExecuteResponse is uint32 (for a reason unknown): to be on the safe # side, making sure it's not negative. - time_used_us=max(0, r.time_used)) + time_used_us=max(0, r.time_used), + ) diff --git a/grr/server/grr_response_server/export_converters/file.py b/grr/server/grr_response_server/export_converters/file.py index 33affaa75a..2078b79d28 100644 --- a/grr/server/grr_response_server/export_converters/file.py +++ b/grr/server/grr_response_server/export_converters/file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Classes for exporting StatEntry.""" + import hashlib import logging import time @@ -97,7 +98,8 @@ def ParseSignedData(signed_data, result): if auth.has_countersignature: result.cert_countersignature_chain_head_issuer = str( - auth.counter_chain_head[2]) + auth.counter_chain_head[2] + ) certs = [] for (issuer, serial), cert in auth.certificates.items(): @@ -116,7 +118,9 @@ def ParseSignedData(signed_data, result): serial=serial, subject=subject_dn, not_before_time=not_before_time_str, - not_after_time=not_after_time_str)) + not_after_time=not_after_time_str, + ) + ) result.cert_certificates = str(certs) # Verify_sigs library can basically throw all kinds of exceptions so @@ -143,8 +147,9 @@ def ParseFileHash(hash_obj, result): result.pecoff_hash_sha1 = str(hash_obj.pecoff_sha1) if hash_obj.HasField("signed_data"): - StatEntryToExportedFileConverter.ParseSignedData(hash_obj.signed_data[0], - result) + StatEntryToExportedFileConverter.ParseSignedData( + hash_obj.signed_data[0], result + ) def Convert(self, metadata, stat_entry): """Converts StatEntry to ExportedFile. @@ -190,7 +195,8 @@ def _CreateExportedFile(self, metadata, stat_entry): st_blocks=stat_entry.st_blocks, st_blksize=stat_entry.st_blksize, st_rdev=stat_entry.st_rdev, - symlink=stat_entry.symlink) + symlink=stat_entry.symlink, + ) _BATCH_SIZE = 5000 @@ -206,23 +212,27 @@ def _BatchConvert(self, metadata_value_pairs): # TODO(user): Deprecate client_urn in ExportedMetadata in favor of # client_id (to be added). client_paths.add( - db.ClientPath.FromPathSpec(metadata.client_urn.Basename(), - stat_entry.pathspec)) + db.ClientPath.FromPathSpec( + metadata.client_urn.Basename(), stat_entry.pathspec + ) + ) data_by_path = {} for chunk in file_store.StreamFilesChunks( - client_paths, max_size=self.MAX_CONTENT_SIZE): + client_paths, max_size=self.MAX_CONTENT_SIZE + ): data_by_path.setdefault(chunk.client_path, []).append(chunk.data) for metadata, stat_entry in fp_batch: result = self._CreateExportedFile(metadata, stat_entry) - clientpath = db.ClientPath.FromPathSpec(metadata.client_urn.Basename(), - stat_entry.pathspec) + clientpath = db.ClientPath.FromPathSpec( + metadata.client_urn.Basename(), stat_entry.pathspec + ) if self.options.export_files_contents: try: data = data_by_path[clientpath] - result.content = b"".join(data)[:self.MAX_CONTENT_SIZE] + result.content = b"".join(data)[: self.MAX_CONTENT_SIZE] result.content_sha256 = hashlib.sha256(result.content).hexdigest() except KeyError: pass @@ -271,10 +281,12 @@ def Convert(self, metadata: base.ExportedMetadata, stat_entry): result = ExportedRegistryKey( metadata=metadata, urn=stat_entry.AFF4Path(metadata.client_urn), - last_modified=stat_entry.st_mtime) + last_modified=stat_entry.st_mtime, + ) - if (stat_entry.HasField("registry_type") and - stat_entry.HasField("registry_data")): + if stat_entry.HasField("registry_type") and stat_entry.HasField( + "registry_data" + ): result.type = stat_entry.registry_type @@ -307,8 +319,10 @@ def _SeparateTypes(self, metadata_value_pairs): file_pairs = [] match_pairs = [] for metadata, result in metadata_value_pairs: - if (result.stat_entry.pathspec.pathtype == - rdf_paths.PathSpec.PathType.REGISTRY): + if ( + result.stat_entry.pathspec.pathtype + == rdf_paths.PathSpec.PathType.REGISTRY + ): registry_pairs.append((metadata, result.stat_entry)) else: file_pairs.append((metadata, result)) @@ -338,7 +352,6 @@ def BatchConvert(self, metadata_value_pairs): similar to statentry exports, and share some code, but different because we already have the hash available without having to go back to the database to retrieve it from the aff4 object. - """ result_generator = self._BatchConvert(metadata_value_pairs) @@ -349,7 +362,8 @@ def BatchConvert(self, metadata_value_pairs): def _BatchConvert(self, metadata_value_pairs): registry_pairs, file_pairs, match_pairs = self._SeparateTypes( - metadata_value_pairs) + metadata_value_pairs + ) for fp_batch in collection.Batch(file_pairs, self._BATCH_SIZE): if self.options.export_files_contents: @@ -358,15 +372,18 @@ def _BatchConvert(self, metadata_value_pairs): # TODO(user): Deprecate client_urn in ExportedMetadata in favor of # client_id (to be added). client_path = db.ClientPath.FromPathSpec( - metadata.client_urn.Basename(), ff_result.stat_entry.pathspec) + metadata.client_urn.Basename(), ff_result.stat_entry.pathspec + ) pathspec_by_client_path[client_path] = ff_result.stat_entry.pathspec data_by_pathspec = {} for chunk in file_store.StreamFilesChunks( - pathspec_by_client_path, max_size=self.MAX_CONTENT_SIZE): + pathspec_by_client_path, max_size=self.MAX_CONTENT_SIZE + ): pathspec = pathspec_by_client_path[chunk.client_path] - data_by_pathspec.setdefault(pathspec.CollapsePath(), - []).append(chunk.data) + data_by_pathspec.setdefault(pathspec.CollapsePath(), []).append( + chunk.data + ) for metadata, ff_result in fp_batch: result = self._CreateExportedFile(metadata, ff_result.stat_entry) @@ -379,8 +396,9 @@ def _BatchConvert(self, metadata_value_pairs): if self.options.export_files_contents: try: data = data_by_pathspec[ - ff_result.stat_entry.pathspec.CollapsePath()] - result.content = b"".join(data)[:self.MAX_CONTENT_SIZE] + ff_result.stat_entry.pathspec.CollapsePath() + ] + result.content = b"".join(data)[: self.MAX_CONTENT_SIZE] result.content_sha256 = hashlib.sha256(result.content).hexdigest() except KeyError: pass @@ -389,12 +407,14 @@ def _BatchConvert(self, metadata_value_pairs): # Now export the registry keys for result in export.ConvertValuesWithMetadata( - registry_pairs, options=self.options): + registry_pairs, options=self.options + ): yield result # Now export the grep matches. for result in export.ConvertValuesWithMetadata( - match_pairs, options=self.options): + match_pairs, options=self.options + ): yield result def Convert(self, metadata, result): @@ -410,30 +430,36 @@ def GetExportedResult(self, original_result, converter, metadata=None): """Converts original result via given converter..""" exported_results = list( - converter.Convert(metadata or base.ExportedMetadata(), original_result)) + converter.Convert(metadata or base.ExportedMetadata(), original_result) + ) if not exported_results: - raise export.ExportError("Got 0 exported result when a single one " - "was expected.") + raise export.ExportError( + "Got 0 exported result when a single one was expected." + ) if len(exported_results) > 1: - raise export.ExportError("Got > 1 exported results when a single " - "one was expected, seems like a logical bug.") + raise export.ExportError( + "Got > 1 exported results when a single " + "one was expected, seems like a logical bug." + ) return exported_results[0] def IsRegistryStatEntry(self, original_result): """Checks if given RDFValue is a registry StatEntry.""" - return (original_result.pathspec.pathtype == - rdf_paths.PathSpec.PathType.REGISTRY) + return ( + original_result.pathspec.pathtype + == rdf_paths.PathSpec.PathType.REGISTRY + ) def IsFileStatEntry(self, original_result): """Checks if given RDFValue is a file StatEntry.""" - return (original_result.pathspec.pathtype in [ + return original_result.pathspec.pathtype in [ rdf_paths.PathSpec.PathType.OS, rdf_paths.PathSpec.PathType.TSK, rdf_paths.PathSpec.PathType.NTFS, - ]) + ] def BatchConvert(self, metadata_value_pairs): metadata_value_pairs = list(metadata_value_pairs) @@ -449,16 +475,20 @@ def BatchConvert(self, metadata_value_pairs): exported_registry_key = self.GetExportedResult( original_result, StatEntryToExportedRegistryKeyConverter(), - metadata=metadata) + metadata=metadata, + ) result = ExportedArtifactFilesDownloaderResult( - metadata=metadata, original_registry_key=exported_registry_key) + metadata=metadata, original_registry_key=exported_registry_key + ) elif self.IsFileStatEntry(original_result): exported_file = self.GetExportedResult( original_result, StatEntryToExportedFileConverter(), - metadata=metadata) + metadata=metadata, + ) result = ExportedArtifactFilesDownloaderResult( - metadata=metadata, original_file=exported_file) + metadata=metadata, original_file=exported_file + ) else: # TODO(user): if original_result is not a registry key or a file, # we should still somehow export the data, otherwise the user will get @@ -501,7 +531,8 @@ def BatchConvert(self, metadata_value_pairs): # matter what type it has, we want it in the export output. original_pairs = [(m, v.original_result) for m, v in metadata_value_pairs] for result in export.ConvertValuesWithMetadata( - original_pairs, options=None): + original_pairs, options=None + ): yield result def Convert(self, metadata, value): diff --git a/grr/server/grr_response_server/export_converters/file_test.py b/grr/server/grr_response_server/export_converters/file_test.py index d7c7951b9e..e6a7f2c23d 100644 --- a/grr/server/grr_response_server/export_converters/file_test.py +++ b/grr/server/grr_response_server/export_converters/file_test.py @@ -18,6 +18,7 @@ from grr_response_server.flows.general import collectors from grr_response_server.flows.general import file_finder from grr_response_server.flows.general import transfer +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import action_mocks from grr.test_lib import export_test_lib @@ -31,21 +32,24 @@ class StatEntryToExportedFileConverterTest(export_test_lib.ExportTestBase): def testStatEntryToExportedFileConverterWithMissingAFF4File(self): stat = rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS), + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ), st_mode=33184, st_ino=1063090, st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1331331331) + st_btime=1331331331, + ) converter = file.StatEntryToExportedFileConverter() results = list(converter.Convert(self.metadata, stat)) self.assertLen(results, 1) self.assertEqual(results[0].basename, "path") - self.assertEqual(results[0].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + results[0].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) self.assertEqual(results[0].st_mode, 33184) self.assertEqual(results[0].st_ino, 1063090) self.assertEqual(results[0].st_atime, 1336469177) @@ -62,9 +66,11 @@ def testStatEntryToExportedFileConverterWithMissingAFF4File(self): def testStatEntryToExportedFileConverterWithFetchedAFF4File(self): pathspec = rdf_paths.PathSpec( pathtype=rdf_paths.PathSpec.PathType.OS, - path=os.path.join(self.base_path, "winexec_img.dd")) + path=os.path.join(self.base_path, "winexec_img.dd"), + ) pathspec.Append( - path="/Ext2IFS_1_10b.exe", pathtype=rdf_paths.PathSpec.PathType.TSK) + path="/Ext2IFS_1_10b.exe", pathtype=rdf_paths.PathSpec.PathType.TSK + ) client_mock = action_mocks.GetFileClientMock() flow_test_lib.TestFlowHelper( @@ -72,12 +78,15 @@ def testStatEntryToExportedFileConverterWithFetchedAFF4File(self): client_mock, creator=self.test_username, client_id=self.client_id, - pathspec=pathspec) + pathspec=pathspec, + ) - path_info = data_store.REL_DB.ReadPathInfo( + proto_path_info = data_store.REL_DB.ReadPathInfo( self.client_id, rdf_objects.PathInfo.PathType.TSK, - components=tuple(pathspec.CollapsePath().lstrip("/").split("/"))) + components=tuple(pathspec.CollapsePath().lstrip("/").split("/")), + ) + path_info = mig_objects.ToRDFPathInfo(proto_path_info) stat = path_info.stat_entry self.assertTrue(stat) @@ -96,20 +105,24 @@ def testStatEntryToExportedFileConverterWithFetchedAFF4File(self): # Convert again, now specifying export_files_contents=True in options. converter = file.StatEntryToExportedFileConverter( - options=base.ExportOptions(export_files_contents=True)) + options=base.ExportOptions(export_files_contents=True) + ) results = list(converter.Convert(self.metadata, stat)) self.assertTrue(results[0].content) self.assertEqual( results[0].content_sha256, - "69264282ca1a3d4e7f9b1f43720f719a4ea47964f0bfd1b2ba88424f1c61395d") + "69264282ca1a3d4e7f9b1f43720f719a4ea47964f0bfd1b2ba88424f1c61395d", + ) self.assertEqual("", results[0].metadata.annotations) def testStatEntryToExportedFileConverterWithHashedAFF4File(self): pathspec = rdf_paths.PathSpec( pathtype=rdf_paths.PathSpec.PathType.OS, - path=os.path.join(self.base_path, "winexec_img.dd")) + path=os.path.join(self.base_path, "winexec_img.dd"), + ) pathspec.Append( - path="/Ext2IFS_1_10b.exe", pathtype=rdf_paths.PathSpec.PathType.TSK) + path="/Ext2IFS_1_10b.exe", pathtype=rdf_paths.PathSpec.PathType.TSK + ) client_mock = action_mocks.GetFileClientMock() flow_test_lib.TestFlowHelper( @@ -117,20 +130,24 @@ def testStatEntryToExportedFileConverterWithHashedAFF4File(self): client_mock, creator=self.test_username, client_id=self.client_id, - pathspec=pathspec) + pathspec=pathspec, + ) path_info = rdf_objects.PathInfo.FromPathSpec(pathspec) - path_info = data_store.REL_DB.ReadPathInfo(self.client_id, - path_info.path_type, - tuple(path_info.components)) + proto_path_info = data_store.REL_DB.ReadPathInfo( + self.client_id, path_info.path_type, tuple(path_info.components) + ) + path_info = mig_objects.ToRDFPathInfo(proto_path_info) hash_value = path_info.hash_entry self.assertTrue(hash_value) converter = file.StatEntryToExportedFileConverter() results = list( - converter.Convert(self.metadata, - rdf_client_fs.StatEntry(pathspec=pathspec))) + converter.Convert( + self.metadata, rdf_client_fs.StatEntry(pathspec=pathspec) + ) + ) # Even though the file has a hash, it's not stored in StatEntry and # doesn't influence the result. Note: this is a change in behavior. @@ -148,17 +165,22 @@ def testExportedFileConverterIgnoresRegistryKeys(self): st_size=51, st_mtime=1247546054, pathspec=rdf_paths.PathSpec( - path="/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" - "CurrentVersion/Run/Sidebar", - pathtype=rdf_paths.PathSpec.PathType.REGISTRY)) + path=( + "/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" + "CurrentVersion/Run/Sidebar" + ), + pathtype=rdf_paths.PathSpec.PathType.REGISTRY, + ), + ) converter = file.StatEntryToExportedFileConverter() results = list(converter.Convert(self.metadata, stat)) self.assertFalse(results) -class StatEntryToExportedRegistryKeyConverterTest(export_test_lib.ExportTestBase - ): +class StatEntryToExportedRegistryKeyConverterTest( + export_test_lib.ExportTestBase +): """Tests for StatEntryToExportedRegistryKeyConverter.""" def testStatEntryToExportedRegistryKeyConverter(self): @@ -168,34 +190,45 @@ def testStatEntryToExportedRegistryKeyConverter(self): st_mtime=1247546054, registry_type=rdf_client_fs.StatEntry.RegistryType.REG_EXPAND_SZ, pathspec=rdf_paths.PathSpec( - path="/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" - "CurrentVersion/Run/Sidebar", - pathtype=rdf_paths.PathSpec.PathType.REGISTRY), - registry_data=rdf_protodict.DataBlob(string="Sidebar.exe")) + path=( + "/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" + "CurrentVersion/Run/Sidebar" + ), + pathtype=rdf_paths.PathSpec.PathType.REGISTRY, + ), + registry_data=rdf_protodict.DataBlob(string="Sidebar.exe"), + ) converter = file.StatEntryToExportedRegistryKeyConverter() results = list(converter.Convert(self.metadata, stat)) self.assertLen(results, 1) self.assertEqual( - results[0].urn, "aff4:/%s/registry/HKEY_USERS/S-1-5-20/Software/" - "Microsoft/Windows/CurrentVersion/Run/Sidebar" % self.client_id) - self.assertEqual(results[0].last_modified, - rdfvalue.RDFDatetimeSeconds(1247546054)) - self.assertEqual(results[0].type, - rdf_client_fs.StatEntry.RegistryType.REG_EXPAND_SZ) + results[0].urn, + "aff4:/%s/registry/HKEY_USERS/S-1-5-20/Software/" + "Microsoft/Windows/CurrentVersion/Run/Sidebar" + % self.client_id, + ) + self.assertEqual( + results[0].last_modified, rdfvalue.RDFDatetimeSeconds(1247546054) + ) + self.assertEqual( + results[0].type, rdf_client_fs.StatEntry.RegistryType.REG_EXPAND_SZ + ) self.assertEqual(results[0].data, b"Sidebar.exe") def testRegistryKeyConverterIgnoresNonRegistryStatEntries(self): stat = rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS), + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ), st_mode=33184, st_ino=1063090, st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1333333331) + st_btime=1333333331, + ) converter = file.StatEntryToExportedRegistryKeyConverter() results = list(converter.Convert(self.metadata, stat)) @@ -209,9 +242,13 @@ def testRegistryKeyConverterWorksWithRegistryKeys(self): st_size=51, st_mtime=1247546054, pathspec=rdf_paths.PathSpec( - path="/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" - "CurrentVersion/Run/Sidebar", - pathtype=rdf_paths.PathSpec.PathType.REGISTRY)) + path=( + "/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" + "CurrentVersion/Run/Sidebar" + ), + pathtype=rdf_paths.PathSpec.PathType.REGISTRY, + ), + ) converter = file.StatEntryToExportedRegistryKeyConverter() results = list(converter.Convert(self.metadata, stat)) @@ -219,11 +256,15 @@ def testRegistryKeyConverterWorksWithRegistryKeys(self): self.assertLen(results, 1) self.assertEqual( results[0].urn, - rdfvalue.RDFURN("aff4:/%s/registry/HKEY_USERS/S-1-5-20/Software/" - "Microsoft/Windows/CurrentVersion/Run/Sidebar" % - self.client_id)) - self.assertEqual(results[0].last_modified, - rdfvalue.RDFDatetimeSeconds(1247546054)) + rdfvalue.RDFURN( + "aff4:/%s/registry/HKEY_USERS/S-1-5-20/Software/" + "Microsoft/Windows/CurrentVersion/Run/Sidebar" + % self.client_id + ), + ) + self.assertEqual( + results[0].last_modified, rdfvalue.RDFDatetimeSeconds(1247546054) + ) self.assertEqual(results[0].data, b"") self.assertEqual(results[0].type, 0) @@ -234,12 +275,15 @@ class FileFinderResultConverterTest(export_test_lib.ExportTestBase): @export_test_lib.WithAllExportConverters def testFileFinderResultExportConverter(self): pathspec = rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ) match1 = rdf_client.BufferReference( - offset=42, length=43, data=b"somedata1", pathspec=pathspec) + offset=42, length=43, data=b"somedata1", pathspec=pathspec + ) match2 = rdf_client.BufferReference( - offset=44, length=45, data=b"somedata2", pathspec=pathspec) + offset=44, length=45, data=b"somedata2", pathspec=pathspec + ) stat_entry = rdf_client_fs.StatEntry( pathspec=pathspec, st_mode=33184, @@ -247,10 +291,12 @@ def testFileFinderResultExportConverter(self): st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1313131313) + st_btime=1313131313, + ) file_finder_result = rdf_file_finder.FileFinderResult( - stat_entry=stat_entry, matches=[match1, match2]) + stat_entry=stat_entry, matches=[match1, match2] + ) converter = file.FileFinderResultConverter() results = list(converter.Convert(self.metadata, file_finder_result)) @@ -262,8 +308,9 @@ def testFileFinderResultExportConverter(self): self.assertLen(exported_files, 1) self.assertEqual(exported_files[0].basename, "path") - self.assertEqual(exported_files[0].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + exported_files[0].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) self.assertEqual(exported_files[0].st_mode, 33184) self.assertEqual(exported_files[0].st_ino, 1063090) self.assertEqual(exported_files[0].st_atime, 1336469177) @@ -279,7 +326,8 @@ def testFileFinderResultExportConverter(self): # We expect 2 ExportedMatch instances in the results exported_matches = [ - result for result in results + result + for result in results if isinstance(result, buffer_reference.ExportedMatch) ] exported_matches = sorted(exported_matches, key=lambda x: x.offset) @@ -288,14 +336,16 @@ def testFileFinderResultExportConverter(self): self.assertEqual(exported_matches[0].offset, 42) self.assertEqual(exported_matches[0].length, 43) self.assertEqual(exported_matches[0].data, b"somedata1") - self.assertEqual(exported_matches[0].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + exported_matches[0].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) self.assertEqual(exported_matches[1].offset, 44) self.assertEqual(exported_matches[1].length, 45) self.assertEqual(exported_matches[1].data, b"somedata2") - self.assertEqual(exported_matches[1].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + exported_matches[1].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) # Also test registry entries. data = rdf_protodict.DataBlob() @@ -304,7 +354,9 @@ def testFileFinderResultExportConverter(self): registry_type="REG_SZ", registry_data=data, pathspec=rdf_paths.PathSpec( - path="HKEY_USERS/S-1-1-1-1/Software", pathtype="REGISTRY")) + path="HKEY_USERS/S-1-1-1-1/Software", pathtype="REGISTRY" + ), + ) file_finder_result = rdf_file_finder.FileFinderResult(stat_entry=stat_entry) converter = file.FileFinderResultConverter() results = list(converter.Convert(self.metadata, file_finder_result)) @@ -316,13 +368,16 @@ def testFileFinderResultExportConverter(self): self.assertEqual(result.data, b"testdata") self.assertEqual( result.urn, - "aff4:/%s/registry/HKEY_USERS/S-1-1-1-1/Software" % self.client_id) + "aff4:/%s/registry/HKEY_USERS/S-1-1-1-1/Software" % self.client_id, + ) @export_test_lib.WithAllExportConverters def testFileFinderResultExportConverterConvertsBufferRefsWithoutPathspecs( - self): + self, + ): pathspec = rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ) match1 = rdf_client.BufferReference(offset=42, length=43, data=b"somedata1") match2 = rdf_client.BufferReference(offset=44, length=45, data=b"somedata2") @@ -333,17 +388,20 @@ def testFileFinderResultExportConverterConvertsBufferRefsWithoutPathspecs( st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1313131313) + st_btime=1313131313, + ) file_finder_result = rdf_file_finder.FileFinderResult( - stat_entry=stat_entry, matches=[match1, match2]) + stat_entry=stat_entry, matches=[match1, match2] + ) converter = file.FileFinderResultConverter() results = list(converter.Convert(self.metadata, file_finder_result)) # We expect 2 ExportedMatch instances in the results exported_matches = [ - result for result in results + result + for result in results if isinstance(result, buffer_reference.ExportedMatch) ] exported_matches = sorted(exported_matches, key=lambda x: x.offset) @@ -352,23 +410,28 @@ def testFileFinderResultExportConverterConvertsBufferRefsWithoutPathspecs( self.assertEqual(exported_matches[0].offset, 42) self.assertEqual(exported_matches[0].length, 43) self.assertEqual(exported_matches[0].data, b"somedata1") - self.assertEqual(exported_matches[0].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + exported_matches[0].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) self.assertEqual(exported_matches[1].offset, 44) self.assertEqual(exported_matches[1].length, 45) self.assertEqual(exported_matches[1].data, b"somedata2") - self.assertEqual(exported_matches[1].urn, - "aff4:/%s/fs/os/some/path" % self.client_id) + self.assertEqual( + exported_matches[1].urn, "aff4:/%s/fs/os/some/path" % self.client_id + ) def testFileFinderResultExportConverterConvertsHashes(self): pathspec = rdf_paths.PathSpec( - path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/some/path", pathtype=rdf_paths.PathSpec.PathType.OS + ) pathspec2 = rdf_paths.PathSpec( - path="/some/path2", pathtype=rdf_paths.PathSpec.PathType.OS) + path="/some/path2", pathtype=rdf_paths.PathSpec.PathType.OS + ) sha256 = binascii.unhexlify( - "0e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5") + "0e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5" + ) sha1 = binascii.unhexlify("7dd6bee591dfcb6d75eb705405302c3eab65e21a") md5 = binascii.unhexlify("bb0a15eefe63fd41f8dc9dee01c5cf9a") pecoff_md5 = binascii.unhexlify("7dd6bee591dfcb6d75eb705405302c3eab65e21a") @@ -381,16 +444,19 @@ def testFileFinderResultExportConverterConvertsHashes(self): st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1331133113) + st_btime=1331133113, + ) hash_entry = rdf_crypto.Hash( sha256=sha256, sha1=sha1, md5=md5, pecoff_md5=pecoff_md5, - pecoff_sha1=pecoff_sha1) + pecoff_sha1=pecoff_sha1, + ) sha256 = binascii.unhexlify( - "9e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5") + "9e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5" + ) sha1 = binascii.unhexlify("6dd6bee591dfcb6d75eb705405302c3eab65e21a") md5 = binascii.unhexlify("8b0a15eefe63fd41f8dc9dee01c5cf9a") pecoff_md5 = binascii.unhexlify("1dd6bee591dfcb6d75eb705405302c3eab65e21a") @@ -403,65 +469,86 @@ def testFileFinderResultExportConverterConvertsHashes(self): st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1331331331) + st_btime=1331331331, + ) hash_entry2 = rdf_crypto.Hash( sha256=sha256, sha1=sha1, md5=md5, pecoff_md5=pecoff_md5, - pecoff_sha1=pecoff_sha1) + pecoff_sha1=pecoff_sha1, + ) file_finder_result = rdf_file_finder.FileFinderResult( - stat_entry=stat_entry, hash_entry=hash_entry) + stat_entry=stat_entry, hash_entry=hash_entry + ) file_finder_result2 = rdf_file_finder.FileFinderResult( - stat_entry=stat_entry2, hash_entry=hash_entry2) + stat_entry=stat_entry2, hash_entry=hash_entry2 + ) converter = file.FileFinderResultConverter() results = list( - converter.BatchConvert([(self.metadata, file_finder_result), - (self.metadata, file_finder_result2)])) + converter.BatchConvert([ + (self.metadata, file_finder_result), + (self.metadata, file_finder_result2), + ]) + ) exported_files = [ result for result in results if isinstance(result, file.ExportedFile) ] self.assertLen(exported_files, 2) - self.assertCountEqual([x.basename for x in exported_files], - ["path", "path2"]) + self.assertCountEqual( + [x.basename for x in exported_files], ["path", "path2"] + ) for export_result in exported_files: if export_result.basename == "path": self.assertEqual( export_result.hash_sha256, - "0e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4" - "f06017acdb5") - self.assertEqual(export_result.hash_sha1, - "7dd6bee591dfcb6d75eb705405302c3eab65e21a") - self.assertEqual(export_result.hash_md5, - "bb0a15eefe63fd41f8dc9dee01c5cf9a") - self.assertEqual(export_result.pecoff_hash_md5, - "7dd6bee591dfcb6d75eb705405302c3eab65e21a") - self.assertEqual(export_result.pecoff_hash_sha1, - "7dd6bee591dfcb6d75eb705405302c3eab65e21a") + "0e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5", + ) + self.assertEqual( + export_result.hash_sha1, "7dd6bee591dfcb6d75eb705405302c3eab65e21a" + ) + self.assertEqual( + export_result.hash_md5, "bb0a15eefe63fd41f8dc9dee01c5cf9a" + ) + self.assertEqual( + export_result.pecoff_hash_md5, + "7dd6bee591dfcb6d75eb705405302c3eab65e21a", + ) + self.assertEqual( + export_result.pecoff_hash_sha1, + "7dd6bee591dfcb6d75eb705405302c3eab65e21a", + ) elif export_result.basename == "path2": self.assertEqual(export_result.basename, "path2") self.assertEqual( export_result.hash_sha256, - "9e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4" - "f06017acdb5") - self.assertEqual(export_result.hash_sha1, - "6dd6bee591dfcb6d75eb705405302c3eab65e21a") - self.assertEqual(export_result.hash_md5, - "8b0a15eefe63fd41f8dc9dee01c5cf9a") - self.assertEqual(export_result.pecoff_hash_md5, - "1dd6bee591dfcb6d75eb705405302c3eab65e21a") - self.assertEqual(export_result.pecoff_hash_sha1, - "1dd6bee591dfcb6d75eb705405302c3eab65e21a") + "9e8dc93e150021bb4752029ebbff51394aa36f069cf19901578e4f06017acdb5", + ) + self.assertEqual( + export_result.hash_sha1, "6dd6bee591dfcb6d75eb705405302c3eab65e21a" + ) + self.assertEqual( + export_result.hash_md5, "8b0a15eefe63fd41f8dc9dee01c5cf9a" + ) + self.assertEqual( + export_result.pecoff_hash_md5, + "1dd6bee591dfcb6d75eb705405302c3eab65e21a", + ) + self.assertEqual( + export_result.pecoff_hash_sha1, + "1dd6bee591dfcb6d75eb705405302c3eab65e21a", + ) def testFileFinderResultExportConverterConvertsContent(self): client_mock = action_mocks.FileFinderClientMockWithTimestamps() action = rdf_file_finder.FileFinderAction( - action_type=rdf_file_finder.FileFinderAction.Action.DOWNLOAD) + action_type=rdf_file_finder.FileFinderAction.Action.DOWNLOAD + ) path = os.path.join(self.base_path, "winexec_img.dd") flow_id = flow_test_lib.TestFlowHelper( @@ -471,7 +558,8 @@ def testFileFinderResultExportConverterConvertsContent(self): paths=[path], pathtype=rdf_paths.PathSpec.PathType.OS, action=action, - creator=self.test_username) + creator=self.test_username, + ) flow_results = flow_test_lib.GetFlowResults(self.client_id, flow_id) self.assertLen(flow_results, 1) @@ -489,17 +577,20 @@ def testFileFinderResultExportConverterConvertsContent(self): # Convert again, now specifying export_files_contents=True in options. converter = file.FileFinderResultConverter( - options=base.ExportOptions(export_files_contents=True)) + options=base.ExportOptions(export_files_contents=True) + ) results = list(converter.Convert(self.metadata, flow_results[0])) self.assertTrue(results[0].content) self.assertEqual( results[0].content_sha256, - "0652da33d5602c165396856540c173cd37277916fba07a9bf3080bc5a6236f03") + "0652da33d5602c165396856540c173cd37277916fba07a9bf3080bc5a6236f03", + ) -class ArtifactFilesDownloaderResultConverterTest(export_test_lib.ExportTestBase - ): +class ArtifactFilesDownloaderResultConverterTest( + export_test_lib.ExportTestBase +): """Tests for ArtifactFilesDownloaderResultConverter.""" def setUp(self): @@ -508,18 +599,25 @@ def setUp(self): self.registry_stat = rdf_client_fs.StatEntry( registry_type=rdf_client_fs.StatEntry.RegistryType.REG_SZ, pathspec=rdf_paths.PathSpec( - path="/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" - "CurrentVersion/Run/Sidebar", - pathtype=rdf_paths.PathSpec.PathType.REGISTRY), - registry_data=rdf_protodict.DataBlob(string="C:\\Windows\\Sidebar.exe")) + path=( + "/HKEY_USERS/S-1-5-20/Software/Microsoft/Windows/" + "CurrentVersion/Run/Sidebar" + ), + pathtype=rdf_paths.PathSpec.PathType.REGISTRY, + ), + registry_data=rdf_protodict.DataBlob(string="C:\\Windows\\Sidebar.exe"), + ) self.file_stat = rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/tmp/bar.exe", pathtype=rdf_paths.PathSpec.PathType.OS)) + path="/tmp/bar.exe", pathtype=rdf_paths.PathSpec.PathType.OS + ) + ) def testExportsOriginalResultAnywayIfItIsNotStatEntry(self): result = collectors.ArtifactFilesDownloaderResult( - original_result=export_test_lib.DataAgnosticConverterTestValue()) + original_result=export_test_lib.DataAgnosticConverterTestValue() + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) @@ -527,14 +625,18 @@ def testExportsOriginalResultAnywayIfItIsNotStatEntry(self): # Test that something gets exported and that this something wasn't # produced by ArtifactFilesDownloaderResultConverter. self.assertLen(converted, 1) - self.assertNotIsInstance(converted[0], - file.ExportedArtifactFilesDownloaderResult) + self.assertNotIsInstance( + converted[0], file.ExportedArtifactFilesDownloaderResult + ) def testExportsOriginalResultIfOriginalResultIsNotRegistryOrFileStatEntry( - self): + self, + ): stat = rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="some/path", pathtype=rdf_paths.PathSpec.PathType.TMPFILE)) + path="some/path", pathtype=rdf_paths.PathSpec.PathType.TMPFILE + ) + ) result = collectors.ArtifactFilesDownloaderResult(original_result=stat) converter = file.ArtifactFilesDownloaderResultConverter() @@ -543,25 +645,29 @@ def testExportsOriginalResultIfOriginalResultIsNotRegistryOrFileStatEntry( # Test that something gets exported and that this something wasn't # produced by ArtifactFilesDownloaderResultConverter. self.assertLen(converted, 1) - self.assertNotIsInstance(converted[0], - file.ExportedArtifactFilesDownloaderResult) + self.assertNotIsInstance( + converted[0], file.ExportedArtifactFilesDownloaderResult + ) def testYieldsOneResultAndOneOriginalValueForFileStatEntry(self): result = collectors.ArtifactFilesDownloaderResult( - original_result=self.file_stat) + original_result=self.file_stat + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) default_exports = [ - v for v in converted + v + for v in converted if not isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(default_exports, 1) self.assertLen(default_exports, 1) downloader_exports = [ - v for v in converted + v + for v in converted if isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(downloader_exports, 1) @@ -569,30 +675,36 @@ def testYieldsOneResultAndOneOriginalValueForFileStatEntry(self): def testYieldsOneResultForRegistryStatEntryIfNoPathspecsWereFound(self): result = collectors.ArtifactFilesDownloaderResult( - original_result=self.registry_stat) + original_result=self.registry_stat + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) downloader_exports = [ - v for v in converted + v + for v in converted if isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(downloader_exports, 1) self.assertEqual(downloader_exports[0].original_registry_key.type, "REG_SZ") - self.assertEqual(downloader_exports[0].original_registry_key.data, - b"C:\\Windows\\Sidebar.exe") + self.assertEqual( + downloader_exports[0].original_registry_key.data, + b"C:\\Windows\\Sidebar.exe", + ) def testIncludesRegistryStatEntryFoundPathspecIntoYieldedResult(self): result = collectors.ArtifactFilesDownloaderResult( original_result=self.registry_stat, - found_pathspec=rdf_paths.PathSpec(path="foo", pathtype="OS")) + found_pathspec=rdf_paths.PathSpec(path="foo", pathtype="OS"), + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) downloader_exports = [ - v for v in converted + v + for v in converted if isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(downloader_exports, 1) @@ -600,13 +712,15 @@ def testIncludesRegistryStatEntryFoundPathspecIntoYieldedResult(self): def testIncludesFileStatEntryFoundPathspecIntoYieldedResult(self): result = collectors.ArtifactFilesDownloaderResult( - original_result=self.file_stat, found_pathspec=self.file_stat.pathspec) + original_result=self.file_stat, found_pathspec=self.file_stat.pathspec + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) downloader_exports = [ - v for v in converted + v + for v in converted if isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(downloader_exports, 1) @@ -617,13 +731,16 @@ def testIncludesDownloadedFileIntoResult(self): original_result=self.registry_stat, found_pathspec=rdf_paths.PathSpec(path="foo", pathtype="OS"), downloaded_file=rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="foo", pathtype="OS"))) + pathspec=rdf_paths.PathSpec(path="foo", pathtype="OS") + ), + ) converter = file.ArtifactFilesDownloaderResultConverter() converted = list(converter.Convert(self.metadata, result)) downloader_exports = [ - v for v in converted + v + for v in converted if isinstance(v, file.ExportedArtifactFilesDownloaderResult) ] self.assertLen(downloader_exports, 1) diff --git a/grr/server/grr_response_server/export_converters/grr_message.py b/grr/server/grr_response_server/export_converters/grr_message.py index b487130ab3..7f0b7b8118 100644 --- a/grr/server/grr_response_server/export_converters/grr_message.py +++ b/grr/server/grr_response_server/export_converters/grr_message.py @@ -101,15 +101,17 @@ def BatchConvert(self, metadata_value_pairs): # Create a dict of values for conversion keyed by type, so we can # apply the right converters to the right object types if cls_name not in data_by_type: - converters_classes = export_converters_registry.GetConvertersByValue( - message.payload) + converters_classes = ( + export_converters_registry.GetConvertersByValue(message.payload) + ) data_by_type[cls_name] = { "converters": [cls(self.options) for cls in converters_classes], - "batch_data": [(new_metadata, message.payload)] + "batch_data": [(new_metadata, message.payload)], } else: data_by_type[cls_name]["batch_data"].append( - (new_metadata, message.payload)) + (new_metadata, message.payload) + ) except KeyError: pass diff --git a/grr/server/grr_response_server/export_converters/grr_message_test.py b/grr/server/grr_response_server/export_converters/grr_message_test.py index 232699f0d3..6122507028 100644 --- a/grr/server/grr_response_server/export_converters/grr_message_test.py +++ b/grr/server/grr_response_server/export_converters/grr_message_test.py @@ -85,18 +85,23 @@ def testGrrMessageConverter(self): fixture_test_lib.ClientFixture(self.client_id) metadata = base.ExportedMetadata( - source_urn=rdfvalue.RDFURN("aff4:/hunts/" + str(queues.HUNTS) + - ":000000/Results")) + source_urn=rdfvalue.RDFURN( + "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results" + ) + ) converter = grr_message.GrrMessageConverter() with test_lib.FakeTime(2): results = list(converter.Convert(metadata, msg)) self.assertLen(results, 1) - self.assertEqual(results[0].timestamp, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2)) - self.assertEqual(results[0].source_urn, - "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results") + self.assertEqual( + results[0].timestamp, rdfvalue.RDFDatetime.FromSecondsSinceEpoch(2) + ) + self.assertEqual( + results[0].source_urn, + "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results", + ) @export_test_lib.WithExportConverter(DummyTestRDFValue4ToMetadataConverter) def testGrrMessageConverterWithOneMissingClient(self): @@ -113,22 +118,30 @@ def testGrrMessageConverterWithOneMissingClient(self): msg2.source = client_id_2 metadata1 = base.ExportedMetadata( - source_urn=rdfvalue.RDFURN("aff4:/hunts/" + str(queues.HUNTS) + - ":000000/Results")) + source_urn=rdfvalue.RDFURN( + "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results" + ) + ) metadata2 = base.ExportedMetadata( - source_urn=rdfvalue.RDFURN("aff4:/hunts/" + str(queues.HUNTS) + - ":000001/Results")) + source_urn=rdfvalue.RDFURN( + "aff4:/hunts/" + str(queues.HUNTS) + ":000001/Results" + ) + ) converter = grr_message.GrrMessageConverter() with test_lib.FakeTime(3): results = list( - converter.BatchConvert([(metadata1, msg1), (metadata2, msg2)])) + converter.BatchConvert([(metadata1, msg1), (metadata2, msg2)]) + ) self.assertLen(results, 1) - self.assertEqual(results[0].timestamp, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3)) - self.assertEqual(results[0].source_urn, - "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results") + self.assertEqual( + results[0].timestamp, rdfvalue.RDFDatetime.FromSecondsSinceEpoch(3) + ) + self.assertEqual( + results[0].source_urn, + "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results", + ) @export_test_lib.WithExportConverter(DummyTestRDFValue3ConverterA) @export_test_lib.WithExportConverter(DummyTestRDFValue3ConverterB) @@ -145,22 +158,28 @@ def testGrrMessageConverterMultipleTypes(self): msg2.source = client_id metadata1 = base.ExportedMetadata( - source_urn=rdfvalue.RDFURN("aff4:/hunts/" + str(queues.HUNTS) + - ":000000/Results")) + source_urn=rdfvalue.RDFURN( + "aff4:/hunts/" + str(queues.HUNTS) + ":000000/Results" + ) + ) metadata2 = base.ExportedMetadata( - source_urn=rdfvalue.RDFURN("aff4:/hunts/" + str(queues.HUNTS) + - ":000001/Results")) + source_urn=rdfvalue.RDFURN( + "aff4:/hunts/" + str(queues.HUNTS) + ":000001/Results" + ) + ) converter = grr_message.GrrMessageConverter() with test_lib.FakeTime(3): results = list( - converter.BatchConvert([(metadata1, msg1), (metadata2, msg2)])) + converter.BatchConvert([(metadata1, msg1), (metadata2, msg2)]) + ) self.assertLen(results, 3) # RDFValue3 gets converted to RDFValue2 and RDFValue, RDFValue5 stays at 5. self.assertCountEqual( ["DummyTestRDFValue2", "DummyTestRDFValue1", "DummyTestRDFValue5"], - [x.__class__.__name__ for x in results]) + [x.__class__.__name__ for x in results], + ) def main(argv): diff --git a/grr/server/grr_response_server/export_converters/launchd_plist.py b/grr/server/grr_response_server/export_converters/launchd_plist.py index e387b3f8f7..4065b2e1ca 100644 --- a/grr/server/grr_response_server/export_converters/launchd_plist.py +++ b/grr/server/grr_response_server/export_converters/launchd_plist.py @@ -15,12 +15,14 @@ class ExportedLaunchdPlist(rdf_structs.RDFProtoStruct): def _CalendarIntervalToString( - i: rdf_plist.LaunchdStartCalendarIntervalEntry) -> str: + i: rdf_plist.LaunchdStartCalendarIntervalEntry, +) -> str: return f"{i.Month}-{i.Weekday}-{i.Day}-{i.Hour}-{i.Minute}" -def _DictEntryToString(e: Union[rdf_plist.PlistStringDictEntry, - rdf_plist.PlistBoolDictEntry]): +def _DictEntryToString( + e: Union[rdf_plist.PlistStringDictEntry, rdf_plist.PlistBoolDictEntry], +): return f"{e.name}={e.value}" @@ -29,8 +31,9 @@ class LaunchdPlistConverter(base.ExportConverter): input_rdf_type = rdf_plist.LaunchdPlist - def Convert(self, metadata: base.ExportedMetadata, - l: rdf_plist.LaunchdPlist) -> Iterator[ExportedLaunchdPlist]: + def Convert( + self, metadata: base.ExportedMetadata, l: rdf_plist.LaunchdPlist + ) -> Iterator[ExportedLaunchdPlist]: yield ExportedLaunchdPlist( metadata=metadata, launchd_file_path=l.path, @@ -45,16 +48,19 @@ def Convert(self, metadata: base.ExportedMetadata, on_demand=l.OnDemand, run_at_load=l.RunAtLoad, start_calendar_interval=" ".join( - _CalendarIntervalToString(i) for i in l.StartCalendarInterval), + _CalendarIntervalToString(i) for i in l.StartCalendarInterval + ), environment_variables=" ".join( - _DictEntryToString(e) for e in l.EnvironmentVariables), + _DictEntryToString(e) for e in l.EnvironmentVariables + ), standard_in_path=l.StandardInPath, standard_out_path=l.StandardOutPath, standard_error_path=l.StandardErrorPath, limit_load_to_hosts=" ".join(i for i in l.LimitLoadToHosts), limit_load_from_hosts=" ".join(i for i in l.LimitLoadFromHosts), limit_load_to_session_type=" ".join( - i for i in l.LimitLoadToSessionType), + i for i in l.LimitLoadToSessionType + ), enable_globbing=l.EnableGlobbing, enable_transactions=l.EnableTransactions, umask=l.Umask, @@ -81,7 +87,9 @@ def Convert(self, metadata: base.ExportedMetadata, keep_alive_successful_exit=l.KeepAliveDict.SuccessfulExit, keep_alive_network_state=l.KeepAliveDict.NetworkState, keep_alive_path_state=" ".join( - _DictEntryToString(e) for e in l.KeepAliveDict.PathState), + _DictEntryToString(e) for e in l.KeepAliveDict.PathState + ), keep_alive_other_job_enabled=" ".join( - _DictEntryToString(e) for e in l.KeepAliveDict.OtherJobEnabled), + _DictEntryToString(e) for e in l.KeepAliveDict.OtherJobEnabled + ), ) diff --git a/grr/server/grr_response_server/export_converters/launchd_plist_test.py b/grr/server/grr_response_server/export_converters/launchd_plist_test.py index 44bdf5fabf..017a8c6f55 100644 --- a/grr/server/grr_response_server/export_converters/launchd_plist_test.py +++ b/grr/server/grr_response_server/export_converters/launchd_plist_test.py @@ -25,9 +25,11 @@ def testExportsValueCorrectly(self): RunAtLoad=True, StartCalendarInterval=[ rdf_plist.LaunchdStartCalendarIntervalEntry( - Minute=1, Hour=2, Day=3, Weekday=4, Month=5), + Minute=1, Hour=2, Day=3, Weekday=4, Month=5 + ), rdf_plist.LaunchdStartCalendarIntervalEntry( - Minute=2, Hour=3, Day=4, Weekday=5, Month=6), + Minute=2, Hour=3, Day=4, Weekday=5, Month=6 + ), ], EnvironmentVariables=[ rdf_plist.PlistStringDictEntry(name="foo", value="bar"), @@ -73,7 +75,8 @@ def testExportsValueCorrectly(self): inetdCompatibilityWait=True, SoftResourceLimits=True, HardResourceLimits=True, - Sockets=True) + Sockets=True, + ) converter = launchd_plist.LaunchdPlistConverter() converted = list(converter.Convert(self.metadata, sample)) diff --git a/grr/server/grr_response_server/export_converters/memory.py b/grr/server/grr_response_server/export_converters/memory.py index 7b624260e1..aacf16dc2d 100644 --- a/grr/server/grr_response_server/export_converters/memory.py +++ b/grr/server/grr_response_server/export_converters/memory.py @@ -22,11 +22,13 @@ class ExportedProcessMemoryError(rdf_structs.RDFProtoStruct): class YaraProcessScanMatchConverter(base.ExportConverter): """Converter for YaraProcessScanMatch.""" + input_rdf_type = rdf_memory.YaraProcessScanMatch def Convert( - self, metadata: base.ExportedMetadata, - value: rdf_memory.YaraProcessScanMatch + self, + metadata: base.ExportedMetadata, + value: rdf_memory.YaraProcessScanMatch, ) -> Iterator[ExportedYaraProcessScanMatch]: """See base class.""" @@ -49,6 +51,7 @@ def Convert( class ProcessMemoryErrorConverter(base.ExportConverter): """Converter for ProcessMemoryError.""" + input_rdf_type = rdf_memory.ProcessMemoryError def Convert( @@ -61,4 +64,5 @@ def Convert( conv = process.ProcessToExportedProcessConverter(options=self.options) proc = next(iter(conv.Convert(metadata, value.process))) yield ExportedProcessMemoryError( - metadata=metadata, process=proc, error=value.error) + metadata=metadata, process=proc, error=value.error + ) diff --git a/grr/server/grr_response_server/export_converters/memory_test.py b/grr/server/grr_response_server/export_converters/memory_test.py index 657392905a..533a332a60 100644 --- a/grr/server/grr_response_server/export_converters/memory_test.py +++ b/grr/server/grr_response_server/export_converters/memory_test.py @@ -17,9 +17,11 @@ def GenerateSample(self, match, **kwargs): ppid=1, cmdline=["cmd.exe"], exe="c:\\windows\\cmd.exe", - ctime=1333718907167083) + ctime=1333718907167083, + ) return rdf_memory.YaraProcessScanMatch( - process=process, match=match, scan_time_us=42, **kwargs) + process=process, match=match, scan_time_us=42, **kwargs + ) def testExportsSingleMatchCorrectly(self): sample = self.GenerateSample([ @@ -27,7 +29,8 @@ def testExportsSingleMatchCorrectly(self): rule_name="foo", string_matches=[ rdf_memory.YaraStringMatch(string_id="bar", offset=5) - ]) + ], + ) ]) converter = memory.YaraProcessScanMatchConverter() @@ -77,12 +80,14 @@ def testExportsOneYaraMatchPerYaraStringMatch(self): string_matches=[ rdf_memory.YaraStringMatch(string_id="bar1", offset=5), rdf_memory.YaraStringMatch(string_id="bar2", offset=10), - ]), + ], + ), rdf_memory.YaraMatch( rule_name="foo2", string_matches=[ rdf_memory.YaraStringMatch(string_id="bar3", offset=15), - ]), + ], + ), ]) converter = memory.YaraProcessScanMatchConverter() @@ -115,7 +120,8 @@ def _GenerateSample(self, **kwargs): ppid=1, cmdline=["cmd.exe"], exe="c:\\windows\\cmd.exe", - ctime=1333718907167083) + ctime=1333718907167083, + ) return rdf_memory.ProcessMemoryError(process=process, **kwargs) def testExportsErrorCorrectly(self): diff --git a/grr/server/grr_response_server/export_converters/mig_base.py b/grr/server/grr_response_server/export_converters/mig_base.py index d9e893293e..ecf3c7365d 100644 --- a/grr/server/grr_response_server/export_converters/mig_base.py +++ b/grr/server/grr_response_server/export_converters/mig_base.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import base diff --git a/grr/server/grr_response_server/export_converters/mig_buffer_reference.py b/grr/server/grr_response_server/export_converters/mig_buffer_reference.py index eedd6f0a11..baf0ceff92 100644 --- a/grr/server/grr_response_server/export_converters/mig_buffer_reference.py +++ b/grr/server/grr_response_server/export_converters/mig_buffer_reference.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import buffer_reference diff --git a/grr/server/grr_response_server/export_converters/mig_client_summary.py b/grr/server/grr_response_server/export_converters/mig_client_summary.py index 07068882f0..9710b59171 100644 --- a/grr/server/grr_response_server/export_converters/mig_client_summary.py +++ b/grr/server/grr_response_server/export_converters/mig_client_summary.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import client_summary diff --git a/grr/server/grr_response_server/export_converters/mig_cron_tab_file.py b/grr/server/grr_response_server/export_converters/mig_cron_tab_file.py index 7313989cb8..a62f083375 100644 --- a/grr/server/grr_response_server/export_converters/mig_cron_tab_file.py +++ b/grr/server/grr_response_server/export_converters/mig_cron_tab_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import cron_tab_file diff --git a/grr/server/grr_response_server/export_converters/mig_execute_response.py b/grr/server/grr_response_server/export_converters/mig_execute_response.py index d42d466136..15ddd3c05b 100644 --- a/grr/server/grr_response_server/export_converters/mig_execute_response.py +++ b/grr/server/grr_response_server/export_converters/mig_execute_response.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import execute_response diff --git a/grr/server/grr_response_server/export_converters/mig_file.py b/grr/server/grr_response_server/export_converters/mig_file.py index 7c72959a0b..5967bd562e 100644 --- a/grr/server/grr_response_server/export_converters/mig_file.py +++ b/grr/server/grr_response_server/export_converters/mig_file.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import file diff --git a/grr/server/grr_response_server/export_converters/mig_launchd_plist.py b/grr/server/grr_response_server/export_converters/mig_launchd_plist.py index e184af2cb6..0f25712552 100644 --- a/grr/server/grr_response_server/export_converters/mig_launchd_plist.py +++ b/grr/server/grr_response_server/export_converters/mig_launchd_plist.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import launchd_plist diff --git a/grr/server/grr_response_server/export_converters/mig_memory.py b/grr/server/grr_response_server/export_converters/mig_memory.py index 1561798a7a..5e19172651 100644 --- a/grr/server/grr_response_server/export_converters/mig_memory.py +++ b/grr/server/grr_response_server/export_converters/mig_memory.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import memory diff --git a/grr/server/grr_response_server/export_converters/mig_network.py b/grr/server/grr_response_server/export_converters/mig_network.py index f710e005cc..cc123f9f61 100644 --- a/grr/server/grr_response_server/export_converters/mig_network.py +++ b/grr/server/grr_response_server/export_converters/mig_network.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import network diff --git a/grr/server/grr_response_server/export_converters/mig_process.py b/grr/server/grr_response_server/export_converters/mig_process.py index 9c20d49396..675490cf67 100644 --- a/grr/server/grr_response_server/export_converters/mig_process.py +++ b/grr/server/grr_response_server/export_converters/mig_process.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import process diff --git a/grr/server/grr_response_server/export_converters/mig_rdf_dict.py b/grr/server/grr_response_server/export_converters/mig_rdf_dict.py index aac7372532..0c9701a5d4 100644 --- a/grr/server/grr_response_server/export_converters/mig_rdf_dict.py +++ b/grr/server/grr_response_server/export_converters/mig_rdf_dict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import rdf_dict diff --git a/grr/server/grr_response_server/export_converters/mig_rdf_primitives.py b/grr/server/grr_response_server/export_converters/mig_rdf_primitives.py index b30a65b550..ce0cb4475a 100644 --- a/grr/server/grr_response_server/export_converters/mig_rdf_primitives.py +++ b/grr/server/grr_response_server/export_converters/mig_rdf_primitives.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import rdf_primitives diff --git a/grr/server/grr_response_server/export_converters/mig_software_package.py b/grr/server/grr_response_server/export_converters/mig_software_package.py index 0a3aacc92b..e94048468f 100644 --- a/grr/server/grr_response_server/export_converters/mig_software_package.py +++ b/grr/server/grr_response_server/export_converters/mig_software_package.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import software_package diff --git a/grr/server/grr_response_server/export_converters/mig_windows_service_info.py b/grr/server/grr_response_server/export_converters/mig_windows_service_info.py index c588608747..1a41b8008a 100644 --- a/grr/server/grr_response_server/export_converters/mig_windows_service_info.py +++ b/grr/server/grr_response_server/export_converters/mig_windows_service_info.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import export_pb2 from grr_response_server.export_converters import windows_service_info diff --git a/grr/server/grr_response_server/export_converters/network.py b/grr/server/grr_response_server/export_converters/network.py index fd707ca6ac..947be0302b 100644 --- a/grr/server/grr_response_server/export_converters/network.py +++ b/grr/server/grr_response_server/export_converters/network.py @@ -32,14 +32,16 @@ class ExportedNetworkInterface(rdf_structs.RDFProtoStruct): class NetworkConnectionToExportedNetworkConnectionConverter( - base.ExportConverter): + base.ExportConverter +): """Converts NetworkConnection to ExportedNetworkConnection.""" input_rdf_type = rdf_client_network.NetworkConnection def Convert( - self, metadata: base.ExportedMetadata, - conn: rdf_client_network.NetworkConnection + self, + metadata: base.ExportedMetadata, + conn: rdf_client_network.NetworkConnection, ) -> List[ExportedNetworkConnection]: """Converts a NetworkConnection into a ExportedNetworkConnection. @@ -60,7 +62,8 @@ def Convert( remote_address=conn.remote_address, state=conn.state, pid=conn.pid, - ctime=conn.ctime) + ctime=conn.ctime, + ) return [result] @@ -70,8 +73,9 @@ class InterfaceToExportedNetworkInterfaceConverter(base.ExportConverter): input_rdf_type = rdf_client_network.Interface def Convert( - self, metadata: base.ExportedMetadata, - interface: rdf_client_network.Interface + self, + metadata: base.ExportedMetadata, + interface: rdf_client_network.Interface, ) -> Iterator[ExportedNetworkInterface]: """Converts a Interface into ExportedNetworkInterfaces. @@ -97,7 +101,8 @@ def Convert( metadata=metadata, ifname=interface.ifname, ip4_addresses=" ".join(ip4_addresses), - ip6_addresses=" ".join(ip6_addresses)) + ip6_addresses=" ".join(ip6_addresses), + ) if interface.mac_address: result.mac_address = interface.mac_address.human_readable_address @@ -106,14 +111,16 @@ def Convert( class DNSClientConfigurationToExportedDNSClientConfiguration( - base.ExportConverter): + base.ExportConverter +): """Converts DNSClientConfiguration to ExportedDNSClientConfiguration.""" input_rdf_type = rdf_client_network.DNSClientConfiguration def Convert( - self, metadata: base.ExportedMetadata, - config: rdf_client_network.DNSClientConfiguration + self, + metadata: base.ExportedMetadata, + config: rdf_client_network.DNSClientConfiguration, ) -> Iterator[ExportedDNSClientConfiguration]: """Converts a DNSClientConfiguration into a ExportedDNSClientConfiguration. @@ -129,5 +136,6 @@ def Convert( result = ExportedDNSClientConfiguration( metadata=metadata, dns_servers=" ".join(config.dns_server), - dns_suffixes=" ".join(config.dns_suffix)) + dns_suffixes=" ".join(config.dns_suffix), + ) yield result diff --git a/grr/server/grr_response_server/export_converters/network_test.py b/grr/server/grr_response_server/export_converters/network_test.py index 315dd224e1..28907866f6 100644 --- a/grr/server/grr_response_server/export_converters/network_test.py +++ b/grr/server/grr_response_server/export_converters/network_test.py @@ -11,7 +11,8 @@ class NetworkConnectionToExportedNetworkConnectionConverterTest( - export_test_lib.ExportTestBase): + export_test_lib.ExportTestBase +): def testBasicConversion(self): conn = rdf_client_network.NetworkConnection( @@ -20,26 +21,34 @@ def testBasicConversion(self): local_address=rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=22), remote_address=rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=0), pid=2136, - ctime=123) + ctime=123, + ) converter = network.NetworkConnectionToExportedNetworkConnectionConverter() results = list(converter.Convert(self.metadata, conn)) self.assertLen(results, 1) - self.assertEqual(results[0].state, - rdf_client_network.NetworkConnection.State.LISTEN) - self.assertEqual(results[0].type, - rdf_client_network.NetworkConnection.Type.SOCK_STREAM) - self.assertEqual(results[0].local_address, - rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=22)) - self.assertEqual(results[0].remote_address, - rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=0)) + self.assertEqual( + results[0].state, rdf_client_network.NetworkConnection.State.LISTEN + ) + self.assertEqual( + results[0].type, rdf_client_network.NetworkConnection.Type.SOCK_STREAM + ) + self.assertEqual( + results[0].local_address, + rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=22), + ) + self.assertEqual( + results[0].remote_address, + rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=0), + ) self.assertEqual(results[0].pid, 2136) self.assertEqual(results[0].ctime, 123) class InterfaceToExportedNetworkInterfaceConverterTest( - export_test_lib.ExportTestBase): + export_test_lib.ExportTestBase +): def testInterfaceToExportedNetworkInterfaceConverter(self): mac_address_bytes = b"123456" @@ -59,10 +68,12 @@ def testInterfaceToExportedNetworkInterfaceConverter(self): ), rdf_client_network.NetworkAddress( address_type=rdf_client_network.NetworkAddress.Family.INET6, - packed_bytes=socket.inet_pton(socket.AF_INET6, - "2001:720:1500:1::a100"), - ) - ]) + packed_bytes=socket.inet_pton( + socket.AF_INET6, "2001:720:1500:1::a100" + ), + ), + ], + ) converter = network.InterfaceToExportedNetworkInterfaceConverter() results = list(converter.Convert(self.metadata, interface)) @@ -74,13 +85,15 @@ def testInterfaceToExportedNetworkInterfaceConverter(self): class DNSClientConfigurationToExportedDNSClientConfigurationTest( - export_test_lib.ExportTestBase): + export_test_lib.ExportTestBase +): def testDNSClientConfigurationToExportedDNSClientConfiguration(self): dns_servers = ["192.168.1.1", "8.8.8.8"] dns_suffixes = ["internal.company.com", "company.com"] config = rdf_client_network.DNSClientConfiguration( - dns_server=dns_servers, dns_suffix=dns_suffixes) + dns_server=dns_servers, dns_suffix=dns_suffixes + ) converter = network.DNSClientConfigurationToExportedDNSClientConfiguration() results = list(converter.Convert(self.metadata, config)) diff --git a/grr/server/grr_response_server/export_converters/osquery.py b/grr/server/grr_response_server/export_converters/osquery.py index 302f8c6cbe..c257b3c054 100644 --- a/grr/server/grr_response_server/export_converters/osquery.py +++ b/grr/server/grr_response_server/export_converters/osquery.py @@ -42,10 +42,13 @@ def _RDFClass(cls, table: rdf_osquery.OsqueryTable) -> Type[Any]: rdf_cls = type(rdf_cls_name, (rdf_structs.RDFProtoStruct,), {}) rdf_cls.AddDescriptor( rdf_structs.ProtoEmbedded( - name="metadata", field_number=1, nested=base.ExportedMetadata)) + name="metadata", field_number=1, nested=base.ExportedMetadata + ) + ) rdf_cls.AddDescriptor( - rdf_structs.ProtoString(name="__query__", field_number=2)) + rdf_structs.ProtoString(name="__query__", field_number=2) + ) for idx, column in enumerate(table.header.columns): # It is possible that RDF column is named "metadata". To avoid name clash @@ -61,8 +64,9 @@ def _RDFClass(cls, table: rdf_osquery.OsqueryTable) -> Type[Any]: cls._rdf_cls_cache[rdf_cls_name] = rdf_cls return rdf_cls - def Convert(self, metadata: base.ExportedMetadata, - table: rdf_osquery.OsqueryTable) -> Any: + def Convert( + self, metadata: base.ExportedMetadata, table: rdf_osquery.OsqueryTable + ) -> Any: precondition.AssertType(table, rdf_osquery.OsqueryTable) rdf_cls = self._RDFClass(table) diff --git a/grr/server/grr_response_server/export_converters/process.py b/grr/server/grr_response_server/export_converters/process.py index ae21fceb3f..071ec9d7b8 100644 --- a/grr/server/grr_response_server/export_converters/process.py +++ b/grr/server/grr_response_server/export_converters/process.py @@ -29,8 +29,9 @@ class ProcessToExportedProcessConverter(base.ExportConverter): input_rdf_type = rdf_client.Process - def Convert(self, metadata: base.ExportedMetadata, - process: rdf_client.Process) -> List[ExportedProcess]: + def Convert( + self, metadata: base.ExportedMetadata, process: rdf_client.Process + ) -> List[ExportedProcess]: """Converts a Process into a ExportedProcess. Args: @@ -66,7 +67,8 @@ def Convert(self, metadata: base.ExportedMetadata, cpu_percent=process.cpu_percent, rss_size=process.RSS_size, vms_size=process.VMS_size, - memory_percent=process.memory_percent) + memory_percent=process.memory_percent, + ) return [result] @@ -89,11 +91,14 @@ def Convert( converted Process. """ - conn_converter = network.NetworkConnectionToExportedNetworkConnectionConverter( - options=self.options) - return conn_converter.BatchConvert([ - (metadata, conn) for conn in process.connections - ]) + conn_converter = ( + network.NetworkConnectionToExportedNetworkConnectionConverter( + options=self.options + ) + ) + return conn_converter.BatchConvert( + [(metadata, conn) for conn in process.connections] + ) class ProcessToExportedOpenFileConverter(base.ExportConverter): @@ -101,8 +106,9 @@ class ProcessToExportedOpenFileConverter(base.ExportConverter): input_rdf_type = rdf_client.Process - def Convert(self, metadata: base.ExportedMetadata, - process: rdf_client.Process) -> Iterator[ExportedOpenFile]: + def Convert( + self, metadata: base.ExportedMetadata, process: rdf_client.Process + ) -> Iterator[ExportedOpenFile]: """Converts a Process into a ExportedOpenFile. Args: diff --git a/grr/server/grr_response_server/export_converters/process_test.py b/grr/server/grr_response_server/export_converters/process_test.py index 0e45257f0b..adbf0941f0 100644 --- a/grr/server/grr_response_server/export_converters/process_test.py +++ b/grr/server/grr_response_server/export_converters/process_test.py @@ -16,7 +16,8 @@ def testBasicConversion(self): ppid=1, cmdline=["cmd.exe"], exe="c:\\windows\\cmd.exe", - ctime=1333718907167083) + ctime=1333718907167083, + ) converter = process.ProcessToExportedProcessConverter() results = list(converter.Convert(self.metadata, proc)) @@ -38,7 +39,8 @@ def testBasicConversion(self): cmdline=["cmd.exe"], exe="c:\\windows\\cmd.exe", ctime=1333718907167083, - open_files=["/some/a", "/some/b"]) + open_files=["/some/a", "/some/b"], + ) converter = process.ProcessToExportedOpenFileConverter() results = list(converter.Convert(self.metadata, proc)) @@ -51,7 +53,8 @@ def testBasicConversion(self): class ProcessToExportedNetworkConnectionConverterTest( - export_test_lib.ExportTestBase): + export_test_lib.ExportTestBase +): def testBasicConversion(self): conn1 = rdf_client_network.NetworkConnection( @@ -60,16 +63,20 @@ def testBasicConversion(self): local_address=rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=22), remote_address=rdf_client_network.NetworkEndpoint(ip="0.0.0.0", port=0), pid=2136, - ctime=0) + ctime=0, + ) conn2 = rdf_client_network.NetworkConnection( state=rdf_client_network.NetworkConnection.State.LISTEN, type=rdf_client_network.NetworkConnection.Type.SOCK_STREAM, local_address=rdf_client_network.NetworkEndpoint( - ip="192.168.1.1", port=31337), + ip="192.168.1.1", port=31337 + ), remote_address=rdf_client_network.NetworkEndpoint( - ip="1.2.3.4", port=6667), + ip="1.2.3.4", port=6667 + ), pid=1, - ctime=0) + ctime=0, + ) proc = rdf_client.Process( pid=2, @@ -77,16 +84,19 @@ def testBasicConversion(self): cmdline=["cmd.exe"], exe="c:\\windows\\cmd.exe", ctime=1333718907167083, - connections=[conn1, conn2]) + connections=[conn1, conn2], + ) converter = process.ProcessToExportedNetworkConnectionConverter() results = list(converter.Convert(self.metadata, proc)) self.assertLen(results, 2) - self.assertEqual(results[0].state, - rdf_client_network.NetworkConnection.State.LISTEN) - self.assertEqual(results[0].type, - rdf_client_network.NetworkConnection.Type.SOCK_STREAM) + self.assertEqual( + results[0].state, rdf_client_network.NetworkConnection.State.LISTEN + ) + self.assertEqual( + results[0].type, rdf_client_network.NetworkConnection.Type.SOCK_STREAM + ) self.assertEqual(results[0].local_address.ip, "0.0.0.0") self.assertEqual(results[0].local_address.port, 22) self.assertEqual(results[0].remote_address.ip, "0.0.0.0") @@ -94,10 +104,12 @@ def testBasicConversion(self): self.assertEqual(results[0].pid, 2136) self.assertEqual(results[0].ctime, 0) - self.assertEqual(results[1].state, - rdf_client_network.NetworkConnection.State.LISTEN) - self.assertEqual(results[1].type, - rdf_client_network.NetworkConnection.Type.SOCK_STREAM) + self.assertEqual( + results[1].state, rdf_client_network.NetworkConnection.State.LISTEN + ) + self.assertEqual( + results[1].type, rdf_client_network.NetworkConnection.Type.SOCK_STREAM + ) self.assertEqual(results[1].local_address.ip, "192.168.1.1") self.assertEqual(results[1].local_address.port, 31337) self.assertEqual(results[1].remote_address.ip, "1.2.3.4") diff --git a/grr/server/grr_response_server/export_converters/rdf_dict.py b/grr/server/grr_response_server/export_converters/rdf_dict.py index 4bdb7ffd60..b8173db5d8 100644 --- a/grr/server/grr_response_server/export_converters/rdf_dict.py +++ b/grr/server/grr_response_server/export_converters/rdf_dict.py @@ -19,9 +19,9 @@ class DictToExportedDictItemsConverter(base.ExportConverter): input_rdf_type = rdf_protodict.Dict - def _IterateDict(self, - d: Dict[str, Any], - key: str = "") -> Iterator[Tuple[str, Any]]: + def _IterateDict( + self, d: Dict[str, Any], key: str = "" + ) -> Iterator[Tuple[str, Any]]: """Performs a deeply-nested iteration of a given dictionary.""" if isinstance(d, (list, tuple)): for i, v in enumerate(d): @@ -48,8 +48,9 @@ def _IterateDict(self, else: yield key, d - def Convert(self, metadata: base.ExportedMetadata, - data: rdf_protodict.Dict) -> List[ExportedDictItem]: + def Convert( + self, metadata: base.ExportedMetadata, data: rdf_protodict.Dict + ) -> List[ExportedDictItem]: result = [] d = data.ToDict() for k, v in self._IterateDict(d): diff --git a/grr/server/grr_response_server/export_converters/rdf_primitives.py b/grr/server/grr_response_server/export_converters/rdf_primitives.py index 994e9c1d5b..d2a9ab8927 100644 --- a/grr/server/grr_response_server/export_converters/rdf_primitives.py +++ b/grr/server/grr_response_server/export_converters/rdf_primitives.py @@ -28,8 +28,9 @@ class RDFBytesToExportedBytesConverter(base.ExportConverter): input_rdf_type = rdfvalue.RDFBytes - def Convert(self, metadata: base.ExportedMetadata, - data: rdfvalue.RDFBytes) -> List[ExportedBytes]: + def Convert( + self, metadata: base.ExportedMetadata, data: rdfvalue.RDFBytes + ) -> List[ExportedBytes]: """Converts a RDFBytes into a ExportedNetworkConnection. Args: @@ -41,7 +42,8 @@ def Convert(self, metadata: base.ExportedMetadata, """ result = ExportedBytes( - metadata=metadata, data=data.SerializeToBytes(), length=len(data)) + metadata=metadata, data=data.SerializeToBytes(), length=len(data) + ) return [result] @@ -50,8 +52,9 @@ class RDFStringToExportedStringConverter(base.ExportConverter): input_rdf_type = rdfvalue.RDFString - def Convert(self, metadata: base.ExportedMetadata, - data: rdfvalue.RDFString) -> List[ExportedString]: + def Convert( + self, metadata: base.ExportedMetadata, data: rdfvalue.RDFString + ) -> List[ExportedString]: """Converts a RDFString into a ExportedString. Args: diff --git a/grr/server/grr_response_server/export_converters/registry_init.py b/grr/server/grr_response_server/export_converters/registry_init.py index 416bc06db2..dcd3c71df8 100644 --- a/grr/server/grr_response_server/export_converters/registry_init.py +++ b/grr/server/grr_response_server/export_converters/registry_init.py @@ -22,45 +22,57 @@ # TODO: Test that this function contains all inheritors. def RegisterExportConverters(): """Registers all ExportConverters.""" - # keep-sorted start export_converters_registry.Register( - buffer_reference.BufferReferenceToExportedMatchConverter) + buffer_reference.BufferReferenceToExportedMatchConverter + ) export_converters_registry.Register( - client_summary.ClientSummaryToExportedClientConverter) + client_summary.ClientSummaryToExportedClientConverter + ) export_converters_registry.Register( - client_summary.ClientSummaryToExportedNetworkInterfaceConverter) + client_summary.ClientSummaryToExportedNetworkInterfaceConverter + ) export_converters_registry.Register(cron_tab_file.CronTabFileConverter) export_converters_registry.Register(execute_response.ExecuteResponseConverter) export_converters_registry.Register( - file.ArtifactFilesDownloaderResultConverter) + file.ArtifactFilesDownloaderResultConverter + ) export_converters_registry.Register(file.FileFinderResultConverter) export_converters_registry.Register(file.StatEntryToExportedFileConverter) export_converters_registry.Register( - file.StatEntryToExportedRegistryKeyConverter) + file.StatEntryToExportedRegistryKeyConverter + ) export_converters_registry.Register(grr_message.GrrMessageConverter) export_converters_registry.Register(launchd_plist.LaunchdPlistConverter) export_converters_registry.Register(memory.ProcessMemoryErrorConverter) export_converters_registry.Register(memory.YaraProcessScanMatchConverter) export_converters_registry.Register( - network.DNSClientConfigurationToExportedDNSClientConfiguration) + network.DNSClientConfigurationToExportedDNSClientConfiguration + ) export_converters_registry.Register( - network.InterfaceToExportedNetworkInterfaceConverter) + network.InterfaceToExportedNetworkInterfaceConverter + ) export_converters_registry.Register( - network.NetworkConnectionToExportedNetworkConnectionConverter) + network.NetworkConnectionToExportedNetworkConnectionConverter + ) export_converters_registry.Register(osquery.OsqueryExportConverter) export_converters_registry.Register( - process.ProcessToExportedNetworkConnectionConverter) + process.ProcessToExportedNetworkConnectionConverter + ) export_converters_registry.Register( - process.ProcessToExportedOpenFileConverter) + process.ProcessToExportedOpenFileConverter + ) export_converters_registry.Register(process.ProcessToExportedProcessConverter) export_converters_registry.Register(rdf_dict.DictToExportedDictItemsConverter) export_converters_registry.Register( - rdf_primitives.RDFBytesToExportedBytesConverter) + rdf_primitives.RDFBytesToExportedBytesConverter + ) export_converters_registry.Register( - rdf_primitives.RDFStringToExportedStringConverter) + rdf_primitives.RDFStringToExportedStringConverter + ) export_converters_registry.Register(software_package.SoftwarePackageConverter) export_converters_registry.Register( - software_package.SoftwarePackagesConverter) + software_package.SoftwarePackagesConverter + ) export_converters_registry.Register( - windows_service_info.WindowsServiceInformationConverter) - # keep-sorted end + windows_service_info.WindowsServiceInformationConverter + ) diff --git a/grr/server/grr_response_server/export_converters/software_package.py b/grr/server/grr_response_server/export_converters/software_package.py index 1b14c16ee0..5677e547ef 100644 --- a/grr/server/grr_response_server/export_converters/software_package.py +++ b/grr/server/grr_response_server/export_converters/software_package.py @@ -22,19 +22,24 @@ class SoftwarePackageConverter(base.ExportConverter): input_rdf_type = rdf_client.SoftwarePackage _INSTALL_STATE_MAP = { - rdf_client.SoftwarePackage.InstallState.INSTALLED: - ExportedSoftwarePackage.InstallState.INSTALLED, - rdf_client.SoftwarePackage.InstallState.PENDING: - ExportedSoftwarePackage.InstallState.PENDING, - rdf_client.SoftwarePackage.InstallState.UNINSTALLED: - ExportedSoftwarePackage.InstallState.UNINSTALLED, - rdf_client.SoftwarePackage.InstallState.UNKNOWN: + rdf_client.SoftwarePackage.InstallState.INSTALLED: ( + ExportedSoftwarePackage.InstallState.INSTALLED + ), + rdf_client.SoftwarePackage.InstallState.PENDING: ( + ExportedSoftwarePackage.InstallState.PENDING + ), + rdf_client.SoftwarePackage.InstallState.UNINSTALLED: ( + ExportedSoftwarePackage.InstallState.UNINSTALLED + ), + rdf_client.SoftwarePackage.InstallState.UNKNOWN: ( ExportedSoftwarePackage.InstallState.UNKNOWN + ), } def Convert( - self, metadata: base.ExportedMetadata, - software_package: rdf_client.SoftwarePackage + self, + metadata: base.ExportedMetadata, + software_package: rdf_client.SoftwarePackage, ) -> Iterator[ExportedSoftwarePackage]: yield ExportedSoftwarePackage( metadata=metadata, @@ -45,7 +50,8 @@ def Convert( install_state=self._INSTALL_STATE_MAP[software_package.install_state], description=software_package.description, installed_on=software_package.installed_on, - installed_by=software_package.installed_by) + installed_by=software_package.installed_by, + ) class SoftwarePackagesConverter(base.ExportConverter): @@ -54,8 +60,9 @@ class SoftwarePackagesConverter(base.ExportConverter): input_rdf_type = rdf_client.SoftwarePackages def Convert( - self, metadata: base.ExportedMetadata, - software_packages: rdf_client.SoftwarePackages + self, + metadata: base.ExportedMetadata, + software_packages: rdf_client.SoftwarePackages, ) -> Iterator[ExportedSoftwarePackage]: conv = SoftwarePackageConverter(options=self.options) for p in software_packages.packages: diff --git a/grr/server/grr_response_server/export_converters/software_package_test.py b/grr/server/grr_response_server/export_converters/software_package_test.py index 76970f2f3f..aa1d2f44bb 100644 --- a/grr/server/grr_response_server/export_converters/software_package_test.py +++ b/grr/server/grr_response_server/export_converters/software_package_test.py @@ -17,7 +17,8 @@ def testConvertsCorrectly(self): publisher="somebody", description="desc", installed_on=42, - installed_by="user") + installed_by="user", + ) converter = software_package.SoftwarePackageConverter() converted = list(converter.Convert(self.metadata, result)) @@ -31,11 +32,12 @@ def testConvertsCorrectly(self): version="ver1", architecture="i386", publisher="somebody", - install_state=software_package.ExportedSoftwarePackage.InstallState - .PENDING, + install_state=software_package.ExportedSoftwarePackage.InstallState.PENDING, description="desc", installed_on=42, - installed_by="user")) + installed_by="user", + ), + ) class SoftwarePackagesConverterTest(export_test_lib.ExportTestBase): @@ -51,7 +53,9 @@ def testConvertsCorrectly(self): publisher="somebody_%d" % i, description="desc_%d" % i, installed_on=42 + i, - installed_by="user_%d" % i)) + installed_by="user_%d" % i, + ) + ) converter = software_package.SoftwarePackagesConverter() converted = list(converter.Convert(self.metadata, result)) @@ -66,11 +70,12 @@ def testConvertsCorrectly(self): version="ver_%d" % i, architecture="i386_%d" % i, publisher="somebody_%d" % i, - install_state=software_package.ExportedSoftwarePackage - .InstallState.PENDING, + install_state=software_package.ExportedSoftwarePackage.InstallState.PENDING, description="desc_%d" % i, installed_on=42 + i, - installed_by="user_%d" % i)) + installed_by="user_%d" % i, + ), + ) def main(argv): diff --git a/grr/server/grr_response_server/export_converters/windows_service_info.py b/grr/server/grr_response_server/export_converters/windows_service_info.py index c92b8785f4..e63cc2d757 100644 --- a/grr/server/grr_response_server/export_converters/windows_service_info.py +++ b/grr/server/grr_response_server/export_converters/windows_service_info.py @@ -20,8 +20,9 @@ class WindowsServiceInformationConverter(base.ExportConverter): input_rdf_type = rdf_client.WindowsServiceInformation def Convert( - self, metadata: base.ExportedMetadata, - i: rdf_client.WindowsServiceInformation + self, + metadata: base.ExportedMetadata, + i: rdf_client.WindowsServiceInformation, ) -> Iterator[ExportedWindowsServiceInformation]: wmi_components = [] for key in sorted(i.wmi_information.keys()): diff --git a/grr/server/grr_response_server/export_converters/windows_service_info_test.py b/grr/server/grr_response_server/export_converters/windows_service_info_test.py index 554f8a32b3..617b1a0c36 100644 --- a/grr/server/grr_response_server/export_converters/windows_service_info_test.py +++ b/grr/server/grr_response_server/export_converters/windows_service_info_test.py @@ -15,19 +15,14 @@ def testExportsValueCoorrectly(self): name="foo", description="bar", state="somestate", - wmi_information={ - "c": "d", - "a": "b" - }, + wmi_information={"c": "d", "a": "b"}, display_name="some name", driver_package_id="1234", error_control=rdf_client.WindowsServiceInformation.ErrorControl.NORMAL, image_path="/foo/bar", object_name="an object", - startup_type=rdf_client.WindowsServiceInformation.ServiceMode - .SERVICE_AUTO_START, - service_type=rdf_client.WindowsServiceInformation.ServiceType - .SERVICE_FILE_SYSTEM_DRIVER, + startup_type=rdf_client.WindowsServiceInformation.ServiceMode.SERVICE_AUTO_START, + service_type=rdf_client.WindowsServiceInformation.ServiceType.SERVICE_FILE_SYSTEM_DRIVER, group_name="somegroup", service_dll="somedll", registry_key="somekey", @@ -39,7 +34,8 @@ def testExportsValueCoorrectly(self): c = converted[0] self.assertIsInstance( - c, windows_service_info.ExportedWindowsServiceInformation) + c, windows_service_info.ExportedWindowsServiceInformation + ) self.assertEqual(c.metadata, self.metadata) self.assertEqual(c.name, "foo") diff --git a/grr/server/grr_response_server/export_test.py b/grr/server/grr_response_server/export_test.py index 2b482c820c..66c5b3aeae 100644 --- a/grr/server/grr_response_server/export_test.py +++ b/grr/server/grr_response_server/export_test.py @@ -9,7 +9,6 @@ from grr_response_server import export from grr_response_server.export_converters import base from grr_response_server.rdfvalues import mig_objects -from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import export_test_lib from grr.test_lib import fixture_test_lib from grr.test_lib import test_lib @@ -229,15 +228,6 @@ def testGetMetadataWithAmazonCloudInstanceID(self): metadata.CloudInstanceType.AMAZON) self.assertEqual(metadata.cloud_instance_id, "foo/bar") - def testGetMetadataUname(self): - info = rdf_objects.ClientFullInfo() - info.last_snapshot.knowledge_base.os = "Linux" - info.last_snapshot.os_release = "1.0RC4" - info.last_snapshot.os_version = "13.37" - - metadata = export.GetMetadata(self.client_id, info) - self.assertEqual(metadata.uname, "Linux-1.0RC4-13.37") - def main(argv): test_lib.main(argv) diff --git a/grr/server/grr_response_server/file_store.py b/grr/server/grr_response_server/file_store.py index d41551585b..122f80674c 100644 --- a/grr/server/grr_response_server/file_store.py +++ b/grr/server/grr_response_server/file_store.py @@ -369,8 +369,13 @@ def GetLastCollectionPathInfos(client_paths, max_timestamp=None): timestamp lower or equal then max_timestamp). """ - return data_store.REL_DB.ReadLatestPathInfosWithHashBlobReferences( - client_paths, max_timestamp=max_timestamp) + proto_dict = data_store.REL_DB.ReadLatestPathInfosWithHashBlobReferences( + client_paths, max_timestamp=max_timestamp + ) + rdf_dict = {} + for k, v in proto_dict.items(): + rdf_dict[k] = mig_objects.ToRDFPathInfo(v) if v is not None else None + return rdf_dict def GetLastCollectionPathInfo(client_path, max_timestamp=None): @@ -411,8 +416,12 @@ def OpenFile( MissingBlobReferencesError: if one of the blobs was not found. """ - path_info = data_store.REL_DB.ReadLatestPathInfosWithHashBlobReferences( - [client_path], max_timestamp=max_timestamp)[client_path] + proto_path_info = data_store.REL_DB.ReadLatestPathInfosWithHashBlobReferences( + [client_path], max_timestamp=max_timestamp + )[client_path] + path_info = None + if proto_path_info: + path_info = mig_objects.ToRDFPathInfo(proto_path_info) if path_info is None: # If path_info returned by ReadLatestPathInfosWithHashBlobReferences @@ -491,9 +500,16 @@ def StreamFilesChunks(client_paths, max_timestamp=None, max_size=None): BlobNotFoundError: if one of the blobs wasn't found while streaming. """ - path_infos_by_cp = ( + proto_path_infos_by_cp = ( data_store.REL_DB.ReadLatestPathInfosWithHashBlobReferences( - client_paths, max_timestamp=max_timestamp)) + client_paths, max_timestamp=max_timestamp + ) + ) + path_infos_by_cp = {} + for k, v in proto_path_infos_by_cp.items(): + path_infos_by_cp[k] = None + if v is not None: + path_infos_by_cp[k] = mig_objects.ToRDFPathInfo(v) hash_ids_by_cp = {} for cp, pi in path_infos_by_cp.items(): diff --git a/grr/server/grr_response_server/file_store_test.py b/grr/server/grr_response_server/file_store_test.py index e27c7941b4..4810b5f1ac 100644 --- a/grr/server/grr_response_server/file_store_test.py +++ b/grr/server/grr_response_server/file_store_test.py @@ -361,8 +361,10 @@ def _PathInfo(self, hash_id=None): return pi def testOpensFileWithSinglePathInfoWithHash(self): - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) @@ -371,61 +373,85 @@ def testRaisesForNonExistentFile(self): file_store.OpenFile(self.client_path) def testRaisesForFileWithSinglePathInfoWithoutHash(self): - data_store.REL_DB.WritePathInfos(self.client_id, [self._PathInfo()]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(self._PathInfo())] + ) with self.assertRaises(file_store.FileHasNoContentError): file_store.OpenFile(self.client_path) def testRaisesForFileWithSinglePathInfoWithUnknownHash(self): - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.invalid_hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.invalid_hash_id))], + ) with self.assertRaises(file_store.FileHasNoContentError): file_store.OpenFile(self.client_path) def testOpensFileWithTwoPathInfosWhereOldestHasHash(self): # Oldest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) # Newest. - data_store.REL_DB.WritePathInfos(self.client_id, [self._PathInfo()]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(self._PathInfo())] + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) def testOpensFileWithTwoPathInfosWhereNewestHasHash(self): # Oldest. - data_store.REL_DB.WritePathInfos(self.client_id, [self._PathInfo()]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(self._PathInfo())] + ) # Newest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) def testOpensFileWithTwoPathInfosWhereOldestHashIsUnknown(self): # Oldest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.invalid_hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.invalid_hash_id))], + ) # Newest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) def testOpensFileWithTwoPathInfosWhereNewestHashIsUnknown(self): # Oldest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) # Newest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.invalid_hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.invalid_hash_id))], + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) def testOpensLatestVersionForPathWithTwoPathInfosWithHashes(self): # Oldest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.other_hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.other_hash_id))], + ) # Newest. - data_store.REL_DB.WritePathInfos(self.client_id, - [self._PathInfo(self.hash_id)]) + data_store.REL_DB.WritePathInfos( + self.client_id, + [mig_objects.ToProtoPathInfo(self._PathInfo(self.hash_id))], + ) fd = file_store.OpenFile(self.client_path) self.assertEqual(fd.read(), self.data) @@ -476,7 +502,9 @@ def testRaisesIfSingleFileChunkIsMissing(self): client_path = db.ClientPath.OS(self.client_id, ("foo", "bar")) path_info = rdf_objects.PathInfo.OS(components=client_path.components) path_info.hash_entry.sha256 = hash_id.AsBytes() - data_store.REL_DB.WritePathInfos(client_path.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_path.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) # Just getting the generator doesn't raise. chunks = file_store.StreamFilesChunks([client_path]) diff --git a/grr/server/grr_response_server/flow.py b/grr/server/grr_response_server/flow.py index 3f71e81237..206257b615 100644 --- a/grr/server/grr_response_server/flow.py +++ b/grr/server/grr_response_server/flow.py @@ -305,7 +305,8 @@ def StartFlow(client_id=None, # database doesn't raise consistency errors due to missing parent keys when # writing logs / errors / results which might happen in Start(). try: - data_store.REL_DB.WriteFlowObject(flow_obj.rdf_flow, allow_update=False) + proto_flow = mig_flow_objects.ToProtoFlow(rdf_flow) + data_store.REL_DB.WriteFlowObject(proto_flow, allow_update=False) except db.FlowExistsError: raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id) @@ -337,8 +338,8 @@ def StartFlow(client_id=None, flow_obj.PersistState() try: - data_store.REL_DB.WriteFlowObject( - flow_obj.rdf_flow, allow_update=allow_update) + proto_flow = mig_flow_objects.ToProtoFlow(rdf_flow) + data_store.REL_DB.WriteFlowObject(proto_flow, allow_update=allow_update) except db.FlowExistsError: raise CanNotStartFlowWithExistingIdError(client_id, rdf_flow.flow_id) diff --git a/grr/server/grr_response_server/flow_base.py b/grr/server/grr_response_server/flow_base.py index 9449ae1243..c2b2c7fca4 100644 --- a/grr/server/grr_response_server/flow_base.py +++ b/grr/server/grr_response_server/flow_base.py @@ -5,18 +5,20 @@ import logging import re import traceback -from typing import Any, Callable, Collection, Iterator, Mapping, NamedTuple, Optional, Sequence, Tuple, Type +from typing import Any, Callable, Collection, Iterator, Mapping, NamedTuple, Optional, Sequence, Tuple, Type, Union from google.protobuf import any_pb2 from google.protobuf import message as pb_message from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import flows as rdf_flows +from grr_response_core.lib.rdfvalues import mig_protodict from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.registry import FlowRegistry from grr_response_core.stats import metrics from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 from grr_response_server import access_control from grr_response_server import action_registry from grr_response_server import data_store @@ -30,8 +32,9 @@ from grr_response_server import server_stubs from grr_response_server.databases import db from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_flow_runner +from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto import rrg_pb2 @@ -249,18 +252,8 @@ def CallState( ) ) nr_responses_expected = len(responses) + 1 - # No need to set needs_processing to True as we write the status - # message above. WriteFlowResponses implementation on all DBs - # will automatically issue a FlowProcessingRequest upon receiving a - # status message and seeing that all needed messages are in the - # database. Setting needs_processing to True here might lead to a race - # condition when a FlowProcessingRequest will be written and processed - # before FlowResponses corresponding to the FlowRequest are going to - # be written to the DB. - needs_processing = False else: nr_responses_expected = 0 - needs_processing = True flow_request = rdf_flow_objects.FlowRequest( client_id=self.rdf_flow.client_id, @@ -269,7 +262,7 @@ def CallState( next_state=next_state, start_time=start_time, nr_responses_expected=nr_responses_expected, - needs_processing=needs_processing, + needs_processing=True, ) self.flow_requests.append(flow_request) @@ -738,15 +731,24 @@ def Log(self, format_str: str, *args: object) -> None: def RunStateMethod( self, method_name: str, - request: Optional[rdf_flow_runner.RequestState] = None, - responses: Optional[Sequence[rdf_flow_objects.FlowMessage]] = None + request: Optional[rdf_flow_objects.FlowRequest] = None, + responses: Optional[ + Sequence[ + Union[ + rdf_flow_objects.FlowResponse, + rdf_flow_objects.FlowStatus, + rdf_flow_objects.FlowIterator, + ] + ] + ] = None, ) -> None: """Completes the request by calling the state method. Args: method_name: The name of the state method to call. request: A RequestState protobuf. - responses: A list of FlowMessages responding to the request. + responses: A list of FlowResponses, FlowStatuses, and FlowIterators + responding to the request. Raises: FlowError: Processing time for the flow has expired. @@ -823,7 +825,8 @@ def ProcessAllReadyRequests(self) -> Tuple[int, int]: request_dict = data_store.REL_DB.ReadFlowRequestsReadyForProcessing( self.rdf_flow.client_id, self.rdf_flow.flow_id, - next_needed_request=self.rdf_flow.next_request_to_process) + next_needed_request=self.rdf_flow.next_request_to_process, + ) if not request_dict: return (0, 0) @@ -836,6 +839,7 @@ def ProcessAllReadyRequests(self) -> Tuple[int, int]: # response is kept in request's 'next_response_id' attribute to guarantee # that responses are going to be processed in the right order. for request_id, (request, responses) in list(request_dict.items()): + request = mig_flow_objects.ToRDFFlowRequest(request) if not self.IsRunning(): break @@ -848,7 +852,15 @@ def ProcessAllReadyRequests(self) -> Tuple[int, int]: to_process = [] for r in responses: if r.response_id == next_response_id: - to_process.append(r) + if isinstance(r, flows_pb2.FlowResponse): + r = mig_flow_objects.ToRDFFlowResponse(r) + to_process.append(r) + if isinstance(r, flows_pb2.FlowStatus): + r = mig_flow_objects.ToRDFFlowStatus(r) + to_process.append(r) + if isinstance(r, flows_pb2.FlowIterator): + r = mig_flow_objects.ToRDFFlowIterator(r) + to_process.append(r) else: break next_response_id += 1 @@ -884,7 +896,21 @@ def ProcessAllReadyRequests(self) -> Tuple[int, int]: while (self.IsRunning() and self.rdf_flow.next_request_to_process in request_dict): - request, responses = request_dict[self.rdf_flow.next_request_to_process] + request, responses_proto = request_dict[ + self.rdf_flow.next_request_to_process + ] + request = mig_flow_objects.ToRDFFlowRequest(request) + responses = [] + for r in responses_proto: + if isinstance(r, flows_pb2.FlowResponse): + r = mig_flow_objects.ToRDFFlowResponse(r) + responses.append(r) + if isinstance(r, flows_pb2.FlowStatus): + r = mig_flow_objects.ToRDFFlowStatus(r) + responses.append(r) + if isinstance(r, flows_pb2.FlowIterator): + r = mig_flow_objects.ToRDFFlowIterator(r) + responses.append(r) if request.needs_processing: self.RunStateMethod(request.next_state, request, responses) self.rdf_flow.next_request_to_process += 1 @@ -935,11 +961,22 @@ def FlushQueuedMessages(self) -> None: # optimizing. if self.flow_requests: - data_store.REL_DB.WriteFlowRequests(self.flow_requests) + flow_requests_proto = [ + mig_flow_objects.ToProtoFlowRequest(r) for r in self.flow_requests + ] + data_store.REL_DB.WriteFlowRequests(flow_requests_proto) self.flow_requests = [] if self.flow_responses: - data_store.REL_DB.WriteFlowResponses(self.flow_responses) + flow_responses_proto = [] + for r in self.flow_responses: + if isinstance(r, rdf_flow_objects.FlowResponse): + flow_responses_proto.append(mig_flow_objects.ToProtoFlowResponse(r)) + if isinstance(r, rdf_flow_objects.FlowStatus): + flow_responses_proto.append(mig_flow_objects.ToProtoFlowStatus(r)) + if isinstance(r, rdf_flow_objects.FlowIterator): + flow_responses_proto.append(mig_flow_objects.ToProtoFlowIterator(r)) + data_store.REL_DB.WriteFlowResponses(flow_responses_proto) self.flow_responses = [] if self.client_action_requests: @@ -955,7 +992,21 @@ def FlushQueuedMessages(self) -> None: self.rrg_requests = [] if self.completed_requests: - data_store.REL_DB.DeleteFlowRequests(self.completed_requests) + completed_requests_protos = [] + for r in self.completed_requests: + if isinstance(r, flows_pb2.FlowResponse): + completed_requests_protos.append( + mig_flow_objects.ToProtoFlowResponse(r) + ) + if isinstance(r, flows_pb2.FlowStatus): + completed_requests_protos.append( + mig_flow_objects.ToProtoFlowStatus(r) + ) + if isinstance(r, flows_pb2.FlowIterator): + completed_requests_protos.append( + mig_flow_objects.ToProtoFlowIterator(r) + ) + data_store.REL_DB.DeleteFlowRequests(completed_requests_protos) self.completed_requests = [] if self.replies_to_write: @@ -983,9 +1034,15 @@ def _ProcessRepliesWithHuntOutputPlugins( ) -> None: """Applies output plugins to hunt results.""" hunt_obj = data_store.REL_DB.ReadHuntObject(self.rdf_flow.parent_hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) self.rdf_flow.output_plugins = hunt_obj.output_plugins hunt_output_plugins_states = data_store.REL_DB.ReadHuntOutputPluginsStates( self.rdf_flow.parent_hunt_id) + + hunt_output_plugins_states = [ + mig_flow_runner.ToRDFOutputPluginState(s) + for s in hunt_output_plugins_states + ] self.rdf_flow.output_plugins_states = hunt_output_plugins_states created_plugins = self._ProcessRepliesWithFlowOutputPlugins(replies) @@ -1000,8 +1057,12 @@ def _ProcessRepliesWithHuntOutputPlugins( plugin.UpdateState(s) if s != state.plugin_state: - def UpdateFn(plugin_state): - plugin.UpdateState(plugin_state) # pylint: disable=cell-var-from-loop + def UpdateFn( + plugin_state: jobs_pb2.AttributedDict, + ) -> jobs_pb2.AttributedDict: + plugin_state_rdf = mig_protodict.ToRDFAttributedDict(plugin_state) + plugin.UpdateState(plugin_state_rdf) # pylint: disable=cell-var-from-loop + plugin_state = mig_protodict.ToProtoAttributedDict(plugin_state_rdf) return plugin_state data_store.REL_DB.UpdateHuntOutputPluginState(hunt_obj.hunt_id, index, @@ -1264,13 +1325,13 @@ def Wrapper(self, responses: flow_responses.Responses) -> None: def _TerminateFlow( - rdf_flow: rdf_flow_objects.Flow, + proto_flow: flows_pb2.Flow, reason: Optional[str] = None, - flow_state: rdf_structs.EnumNamedValue = rdf_flow_objects.Flow.FlowState - .ERROR + flow_state: rdf_structs.EnumNamedValue = rdf_flow_objects.Flow.FlowState.ERROR, ) -> None: """Does the actual termination.""" - flow_cls = FlowRegistry.FlowClassByName(rdf_flow.flow_class_name) + flow_cls = FlowRegistry.FlowClassByName(proto_flow.flow_class_name) + rdf_flow = mig_flow_objects.ToRDFFlow(proto_flow) flow_obj = flow_cls(rdf_flow) if not flow_obj.IsRunning(): @@ -1283,16 +1344,18 @@ def _TerminateFlow( rdf_flow.flow_state = flow_state rdf_flow.error_message = reason flow_obj.NotifyCreatorOfError() - + proto_flow = mig_flow_objects.ToProtoFlow(rdf_flow) data_store.REL_DB.UpdateFlow( - rdf_flow.client_id, - rdf_flow.flow_id, - flow_obj=rdf_flow, + proto_flow.client_id, + proto_flow.flow_id, + flow_obj=proto_flow, processing_on=None, processing_since=None, - processing_deadline=None) - data_store.REL_DB.DeleteAllFlowRequestsAndResponses(rdf_flow.client_id, - rdf_flow.flow_id) + processing_deadline=None, + ) + data_store.REL_DB.DeleteAllFlowRequestsAndResponses( + proto_flow.client_id, proto_flow.flow_id + ) def TerminateFlow( @@ -1316,9 +1379,11 @@ def TerminateFlow( while to_terminate: next_to_terminate = [] - for rdf_flow in to_terminate: - _TerminateFlow(rdf_flow, reason=reason, flow_state=flow_state) + for proto_flow in to_terminate: + _TerminateFlow(proto_flow, reason=reason, flow_state=flow_state) next_to_terminate.extend( - data_store.REL_DB.ReadChildFlowObjects(rdf_flow.client_id, - rdf_flow.flow_id)) + data_store.REL_DB.ReadChildFlowObjects( + proto_flow.client_id, proto_flow.flow_id + ) + ) to_terminate = next_to_terminate diff --git a/grr/server/grr_response_server/flow_base_test.py b/grr/server/grr_response_server/flow_base_test.py index 536ece4dff..4556fe2736 100644 --- a/grr/server/grr_response_server/flow_base_test.py +++ b/grr/server/grr_response_server/flow_base_test.py @@ -17,6 +17,7 @@ from grr_response_server.databases import db as abstract_db from grr_response_server.databases import db_test_utils from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import db_test_lib from grr.test_lib import stats_test_lib from grr_response_proto import rrg_pb2 @@ -36,7 +37,7 @@ def testLogWithFormatArgs(self, db: abstract_db.Database) -> None: flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow = FlowBaseTest.Flow(flow) flow.Log("foo %s %s", "bar", 42) @@ -52,7 +53,7 @@ def testLogWithoutFormatArgs(self, db: abstract_db.Database) -> None: flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow = FlowBaseTest.Flow(flow) flow.Log("foo %s %s") @@ -136,7 +137,7 @@ def testReturnsDefaultFlowProgressForEmptyFlow(self, flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow_obj = FlowBaseTest.Flow(flow) progress = flow_obj.GetProgress() @@ -150,7 +151,7 @@ def testReturnsEmptyResultMetadataForEmptyFlow(self, flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow_obj = FlowBaseTest.Flow(flow) result_metadata = flow_obj.GetResultMetadata() @@ -167,7 +168,7 @@ def testReturnsEmptyResultMetadataWithFlagSetForPersistedEmptyFlow( flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow_obj = FlowBaseTest.Flow(flow) flow_obj.PersistState() @@ -185,7 +186,7 @@ def testResultMetadataHasGroupedNumberOfReplies(self, flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow_obj = FlowBaseTest.Flow(flow) flow_obj.SendReply(rdf_client.ClientInformation()) @@ -193,9 +194,10 @@ def testResultMetadataHasGroupedNumberOfReplies(self, flow_obj.SendReply(rdf_client.StartupInfo()) flow_obj.SendReply(rdf_client.StartupInfo(), tag="foo") flow_obj.PersistState() - db.WriteFlowObject(flow_obj.rdf_flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj.rdf_flow)) flow_2 = db.ReadFlowObject(client_id, self._FLOW_ID) + flow_2 = mig_flow_objects.ToRDFFlow(flow_2) flow_obj_2 = FlowBaseTest.Flow(flow_2) result_metadata = flow_obj_2.GetResultMetadata() @@ -222,15 +224,16 @@ def testResultMetadataAreCorrectlyUpdatedAfterMultiplePersistStateCalls( flow = rdf_flow_objects.Flow() flow.client_id = client_id flow.flow_id = self._FLOW_ID - db.WriteFlowObject(flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) flow_obj = FlowBaseTest.Flow(flow) flow_obj.SendReply(rdf_client.ClientInformation()) flow_obj.PersistState() flow_obj.PersistState() - db.WriteFlowObject(flow_obj.rdf_flow) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj.rdf_flow)) flow_2 = db.ReadFlowObject(client_id, self._FLOW_ID) + flow_2 = mig_flow_objects.ToRDFFlow(flow_2) flow_obj_2 = FlowBaseTest.Flow(flow_2) result_metadata = flow_obj_2.GetResultMetadata() diff --git a/grr/server/grr_response_server/flow_responses.py b/grr/server/grr_response_server/flow_responses.py index 9e1c65500e..eede550298 100644 --- a/grr/server/grr_response_server/flow_responses.py +++ b/grr/server/grr_response_server/flow_responses.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """The class encapsulating flow responses.""" -from typing import Any, Iterable, Iterator, Optional, Sequence, TypeVar +from typing import Any, Iterable, Iterator, Optional, Sequence, TypeVar, Union from google.protobuf import any_pb2 from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects @@ -46,7 +46,13 @@ def FromResponses(cls, request=None, responses=None) -> "Responses": @classmethod def FromResponsesProto2Any( cls, - responses: Sequence[rdf_flow_objects.FlowMessage], + responses: Sequence[ + Union[ + rdf_flow_objects.FlowResponse, + rdf_flow_objects.FlowStatus, + rdf_flow_objects.FlowIterator, + ], + ], ) -> "Responses[any_pb2.Any]": # pytype: enable=name-error """Creates a `Response` object from raw flow responses. diff --git a/grr/server/grr_response_server/flow_test.py b/grr/server/grr_response_server/flow_test.py index d5d01ae1ad..9f4bac92fe 100644 --- a/grr/server/grr_response_server/flow_test.py +++ b/grr/server/grr_response_server/flow_test.py @@ -13,6 +13,7 @@ from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import paths as rdf_paths +from grr_response_proto import flows_pb2 from grr_response_server import action_registry from grr_response_server import data_store from grr_response_server import flow @@ -25,6 +26,7 @@ from grr_response_server.flows.general import file_finder from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import acl_test_lib from grr.test_lib import action_mocks @@ -85,6 +87,13 @@ def Start(self) -> None: self.CallState( next_state=self.ReceiveHello.__name__, responses=responses, + # Calling the state a little in the future to avoid inline processing + # done by the flow test library. Inline processing will break the + # CallState logic: responses are written after requests, but the + # inline processing is triggered already when requests are written. + # Inline processing doesn't happen if flow requests are scheduled in + # the future. + start_time=rdfvalue.RDFDatetime.Now() + rdfvalue.Duration("1s"), ) def ReceiveHello(self, responses: flow_responses.Responses) -> None: @@ -251,8 +260,10 @@ def testChildTermination(self): client_flow_obj = data_store.REL_DB.ReadChildFlowObjects( self.client_id, flow_id)[0] - self.assertEqual(flow_obj.flow_state, "RUNNING") - self.assertEqual(client_flow_obj.flow_state, "RUNNING") + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.RUNNING) + self.assertEqual( + client_flow_obj.flow_state, flows_pb2.Flow.FlowState.RUNNING + ) # Terminate the parent flow. flow_base.TerminateFlow(self.client_id, flow_id, reason="Testing") @@ -261,15 +272,15 @@ def testChildTermination(self): client_flow_obj = data_store.REL_DB.ReadChildFlowObjects( self.client_id, flow_id)[0] - self.assertEqual(flow_obj.flow_state, "ERROR") - self.assertEqual(client_flow_obj.flow_state, "ERROR") + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) + self.assertEqual(client_flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) def testExceptionInStart(self): flow_id = flow.StartFlow( flow_cls=FlowWithBrokenStart, client_id=self.client_id) flow_obj = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(flow_obj.flow_state, flow_obj.FlowState.ERROR) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertEqual(flow_obj.error_message, "boo") self.assertIsNotNone(flow_obj.backtrace) @@ -381,7 +392,7 @@ def testCPULimitExceeded(self): check_flow_errors=False) rdf_flow = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(rdf_flow.flow_state, "ERROR") + self.assertEqual(rdf_flow.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn("CPU limit exceeded", rdf_flow.error_message) def testNetworkLimitExceeded(self): @@ -399,7 +410,7 @@ def testNetworkLimitExceeded(self): check_flow_errors=False) rdf_flow = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(rdf_flow.flow_state, "ERROR") + self.assertEqual(rdf_flow.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn("bytes limit exceeded", rdf_flow.error_message) def testRuntimeLimitExceeded(self): @@ -417,9 +428,9 @@ def testRuntimeLimitExceeded(self): runtime_limit=rdfvalue.Duration.From(9, rdfvalue.SECONDS), check_flow_errors=False) - rdf_flow = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(rdf_flow.flow_state, "ERROR") - self.assertIn("Runtime limit exceeded", rdf_flow.error_message) + flow_obj = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) + self.assertIn("Runtime limit exceeded", flow_obj.error_message) def testUserGetsNotificationWithNumberOfResults(self): username = "notification_test_user" @@ -548,7 +559,7 @@ def testFlowDoesNotFailWhenOutputPluginFails(self): plugin_name="FailingDummyFlowOutputPlugin") ]) flow_obj = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(flow_obj.flow_state, "FINISHED") + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.FINISHED) def testFailingPluginDoesNotImpactOtherPlugins(self): self.RunFlow(output_plugins=[ @@ -660,7 +671,8 @@ def testStartScheduledFlowsCreatesFlow(self): self.assertEqual( flows[0].flow_class_name, file.CollectFilesByKnownPath.__name__ ) - self.assertEqual(flows[0].args.paths, ["/foo"]) + rdf_flow = mig_flow_objects.ToRDFFlow(flows[0]) + self.assertEqual(rdf_flow.args.paths, ["/foo"]) self.assertEqual( flows[0].flow_state, rdf_flow_objects.Flow.FlowState.RUNNING ) @@ -993,8 +1005,9 @@ def testRaisesIfFlowProcessingRequestDoesNotTriggerAnyProcessing(self): with flow_test_lib.TestWorker() as worker: flow_id = flow.StartFlow( flow_cls=CallClientParentFlow, client_id=self.client_id) - fpr = rdf_flows.FlowProcessingRequest( - client_id=self.client_id, flow_id=flow_id) + fpr = flows_pb2.FlowProcessingRequest( + client_id=self.client_id, flow_id=flow_id + ) with self.assertRaises(worker_lib.FlowHasNothingToProcessError): worker.ProcessFlow(fpr) diff --git a/grr/server/grr_response_server/flow_utils.py b/grr/server/grr_response_server/flow_utils.py deleted file mode 100644 index b9848c32da..0000000000 --- a/grr/server/grr_response_server/flow_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python -"""Utils for flow related tasks.""" - - -def GetUserInfo(knowledge_base, user): - r"""Get a User protobuf for a specific user. - - Args: - knowledge_base: An rdf_client.KnowledgeBase object. - user: Username as string. May contain domain like DOMAIN\user. - - Returns: - A User rdfvalue or None - """ - if "\\" in user: - domain, user = user.split("\\", 1) - users = [ - u for u in knowledge_base.users - if u.username == user and u.userdomain == domain - ] - else: - users = [u for u in knowledge_base.users if u.username == user] - - if not users: - return - else: - return users[0] diff --git a/grr/server/grr_response_server/flows/file_test.py b/grr/server/grr_response_server/flows/file_test.py index aeec10ab5e..830143e004 100644 --- a/grr/server/grr_response_server/flows/file_test.py +++ b/grr/server/grr_response_server/flows/file_test.py @@ -21,6 +21,7 @@ from grr_response_server import flow_base from grr_response_server.flows import file from grr_response_server.flows.general import file_finder +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import action_mocks from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -593,7 +594,8 @@ def testPassesNoConditionsToClientFileFinderWhenNoConditionsSpecified(self): self.assertEqual( child.flow_class_name, file_finder.ClientFileFinder.__name__ ) - self.assertEmpty(child.args.conditions) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertEmpty(conditions) def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): modification_time = rdf_file_finder.FileFinderModificationTimeCondition( @@ -650,10 +652,11 @@ def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): ) # We expect 7 condition-attributes to be converted # to 7 FileFinderConditions. - self.assertLen(child.args.conditions, 7) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertLen(conditions, 7) def _GetCondition(condition_type): - for c in child.args.conditions: + for c in conditions: if c.condition_type == condition_type: return c.UnionCast() @@ -858,7 +861,8 @@ def testPassesNoConditionsToClientFileFinderWhenNoConditionsSpecified(self): self.assertEqual( child.flow_class_name, file_finder.ClientFileFinder.__name__ ) - self.assertEmpty(child.args.conditions) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertEmpty(conditions) def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): modification_time = rdf_file_finder.FileFinderModificationTimeCondition( @@ -915,10 +919,11 @@ def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): ) # We expect 7 condition-attributes to be converted # to 7 FileFinderConditions. - self.assertLen(child.args.conditions, 7) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertLen(conditions, 7) def _GetCondition(condition_type): - for c in child.args.conditions: + for c in conditions: if c.condition_type == condition_type: return c.UnionCast() @@ -1179,7 +1184,8 @@ def testPassesNoConditionsToClientFileFinderWhenNoConditionsSpecified(self): self.assertEqual( child.flow_class_name, file_finder.ClientFileFinder.__name__ ) - self.assertEmpty(child.args.conditions) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertEmpty(conditions) def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): modification_time = rdf_file_finder.FileFinderModificationTimeCondition( @@ -1236,10 +1242,11 @@ def testPassesAllConditionsToClientFileFinderWhenAllConditionsSpecified(self): ) # We expect 7 condition-attributes to be converted # to 7 FileFinderConditions. - self.assertLen(child.args.conditions, 7) + conditions = mig_flow_objects.ToRDFFlow(child).args.conditions + self.assertLen(conditions, 7) def _GetCondition(condition_type): - for c in child.args.conditions: + for c in conditions: if c.condition_type == condition_type: return c.UnionCast() diff --git a/grr/server/grr_response_server/flows/general/administrative.py b/grr/server/grr_response_server/flows/general/administrative.py index 5af3ac3135..551971bb2d 100644 --- a/grr/server/grr_response_server/flows/general/administrative.py +++ b/grr/server/grr_response_server/flows/general/administrative.py @@ -57,7 +57,7 @@ def WriteAllCrashDetails(client_id, crash_details, flow_session_id=None): data_store.REL_DB.UpdateFlow( client_id, flow_id, - client_crash_info=mig_client.ToRDFClientCrash(crash_details), + client_crash_info=crash_details, ) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) @@ -195,7 +195,7 @@ def StartBlobsUpload( # Fail early if the file is not there or empty. try: file_size = signed_binary_utils.FetchSizeOfSignedBinary(binary_id) - except signed_binary_utils.UnknownSignedBinaryError as ex: + except signed_binary_utils.SignedBinaryNotFoundError as ex: raise flow_base.FlowError(f"File {binary_id} not found.") from ex if file_size == 0: @@ -506,18 +506,92 @@ def SendMail(self, responses): self.args.email, "grr-noreply", subject, body, is_html=True) -class NannyMessageHandlerMixin(object): - """A listener for nanny messages.""" +class UpdateClientArgs(rdf_structs.RDFProtoStruct): + protobuf = flows_pb2.UpdateClientArgs + rdf_deps = [] + + +class UpdateClient(RecursiveBlobUploadMixin, flow_base.FlowBase): + """Updates the GRR client to a new version replacing the current client. + + This will execute the specified installer on the client and then run + an Interrogate flow. + + The new installer's binary has to be uploaded to GRR (as a binary, not as + a Python hack) and must be signed using the exec signing key. + + Signing and upload of the file is done with grr_config_updater or through + the API. + """ + + category = "/Administrative/" + + args_type = UpdateClientArgs + result_types = (rdf_client_action.ExecuteBinaryResponse,) + + def GenerateUploadRequest( + self, offset: int, file_size: int, blob: rdf_crypto.SignedBlob + ) -> Tuple[rdf_structs.RDFProtoStruct, Type[server_stubs.ClientActionStub]]: + request = rdf_client_action.ExecuteBinaryRequest( + executable=blob, + offset=offset, + write_path=self.state.write_path, + more_data=(offset + len(blob.data) < file_size), + use_client_env=False, + ) + + return request, server_stubs.UpdateAgent + + @property + def _binary_id(self): + return objects_pb2.SignedBinaryID( + binary_type=rdf_objects.SignedBinaryID.BinaryType.EXECUTABLE, + path=self.args.binary_path, + ) + + def Start(self): + """Start.""" + if not self.args.binary_path: + raise flow_base.FlowError("Installer binary path is not specified.") + + binary_urn = rdfvalue.RDFURN(self.args.binary_path) + self.state.write_path = "%d_%s" % (time.time(), binary_urn.Basename()) + + self.StartBlobsUpload(self._binary_id, self.End.__name__) + + def End(self, responses): + if not responses.success: + raise flow_base.FlowError( + "Installer reported an error: %s" % responses.status + ) + response = responses.First() + if not response: + return + + if response.exit_status != 0: + raise flow_base.FlowError( + f"Installer process failed with exit code {response.exit_status}" + f"\nstdout: {response.stdout}" + f"\nstderr: {response.stderr}" + ) + self.Log("Installer finished running.") + self.SendReply(response) + + +class ClientAlertHandler(message_handlers.MessageHandler): + """A listener for client messages.""" + + handler_name = "ClientAlertHandler" mail_template = jinja2.Template( """ -

GRR nanny message received.

+

GRR client message received.

-The nanny for client {{ client_id }} ({{ hostname }}) just sent a message:
+The client {{ client_id }} ({{ hostname }}) just sent a message:

{{ message }}
-Click here to access this machine. +Click here to access this machine.

{{ signature }}

@@ -525,9 +599,9 @@ class NannyMessageHandlerMixin(object): autoescape=True, ) - subject = "GRR nanny message received from %s." + subject = "GRR client message received from %s." - logline = "Nanny for client %s sent: %s" + logline = "Client message from %s: %s" def SendEmail(self, client_id, message): """Processes this event.""" @@ -571,46 +645,6 @@ def SendEmail(self, client_id, message): body, is_html=True) - -class NannyMessageHandler(NannyMessageHandlerMixin, - message_handlers.MessageHandler): - - handler_name = "NannyMessageHandler" - - def ProcessMessages(self, msgs): - for message in msgs: - self.SendEmail(message.client_id, message.request.payload.string) - - -class ClientAlertHandlerMixin(NannyMessageHandlerMixin): - """A listener for client messages.""" - - mail_template = jinja2.Template( - """ -

GRR client message received.

- -The client {{ client_id }} ({{ hostname }}) just sent a message:
-
-{{ message }} -
-Click here to access this machine. - -

{{ signature }}

- -""", - autoescape=True, - ) - - subject = "GRR client message received from %s." - - logline = "Client message from %s: %s" - - -class ClientAlertHandler(ClientAlertHandlerMixin, - message_handlers.MessageHandler): - - handler_name = "ClientAlertHandler" - def ProcessMessages(self, msgs): for message in msgs: self.SendEmail(message.client_id, message.request.payload.string) diff --git a/grr/server/grr_response_server/flows/general/administrative_test.py b/grr/server/grr_response_server/flows/general/administrative_test.py index b46fe5cb3f..acc2734daa 100644 --- a/grr/server/grr_response_server/flows/general/administrative_test.py +++ b/grr/server/grr_response_server/flows/general/administrative_test.py @@ -2,6 +2,7 @@ """Tests for administrative flows.""" import datetime +import os import subprocess import sys import tempfile @@ -10,6 +11,7 @@ from absl import app import psutil +from grr_response_client import actions from grr_response_client.client_actions import admin from grr_response_client.client_actions import standard from grr_response_core import config @@ -19,6 +21,7 @@ from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import protodict as rdf_protodict from grr_response_core.lib.rdfvalues import structs as rdf_structs +from grr_response_proto import flows_pb2 from grr_response_proto import tests_pb2 from grr_response_server import action_registry from grr_response_server import client_index @@ -32,6 +35,7 @@ from grr_response_server.flows.general import administrative from grr_response_server.flows.general import discovery from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import acl_test_lib from grr.test_lib import action_mocks from grr.test_lib import client_test_lib @@ -53,6 +57,32 @@ def Start(self): action_registry.ACTION_STUB_BY_ID[self.args.action], next_state="End") +class UpdateClientErrorAction(actions.ActionPlugin): + in_rdfvalue = rdf_client_action.ExecuteBinaryRequest + out_rdfvalues = [rdf_client_action.ExecuteBinaryResponse] + + def Run(self, args: rdf_client_action.ExecuteBinaryRequest): + if not args.more_data: + self.SendReply( + rdf_client_action.ExecuteBinaryResponse( + exit_status=1, stdout=b"\xff\xff\xff\xff", stderr=b"foobar" + ) + ) + + +class UpdateClientNoCrashAction(actions.ActionPlugin): + in_rdfvalue = rdf_client_action.ExecuteBinaryRequest + out_rdfvalues = [rdf_client_action.ExecuteBinaryResponse] + + def Run(self, args: rdf_client_action.ExecuteBinaryRequest): + if not args.more_data: + self.SendReply( + rdf_client_action.ExecuteBinaryResponse( + exit_status=0, stdout=b"foobar", stderr=b"\xff\xff\xff\xff" + ) + ) + + class TestAdministrativeFlows(flow_test_lib.FlowTestsBaseclass, hunt_test_lib.StandardHuntTestMixin): """Tests the administrative flows.""" @@ -161,7 +191,7 @@ def SendEmail(address, sender, title, message, **_): self.assertIn(client_id, email_message["title"]) rel_flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - self.assertEqual(rel_flow_obj.flow_state, rel_flow_obj.FlowState.CRASHED) + self.assertEqual(rel_flow_obj.flow_state, flows_pb2.Flow.FlowState.CRASHED) # Make sure client object is updated with the last crash. crash = data_store.REL_DB.ReadClientCrashInfo(client_id) @@ -414,7 +444,82 @@ def Run(self, args): binary=binary_path, client_id=client_id, command_line="--bar --baz", - creator=self.test_username) + creator=self.test_username, + ) + + def testUpdateClient(self): + client_mock = action_mocks.ActionMock.With( + {"UpdateAgent": UpdateClientNoCrashAction} + ) + fake_installer = b"FakeGRRDebInstaller" * 20 + upload_path = ( + signed_binary_utils.GetAFF4ExecutablesRoot() + .Add(config.CONFIG["Client.platform"]) + .Add("test.deb") + ) + maintenance_utils.UploadSignedConfigBlob( + fake_installer, aff4_path=upload_path, limit=100 + ) + + blob_list, _ = signed_binary_utils.FetchBlobsForSignedBinaryByURN( + upload_path + ) + self.assertLen(list(blob_list), 4) + + acl_test_lib.CreateAdminUser(self.test_username) + + client_id = self.SetupClient(0, system="") + flow_id = flow_test_lib.TestFlowHelper( + administrative.UpdateClient.__name__, + client_mock, + client_id=client_id, + binary_path=os.path.join(config.CONFIG["Client.platform"], "test.deb"), + creator=self.test_username, + ) + results = flow_test_lib.GetFlowResults(client_id, flow_id) + self.assertLen(results, 1) + self.assertEqual(0, results[0].exit_status) + self.assertEqual(results[0].stdout, b"foobar") + + def testUpdateClientFailure(self): + client_mock = action_mocks.ActionMock.With( + {"UpdateAgent": UpdateClientErrorAction} + ) + fake_installer = b"FakeGRRDebInstaller" * 20 + upload_path = ( + signed_binary_utils.GetAFF4ExecutablesRoot() + .Add(config.CONFIG["Client.platform"]) + .Add("test.deb") + ) + maintenance_utils.UploadSignedConfigBlob( + fake_installer, aff4_path=upload_path, limit=100 + ) + + blob_list, _ = signed_binary_utils.FetchBlobsForSignedBinaryByURN( + upload_path + ) + self.assertLen(list(blob_list), 4) + + acl_test_lib.CreateAdminUser(self.test_username) + + client_id = self.SetupClient(0, system="") + flow_id = flow_test_lib.TestFlowHelper( + administrative.UpdateClient.__name__, + client_mock, + client_id=client_id, + binary_path=os.path.join(config.CONFIG["Client.platform"], "test.deb"), + creator=self.test_username, + check_flow_errors=False, + ) + + rel_flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + self.assertEqual(rel_flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) + results = flow_test_lib.GetFlowResults(client_id, flow_id) + self.assertEmpty(results) + self.assertContainsInOrder( + ["stdout: b'\\xff\\xff\\xff\\xff'", "stderr: b'foobar'"], + rel_flow_obj.error_message, + ) def testOnlineNotificationEmail(self): """Tests that the mail is sent in the OnlineNotification flow.""" @@ -599,11 +704,15 @@ def testStartupDoesNotTriggerInterrogateIfRecentInterrogateIsRunning(self): self._RunSendStartupInfo(client_id) data_store.REL_DB.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=flow.RandomFlowId(), - client_id=client_id, - flow_class_name=discovery.Interrogate.__name__, - flow_state=rdf_flow_objects.Flow.FlowState.RUNNING)) + mig_flow_objects.ToProtoFlow( + rdf_flow_objects.Flow( + flow_id=flow.RandomFlowId(), + client_id=client_id, + flow_class_name=discovery.Interrogate.__name__, + flow_state=rdf_flow_objects.Flow.FlowState.RUNNING, + ) + ) + ) flows = data_store.REL_DB.ReadAllFlowObjects( client_id, include_child_flows=False) @@ -628,11 +737,15 @@ def testStartupTriggersInterrogateWhenPreviousInterrogateIsDone(self): self._RunSendStartupInfo(client_id) data_store.REL_DB.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=flow.RandomFlowId(), - client_id=client_id, - flow_class_name=discovery.Interrogate.__name__, - flow_state=rdf_flow_objects.Flow.FlowState.FINISHED)) + mig_flow_objects.ToProtoFlow( + rdf_flow_objects.Flow( + flow_id=flow.RandomFlowId(), + client_id=client_id, + flow_class_name=discovery.Interrogate.__name__, + flow_state=rdf_flow_objects.Flow.FlowState.FINISHED, + ) + ) + ) flows = data_store.REL_DB.ReadAllFlowObjects( client_id, include_child_flows=False) @@ -652,55 +765,6 @@ def testStartupTriggersInterrogateWhenPreviousInterrogateIsDone(self): ] self.assertLen(interrogates, orig_count + 1) - def testNannyMessageHandler(self): - client_id = self.SetupClient(0) - nanny_message = "Oh no!" - email_dict = {} - - def SendEmail(address, sender, title, message, **_): - email_dict.update( - dict(address=address, sender=sender, title=title, message=message)) - - with mock.patch.object(email_alerts.EMAIL_ALERTER, "SendEmail", SendEmail): - flow_test_lib.MockClient(client_id, None)._PushHandlerMessage( - rdf_flows.GrrMessage( - source=client_id, - session_id=rdfvalue.SessionID(flow_name="NannyMessage"), - payload=rdf_protodict.DataBlob(string=nanny_message), - request_id=0, - auth_state="AUTHENTICATED", - response_id=123)) - - self._CheckNannyEmail(client_id, nanny_message, email_dict) - - def testNannyMessageHandlerForUnknownClient(self): - client_id = self.SetupClient(0) - nanny_message = "Oh no!" - email_dict = {} - - def SendEmail(address, sender, title, message, **_): - email_dict.update( - dict(address=address, sender=sender, title=title, message=message)) - - with mock.patch.object(email_alerts.EMAIL_ALERTER, "SendEmail", SendEmail): - flow_test_lib.MockClient(client_id, None)._PushHandlerMessage( - rdf_flows.GrrMessage( - source=client_id, - session_id=rdfvalue.SessionID(flow_name="NannyMessage"), - payload=rdf_protodict.DataBlob(string=nanny_message), - request_id=0, - auth_state="AUTHENTICATED", - response_id=123)) - - # We expect the email to be sent. - self.assertEqual( - email_dict.get("address"), config.CONFIG["Monitoring.alert_email"]) - - # Make sure the message is included in the email message. - self.assertIn(nanny_message, email_dict["message"]) - - self.assertIn(client_id, email_dict["title"]) - def testClientAlertHandler(self): client_id = self.SetupClient(0) client_message = "Oh no!" diff --git a/grr/server/grr_response_server/flows/general/collectors.py b/grr/server/grr_response_server/flows/general/collectors.py index c58e5401d3..da5cb426ee 100644 --- a/grr/server/grr_response_server/flows/general/collectors.py +++ b/grr/server/grr_response_server/flows/general/collectors.py @@ -457,6 +457,7 @@ def _StartSubArtifactCollector(self, artifact_list, source, next_state): artifact_list=artifact_list, use_raw_filesystem_access=self.args.use_raw_filesystem_access, apply_parsers=self.args.apply_parsers, + implementation_type=self.args.implementation_type, max_file_size=self.args.max_file_size, ignore_interpolation_errors=self.args.ignore_interpolation_errors, dependencies=self.args.dependencies, diff --git a/grr/server/grr_response_server/flows/general/collectors_core_artifacts_test.py b/grr/server/grr_response_server/flows/general/collectors_core_artifacts_test.py index ce8d644043..2112b40780 100644 --- a/grr/server/grr_response_server/flows/general/collectors_core_artifacts_test.py +++ b/grr/server/grr_response_server/flows/general/collectors_core_artifacts_test.py @@ -172,35 +172,6 @@ def WmiQuery(self, _): self.assertEqual(result.Name(), "homefileshare$") self.assertAlmostEqual(result.FreeSpacePercent(), 58.823, delta=0.001) - @parser_test_lib.WithParser("WmiActiveScriptEventConsumer", - wmi_parser.WMIActiveScriptEventConsumerParser) - def testWMIBaseObject(self): - client_id = self.SetupClient(0, system="Windows", os_version="6.2") - - class WMIActionMock(action_mocks.ActionMock): - - base_objects = [] - - def WmiQuery(self, args): - self.base_objects.append(args.base_object) - return client_fixture.WMI_SAMPLE - - client_mock = WMIActionMock() - flow_test_lib.TestFlowHelper( - collectors.ArtifactCollectorFlow.__name__, - client_mock, - artifact_list=["WMIActiveScriptEventConsumer"], - creator=self.test_username, - client_id=client_id, - dependencies=( - rdf_artifacts.ArtifactCollectorFlowArgs.Dependency.IGNORE_DEPS)) - - # Make sure the artifact's base_object made it into the WmiQuery call. - artifact_obj = artifact_registry.REGISTRY.GetArtifact( - "WMIActiveScriptEventConsumer") - self.assertCountEqual(WMIActionMock.base_objects, - [artifact_obj.sources[0].attributes["base_object"]]) - def main(argv): # Run the full test suite diff --git a/grr/server/grr_response_server/flows/general/collectors_interactions_test.py b/grr/server/grr_response_server/flows/general/collectors_interactions_test.py index c84756ad17..5ac8b3caa5 100644 --- a/grr/server/grr_response_server/flows/general/collectors_interactions_test.py +++ b/grr/server/grr_response_server/flows/general/collectors_interactions_test.py @@ -124,6 +124,9 @@ def testNewArtifactLoaded(self): @parser_test_lib.WithAllParsers def testProcessCollectedArtifacts(self): """Tests downloading files from artifacts.""" + self.skipTest("Deeply nested protobufs") + # TODO: Test disabled because of restriction of proto nesting + # depth. Enable open source test again when fixed. self.client_id = self.SetupClient(0, system="Windows", os_version="6.2") client_mock = action_mocks.FileFinderClientMock() @@ -151,6 +154,9 @@ def testProcessCollectedArtifacts(self): def testBrokenArtifact(self): """Tests a broken artifact.""" + self.skipTest("Deeply nested protobufs") + # TODO: Test disabled because of restriction of proto nesting + # depth. Enable open source test again when fixed. self.client_id = self.SetupClient(0, system="Windows", os_version="6.2") client_mock = action_mocks.FileFinderClientMock() diff --git a/grr/server/grr_response_server/flows/general/collectors_test.py b/grr/server/grr_response_server/flows/general/collectors_test.py index 2b60b5fdb1..08f55b79c6 100644 --- a/grr/server/grr_response_server/flows/general/collectors_test.py +++ b/grr/server/grr_response_server/flows/general/collectors_test.py @@ -14,6 +14,7 @@ from absl import app import psutil +from grr_response_client import actions from grr_response_client.client_actions import standard from grr_response_core import config from grr_response_core.lib import factory @@ -29,12 +30,15 @@ from grr_response_core.lib.util import temp from grr_response_proto import knowledge_base_pb2 from grr_response_proto import objects_pb2 +from grr_response_server import action_registry from grr_response_server import artifact_registry from grr_response_server import data_store from grr_response_server import file_store from grr_response_server.databases import db +from grr_response_server.databases import db_test_utils from grr_response_server.flows.general import collectors from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import action_mocks from grr.test_lib import artifact_test_lib @@ -538,6 +542,72 @@ def _RunListProcesses(self, args): self.assertLen(results, 1) self.assertEqual(results[0].pid, 123) + def testArtifactGroupGetsParentArgs(self): + client_id = db_test_utils.InitializeClient(data_store.REL_DB) + # Client metadata is not enough, we need some KB data to be present. + client = objects_pb2.ClientSnapshot(client_id=client_id) + client.knowledge_base.fqdn = "hidrogenesse.example.com" + client.knowledge_base.os = "linux" + data_store.REL_DB.WriteClientSnapshot(client) + + artifact_registry.REGISTRY.RegisterArtifact( + rdf_artifacts.Artifact( + name="Planta", + doc="Animalito", + sources=[ + rdf_artifacts.ArtifactSource( + type=rdf_artifacts.ArtifactSource.SourceType.ARTIFACT_GROUP, + attributes={"names": ["Máquina"]}, + ) + ], + ) + ) + artifact_registry.REGISTRY.RegisterArtifact( + rdf_artifacts.Artifact( + name="Máquina", + doc="Piedra", + sources=[ + rdf_artifacts.ArtifactSource( + type=rdf_artifacts.ArtifactSource.SourceType.GRR_CLIENT_ACTION, + attributes={"client_action": "DoesNothingActionMock"}, + ) + ], + ) + ) + + class DoesNothingActionMock(actions.ActionPlugin): + + def Run(self, args: any) -> None: + del args + pass + + # TODO: Start using the annotation (w/cleanup). + action_registry.RegisterAdditionalTestClientAction(DoesNothingActionMock) + + flow_id = flow_test_lib.TestFlowHelper( + collectors.ArtifactCollectorFlow.__name__, + action_mocks.ActionMock(), + artifact_list=["Planta"], + client_id=client_id, + apply_parsers=False, + use_raw_filesystem_access=True, + implementation_type=rdf_paths.PathSpec.ImplementationType.DIRECT, + max_file_size=1, + ignore_interpolation_errors=True, + ) + + child_flows = data_store.REL_DB.ReadChildFlowObjects(client_id, flow_id) + self.assertLen(child_flows, 1) + args = mig_flow_objects.ToRDFFlow(child_flows[0]).args + self.assertEqual(args.apply_parsers, False) + self.assertEqual(args.use_raw_filesystem_access, True) + self.assertEqual( + args.implementation_type, + rdf_paths.PathSpec.ImplementationType.DIRECT, + ) + self.assertEqual(args.max_file_size, 1) + self.assertEqual(args.ignore_interpolation_errors, True) + def testGrep2(self): client_id = self.SetupClient(0, system="Linux") client_mock = action_mocks.ClientFileFinderClientMock() diff --git a/grr/server/grr_response_server/flows/general/discovery.py b/grr/server/grr_response_server/flows/general/discovery.py index 1991139fe4..8fe49f9c3f 100644 --- a/grr/server/grr_response_server/flows/general/discovery.py +++ b/grr/server/grr_response_server/flows/general/discovery.py @@ -30,6 +30,7 @@ from grr_response_server.databases import db from grr_response_server.flows.general import collectors from grr_response_server.flows.general import crowdstrike +from grr_response_server.flows.general import hardware from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto import rrg_pb2 @@ -347,6 +348,11 @@ def ProcessKnowledgeBase(self, responses): next_state=self.ProcessPasswdCacheUsers.__name__, ) + self.CallFlow( + hardware.CollectHardwareInfo.__name__, + next_state=self.ProcessHardwareInfo.__name__, + ) + non_kb_artifacts = config.CONFIG["Artifacts.non_kb_interrogate_artifacts"] if non_kb_artifacts: self.CallFlow( @@ -361,6 +367,17 @@ def ProcessKnowledgeBase(self, responses): except db.UnknownClientError: pass + def ProcessHardwareInfo( + self, + responses: flow_responses.Responses[rdf_client.HardwareInfo], + ) -> None: + if not responses.success: + self.Log("Failed to collect hardware information: %s", responses.status) + return + + for response in responses: + self.state.client.hardware_info = response + def ProcessArtifactResponses(self, responses): if not responses.success: self.Log("Error collecting artifacts: %s", responses.status) @@ -370,8 +387,6 @@ def ProcessArtifactResponses(self, responses): for response in responses: if isinstance(response, rdf_client_fs.Volume): self.state.client.volumes.append(response) - elif isinstance(response, rdf_client.HardwareInfo): - self.state.client.hardware_info = response else: raise ValueError("Unexpected response type: %s" % type(response)) diff --git a/grr/server/grr_response_server/flows/general/discovery_test.py b/grr/server/grr_response_server/flows/general/discovery_test.py index 8cfb94346f..445dafa2da 100644 --- a/grr/server/grr_response_server/flows/general/discovery_test.py +++ b/grr/server/grr_response_server/flows/general/discovery_test.py @@ -212,27 +212,20 @@ def _SetupMinimalClient(self): return client_id - @parser_test_lib.WithAllParsers def testInterrogateCloudMetadataLinux(self): """Check google cloud metadata on linux.""" client_id = self._SetupMinimalClient() with vfs_test_lib.VFSOverrider(rdf_paths.PathSpec.PathType.OS, vfs_test_lib.FakeTestDataVFSHandler): - with test_lib.ConfigOverrider({ - "Artifacts.knowledge_base": [ - "LinuxWtmp", - "NetgroupConfiguration", - ], - "Artifacts.netgroup_filter_regexes": [r"^login$"], - }): - client_mock = action_mocks.InterrogatedClient() - client_mock.InitializeClient() - with test_lib.SuppressLogs(): - flow_test_lib.TestFlowHelper( - discovery.Interrogate.__name__, - client_mock, - creator=self.test_username, - client_id=client_id) + client_mock = action_mocks.InterrogatedClient() + client_mock.InitializeClient() + + flow_test_lib.TestFlowHelper( + discovery.Interrogate.__name__, + client_mock, + creator=self.test_username, + client_id=client_id, + ) client = self._OpenClient(client_id) self._CheckCloudMetadata(client) diff --git a/grr/server/grr_response_server/flows/general/file_finder.py b/grr/server/grr_response_server/flows/general/file_finder.py index e6369f833c..f8ee37ef7d 100644 --- a/grr/server/grr_response_server/flows/general/file_finder.py +++ b/grr/server/grr_response_server/flows/general/file_finder.py @@ -22,6 +22,7 @@ from grr_response_server.flows.general import fingerprint from grr_response_server.flows.general import transfer from grr_response_server.models import blobs +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -477,6 +478,17 @@ def Start(self): else: stub = server_stubs.VfsFileFinder + # TODO: Remove this workaround once sandboxing issues are + # resolved and NTFS paths work it again. + if ( + self.args.pathtype == rdf_paths.PathSpec.PathType.NTFS + and not self.args.HasField("implementation_type") + ): + self.Log("Using unsandboxed NTFS access") + self.args.implementation_type = ( + rdf_paths.PathSpec.ImplementationType.DIRECT + ) + if (paths := self._InterpolatePaths(self.args.paths)) is not None: interpolated_args = self.args.Copy() interpolated_args.paths = paths @@ -678,7 +690,8 @@ def _WriteFilesContent( path_info.hash_entry.num_bytes = client_path_sizes[client_path] path_infos = list(client_path_path_info.values()) - data_store.REL_DB.WritePathInfos(self.client_id, path_infos) + proto_path_infos = [mig_objects.ToProtoPathInfo(pi) for pi in path_infos] + data_store.REL_DB.WritePathInfos(self.client_id, proto_path_infos) return client_path_hash_id diff --git a/grr/server/grr_response_server/flows/general/file_finder_test.py b/grr/server/grr_response_server/flows/general/file_finder_test.py index 7805b2136d..15a1f52cb4 100644 --- a/grr/server/grr_response_server/flows/general/file_finder_test.py +++ b/grr/server/grr_response_server/flows/general/file_finder_test.py @@ -30,6 +30,7 @@ from grr_response_server.databases import db_test_utils from grr_response_server.flows.general import file_finder from grr_response_server.flows.general import transfer +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import action_mocks from grr.test_lib import filesystem_test_lib @@ -74,10 +75,12 @@ def CheckFilesHashed(self, fnames): raise RuntimeError("Can't check unexpected result for correct " "hashes: %s" % fname) - path_info = data_store.REL_DB.ReadPathInfo( + proto_path_info = data_store.REL_DB.ReadPathInfo( self.client_id, rdf_objects.PathInfo.PathType.OS, - components=self.FilenameToPathComponents(fname)) + components=self.FilenameToPathComponents(fname), + ) + path_info = mig_objects.ToRDFPathInfo(proto_path_info) hash_obj = path_info.hash_entry self.assertEqual(str(hash_obj.sha256), file_hash) @@ -917,8 +920,11 @@ def _ReadTestPathInfo(self, path_type=rdf_objects.PathInfo.PathType.TSK): components = self.base_path.strip("/").split("/") components += path_components - return data_store.REL_DB.ReadPathInfo(self.client_id, path_type, - tuple(components)) + return mig_objects.ToRDFPathInfo( + data_store.REL_DB.ReadPathInfo( + self.client_id, path_type, tuple(components) + ) + ) def _ReadTestFile(self, path_components, @@ -1431,7 +1437,7 @@ def testInterpolationMissingAttributes(self): flow_args=flow_args) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.ERROR) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn("Missing knowledgebase attributes", flow_obj.error_message) log_entries = data_store.REL_DB.ReadFlowLogEntries( @@ -1461,7 +1467,7 @@ def testInterpolationUnknownAttributes(self): flow_args=flow_args) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.ERROR) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn("Unknown knowledgebase attributes", flow_obj.error_message) log_entries = data_store.REL_DB.ReadFlowLogEntries( @@ -1488,7 +1494,7 @@ def testSkipsGlobsWithInterpolationWhenNoKnowledgeBase(self): ) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.RUNNING) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.RUNNING) log_entries = data_store.REL_DB.ReadFlowLogEntries( client_id=client_id, flow_id=flow_id, offset=0, count=1024 ) @@ -1516,7 +1522,7 @@ def testFailsIfAllGlobsWithAreSkippedDueToNoKnowledgeBase(self): flow_args=flow_args) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.ERROR) + self.assertEqual(flow_obj.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn( "All globs skipped, as there's no knowledgebase available for" " interpolation", diff --git a/grr/server/grr_response_server/flows/general/filesystem.py b/grr/server/grr_response_server/flows/general/filesystem.py index 9d5736b7f5..729656d0c4 100644 --- a/grr/server/grr_response_server/flows/general/filesystem.py +++ b/grr/server/grr_response_server/flows/general/filesystem.py @@ -19,6 +19,7 @@ from grr_response_server import flow_responses from grr_response_server import notification from grr_response_server import server_stubs +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto import rrg_pb2 from grr_response_proto.rrg import fs_pb2 as rrg_fs_pb2 @@ -82,7 +83,10 @@ def WriteStatEntries(stat_entries, client_id): stat_response.st_mode &= ~stat_type_mask stat_response.st_mode |= stat.S_IFREG - path_infos = [rdf_objects.PathInfo.FromStatEntry(s) for s in stat_entries] + path_infos = _FilterOutPathInfoDuplicates( + [rdf_objects.PathInfo.FromStatEntry(s) for s in stat_entries] + ) + proto_path_infos = [mig_objects.ToProtoPathInfo(pi) for pi in path_infos] # NOTE: TSK may return duplicate entries. This is may be either due to # a bug in TSK implementation, or due to the fact that TSK is capable # of returning deleted files information. Our VFS data model only supports @@ -93,8 +97,7 @@ def WriteStatEntries(stat_entries, client_id): # Current behaviour is to simply drop excessive version before the # WritePathInfo call. This way files returned by TSK will still make it # into the flow's results, but not into the VFS data. - data_store.REL_DB.WritePathInfos(client_id, - _FilterOutPathInfoDuplicates(path_infos)) + data_store.REL_DB.WritePathInfos(client_id, proto_path_infos) def WriteFileFinderResults( @@ -123,6 +126,8 @@ def WriteFileFinderResults( path_info.hash_entry = r.hash_entry path_infos.append(path_info) + path_infos = _FilterOutPathInfoDuplicates(path_infos) + proto_path_infos = [mig_objects.ToProtoPathInfo(pi) for pi in path_infos] # NOTE: TSK may return duplicate entries. This is may be either due to # a bug in TSK implementation, or due to the fact that TSK is capable # of returning deleted files information. Our VFS data model only supports @@ -133,9 +138,7 @@ def WriteFileFinderResults( # Current behaviour is to simply drop excessive version before the # WritePathInfo call. This way files returned by TSK will still make it # into the flow's results, but not into the VFS data. - data_store.REL_DB.WritePathInfos( - client_id, _FilterOutPathInfoDuplicates(path_infos) - ) + data_store.REL_DB.WritePathInfos(client_id, proto_path_infos) class ListDirectoryArgs(rdf_structs.RDFProtoStruct): @@ -253,7 +256,9 @@ def List(self, responses): self.Log("Listed %s", self.state.urn) path_info = rdf_objects.PathInfo.FromStatEntry(self.state.stat) - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) stat_entries = list(map(rdf_client_fs.StatEntry, responses)) WriteStatEntries(stat_entries, client_id=self.client_id) diff --git a/grr/server/grr_response_server/flows/general/filesystem_test.py b/grr/server/grr_response_server/flows/general/filesystem_test.py index ee0187030e..0e446254d5 100644 --- a/grr/server/grr_response_server/flows/general/filesystem_test.py +++ b/grr/server/grr_response_server/flows/general/filesystem_test.py @@ -938,7 +938,7 @@ def testListingRegistryDirectoryDoesNotYieldMtimes(self): ["HKEY_LOCAL_MACHINE", "SOFTWARE", "ListingTest"]) self.assertLen(children, 2) for child in children: - self.assertIsNone(child.stat_entry.st_mtime) + self.assertFalse(child.stat_entry.st_mtime) def testNotificationWhenListingRegistry(self): # Change the username so notifications get written. diff --git a/grr/server/grr_response_server/flows/general/fingerprint.py b/grr/server/grr_response_server/flows/general/fingerprint.py index f543c4491e..8db24d7578 100644 --- a/grr/server/grr_response_server/flows/general/fingerprint.py +++ b/grr/server/grr_response_server/flows/general/fingerprint.py @@ -5,6 +5,7 @@ from grr_response_server import data_store from grr_response_server import flow_base from grr_response_server import server_stubs +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -65,7 +66,9 @@ def _ProcessFingerprint(self, responses): path_info = rdf_objects.PathInfo.FromPathSpec(pathspec) path_info.hash_entry = response.hash - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) self.ReceiveFileFingerprint( self.state.urn, hash_obj, request_data=responses.request_data) diff --git a/grr/server/grr_response_server/flows/general/hardware.py b/grr/server/grr_response_server/flows/general/hardware.py new file mode 100644 index 0000000000..9bacb94c93 --- /dev/null +++ b/grr/server/grr_response_server/flows/general/hardware.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +"""Flows for collecting hardware information.""" + +import plistlib +import re + +from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_core.lib.rdfvalues import client_action as rdf_client_action +from grr_response_core.lib.rdfvalues import mig_client +from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import sysinfo_pb2 +from grr_response_server import flow_base +from grr_response_server import flow_responses +from grr_response_server import server_stubs + + +class CollectHardwareInfo(flow_base.FlowBase): + """Flow that collects information about the hardware of the endpoint.""" + + category = "/Collectors/" + behaviours = flow_base.BEHAVIOUR_DEBUG + + result_types = [rdf_client.HardwareInfo] + + def Start(self) -> None: + if self.client_os == "Linux": + dmidecode_args = rdf_client_action.ExecuteRequest() + dmidecode_args.cmd = "/usr/sbin/dmidecode" + dmidecode_args.args.append("-q") + + self.CallClient( + server_stubs.ExecuteCommand, + dmidecode_args, + next_state=self._ProcessDmidecodeResults.__name__, + ) + elif self.client_os == "Windows": + win32_computer_system_product_args = rdf_client_action.WMIRequest() + win32_computer_system_product_args.query = """ + SELECT * + FROM Win32_ComputerSystemProduct + """.strip() + + self.CallClient( + server_stubs.WmiQuery, + win32_computer_system_product_args, + next_state=self._ProcessWin32ComputerSystemProductResults.__name__, + ) + elif self.client_os == "Darwin": + system_profiler_args = rdf_client_action.ExecuteRequest() + system_profiler_args.cmd = "/usr/sbin/system_profiler" + system_profiler_args.args.append("-xml") + system_profiler_args.args.append("SPHardwareDataType") + + self.CallClient( + server_stubs.ExecuteCommand, + system_profiler_args, + next_state=self._ProcessSystemProfilerResults.__name__, + ) + else: + message = f"Unsupported operating system: {self.client_os}" + raise flow_base.FlowError(message) + + def _ProcessDmidecodeResults( + self, + responses: flow_responses.Responses[rdf_client_action.ExecuteResponse], + ) -> None: + if not responses.success: + raise flow_base.FlowError( + f"Failed to run dmidecode: {responses.status}", + ) + + for response in responses: + if response.exit_status != 0: + raise flow_base.FlowError( + f"dmidecode quit abnormally (status: {response.exit_status}, " + f"stdout: {response.stdout}, stderr: {response.stderr})", + ) + + result = sysinfo_pb2.HardwareInfo() + + stdout = response.stdout.decode("utf-8", "backslashreplace") + lines = iter(stdout.splitlines()) + + for line in lines: + line = line.strip() + + if line == "System Information": + for line in lines: + if not line.strip(): + # Blank line ends system information section. + break + elif match := re.fullmatch(r"\s*Serial Number:\s*(.*)", line): + result.serial_number = match[1] + elif match := re.fullmatch(r"\s*Manufacturer:\s*(.*)", line): + result.system_manufacturer = match[1] + elif match := re.fullmatch(r"\s*Product Name:\s*(.*)", line): + result.system_product_name = match[1] + elif match := re.fullmatch(r"\s*UUID:\s*(.*)", line): + result.system_uuid = match[1] + elif match := re.fullmatch(r"\s*SKU Number:\s*(.*)", line): + result.system_sku_number = match[1] + elif match := re.fullmatch(r"\s*Family:\s*(.*)", line): + result.system_family = match[1] + elif match := re.fullmatch(r"\s*Asset Tag:\s*(.*)", line): + result.system_assettag = match[1] + + elif line == "BIOS Information": + for line in lines: + if not line.strip(): + # Blank link ends BIOS information section. + break + elif match := re.fullmatch(r"^\s*Vendor:\s*(.*)", line): + result.bios_vendor = match[1] + elif match := re.fullmatch(r"^\s*Version:\s*(.*)", line): + result.bios_version = match[1] + elif match := re.fullmatch(r"^\s*Release Date:\s*(.*)", line): + result.bios_release_date = match[1] + elif match := re.fullmatch(r"^\s*ROM Size:\s*(.*)", line): + result.bios_rom_size = match[1] + elif match := re.fullmatch(r"^\s*BIOS Revision:\s*(.*)", line): + result.bios_revision = match[1] + + self.SendReply(mig_client.ToRDFHardwareInfo(result)) + + def _ProcessWin32ComputerSystemProductResults( + self, + responses: flow_responses.Responses[rdf_protodict.Dict], + ) -> None: + if not responses.success: + raise flow_base.FlowError( + f"Failed to run WMI query: {responses.status}", + ) + + responses = list(responses) + + if len(responses) != 1: + raise flow_base.FlowError( + f"Unexpected number of WMI query results: {len(responses)}", + ) + + response = responses[0] + + result = sysinfo_pb2.HardwareInfo() + + if identifying_number := response.get("IdentifyingNumber"): + result.serial_number = identifying_number + if vendor := response.get("Vendor"): + result.system_manufacturer = vendor + + self.SendReply(mig_client.ToRDFHardwareInfo(result)) + + def _ProcessSystemProfilerResults( + self, + responses: flow_responses.Responses[rdf_client_action.ExecuteResponse], + ) -> None: + if not responses.success: + raise flow_base.FlowError( + f"Failed to run system profiler: {responses.status}", + ) + + for response in responses: + if response.exit_status != 0: + raise flow_base.FlowError( + f"system profiler quit abnormally (status: {response.exit_status}, " + f"stdout: {response.stdout}, stderr: {response.stderr})", + ) + + try: + plist = plistlib.loads(response.stdout) + except plistlib.InvalidFileException as error: + raise flow_base.FlowError( + f"Failed to parse system profiler output: {error}", + ) + + if not isinstance(plist, list): + raise flow_base.FlowError( + f"Unexpected type of system profiler output: {type(plist)}", + ) + + if len(plist) != 1: + raise flow_base.FlowError( + f"Unexpected length of system profiler output: {len(plist)}", + ) + + if not (items := plist[0].get("_items")): + raise flow_base.FlowError( + "`_items` property missing in system profiler output", + ) + + if not isinstance(items, list): + raise flow_base.FlowError( + f"Unexpected type of system profiler items: {type(items)}", + ) + + if len(items) != 1: + raise flow_base.FlowError( + f"Unexpected number of system profiler items: {len(items)}", + ) + + item = items[0] + + if not isinstance(item, dict): + raise flow_base.FlowError( + f"Unexpected type of system profiler item: {type(item)}", + ) + + result = sysinfo_pb2.HardwareInfo() + + if serial_number := item.get("serial_number"): + result.serial_number = serial_number + if machine_model := item.get("machine_model"): + result.system_product_name = machine_model + if boot_rom_version := item.get("boot_rom_version"): + result.bios_version = boot_rom_version + if platform_uuid := item.get("platform_UUID"): + result.system_uuid = platform_uuid + + self.SendReply(mig_client.ToRDFHardwareInfo(result)) diff --git a/grr/server/grr_response_server/flows/general/hardware_test.py b/grr/server/grr_response_server/flows/general/hardware_test.py new file mode 100644 index 0000000000..8814d51045 --- /dev/null +++ b/grr/server/grr_response_server/flows/general/hardware_test.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +from typing import Iterator + +from absl.testing import absltest + +from grr_response_core.lib.rdfvalues import client_action as rdf_client_action +from grr_response_core.lib.rdfvalues import mig_client_action +from grr_response_core.lib.rdfvalues import mig_protodict +from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_server import data_store +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_test_utils +from grr_response_server.flows.general import hardware +from grr_response_server.models import protodicts +from grr.test_lib import action_mocks +from grr.test_lib import artifact_test_lib +from grr.test_lib import flow_test_lib +from grr.test_lib import parser_test_lib +from grr.test_lib import testing_startup + + +class CollectHardwareInfoTest(flow_test_lib.FlowTestsBaseclass): + + @classmethod + def setUpClass(cls): + super().setUpClass() + testing_startup.TestInit() + + @parser_test_lib.WithAllParsers + def testLinux(self): + assert data_store.REL_DB is not None + db: abstract_db.Database = data_store.REL_DB + + creator = db_test_utils.InitializeUser(db) + client_id = db_test_utils.InitializeClient(db) + + snapshot = objects_pb2.ClientSnapshot() + snapshot.client_id = client_id + snapshot.knowledge_base.os = "Linux" + db.WriteClientSnapshot(snapshot) + + class ActionMock(action_mocks.ActionMock): + + def ExecuteCommand( + self, + args: rdf_client_action.ExecuteRequest, + ) -> Iterator[rdf_client_action.ExecuteResponse]: + args = mig_client_action.ToProtoExecuteRequest(args) + + if args.cmd != "/usr/sbin/dmidecode": + raise RuntimeError(f"Unexpected command: {args.cmd}") + + result = jobs_pb2.ExecuteResponse() + result.request.MergeFrom(args) + result.exit_status = 0 + result.stdout = """\ +BIOS Information + Vendor: Google + Version: Google + Release Date: 01/25/2024 + Address: 0xE8000 + Runtime Size: 96 kB + ROM Size: 64 kB + Characteristics: + BIOS characteristics not supported + Targeted content distribution is supported + BIOS Revision: 1.0 + +System Information + Manufacturer: Google + Product Name: Google Compute Engine + Version: Not Specified + Serial Number: GoogleCloud-ABCDEF1234567890ABCDEF1234567890 + UUID: 78fc848d-b909-4b53-a917-50d5203d88ac + Wake-up Type: Power Switch + SKU Number: Not Specified + Family: Not Specified + +Base Board Information + Manufacturer: Google + Product Name: Google Compute Engine + Version: Not Specified + Serial Number: Board-GoogleCloud-ABCDEF1234567890ABCDEF1234567890 + Asset Tag: 78FC848D-B909-4B53-A917-50D5203D88AC + Features: + Board is a hosting board + Location In Chassis: Not Specified + Type: Motherboard + +System Boot Information + Status: No errors detected + +""".encode("utf-8") + + yield mig_client_action.ToRDFExecuteResponse(result) + + with artifact_test_lib.PatchDefaultArtifactRegistry(): + flow_id = flow_test_lib.StartAndRunFlow( + hardware.CollectHardwareInfo, + ActionMock(), + client_id=client_id, + creator=creator, + ) + + results = flow_test_lib.GetFlowResults(client_id, flow_id) + + self.assertLen(results, 1) + + result = results[0] + + self.assertEqual( + result.serial_number, + "GoogleCloud-ABCDEF1234567890ABCDEF1234567890", + ) + + self.assertEqual(result.system_manufacturer, "Google") + self.assertEqual(result.system_product_name, "Google Compute Engine") + self.assertEqual(result.system_uuid, "78fc848d-b909-4b53-a917-50d5203d88ac") + self.assertEqual(result.system_sku_number, "Not Specified") + self.assertEqual(result.system_family, "Not Specified") + + self.assertEqual(result.bios_vendor, "Google") + self.assertEqual(result.bios_version, "Google") + self.assertEqual(result.bios_release_date, "01/25/2024") + self.assertEqual(result.bios_rom_size, "64 kB") + self.assertEqual(result.bios_revision, "1.0") + + @parser_test_lib.WithAllParsers + def testMacos(self): + assert data_store.REL_DB is not None + db: abstract_db.Database = data_store.REL_DB + + creator = db_test_utils.InitializeUser(db) + client_id = db_test_utils.InitializeClient(db) + + snapshot = objects_pb2.ClientSnapshot() + snapshot.client_id = client_id + snapshot.knowledge_base.os = "Darwin" + db.WriteClientSnapshot(snapshot) + + class ActionMock(action_mocks.ActionMock): + + def ExecuteCommand( + self, + args: rdf_client_action.ExecuteRequest, + ) -> Iterator[rdf_client_action.ExecuteResponse]: + args = mig_client_action.ToProtoExecuteRequest(args) + + if args.cmd != "/usr/sbin/system_profiler": + raise RuntimeError(f"Unexpected command: {args.cmd}") + + result = jobs_pb2.ExecuteResponse() + result.request.MergeFrom(args) + result.exit_status = 0 + result.stdout = """\ + + + + + + _SPCommandLineArguments + + /usr/sbin/system_profiler + -nospawn + -xml + SPHardwareDataType + -detailLevel + full + + _SPCompletionInterval + 0.044379949569702148 + _SPResponseTime + 0.19805097579956055 + _dataType + SPHardwareDataType + _detailLevel + -2 + _items + + + _name + hardware_overview + activation_lock_status + activation_lock_disabled + boot_rom_version + 10151.101.3 + chip_type + Apple M1 Pro + machine_model + MacBookPro18,3 + machine_name + MacBook Pro + model_number + Z15G000PCB/A + number_processors + proc 8:6:2 + os_loader_version + 10151.101.3 + physical_memory + 16 GB + platform_UUID + 48F1516D-23AB-4242-BB81-6F32D193D3F2 + provisioning_UDID + 00008000-0001022E3FD6901A + serial_number + XY42EDVYNN + + + _parentDataType + SPRootDataType + _timeStamp + 2024-04-12T15:26:32Z + _versionInfo + + com.apple.SystemProfiler.SPPlatformReporter + 1500 + + + + +""".encode("utf-8") + + yield mig_client_action.ToRDFExecuteResponse(result) + + with artifact_test_lib.PatchDefaultArtifactRegistry(): + flow_id = flow_test_lib.StartAndRunFlow( + hardware.CollectHardwareInfo, + ActionMock(), + client_id=client_id, + creator=creator, + ) + + results = flow_test_lib.GetFlowResults(client_id, flow_id) + + self.assertLen(results, 1) + + result = results[0] + self.assertEqual(result.serial_number, "XY42EDVYNN") + self.assertEqual(result.system_product_name, "MacBookPro18,3") + self.assertEqual(result.system_uuid, "48F1516D-23AB-4242-BB81-6F32D193D3F2") + self.assertEqual(result.bios_version, "10151.101.3") + + @parser_test_lib.WithAllParsers + def testWindows(self) -> None: + assert data_store.REL_DB is not None + db: abstract_db.Database = data_store.REL_DB + + creator = db_test_utils.InitializeUser(db) + client_id = db_test_utils.InitializeClient(db) + + snapshot = objects_pb2.ClientSnapshot() + snapshot.client_id = client_id + snapshot.knowledge_base.os = "Windows" + db.WriteClientSnapshot(snapshot) + + class ActionMock(action_mocks.ActionMock): + + def WmiQuery( + self, + args: rdf_client_action.WMIRequest, + ) -> Iterator[rdf_protodict.Dict]: + args = mig_client_action.ToProtoWMIRequest(args) + + if not args.query.upper().startswith("SELECT "): + raise RuntimeError("Non-`SELECT` WMI query") + + if "Win32_ComputerSystemProduct" not in args.query: + raise RuntimeError(f"Unexpected WMI query: {args.query!r}") + + result = { + "IdentifyingNumber": "2S42F1S3320HFN2179FV", + "Name": "42F1S3320H", + "Vendor": "LEVELHO", + "Version": "NumbBox Y1337", + "Caption": "Computer System Product", + } + + yield mig_protodict.ToRDFDict(protodicts.Dict(result)) + + with artifact_test_lib.PatchDefaultArtifactRegistry(): + flow_id = flow_test_lib.StartAndRunFlow( + hardware.CollectHardwareInfo, + ActionMock(), + client_id=client_id, + creator=creator, + ) + + results = flow_test_lib.GetFlowResults(client_id, flow_id) + + self.assertLen(results, 1) + self.assertEqual(results[0].serial_number, "2S42F1S3320HFN2179FV") + self.assertEqual(results[0].system_manufacturer, "LEVELHO") + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/flows/general/memory.py b/grr/server/grr_response_server/flows/general/memory.py index ac2c02cbe7..2d9b93ea07 100644 --- a/grr/server/grr_response_server/flows/general/memory.py +++ b/grr/server/grr_response_server/flows/general/memory.py @@ -70,13 +70,6 @@ def _ValidateFlowArgs(self): def Start(self): """See base class.""" self._ValidateFlowArgs() - if self.client_version < 3306: - # TODO(user): Remove when support ends for old clients (Jan 1 2022). - self.CallClient( - server_stubs.YaraProcessScan, - request=self.args, - next_state=self.ProcessScanResults.__name__) - return if self.args.scan_runtime_limit_us: # Back up original runtime limit. Override it for YaraProcessScan action @@ -141,6 +134,8 @@ def ProcessScanResults( regions_to_dump[match.process.pid].add(string_match.offset) for error in response.errors: + # TODO: Remove server side filtering for errors after + # clients adopted to the new version. if self._ShouldIncludeError(error): self.SendReply(error) diff --git a/grr/server/grr_response_server/flows/general/mig_webhistory.py b/grr/server/grr_response_server/flows/general/mig_webhistory.py index f761631e3b..62cb27d7e0 100644 --- a/grr/server/grr_response_server/flows/general/mig_webhistory.py +++ b/grr/server/grr_response_server/flows/general/mig_webhistory.py @@ -4,34 +4,6 @@ from grr_response_server.flows.general import webhistory -def ToProtoChromeHistoryArgs( - rdf: webhistory.ChromeHistoryArgs, -) -> flows_pb2.ChromeHistoryArgs: - return rdf.AsPrimitiveProto() - - -def ToRDFChromeHistoryArgs( - proto: flows_pb2.ChromeHistoryArgs, -) -> webhistory.ChromeHistoryArgs: - return webhistory.ChromeHistoryArgs.FromSerializedBytes( - proto.SerializeToString() - ) - - -def ToProtoFirefoxHistoryArgs( - rdf: webhistory.FirefoxHistoryArgs, -) -> flows_pb2.FirefoxHistoryArgs: - return rdf.AsPrimitiveProto() - - -def ToRDFFirefoxHistoryArgs( - proto: flows_pb2.FirefoxHistoryArgs, -) -> webhistory.FirefoxHistoryArgs: - return webhistory.FirefoxHistoryArgs.FromSerializedBytes( - proto.SerializeToString() - ) - - def ToProtoCollectBrowserHistoryArgs( rdf: webhistory.CollectBrowserHistoryArgs, ) -> flows_pb2.CollectBrowserHistoryArgs: diff --git a/grr/server/grr_response_server/flows/general/network.py b/grr/server/grr_response_server/flows/general/network.py index 55020e16ba..02620efcd5 100644 --- a/grr/server/grr_response_server/flows/general/network.py +++ b/grr/server/grr_response_server/flows/general/network.py @@ -24,24 +24,13 @@ def Start(self): self.CallClient( server_stubs.ListNetworkConnections, listening_only=self.args.listening_only, - next_state=self.ValidateListNetworkConnections.__name__) + next_state=self.StoreNetstat.__name__, + ) - def ValidateListNetworkConnections( + def StoreNetstat( self, responses: flow_responses.Responses, ) -> None: - if not responses.success: - # Most likely the client is old and doesn't have ListNetworkConnections. - self.Log("%s", responses.status) - - # Fallback to Netstat. - self.CallClient( - server_stubs.Netstat, next_state=self.StoreNetstat.__name__) - else: - self.CallStateInline( - next_state=self.StoreNetstat.__name__, responses=responses) - - def StoreNetstat(self, responses): """Collects the connections. Args: diff --git a/grr/server/grr_response_server/flows/general/osquery_test.py b/grr/server/grr_response_server/flows/general/osquery_test.py index 3ab75a7d44..64eb0573c4 100644 --- a/grr/server/grr_response_server/flows/general/osquery_test.py +++ b/grr/server/grr_response_server/flows/general/osquery_test.py @@ -15,12 +15,12 @@ from grr_response_core.lib.rdfvalues import osquery as rdf_osquery from grr_response_core.lib.util import temp from grr_response_core.lib.util import text +from grr_response_proto import flows_pb2 from grr_response_server import data_store from grr_response_server import file_store from grr_response_server import flow_base from grr_response_server.databases import db from grr_response_server.flows.general import osquery as osquery_flow -from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr.test_lib import action_mocks from grr.test_lib import flow_test_lib from grr.test_lib import osquery_test_lib @@ -257,7 +257,7 @@ def testFailure(self): client_mock=action_mocks.OsqueryClientMock()) flow = data_store.REL_DB.ReadFlowObject(self.client_id, flow_id) - self.assertEqual(flow.flow_state, rdf_flow_objects.Flow.FlowState.ERROR) + self.assertEqual(flow.flow_state, flows_pb2.Flow.FlowState.ERROR) self.assertIn(stderr, flow.error_message) def testSmallerTruncationLimit(self): diff --git a/grr/server/grr_response_server/flows/general/read_low_level.py b/grr/server/grr_response_server/flows/general/read_low_level.py index 99422e6a86..c3b9086d55 100644 --- a/grr/server/grr_response_server/flows/general/read_low_level.py +++ b/grr/server/grr_response_server/flows/general/read_low_level.py @@ -8,6 +8,7 @@ from grr_response_server import flow_base from grr_response_server import server_stubs from grr_response_server.databases import db +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -105,7 +106,9 @@ def StoreBlobsAsTmpFile(self, responses): path_info.hash_entry.source_offset = smallest_offset # Store file reference for this client in data_store. - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) result = rdf_read_low_level.ReadLowLevelFlowResult(path=tmp_filename) self.SendReply(result) diff --git a/grr/server/grr_response_server/flows/general/registry_init.py b/grr/server/grr_response_server/flows/general/registry_init.py index 0a44312885..3c5a9d0d9b 100644 --- a/grr/server/grr_response_server/flows/general/registry_init.py +++ b/grr/server/grr_response_server/flows/general/registry_init.py @@ -13,6 +13,7 @@ from grr_response_server.flows.general import export from grr_response_server.flows.general import file_finder from grr_response_server.flows.general import filesystem +from grr_response_server.flows.general import hardware from grr_response_server.flows.general import large_file from grr_response_server.flows.general import memory from grr_response_server.flows.general import network diff --git a/grr/server/grr_response_server/flows/general/timeline_test.py b/grr/server/grr_response_server/flows/general/timeline_test.py index 00f175f146..cc750aca60 100644 --- a/grr/server/grr_response_server/flows/general/timeline_test.py +++ b/grr/server/grr_response_server/flows/general/timeline_test.py @@ -259,8 +259,7 @@ def testFlowWithNoResult(self, db: abstract_db.Database) -> None: flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = timeline_flow.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() - db.WriteFlowObject(flow_obj) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) self.assertIsNone(timeline_flow.FilesystemType(client_id, flow_id)) @@ -275,8 +274,7 @@ def testFlowWithResult(self, db: abstract_db.Database) -> None: flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = timeline_flow.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() - db.WriteFlowObject(flow_obj) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) flow_result = rdf_flow_objects.FlowResult() flow_result.client_id = client_id diff --git a/grr/server/grr_response_server/flows/general/transfer.py b/grr/server/grr_response_server/flows/general/transfer.py index 09b25b955d..85abe18133 100644 --- a/grr/server/grr_response_server/flows/general/transfer.py +++ b/grr/server/grr_response_server/flows/general/transfer.py @@ -30,6 +30,7 @@ from grr_response_server.databases import db from grr_response_server.models import blobs as blob_models from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto import rrg_pb2 from grr_response_proto.rrg import fs_pb2 as rrg_fs_pb2 @@ -157,7 +158,8 @@ def HandleGetFileContents( path_info.hash_entry.sha256 = hash_id.AsBytes() path_info.hash_entry.num_bytes = sum(_.size for _ in blob_refs) - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + proto_path_info = mig_objects.ToProtoPathInfo(path_info) + data_store.REL_DB.WritePathInfos(self.client_id, [proto_path_info]) self.SendReply(path_info.stat_entry) @@ -310,7 +312,8 @@ def _AddFileToFileStore(self): path_info.hash_entry.sha256 = hash_id.AsBytes() path_info.hash_entry.num_bytes = offset - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + proto_path_info = mig_objects.ToProtoPathInfo(path_info) + data_store.REL_DB.WritePathInfos(self.client_id, [proto_path_info]) # Save some space. del self.state["blobs"] @@ -818,7 +821,8 @@ def _CheckHashesWithFileStore(self): stat_entry = file_tracker["stat_entry"] path_info = rdf_objects.PathInfo.FromStatEntry(stat_entry) path_info.hash_entry = file_tracker["hash_obj"] - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + proto_path_info = mig_objects.ToProtoPathInfo(path_info) + data_store.REL_DB.WritePathInfos(self.client_id, [proto_path_info]) # Report this hit to the flow's caller. self._ReceiveFetchedFile(file_tracker, is_duplicate=True) @@ -988,7 +992,8 @@ def _WriteBuffer(self, responses): path_info.hash_entry.sha256 = hash_id.AsBytes() path_info.hash_entry.num_bytes = offset - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + proto_path_info = mig_objects.ToProtoPathInfo(path_info) + data_store.REL_DB.WritePathInfos(self.client_id, [proto_path_info]) # Save some space. del file_tracker["blobs"] diff --git a/grr/server/grr_response_server/flows/general/transfer_test.py b/grr/server/grr_response_server/flows/general/transfer_test.py index 9309afb193..e50c56e052 100644 --- a/grr/server/grr_response_server/flows/general/transfer_test.py +++ b/grr/server/grr_response_server/flows/general/transfer_test.py @@ -165,8 +165,8 @@ def testGetFile(self): history = data_store.REL_DB.ReadPathInfoHistory(cp.client_id, cp.path_type, cp.components) self.assertEqual(history[-1].hash_entry.sha256, fd_rel_db.hash_id.AsBytes()) - self.assertIsNone(history[-1].hash_entry.sha1) - self.assertIsNone(history[-1].hash_entry.md5) + self.assertFalse(history[-1].hash_entry.HasField("sha1")) + self.assertFalse(history[-1].hash_entry.HasField("md5")) def testGetFilePathCorrection(self): """Tests that the pathspec returned is used for the aff4path.""" @@ -204,8 +204,8 @@ def testGetFilePathCorrection(self): cp.components) self.assertEqual(history[-1].hash_entry.sha256, fd_rel_db.hash_id.AsBytes()) self.assertEqual(history[-1].hash_entry.num_bytes, expected_size) - self.assertIsNone(history[-1].hash_entry.sha1) - self.assertIsNone(history[-1].hash_entry.md5) + self.assertFalse(history[-1].hash_entry.HasField("sha1")) + self.assertFalse(history[-1].hash_entry.HasField("md5")) def testGetFileIsDirectory(self): """Tests that the flow raises when called on directory.""" diff --git a/grr/server/grr_response_server/flows/general/webhistory.py b/grr/server/grr_response_server/flows/general/webhistory.py index 863b8a4599..c62b9ddddf 100644 --- a/grr/server/grr_response_server/flows/general/webhistory.py +++ b/grr/server/grr_response_server/flows/general/webhistory.py @@ -4,281 +4,19 @@ # DISABLED for now until it gets converted to artifacts. import collections -import datetime import os from typing import Iterator, cast -from grr_response_core.lib import rdfvalue -from grr_response_core.lib.parsers import chrome_history -from grr_response_core.lib.parsers import firefox3_history from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs -from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder -from grr_response_core.lib.rdfvalues import mig_client from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import flows_pb2 -from grr_response_server import data_store -from grr_response_server import file_store from grr_response_server import flow_base from grr_response_server import flow_responses -from grr_response_server import flow_utils from grr_response_server.databases import db from grr_response_server.flows.general import collectors -from grr_response_server.flows.general import file_finder from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -class ChromeHistoryArgs(rdf_structs.RDFProtoStruct): - protobuf = flows_pb2.ChromeHistoryArgs - - -class ChromeHistory(flow_base.FlowBase): - r"""Retrieve and analyze the chrome history for a machine. - - Default directories as per: - http://www.chromium.org/user-experience/user-data-directory - - Windows XP - Google Chrome: - c:\\Documents and Settings\\\\Local Settings\\Application Data\\ - Google\\Chrome\\User Data\\Default - - Windows 7 or Vista - c:\\Users\\\\AppData\\Local\\Google\\Chrome\\User Data\\Default - - Mac OS X - /Users//Library/Application Support/Google/Chrome/Default - - Linux - /home//.config/google-chrome/Default - """ - - category = "/Browser/" - args_type = ChromeHistoryArgs - behaviours = flow_base.BEHAVIOUR_DEBUG # Deprecated. - - def Start(self): - """Determine the Chrome directory.""" - self.state.hist_count = 0 - # List of paths where history files are located - self.state.history_paths = [] - if self.args.history_path: - self.state.history_paths.append(self.args.history_path) - - if not self.state.history_paths: - self.state.history_paths = self.GuessHistoryPaths(self.args.username) - - if not self.state.history_paths: - raise flow_base.FlowError("Could not find valid History paths.") - - filenames = ["History"] - if self.args.get_archive: - filenames.append("Archived History") - - for path in self.state.history_paths: - for fname in filenames: - self.CallFlow( - file_finder.FileFinder.__name__, - paths=[os.path.join(path, fname)], - pathtype=self.args.pathtype, - action=rdf_file_finder.FileFinderAction.Download(), - next_state=self.ParseFiles.__name__) - - def ParseFiles(self, responses): - """Take each file we retrieved and get the history from it.""" - # Note that some of these Find requests will fail because some paths don't - # exist, e.g. Chromium on most machines, so we don't check for success. - if responses: - for response in responses: - client_path = db.ClientPath.FromPathSpec(self.client_id, - response.stat_entry.pathspec) - filepath = response.stat_entry.pathspec.CollapsePath() - fd = file_store.OpenFile(client_path) - hist = chrome_history.ChromeParser() - count = 0 - for epoch64, dtype, url, dat1, dat2, dat3 in hist.Parse(filepath, fd): - count += 1 - str_entry = "%s %s %s %s %s %s" % (datetime.datetime.utcfromtimestamp( - epoch64 / 1e6), url, dat1, dat2, dat3, dtype) - self.SendReply(rdfvalue.RDFString(str_entry)) - - self.Log("Wrote %d Chrome History entries for user %s from %s", count, - self.args.username, response.stat_entry.pathspec.Basename()) - self.state.hist_count += count - - def GuessHistoryPaths(self, username): - """Take a user and return guessed full paths to History files. - - Args: - username: Username as string. - - Returns: - A list of strings containing paths to look for history files in. - - Raises: - OSError: On invalid system in the Schema - """ - client = data_store.REL_DB.ReadClientSnapshot(self.client_id) - system = client.knowledge_base.os - user_info = flow_utils.GetUserInfo( - mig_client.ToRDFKnowledgeBase(client.knowledge_base), username - ) - - if not user_info: - self.Error("Could not find homedir for user {0}".format(username)) - return - - paths = [] - if system == "Windows": - path = ("{app_data}\\{sw}\\User Data\\Default\\") - for sw_path in ["Google\\Chrome", "Chromium"]: - paths.append(path.format(app_data=user_info.localappdata, sw=sw_path)) - elif system == "Linux": - path = "{homedir}/.config/{sw}/Default/" - for sw_path in ["google-chrome", "chromium"]: - paths.append(path.format(homedir=user_info.homedir, sw=sw_path)) - elif system == "Darwin": - path = "{homedir}/Library/Application Support/{sw}/Default/" - for sw_path in ["Google/Chrome", "Chromium"]: - paths.append(path.format(homedir=user_info.homedir, sw=sw_path)) - else: - raise OSError("Invalid OS for Chrome History") - return paths - - -class FirefoxHistoryArgs(rdf_structs.RDFProtoStruct): - protobuf = flows_pb2.FirefoxHistoryArgs - - -class FirefoxHistory(flow_base.FlowBase): - r"""Retrieve and analyze the Firefox history for a machine. - - Default directories as per: - http://www.forensicswiki.org/wiki/Mozilla_Firefox_3_History_File_Format - - Windows XP - C:\\Documents and Settings\\\\Application Data\\Mozilla\\ - Firefox\\Profiles\\\\places.sqlite - - Windows Vista - C:\\Users\\\\AppData\\Roaming\\Mozilla\\Firefox\\Profiles\\ - \\places.sqlite - - GNU/Linux - /home//.mozilla/firefox//places.sqlite - - Mac OS X - /Users//Library/Application Support/Firefox/Profiles/ - /places.sqlite - """ - - category = "/Browser/" - args_type = FirefoxHistoryArgs - behaviours = flow_base.BEHAVIOUR_DEBUG # Deprecated. - - def Start(self): - """Determine the Firefox history directory.""" - self.state.hist_count = 0 - self.state.history_paths = [] - - if self.args.history_path: - self.state.history_paths.append(self.args.history_path) - else: - self.state.history_paths = self.GuessHistoryPaths(self.args.username) - - if not self.state.history_paths: - raise flow_base.FlowError("Could not find valid History paths.") - - filename = "places.sqlite" - for path in self.state.history_paths: - self.CallFlow( - file_finder.FileFinder.__name__, - paths=[os.path.join(path, "**2", filename)], - pathtype=self.args.pathtype, - action=rdf_file_finder.FileFinderAction.Download(), - next_state=self.ParseFiles.__name__) - - def ParseFiles(self, responses): - """Take each file we retrieved and get the history from it.""" - if responses: - for response in responses: - client_path = db.ClientPath.FromPathSpec(self.client_id, - response.stat_entry.pathspec) - fd = file_store.OpenFile(client_path) - hist = firefox3_history.Firefox3History() - count = 0 - for epoch64, dtype, url, dat1, in hist.Parse(fd): - count += 1 - str_entry = "%s %s %s %s" % (datetime.datetime.utcfromtimestamp( - epoch64 / 1e6), url, dat1, dtype) - self.SendReply(rdfvalue.RDFString(str_entry)) - self.Log("Wrote %d Firefox History entries for user %s from %s", count, - self.args.username, response.stat_entry.pathspec.Basename()) - self.state.hist_count += count - - def GuessHistoryPaths(self, username): - """Take a user and return guessed full paths to History files. - - Args: - username: Username as string. - - Returns: - A list of strings containing paths to look for history files in. - - Raises: - OSError: On invalid system in the Schema - """ - client = data_store.REL_DB.ReadClientSnapshot(self.client_id) - system = client.knowledge_base.os - user_info = flow_utils.GetUserInfo( - mig_client.ToRDFKnowledgeBase(client.knowledge_base), username - ) - - if not user_info: - self.Error("Could not find homedir for user {0}".format(username)) - return - - paths = [] - if system == "Windows": - path = "{app_data}\\Mozilla\\Firefox\\Profiles/" - paths.append(path.format(app_data=user_info.appdata)) - elif system == "Linux": - path = "{homedir}/.mozilla/firefox/" - paths.append(path.format(homedir=user_info.homedir)) - elif system == "Darwin": - path = ("{homedir}/Library/Application Support/" "Firefox/Profiles/") - paths.append(path.format(homedir=user_info.homedir)) - else: - raise OSError("Invalid OS for Chrome History") - return paths - - -BROWSER_PATHS = { - "Linux": { - "Firefox": ["/home/{username}/.mozilla/firefox/"], - "Chrome": [ - "{homedir}/.config/google-chrome/", "{homedir}/.config/chromium/" - ] - }, - "Windows": { - "Chrome": [ - "{local_app_data}\\Google\\Chrome\\User Data\\", - "{local_app_data}\\Chromium\\User Data\\" - ], - "Firefox": ["{local_app_data}\\Mozilla\\Firefox\\Profiles\\"], - "IE": [ - "{cache}\\", "{cache}\\Low\\", "{app_data}\\Microsoft\\Windows\\" - ] - }, - "Darwin": { - "Firefox": ["{homedir}/Library/Application Support/Firefox/Profiles/"], - "Chrome": [ - "{homedir}/Library/Application Support/Google/Chrome/", - "{homedir}/Library/Application Support/Chromium/" - ] - } -} - - class CollectBrowserHistoryArgs(rdf_structs.RDFProtoStruct): """Arguments for CollectBrowserHistory.""" protobuf = flows_pb2.CollectBrowserHistoryArgs diff --git a/grr/server/grr_response_server/flows/general/webhistory_test.py b/grr/server/grr_response_server/flows/general/webhistory_test.py index b35d7cd2e0..f1fc92e663 100644 --- a/grr/server/grr_response_server/flows/general/webhistory_test.py +++ b/grr/server/grr_response_server/flows/general/webhistory_test.py @@ -7,23 +7,19 @@ from absl import app from grr_response_client import client_utils -from grr_response_core.lib.parsers import chrome_history -from grr_response_core.lib.parsers import firefox3_history from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import paths as rdf_paths -from grr_response_server import file_store from grr_response_server import flow_base from grr_response_server.databases import db from grr_response_server.flows.general import collectors from grr_response_server.flows.general import webhistory from grr.test_lib import action_mocks from grr.test_lib import flow_test_lib -from grr.test_lib import parser_test_lib from grr.test_lib import test_lib -class WebHistoryFlowTestMixin(flow_test_lib.FlowTestsBaseclass): +class TestWebHistoryWithArtifacts(flow_test_lib.FlowTestsBaseclass): def MockClientRawDevWithImage(self): """Mock the client to run off a test image. @@ -41,92 +37,6 @@ def MockGetRawdevice(path): return mock.patch.object(client_utils, "GetRawDevice", MockGetRawdevice) - -class TestWebHistory(WebHistoryFlowTestMixin): - """Test the browser history flows.""" - - def setUp(self): - super().setUp() - # Set up client info - users = [ - rdf_client.User( - username="test", - full_name="test user", - homedir="/home/test", - last_logon=250, - ) - ] - self.client_id = self.SetupClient(0, system="Linux", users=users) - - self.client_mock = action_mocks.ClientFileFinderWithVFS() - - def testChromeHistoryFetch(self): - """Test that downloading the Chrome history works.""" - with self.MockClientRawDevWithImage(): - # Run the flow in the simulated way - session_id = flow_test_lib.TestFlowHelper( - webhistory.ChromeHistory.__name__, - self.client_mock, - check_flow_errors=False, - client_id=self.client_id, - username="test", - creator=self.test_username, - pathtype=rdf_paths.PathSpec.PathType.TSK) - - # Now check that the right files were downloaded. - fs_path = "/home/test/.config/google-chrome/Default/History" - - components = list(filter(bool, self.base_path.split(os.path.sep))) - components.append("test_img.dd") - components.extend(filter(bool, fs_path.split(os.path.sep))) - - # Check if the History file is created. - cp = db.ClientPath.TSK(self.client_id, tuple(components)) - fd = file_store.OpenFile(cp) - self.assertGreater(len(fd.read()), 20000) - - # Check for analysis file. - results = flow_test_lib.GetFlowResults(self.client_id, session_id) - self.assertGreater(len(results), 50) - self.assertIn("funnycats.exe", "\n".join(map(str, results))) - - def testFirefoxHistoryFetch(self): - """Test that downloading the Firefox history works.""" - with self.MockClientRawDevWithImage(): - # Run the flow in the simulated way - session_id = flow_test_lib.TestFlowHelper( - webhistory.FirefoxHistory.__name__, - self.client_mock, - check_flow_errors=False, - client_id=self.client_id, - username="test", - creator=self.test_username, - # This has to be TSK, since test_img.dd is an EXT3 file system. - pathtype=rdf_paths.PathSpec.PathType.TSK) - - # Now check that the right files were downloaded. - fs_path = "/home/test/.mozilla/firefox/adts404t.default/places.sqlite" - - components = list(filter(bool, self.base_path.split(os.path.sep))) - components.append("test_img.dd") - components.extend(filter(bool, fs_path.split(os.path.sep))) - - # Check if the History file is created. - cp = db.ClientPath.TSK(self.client_id, tuple(components)) - rel_fd = file_store.OpenFile(cp) - self.assertEqual(rel_fd.read(15), b"SQLite format 3") - - # Check for analysis file. - results = flow_test_lib.GetFlowResults(self.client_id, session_id) - self.assertGreater(len(results), 3) - data = "\n".join(map(str, results)) - self.assertNotEqual(data.find("Welcome to Firefox"), -1) - self.assertNotEqual(data.find("sport.orf.at"), -1) - - -class TestWebHistoryWithArtifacts(WebHistoryFlowTestMixin): - """Test the browser history flows.""" - def setUp(self): super().setUp() users = [ @@ -160,38 +70,6 @@ def RunCollectorAndGetCollection(self, artifact_list, client_mock=None, **kw): return flow_test_lib.GetFlowResults(self.client_id, session_id) - @parser_test_lib.WithParser("Chrome", chrome_history.ChromeHistoryParser) - def testChrome(self): - """Check we can run WMI based artifacts.""" - with self.MockClientRawDevWithImage(): - - fd = self.RunCollectorAndGetCollection(["ChromiumBasedBrowsersHistory"], - client_mock=self.client_mock, - use_raw_filesystem_access=True) - - self.assertLen(fd, 71) - self.assertIn("/home/john/Downloads/funcats_scr.exe", - [d.download_path for d in fd]) - self.assertIn("http://www.java.com/", [d.url for d in fd]) - self.assertEndsWith(fd[0].source_path, - "/home/test/.config/google-chrome/Default/History") - - @parser_test_lib.WithParser("Firefox", firefox3_history.FirefoxHistoryParser) - def testFirefox(self): - """Check we can run WMI based artifacts.""" - with self.MockClientRawDevWithImage(): - fd = self.RunCollectorAndGetCollection( - [webhistory.FirefoxHistory.__name__], - client_mock=self.client_mock, - use_raw_filesystem_access=True) - - self.assertLen(fd, 5) - self.assertEqual(fd[0].access_time.AsSecondsSinceEpoch(), 1340623334) - self.assertIn("http://sport.orf.at/", [d.url for d in fd]) - self.assertEndsWith( - fd[0].source_path, - "/home/test/.mozilla/firefox/adts404t.default/places.sqlite") - class MockArtifactCollectorFlow(collectors.ArtifactCollectorFlow): diff --git a/grr/server/grr_response_server/frontend_lib.py b/grr/server/grr_response_server/frontend_lib.py index 6547280532..cf070f8f35 100644 --- a/grr/server/grr_response_server/frontend_lib.py +++ b/grr/server/grr_response_server/frontend_lib.py @@ -2,13 +2,12 @@ """The GRR frontend server.""" import logging import time -from typing import Optional, Sequence +from typing import Optional, Sequence, Union from grr_response_core.lib import queues from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import flows as rdf_flows -from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import collection from grr_response_core.lib.util import random from grr_response_core.stats import metrics @@ -21,6 +20,7 @@ from grr_response_server.databases import db from grr_response_server.flows.general import transfer from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr_response_proto import rrg_pb2 @@ -132,6 +132,34 @@ def ReceiveMessages( frontend_message_handler_requests = [] dropped_count = 0 + # TODO: Remove `fixed_messages` once old clients + # have been migrated. + fixed_messages = [] + for message in messages: + if message.type != rdf_flows.GrrMessage.Type.STATUS: + fixed_messages.append(message) + continue + + stat = rdf_flows.GrrStatus(message.payload) + if not stat.HasField("cpu_time_used"): + fixed_messages.append(message) + continue + + if stat.cpu_time_used.HasField("deprecated_user_cpu_time"): + stat.cpu_time_used.user_cpu_time = ( + stat.cpu_time_used.deprecated_user_cpu_time + ) + stat.cpu_time_used.deprecated_user_cpu_time = None + if stat.cpu_time_used.HasField("deprecated_system_cpu_time"): + stat.cpu_time_used.system_cpu_time = ( + stat.cpu_time_used.deprecated_system_cpu_time + ) + stat.cpu_time_used.deprecated_system_cpu_time = None + message.payload = stat + fixed_messages.append(message) + + messages = fixed_messages + msgs_by_session_id = collection.Group(messages, lambda m: m.session_id) for session_id, msgs in msgs_by_session_id.items(): try: @@ -172,11 +200,19 @@ def ReceiveMessages( flow_responses = [] for message in unprocessed_msgs: try: - flow_responses.append( - rdf_flow_objects.FlowResponseForLegacyResponse(message)) + response = rdf_flow_objects.FlowResponseForLegacyResponse(message) except ValueError as e: - logging.warning("Failed to parse legacy FlowResponse:\n%s\n%s", e, - message) + logging.warning( + "Failed to parse legacy FlowResponse:\n%s\n%s", e, message + ) + else: + if isinstance(response, rdf_flow_objects.FlowStatus): + response = mig_flow_objects.ToProtoFlowStatus(response) + if isinstance(response, rdf_flow_objects.FlowIterator): + response = mig_flow_objects.ToProtoFlowIterator(response) + if isinstance(response, rdf_flow_objects.FlowResponse): + response = mig_flow_objects.ToProtoFlowResponse(response) + flow_responses.append(response) data_store.REL_DB.WriteFlowResponses(flow_responses) @@ -191,21 +227,36 @@ def ReceiveMessages( backtrace=stat.backtrace, crash_message=stat.error_message, nanny_status=stat.nanny_status, - timestamp=rdfvalue.RDFDatetime.Now()) + timestamp=rdfvalue.RDFDatetime.Now(), + ) events.Events.PublishEvent( - "ClientCrash", crash_details, username=FRONTEND_USERNAME) + "ClientCrash", crash_details, username=FRONTEND_USERNAME + ) if worker_message_handler_requests: + worker_message_handler_requests = [ + mig_objects.ToProtoMessageHandlerRequest(r) + for r in worker_message_handler_requests + ] data_store.REL_DB.WriteMessageHandlerRequests( - worker_message_handler_requests) + worker_message_handler_requests + ) if frontend_message_handler_requests: + frontend_message_handler_requests = [ + mig_objects.ToProtoMessageHandlerRequest(r) + for r in frontend_message_handler_requests + ] worker_lib.ProcessMessageHandlerRequests( - frontend_message_handler_requests) + frontend_message_handler_requests + ) - logging.debug("Received %s messages from %s in %s sec", len(messages), - client_id, - time.time() - now) + logging.debug( + "Received %s messages from %s in %s sec", + len(messages), + client_id, + time.time() - now, + ) def ReceiveRRGResponse( self, @@ -218,25 +269,27 @@ def ReceiveRRGResponse( client_id: An identifier of the client for which we process the response. response: A response to process. """ - flow_response: rdf_flow_objects.FlowMessage + flow_response: Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ] if response.HasField("status"): - flow_response = rdf_flow_objects.FlowStatus() + flow_response = flows_pb2.FlowStatus() flow_response.network_bytes_sent = response.status.network_bytes_sent # TODO: Populate `cpu_time_used` and `runtime_us` if response.status.HasField("error"): # TODO: Convert RRG error types to GRR error types. - flow_response.status = rdf_flow_objects.FlowStatus.Status.ERROR + flow_response.status = flows_pb2.FlowStatus.Status.ERROR flow_response.error_message = response.status.error.message else: - flow_response.status = rdf_flow_objects.FlowStatus.Status.OK + flow_response.status = flows_pb2.FlowStatus.Status.OK elif response.HasField("result"): - packed_result = rdf_structs.AnyValue.FromProto2(response.result) - - flow_response = rdf_flow_objects.FlowResponse() - flow_response.any_payload = packed_result + flow_response = flows_pb2.FlowResponse() + flow_response.any_payload.CopyFrom(response.result) elif response.HasField("log"): log = response.log diff --git a/grr/server/grr_response_server/frontend_lib_test.py b/grr/server/grr_response_server/frontend_lib_test.py index 276d30f3cc..841134dbf8 100644 --- a/grr/server/grr_response_server/frontend_lib_test.py +++ b/grr/server/grr_response_server/frontend_lib_test.py @@ -11,6 +11,7 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import flows_pb2 from grr_response_server import data_store from grr_response_server import fleetspeak_connector from grr_response_server import frontend_lib @@ -44,15 +45,16 @@ class GRRFEServerTestRelational(flow_test_lib.FlowTestsBaseclass): """Tests the GRRFEServer with relational flows enabled.""" def _FlowSetup(self, client_id, flow_id): - rdf_flow = rdf_flow_objects.Flow( + rdf_flow = flows_pb2.Flow( flow_class_name=administrative.OnlineNotification.__name__, client_id=client_id, flow_id=flow_id, - create_time=rdfvalue.RDFDatetime.Now()) + ) data_store.REL_DB.WriteFlowObject(rdf_flow) - req = rdf_flow_objects.FlowRequest( - client_id=client_id, flow_id=flow_id, request_id=1) + req = flows_pb2.FlowRequest( + client_id=client_id, flow_id=flow_id, request_id=1 + ) data_store.REL_DB.WriteFlowRequests([req]) @@ -63,7 +65,9 @@ def testReceiveMessages(self): client_id = "C.1234567890123456" flow_id = "12345678" data_store.REL_DB.WriteClientMetadata(client_id) + before_flow_create = data_store.REL_DB.Now() _, req = self._FlowSetup(client_id, flow_id) + after_flow_create = data_store.REL_DB.Now() session_id = "%s/%s" % (client_id, flow_id) messages = [ @@ -79,7 +83,13 @@ def testReceiveMessages(self): received = data_store.REL_DB.ReadAllFlowRequestsAndResponses( client_id, flow_id) self.assertLen(received, 1) - self.assertEqual(received[0][0], req) + received_request = received[0][0] + self.assertEqual(received_request.client_id, req.client_id) + self.assertEqual(received_request.flow_id, req.flow_id) + self.assertEqual(received_request.request_id, req.request_id) + self.assertBetween( + received_request.timestamp, before_flow_create, after_flow_create + ) self.assertLen(received[0][1], 9) def testBlobHandlerMessagesAreHandledOnTheFrontend(self): @@ -138,6 +148,41 @@ def testCrashReport(self): self.assertTrue(crash_details_rel) self.assertEqual(crash_details_rel.session_id, session_id) + def testReceiveStatusMessage(self): + client_id = "C.1234567890123456" + flow_id = "12345678" + data_store.REL_DB.WriteClientMetadata(client_id) + self._FlowSetup(client_id, flow_id) + + session_id = rdfvalue.FlowSessionID(f"{client_id}/{flow_id}") + status = rdf_flows.GrrStatus(status=rdf_flows.GrrStatus.ReturnedStatus.OK) + status.cpu_time_used.deprecated_user_cpu_time = 1.1 + status.cpu_time_used.deprecated_system_cpu_time = 2.2 + + messages = [ + rdf_flows.GrrMessage( + source=client_id, + request_id=1, + response_id=1, + session_id=session_id, + payload=status, + auth_state="AUTHENTICATED", + type=rdf_flows.GrrMessage.Type.STATUS, + ) + ] + + ReceiveMessages(client_id, messages) + + received = data_store.REL_DB.ReadAllFlowRequestsAndResponses( + client_id, flow_id + ) + self.assertLen(received, 1) + self.assertNotEqual(received[0][1][1], status) + self.assertAlmostEqual(received[0][1][1].cpu_time_used.user_cpu_time, 1.1) + self.assertFalse(received[0][1][1].cpu_time_used.deprecated_user_cpu_time) + self.assertAlmostEqual(received[0][1][1].cpu_time_used.system_cpu_time, 2.2) + self.assertFalse(received[0][1][1].cpu_time_used.deprecated_system_cpu_time) + class FleetspeakFrontendTests(flow_test_lib.FlowTestsBaseclass): @@ -164,7 +209,7 @@ def testReceiveRRGResponseStatusOK(self, db: abstract_db.Database): client_id = db_test_utils.InitializeClient(db) flow_id = db_test_utils.InitializeFlow(db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1337 @@ -182,7 +227,7 @@ def testReceiveRRGResponseStatusOK(self, db: abstract_db.Database): self.assertLen(flow_responses, 1) flow_response = flow_responses[0][1][response.response_id] - self.assertIsInstance(flow_response, rdf_flow_objects.FlowStatus) + self.assertIsInstance(flow_response, flows_pb2.FlowStatus) self.assertEqual(flow_response.client_id, client_id) self.assertEqual(flow_response.flow_id, flow_id) self.assertEqual(flow_response.request_id, 1337) @@ -199,7 +244,7 @@ def testReceiveRRGResponseStatusError(self, db: abstract_db.Database): client_id = db_test_utils.InitializeClient(db) flow_id = db_test_utils.InitializeFlow(db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1337 @@ -218,7 +263,7 @@ def testReceiveRRGResponseStatusError(self, db: abstract_db.Database): self.assertLen(flow_responses, 1) flow_response = flow_responses[0][1][response.response_id] - self.assertIsInstance(flow_response, rdf_flow_objects.FlowStatus) + self.assertIsInstance(flow_response, flows_pb2.FlowStatus) self.assertEqual(flow_response.client_id, client_id) self.assertEqual(flow_response.flow_id, flow_id) self.assertEqual(flow_response.request_id, 1337) @@ -235,7 +280,7 @@ def testReceiveRRGResponseResult(self, db: abstract_db.Database): client_id = db_test_utils.InitializeClient(db) flow_id = db_test_utils.InitializeFlow(db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1337 @@ -253,7 +298,7 @@ def testReceiveRRGResponseResult(self, db: abstract_db.Database): self.assertLen(flow_responses, 1) flow_response = flow_responses[0][1][response.response_id] - self.assertIsInstance(flow_response, rdf_flow_objects.FlowResponse) + self.assertIsInstance(flow_response, flows_pb2.FlowResponse) self.assertEqual(flow_response.client_id, client_id) self.assertEqual(flow_response.flow_id, flow_id) self.assertEqual(flow_response.request_id, 1337) @@ -268,7 +313,7 @@ def testReceiveRRGResponseLog(self, db: abstract_db.Database): client_id = db_test_utils.InitializeClient(db) flow_id = db_test_utils.InitializeFlow(db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1337 @@ -293,7 +338,7 @@ def testReceiveRRGResponseUnexpected(self, db: abstract_db.Database): client_id = db_test_utils.InitializeClient(db) flow_id = db_test_utils.InitializeFlow(db, client_id) - flow_request = rdf_flow_objects.FlowRequest() + flow_request = flows_pb2.FlowRequest() flow_request.client_id = client_id flow_request.flow_id = flow_id flow_request.request_id = 1337 diff --git a/grr/server/grr_response_server/gui/api_call_handler_base.py b/grr/server/grr_response_server/gui/api_call_handler_base.py index 5616941694..6a48c154fa 100644 --- a/grr/server/grr_response_server/gui/api_call_handler_base.py +++ b/grr/server/grr_response_server/gui/api_call_handler_base.py @@ -67,6 +67,12 @@ class ApiCallHandler: # that implement Handle() method. result_type = None + # Proto type used as input to the handler Handle() method. + proto_args_type = None + + # Proto type returned by the handler Handle() method. + proto_result_type = None + # This is a maximum time in seconds the renderer is allowed to run. Renderers # exceeding this time are killed softly (i.e. the time is not a guaranteed # maximum, but will be used as a guide). diff --git a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py index d81dd465e7..152f5863e1 100644 --- a/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py +++ b/grr/server/grr_response_server/gui/api_call_router_with_approval_checks.py @@ -32,6 +32,8 @@ from grr_response_server.gui.api_plugins import user as api_user from grr_response_server.gui.api_plugins import vfs as api_vfs from grr_response_server.gui.api_plugins import yara as api_yara +from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import mig_objects @@ -42,6 +44,7 @@ administrative.ExecuteCommand, administrative.ExecutePythonHack, administrative.LaunchBinary, + administrative.UpdateClient, administrative.UpdateConfiguration, ] @@ -749,9 +752,11 @@ def ModifyHunt(self, args, context=None): return self.delegate.ModifyHunt(args, context=context) - def _GetHuntObj(self, hunt_id, context=None): + def _GetHuntObj(self, hunt_id, context=None) -> rdf_hunt_objects.Hunt: try: - return data_store.REL_DB.ReadHuntObject(str(hunt_id)) + hunt_obj = data_store.REL_DB.ReadHuntObject(str(hunt_id)) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + return hunt_obj except db.UnknownHuntError: raise api_call_handler_base.ResourceNotFoundError( "Hunt with id %s could not be found" % hunt_id) diff --git a/grr/server/grr_response_server/gui/api_integration_tests/audit_test.py b/grr/server/grr_response_server/gui/api_integration_tests/audit_test.py index 41e10fd5f1..e2f9cc13ed 100644 --- a/grr/server/grr_response_server/gui/api_integration_tests/audit_test.py +++ b/grr/server/grr_response_server/gui/api_integration_tests/audit_test.py @@ -3,6 +3,7 @@ from absl import app +from grr_response_proto import objects_pb2 from grr_response_server import data_store from grr_response_server.gui import api_integration_test_lib from grr.test_lib import test_lib @@ -20,7 +21,7 @@ def testFlowIsAudited(self): self.assertEqual(entry.http_request_path, "/api/v2/clients?count=50&offset=0&query=.") - self.assertEqual(entry.response_code, "OK") + self.assertEqual(entry.response_code, objects_pb2.APIAuditEntry.Code.OK) self.assertEqual(entry.router_method_name, "SearchClients") self.assertEqual(entry.username, "api_test_robot_user") diff --git a/grr/server/grr_response_server/gui/api_integration_tests/flow_test.py b/grr/server/grr_response_server/gui/api_integration_tests/flow_test.py index 6657b542c7..4002d75e70 100644 --- a/grr/server/grr_response_server/gui/api_integration_tests/flow_test.py +++ b/grr/server/grr_response_server/gui/api_integration_tests/flow_test.py @@ -105,7 +105,8 @@ def testCreateFlowFromClientRef(self): flows = data_store.REL_DB.ReadAllFlowObjects(client_id) self.assertLen(flows, 1) - self.assertEqual(flows[0].args, args) + flow = mig_flow_objects.ToRDFFlow(flows[0]) + self.assertEqual(flow.args, args) def testCreateFlowFromClientObject(self): client_id = self.SetupClient(0) @@ -121,7 +122,8 @@ def testCreateFlowFromClientObject(self): flows = data_store.REL_DB.ReadAllFlowObjects(client_id) self.assertLen(flows, 1) - self.assertEqual(flows[0].args, args) + flow = mig_flow_objects.ToRDFFlow(flows[0]) + self.assertEqual(flow.args, args) def testRunInterrogateFlow(self): client_id = self.SetupClient(0) @@ -167,7 +169,7 @@ def testListParsedFlowResults(self): flow.flow_class_name = collectors.ArtifactCollectorFlow.__name__ flow.args = rdf_artifacts.ArtifactCollectorFlowArgs(apply_parsers=False) flow.persistent_data = {"knowledge_base": rdf_client.KnowledgeBase()} - data_store.REL_DB.WriteFlowObject(flow) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) result = rdf_flow_objects.FlowResult() result.client_id = client_id @@ -227,7 +229,7 @@ def testListFlowApplicableParsers(self): flow.flow_id = flow_id flow.flow_class_name = collectors.ArtifactCollectorFlow.__name__ flow.args = rdf_artifacts.ArtifactCollectorFlowArgs(apply_parsers=False) - data_store.REL_DB.WriteFlowObject(flow) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow)) result = rdf_flow_objects.FlowResult() result.client_id = client_id diff --git a/grr/server/grr_response_server/gui/api_integration_tests/hunt_test.py b/grr/server/grr_response_server/gui/api_integration_tests/hunt_test.py index 264d64ad46..772163dd05 100644 --- a/grr/server/grr_response_server/gui/api_integration_tests/hunt_test.py +++ b/grr/server/grr_response_server/gui/api_integration_tests/hunt_test.py @@ -9,7 +9,6 @@ from absl import app -from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import timeline as rdf_timeline from grr_response_core.lib.util import chunked from grr_response_proto import flows_pb2 @@ -166,7 +165,7 @@ def testListErrors(self): client_id=client_ids[0], parent=flow.FlowParent.FromHuntID(hunt_id)) flow_obj = data_store.REL_DB.ReadFlowObject(client_ids[0], flow_id) - flow_obj.flow_state = flow_obj.FlowState.ERROR + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR flow_obj.error_message = "Error foo." data_store.REL_DB.UpdateFlow(client_ids[0], flow_id, flow_obj=flow_obj) @@ -176,7 +175,7 @@ def testListErrors(self): client_id=client_ids[1], parent=flow.FlowParent.FromHuntID(hunt_id)) flow_obj = data_store.REL_DB.ReadFlowObject(client_ids[1], flow_id) - flow_obj.flow_state = flow_obj.FlowState.ERROR + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR flow_obj.error_message = "Error bar." flow_obj.backtrace = "" data_store.REL_DB.UpdateFlow(client_ids[1], flow_id, flow_obj=flow_obj) @@ -288,9 +287,8 @@ def testGetCollectedTimelinesBody(self): flow_obj.client_id = client_id flow_obj.flow_id = hunt_id flow_obj.flow_class_name = timeline.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() flow_obj.parent_hunt_id = hunt_id - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) entry_1 = rdf_timeline.TimelineEntry() entry_1.path = "/bar/baz/quux".encode("utf-8") @@ -377,9 +375,8 @@ def testGetCollectedTimelinesGzchunked(self): flow_obj.client_id = client_id flow_obj.flow_id = hunt_id flow_obj.flow_class_name = timeline.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() flow_obj.parent_hunt_id = hunt_id - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) entry_1 = rdf_timeline.TimelineEntry() entry_1.path = "/foo/bar".encode("utf-8") diff --git a/grr/server/grr_response_server/gui/api_plugins/client.py b/grr/server/grr_response_server/gui/api_plugins/client.py index e34c8aa75c..f5173ffe0f 100644 --- a/grr/server/grr_response_server/gui/api_plugins/client.py +++ b/grr/server/grr_response_server/gui/api_plugins/client.py @@ -16,6 +16,7 @@ from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import collection from grr_response_core.lib.util import precondition +from grr_response_proto import flows_pb2 from grr_response_proto.api import client_pb2 from grr_response_server import client_index from grr_response_server import data_store @@ -552,7 +553,7 @@ def Handle(self, args, context=None): raise InterrogateOperationNotFoundError("Operation with id %s not found" % args.operation_id) - complete = flow_obj.flow_state != flow_obj.FlowState.RUNNING + complete = flow_obj.flow_state != flows_pb2.Flow.FlowState.RUNNING result = ApiGetInterrogateOperationStateResult() if complete: diff --git a/grr/server/grr_response_server/gui/api_plugins/client_regression_test.py b/grr/server/grr_response_server/gui/api_plugins/client_regression_test.py index d77fb3e2e0..585d8d7c4d 100644 --- a/grr/server/grr_response_server/gui/api_plugins/client_regression_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/client_regression_test.py @@ -40,7 +40,7 @@ class ApiGetClientHandlerRegressionTest( def Run(self): # Fix the time to avoid regressions. with test_lib.FakeTime(42): - client_id = self.SetupClient(0, memory_size=4294967296, add_cert=False) + client_id = self.SetupClient(0, memory_size=4294967296) self.Check( "GetClient", args=client_plugin.ApiGetClientArgs(client_id=client_id)) @@ -56,14 +56,12 @@ class ApiGetClientVersionsRegressionTest( def _SetupTestClient(self): with test_lib.FakeTime(42): - client_id = self.SetupClient(0, memory_size=4294967296, add_cert=False) + client_id = self.SetupClient(0, memory_size=4294967296) with test_lib.FakeTime(45): self.SetupClient( - 0, - fqdn="some-other-hostname.org", - memory_size=4294967296, - add_cert=False) + 0, fqdn="some-other-hostname.org", memory_size=4294967296 + ) return client_id diff --git a/grr/server/grr_response_server/gui/api_plugins/flow.py b/grr/server/grr_response_server/gui/api_plugins/flow.py index 452c398be1..55f472f212 100644 --- a/grr/server/grr_response_server/gui/api_plugins/flow.py +++ b/grr/server/grr_response_server/gui/api_plugins/flow.py @@ -372,6 +372,7 @@ class ApiGetFlowHandler(api_call_handler_base.ApiCallHandler): def Handle(self, args, context=None): flow_obj = data_store.REL_DB.ReadFlowObject( str(args.client_id), str(args.flow_id)) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) return ApiFlow().InitFromFlowObject( flow_obj, with_state_and_context=True, with_progress=True) @@ -417,10 +418,17 @@ def Handle(self, args, context=None): api_request = ApiFlowRequest( request_id=str(request.request_id), request_state=request_state) + responses = [] if response_dict: - responses = [ - response_dict[i].AsLegacyGrrMessage() for i in sorted(response_dict) - ] + for _, response in sorted(response_dict.items()): + if isinstance(response, flows_pb2.FlowResponse): + response = mig_flow_objects.ToRDFFlowResponse(response) + if isinstance(response, flows_pb2.FlowStatus): + response = mig_flow_objects.ToRDFFlowStatus(response) + if isinstance(response, flows_pb2.FlowIterator): + response = mig_flow_objects.ToRDFFlowIterator(response) + responses.append(response.AsLegacyGrrMessage()) + for r in responses: r.ClearPayload() @@ -518,6 +526,7 @@ def Handle( flow_id = str(args.flow_id) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) if flow_obj.flow_class_name != collectors.ArtifactCollectorFlow.__name__: message = "Not an artifact-collector flow: {}" raise ValueError(message.format(flow_obj.flow_class_name)) @@ -612,6 +621,7 @@ def Handle( flow_id = str(args.flow_id) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) if flow_obj.flow_class_name != collectors.ArtifactCollectorFlow.__name__: message = "Not an artifact-collector flow: {}" raise ValueError(message.format(flow_obj.flow_class_name)) @@ -850,6 +860,7 @@ def _GetFlow(self, args, context=None): client_id = str(args.client_id) flow_id = str(args.flow_id) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) flow_results = data_store.REL_DB.ReadFlowResults(client_id, flow_id, 0, db.MAX_COUNT) flow_results = [mig_flow_objects.ToRDFFlowResult(r) for r in flow_results] @@ -929,6 +940,7 @@ class ApiListFlowOutputPluginsHandler(api_call_handler_base.ApiCallHandler): def Handle(self, args, context=None): flow_obj = data_store.REL_DB.ReadFlowObject( str(args.client_id), str(args.flow_id)) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) output_plugins_states = flow_obj.output_plugins_states type_indices = {} @@ -1011,6 +1023,7 @@ class ApiListFlowOutputPluginLogsHandlerBase( def Handle(self, args, context=None): flow_obj = data_store.REL_DB.ReadFlowObject( str(args.client_id), str(args.flow_id)) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) index = GetOutputPluginIndex(flow_obj.output_plugins, args.plugin_id) output_plugin_id = "%d" % index @@ -1118,6 +1131,7 @@ def _HandleTopFlowsOnly(self, args, context=None): include_child_flows=False, not_created_by=access_control.SYSTEM_USERS if args.human_flows_only else None) + top_flows = [mig_flow_objects.ToRDFFlow(f) for f in top_flows] result = [ ApiFlow().InitFromFlowObject( f_data, with_args=True, with_progress=True) for f_data in top_flows @@ -1138,6 +1152,7 @@ def _HandleAllFlows(self, args, context=None): include_child_flows=True, not_created_by=access_control.SYSTEM_USERS if args.human_flows_only else None) + all_flows = [mig_flow_objects.ToRDFFlow(f) for f in all_flows] api_flow_dict = { rdf_flow.flow_id: ApiFlow().InitFromFlowObject(rdf_flow, with_args=False) @@ -1245,6 +1260,7 @@ def Handle(self, args, context=None): output_plugins=runner_args.output_plugins, ) flow_obj = data_store.REL_DB.ReadFlowObject(str(args.client_id), flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) res = ApiFlow().InitFromFlowObject(flow_obj) res.context = None @@ -1271,6 +1287,7 @@ def Handle(self, args, context=None): ) flow_obj = data_store.REL_DB.ReadFlowObject( str(args.client_id), str(args.flow_id)) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) return ApiFlow().InitFromFlowObject(flow_obj) diff --git a/grr/server/grr_response_server/gui/api_plugins/flow_test.py b/grr/server/grr_response_server/gui/api_plugins/flow_test.py index 10e28d975d..30f4b4ad19 100644 --- a/grr/server/grr_response_server/gui/api_plugins/flow_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/flow_test.py @@ -78,6 +78,7 @@ def testInitializesClientIdForClientBasedFlows(self): flow_id = flow.StartFlow( client_id=client_id, flow_cls=processes.ListProcesses) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) flow_api_obj = flow_plugin.ApiFlow().InitFromFlowObject(flow_obj) self.assertEqual(flow_api_obj.client_id, @@ -88,6 +89,7 @@ def testFlowWithoutFlowProgressTypeReportsDefaultFlowProgress(self): flow_id = flow.StartFlow( client_id=client_id, flow_cls=flow_test_lib.DummyFlow) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) flow_api_obj = flow_plugin.ApiFlow().InitFromFlowObject(flow_obj) self.assertIsNotNone(flow_api_obj.progress) @@ -99,6 +101,7 @@ def testFlowWithoutResultsCorrectlyReportsEmptyResultMetadata(self): flow_id = flow.StartFlow( client_id=client_id, flow_cls=flow_test_lib.DummyFlow) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) flow_api_obj = flow_plugin.ApiFlow().InitFromFlowObject(flow_obj) self.assertIsNotNone(flow_api_obj.result_metadata) @@ -109,6 +112,7 @@ def testWithFlowProgressTypeReportsProgressCorrectly(self): flow_id = flow.StartFlow( client_id=client_id, flow_cls=flow_test_lib.DummyFlowWithProgress) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) flow_api_obj = flow_plugin.ApiFlow().InitFromFlowObject(flow_obj) self.assertIsNotNone(flow_api_obj.progress) @@ -548,7 +552,7 @@ def testIncorrectFlowType(self, db: abstract_db.Database) -> None: flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = "NotArtifactCollector" - db.WriteFlowObject(flow_obj) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) args = flow_plugin.ApiListFlowApplicableParsersArgs() args.client_id = client_id @@ -571,7 +575,7 @@ def testAlreadyAppliedParsers(self, db: abstract_db.Database) -> None: flow_obj.flow_id = flow_id flow_obj.flow_class_name = collectors.ArtifactCollectorFlow.__name__ flow_obj.args = rdf_artifacts.ArtifactCollectorFlowArgs(apply_parsers=True) - db.WriteFlowObject(flow_obj) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) flow_result = rdf_flow_objects.FlowResult() flow_result.client_id = client_id @@ -600,7 +604,7 @@ def testNotAppliedParsers(self, db: abstract_db.Database) -> None: flow_obj.flow_id = flow_id flow_obj.flow_class_name = collectors.ArtifactCollectorFlow.__name__ flow_obj.args = rdf_artifacts.ArtifactCollectorFlowArgs(apply_parsers=False) - db.WriteFlowObject(flow_obj) + db.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) flow_result = rdf_flow_objects.FlowResult() flow_result.client_id = client_id diff --git a/grr/server/grr_response_server/gui/api_plugins/hunt.py b/grr/server/grr_response_server/gui/api_plugins/hunt.py index 2d5b98f30a..83b5678a8b 100644 --- a/grr/server/grr_response_server/gui/api_plugins/hunt.py +++ b/grr/server/grr_response_server/gui/api_plugins/hunt.py @@ -15,10 +15,13 @@ from grr_response_core.lib import registry from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_core.lib.rdfvalues import mig_client +from grr_response_core.lib.rdfvalues import mig_stats from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import stats as rdf_stats from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_proto.api import hunt_pb2 from grr_response_server import access_control from grr_response_server import data_store @@ -43,6 +46,7 @@ from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects from grr_response_server.rdfvalues import hunts as rdf_hunts from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_flow_runner from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -407,16 +411,21 @@ class ApiHuntError(rdf_structs.RDFProtoStruct): rdfvalue.RDFDatetime, ] - def InitFromFlowObject(self, fo): - """Initialize from rdf_flow_objects.Flow corresponding to a failed flow.""" + @classmethod + def FromFlowErrorInfo( + cls, + client_id: str, + info: db.FlowErrorInfo, + ) -> "ApiHuntError": + result = cls() + result.client_id = client_id + result.log_message = info.message + result.timestamp = info.time - self.client_id = fo.client_id - if fo.HasField("backtrace"): - self.backtrace = fo.backtrace - self.log_message = fo.error_message - self.timestamp = fo.last_update_time + if info.backtrace is not None: + result.backtrace = info.backtrace - return self + return result class ApiListHuntsArgs(rdf_structs.RDFProtoStruct): @@ -433,6 +442,22 @@ class ApiListHuntsResult(rdf_structs.RDFProtoStruct): ] +def _ApiToObjectHuntStateProto( + state: ApiHunt.State, +) -> hunts_pb2.Hunt.HuntState: + """Converts ApiHunt.State to hunts_pb2.Hunt.HuntState.""" + if state == ApiHunt.State.PAUSED: + return hunts_pb2.Hunt.HuntState.PAUSED + elif state == ApiHunt.State.STARTED: + return hunts_pb2.Hunt.HuntState.STARTED + elif state == ApiHunt.State.STOPPED: + return hunts_pb2.Hunt.HuntState.STOPPED + elif state == ApiHunt.State.COMPLETED: + return hunts_pb2.Hunt.HuntState.COMPLETED + else: + return hunts_pb2.Hunt.HuntState.UNKNOWN + + class ApiListHuntsHandler(api_call_handler_base.ApiCallHandler): """Renders list of available hunts.""" @@ -510,18 +535,6 @@ def Filter(x): else: return None - def _ApiToObjectHuntState(self, state): - if state == ApiHunt.State.PAUSED: - return rdf_hunt_objects.Hunt.HuntState.PAUSED - elif state == ApiHunt.State.STARTED: - return rdf_hunt_objects.Hunt.HuntState.STARTED - elif state == ApiHunt.State.STOPPED: - return rdf_hunt_objects.Hunt.HuntState.STOPPED - elif state == ApiHunt.State.COMPLETED: - return rdf_hunt_objects.Hunt.HuntState.COMPLETED - else: - return rdf_hunt_objects.Hunt.HuntState.UNKNOWN - def Handle(self, args, context=None): if args.description_contains and not args.active_within: raise ValueError( @@ -542,7 +555,7 @@ def Handle(self, args, context=None): if args.active_within: kw_args["created_after"] = rdfvalue.RDFDatetime.Now() - args.active_within if args.with_state: - kw_args["with_states"] = [self._ApiToObjectHuntState(args.with_state)] + kw_args["with_states"] = [_ApiToObjectHuntStateProto(args.with_state)] # TODO(user): total_count is not returned by the current implementation. # It's not clear, if it's needed - GRR UI doesn't show total number of @@ -553,18 +566,27 @@ def Handle(self, args, context=None): hunt_objects = data_store.REL_DB.ReadHuntObjects( args.offset, args.count or db.MAX_COUNT, **kw_args ) + hunt_ids = [h.hunt_id for h in hunt_objects] + hunt_counters = data_store.REL_DB.ReadHuntsCounters(hunt_ids) + items = [] for hunt_obj in hunt_objects: - hunt_counters = data_store.REL_DB.ReadHuntCounters(hunt_obj.hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) items.append( ApiHunt().InitFromHuntObject( - hunt_obj, hunt_counters=hunt_counters, with_full_summary=True + hunt_obj, + hunt_counters=hunt_counters[hunt_obj.hunt_id], + with_full_summary=True, ) ) + else: hunt_objects = data_store.REL_DB.ListHuntObjects( args.offset, args.count or db.MAX_COUNT, **kw_args ) + hunt_objects = [ + mig_hunt_objects.ToRDFHuntMetadata(h) for h in hunt_objects + ] items = [ApiHunt().InitFromHuntMetadata(h) for h in hunt_objects] return ApiListHuntsResult(items=items) @@ -608,6 +630,7 @@ def Handle(self, args, context=None): try: hunt_id = str(args.hunt_id) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) hunt_counters = data_store.REL_DB.ReadHuntCounters(hunt_id) return ApiHunt().InitFromHuntObject( @@ -691,6 +714,7 @@ def Handle(self, args, context=None): str(args.hunt_id), with_type=args.with_type or None ) + results = [mig_flow_objects.ToRDFFlowResult(r) for r in results] return ApiListHuntResultsResult( items=[ApiHuntResult().InitFromFlowResult(r) for r in results], total_count=total_count, @@ -728,10 +752,9 @@ def Handle(self, args, context=None): str(args.hunt_id), filter_condition=db.HuntFlowsCondition.CRASHED_FLOWS_ONLY, ) - - return ApiListHuntCrashesResult( - items=[f.client_crash_info for f in flows], total_count=total_count - ) + crash_infos = [f.client_crash_info for f in flows] + crash_infos = [mig_client.ToRDFClientCrash(info) for info in crash_infos] + return ApiListHuntCrashesResult(items=crash_infos, total_count=total_count) class ApiGetHuntResultsExportCommandArgs(rdf_structs.RDFProtoStruct): @@ -802,6 +825,7 @@ def Handle(self, args, context=None): ) from ex for s in plugin_states: + s = mig_flow_runner.ToRDFOutputPluginState(s) name = s.plugin_descriptor.plugin_name plugin_id = "%s_%d" % (name, used_names[name]) used_names[name] += 1 @@ -820,8 +844,9 @@ def Handle(self, args, context=None): if "success_count" in state and not state["success_count"]: del state["success_count"] + plugin_descriptor = s.plugin_descriptor api_plugin = api_output_plugin.ApiOutputPlugin( - id=plugin_id, plugin_descriptor=s.plugin_descriptor, state=state + id=plugin_id, plugin_descriptor=plugin_descriptor, state=state ) result.append(api_plugin) @@ -841,6 +866,7 @@ class ApiListHuntOutputPluginLogsHandlerBase( def Handle(self, args, context=None): h = data_store.REL_DB.ReadHuntObject(str(args.hunt_id)) + h = mig_hunt_objects.ToRDFHunt(h) if self.__class__.log_entry_type is None: raise ValueError( @@ -984,38 +1010,44 @@ class ApiListHuntErrorsHandler(api_call_handler_base.ApiCallHandler): args_type = ApiListHuntErrorsArgs result_type = ApiListHuntErrorsResult - _FLOW_ATTRS_TO_MATCH = ["flow_id", "client_id", "error_message", "backtrace"] - - def _MatchFlowAgainstFilter(self, flow_obj, filter_str): - for attr in self._FLOW_ATTRS_TO_MATCH: - if filter_str in flow_obj.Get(attr): - return True - - return False - def Handle(self, args, context=None): total_count = data_store.REL_DB.CountHuntFlows( str(args.hunt_id), filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY, ) + + errors = data_store.REL_DB.ReadHuntFlowErrors( + str(args.hunt_id), + args.offset, + args.count or db.MAX_COUNT, + ) + if args.filter: - flows = data_store.REL_DB.ReadHuntFlows( - str(args.hunt_id), - args.offset, - total_count, - filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY, - ) - flows = [f for f in flows if self._MatchFlowAgainstFilter(f, args.filter)] - else: - flows = data_store.REL_DB.ReadHuntFlows( - str(args.hunt_id), - args.offset, - args.count or db.MAX_COUNT, - filter_condition=db.HuntFlowsCondition.FAILED_FLOWS_ONLY, - ) + + def MatchesFilter( + client_id: str, + info: db.FlowErrorInfo, + ) -> bool: + if args.filter in client_id: + return True + if args.filter in info.message: + return True + if info.backtrace is not None and args.filter in info.backtrace: + return True + + return False + + errors = { + client_id: info + for client_id, info in errors.items() + if MatchesFilter(client_id, info) + } return ApiListHuntErrorsResult( - items=[ApiHuntError().InitFromFlowObject(f) for f in flows], + items=[ + ApiHuntError.FromFlowErrorInfo(client_id, info) + for client_id, info in errors.items() + ], total_count=total_count, ) @@ -1229,6 +1261,7 @@ def _LoadData( ) -> Tuple[Iterable[rdf_flow_objects.FlowResult], str]: hunt_id = str(args.hunt_id) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) hunt_api_object = ApiHunt().InitFromHuntObject(hunt_obj) description = ( "Files downloaded by hunt %s (%s, '%s') created by user %s on %s" @@ -1244,6 +1277,7 @@ def _LoadData( results = data_store.REL_DB.ReadHuntResults( hunt_id, offset=0, count=db.MAX_COUNT ) + results = [mig_flow_objects.ToRDFFlowResult(res) for res in results] return results, description def Handle( @@ -1337,6 +1371,7 @@ def Handle(self, args, context=None): # get filled automatically from the hunt results and we check # later that the aff4_path we get is the same as the one that # was requested. + item = mig_flow_objects.ToRDFFlowResult(item) client_path = export.CollectionItemToClientPath(item, client_id=None) except export.ItemNotExportableError: continue @@ -1388,6 +1423,7 @@ class ApiGetHuntStatsHandler(api_call_handler_base.ApiCallHandler): def Handle(self, args, context=None): del context # Unused. stats = data_store.REL_DB.ReadHuntClientResourcesStats(str(args.hunt_id)) + stats = mig_stats.ToRDFClientResourcesStats(stats) return ApiGetHuntStatsResult(stats=stats) @@ -1461,6 +1497,7 @@ class ApiGetHuntContextHandler(api_call_handler_base.ApiCallHandler): def Handle(self, args, context=None): hunt_id = str(args.hunt_id) h = data_store.REL_DB.ReadHuntObject(hunt_id) + h = mig_hunt_objects.ToRDFHunt(h) h_counters = data_store.REL_DB.ReadHuntCounters(hunt_id) context = rdf_hunts.HuntContext( session_id=rdfvalue.RDFURN("hunts").Add(h.hunt_id), @@ -1585,6 +1622,7 @@ def Handle(self, args, context=None): try: hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if has_change: kw_args = {} if hunt_obj.hunt_state != hunt_obj.HuntState.PAUSED: @@ -1607,6 +1645,7 @@ def Handle(self, args, context=None): data_store.REL_DB.UpdateHuntObject(hunt_id, **kw_args) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) except db.UnknownHuntError: raise HuntNotFoundError( "Hunt with id %s could not be found" % args.hunt_id @@ -1630,7 +1669,7 @@ def Handle(self, args, context=None): ) hunt_obj = hunt.StopHunt( hunt_obj.hunt_id, - hunt_state_reason=rdf_hunt_objects.Hunt.HuntStateReason.TRIGGERED_BY_USER, + hunt_state_reason=hunts_pb2.Hunt.HuntStateReason.TRIGGERED_BY_USER, reason_comment=CANCELLED_BY_USER, ) @@ -1714,6 +1753,7 @@ def FetchFn(type_name): break for r in results: + r = mig_flow_objects.ToRDFFlowResult(r) msg = r.AsLegacyGrrMessage() msg.source_urn = source_urn yield msg diff --git a/grr/server/grr_response_server/gui/api_plugins/hunt_regression_test.py b/grr/server/grr_response_server/gui/api_plugins/hunt_regression_test.py index 262b2714cd..99ffb1927e 100644 --- a/grr/server/grr_response_server/gui/api_plugins/hunt_regression_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/hunt_regression_test.py @@ -348,7 +348,7 @@ def Run(self): client_id=client_id_1, parent=flow.FlowParent.FromHuntID(hunt_id)) flow_obj = data_store.REL_DB.ReadFlowObject(client_id_1, flow_id) - flow_obj.flow_state = flow_obj.FlowState.ERROR + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR flow_obj.error_message = "Error foo." data_store.REL_DB.UpdateFlow(client_id_1, flow_id, flow_obj=flow_obj) @@ -614,10 +614,6 @@ def Run(self): # Create replace dictionary. replace = {hunt_id: "H:123456"} - stats = data_store.REL_DB.ReadHuntClientResourcesStats(hunt_id) - for performance in stats.worst_performers: - session_id = str(performance.session_id) - replace[session_id] = "" self.Check( "GetHuntStats", diff --git a/grr/server/grr_response_server/gui/api_plugins/hunt_test.py b/grr/server/grr_response_server/gui/api_plugins/hunt_test.py index 8d21532eec..653e15e5cc 100644 --- a/grr/server/grr_response_server/gui/api_plugins/hunt_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/hunt_test.py @@ -15,6 +15,8 @@ from grr_response_core.lib.rdfvalues import file_finder as rdf_file_finder from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import test_base as rdf_test_base +from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_server import data_store from grr_response_server import hunt from grr_response_server.databases import db @@ -25,12 +27,13 @@ from grr_response_server.gui import api_test_lib from grr_response_server.gui.api_plugins import hunt as hunt_plugin from grr_response_server.output_plugins import test_plugins +from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects -from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import action_mocks +from grr.test_lib import db_test_lib from grr.test_lib import flow_test_lib from grr.test_lib import hunt_test_lib from grr.test_lib import test_lib @@ -493,7 +496,9 @@ def testRaisesIfResultIsBeforeTimestamp(self): hunt_id=self.hunt_id, client_id=self.client_id, vfs_path=self.vfs_file_path, - timestamp=results[0].timestamp + timestamp=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + results[0].timestamp + ) + rdfvalue.Duration.From(1, rdfvalue.SECONDS), ) with self.assertRaises(hunt_plugin.HuntFileNotFoundError): @@ -504,16 +509,21 @@ def testRaisesIfResultFileDoesNotExist(self): original_result = results[0] with test_lib.FakeTime( - original_result.timestamp - rdfvalue.Duration.From(1, rdfvalue.SECONDS) + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + original_result.timestamp + ) + - rdfvalue.Duration.From(1, rdfvalue.SECONDS) ): - wrong_result = original_result.Copy() - payload = wrong_result.payload + wrong_result = flows_pb2.FlowResult() + wrong_result.CopyFrom(original_result) + payload = flows_pb2.FileFinderResult() + wrong_result.payload.Unpack(payload) payload.stat_entry.pathspec.path += "blah" - data_store.REL_DB.WriteFlowResults( - [mig_flow_objects.ToProtoFlowResult(wrong_result)] - ) + data_store.REL_DB.WriteFlowResults([wrong_result]) - wrong_result_timestamp = wrong_result.timestamp + wrong_result_timestamp = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + wrong_result.timestamp + ) args = hunt_plugin.ApiGetHuntFileArgs( hunt_id=self.hunt_id, @@ -533,14 +543,14 @@ def testReturnsBinaryStreamIfResultFound(self): hunt_id=self.hunt_id, client_id=self.client_id, vfs_path=self.vfs_file_path, - timestamp=timestamp, + timestamp=rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(timestamp), ) result = self.handler.Handle(args, context=self.context) self.assertTrue(hasattr(result, "GenerateContent")) - self.assertEqual( - result.content_length, results[0].payload.stat_entry.st_size - ) + payload = flows_pb2.FileFinderResult() + results[0].payload.Unpack(payload) + self.assertEqual(result.content_length, payload.stat_entry.st_size) class ApiListHuntResultsHandlerTest( @@ -799,15 +809,15 @@ def setUp(self): self.args = hunt_plugin.ApiModifyHuntArgs(hunt_id=self.hunt_id) def testDoesNothingIfArgsHaveNoChanges(self): - before = hunt_plugin.ApiHunt().InitFromHuntObject( - data_store.REL_DB.ReadHuntObject(self.hunt_id) - ) + hunt_obj = data_store.REL_DB.ReadHuntObject(self.hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + before = hunt_plugin.ApiHunt().InitFromHuntObject(hunt_obj) self.handler.Handle(self.args, context=self.context) - after = hunt_plugin.ApiHunt().InitFromHuntObject( - data_store.REL_DB.ReadHuntObject(self.hunt_id) - ) + hunt_obj = data_store.REL_DB.ReadHuntObject(self.hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + after = hunt_plugin.ApiHunt().InitFromHuntObject(hunt_obj) self.assertEqual(before, after) @@ -847,7 +857,7 @@ def testStopsHuntCorrectly(self): self.assertEqual(h.hunt_state, h.HuntState.STOPPED) self.assertEqual( h.hunt_state_reason, - rdf_hunt_objects.Hunt.HuntStateReason.TRIGGERED_BY_USER, + hunts_pb2.Hunt.HuntStateReason.TRIGGERED_BY_USER, ) self.assertEqual(h.hunt_state_comment, "Cancelled by user") @@ -865,9 +875,9 @@ def testModifiesHuntCorrectly(self): self.handler.Handle(self.args, context=self.context) - after = hunt_plugin.ApiHunt().InitFromHuntObject( - data_store.REL_DB.ReadHuntObject(self.hunt_id) - ) + hunt_obj = data_store.REL_DB.ReadHuntObject(self.hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + after = hunt_plugin.ApiHunt().InitFromHuntObject(hunt_obj) self.assertEqual(after.client_rate, 100) self.assertEqual(after.client_limit, 42) @@ -1022,6 +1032,181 @@ def testFailsWhenMoreThan250ClientsScheduledForCollection(self): self.handler.Handle(args, self.context) +class ListHuntErrorsHandlerTest(absltest.TestCase): + + @db_test_lib.WithDatabase + def testWithoutFilter(self, rel_db: db.Database): + hunt_id = db_test_utils.InitializeHunt(rel_db) + + client_id_1 = db_test_utils.InitializeClient(rel_db) + client_id_2 = db_test_utils.InitializeClient(rel_db) + + flow_id_1 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + flow_obj_1 = rel_db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + rel_db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + flow_obj_2 = rel_db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = rdf_flow_objects.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + rel_db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + args = hunt_plugin.ApiListHuntErrorsArgs() + args.hunt_id = hunt_id + + handler = hunt_plugin.ApiListHuntErrorsHandler() + + results = handler.Handle(args) + self.assertLen(results.items, 2) + + self.assertEqual(results.items[0].client_id, client_id_1) + self.assertEqual(results.items[0].log_message, "ERROR_1") + + self.assertEqual(results.items[1].client_id, client_id_2) + self.assertEqual(results.items[1].log_message, "ERROR_2") + + @db_test_lib.WithDatabase + def testWithFilterByClientID(self, rel_db: db.Database): + hunt_id = db_test_utils.InitializeHunt(rel_db) + + client_id_1 = db_test_utils.InitializeClient(rel_db) + client_id_2 = db_test_utils.InitializeClient(rel_db) + + flow_id_1 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + flow_obj_1 = rel_db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + rel_db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + flow_obj_2 = rel_db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + rel_db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + args = hunt_plugin.ApiListHuntErrorsArgs() + args.hunt_id = hunt_id + args.filter = client_id_2 + + handler = hunt_plugin.ApiListHuntErrorsHandler() + + results = handler.Handle(args) + self.assertLen(results.items, 1) + + self.assertEqual(results.items[0].client_id, client_id_2) + self.assertEqual(results.items[0].log_message, "ERROR_2") + + @db_test_lib.WithDatabase + def testWithFilterByMessage(self, rel_db: db.Database): + hunt_id = db_test_utils.InitializeHunt(rel_db) + + client_id_1 = db_test_utils.InitializeClient(rel_db) + client_id_2 = db_test_utils.InitializeClient(rel_db) + + flow_id_1 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + flow_obj_1 = rel_db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + rel_db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + flow_obj_2 = rel_db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + rel_db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + args = hunt_plugin.ApiListHuntErrorsArgs() + args.hunt_id = hunt_id + args.filter = "_1" + + handler = hunt_plugin.ApiListHuntErrorsHandler() + + results = handler.Handle(args) + self.assertLen(results.items, 1) + + self.assertEqual(results.items[0].client_id, client_id_1) + self.assertEqual(results.items[0].log_message, "ERROR_1") + + @db_test_lib.WithDatabase + def testWithFilterByBacktrace(self, rel_db: db.Database): + hunt_id = db_test_utils.InitializeHunt(rel_db) + + client_id_1 = db_test_utils.InitializeClient(rel_db) + client_id_2 = db_test_utils.InitializeClient(rel_db) + + flow_id_1 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_1, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + flow_id_2 = db_test_utils.InitializeFlow( + rel_db, + client_id=client_id_2, + flow_id=hunt_id, + parent_hunt_id=hunt_id, + ) + + flow_obj_1 = rel_db.ReadFlowObject(client_id_1, flow_id_1) + flow_obj_1.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_1.error_message = "ERROR_1" + flow_obj_1.backtrace = "File 'foo_1.py', line 1, in 'foo'" + rel_db.UpdateFlow(client_id_1, flow_id_1, flow_obj_1) + + flow_obj_2 = rel_db.ReadFlowObject(client_id_2, flow_id_2) + flow_obj_2.flow_state = flows_pb2.Flow.FlowState.ERROR + flow_obj_2.error_message = "ERROR_2" + flow_obj_2.backtrace = "File 'foo_2.py', line 1, in 'foo'" + rel_db.UpdateFlow(client_id_2, flow_id_2, flow_obj_2) + + args = hunt_plugin.ApiListHuntErrorsArgs() + args.hunt_id = hunt_id + args.filter = "foo_2.py" + + handler = hunt_plugin.ApiListHuntErrorsHandler() + + results = handler.Handle(args) + self.assertLen(results.items, 1) + + self.assertEqual(results.items[0].client_id, client_id_2) + self.assertEqual(results.items[0].log_message, "ERROR_2") + + def main(argv): test_lib.main(argv) diff --git a/grr/server/grr_response_server/gui/api_plugins/report_plugins/report_plugins_test.py b/grr/server/grr_response_server/gui/api_plugins/report_plugins/report_plugins_test.py index d5570f2154..93ce1089d0 100644 --- a/grr/server/grr_response_server/gui/api_plugins/report_plugins/report_plugins_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/report_plugins/report_plugins_test.py @@ -5,13 +5,13 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import events as rdf_events +from grr_response_proto import objects_pb2 from grr_response_server import data_store from grr_response_server.gui.api_plugins import stats as stats_api from grr_response_server.gui.api_plugins.report_plugins import rdf_report_plugins from grr_response_server.gui.api_plugins.report_plugins import report_plugins from grr_response_server.gui.api_plugins.report_plugins import report_plugins_test_mocks from grr_response_server.gui.api_plugins.report_plugins import server_report_plugins -from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import test_lib RepresentationType = rdf_report_plugins.ApiReportData.RepresentationType @@ -52,11 +52,12 @@ def testGetReportDescriptor(self): def AddFakeAuditLog(user=None, router_method_name=None, http_request_path=None): data_store.REL_DB.WriteAPIAuditEntry( - rdf_objects.APIAuditEntry( + objects_pb2.APIAuditEntry( username=user, router_method_name=router_method_name, http_request_path=http_request_path, - )) + ) + ) class ServerReportPluginsTest(test_lib.GRRBaseTest): diff --git a/grr/server/grr_response_server/gui/api_plugins/timeline_test.py b/grr/server/grr_response_server/gui/api_plugins/timeline_test.py index 7314c77a59..7f24dd5919 100644 --- a/grr/server/grr_response_server/gui/api_plugins/timeline_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/timeline_test.py @@ -7,7 +7,6 @@ from absl.testing import absltest -from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import timeline as rdf_timeline from grr_response_core.lib.util import chunked from grr_response_proto import objects_pb2 @@ -43,8 +42,7 @@ def testRaisesOnIncorrectFlowType(self): flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = "NotTimelineFlow" - flow_obj.create_time = rdfvalue.RDFDatetime.Now() - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) args = api_timeline.ApiGetCollectedTimelineArgs() args.client_id = client_id @@ -193,8 +191,7 @@ def testNtfsFileReferenceFormatInference(self): flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = timeline.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) blobs = list(rdf_timeline.TimelineEntry.SerializeStream(iter([entry]))) blob_ids = data_store.BLOBS.WriteBlobsWithUnknownHashes(blobs) @@ -286,8 +283,7 @@ def testBodyMultipleResults(self): flow_obj.client_id = client_id flow_obj.flow_id = flow_id flow_obj.flow_class_name = timeline.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) entry_1 = rdf_timeline.TimelineEntry() entry_1.path = "/foo".encode("utf-8") diff --git a/grr/server/grr_response_server/gui/api_plugins/user.py b/grr/server/grr_response_server/gui/api_plugins/user.py index b31eaec6d5..d4a0f1346f 100644 --- a/grr/server/grr_response_server/gui/api_plugins/user.py +++ b/grr/server/grr_response_server/gui/api_plugins/user.py @@ -31,6 +31,8 @@ from grr_response_server.gui.api_plugins import flow as api_flow from grr_response_server.gui.api_plugins import hunt as api_hunt from grr_response_server.models import users +from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -457,6 +459,7 @@ def InitFromDatabaseObject(self, db_obj, approval_subject_obj=None): if not approval_subject_obj: approval_subject_obj = data_store.REL_DB.ReadHuntObject(db_obj.subject_id) + approval_subject_obj = mig_hunt_objects.ToRDFHunt(approval_subject_obj) approval_subject_counters = data_store.REL_DB.ReadHuntCounters( db_obj.subject_id) self.subject = api_hunt.ApiHunt().InitFromHuntObject( @@ -469,11 +472,13 @@ def InitFromDatabaseObject(self, db_obj, approval_subject_obj=None): original_flow = data_store.REL_DB.ReadFlowObject( original_object.flow_reference.client_id, original_object.flow_reference.flow_id) + original_flow = mig_flow_objects.ToRDFFlow(original_flow) self.copied_from_flow = api_flow.ApiFlow().InitFromFlowObject( original_flow) elif original_object.object_type == "HUNT_REFERENCE": original_hunt = data_store.REL_DB.ReadHuntObject( original_object.hunt_reference.hunt_id) + original_hunt = mig_hunt_objects.ToRDFHunt(original_hunt) original_hunt_counters = data_store.REL_DB.ReadHuntCounters( original_object.hunt_reference.hunt_id) self.copied_from_hunt = api_hunt.ApiHunt().InitFromHuntObject( diff --git a/grr/server/grr_response_server/gui/api_plugins/vfs.py b/grr/server/grr_response_server/gui/api_plugins/vfs.py index 993939a201..5f2930425f 100644 --- a/grr/server/grr_response_server/gui/api_plugins/vfs.py +++ b/grr/server/grr_response_server/gui/api_plugins/vfs.py @@ -34,6 +34,8 @@ from grr_response_server.gui import api_call_context from grr_response_server.gui import api_call_handler_base from grr_response_server.gui.api_plugins import client +from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects # Files can only be accessed if their first path component is from this list. @@ -303,24 +305,30 @@ def Handle(self, args, context=None): client_id = str(args.client_id) try: - path_info = data_store.REL_DB.ReadPathInfo( + proto_path_info = data_store.REL_DB.ReadPathInfo( client_id=client_id, path_type=path_type, components=components, - timestamp=args.timestamp) + timestamp=args.timestamp, + ) except db.UnknownPathError: raise FileNotFoundError( client_id=client_id, path_type=path_type, components=components) + path_info = None + if proto_path_info is not None: + path_info = mig_objects.ToRDFPathInfo(proto_path_info) last_collection_pi = file_store.GetLastCollectionPathInfo( db.ClientPath.FromPathInfo(client_id, path_info), max_timestamp=args.timestamp) - history = data_store.REL_DB.ReadPathInfoHistory( + proto_history = data_store.REL_DB.ReadPathInfoHistory( client_id=client_id, path_type=path_type, components=components, - cutoff=args.timestamp) + cutoff=args.timestamp, + ) + history = [mig_objects.ToRDFPathInfo(pi) for pi in proto_history] history.reverse() # It might be the case that we do not have any history about the file, but @@ -471,18 +479,21 @@ def Handle(self, args, context=None): # TODO: This API handler should return a 404 response if the # path is not found. Currently, 500 is returned. - child_path_infos = data_store.REL_DB.ListChildPathInfos( + proto_child_path_infos = data_store.REL_DB.ListChildPathInfos( client_id=args.client_id.ToString(), path_type=path_type, components=components, - timestamp=args.timestamp) + timestamp=args.timestamp, + ) items = [] - for child_path_info in child_path_infos: + for child_path_info in proto_child_path_infos: if args.directories_only and not child_path_info.directory: continue - items.append(_PathInfoToApiFile(child_path_info)) + items.append( + _PathInfoToApiFile(mig_objects.ToRDFPathInfo(child_path_info)) + ) # TODO(hanuszczak): Instead of getting the whole list from the database and # then filtering the results we should do the filtering directly in the @@ -577,7 +588,8 @@ def _MergePathInfos(self, path_infos: Dict[str, rdf_objects.PathInfo], cur_path_infos: Collection[rdf_objects.PathInfo]) -> None: """Merges PathInfos from different PathTypes (OS, TSK, NTFS).""" - for pi in cur_path_infos: + for proto_pi in cur_path_infos: + pi = mig_objects.ToRDFPathInfo(proto_pi) existing = path_infos.get(pi.basename) # If the VFS has the same file in two PathTypes, use the latest collected # version. @@ -774,8 +786,10 @@ def Handle(self, args, context=None): # empty response. return ApiGetFileVersionTimesResult(times=[]) - history = data_store.REL_DB.ReadPathInfoHistory( - str(args.client_id), path_type, components) + proto_history = data_store.REL_DB.ReadPathInfoHistory( + str(args.client_id), path_type, components + ) + history = [mig_objects.ToRDFPathInfo(pi) for pi in proto_history] times = reversed([pi.timestamp for pi in history]) return ApiGetFileVersionTimesResult(times=times) @@ -872,6 +886,7 @@ def _FindPathspec(self, args): path_info = res[k] if path_info is None: raise FileNotFoundError(args.client_id, path_type, components) + path_info = mig_objects.ToRDFPathInfo(path_info) if path_info.stat_entry and path_info.stat_entry.pathspec: ps = path_info.stat_entry.pathspec @@ -945,16 +960,19 @@ def _RaiseOperationNotFoundError(self, args): def Handle(self, args, context=None): try: - rdf_flow = data_store.REL_DB.ReadFlowObject( - str(args.client_id), str(args.operation_id)) + flow_obj = data_store.REL_DB.ReadFlowObject( + str(args.client_id), str(args.operation_id) + ) except db.UnknownFlowError: self._RaiseOperationNotFoundError(args) - if rdf_flow.flow_class_name not in [ - "RecursiveListDirectory", "ListDirectory" + if flow_obj.flow_class_name not in [ + "RecursiveListDirectory", + "ListDirectory", ]: self._RaiseOperationNotFoundError(args) + rdf_flow = mig_flow_objects.ToRDFFlow(flow_obj) complete = rdf_flow.flow_state != "RUNNING" result = ApiGetVfsRefreshOperationStateResult() if complete: @@ -972,35 +990,45 @@ def _GetTimelineStatEntries(api_client_id, file_path, with_history=True): client_id = str(api_client_id) try: - root_path_info = data_store.REL_DB.ReadPathInfo(client_id, path_type, - components) + proto_root_path_info = data_store.REL_DB.ReadPathInfo( + client_id, path_type, components + ) except db.UnknownPathError: return path_infos = [] - for path_info in itertools.chain( - [root_path_info], - data_store.REL_DB.ListDescendantPathInfos(client_id, path_type, - components), + for proto_path_info in itertools.chain( + [proto_root_path_info], + data_store.REL_DB.ListDescendantPathInfos( + client_id, path_type, components + ), ): + path_info = mig_objects.ToRDFPathInfo(proto_path_info) + # TODO(user): this is to keep the compatibility with current # AFF4 implementation. Check if this check is needed. if path_info.directory: continue - categorized_path = rdf_objects.ToCategorizedPath(path_info.path_type, - path_info.components) + categorized_path = rdf_objects.ToCategorizedPath( + path_info.path_type, path_info.components + ) if with_history: path_infos.append(path_info) else: yield categorized_path, path_info.stat_entry, path_info.hash_entry if with_history: - hist_path_infos = data_store.REL_DB.ReadPathInfosHistories( - client_id, path_type, [tuple(pi.components) for pi in path_infos]) - for path_info in itertools.chain.from_iterable(hist_path_infos.values()): - categorized_path = rdf_objects.ToCategorizedPath(path_info.path_type, - path_info.components) + proto_hist_path_infos = data_store.REL_DB.ReadPathInfosHistories( + client_id, path_type, [tuple(pi.components) for pi in path_infos] + ) + for proto_path_info in itertools.chain.from_iterable( + proto_hist_path_infos.values() + ): + path_info = mig_objects.ToRDFPathInfo(proto_path_info) + categorized_path = rdf_objects.ToCategorizedPath( + path_info.path_type, path_info.components + ) yield categorized_path, path_info.stat_entry, path_info.hash_entry @@ -1191,8 +1219,10 @@ class ApiUpdateVfsFileContentHandler(api_call_handler_base.ApiCallHandler): def Handle(self, args, context=None): path_type, components = rdf_objects.ParseCategorizedPath(args.file_path) - path_info = data_store.REL_DB.ReadPathInfo( - str(args.client_id), path_type, components) + proto_path_info = data_store.REL_DB.ReadPathInfo( + str(args.client_id), path_type, components + ) + path_info = mig_objects.ToRDFPathInfo(proto_path_info) if (not path_info or not path_info.stat_entry or not path_info.stat_entry.pathspec): @@ -1234,17 +1264,19 @@ class ApiGetVfsFileContentUpdateStateHandler( def Handle(self, args, context=None): try: - rdf_flow = data_store.REL_DB.ReadFlowObject( - str(args.client_id), str(args.operation_id)) + proto_flow = data_store.REL_DB.ReadFlowObject( + str(args.client_id), str(args.operation_id) + ) except db.UnknownFlowError: raise VfsFileContentUpdateNotFoundError("Operation with id %s not found" % args.operation_id) - if rdf_flow.flow_class_name != "MultiGetFile": + if proto_flow.flow_class_name != "MultiGetFile": raise VfsFileContentUpdateNotFoundError("Operation with id %s not found" % args.operation_id) result = ApiGetVfsFileContentUpdateStateResult() + rdf_flow = mig_flow_objects.ToRDFFlow(proto_flow) if rdf_flow.flow_state == "RUNNING": result.state = ApiGetVfsFileContentUpdateStateResult.State.RUNNING else: diff --git a/grr/server/grr_response_server/gui/api_plugins/vfs_regression_test.py b/grr/server/grr_response_server/gui/api_plugins/vfs_regression_test.py index 6cab3925a2..87ad8bd53f 100644 --- a/grr/server/grr_response_server/gui/api_plugins/vfs_regression_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/vfs_regression_test.py @@ -13,6 +13,7 @@ from grr_response_server.gui import api_regression_test_lib from grr_response_server.gui.api_plugins import vfs as vfs_plugin from grr_response_server.gui.api_plugins import vfs_test as vfs_plugin_test +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import acl_test_lib from grr.test_lib import fixture_test_lib from grr.test_lib import flow_test_lib @@ -165,13 +166,15 @@ def Run(self): creator=self.test_username) # Kill flow. - rdf_flow = data_store.REL_DB.LeaseFlowForProcessing( - client_id, finished_flow_id, - rdfvalue.Duration.From(5, rdfvalue.MINUTES)) + proto_flow = data_store.REL_DB.LeaseFlowForProcessing( + client_id, finished_flow_id, rdfvalue.Duration.From(5, rdfvalue.MINUTES) + ) + rdf_flow = mig_flow_objects.ToRDFFlow(proto_flow) flow_cls = registry.FlowRegistry.FlowClassByName(rdf_flow.flow_class_name) flow_obj = flow_cls(rdf_flow) flow_obj.Error("Fake error") - data_store.REL_DB.ReleaseProcessedFlow(rdf_flow) + proto_flow = mig_flow_objects.ToProtoFlow(rdf_flow) + data_store.REL_DB.ReleaseProcessedFlow(proto_flow) # Create an arbitrary flow to check on 404s. non_refresh_flow_id = flow_test_lib.StartFlow( diff --git a/grr/server/grr_response_server/gui/api_plugins/vfs_test.py b/grr/server/grr_response_server/gui/api_plugins/vfs_test.py index ea5b0d55ed..9393091ddb 100644 --- a/grr/server/grr_response_server/gui/api_plugins/vfs_test.py +++ b/grr/server/grr_response_server/gui/api_plugins/vfs_test.py @@ -14,7 +14,6 @@ from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import paths as rdf_paths - from grr_response_server import data_store from grr_response_server import decoders from grr_response_server import flow @@ -25,8 +24,9 @@ from grr_response_server.flows.general import transfer from grr_response_server.gui import api_test_lib from grr_response_server.gui.api_plugins import vfs as vfs_plugin +from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects - from grr.test_lib import fixture_test_lib from grr.test_lib import flow_test_lib from grr.test_lib import notification_test_lib @@ -503,10 +503,13 @@ def testRaisesOnNonExistentPath(self): def testRaisesOnExistingPathWithoutContent(self): path_info = rdf_objects.PathInfo.OS(components=["foo", "bar"]) - data_store.REL_DB.WritePathInfos(self.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + self.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) args = vfs_plugin.ApiGetFileBlobArgs( - client_id=self.client_id, file_path="fs/os/foo/bar") + client_id=self.client_id, file_path="fs/os/foo/bar" + ) with self.assertRaises(vfs_plugin.FileContentNotFoundError) as context: self.handler.Handle(args, context=self.context) @@ -759,6 +762,7 @@ def _testPathTranslation(self, directory: str, result = self.handler.Handle(args, context=self.context) flow_obj = data_store.REL_DB.ReadFlowObject(self.client_id, result.operation_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) self.assertEqual(flow_obj.args.pathspec, expected_pathspec) @@ -938,11 +942,14 @@ def SetupFileMetadata(self, client_id, vfs_path, stat_entry, hash_entry): else: path_info = rdf_objects.PathInfo.OS(components=vfs_path.split("/")) path_info.hash_entry = hash_entry - data_store.REL_DB.WritePathInfos(client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) -class ApiGetVfsTimelineAsCsvHandlerTest(api_test_lib.ApiCallHandlerTest, - VfsTimelineTestMixin): +class ApiGetVfsTimelineAsCsvHandlerTest( + api_test_lib.ApiCallHandlerTest, VfsTimelineTestMixin +): def setUp(self): super().setUp() diff --git a/grr/server/grr_response_server/gui/archive_generator_test.py b/grr/server/grr_response_server/gui/archive_generator_test.py index aa3ee31fba..5adce7d679 100644 --- a/grr/server/grr_response_server/gui/archive_generator_test.py +++ b/grr/server/grr_response_server/gui/archive_generator_test.py @@ -18,6 +18,8 @@ from grr_response_server.databases import db from grr_response_server.gui import archive_generator from grr_response_server.models import blobs +from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -48,7 +50,9 @@ def _CreateFile(self, client_id, vfs_path, content): db.ClientPath.FromPathInfo(client_id, path_info), [blob_ref]) path_info.hash_entry.sha256 = hash_id.AsBytes() - data_store.REL_DB.WritePathInfos(client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) def _InitializeFiles(self): path1 = "fs/os/foo/bar/hello1.txt" @@ -263,6 +267,7 @@ def setUp(self): self.flow_id = flow_test_lib.StartFlow( flow_test_lib.DummyFlow, client_id=self.client_id) self.flow = data_store.REL_DB.ReadFlowObject(self.client_id, self.flow_id) + self.flow = mig_flow_objects.ToRDFFlow(self.flow) self.path1 = db.ClientPath.OS(self.client_id, ["foo", "bar", "hello1.txt"]) self.path1_content = "hello1".encode("utf-8") diff --git a/grr/server/grr_response_server/gui/gui_test_lib.py b/grr/server/grr_response_server/gui/gui_test_lib.py index 1553882529..906388322e 100644 --- a/grr/server/grr_response_server/gui/gui_test_lib.py +++ b/grr/server/grr_response_server/gui/gui_test_lib.py @@ -41,6 +41,7 @@ from grr_response_server.gui import webauth from grr_response_server.gui import wsgiapp_testlib from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner +from grr_response_server.rdfvalues import mig_objects from grr_response_server.rdfvalues import objects as rdf_objects from grr.test_lib import acl_test_lib from grr.test_lib import action_mocks @@ -118,7 +119,9 @@ def CreateFolder(client_id, path, timestamp): path_info.components = components path_info.directory = True - data_store.REL_DB.WritePathInfos(client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) def SeleniumAction(f): diff --git a/grr/server/grr_response_server/gui/http_api.py b/grr/server/grr_response_server/gui/http_api.py index 199f0bd8ad..f680d431df 100644 --- a/grr/server/grr_response_server/gui/http_api.py +++ b/grr/server/grr_response_server/gui/http_api.py @@ -16,6 +16,7 @@ from werkzeug import routing from google.protobuf import json_format +from google.protobuf import message from grr_response_core import config from grr_response_core.lib import rdfvalue from grr_response_core.lib import serialization @@ -303,7 +304,18 @@ def _FormatResultAsJson(self, result, format_mode=None): def CallApiHandler(handler, args, context=None): """Handles API call to a given handler with given args and context.""" - result = handler.Handle(args, context=context) + if handler.proto_args_type: + result = handler.Handle(args.AsPrimitiveProto(), context=context) + else: + result = handler.Handle(args, context=context) + + # TODO: Once all ApiCallHandlers are migrated this code + # can likely be moved to the result json formatter (legacy UI formatting + # currently relies on RDFValue). + if isinstance(result, message.Message): + rdf_cls = handler.result_type + proto_bytes = result.SerializeToString() + result = rdf_cls.FromSerializedBytes(proto_bytes) expected_type = handler.result_type if expected_type is None: diff --git a/grr/server/grr_response_server/gui/http_request.py b/grr/server/grr_response_server/gui/http_request.py new file mode 100644 index 0000000000..6b30cfa51d --- /dev/null +++ b/grr/server/grr_response_server/gui/http_request.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +"""HTTP request wrapper.""" + +from werkzeug import wrappers as werkzeug_wrappers + +from grr_response_core.lib import rdfvalue + + +class RequestHasNoUserError(AttributeError): + """Error raised when accessing a user of an unautenticated request.""" + + +class HttpRequest(werkzeug_wrappers.Request): + """HTTP request object to be used in GRR.""" + + charset = "utf-8" + encoding_errors = "strict" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._user = None + self.email = None + + self.timestamp = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() + + self.method_metadata = None + self.parsed_args = None + + @property + def user(self): + if self._user is None: + raise RequestHasNoUserError( + "Trying to access Request.user while user is unset." + ) + + if not self._user: + raise RequestHasNoUserError( + "Trying to access Request.user while user is empty." + ) + + return self._user + + @user.setter + def user(self, value): + if not isinstance(value, str): + message = "Expected instance of '%s' but got value '%s' of type '%s'" + message %= (str, value, type(value)) + raise TypeError(message) + + self._user = value diff --git a/grr/server/grr_response_server/gui/selenium_tests/flow_copy_test.py b/grr/server/grr_response_server/gui/selenium_tests/flow_copy_test.py index 52f6b8a2ac..5f0688b10f 100644 --- a/grr/server/grr_response_server/gui/selenium_tests/flow_copy_test.py +++ b/grr/server/grr_response_server/gui/selenium_tests/flow_copy_test.py @@ -13,6 +13,7 @@ from grr_response_server.gui import api_call_router_with_approval_checks from grr_response_server.gui import gui_test_lib from grr_response_server.output_plugins import email_plugin +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import fixture_test_lib from grr.test_lib import flow_test_lib @@ -157,6 +158,7 @@ def testUserChangesToCopiedFlowAreRespected(self): # Now open the last flow and check that it has the changes we made. flows = data_store.REL_DB.ReadAllFlowObjects(client_id=self.client_id) + flows = [mig_flow_objects.ToRDFFlow(f) for f in flows] flows.sort(key=lambda f: f.create_time) fobj = flows[-1] diff --git a/grr/server/grr_response_server/gui/selenium_tests/hunt_acls_test.py b/grr/server/grr_response_server/gui/selenium_tests/hunt_acls_test.py index 8a4cec8c67..a38c0c7871 100644 --- a/grr/server/grr_response_server/gui/selenium_tests/hunt_acls_test.py +++ b/grr/server/grr_response_server/gui/selenium_tests/hunt_acls_test.py @@ -362,8 +362,9 @@ def testFlowDiffIsShownIfHuntCreatedFromFlow(self): "css=tr.diff-changed:contains('Action'):contains('DOWNLOAD')") self.WaitUntil( self.IsElementPresent, - "css=tr:not(:contains('Args')):contains('Conditions')" - ":has('.diff-added'):contains('Size'):contains('42')") + "css=tr.diff-added:not(:contains('Args')):contains('Conditions')" + ":contains('Size'):contains('42')", + ) def testOriginalFlowLinkIsShownIfHuntCreatedFromFlow(self): h_id, flow_id = self._CreateHuntFromFlow() diff --git a/grr/server/grr_response_server/gui/selenium_tests/hunt_copy_test.py b/grr/server/grr_response_server/gui/selenium_tests/hunt_copy_test.py index ad3cc26df3..030ee65d67 100644 --- a/grr/server/grr_response_server/gui/selenium_tests/hunt_copy_test.py +++ b/grr/server/grr_response_server/gui/selenium_tests/hunt_copy_test.py @@ -11,6 +11,7 @@ from grr_response_server.flows.general import transfer from grr_response_server.gui import gui_test_lib from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner +from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import test_lib @@ -250,6 +251,7 @@ def testCopyHuntRespectsUserChanges(self): self.assertLen(hunts_list, 2) last_hunt = hunts_list[-1] + last_hunt = mig_hunt_objects.ToRDFHunt(last_hunt) args = last_hunt.args.standard.flow_args.Unpack(transfer.GetFileArgs) self.assertEqual(args.pathspec.path, "/tmp/very-evil.txt") @@ -340,6 +342,7 @@ def testCopyHuntHandlesLiteralExpressionCorrectly(self): self.assertLen(hunts_list, 2) last_hunt = hunts_list[-1] + last_hunt = mig_hunt_objects.ToRDFHunt(last_hunt) # Check that the hunt was created with a correct literal value. self.assertEqual(last_hunt.args.standard.flow_name, @@ -445,6 +448,7 @@ def testRuleTypeChangeClearsItsProto(self): self.assertLen(hunts_list, 1) hunt = hunts_list[0] + hunt = mig_hunt_objects.ToRDFHunt(hunt) # Check that the hunt was created with correct rules rules = hunt.client_rule_set.rules diff --git a/grr/server/grr_response_server/gui/selenium_tests/hunt_create_test.py b/grr/server/grr_response_server/gui/selenium_tests/hunt_create_test.py index 712dd8f2ea..64580069a7 100644 --- a/grr/server/grr_response_server/gui/selenium_tests/hunt_create_test.py +++ b/grr/server/grr_response_server/gui/selenium_tests/hunt_create_test.py @@ -16,6 +16,7 @@ from grr_response_server.flows.general import transfer from grr_response_server.gui import gui_test_lib from grr_response_server.rdfvalues import flow_runner as rdf_flow_runner +from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import test_lib @@ -262,7 +263,7 @@ def testNewHuntWizard(self): # Check that the hunt was created with a correct flow hunt = hunts_list[0] - + hunt = mig_hunt_objects.ToRDFHunt(hunt) self.assertEqual(hunt.args.standard.flow_name, file_finder.FileFinder.__name__) @@ -396,6 +397,7 @@ def testLiteralExpressionIsProcessedCorrectly(self): # Check that the hunt was created with a correct literal value. hunt = hunts_list[0] + hunt = mig_hunt_objects.ToRDFHunt(hunt) self.assertEqual(hunt.args.standard.flow_name, file_finder.FileFinder.__name__) diff --git a/grr/server/grr_response_server/gui/selenium_tests/hunt_view_test.py b/grr/server/grr_response_server/gui/selenium_tests/hunt_view_test.py index 287ec7d81f..ac01a49a3b 100644 --- a/grr/server/grr_response_server/gui/selenium_tests/hunt_view_test.py +++ b/grr/server/grr_response_server/gui/selenium_tests/hunt_view_test.py @@ -11,6 +11,7 @@ from grr_response_server import hunt from grr_response_server.gui import gui_test_lib from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import test_lib @@ -123,11 +124,10 @@ def testHuntOverviewShowsStats(self): client_id=client_id, flow_id=hunt_id, parent_hunt_id=hunt_id, - create_time=rdfvalue.RDFDatetime.Now(), ) rdf_flow.cpu_time_used.user_cpu_time = 5000 rdf_flow.network_bytes_sent = 1000000 - data_store.REL_DB.WriteFlowObject(rdf_flow) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(rdf_flow)) # Open up and click on View Hunts then the first Hunt. self.Open("/legacy") @@ -150,11 +150,10 @@ def testHuntOverviewGetsUpdatedWhenHuntChanges(self): client_id=client_id, flow_id=hunt_id, parent_hunt_id=hunt_id, - create_time=rdfvalue.RDFDatetime.Now(), ) rdf_flow.cpu_time_used.user_cpu_time = 5000 rdf_flow.network_bytes_sent = 1000000 - data_store.REL_DB.WriteFlowObject(rdf_flow) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(rdf_flow)) self.Open("/legacy") # Ensure auto-refresh updates happen every second. @@ -173,11 +172,10 @@ def testHuntOverviewGetsUpdatedWhenHuntChanges(self): client_id=client_id, flow_id=hunt_id, parent_hunt_id=hunt_id, - create_time=rdfvalue.RDFDatetime.Now(), ) rdf_flow.cpu_time_used.user_cpu_time = 1000 rdf_flow.network_bytes_sent = 10000000 - data_store.REL_DB.WriteFlowObject(rdf_flow) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(rdf_flow)) self.WaitUntil(self.IsTextPresent, "1h 40m") self.WaitUntil(self.IsTextPresent, "10.5MiB") diff --git a/grr/server/grr_response_server/gui/static/angular-components/docs/api-docs-examples.json b/grr/server/grr_response_server/gui/static/angular-components/docs/api-docs-examples.json index 700c245866..447e8db684 100644 --- a/grr/server/grr_response_server/gui/static/angular-components/docs/api-docs-examples.json +++ b/grr/server/grr_response_server/gui/static/angular-components/docs/api-docs-examples.json @@ -8867,7 +8867,7 @@ }, "session_id": { "type": "SessionID", - "value": "" + "value": "aff4:/C.1000000000000000/H:123456" } } } @@ -9090,7 +9090,7 @@ "user_cpu_time": 1.0 }, "network_bytes_sent": 3, - "session_id": "" + "session_id": "aff4:/C.1000000000000000/H:123456" } ] } @@ -11454,7 +11454,6 @@ "type": "ArtifactName", "value": "TestDrivers" }, - "provides": [], "sources": [ { "type": "ArtifactSource", @@ -11512,7 +11511,6 @@ "artifact": { "doc": "Extract the installed drivers on Windows via WMI.", "name": "TestDrivers", - "provides": [], "sources": [ { "attributes": { diff --git a/grr/server/grr_response_server/gui/static/angular-components/docs/api-v2-docs-examples.json b/grr/server/grr_response_server/gui/static/angular-components/docs/api-v2-docs-examples.json index 750a1830b2..ee8bdb3f8d 100644 --- a/grr/server/grr_response_server/gui/static/angular-components/docs/api-v2-docs-examples.json +++ b/grr/server/grr_response_server/gui/static/angular-components/docs/api-v2-docs-examples.json @@ -2685,7 +2685,7 @@ "userCpuTime": 1.0 }, "networkBytesSent": "3", - "sessionId": "" + "sessionId": "aff4:/C.1000000000000000/H:123456" } ] } diff --git a/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card.ng.html b/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card.ng.html index c5f5a7c732..71b8fd1297 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card.ng.html @@ -205,6 +205,7 @@

Access approval

[disabled]="submitDisabled$ | async" class="progress-spinner-button" [matTooltip]="'CTRL/⌘ + ENTER'" + name="request-access" >
diff --git a/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card_test.ts b/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card_test.ts index 8888332b05..33571dbea6 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/approval_card/approval_card_test.ts @@ -5,7 +5,12 @@ import {By} from '@angular/platform-browser'; import {NoopAnimationsModule} from '@angular/platform-browser/animations'; import {Router} from '@angular/router'; import {RouterTestingModule} from '@angular/router/testing'; +import { + RequestStatusType, + type RequestStatus, +} from '../../lib/api/track_request'; +import {newApproval} from '../../lib/models/model_test_util'; import {Approval} from '../../lib/models/user'; import {ApprovalCardLocalStore} from '../../store/approval_card_local_store'; import { @@ -70,9 +75,11 @@ describe('ApprovalCard Component', () => { validateOnStart = false, showDuration = false, editableDuration = false, + requestApprovalStatus: RequestStatus | null = null, ): ComponentFixture { const fixture = TestBed.createComponent(ApprovalCard); fixture.componentInstance.latestApproval = latestApproval; + fixture.componentInstance.requestApprovalStatus = requestApprovalStatus; fixture.componentInstance.urlTree = urlTree; fixture.componentInstance.validateOnStart = validateOnStart; fixture.componentInstance.showDuration = showDuration; @@ -453,4 +460,66 @@ describe('ApprovalCard Component', () => { const text = fixture.debugElement.nativeElement.textContent; expect(text).toContain('sample reason http://example.com'); }); + + it('displays spinner when request is pending', () => { + const fixture = createComponent(null, [], false, false, false, { + status: RequestStatusType.SENT, + }); + fixture.detectChanges(); + + const button = fixture.debugElement.query( + By.css("[name='request-access']"), + ); + expect(button.nativeElement.disabled).toBeTrue(); + + const buttonSpinner = fixture.debugElement.query( + By.css("[name='request-access'] mat-spinner"), + ); + expect(buttonSpinner).toBeTruthy(); + + const errorMsg = fixture.debugElement.query(By.css('mat-error')); + expect(errorMsg).toBeFalsy(); + }); + + it('displays error from status', () => { + const fixture = createComponent(null, [], false, false, false, { + status: RequestStatusType.ERROR, + error: 'bad request', + }); + fixture.detectChanges(); + + const button = fixture.debugElement.query( + By.css("[name='request-access']"), + ); + expect(button.nativeElement.disabled).toBeFalse(); // Enabled + + const buttonSpinner = fixture.debugElement.query( + By.css("[name='request-access'] mat-spinner"), + ); + expect(buttonSpinner).toBeFalsy(); + + const errorMsg = fixture.debugElement.query(By.css('mat-error')); + expect(errorMsg.nativeElement.innerText).toContain('bad request'); + }); + + it('displays success from status', () => { + const fixture = createComponent(null, [], false, false, false, { + status: RequestStatusType.SUCCESS, + data: newApproval(), + }); + fixture.detectChanges(); + + const button = fixture.debugElement.query( + By.css("[name='request-access']"), + ); + expect(button.nativeElement.disabled).toBeTrue(); + + const buttonSpinner = fixture.debugElement.query( + By.css("[name='request-access'] mat-spinner"), + ); + expect(buttonSpinner).toBeFalsy(); + + const errorMsg = fixture.debugElement.query(By.css('mat-error')); + expect(errorMsg).toBeFalsy(); + }); }); diff --git a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ng.html b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ng.html index f79229913d..ef4afc5eb2 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ng.html @@ -52,7 +52,7 @@ > warning - This duration is longer then the default of {{defaultAccessDurationDays}} days. + This duration is longer than the default of {{defaultAccessDurationDays}} days. diff --git a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ts b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ts index 16db5decb9..56e31b0dca 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ts +++ b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page.ts @@ -110,6 +110,9 @@ export class ApprovalPage implements OnDestroy { watermark.setDate( watermark.getDate() + this.defaultAccessDurationDays, ); + // TODO: Use a daylight saving time resistant time + // measurement. + watermark.setHours(watermark.getHours() + 1); this.longExpiration = approval.expirationTime > watermark; }); diff --git a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_module.ts b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_module.ts index 685bc7c4d6..7a68f668db 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_module.ts +++ b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_module.ts @@ -10,7 +10,6 @@ import {RouterModule} from '@angular/router'; import {ClientDetailsModule} from '../client_details/module'; import {ClientOverviewModule} from '../client_overview/module'; -import {DrawerLinkModule} from '../helpers/drawer_link/drawer_link_module'; import {TextWithLinksModule} from '../helpers/text_with_links/text_with_links_module'; import {ScheduledFlowListModule} from '../scheduled_flow_list/module'; import {TimestampModule} from '../timestamp/module'; @@ -26,7 +25,6 @@ import {ApprovalPage} from './approval_page'; ClientDetailsModule, ClientOverviewModule, CommonModule, - DrawerLinkModule, MatButtonModule, MatCardModule, MatChipsModule, diff --git a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_test.ts b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_test.ts index 2ec2fdaddf..5cdf2d794a 100644 --- a/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/approval_page/approval_page_test.ts @@ -65,7 +65,8 @@ describe('ApprovalPage Component', () => { fixture.detectChanges(); const twentyEightDaysFromNow = new Date( - Date.now() + 28 * 24 * 60 * 60 * 1000 - 1000, + // 28 days minus 1 hour in ms. + Date.now() + 28 * 24 * 60 * 60 * 1000 - 1000 * 60 * 60, ); injectMockStore(ApprovalPageGlobalStore).mockedObservables.approval$.next( @@ -145,7 +146,7 @@ describe('ApprovalPage Component', () => { ); expect(timestampChip).not.toBeNull(); expect(timestampChip.textContent).toEqual( - 'warning This duration is longer then the default of 28 days. ', + 'warning This duration is longer than the default of 28 days. ', ); }); diff --git a/grr/server/grr_response_server/gui/ui/components/client_overview/client_overview.ts b/grr/server/grr_response_server/gui/ui/components/client_overview/client_overview.ts index ef09e472c4..66641c22ae 100644 --- a/grr/server/grr_response_server/gui/ui/components/client_overview/client_overview.ts +++ b/grr/server/grr_response_server/gui/ui/components/client_overview/client_overview.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import { ChangeDetectionStrategy, Component, @@ -73,13 +72,14 @@ export class ClientOverview implements OnInit, OnDestroy { readonly activeOnlineNotificationArgs$ = this.clientPageGlobalStore.flowListEntries$.pipe( withLatestFrom(this.userGlobalStore.currentUser$), - map(([data, user]) => - data.flows?.find( - (f) => - f.name === 'OnlineNotification' && - f.creator === user.name && - f.state === FlowState.RUNNING, - ), + map( + ([data, user]) => + data.flows?.find( + (f) => + f.name === 'OnlineNotification' && + f.creator === user.name && + f.state === FlowState.RUNNING, + ), ), map((flow) => flow?.args as OnlineNotificationArgs | undefined), ); diff --git a/grr/server/grr_response_server/gui/ui/components/client_page/client_page_module.ts b/grr/server/grr_response_server/gui/ui/components/client_page/client_page_module.ts index 494013a484..5d0f1aadba 100644 --- a/grr/server/grr_response_server/gui/ui/components/client_page/client_page_module.ts +++ b/grr/server/grr_response_server/gui/ui/components/client_page/client_page_module.ts @@ -16,7 +16,6 @@ import {FlowListModule} from '../../components/flow_list/module'; import {ApprovalCardModule} from '../approval_card/module'; import {ClientOverviewModule} from '../client_overview/module'; import {FileDetailsModule} from '../file_details/file_details_module'; -import {DrawerLinkModule} from '../helpers/drawer_link/drawer_link_module'; import {HumanReadableSizeModule} from '../human_readable_size/module'; import {ScheduledFlowListModule} from '../scheduled_flow_list/module'; import {TimestampModule} from '../timestamp/module'; @@ -37,7 +36,6 @@ import {VfsSection} from './vfs_section'; ApprovalCardModule, BrowserAnimationsModule, ClientOverviewModule, - DrawerLinkModule, FileDetailsModule, FlowFormModule, FlowListModule, diff --git a/grr/server/grr_response_server/gui/ui/components/client_page/flow_section.ng.html b/grr/server/grr_response_server/gui/ui/components/client_page/flow_section.ng.html index 4f8c2213d7..5b8ce97901 100644 --- a/grr/server/grr_response_server/gui/ui/components/client_page/flow_section.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/client_page/flow_section.ng.html @@ -10,6 +10,7 @@ { fixture.detectChanges(); const harnessLoader = TestbedHarnessEnvironment.loader(fixture); - const paginationHarness = await harnessLoader.getHarness( - MatPaginatorHarness, - ); + const paginationHarness = + await harnessLoader.getHarness(MatPaginatorHarness); const pageSize = await paginationHarness.getPageSize(); expect(resultSource.loadResults).toHaveBeenCalledOnceWith({ @@ -150,9 +148,8 @@ describe('FlowDataTableView component', () => { fixture.detectChanges(); const harnessLoader = TestbedHarnessEnvironment.loader(fixture); - const paginationHarness = await harnessLoader.getHarness( - MatPaginatorHarness, - ); + const paginationHarness = + await harnessLoader.getHarness(MatPaginatorHarness); await paginationHarness.goToNextPage(); const pageSize = await paginationHarness.getPageSize(); diff --git a/grr/server/grr_response_server/gui/ui/components/data_renderers/table/table_test.ts b/grr/server/grr_response_server/gui/ui/components/data_renderers/table/table_test.ts index 217d4f0f77..bc0916a545 100644 --- a/grr/server/grr_response_server/gui/ui/components/data_renderers/table/table_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/data_renderers/table/table_test.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {TestbedHarnessEnvironment} from '@angular/cdk/testing/testbed'; import {TestBed, waitForAsync} from '@angular/core/testing'; import {MatPaginatorHarness} from '@angular/material/paginator/testing'; @@ -133,9 +132,8 @@ describe('DataTableView component', () => { fixture.detectChanges(); const harnessLoader = TestbedHarnessEnvironment.loader(fixture); - const paginationHarness = await harnessLoader.getHarness( - MatPaginatorHarness, - ); + const paginationHarness = + await harnessLoader.getHarness(MatPaginatorHarness); const pageSize = await paginationHarness.getPageSize(); expect(resultSource.loadResults).toHaveBeenCalledOnceWith({ @@ -150,9 +148,8 @@ describe('DataTableView component', () => { fixture.detectChanges(); const harnessLoader = TestbedHarnessEnvironment.loader(fixture); - const paginationHarness = await harnessLoader.getHarness( - MatPaginatorHarness, - ); + const paginationHarness = + await harnessLoader.getHarness(MatPaginatorHarness); await paginationHarness.goToNextPage(); const pageSize = await paginationHarness.getPageSize(); diff --git a/grr/server/grr_response_server/gui/ui/components/file_details/file_details_module.ts b/grr/server/grr_response_server/gui/ui/components/file_details/file_details_module.ts index d800a3e5a1..5baca71a2a 100644 --- a/grr/server/grr_response_server/gui/ui/components/file_details/file_details_module.ts +++ b/grr/server/grr_response_server/gui/ui/components/file_details/file_details_module.ts @@ -8,7 +8,6 @@ import {BrowserAnimationsModule} from '@angular/platform-browser/animations'; import {RouterModule} from '@angular/router'; import {HexViewModule} from '../data_renderers/hex_view/hex_view_module'; -import {DrawerLinkModule} from '../helpers/drawer_link/drawer_link_module'; import {HumanReadableSizeModule} from '../human_readable_size/module'; import {TimestampModule} from '../timestamp/module'; @@ -21,7 +20,6 @@ import {FileDetailsPage} from './file_details_page'; // prettier-ignore // keep-sorted start block=yes BrowserAnimationsModule, - DrawerLinkModule, HexViewModule, HumanReadableSizeModule, MatButtonModule, diff --git a/grr/server/grr_response_server/gui/ui/components/flow_args_form/form_interface.ts b/grr/server/grr_response_server/gui/ui/components/flow_args_form/form_interface.ts index 41fdc089c6..1ad846a6ff 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_args_form/form_interface.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_args_form/form_interface.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {Component, OnDestroy} from '@angular/core'; import {AbstractControl, FormGroup} from '@angular/forms'; import {map} from 'rxjs/operators'; @@ -45,10 +44,10 @@ export declare type ControlValues< ? // For basic {key: FormControl()} mappings, the type is {key: X}. T[K]['value'] : T[K] extends undefined | infer C extends AbstractControl - ? // For optional {key?: FormControl()} mappings, the type is - // {key: X|undefined}. - C['value'] | undefined - : never; + ? // For optional {key?: FormControl()} mappings, the type is + // {key: X|undefined}. + C['value'] | undefined + : never; }; /** Form component to configure arguments for a Flow. */ diff --git a/grr/server/grr_response_server/gui/ui/components/flow_args_form/osquery_form_test.ts b/grr/server/grr_response_server/gui/ui/components/flow_args_form/osquery_form_test.ts index 75fc0018ff..6608be3672 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_args_form/osquery_form_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_args_form/osquery_form_test.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {TestbedHarnessEnvironment} from '@angular/cdk/testing/testbed'; import {TestBed, waitForAsync} from '@angular/core/testing'; import {ReactiveFormsModule} from '@angular/forms'; @@ -120,9 +119,8 @@ describe('OsqueryForm', () => { ); await expandButtonHarness.click(); - const collectionListHarness = await harnessLoader.getHarness( - MatChipGridHarness, - ); + const collectionListHarness = + await harnessLoader.getHarness(MatChipGridHarness); const inputHarness = await collectionListHarness.getInput(); await inputHarness?.setValue('column1'); @@ -148,9 +146,8 @@ describe('OsqueryForm', () => { ); await expandButtonHarness.click(); - const collectionListHarness = await harnessLoader.getHarness( - MatChipGridHarness, - ); + const collectionListHarness = + await harnessLoader.getHarness(MatChipGridHarness); const inputHarness = await collectionListHarness.getInput(); await inputHarness?.setValue('column1'); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ng.html b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ng.html index d223fae3ad..39696173f8 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ng.html @@ -108,6 +108,26 @@
+
+ + Match context capture window + + + Context window must be a non-negative integers. Invalid: {{error.value}} + + +
+
Skip memory regions: diff --git a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ts b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ts index 9aec196adf..67610dc45e 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form.ts @@ -1,5 +1,5 @@ import {ChangeDetectionStrategy, Component, ViewChild} from '@angular/core'; -import {FormControl} from '@angular/forms'; +import {FormControl, Validators} from '@angular/forms'; import { ControlValues, @@ -22,6 +22,10 @@ function makeControls() { processRegex: new FormControl('', {nonNullable: true}), cmdlineRegex: new FormControl('', {nonNullable: true}), pids: new FormControl([], {nonNullable: true}), + contextWindow: new FormControl(50, { + nonNullable: true, + validators: [Validators.min(0)], + }), skipSpecialRegions: new FormControl(false, {nonNullable: true}), skipMappedFiles: new FormControl(false, {nonNullable: true}), skipSharedRegions: new FormControl(false, {nonNullable: true}), @@ -79,6 +83,8 @@ export class YaraProcessScanForm extends FlowArgumentForm< flowArgs.yaraSignature ?? this.controls.yaraSignature.defaultValue, filterMode, pids: flowArgs.pids?.map(Number) ?? this.controls.pids.defaultValue, + contextWindow: + flowArgs.contextWindow ?? this.controls.contextWindow.defaultValue, processRegex: flowArgs.processRegex ?? this.controls.processRegex.defaultValue, cmdlineRegex: @@ -116,6 +122,7 @@ export class YaraProcessScanForm extends FlowArgumentForm< ? formState.cmdlineRegex : undefined, + contextWindow: formState.contextWindow, skipSpecialRegions: formState.skipSpecialRegions, skipMappedFiles: formState.skipMappedFiles, skipSharedRegions: formState.skipSharedRegions, diff --git a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form_test.ts b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form_test.ts index 1a0a69732f..1ae8170a7e 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_args_form/yara_process_scan_form_test.ts @@ -98,4 +98,24 @@ describe('YaraProcessScanForm', () => { }), ); }); + + it('shows a context window input field', async () => { + const fixture = TestBed.createComponent(YaraProcessScanForm); + fixture.detectChanges(); + + const latestValue = latestValueFrom(fixture.componentInstance.flowArgs$); + + const input = fixture.debugElement.query( + By.css('input[name=contextWindow]'), + ); + input.nativeElement.value = 999; + input.triggerEventHandler('input', {target: input.nativeElement}); + fixture.detectChanges(); + + expect(latestValue.get()).toEqual( + jasmine.objectContaining({ + contextWindow: 999, + }), + ); + }); }); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/helpers/file_results_table.ng.html b/grr/server/grr_response_server/gui/ui/components/flow_details/helpers/file_results_table.ng.html index efe663a1cc..6b834786c4 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/helpers/file_results_table.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/helpers/file_results_table.ng.html @@ -17,7 +17,7 @@ Path - {{ r.path }} + {{ r.path }} @@ -60,28 +60,28 @@ A-time - + M-time - + C-time - + B-time - + diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_files_by_known_path_details.ts b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_files_by_known_path_details.ts index 0a8a53b746..97c95736fb 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_files_by_known_path_details.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_files_by_known_path_details.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {ChangeDetectionStrategy, Component} from '@angular/core'; import {BehaviorSubject, Observable} from 'rxjs'; import {map, takeUntil} from 'rxjs/operators'; @@ -130,8 +129,9 @@ export class CollectFilesByKnownPathDetails extends Plugin { readonly fileResults$: Observable = this.flowResultsLocalStore.results$.pipe( - map((results) => - results?.map((data) => data.payload as CollectFilesByKnownPathResult), + map( + (results) => + results?.map((data) => data.payload as CollectFilesByKnownPathResult), ), ); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.ng.html b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.ng.html index c1ee091445..9489e11332 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.ng.html @@ -38,4 +38,8 @@ + +
+ See Collect Large File documentation for further details. +
diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.scss b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.scss index 88a007556a..a627b383b7 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.scss +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/collect_large_file_flow_details.scss @@ -22,3 +22,7 @@ overflow: hidden; white-space: nowrap; } + +.doc-reference-container { + padding: 16px 12px; +} diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/execute_python_hack_details.ts b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/execute_python_hack_details.ts index 8a633bae14..1bd2c651fd 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/execute_python_hack_details.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/execute_python_hack_details.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {ChangeDetectionStrategy, Component} from '@angular/core'; import {Observable} from 'rxjs'; import {filter, map} from 'rxjs/operators'; @@ -44,10 +43,11 @@ export class ExecutePythonHackDetails extends Plugin { ); readonly textContent$ = this.flowResultsLocalStore.results$.pipe( - map((results) => - ( - results[0]?.payload as ExecutePythonHackResult | undefined - )?.resultString?.split('\n'), + map( + (results) => + ( + results[0]?.payload as ExecutePythonHackResult | undefined + )?.resultString?.split('\n'), ), filter(isNonNull), ); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/netstat_details.ts b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/netstat_details.ts index 2d3bddbad7..9ee71c3222 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/netstat_details.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/netstat_details.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import { ChangeDetectionStrategy, Component, @@ -66,11 +65,12 @@ export class NetstatDetails extends Plugin implements OnDestroy { readonly netstatResults$: Observable = this.flowResultsLocalStore.results$.pipe( - map((results) => - results?.map((data) => data.payload as NetworkConnection), + map( + (results) => results?.map((data) => data.payload as NetworkConnection), ), - map((connections) => - connections?.map((connection) => asConnectionRow(connection)), + map( + (connections) => + connections?.map((connection) => asConnectionRow(connection)), ), ); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ng.html b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ng.html index f98f65ee71..c55a97d0bd 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ng.html @@ -37,6 +37,13 @@ + + Context (Base64) + + {{ r.context || 'empty' }} + + + diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ts b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ts index 3f02a4e898..1bd8a9ba61 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details.ts @@ -71,6 +71,7 @@ export class YaraProcessScanDetails extends Plugin { 'matchOffset', 'matchId', 'matchData', + 'context', ]; } @@ -104,6 +105,7 @@ function toRow( return { process: response.process, match: yaraMatch, + context: decodeBase64ToString(stringMatch.context ?? ''), stringMatch, data: decodeBase64ToString(stringMatch.data ?? ''), }; diff --git a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details_test.ts b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details_test.ts index fe34c8aa75..c817233705 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_details/plugins/yara_process_scan_details_test.ts @@ -79,6 +79,7 @@ describe('app-yara-process-scan-details component', () => { data: encodeStringToBase64('ExampleData'), stringId: 'ExampleStringId', offset: '456', + context: encodeStringToBase64('ExampleContext'), }, ], }, @@ -98,5 +99,6 @@ describe('app-yara-process-scan-details component', () => { expect(fixture.nativeElement.textContent).toContain('ExampleRuleName'); expect(fixture.nativeElement.textContent).toContain('ExampleData'); expect(fixture.nativeElement.textContent).toContain('ExampleStringId'); + expect(fixture.nativeElement.textContent).toContain('ExampleContext'); }); }); diff --git a/grr/server/grr_response_server/gui/ui/components/flow_picker/flow_picker_test.ts b/grr/server/grr_response_server/gui/ui/components/flow_picker/flow_picker_test.ts index c84e722843..27fc7187b5 100644 --- a/grr/server/grr_response_server/gui/ui/components/flow_picker/flow_picker_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/flow_picker/flow_picker_test.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {OverlayContainer} from '@angular/cdk/overlay'; import {TestbedHarnessEnvironment} from '@angular/cdk/testing/testbed'; import { @@ -213,8 +212,8 @@ describe('FlowPicker Component', () => { await autocompleteHarness.focus(); const links = overlayContainerElement.querySelectorAll('flows-overview a'); - const link = Array.from(links).find((l) => - l.textContent?.includes('Forensic artifacts'), + const link = Array.from(links).find( + (l) => l.textContent?.includes('Forensic artifacts'), ); assertNonNull(link); link.dispatchEvent(new MouseEvent('click')); diff --git a/grr/server/grr_response_server/gui/ui/components/form/date_time_input/date_time_input_test.ts b/grr/server/grr_response_server/gui/ui/components/form/date_time_input/date_time_input_test.ts index 8fcbeedd8c..d0f7f7e13a 100644 --- a/grr/server/grr_response_server/gui/ui/components/form/date_time_input/date_time_input_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/form/date_time_input/date_time_input_test.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {TestbedHarnessEnvironment} from '@angular/cdk/testing/testbed'; import {Component} from '@angular/core'; import { @@ -119,9 +118,8 @@ describe('DateTimeInput Component', () => { calendarButton.nativeElement.click(); fixture.detectChanges(); - const calendar = await loader.getHarness( - MatCalendarHarness, - ); + const calendar = + await loader.getHarness(MatCalendarHarness); const cells = await calendar.getCells(); await calendar.selectCell({text: await cells[0].getText()}); fixture.detectChanges(); @@ -168,9 +166,8 @@ describe('DateTimeInput Component', () => { calendarButton.nativeElement.click(); fixture.detectChanges(); - const calendar = await loader.getHarness( - MatCalendarHarness, - ); + const calendar = + await loader.getHarness(MatCalendarHarness); const cells = await calendar.getCells(); await calendar.selectCell({text: await cells[0].getText()}); fixture.detectChanges(); diff --git a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ng.html b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ng.html index 114adb1be7..fc65accbcc 100644 --- a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ng.html @@ -1,6 +1,6 @@ -
+
diff --git a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.scss b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.scss index 440b9186ba..16390d54c9 100644 --- a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.scss +++ b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.scss @@ -31,6 +31,10 @@ $transition-time: 200ms; } } + .multiline { + white-space: normal; + } + .icon-container { display: flex; align-items: center; diff --git a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ts b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ts index 1f40d2a3f0..d02d3d9aef 100644 --- a/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ts +++ b/grr/server/grr_response_server/gui/ui/components/helpers/copy_button/copy_button.ts @@ -23,6 +23,8 @@ export class CopyButton { */ @Input() overrideCopyText: string | null | undefined = undefined; + @Input() multiline: boolean | null | undefined = undefined; + copied = false; constructor(private readonly clipboard: Clipboard) {} diff --git a/grr/server/grr_response_server/gui/ui/components/home/home.ts b/grr/server/grr_response_server/gui/ui/components/home/home.ts index ada639ccc2..b946fb2a18 100644 --- a/grr/server/grr_response_server/gui/ui/components/home/home.ts +++ b/grr/server/grr_response_server/gui/ui/components/home/home.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {ChangeDetectionStrategy, Component} from '@angular/core'; import {Title} from '@angular/platform-browser'; import {Router} from '@angular/router'; @@ -14,7 +13,10 @@ import {isClientId} from '../../lib/models/client'; changeDetection: ChangeDetectionStrategy.OnPush, }) export class Home { - constructor(private readonly router: Router, title: Title) { + constructor( + private readonly router: Router, + title: Title, + ) { title.setTitle('GRR'); } diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ng.html b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ng.html index 3ea3534a09..1e66208365 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ng.html @@ -12,58 +12,89 @@

{{ getHuntTitle(hunt$ | async) }}

- - - - - - - - - - - - - -
ID: - - {{ (hunt$ | async)?.huntId }} - -
Creator: - - - {{ (hunt$ | async)?.creator }} - -
Flow: - - {{ flowTitle$ | async }} - -
- - - - - - - - - - - - - - -
Resource usage -
Total CPU time: - {{ huntTotalCPU$ | async }} - -
Total network traffic: - - - -
+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
ID: + + {{ (hunt$ | async)?.huntId }} + +
Creator: + + + {{ (hunt$ | async)?.creator }} + +
Created: + +
Initial start time: + + never started +
Last start time: + + never started +
Flow: + + {{ flowTitle$ | async }} + +
+
+
+ + + + + + + + + + + + + +
Resource usage +
Total CPU time: + {{ huntTotalCPU$ | async }} + +
Total network traffic: + + + +
+
@@ -145,6 +176,7 @@

{{ getHuntTitle(hunt$ | async) }}

*ngIf="huntApprovalRequired$ | async" [urlTree]="(huntApprovalRoute$ | async) ?? []" [latestApproval]="latestApproval$ | async" + [requestApprovalStatus]="requestApprovalStatus$ | async" [hideContent]="(hideApprovalCardContentByDefault$ | async) === true" (approvalParams)="requestHuntApproval($event)"> diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ts b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ts index 2582ae07d5..8690a6b9ec 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page.ts @@ -54,6 +54,8 @@ export class HuntPage implements OnDestroy { null, ); readonly latestApproval$ = this.huntApprovalLocalStore.latestApproval$; + readonly requestApprovalStatus$ = + this.huntApprovalLocalStore.requestApprovalStatus$; readonly hideApprovalCardContentByDefault$ = this.latestApproval$.pipe( map((approval) => isNull(approval)), // we are only interested in the first emission, as we don't want diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page_test.ts b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page_test.ts index b9eaa40a84..331d660386 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_page_test.ts @@ -1,5 +1,11 @@ import {Location} from '@angular/common'; -import {fakeAsync, TestBed, tick, waitForAsync} from '@angular/core/testing'; +import { + TestBed, + discardPeriodicTasks, + fakeAsync, + tick, + waitForAsync, +} from '@angular/core/testing'; import {By} from '@angular/platform-browser'; import {NoopAnimationsModule} from '@angular/platform-browser/animations'; import {ActivatedRoute, Router} from '@angular/router'; @@ -15,7 +21,7 @@ import {HttpApiService} from '../../../lib/api/http_api_service'; import {mockHttpApiService} from '../../../lib/api/http_api_service_test_util'; import {translateHuntApproval} from '../../../lib/api_translation/hunt'; import {getFlowTitleFromFlowName} from '../../../lib/models/flow'; -import {getHuntTitle, HuntState} from '../../../lib/models/hunt'; +import {HuntState, getHuntTitle} from '../../../lib/models/hunt'; import { newFlowDescriptorMap, newHunt, @@ -33,8 +39,8 @@ import { mockHuntResultDetailsGlobalStore, } from '../../../store/hunt_result_details_global_store_test_util'; import { - injectMockStore, STORE_PROVIDERS, + injectMockStore, } from '../../../store/store_test_providers'; import {UserGlobalStore} from '../../../store/user_global_store'; import {mockUserGlobalStore} from '../../../store/user_global_store_test_util'; @@ -50,6 +56,9 @@ const TEST_HUNT = newHunt({ huntId: '1984', description: 'Ghost', creator: 'buster', + created: new Date('1970-01-12 13:46:39 UTC'), + initStartTime: new Date('1980-01-12 13:46:39 UTC'), + lastStartTime: undefined, flowName: 'MadeUpFlow', resourceUsage: { totalCPUTime: 0.7999999821186066, @@ -138,6 +147,9 @@ describe('hunt page test', () => { const text = overviewSection.nativeElement.textContent; expect(text).toContain('1984'); expect(text).toContain('buster'); + expect(text).toContain('1970-01-12 13:46:39 UTC'); + expect(text).toContain('1980-01-12 13:46:39 UTC'); + expect(text).toContain('never started'); expect(text).toContain('MadeUpFlow'); expect(text).toContain('View flow arguments'); expect(text).toContain('1 s'); @@ -441,6 +453,7 @@ describe('hunt page test', () => { tick(); // after tick(), URL changes will have taken into effect. expect(location.path()).toBe('/hunts/1984(drawer:modify-hunt)'); + discardPeriodicTasks(); })); it('Copy button navigates to new hunt page with correct param', fakeAsync(async () => { @@ -458,6 +471,8 @@ describe('hunt page test', () => { tick(); // after tick(), URL changes will have taken into effect. expect(location.path()).toBe('/new-hunt?huntId=1984'); + + discardPeriodicTasks(); })); it('does not display approval component if disabled', async () => { diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/hunt_results_test.ts b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/hunt_results_test.ts index 1104fee267..213096bdf1 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/hunt_results_test.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/hunt_results_test.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {TestbedHarnessEnvironment} from '@angular/cdk/testing/testbed'; import {Component, ViewChild} from '@angular/core'; import {TestBed, fakeAsync, tick, waitForAsync} from '@angular/core/testing'; @@ -377,9 +376,8 @@ describe('HuntResults', () => { fixture.detectChanges(); const harnessLoader = TestbedHarnessEnvironment.loader(fixture); - const tabGroupHarness = await harnessLoader.getHarness( - MatTabGroupHarness, - ); + const tabGroupHarness = + await harnessLoader.getHarness(MatTabGroupHarness); expect((await tabGroupHarness.getTabs()).length).toEqual(2); const fileFinderTab = ( diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/module.ts b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/module.ts index d90e622de8..d5ec020ad4 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/module.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/hunt_results/module.ts @@ -18,7 +18,6 @@ import {FileModeModule} from '../../../data_renderers/file_mode/file_mode_module import {ExpandableHashModule} from '../../../expandable_hash/module'; import {HelpersModule} from '../../../flow_details/helpers/module'; import {CopyButtonModule} from '../../../helpers/copy_button/copy_button_module'; -import {DrawerLinkModule} from '../../../helpers/drawer_link/drawer_link_module'; import {FilterPaginate} from '../../../helpers/filter_paginate/filter_paginate'; import {HumanReadableSizeModule} from '../../../human_readable_size/module'; import {TimestampModule} from '../../../timestamp/module'; @@ -34,7 +33,6 @@ import {HuntResults} from './hunt_results'; // keep-sorted start block=yes CommonModule, CopyButtonModule, - DrawerLinkModule, ExpandableHashModule, FileModeModule, FilterPaginate, diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/module.ts b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/module.ts index 832652ae66..38fbb2337b 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/module.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/hunt_page/module.ts @@ -11,6 +11,7 @@ import {TitleEditorModule} from '../../form/title_editor/module'; import {CopyButtonModule} from '../../helpers/copy_button/copy_button_module'; import {DrawerLinkModule} from '../../helpers/drawer_link/drawer_link_module'; import {HumanReadableSizeModule} from '../../human_readable_size/module'; +import {TimestampModule} from '../../timestamp/module'; import {UserImageModule} from '../../user_image/module'; import {HuntArguments} from '../hunt_arguments/hunt_arguments'; import {HuntFlowArguments} from '../hunt_flow_arguments/hunt_flow_arguments'; @@ -47,6 +48,7 @@ import {HuntResultsModule} from './hunt_results/module'; MatIconModule, MatTooltipModule, ModifyHuntModule, + TimestampModule, TitleEditorModule, UserImageModule, // keep-sorted end diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ng.html b/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ng.html index f98736daea..34b88f4182 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ng.html @@ -60,6 +60,7 @@ *ngIf="huntApprovalRequired$ | async" [urlTree]="(huntApprovalRoute$ | async) ?? []" [latestApproval]="latestApproval$ | async" + [requestApprovalStatus]="requestApprovalStatus$ | async" [showSubmitButton]="false" [validateOnStart]="true" (approvalParams)="requestHuntApproval($event)"> diff --git a/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ts b/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ts index 1a14bf8d61..7345bc043e 100644 --- a/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ts +++ b/grr/server/grr_response_server/gui/ui/components/hunt/new_hunt/new_hunt.ts @@ -62,7 +62,7 @@ export class NewHunt { private readonly approvalParams$ = new BehaviorSubject( null, ); - private readonly requestApprovalStatus$ = + protected readonly requestApprovalStatus$ = this.huntApprovalLocalStore.requestApprovalStatus$; protected hasOriginInput: boolean | undefined = undefined; diff --git a/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ng.html b/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ng.html index 8b6cd2501c..ea9513e73d 100644 --- a/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ng.html +++ b/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ng.html @@ -4,7 +4,7 @@ [matTooltip]="relativeTimestamp === 'tooltip' ? ((relativeTimestampString$ | async) ?? '') : ''" [matTooltipDisabled]="relativeTimestamp !== 'tooltip'"> - {{date | date:'yyyy-MM-dd HH:mm:ss':timezone}} {{timezone}} + {{date | date:'yyyy-MM-dd HH:mm:ss':timezone}} {{timezone}} {{ relativeTimestampString$ | async }} diff --git a/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ts b/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ts index cff7e6479e..791117ffe3 100644 --- a/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ts +++ b/grr/server/grr_response_server/gui/ui/components/timestamp/timestamp.ts @@ -60,6 +60,9 @@ export class Timestamp implements OnDestroy { get date() { return this.date$.value; } + + @Input() multiline: boolean | null | undefined = undefined; + @Input() relativeTimestamp: RelativeTimestampVisibility = 'tooltip'; readonly timezone: string = 'UTC'; diff --git a/grr/server/grr_response_server/gui/ui/lib/api/api_interfaces.ts b/grr/server/grr_response_server/gui/ui/lib/api/api_interfaces.ts index 9f5e7f957c..6ab7cda2d1 100644 --- a/grr/server/grr_response_server/gui/ui/lib/api/api_interfaces.ts +++ b/grr/server/grr_response_server/gui/ui/lib/api/api_interfaces.ts @@ -2087,7 +2087,6 @@ export declare interface Artifact { readonly doc?: string; readonly supportedOs?: readonly string[]; readonly urls?: readonly string[]; - readonly provides?: readonly string[]; readonly sources?: readonly ArtifactSource[]; readonly errorMessage?: string; readonly aliases?: readonly string[]; @@ -2295,14 +2294,6 @@ export declare interface BufferReference { readonly pathspec?: PathSpec; } -/** ChromeHistoryArgs proto mapping. */ -export declare interface ChromeHistoryArgs { - readonly pathtype?: PathSpecPathType; - readonly getArchive?: boolean; - readonly username?: string; - readonly historyPath?: string; -} - /** ClientCrash proto mapping. */ export declare interface ClientCrash { readonly clientId?: string; @@ -2505,8 +2496,10 @@ export enum ConditionExpressionConditionType { /** CpuSeconds proto mapping. */ export declare interface CpuSeconds { - readonly userCpuTime?: ProtoFloat; - readonly systemCpuTime?: ProtoFloat; + readonly deprecatedUserCpuTime?: ProtoFloat; + readonly deprecatedSystemCpuTime?: ProtoFloat; + readonly userCpuTime?: ProtoDouble; + readonly systemCpuTime?: ProtoDouble; } /** CronJobAction proto mapping. */ @@ -2824,14 +2817,6 @@ export declare interface FileFinderStatActionOptions { readonly collectExtAttrs?: boolean; } -/** FirefoxHistoryArgs proto mapping. */ -export declare interface FirefoxHistoryArgs { - readonly pathtype?: PathSpecPathType; - readonly getArchive?: boolean; - readonly username?: string; - readonly historyPath?: string; -} - /** FleetspeakValidationInfo proto mapping. */ export declare interface FleetspeakValidationInfo { readonly tags?: readonly FleetspeakValidationInfoTag[]; @@ -3929,6 +3914,11 @@ export declare interface UnixVolume { readonly options?: string; } +/** UpdateClientArgs proto mapping. */ +export declare interface UpdateClientArgs { + readonly binaryPath?: string; +} + /** UpdateConfigurationArgs proto mapping. */ export declare interface UpdateConfigurationArgs { readonly config?: Dict; @@ -4070,6 +4060,7 @@ export declare interface YaraProcessDumpArgs { readonly skipExecutableRegions?: boolean; readonly skipReadonlyRegions?: boolean; readonly prioritizeOffsets?: readonly ProtoUint64[]; + readonly ignoreParentProcesses?: boolean; } /** YaraProcessDumpInformation proto mapping. */ @@ -4113,6 +4104,7 @@ export declare interface YaraProcessScanRequest { readonly includeErrorsInResults?: YaraProcessScanRequestErrorPolicy; readonly includeMissesInResults?: boolean; readonly ignoreGrrProcess?: boolean; + readonly ignoreParentProcesses?: boolean; readonly perProcessTimeout?: ProtoUint32; readonly chunkSize?: ProtoUint64; readonly overlapSize?: ProtoUint64; @@ -4125,6 +4117,7 @@ export declare interface YaraProcessScanRequest { readonly maxResultsPerProcess?: ProtoUint32; readonly processDumpSizeLimit?: ByteSize; readonly scanRuntimeLimitUs?: Duration; + readonly contextWindow?: ProtoUint32; readonly implementationType?: YaraProcessScanRequestImplementationType; } @@ -4153,6 +4146,7 @@ export declare interface YaraStringMatch { readonly stringId?: string; readonly offset?: ProtoUint64; readonly data?: ProtoBytes; + readonly context?: ProtoBytes; } /** protobuf2.TYPE_BOOL proto mapping. */ diff --git a/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact.ts b/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact.ts index a6c9819126..cb94e2693b 100644 --- a/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact.ts +++ b/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact.ts @@ -30,7 +30,6 @@ export function translateArtifactDescriptor( .filter(isNonNull), ), urls: [...(artifact.urls ?? [])], - provides: [...(artifact.provides ?? [])], dependencies: [...(ad.dependencies ?? [])], pathDependencies: [...(ad.pathDependencies ?? [])], isCustom: ad.isCustom ?? false, diff --git a/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact_test.ts b/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact_test.ts index 9b03f7970d..68fe9edf2c 100644 --- a/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact_test.ts +++ b/grr/server/grr_response_server/gui/ui/lib/api_translation/artifact_test.ts @@ -69,7 +69,6 @@ describe('translateArtifactDescriptor', () => { doc: 'Chrome browser history.', supportedOs: new Set([OperatingSystem.WINDOWS, OperatingSystem.LINUX]), urls: ['artifactUrl'], - provides: [], dependencies: [], sources: [ { diff --git a/grr/server/grr_response_server/gui/ui/lib/dataviz/line_chart.ts b/grr/server/grr_response_server/gui/ui/lib/dataviz/line_chart.ts index 458267d200..ff097325a9 100644 --- a/grr/server/grr_response_server/gui/ui/lib/dataviz/line_chart.ts +++ b/grr/server/grr_response_server/gui/ui/lib/dataviz/line_chart.ts @@ -524,7 +524,7 @@ export class LineChart { private setupEventListeners(): void { if (this.configuration?.sizing?.rerenderOnResize) { - this.resizeObserver = new ResizeObserver((e) => { + this.resizeObserver = new ResizeObserver(() => { const currentSizeConfig = this.configuration?.sizing || {}; const containerWidth = this.getElementWidthPx(this.parentNode); @@ -538,11 +538,13 @@ export class LineChart { widthPx: this.containerWidthPx, }; - this.setChartSize(this.parentNode, newChartSizeConfiguration); - this.resetScalesRange(); - this.setAxisScales(); - this.setAxisTicks(); - this.redrawChart(); + requestAnimationFrame(() => { + this.setChartSize(this.parentNode, newChartSizeConfiguration); + this.resetScalesRange(); + this.setAxisScales(); + this.setAxisTicks(); + this.redrawChart(); + }); }); // We listen to size changes of the chart's container element: diff --git a/grr/server/grr_response_server/gui/ui/lib/models/flow.ts b/grr/server/grr_response_server/gui/ui/lib/models/flow.ts index 0967bde842..78134c950c 100644 --- a/grr/server/grr_response_server/gui/ui/lib/models/flow.ts +++ b/grr/server/grr_response_server/gui/ui/lib/models/flow.ts @@ -316,7 +316,6 @@ export interface ArtifactDescriptor { readonly doc?: string; readonly supportedOs: ReadonlySet; readonly urls: readonly string[]; - readonly provides: readonly string[]; readonly sources: readonly ArtifactSource[]; readonly dependencies: readonly string[]; readonly pathDependencies: readonly string[]; diff --git a/grr/server/grr_response_server/gui/ui/lib/models/model_test_util.ts b/grr/server/grr_response_server/gui/ui/lib/models/model_test_util.ts index c8e141d9eb..15068c4678 100644 --- a/grr/server/grr_response_server/gui/ui/lib/models/model_test_util.ts +++ b/grr/server/grr_response_server/gui/ui/lib/models/model_test_util.ts @@ -171,7 +171,6 @@ export function newArtifactDescriptor( isCustom: false, name: 'TestAritfact', pathDependencies: [], - provides: [], sources: [], supportedOs: new Set([OperatingSystem.LINUX]), urls: [], diff --git a/grr/server/grr_response_server/gui/ui/lib/queued_exhaust_map.ts b/grr/server/grr_response_server/gui/ui/lib/queued_exhaust_map.ts index 72ef941cd0..67ea343d86 100644 --- a/grr/server/grr_response_server/gui/ui/lib/queued_exhaust_map.ts +++ b/grr/server/grr_response_server/gui/ui/lib/queued_exhaust_map.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import { EMPTY, ObservableInput, @@ -28,7 +27,10 @@ class SingleIndexTracker implements RequestTracker { } class ValueWithTag { - constructor(readonly value: V, readonly tag: T) {} + constructor( + readonly value: V, + readonly tag: T, + ) {} } function tagByIndex(indexTracker: SingleIndexTracker) { diff --git a/grr/server/grr_response_server/gui/ui/store/client_page_global_store.ts b/grr/server/grr_response_server/gui/ui/store/client_page_global_store.ts index 99006a95c9..c0cfe2abe8 100644 --- a/grr/server/grr_response_server/gui/ui/store/client_page_global_store.ts +++ b/grr/server/grr_response_server/gui/ui/store/client_page_global_store.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {Injectable} from '@angular/core'; import {ComponentStore} from '@ngrx/component-store'; import {combineLatest, merge, Observable, of, throwError} from 'rxjs'; @@ -281,7 +280,7 @@ class ClientPageComponentStore extends ComponentStore { flows, isLoading: false, hasMore: flows.length >= count, - } as FlowListState), + }) as FlowListState, ), catchError>((err) => err instanceof MissingApprovalError diff --git a/grr/server/grr_response_server/gui/ui/store/hunt_approval_local_store.ts b/grr/server/grr_response_server/gui/ui/store/hunt_approval_local_store.ts index 54338e5687..e954ecb40b 100644 --- a/grr/server/grr_response_server/gui/ui/store/hunt_approval_local_store.ts +++ b/grr/server/grr_response_server/gui/ui/store/hunt_approval_local_store.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {Injectable} from '@angular/core'; import {ComponentStore} from '@ngrx/component-store'; import {Observable, of} from 'rxjs'; @@ -62,8 +61,9 @@ class HuntApprovalComponentStore extends ComponentStore { } return this.httpApiService.subscribeToListHuntApprovals(huntId).pipe( - map((approvals: readonly ApiHuntApproval[]): HuntApproval[] => - approvals?.map(translateHuntApproval), + map( + (approvals: readonly ApiHuntApproval[]): HuntApproval[] => + approvals?.map(translateHuntApproval), ), takeWhile( (approvals: HuntApproval[]) => diff --git a/grr/server/grr_response_server/gui/ui/store/hunt_overview_page_local_store.ts b/grr/server/grr_response_server/gui/ui/store/hunt_overview_page_local_store.ts index d35a027a6e..06fa1b3c77 100644 --- a/grr/server/grr_response_server/gui/ui/store/hunt_overview_page_local_store.ts +++ b/grr/server/grr_response_server/gui/ui/store/hunt_overview_page_local_store.ts @@ -15,7 +15,7 @@ export class HuntOverviewPageLocalStore extends ApiCollectionStore< Hunt, ApiListHuntsArgs > { - override readonly INITIAL_LOAD_COUNT = 10; + override readonly INITIAL_LOAD_COUNT = 100; protected loadResults( args: ApiListHuntsArgs, diff --git a/grr/server/grr_response_server/gui/ui/store/recent_client_flows_local_store.ts b/grr/server/grr_response_server/gui/ui/store/recent_client_flows_local_store.ts index f8c475a865..ea5a9909f0 100644 --- a/grr/server/grr_response_server/gui/ui/store/recent_client_flows_local_store.ts +++ b/grr/server/grr_response_server/gui/ui/store/recent_client_flows_local_store.ts @@ -1,4 +1,3 @@ -// g3-format-changed-lines-during-prettier-version-upgrade import {Injectable} from '@angular/core'; import {ComponentStore} from '@ngrx/component-store'; import {combineLatest, Observable, of, throwError} from 'rxjs'; @@ -80,7 +79,7 @@ class RecentClientFlowsComponentStore extends ComponentStore f.startedAt)), ), - map((flows) => ({flows} as FlowListState)), + map((flows) => ({flows}) as FlowListState), catchError>((err) => err instanceof MissingApprovalError ? of({flows: []} as FlowListState) diff --git a/grr/server/grr_response_server/gui/webauth_test.py b/grr/server/grr_response_server/gui/webauth_test.py index 8dfe212239..c7ac40eed1 100644 --- a/grr/server/grr_response_server/gui/webauth_test.py +++ b/grr/server/grr_response_server/gui/webauth_test.py @@ -12,10 +12,10 @@ from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.rdfvalues import mig_crypto from grr_response_server import data_store +from grr_response_server.gui import http_request from grr_response_server.gui import http_response from grr_response_server.gui import validate_iap from grr_response_server.gui import webauth -from grr_response_server.gui import wsgiapp from grr.test_lib import test_lib @@ -36,7 +36,7 @@ def testRejectsRequestWithoutRemoteUserHeader(self): environ = werkzeug_test.EnvironBuilder(environ_base={ "REMOTE_ADDR": "127.0.0.1" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -46,7 +46,7 @@ def testRejectsRequestFromUntrustedIp(self): environ = werkzeug_test.EnvironBuilder(environ_base={ "REMOTE_ADDR": "127.0.0.2" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertRegex( @@ -59,7 +59,7 @@ def testRejectsRequestWithEmptyUsername(self): "REMOTE_ADDR": "127.0.0.1", "HTTP_X_REMOTE_USER": "" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -70,7 +70,7 @@ def testProcessesRequestWithUsernameFromTrustedIp(self): "REMOTE_ADDR": "127.0.0.1", "HTTP_X_REMOTE_USER": "foo" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual(response, self.success_response) @@ -82,7 +82,7 @@ def testProcessesRequestWithEmail_configDisabled(self): "HTTP_X_REMOTE_USER": "foo", "HTTP_X_REMOTE_EXTRA_EMAIL": "foo@bar.org", }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertIsNone(request.email) @@ -95,7 +95,7 @@ def testProcessesRequestWithEmail_configEnabled(self): "HTTP_X_REMOTE_USER": "foo", "HTTP_X_REMOTE_EXTRA_EMAIL": "foo@bar.org", }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) with test_lib.ConfigOverrider({"Email.enable_custom_email_address": True}): response = self.manager.SecurityCheck(self.HandlerStub, request) @@ -131,14 +131,14 @@ def HandlerStub(self, request, *args, **kwargs): def testPassesThroughHomepageWhenAuthorizationHeaderIsMissing(self): environ = werkzeug_test.EnvironBuilder().get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual(response, self.success_response) def testReportsErrorOnNonHomepagesWhenAuthorizationHeaderIsMissing(self): environ = werkzeug_test.EnvironBuilder(path="/foo").get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -150,7 +150,7 @@ def testReportsErrorWhenBearerPrefixIsMissing(self): path="/foo", headers={ "Authorization": "blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -165,7 +165,7 @@ def testPassesThroughHomepageOnVerificationFailure(self, mock_method): environ = werkzeug_test.EnvironBuilder(headers={ "Authorization": "Bearer blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual(response, self.success_response) @@ -179,7 +179,7 @@ def testReportsErrorOnVerificationFailureOnNonHomepage(self, mock_method): path="/foo", headers={ "Authorization": "Bearer blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -191,7 +191,7 @@ def testVerifiesTokenWithProjectIdFromDomain(self, mock_method): environ = werkzeug_test.EnvironBuilder(headers={ "Authorization": "Bearer blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual(mock_method.call_count, 1) @@ -206,7 +206,7 @@ def testReportsErrorIfIssuerIsWrong(self, mock_method): path="/foo", headers={ "Authorization": "Bearer blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) response = self.manager.SecurityCheck(self.HandlerStub, request) self.assertEqual( @@ -225,7 +225,7 @@ def testFillsRequestUserFromTokenEmailOnSuccess(self, mock_method): environ = werkzeug_test.EnvironBuilder(headers={ "Authorization": "Bearer blah" }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) self.manager.SecurityCheck(self.HandlerStub, request) @@ -239,7 +239,7 @@ def testNoHeader(self): """Test requests sent to the Admin UI without an IAP Header.""" environ = werkzeug_test.EnvironBuilder(path="/").get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) def Handler(request, *args, **kwargs): del request, args, kwargs # Unused. @@ -281,7 +281,7 @@ def testFailedSignatureKey(self, mock_get): "X-Goog-IAP-JWT-Assertion": assertion_header }, ).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) def Handler(request, *args, **kwargs): del request, args, kwargs # Unused. @@ -306,7 +306,7 @@ def testSuccessfulKey(self, mock_method): path="/", headers={ "X-Goog-IAP-JWT-Assertion": ("valid_key") }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) def Handler(request, *args, **kwargs): del args, kwargs # Unused. @@ -344,7 +344,7 @@ def testSecurityCheckUnicode(self): path="/foo", headers={ "Authorization": "Basic %s" % token, }).get_environ() - request = wsgiapp.HttpRequest(environ) + request = http_request.HttpRequest(environ) def Handler(request, *args, **kwargs): del args, kwargs # Unused. diff --git a/grr/server/grr_response_server/gui/wsgiapp.py b/grr/server/grr_response_server/gui/wsgiapp.py index 15dc907ba1..2c0f6d00a4 100644 --- a/grr/server/grr_response_server/gui/wsgiapp.py +++ b/grr/server/grr_response_server/gui/wsgiapp.py @@ -19,7 +19,6 @@ import psutil from werkzeug import exceptions as werkzeug_exceptions from werkzeug import routing as werkzeug_routing -from werkzeug import wrappers as werkzeug_wrappers from werkzeug import wsgi as werkzeug_wsgi from grr_response_core import config @@ -30,6 +29,7 @@ from grr_response_server.gui import admin_ui_metrics from grr_response_server.gui import csp from grr_response_server.gui import http_api +from grr_response_server.gui import http_request from grr_response_server.gui import http_response from grr_response_server.gui import webauth @@ -127,49 +127,6 @@ def ValidateCSRFTokenOrRaise(request): raise werkzeug_exceptions.Forbidden("Expired CSRF token") -class RequestHasNoUser(AttributeError): - """Error raised when accessing a user of an unautenticated request.""" - - -class HttpRequest(werkzeug_wrappers.Request): - """HTTP request object to be used in GRR.""" - - charset = "utf-8" - encoding_errors = "strict" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._user = None - self.email = None - - self.timestamp = rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch() - - self.method_metadata = None - self.parsed_args = None - - @property - def user(self): - if self._user is None: - raise RequestHasNoUser( - "Trying to access Request.user while user is unset.") - - if not self._user: - raise RequestHasNoUser( - "Trying to access Request.user while user is empty.") - - return self._user - - @user.setter - def user(self, value): - if not isinstance(value, Text): - message = "Expected instance of '%s' but got value '%s' of type '%s'" - message %= (Text, value, type(value)) - raise TypeError(message) - - self._user = value - - def LogAccessWrapper(func): """Decorator that ensures that HTTP access is logged.""" @@ -238,7 +195,7 @@ def __init__(self): ) def _BuildRequest(self, environ): - return HttpRequest(environ) + return http_request.HttpRequest(environ) def _HandleLegacyHomepage(self, request): admin_ui_metrics.WSGI_ROUTE.Increment(fields=["legacy"]) @@ -292,7 +249,7 @@ def _HandleHomepageV1(self, request): # present. try: StoreCSRFCookie(request.user, response) - except RequestHasNoUser: + except http_request.RequestHasNoUserError: pass return response @@ -319,7 +276,7 @@ def _HandleHomepageV2(self, request): try: StoreCSRFCookie(request.user, response) - except RequestHasNoUser: + except http_request.RequestHasNoUserError: pass return response diff --git a/grr/server/grr_response_server/handler_registry.py b/grr/server/grr_response_server/handler_registry.py index b2afac177c..2f4d5c956c 100644 --- a/grr/server/grr_response_server/handler_registry.py +++ b/grr/server/grr_response_server/handler_registry.py @@ -9,7 +9,6 @@ administrative.ClientAlertHandler, administrative.ClientStartupHandler, administrative.ClientStatsHandler, - administrative.NannyMessageHandler, foreman.ForemanMessageHandler, transfer.BlobHandler, ] diff --git a/grr/server/grr_response_server/hunt.py b/grr/server/grr_response_server/hunt.py index cc81856773..94e6f7fd67 100644 --- a/grr/server/grr_response_server/hunt.py +++ b/grr/server/grr_response_server/hunt.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """REL_DB implementation of hunts.""" +from typing import Optional + from grr_response_core.lib import rdfvalue from grr_response_core.lib import registry from grr_response_core.lib.rdfvalues import structs as rdf_structs @@ -14,6 +16,7 @@ from grr_response_server import mig_foreman_rules from grr_response_server import notification from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_flow_runner from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import objects as rdf_objects @@ -101,6 +104,7 @@ def HuntURNFromID(hunt_id): def StopHuntIfCrashLimitExceeded(hunt_id): """Stops the hunt if number of crashes exceeds the limit.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) # Do nothing if the hunt is already stopped. if hunt_obj.hunt_state == rdf_hunt_objects.Hunt.HuntState.STOPPED: @@ -115,9 +119,7 @@ def StopHuntIfCrashLimitExceeded(hunt_id): f"Hunt {hunt_obj.hunt_id} reached the crashes limit of" f" {hunt_obj.crash_limit} and was stopped." ) - hunt_state_reason = ( - rdf_hunt_objects.Hunt.HuntStateReason.TOTAL_CRASHES_EXCEEDED - ) + hunt_state_reason = hunts_pb2.Hunt.HuntStateReason.TOTAL_CRASHES_EXCEEDED StopHunt( hunt_obj.hunt_id, hunt_state_reason=hunt_state_reason, @@ -134,6 +136,7 @@ def StopHuntIfCrashLimitExceeded(hunt_id): def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): """Stops the hunt if average limites are exceeded.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) # Do nothing if the hunt is already stopped. if hunt_obj.hunt_state == rdf_hunt_objects.Hunt.HuntState.STOPPED: @@ -151,9 +154,7 @@ def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): f"Hunt {hunt_obj.hunt_id} reached the total network bytes sent limit of" f" {hunt_obj.total_network_bytes_limit} and was stopped." ) - hunt_state_reason = ( - rdf_hunt_objects.Hunt.HuntStateReason.TOTAL_NETWORK_EXCEEDED - ) + hunt_state_reason = hunts_pb2.Hunt.HuntStateReason.TOTAL_NETWORK_EXCEEDED StopHunt( hunt_obj.hunt_id, hunt_state_reason=hunt_state_reason, @@ -175,9 +176,7 @@ def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): f"Hunt {hunt_obj.hunt_id} reached the average results per client " f"limit of {hunt_obj.avg_results_per_client_limit} and was stopped." ) - hunt_state_reason = ( - rdf_hunt_objects.Hunt.HuntStateReason.AVG_RESULTS_EXCEEDED - ) + hunt_state_reason = hunts_pb2.Hunt.HuntStateReason.AVG_RESULTS_EXCEEDED StopHunt( hunt_obj.hunt_id, hunt_state_reason=hunt_state_reason, @@ -196,7 +195,7 @@ def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): f" limit of {hunt_obj.avg_cpu_seconds_per_client_limit} and was" " stopped." ) - hunt_state_reason = rdf_hunt_objects.Hunt.HuntStateReason.AVG_CPU_EXCEEDED + hunt_state_reason = hunts_pb2.Hunt.HuntStateReason.AVG_CPU_EXCEEDED StopHunt( hunt_obj.hunt_id, hunt_state_reason=hunt_state_reason, @@ -219,9 +218,7 @@ def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): f" client limit of {hunt_obj.avg_network_bytes_per_client_limit} and" " was stopped." ) - hunt_state_reason = ( - rdf_hunt_objects.Hunt.HuntStateReason.AVG_NETWORK_EXCEEDED - ) + hunt_state_reason = hunts_pb2.Hunt.HuntStateReason.AVG_NETWORK_EXCEEDED StopHunt( hunt_obj.hunt_id, hunt_state_reason=hunt_state_reason, @@ -231,11 +228,12 @@ def StopHuntIfCPUOrNetworkLimitsExceeded(hunt_id): return hunt_obj -def CompleteHuntIfExpirationTimeReached(hunt_id): +def CompleteHuntIfExpirationTimeReached(hunt_id: str) -> rdf_hunt_objects.Hunt: """Marks the hunt as complete if it's past its expiry time.""" # TODO(hanuszczak): This should not set the hunt state to `COMPLETED` but we # should have a separate `EXPIRED` state instead and set that. hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if ( hunt_obj.hunt_state not in [ @@ -246,14 +244,15 @@ def CompleteHuntIfExpirationTimeReached(hunt_id): ): StopHunt( hunt_obj.hunt_id, - rdf_hunt_objects.Hunt.HuntStateReason.DEADLINE_REACHED, + hunts_pb2.Hunt.HuntStateReason.DEADLINE_REACHED, reason_comment="Hunt completed.", ) data_store.REL_DB.UpdateHuntObject( - hunt_obj.hunt_id, hunt_state=hunt_obj.HuntState.COMPLETED + hunt_obj.hunt_id, hunt_state=hunts_pb2.Hunt.HuntState.COMPLETED ) - return data_store.REL_DB.ReadHuntObject(hunt_obj.hunt_id) + hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_obj.hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) return hunt_obj @@ -267,6 +266,10 @@ def CreateHunt(hunt_obj: hunts_pb2.Hunt): output_plugins_states = flow.GetOutputPluginStates( hunt_obj.output_plugins, source=f"hunts/{hunt_obj.hunt_id}" ) + output_plugins_states = [ + mig_flow_runner.ToProtoOutputPluginState(state) + for state in output_plugins_states + ] data_store.REL_DB.WriteHuntOutputPluginsStates( hunt_obj.hunt_id, output_plugins_states ) @@ -299,7 +302,7 @@ def CreateAndStartHunt(flow_name, flow_args, creator, **kwargs): return hunt_obj.hunt_id -def _ScheduleGenericHunt(hunt_obj): +def _ScheduleGenericHunt(hunt_obj: rdf_hunt_objects.Hunt): """Adds foreman rules for a generic hunt.""" # TODO: Migrate foreman conditions to use relation expiration # durations instead of absolute timestamps. @@ -320,7 +323,7 @@ def _ScheduleGenericHunt(hunt_obj): data_store.REL_DB.WriteForemanRule(proto_foreman_condition) -def _ScheduleVariableHunt(hunt_obj): +def _ScheduleVariableHunt(hunt_obj: rdf_hunt_objects.Hunt): """Schedules flows for a variable hunt.""" if hunt_obj.client_rate != 0: raise VariableHuntCanNotHaveClientRateError( @@ -359,10 +362,11 @@ def _ScheduleVariableHunt(hunt_obj): ) -def StartHunt(hunt_id): +def StartHunt(hunt_id) -> rdf_hunt_objects.Hunt: """Starts a hunt with a given id.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) num_hunt_clients = data_store.REL_DB.CountHuntFlows(hunt_id) if hunt_obj.hunt_state != hunt_obj.HuntState.PAUSED: @@ -370,11 +374,12 @@ def StartHunt(hunt_id): data_store.REL_DB.UpdateHuntObject( hunt_id, - hunt_state=hunt_obj.HuntState.STARTED, + hunt_state=hunts_pb2.Hunt.HuntState.STARTED, start_time=rdfvalue.RDFDatetime.Now(), num_clients_at_start_time=num_hunt_clients, ) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if hunt_obj.args.hunt_type == hunt_obj.args.HuntType.STANDARD: _ScheduleGenericHunt(hunt_obj) @@ -388,28 +393,42 @@ def StartHunt(hunt_id): return hunt_obj -def PauseHunt(hunt_id, hunt_state_reason=None, reason=None): +def PauseHunt( + hunt_id, + hunt_state_reason=None, + reason=None, +) -> rdf_hunt_objects.Hunt: """Pauses a hunt with a given id.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if hunt_obj.hunt_state != hunt_obj.HuntState.STARTED: raise OnlyStartedHuntCanBePausedError(hunt_obj) data_store.REL_DB.UpdateHuntObject( hunt_id, - hunt_state=hunt_obj.HuntState.PAUSED, + hunt_state=hunts_pb2.Hunt.HuntState.PAUSED, hunt_state_reason=hunt_state_reason, hunt_state_comment=reason, ) data_store.REL_DB.RemoveForemanRule(hunt_id=hunt_obj.hunt_id) - return data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + return hunt_obj -def StopHunt(hunt_id, hunt_state_reason=None, reason_comment=None): +def StopHunt( + hunt_id: str, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + reason_comment: Optional[str] = None, +) -> rdf_hunt_objects.Hunt: """Stops a hunt with a given id.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if hunt_obj.hunt_state not in [ hunt_obj.HuntState.STARTED, hunt_obj.HuntState.PAUSED, @@ -418,7 +437,7 @@ def StopHunt(hunt_id, hunt_state_reason=None, reason_comment=None): data_store.REL_DB.UpdateHuntObject( hunt_id, - hunt_state=hunt_obj.HuntState.STOPPED, + hunt_state=hunts_pb2.Hunt.HuntState.STOPPED, hunt_state_reason=hunt_state_reason, hunt_state_comment=reason_comment, ) @@ -426,8 +445,7 @@ def StopHunt(hunt_id, hunt_state_reason=None, reason_comment=None): # TODO: Stop matching on string (comment). if ( - hunt_state_reason - != rdf_hunt_objects.Hunt.HuntStateReason.TRIGGERED_BY_USER + hunt_state_reason != hunts_pb2.Hunt.HuntStateReason.TRIGGERED_BY_USER and reason_comment is not None and reason_comment != CANCELLED_BY_USER and hunt_obj.creator not in access_control.SYSTEM_USERS @@ -442,13 +460,21 @@ def StopHunt(hunt_id, hunt_state_reason=None, reason_comment=None): ), ) - return data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + return hunt_obj -def UpdateHunt(hunt_id, client_limit=None, client_rate=None, duration=None): +def UpdateHunt( + hunt_id, + client_limit=None, + client_rate=None, + duration=None, +) -> rdf_hunt_objects.Hunt: """Updates a hunt (it must be paused to be updated).""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if hunt_obj.hunt_state != hunt_obj.HuntState.PAUSED: raise OnlyPausedHuntCanBeModifiedError(hunt_obj) @@ -458,7 +484,9 @@ def UpdateHunt(hunt_id, client_limit=None, client_rate=None, duration=None): client_rate=client_rate, duration=duration, ) - return data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) + return hunt_obj _TIME_BETWEEN_PAUSE_CHECKS = rdfvalue.Duration.From(5, rdfvalue.SECONDS) @@ -473,6 +501,7 @@ def StartHuntFlowOnClient(client_id, hunt_id): """Starts a flow corresponding to a given hunt on a given client.""" hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) + # There may be a little race between foreman rules being removed and # foreman scheduling a client on an (already) paused hunt. Making sure # we don't lose clients in such a race by accepting clients for paused @@ -480,6 +509,7 @@ def StartHuntFlowOnClient(client_id, hunt_id): if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing(hunt_obj.hunt_state): return + hunt_obj = mig_hunt_objects.ToRDFHunt(hunt_obj) if hunt_obj.args.hunt_type == hunt_obj.args.HuntType.STANDARD: hunt_args = hunt_obj.args.standard diff --git a/grr/server/grr_response_server/hunt_test.py b/grr/server/grr_response_server/hunt_test.py index c573e8533e..e0f8bb7d90 100644 --- a/grr/server/grr_response_server/hunt_test.py +++ b/grr/server/grr_response_server/hunt_test.py @@ -15,6 +15,7 @@ from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import collection from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 from grr_response_server import data_store from grr_response_server import flow_base from grr_response_server import foreman @@ -25,6 +26,7 @@ from grr_response_server.flows.general import processes from grr_response_server.flows.general import transfer from grr_response_server.rdfvalues import hunt_objects as rdf_hunt_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_hunt_objects from grr_response_server.rdfvalues import output_plugin as rdf_output_plugin from grr.test_lib import acl_test_lib @@ -189,12 +191,10 @@ def testStopHuntWithReason(self): ) hunt_obj2 = data_store.REL_DB.ReadHuntObject(hunt_obj.hunt_id) - self.assertEqual( - hunt_obj2.hunt_state, rdf_hunt_objects.Hunt.HuntState.STOPPED - ) + self.assertEqual(hunt_obj2.hunt_state, hunts_pb2.Hunt.HuntState.STOPPED) self.assertEqual( hunt_obj2.hunt_state_reason, - rdf_hunt_objects.Hunt.HuntStateReason.AVG_NETWORK_EXCEEDED, + hunts_pb2.Hunt.HuntStateReason.AVG_NETWORK_EXCEEDED, ) self.assertEqual(hunt_obj2.hunt_state_comment, "not working") @@ -320,12 +320,10 @@ def testHuntIsPausedOnReachingClientLimit(self): ) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.PAUSED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.PAUSED) self.assertEqual( hunt_obj.hunt_state_reason, - rdf_hunt_objects.Hunt.HuntStateReason.TOTAL_CLIENTS_EXCEEDED, + hunts_pb2.Hunt.HuntStateReason.TOTAL_CLIENTS_EXCEEDED, ) hunt_counters = data_store.REL_DB.ReadHuntCounters(hunt_id) @@ -349,7 +347,10 @@ def testHuntClientRateIsAppliedCorrectly(self): self.assertLen(requests, 9) for i, (r, client_id) in enumerate(zip(requests, client_ids[1:])): self.assertEqual(r.client_id, client_id) - time_diff = r.delivery_time - ( + delivery_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + r.delivery_time + ) + time_diff = delivery_time - ( now + rdfvalue.Duration.From(1, rdfvalue.MINUTES) * (i + 1) ) self.assertLess(time_diff, rdfvalue.Duration.From(5, rdfvalue.SECONDS)) @@ -423,10 +424,10 @@ def testResultsAreCorrectlyWrittenAndAreFilterable(self): results = data_store.REL_DB.ReadHuntResults(hunt_id, 0, sys.maxsize) self.assertLen(results, 5) for r in results: - self.assertIsInstance(r.payload, rdf_file_finder.FileFinderResult) - self.assertEqual( - r.payload.stat_entry.pathspec.CollapsePath(), "/tmp/evil.txt" - ) + self.assertTrue(r.payload.Is(flows_pb2.FileFinderResult.DESCRIPTOR)) + ff_result = flows_pb2.FileFinderResult() + r.payload.Unpack(ff_result) + self.assertEqual(ff_result.stat_entry.pathspec.path, "/tmp/evil.txt") def testOutputPluginsAreCorrectlyAppliedAndTheirStatusCanBeRead(self): hunt_test_lib.StatefulDummyHuntOutputPlugin.data = [] @@ -662,16 +663,12 @@ def testHuntIsStoppedIfCrashNumberOverThreshold(self): self._RunHunt(client_ids[:2], client_mock=client_mock) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.STARTED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.STARTED) self._RunHunt(client_ids[2:], client_mock=client_mock) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.STOPPED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.STOPPED) self._CheckHuntStoppedNotification("reached the crashes limit") @@ -705,7 +702,7 @@ def CheckState(hunt_state, num_results): # Hunt should still be running: we got 1 response from 2 clients. We need # at least 3 clients to start calculating the average. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 2) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 2) self._RunHunt( [client_ids[2]], @@ -715,7 +712,7 @@ def CheckState(hunt_state, num_results): # Hunt should still be running: we got 1 response for first 2 clients and # 2 responses for the third. This is over the limit but we need at least 4 # clients to start applying thresholds. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4) self._RunHunt( [client_ids[3]], client_mock=action_mocks.ListProcessesMock([]) @@ -724,7 +721,7 @@ def CheckState(hunt_state, num_results): # Hunt should still be running: we got 1 response for first 2 clients, # 2 responses for the third and zero for the 4th. This makes it 1 result # per client on average. This is within the limit of 1. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4) self._RunHunt( client_ids[4:5], @@ -735,7 +732,7 @@ def CheckState(hunt_state, num_results): # That's more than the allowed average of 1. # Note that this check also implicitly checks that the 6th client didn't # run at all (otherwise total number of results would be 8, not 6). - CheckState(rdf_hunt_objects.Hunt.HuntState.STOPPED, 6) + CheckState(hunts_pb2.Hunt.HuntState.STOPPED, 6) self._CheckHuntStoppedNotification( "reached the average results per client" @@ -770,7 +767,7 @@ def CheckState(hunt_state, user_cpu_time, system_cpu_time): # Hunt should still be running: we need at least 3 clients to start # calculating the average. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 2, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 2, 4) self._RunHunt( [client_ids[2]], @@ -781,7 +778,7 @@ def CheckState(hunt_state, user_cpu_time, system_cpu_time): # Hunt should still be running: even though the average is higher than the # limit, number of clients is not enough. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4, 8) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4, 8) self._RunHunt( [client_ids[3]], @@ -794,7 +791,7 @@ def CheckState(hunt_state, user_cpu_time, system_cpu_time): # average per-client CPU usage. But 4 user cpu + 8 system cpu seconds for # 4 clients make an average of 3 seconds per client - this is within the # limit. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4, 8) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4, 8) self._RunHunt( [client_ids[4]], @@ -841,7 +838,7 @@ def CheckState(hunt_state, network_bytes_sent): # Hunt should still be running: we need at least 3 clients to start # calculating the average. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 2) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 2) self._RunHunt( [client_ids[2]], @@ -852,7 +849,7 @@ def CheckState(hunt_state, network_bytes_sent): # Hunt should still be running: even though the average is higher than the # limit, number of clients is not enough. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4) self._RunHunt( [client_ids[3]], @@ -864,7 +861,7 @@ def CheckState(hunt_state, network_bytes_sent): # Hunt should still be running: we got 4 clients, which is enough to check # average per-client network bytes usage, but 4 bytes for 4 clients is # within the limit of 1 byte per client on average. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4) self._RunHunt( [client_ids[4]], @@ -906,7 +903,7 @@ def CheckState(hunt_state, network_bytes_sent): ) # 4 is lower than the total limit. The hunt should still be running. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 4) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 4) self._RunHunt( [client_ids[2]], @@ -915,7 +912,7 @@ def CheckState(hunt_state, network_bytes_sent): # 5 is equal to the total limit. Total network bytes sent should # go over the limit in order for the hunt to be stopped. - CheckState(rdf_hunt_objects.Hunt.HuntState.STARTED, 5) + CheckState(hunts_pb2.Hunt.HuntState.STARTED, 5) self._RunHunt( [client_ids[3]], @@ -926,7 +923,7 @@ def CheckState(hunt_state, network_bytes_sent): # run with approximate limits (flow not persisted in the DB every time). # # 6 is greater than the total limit. The hunt should be stopped now. - # CheckState(rdf_hunt_objects.Hunt.HuntState.STOPPED, 6) + # CheckState(hunts_pb2.Hunt.HuntState.STOPPED, 6) # self._RunHunt([client_ids[4]], # client_mock=hunt_test_lib.SampleHuntMock( @@ -958,18 +955,14 @@ def testHuntIsStoppedWhenExpirationTimeIsReached(self): foreman_obj.AssignTasksToClient(client_ids[0]) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.STARTED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.STARTED) with test_lib.FakeTime( expiry_time - rdfvalue.Duration.From(1, rdfvalue.SECONDS) ): foreman_obj.AssignTasksToClient(client_ids[1]) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.STARTED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.STARTED) hunt_counters = data_store.REL_DB.ReadHuntCounters(hunt_id) self.assertEqual(hunt_counters.num_clients, 2) @@ -996,9 +989,7 @@ def testPausingTheHuntChangingParametersAndStartingAgainWorksAsExpected(self): self._RunHunt(client_ids[:2]) hunt_obj = data_store.REL_DB.ReadHuntObject(hunt_id) - self.assertEqual( - hunt_obj.hunt_state, rdf_hunt_objects.Hunt.HuntState.PAUSED - ) + self.assertEqual(hunt_obj.hunt_state, hunts_pb2.Hunt.HuntState.PAUSED) # There should be only one client, due to the limit hunt_counters = data_store.REL_DB.ReadHuntCounters(hunt_id) self.assertEqual(hunt_counters.num_clients, 1) @@ -1023,15 +1014,15 @@ def testResourceUsageStatsAreReportedCorrectly(self): # Values below are calculated based on SampleHuntMock's behavior. self.assertEqual(usage_stats.user_cpu_stats.num, 10) - self.assertAlmostEqual(usage_stats.user_cpu_stats.mean, 5.5) + self.assertAlmostEqual(usage_stats.user_cpu_stats.sum, 55) self.assertAlmostEqual(usage_stats.user_cpu_stats.stddev, 2.8722813) self.assertEqual(usage_stats.system_cpu_stats.num, 10) - self.assertAlmostEqual(usage_stats.system_cpu_stats.mean, 11) + self.assertAlmostEqual(usage_stats.system_cpu_stats.sum, 110) self.assertAlmostEqual(usage_stats.system_cpu_stats.stddev, 5.7445626) self.assertEqual(usage_stats.network_bytes_sent_stats.num, 10) - self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.mean, 16.5) + self.assertAlmostEqual(usage_stats.network_bytes_sent_stats.sum, 165) self.assertAlmostEqual( usage_stats.network_bytes_sent_stats.stddev, 8.61684396 ) @@ -1193,8 +1184,9 @@ def testVariableHuntSchedulesAllFlowsOnStart(self): self.assertEqual( all_flows[0].flow_class_name, transfer.GetFile.__name__ ) + rdf_flow = mig_flow_objects.ToRDFFlow(all_flows[0]) self.assertEqual( - all_flows[0].args.pathspec.path, "/tmp/evil_%d.txt" % index + rdf_flow.args.pathspec.path, "/tmp/evil_%d.txt" % index ) def testHuntIDFromURN(self): diff --git a/grr/server/grr_response_server/key_utils.py b/grr/server/grr_response_server/key_utils.py deleted file mode 100644 index 00fb14d1bb..0000000000 --- a/grr/server/grr_response_server/key_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python -"""This file abstracts the loading of the private key.""" - -from cryptography import x509 -from cryptography.hazmat.backends import openssl -from cryptography.hazmat.primitives import hashes -from cryptography.x509 import oid - -from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto - - -def MakeCASignedCert(common_name, - private_key, - ca_cert, - ca_private_key, - serial_number=2): - """Make a cert and sign it with the CA's private key.""" - public_key = private_key.GetPublicKey() - - builder = x509.CertificateBuilder() - - builder = builder.issuer_name(ca_cert.GetIssuer()) - - subject = x509.Name( - [x509.NameAttribute(oid.NameOID.COMMON_NAME, common_name)]) - builder = builder.subject_name(subject) - - valid_from = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) - valid_until = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 3650, rdfvalue.DAYS) - builder = builder.not_valid_before(valid_from.AsDatetime()) - builder = builder.not_valid_after(valid_until.AsDatetime()) - - builder = builder.serial_number(serial_number) - builder = builder.public_key(public_key.GetRawPublicKey()) - - builder = builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True) - certificate = builder.sign( - private_key=ca_private_key.GetRawPrivateKey(), - algorithm=hashes.SHA256(), - backend=openssl.backend) - return rdf_crypto.RDFX509Cert(certificate) - - -def MakeCACert(private_key, - common_name=u"grr", - issuer_cn=u"grr_test", - issuer_c=u"US"): - """Generate a CA certificate. - - Args: - private_key: The private key to use. - common_name: Name for cert. - issuer_cn: Name for issuer. - issuer_c: Country for issuer. - - Returns: - The certificate. - """ - public_key = private_key.GetPublicKey() - - builder = x509.CertificateBuilder() - - issuer = x509.Name([ - x509.NameAttribute(oid.NameOID.COMMON_NAME, issuer_cn), - x509.NameAttribute(oid.NameOID.COUNTRY_NAME, issuer_c) - ]) - subject = x509.Name( - [x509.NameAttribute(oid.NameOID.COMMON_NAME, common_name)]) - builder = builder.subject_name(subject) - builder = builder.issuer_name(issuer) - - valid_from = rdfvalue.RDFDatetime.Now() - rdfvalue.Duration.From( - 1, rdfvalue.DAYS) - valid_until = rdfvalue.RDFDatetime.Now() + rdfvalue.Duration.From( - 3650, rdfvalue.DAYS) - builder = builder.not_valid_before(valid_from.AsDatetime()) - builder = builder.not_valid_after(valid_until.AsDatetime()) - - builder = builder.serial_number(1) - builder = builder.public_key(public_key.GetRawPublicKey()) - - builder = builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True) - builder = builder.add_extension( - x509.SubjectKeyIdentifier.from_public_key(public_key.GetRawPublicKey()), - critical=False) - certificate = builder.sign( - private_key=private_key.GetRawPrivateKey(), - algorithm=hashes.SHA256(), - backend=openssl.backend) - return rdf_crypto.RDFX509Cert(certificate) diff --git a/grr/server/grr_response_server/keystore/abstract.py b/grr/server/grr_response_server/keystore/abstract.py index 76d0129004..112f089067 100644 --- a/grr/server/grr_response_server/keystore/abstract.py +++ b/grr/server/grr_response_server/keystore/abstract.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with the definition of the abstract keystore.""" + import abc diff --git a/grr/server/grr_response_server/keystore/abstract_test_lib.py b/grr/server/grr_response_server/keystore/abstract_test_lib.py index ef338bc428..a1b2b2d6ab 100644 --- a/grr/server/grr_response_server/keystore/abstract_test_lib.py +++ b/grr/server/grr_response_server/keystore/abstract_test_lib.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Common utilities for testing keystore implementations.""" + import abc from typing import Sequence diff --git a/grr/server/grr_response_server/keystore/cached.py b/grr/server/grr_response_server/keystore/cached.py index ffc14a063c..278796a614 100644 --- a/grr/server/grr_response_server/keystore/cached.py +++ b/grr/server/grr_response_server/keystore/cached.py @@ -1,10 +1,9 @@ #!/usr/bin/env python """A module with implementation of the cached keystore.""" + import dataclasses import datetime -from typing import Generic -from typing import Optional -from typing import TypeVar +from typing import Generic, Optional, TypeVar from grr_response_server.keystore import abstract @@ -43,7 +42,8 @@ def Crypter(self, name: str) -> abstract.Crypter: except KeyError: entry = CachedKeystore._CacheEntry( crypter=self._delegate.Crypter(name), - expiration_time=datetime.datetime.now() + self._validity_duration) + expiration_time=datetime.datetime.now() + self._validity_duration, + ) self._cache[name] = entry return entry.crypter @@ -57,6 +57,7 @@ def Crypter(self, name: str) -> abstract.Crypter: @dataclasses.dataclass(frozen=True) class _CacheEntry(Generic[_T]): """An entry of the cache dictionary.""" + crypter: abstract.Crypter expiration_time: datetime.datetime diff --git a/grr/server/grr_response_server/keystore/cached_test.py b/grr/server/grr_response_server/keystore/cached_test.py index 8654f7c788..fada20dc57 100644 --- a/grr/server/grr_response_server/keystore/cached_test.py +++ b/grr/server/grr_response_server/keystore/cached_test.py @@ -24,7 +24,8 @@ def testCrypterCached(self): # We create a keystore where all the keys expire after 128 weeks (so, enough # for the test to execute without expiring anything. cached_ks = cached.CachedKeystore( - mem_ks, validity_duration=datetime.timedelta(weeks=128)) + mem_ks, validity_duration=datetime.timedelta(weeks=128) + ) crypter_1 = cached_ks.Crypter("foo") crypter_2 = cached_ks.Crypter("foo") @@ -36,7 +37,8 @@ def testCrypterExpired(self): # We create a keystore where all the keys have no validity duration (meaning # the keystore should expire them all the time). cached_ks = cached.CachedKeystore( - mem_ks, validity_duration=datetime.timedelta(0)) + mem_ks, validity_duration=datetime.timedelta(0) + ) crypter_1 = cached_ks.Crypter("foo") crypter_2 = cached_ks.Crypter("foo") diff --git a/grr/server/grr_response_server/keystore/mem.py b/grr/server/grr_response_server/keystore/mem.py index 789fb8e7c7..862d42a32a 100644 --- a/grr/server/grr_response_server/keystore/mem.py +++ b/grr/server/grr_response_server/keystore/mem.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with implementation of the in-memory keystore.""" + import itertools import os from typing import Sequence @@ -56,7 +57,7 @@ def Decrypt(self, data: bytes, assoc_data: bytes) -> bytes: key = itertools.cycle(self._key) unencrypted_data = bytes(db ^ kb for db, kb in zip(data, key)) - if unencrypted_data[-len(assoc_data):] != assoc_data: + if unencrypted_data[-len(assoc_data) :] != assoc_data: raise abstract.DecryptionError("Incorrect associated data") - return unencrypted_data[:-len(assoc_data)] + return unencrypted_data[: -len(assoc_data)] diff --git a/grr/server/grr_response_server/maintenance_utils.py b/grr/server/grr_response_server/maintenance_utils.py index 01574ae5fa..42fc1aae3a 100644 --- a/grr/server/grr_response_server/maintenance_utils.py +++ b/grr/server/grr_response_server/maintenance_utils.py @@ -8,8 +8,6 @@ from grr_api_client import api from grr_response_core import config from grr_response_core.lib import rdfvalue -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto -from grr_response_server import key_utils from grr_response_server import signed_binary_utils from grr_response_server.bin import api_shell_raw_access_lib from grr_response_server.gui import api_call_context @@ -71,54 +69,3 @@ def UploadSignedConfigBlob(content, aff4_path, client_context=None, limit=None): chunk_size=limit) logging.info("Uploaded to %s", aff4_path) - - -def RotateServerKey(cn=u"grr", keylength=4096): - """This function creates and installs a new server key. - - Note that - - - Clients might experience intermittent connection problems after - the server keys rotated. - - - It's not possible to go back to an earlier key. Clients that see a - new certificate will remember the cert's serial number and refuse - to accept any certificate with a smaller serial number from that - point on. - - Args: - cn: The common name for the server to use. - keylength: Length in bits for the new server key. - - Raises: - ValueError: There is no CA cert in the config. Probably the server - still needs to be initialized. - """ - ca_certificate = config.CONFIG["CA.certificate"] - ca_private_key = config.CONFIG["PrivateKeys.ca_key"] - - if not ca_certificate or not ca_private_key: - raise ValueError("No existing CA certificate found.") - - # Check the current certificate serial number - existing_cert = config.CONFIG["Frontend.certificate"] - - serial_number = existing_cert.GetSerialNumber() + 1 - EPrint("Generating new server key (%d bits, cn '%s', serial # %d)" % - (keylength, cn, serial_number)) - - server_private_key = rdf_crypto.RSAPrivateKey.GenerateKey(bits=keylength) - server_cert = key_utils.MakeCASignedCert( - str(cn), - server_private_key, - ca_certificate, - ca_private_key, - serial_number=serial_number) - - EPrint("Updating configuration.") - config.CONFIG.Set("Frontend.certificate", server_cert.AsPEM().decode("ascii")) - config.CONFIG.Set("PrivateKeys.server_key", - server_private_key.AsPEM().decode("ascii")) - config.CONFIG.Write() - - EPrint("Server key rotated, please restart the GRR Frontends.") diff --git a/grr/server/grr_response_server/message_handlers.py b/grr/server/grr_response_server/message_handlers.py index 8f8a812ef3..53988aca9e 100644 --- a/grr/server/grr_response_server/message_handlers.py +++ b/grr/server/grr_response_server/message_handlers.py @@ -13,8 +13,6 @@ "ClientAlertHandler", str(rdfvalue.SessionID(flow_name="Foreman")): "ForemanHandler", - str(rdfvalue.SessionID(flow_name="NannyMessage")): - "NannyMessageHandler", str(rdfvalue.SessionID(flow_name="Startup")): "ClientStartupHandler", str(rdfvalue.SessionID(flow_name="TransferStore")): diff --git a/grr/server/grr_response_server/models/blobs.py b/grr/server/grr_response_server/models/blobs.py index e4bf6040f9..bcde037845 100644 --- a/grr/server/grr_response_server/models/blobs.py +++ b/grr/server/grr_response_server/models/blobs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Module with data models and helpers related to blobs.""" + import binascii import hashlib diff --git a/grr/server/grr_response_server/models/events.py b/grr/server/grr_response_server/models/events.py new file mode 100644 index 0000000000..1bc1f5506d --- /dev/null +++ b/grr/server/grr_response_server/models/events.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +"""Event related helpers.""" + +from grr_response_proto import objects_pb2 +from grr_response_server.gui import http_request +from grr_response_server.gui import http_response + +_HTTP_STATUS_TO_CODE = { + 200: objects_pb2.APIAuditEntry.Code.OK, + 403: objects_pb2.APIAuditEntry.Code.FORBIDDEN, + 404: objects_pb2.APIAuditEntry.Code.NOT_FOUND, + 500: objects_pb2.APIAuditEntry.Code.ERROR, + 501: objects_pb2.APIAuditEntry.Code.NOT_IMPLEMENTED, +} + + +def APIAuditEntryFromHttpRequestResponse( + request: http_request.HttpRequest, + response: http_response.HttpResponse, +) -> objects_pb2.APIAuditEntry: + response_code = _HTTP_STATUS_TO_CODE.get( + response.status_code, objects_pb2.APIAuditEntry.Code.ERROR + ) + + return objects_pb2.APIAuditEntry( + http_request_path=request.full_path, # Includes query string. + router_method_name=response.headers.get("X-API-Method", ""), + username=request.user, + response_code=response_code, + ) diff --git a/grr/server/grr_response_server/models/events_test.py b/grr/server/grr_response_server/models/events_test.py new file mode 100644 index 0000000000..5a5fc04bee --- /dev/null +++ b/grr/server/grr_response_server/models/events_test.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +from absl.testing import absltest + +from grr_response_proto import objects_pb2 +from grr_response_server.gui import http_request +from grr_response_server.gui import http_response +from grr_response_server.models import events + + +class APIAuditEntryFromHttpRequestResponseTest(absltest.TestCase): + + def testBaseFields(self): + request = http_request.HttpRequest.from_values( + "/bar?foo=baz", "http://example.com/test" + ) + request.user = "testuser" + + response = http_response.HttpResponse( + status=42, + headers={"X-API-Method": "TestMethod"}, + ) + + expected = objects_pb2.APIAuditEntry( + http_request_path="/bar?foo=baz", # Includes query string. + router_method_name="TestMethod", + username="testuser", + response_code=objects_pb2.APIAuditEntry.Code.ERROR, + ) + + result = events.APIAuditEntryFromHttpRequestResponse(request, response) + self.assertEqual(expected, result) + + def testStatus(self): + request = http_request.HttpRequest({}) + request.user = "needs_to_be_set" + + # Make sure we always test everything in the dict + for status, want_code in events._HTTP_STATUS_TO_CODE.items(): + response = http_response.HttpResponse(status=status) + result = events.APIAuditEntryFromHttpRequestResponse(request, response) + self.assertEqual(want_code, result.response_code) + + def testStatusDefault(self): + request = http_request.HttpRequest({}) + request.user = "needs_to_be_set" + response = http_response.HttpResponse(status=42) + result = events.APIAuditEntryFromHttpRequestResponse(request, response) + self.assertEqual(objects_pb2.APIAuditEntry.Code.ERROR, result.response_code) + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/models/protodicts.py b/grr/server/grr_response_server/models/protodicts.py new file mode 100644 index 0000000000..277e6d95e0 --- /dev/null +++ b/grr/server/grr_response_server/models/protodicts.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +"""Module with helpers related to `Dict` messages.""" + +from typing import Union + +from grr_response_proto import jobs_pb2 + + +DataBlobValue = Union[ + bool, + int, + float, + bytes, + str, + list["DataBlobValue"], + dict["DataBlobValue", "DataBlobValue"], +] + + +def DataBlob(value: DataBlobValue) -> jobs_pb2.DataBlob: + """Creates a `DataBlob` message from the given Python value. + + Args: + value: A Python value to convert to a `DataBlob` message. + + Returns: + A `DataBlob` message corresponding to the given value. + """ + if isinstance(value, bool): + return jobs_pb2.DataBlob(boolean=value) + if isinstance(value, int): + return jobs_pb2.DataBlob(integer=value) + if isinstance(value, float): + return jobs_pb2.DataBlob(float=value) + if isinstance(value, str): + return jobs_pb2.DataBlob(string=value) + if isinstance(value, bytes): + return jobs_pb2.DataBlob(data=value) + if isinstance(value, list): + result = jobs_pb2.DataBlob() + result.list.CopyFrom(BlobArray(value)) + return result + if isinstance(value, dict): + result = jobs_pb2.DataBlob() + result.dict.CopyFrom(Dict(value)) + return result + + raise TypeError(f"Unexpected type: {type(value)}") + + +def BlobArray(values: list[DataBlobValue]) -> jobs_pb2.BlobArray: + """Creates a `BlobArray` message from the given list of Python values. + + Args: + values: A list of Python values to convert to a `BlobArray` message. + + Returns: + A `BlobArray` message corresponding to the given Python list. + """ + result = jobs_pb2.BlobArray() + + for value in values: + result.content.append(DataBlob(value)) + + return result + + +def Dict(dikt: dict[DataBlobValue, DataBlobValue]) -> jobs_pb2.Dict: + """Creates a `Dict` message from the given Python dictionary. + + Args: + dikt: A dictionary of Python values to convert to a `Dict` message. + + Returns: + A `Dict` message corresponding to the given Python dictionary. + """ + result = jobs_pb2.Dict() + + for key, value in dikt.items(): + result.dat.add(k=DataBlob(key), v=DataBlob(value)) + + return result diff --git a/grr/server/grr_response_server/models/protodicts_test.py b/grr/server/grr_response_server/models/protodicts_test.py new file mode 100644 index 0000000000..0329f8b9a5 --- /dev/null +++ b/grr/server/grr_response_server/models/protodicts_test.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +from absl.testing import absltest + +from grr_response_server.models import protodicts + + +class DataBlobTest(absltest.TestCase): + + def testBool(self): + self.assertEqual(protodicts.DataBlob(True).boolean, True) + + def testInt(self): + self.assertEqual(protodicts.DataBlob(1337).integer, 1337) + + def testFloat(self): + self.assertEqual(protodicts.DataBlob(0.5).float, 0.5) + + def testBytes(self): + self.assertEqual(protodicts.DataBlob(b"\x00\xff\x00").data, b"\x00\xff\x00") + + def testStr(self): + self.assertEqual(protodicts.DataBlob("foobar").string, "foobar") + + def testList(self): + proto = protodicts.DataBlob([1, 3, 3, 7]) + + self.assertLen(proto.list.content, 4) + self.assertEqual(proto.list.content[0].integer, 1) + self.assertEqual(proto.list.content[1].integer, 3) + self.assertEqual(proto.list.content[2].integer, 3) + self.assertEqual(proto.list.content[3].integer, 7) + + def testDict(self): + proto = protodicts.DataBlob({ + "foo": 42, + "bar": 1337, + }) + + self.assertLen(proto.dict.dat, 2) + self.assertEqual(proto.dict.dat[0].k.string, "foo") + self.assertEqual(proto.dict.dat[0].v.integer, 42) + self.assertEqual(proto.dict.dat[1].k.string, "bar") + self.assertEqual(proto.dict.dat[1].v.integer, 1337) + + +class BlobArrayTest(absltest.TestCase): + + def testEmpty(self): + proto = protodicts.BlobArray([]) + + self.assertEmpty(proto.content) + + def testSingleton(self): + proto = protodicts.BlobArray(["foo"]) + + self.assertLen(proto.content, 1) + self.assertEqual(proto.content[0].string, "foo") + + def testHomogeneous(self): + proto = protodicts.BlobArray(["foo", "bar", "baz"]) + + self.assertLen(proto.content, 3) + self.assertEqual(proto.content[0].string, "foo") + self.assertEqual(proto.content[1].string, "bar") + self.assertEqual(proto.content[2].string, "baz") + + def testHeterogeneous(self): + proto = protodicts.BlobArray([42, "foo", 0.5]) + + self.assertLen(proto.content, 3) + self.assertEqual(proto.content[0].integer, 42) + self.assertEqual(proto.content[1].string, "foo") + self.assertEqual(proto.content[2].float, 0.5) + + def testRepeated(self): + proto = protodicts.BlobArray([1, 3, 3, 7]) + + self.assertLen(proto.content, 4) + self.assertEqual(proto.content[0].integer, 1) + self.assertEqual(proto.content[1].integer, 3) + self.assertEqual(proto.content[2].integer, 3) + self.assertEqual(proto.content[3].integer, 7) + + def testNested(self): + proto = protodicts.BlobArray([["foo", "bar"], ["quux"]]) + + self.assertLen(proto.content, 2) + self.assertLen(proto.content[0].list.content, 2) + self.assertLen(proto.content[1].list.content, 1) + self.assertEqual(proto.content[0].list.content[0].string, "foo") + self.assertEqual(proto.content[0].list.content[1].string, "bar") + self.assertEqual(proto.content[1].list.content[0].string, "quux") + + +class DictTest(absltest.TestCase): + + def testEmpty(self): + proto = protodicts.Dict({}) + + self.assertEmpty(proto.dat) + + def testSingleton(self): + proto = protodicts.Dict({"foo": 42}) + + self.assertLen(proto.dat, 1) + self.assertEqual(proto.dat[0].k.string, "foo") + self.assertEqual(proto.dat[0].v.integer, 42) + + def testHomogeneous(self): + proto = protodicts.Dict({ + "foo": 0xC0DE, + "bar": 0xBEEF, + "quux": 0xC0FE, + }) + + self.assertLen(proto.dat, 3) + self.assertEqual(proto.dat[0].k.string, "foo") + self.assertEqual(proto.dat[0].v.integer, 0xC0DE) + self.assertEqual(proto.dat[1].k.string, "bar") + self.assertEqual(proto.dat[1].v.integer, 0xBEEF) + self.assertEqual(proto.dat[2].k.string, "quux") + self.assertEqual(proto.dat[2].v.integer, 0xC0FE) + + def testHeterogeneous(self): + proto = protodicts.Dict({ + "foo": 0.5, + 1337: b"\x00\xFF\x00", + }) + + self.assertLen(proto.dat, 2) + self.assertEqual(proto.dat[0].k.string, "foo") + self.assertEqual(proto.dat[0].v.float, 0.5) + self.assertEqual(proto.dat[1].k.integer, 1337) + self.assertEqual(proto.dat[1].v.data, b"\x00\xFF\x00") + + def testNested(self): + proto = protodicts.Dict({ + "foo": { + "bar": "baz", + }, + "quux": { + "norf": "thud", + }, + }) + + self.assertLen(proto.dat, 2) + self.assertLen(proto.dat[0].v.dict.dat, 1) + self.assertLen(proto.dat[1].v.dict.dat, 1) + self.assertEqual(proto.dat[0].k.string, "foo") + self.assertEqual(proto.dat[0].v.dict.dat[0].k.string, "bar") + self.assertEqual(proto.dat[0].v.dict.dat[0].v.string, "baz") + self.assertEqual(proto.dat[1].k.string, "quux") + self.assertEqual(proto.dat[1].v.dict.dat[0].k.string, "norf") + self.assertEqual(proto.dat[1].v.dict.dat[0].v.string, "thud") + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/models/users.py b/grr/server/grr_response_server/models/users.py index c439794079..356c5e2cf2 100644 --- a/grr/server/grr_response_server/models/users.py +++ b/grr/server/grr_response_server/models/users.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides user-related data models and helpers.""" + from grr_response_core import config from grr_response_proto import objects_pb2 diff --git a/grr/server/grr_response_server/output_plugins/bigquery_plugin.py b/grr/server/grr_response_server/output_plugins/bigquery_plugin.py index a3b6ea2250..6cf81499f0 100644 --- a/grr/server/grr_response_server/output_plugins/bigquery_plugin.py +++ b/grr/server/grr_response_server/output_plugins/bigquery_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """BigQuery output plugin.""" + import base64 import gzip import json @@ -18,14 +19,16 @@ from grr_response_server.export_converters import base -class TempOutputTracker(object): +class TempOutputTracker: """Track temp output files for BigQuery JSON data and schema.""" - def __init__(self, - output_type=None, - gzip_filehandle=None, - gzip_filehandle_parent=None, - schema=None): + def __init__( + self, + output_type=None, + gzip_filehandle=None, + gzip_filehandle_parent=None, + schema=None, + ): """Create tracker. This class is used to track a gzipped filehandle for each type of output @@ -75,7 +78,7 @@ class BigQueryOutputPlugin(output_plugin.OutputPlugin): "bool": "BOOLEAN", "float": "FLOAT", "uint32": "INTEGER", - "uint64": "INTEGER" + "uint64": "INTEGER", } def __init__(self, *args, **kwargs): @@ -93,15 +96,17 @@ def UpdateState(self, state): def ProcessResponses(self, state, responses): default_metadata = base.ExportedMetadata( - annotations=u",".join(self.args.export_options.annotations), - source_urn=self.source_urn) + annotations=",".join(self.args.export_options.annotations), + source_urn=self.source_urn, + ) responses = [r.AsLegacyGrrMessage() for r in responses] if self.args.convert_values: # This is thread-safe - we just convert the values. converted_responses = export.ConvertValues( - default_metadata, responses, options=self.args.export_options) + default_metadata, responses, options=self.args.export_options + ) else: converted_responses = responses @@ -173,13 +178,17 @@ def _CreateOutputFileHandles(self, output_type): A TempOutputTracker object """ gzip_filehandle_parent = tempfile.NamedTemporaryFile(suffix=output_type) - gzip_filehandle = gzip.GzipFile(gzip_filehandle_parent.name, "wb", - self.GZIP_COMPRESSION_LEVEL, - gzip_filehandle_parent) + gzip_filehandle = gzip.GzipFile( + gzip_filehandle_parent.name, + "wb", + self.GZIP_COMPRESSION_LEVEL, + gzip_filehandle_parent, + ) self.temp_output_trackers[output_type] = TempOutputTracker( output_type=output_type, gzip_filehandle=gzip_filehandle, - gzip_filehandle_parent=gzip_filehandle_parent) + gzip_filehandle_parent=gzip_filehandle_parent, + ) return self.temp_output_trackers[output_type] def _GetTempOutputFileHandles(self, value_type): @@ -193,8 +202,13 @@ def Flush(self, state): """Finish writing JSON files, upload to cloudstorage and bigquery.""" self.bigquery = bigquery.GetBigQueryClient() # BigQuery job ids must be alphanum plus dash and underscore. - urn_str = rdfvalue.RDFURN(self.source_urn).RelativeName("aff4:/").replace( - "/", "_").replace(":", "").replace(".", "-") + urn_str = ( + rdfvalue.RDFURN(self.source_urn) + .RelativeName("aff4:/") + .replace("/", "_") + .replace(":", "") + .replace(".", "-") + ) for tracker in self.temp_output_trackers.values(): # Close out the gzip handle and pass the original file handle to the @@ -205,8 +219,10 @@ def Flush(self, state): # e.g. job_id: hunts_HFFE1D044_Results_ExportedFile_1446056474 job_id = "{0}_{1}_{2}".format( - urn_str, tracker.output_type, - rdfvalue.RDFDatetime.Now().AsSecondsSinceEpoch()) + urn_str, + tracker.output_type, + rdfvalue.RDFDatetime.Now().AsSecondsSinceEpoch(), + ) # If we have a job id stored, that means we failed last time. Re-use the # job id and append to the same file if it continues to fail. This avoids @@ -216,16 +232,22 @@ def Flush(self, state): else: self.output_jobids[tracker.output_type] = job_id - if (state.failure_count + self.failure_count >= - config.CONFIG["BigQuery.max_upload_failures"]): + if ( + state.failure_count + self.failure_count + >= config.CONFIG["BigQuery.max_upload_failures"] + ): logging.error( "Exceeded BigQuery.max_upload_failures for %s, giving up.", - self.source_urn) + self.source_urn, + ) else: try: - self.bigquery.InsertData(tracker.output_type, - tracker.gzip_filehandle_parent, - tracker.schema, job_id) + self.bigquery.InsertData( + tracker.output_type, + tracker.gzip_filehandle_parent, + tracker.schema, + job_id, + ) self.failure_count = max(0, self.failure_count - 1) del self.output_jobids[tracker.output_type] except bigquery.BigQueryJobUploadError: @@ -244,18 +266,21 @@ def RDFValueToBigQuerySchema(self, value): "name": type_info.name, "type": "RECORD", "description": type_info.description, - "fields": self.RDFValueToBigQuerySchema(value.Get(type_info.name)) + "fields": self.RDFValueToBigQuerySchema(value.Get(type_info.name)), }) else: # If we don't have a specific map use string. - bq_type = self.RDF_BIGQUERY_TYPE_MAP.get(type_info.proto_type_name, - None) or "STRING" + bq_type = ( + self.RDF_BIGQUERY_TYPE_MAP.get(type_info.proto_type_name, None) + or "STRING" + ) # For protos with RDF types we need to do some more checking to properly # covert types. if hasattr(type_info, "original_proto_type_name"): if type_info.original_proto_type_name in [ - "RDFDatetime", "RDFDatetimeSeconds" + "RDFDatetime", + "RDFDatetimeSeconds", ]: bq_type = "TIMESTAMP" elif type_info.proto_type_name == "uint64": @@ -268,7 +293,7 @@ def RDFValueToBigQuerySchema(self, value): fields_array.append({ "name": type_info.name, "type": bq_type, - "description": type_info.description + "description": type_info.description, }) return fields_array @@ -312,7 +337,8 @@ def WriteValuesToJSONFile(self, state, values): self._WriteJSONValue(output_tracker.gzip_filehandle, value) else: self._WriteJSONValue( - output_tracker.gzip_filehandle, value, delimiter="\n") + output_tracker.gzip_filehandle, value, delimiter="\n" + ) for output_tracker in self.temp_output_trackers.values(): output_tracker.gzip_filehandle.flush() diff --git a/grr/server/grr_response_server/output_plugins/bigquery_plugin_test.py b/grr/server/grr_response_server/output_plugins/bigquery_plugin_test.py index eece0e4bf9..d341203ac7 100644 --- a/grr/server/grr_response_server/output_plugins/bigquery_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/bigquery_plugin_test.py @@ -29,16 +29,19 @@ class BigQueryOutputPluginTest(flow_test_lib.FlowTestsBaseclass): def setUp(self): super().setUp() self.client_id = self.SetupClient(0) - self.source_id = rdf_client.ClientURN( - self.client_id).Add("Results").RelativeName("aff4:/") - - def ProcessResponses(self, - plugin_args=None, - responses=None, - process_responses_separately=False): + self.source_id = ( + rdf_client.ClientURN(self.client_id) + .Add("Results") + .RelativeName("aff4:/") + ) + + def ProcessResponses( + self, plugin_args=None, responses=None, process_responses_separately=False + ): plugin_cls = bigquery_plugin.BigQueryOutputPlugin plugin, plugin_state = plugin_cls.CreatePluginAndDefaultState( - source_urn=self.source_id, args=plugin_args) + source_urn=self.source_id, args=plugin_args + ) messages = [] for response in responses: @@ -62,8 +65,9 @@ def ProcessResponses(self, return [x[0] for x in mock_bigquery.return_value.InsertData.call_args_list] def CompareSchemaToKnownGood(self, schema): - expected_schema_path = os.path.join(config.CONFIG["Test.data_dir"], - "bigquery", "ExportedFile.schema") + expected_schema_path = os.path.join( + config.CONFIG["Test.data_dir"], "bigquery", "ExportedFile.schema" + ) with open(expected_schema_path, mode="rt", encoding="utf-8") as file: expected_schema_data = json.load(file) @@ -88,7 +92,8 @@ def testBigQueryPluginWithValuesOfSameType(self): responses.append( rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"), + path="/foo/bar/%d" % i, pathtype="OS" + ), st_mode=33184, # octal = 100640 => u=rw,g=r,o= => -rw-r----- st_ino=1063090, st_dev=64512, @@ -99,28 +104,38 @@ def testBigQueryPluginWithValuesOfSameType(self): st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1338111338)) + st_btime=1338111338, + ) + ) output = self.ProcessResponses( plugin_args=bigquery_plugin.BigQueryOutputPluginArgs(), - responses=responses) + responses=responses, + ) self.assertLen(output, 1) _, stream, schema, job_id = output[0] - self.assertEqual(job_id, - "C-1000000000000000_Results_ExportedFile_1445995873") + self.assertEqual( + job_id, "C-1000000000000000_Results_ExportedFile_1445995873" + ) self.CompareSchemaToKnownGood(schema) actual_fd = gzip.GzipFile( - None, "r", bigquery_plugin.BigQueryOutputPlugin.GZIP_COMPRESSION_LEVEL, - stream) + None, + "r", + bigquery_plugin.BigQueryOutputPlugin.GZIP_COMPRESSION_LEVEL, + stream, + ) # Compare to our stored data. expected_fd = open( - os.path.join(config.CONFIG["Test.data_dir"], "bigquery", - "ExportedFile.jsonlines"), "rb") + os.path.join( + config.CONFIG["Test.data_dir"], "bigquery", "ExportedFile.jsonlines" + ), + "rb", + ) # Bigquery expects a newline separarted list of JSON dicts, but this isn't # valid JSON so we can't just load the whole thing and compare. @@ -203,19 +218,25 @@ def _parseOutput(self, name, stream): row = json.loads(item.decode("utf-8")) if name == "ExportedFile": - self.assertEqual(row["metadata"]["client_urn"], - "aff4:/%s" % self.client_id) + self.assertEqual( + row["metadata"]["client_urn"], "aff4:/%s" % self.client_id + ) self.assertEqual(row["metadata"]["hostname"], "Host-0.example.com") - self.assertEqual(row["metadata"]["mac_address"], - "aabbccddee00\nbbccddeeff00") + self.assertEqual( + row["metadata"]["mac_address"], "aabbccddee00\nbbccddeeff00" + ) self.assertEqual(row["metadata"]["source_urn"], source_urn) - self.assertEqual(row["urn"], "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id) + self.assertEqual( + row["urn"], "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id + ) else: - self.assertEqual(row["metadata"]["client_urn"], - "aff4:/%s" % self.client_id) + self.assertEqual( + row["metadata"]["client_urn"], "aff4:/%s" % self.client_id + ) self.assertEqual(row["metadata"]["hostname"], "Host-0.example.com") - self.assertEqual(row["metadata"]["mac_address"], - "aabbccddee00\nbbccddeeff00") + self.assertEqual( + row["metadata"]["mac_address"], "aabbccddee00\nbbccddeeff00" + ) self.assertEqual(row["metadata"]["source_urn"], source_urn) self.assertEqual(row["pid"], "42") @@ -227,19 +248,26 @@ def testBigQueryPluginWithValuesOfMultipleTypes(self): plugin_args=bigquery_plugin.BigQueryOutputPluginArgs(), responses=[ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/中国新闻网新闻中", pathtype="OS")), - rdf_client.Process(pid=42) + pathspec=rdf_paths.PathSpec( + path="/中国新闻网新闻中", pathtype="OS" + ) + ), + rdf_client.Process(pid=42), ], - process_responses_separately=True) + process_responses_separately=True, + ) # Should have two separate output streams for the two types self.assertLen(output, 2) for name, stream, _, job_id in output: - self.assertIn(job_id, [ - "C-1000000000000000_Results_ExportedFile_1445995873", - "C-1000000000000000_Results_ExportedProcess_1445995873" - ]) + self.assertIn( + job_id, + [ + "C-1000000000000000_Results_ExportedFile_1445995873", + "C-1000000000000000_Results_ExportedProcess_1445995873", + ], + ) self._parseOutput(name, stream) @export_test_lib.WithAllExportConverters @@ -249,7 +277,8 @@ def testBigQueryPluginWithEarlyFlush(self): responses.append( rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"), + path="/foo/bar/%d" % i, pathtype="OS" + ), st_mode=33184, # octal = 100640 => u=rw,g=r,o= => -rw-r----- st_ino=1063090, st_dev=64512, @@ -260,7 +289,9 @@ def testBigQueryPluginWithEarlyFlush(self): st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892, - st_btime=1338111338)) + st_btime=1338111338, + ) + ) sizes = [37, 687, 722, 755, 788, 821, 684, 719, 752, 785] @@ -274,7 +305,8 @@ def GetSize(unused_path): with mock.patch.object(os.path, "getsize", GetSize): output = self.ProcessResponses( plugin_args=bigquery_plugin.BigQueryOutputPluginArgs(), - responses=responses) + responses=responses, + ) self.assertLen(output, 2) # Check that the output is still consistent @@ -287,8 +319,11 @@ def GetSize(unused_path): # TODO(user): there needs to be a better way to generate these files on # change than breaking into the debugger. expected_fd = open( - os.path.join(config.CONFIG["Test.data_dir"], "bigquery", - "ExportedFile.jsonlines"), "rb") + os.path.join( + config.CONFIG["Test.data_dir"], "bigquery", "ExportedFile.jsonlines" + ), + "rb", + ) # Check that the same entries we expect are spread across the two files. counter = 0 diff --git a/grr/server/grr_response_server/output_plugins/csv_plugin.py b/grr/server/grr_response_server/output_plugins/csv_plugin.py index e30d324cf8..a48eb6c4ee 100644 --- a/grr/server/grr_response_server/output_plugins/csv_plugin.py +++ b/grr/server/grr_response_server/output_plugins/csv_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """CSV single-pass output plugin.""" + import csv import io import os @@ -15,7 +16,8 @@ class CSVInstantOutputPlugin( - instant_output_plugin.InstantOutputPluginWithExportConversion): + instant_output_plugin.InstantOutputPluginWithExportConversion +): """Instant Output plugin that writes results to an archive of CSV files.""" plugin_name = "csv-zip" @@ -25,13 +27,15 @@ class CSVInstantOutputPlugin( ROW_BATCH = 100 - def _GetCSVHeader(self, value_class, prefix=u""): + def _GetCSVHeader(self, value_class, prefix=""): header = [] for type_info in value_class.type_infos: if isinstance(type_info, rdf_structs.ProtoEmbedded): header.extend( self._GetCSVHeader( - type_info.type, prefix=prefix + type_info.name + u".")) + type_info.type, prefix=prefix + type_info.name + "." + ) + ) else: header.append(prefix + type_info.name) @@ -56,19 +60,26 @@ def path_prefix(self): def Start(self): self.archive_generator = utils.StreamingZipGenerator( - compression=zipfile.ZIP_DEFLATED) + compression=zipfile.ZIP_DEFLATED + ) self.export_counts = {} return [] - def ProcessSingleTypeExportedValues(self, original_value_type, - exported_values): + def ProcessSingleTypeExportedValues( + self, original_value_type, exported_values + ): first_value = next(exported_values, None) if not first_value: return yield self.archive_generator.WriteFileHeader( - "%s/%s/from_%s.csv" % (self.path_prefix, first_value.__class__.__name__, - original_value_type.__name__)) + "%s/%s/from_%s.csv" + % ( + self.path_prefix, + first_value.__class__.__name__, + original_value_type.__name__, + ) + ) buffer = io.StringIO() writer = csv.writer(buffer) @@ -96,9 +107,9 @@ def ProcessSingleTypeExportedValues(self, original_value_type, yield self.archive_generator.WriteFileFooter() - self.export_counts.setdefault( - original_value_type.__name__, - dict())[first_value.__class__.__name__] = counter + self.export_counts.setdefault(original_value_type.__name__, dict())[ + first_value.__class__.__name__ + ] = counter def Finish(self): manifest = {"export_stats": self.export_counts} diff --git a/grr/server/grr_response_server/output_plugins/csv_plugin_test.py b/grr/server/grr_response_server/output_plugins/csv_plugin_test.py index 77910df7a4..3c2c65a74a 100644 --- a/grr/server/grr_response_server/output_plugins/csv_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/csv_plugin_test.py @@ -35,7 +35,8 @@ def testCSVPluginWithValuesOfSameType(self): responses.append( rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"), + path="/foo/bar/%d" % i, pathtype="OS" + ), st_mode=33184, # octal = 100640 => u=rw,g=r,o= => -rw-r----- st_ino=1063090, st_dev=64512, @@ -45,24 +46,25 @@ def testCSVPluginWithValuesOfSameType(self): st_size=0, st_atime=1336469177, st_mtime=1336129892, - st_ctime=1336129892)) + st_ctime=1336129892, + ) + ) zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: responses}) + {rdf_client_fs.StatEntry: responses} + ) self.assertEqual( set(zip_fd.namelist()), set([ "%s/MANIFEST" % prefix, - "%s/ExportedFile/from_StatEntry.csv" % prefix - ])) + "%s/ExportedFile/from_StatEntry.csv" % prefix, + ]), + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) - self.assertEqual(parsed_manifest, - {"export_stats": { - "StatEntry": { - "ExportedFile": 10 - } - }}) + self.assertEqual( + parsed_manifest, {"export_stats": {"StatEntry": {"ExportedFile": 10}}} + ) with zip_fd.open("%s/ExportedFile/from_StatEntry.csv" % prefix) as filedesc: content = filedesc.read().decode("utf-8") @@ -71,19 +73,27 @@ def testCSVPluginWithValuesOfSameType(self): self.assertLen(parsed_output, 10) for i in range(10): # Make sure metadata is filled in. - self.assertEqual(parsed_output[i]["metadata.client_urn"], - "aff4:/%s" % self.client_id) - self.assertEqual(parsed_output[i]["metadata.hostname"], - "Host-0.example.com") - self.assertEqual(parsed_output[i]["metadata.mac_address"], - "aabbccddee00\nbbccddeeff00") - self.assertEqual(parsed_output[i]["metadata.source_urn"], - self.results_urn) - self.assertEqual(parsed_output[i]["metadata.hardware_info.bios_version"], - "Bios-Version-0") - - self.assertEqual(parsed_output[i]["urn"], - "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i)) + self.assertEqual( + parsed_output[i]["metadata.client_urn"], "aff4:/%s" % self.client_id + ) + self.assertEqual( + parsed_output[i]["metadata.hostname"], "Host-0.example.com" + ) + self.assertEqual( + parsed_output[i]["metadata.mac_address"], "aabbccddee00\nbbccddeeff00" + ) + self.assertEqual( + parsed_output[i]["metadata.source_urn"], self.results_urn + ) + self.assertEqual( + parsed_output[i]["metadata.hardware_info.bios_version"], + "Bios-Version-0", + ) + + self.assertEqual( + parsed_output[i]["urn"], + "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i), + ) self.assertEqual(parsed_output[i]["st_mode"], "-rw-r-----") self.assertEqual(parsed_output[i]["st_ino"], "1063090") self.assertEqual(parsed_output[i]["st_dev"], "64512") @@ -103,30 +113,30 @@ def testCSVPluginWithValuesOfMultipleTypes(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS")) + pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS") + ) ], - rdf_client.Process: [rdf_client.Process(pid=42)] + rdf_client.Process: [rdf_client.Process(pid=42)], }) self.assertEqual( set(zip_fd.namelist()), set([ "%s/MANIFEST" % prefix, "%s/ExportedFile/from_StatEntry.csv" % prefix, - "%s/ExportedProcess/from_Process.csv" % prefix - ])) + "%s/ExportedProcess/from_Process.csv" % prefix, + ]), + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) self.assertEqual( - parsed_manifest, { + parsed_manifest, + { "export_stats": { - "StatEntry": { - "ExportedFile": 1 - }, - "Process": { - "ExportedProcess": 1 - } + "StatEntry": {"ExportedFile": 1}, + "Process": {"ExportedProcess": 1}, } - }) + }, + ) with zip_fd.open("%s/ExportedFile/from_StatEntry.csv" % prefix) as filedesc: content = filedesc.read().decode("utf-8") @@ -135,15 +145,19 @@ def testCSVPluginWithValuesOfMultipleTypes(self): self.assertLen(parsed_output, 1) # Make sure metadata is filled in. - self.assertEqual(parsed_output[0]["metadata.client_urn"], - "aff4:/%s" % self.client_id) - self.assertEqual(parsed_output[0]["metadata.hostname"], - "Host-0.example.com") - self.assertEqual(parsed_output[0]["metadata.mac_address"], - "aabbccddee00\nbbccddeeff00") + self.assertEqual( + parsed_output[0]["metadata.client_urn"], "aff4:/%s" % self.client_id + ) + self.assertEqual( + parsed_output[0]["metadata.hostname"], "Host-0.example.com" + ) + self.assertEqual( + parsed_output[0]["metadata.mac_address"], "aabbccddee00\nbbccddeeff00" + ) self.assertEqual(parsed_output[0]["metadata.source_urn"], self.results_urn) - self.assertEqual(parsed_output[0]["urn"], - "aff4:/%s/fs/os/foo/bar" % self.client_id) + self.assertEqual( + parsed_output[0]["urn"], "aff4:/%s/fs/os/foo/bar" % self.client_id + ) filepath = "%s/ExportedProcess/from_Process.csv" % prefix with zip_fd.open(filepath) as filedesc: @@ -152,12 +166,15 @@ def testCSVPluginWithValuesOfMultipleTypes(self): parsed_output = list(csv.DictReader(io.StringIO(content))) self.assertLen(parsed_output, 1) - self.assertEqual(parsed_output[0]["metadata.client_urn"], - "aff4:/%s" % self.client_id) - self.assertEqual(parsed_output[0]["metadata.hostname"], - "Host-0.example.com") - self.assertEqual(parsed_output[0]["metadata.mac_address"], - "aabbccddee00\nbbccddeeff00") + self.assertEqual( + parsed_output[0]["metadata.client_urn"], "aff4:/%s" % self.client_id + ) + self.assertEqual( + parsed_output[0]["metadata.hostname"], "Host-0.example.com" + ) + self.assertEqual( + parsed_output[0]["metadata.mac_address"], "aabbccddee00\nbbccddeeff00" + ) self.assertEqual(parsed_output[0]["metadata.source_urn"], self.results_urn) self.assertEqual(parsed_output[0]["pid"], "42") @@ -166,15 +183,19 @@ def testCSVPluginWritesUnicodeValuesCorrectly(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/中国新闻网新闻中", pathtype="OS")) + pathspec=rdf_paths.PathSpec( + path="/中国新闻网新闻中", pathtype="OS" + ) + ) ] }) self.assertEqual( set(zip_fd.namelist()), set([ "%s/MANIFEST" % prefix, - "%s/ExportedFile/from_StatEntry.csv" % prefix - ])) + "%s/ExportedFile/from_StatEntry.csv" % prefix, + ]), + ) data = zip_fd.open("%s/ExportedFile/from_StatEntry.csv" % prefix) data = io.TextIOWrapper(data, encoding="utf-8") @@ -220,11 +241,13 @@ def testCSVPluginWritesMoreThanOneBatchOfRowsCorrectly(self): for i in range(num_rows): responses.append( rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"))) + pathspec=rdf_paths.PathSpec(path="/foo/bar/%d" % i, pathtype="OS") + ) + ) zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: responses}) + {rdf_client_fs.StatEntry: responses} + ) with zip_fd.open("%s/ExportedFile/from_StatEntry.csv" % prefix) as filedesc: content = filedesc.read().decode("utf-8") @@ -232,8 +255,10 @@ def testCSVPluginWritesMoreThanOneBatchOfRowsCorrectly(self): parsed_output = list(csv.DictReader(io.StringIO(content))) self.assertLen(parsed_output, num_rows) for i in range(num_rows): - self.assertEqual(parsed_output[i]["urn"], - "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i)) + self.assertEqual( + parsed_output[i]["urn"], + "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i), + ) def main(argv): diff --git a/grr/server/grr_response_server/output_plugins/elasticsearch_plugin.py b/grr/server/grr_response_server/output_plugins/elasticsearch_plugin.py index a1c6b464e6..94466f0f60 100644 --- a/grr/server/grr_response_server/output_plugins/elasticsearch_plugin.py +++ b/grr/server/grr_response_server/output_plugins/elasticsearch_plugin.py @@ -7,11 +7,9 @@ The specification for the indexing of documents is https://www.elastic.co/guide/en/elasticsearch/reference/7.1/docs-index_.html """ + import json -from typing import Any -from typing import Dict -from typing import List -from typing import Text +from typing import Any, Dict, List from urllib import parse as urlparse import requests @@ -27,21 +25,22 @@ from grr_response_server.export_converters import base from grr_response_server.gui.api_plugins import flow as api_flow from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_objects BULK_OPERATIONS_PATH = "_bulk" # TODO(user): Use the JSON type. -JsonDict = Dict[Text, Any] +JsonDict = Dict[str, Any] class ElasticsearchConfigurationError(Exception): """Error indicating a wrong or missing Elasticsearch configuration.""" - pass class ElasticsearchOutputPluginArgs(rdf_structs.RDFProtoStruct): """An RDF wrapper class for the arguments of ElasticsearchOutputPlugin.""" + protobuf = output_plugin_pb2.ElasticsearchOutputPluginArgs rdf_deps = [] @@ -66,7 +65,8 @@ def __init__(self, *args, **kwargs): raise ElasticsearchConfigurationError( "Cannot start ElasticsearchOutputPlugin, because Elasticsearch.url" "is not configured. Set it to the base URL of your Elasticsearch" - "installation, e.g. 'https://myelasticsearch.example.com:9200'.") + "installation, e.g. 'https://myelasticsearch.example.com:9200'." + ) self._verify_https = config.CONFIG["Elasticsearch.verify_https"] self._token = config.CONFIG["Elasticsearch.token"] @@ -90,32 +90,38 @@ def ProcessResponses( events = [self._MakeEvent(response, client, flow) for response in responses] self._SendEvents(events) - def _GetClientId(self, responses: List[rdf_flow_objects.FlowResult]) -> Text: + def _GetClientId(self, responses: List[rdf_flow_objects.FlowResult]) -> str: client_ids = {msg.client_id for msg in responses} if len(client_ids) > 1: - raise AssertionError(( - "ProcessResponses received messages from different Clients {}, which " - "violates OutputPlugin constraints.").format(client_ids)) + raise AssertionError( + ( + "ProcessResponses received messages from different Clients {}," + " which violates OutputPlugin constraints." + ).format(client_ids) + ) return client_ids.pop() - def _GetFlowId(self, responses: List[rdf_flow_objects.FlowResult]) -> Text: + def _GetFlowId(self, responses: List[rdf_flow_objects.FlowResult]) -> str: flow_ids = {msg.flow_id for msg in responses} if len(flow_ids) > 1: raise AssertionError( - ("ProcessResponses received messages from different Flows {}, which " - "violates OutputPlugin constraints.").format(flow_ids)) + ( + "ProcessResponses received messages from different Flows {}," + " which violates OutputPlugin constraints." + ).format(flow_ids) + ) return flow_ids.pop() - def _GetClientMetadata(self, client_id: Text) -> base.ExportedMetadata: + def _GetClientMetadata(self, client_id: str) -> base.ExportedMetadata: info = data_store.REL_DB.ReadClientFullInfo(client_id) info = mig_objects.ToRDFClientFullInfo(info) metadata = export.GetMetadata(client_id, info) metadata.timestamp = None # timestamp is sent outside of metadata. return metadata - def _GetFlowMetadata(self, client_id: Text, - flow_id: Text) -> api_flow.ApiFlow: + def _GetFlowMetadata(self, client_id: str, flow_id: str) -> api_flow.ApiFlow: flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) return api_flow.ApiFlow().InitFromFlowObject(flow_obj) def _MakeEvent( @@ -138,14 +144,13 @@ def _MakeEvent( return event def _SendEvents(self, events: List[JsonDict]) -> None: - """Uses the Elasticsearch bulk API to index all events in a single request. - """ + """Uses the Elasticsearch bulk API to index all events in a single request.""" # https://www.elastic.co/guide/en/elasticsearch/reference/7.1/docs-bulk.html if self._token: headers = { "Authorization": "Basic {}".format(self._token), - "Content-Type": "application/json" + "Content-Type": "application/json", } else: headers = {"Content-Type": "application/json"} @@ -155,15 +160,14 @@ def _SendEvents(self, events: List[JsonDict]) -> None: # Each index operation is two lines, the first defining the index settings, # the second is the actual document to be indexed data = ( - "\n".join( - [ - "{}\n{}".format(index_command, json.dumps(event, indent=None)) - for event in events - ] - ) + "\n".join([ + "{}\n{}".format(index_command, json.dumps(event, indent=None)) + for event in events + ]) + "\n" ) response = requests.post( - url=self._url, verify=self._verify_https, data=data, headers=headers) + url=self._url, verify=self._verify_https, data=data, headers=headers + ) response.raise_for_status() diff --git a/grr/server/grr_response_server/output_plugins/elasticsearch_plugin_test.py b/grr/server/grr_response_server/output_plugins/elasticsearch_plugin_test.py index ec2fe858ed..764798ff85 100644 --- a/grr/server/grr_response_server/output_plugins/elasticsearch_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/elasticsearch_plugin_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Tests for Elasticsearch output plugin.""" + import json from unittest import mock @@ -13,6 +14,7 @@ from grr_response_server import data_store from grr_response_server.output_plugins import elasticsearch_plugin from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -29,16 +31,21 @@ def setUp(self): self.client_id = self.SetupClient(0) self.flow_id = '12345678' data_store.REL_DB.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=self.flow_id, - client_id=self.client_id, - flow_class_name='ClientFileFinder', - create_time=rdfvalue.RDFDatetime.Now(), - )) + mig_flow_objects.ToProtoFlow( + rdf_flow_objects.Flow( + flow_id=self.flow_id, + client_id=self.client_id, + flow_class_name='ClientFileFinder', + ) + ) + ) def _CallPlugin(self, plugin_args=None, responses=None, patcher=None): - source_id = rdf_client.ClientURN( - self.client_id).Add('Results').RelativeName('aff4:/') + source_id = ( + rdf_client.ClientURN(self.client_id) + .Add('Results') + .RelativeName('aff4:/') + ) messages = [] for response in responses: @@ -50,7 +57,8 @@ def _CallPlugin(self, plugin_args=None, responses=None, patcher=None): plugin_cls = elasticsearch_plugin.ElasticsearchOutputPlugin plugin, plugin_state = plugin_cls.CreatePluginAndDefaultState( - source_urn=source_id, args=plugin_args) + source_urn=source_id, args=plugin_args + ) if patcher is None: patcher = mock.patch.object(requests, 'post') @@ -87,27 +95,34 @@ def testPopulatesEventCorrectly(self): with test_lib.FakeTime(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(15)): mock_post = self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs( - index='idx', tags=['a', 'b', 'c']), + index='idx', tags=['a', 'b', 'c'] + ), responses=[ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS')) - ]) + pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS') + ) + ], + ) bulk_pairs = self._ParseEvents(mock_post) self.assertLen(bulk_pairs, 1) event_pair = bulk_pairs[0] self.assertEqual(event_pair[0]['index']['_index'], 'idx') - self.assertEqual(event_pair[1]['client']['clientUrn'], - 'aff4:/C.1000000000000000') + self.assertEqual( + event_pair[1]['client']['clientUrn'], 'aff4:/C.1000000000000000' + ) self.assertEqual(event_pair[1]['flow']['flowId'], '12345678') self.assertEqual(event_pair[1]['tags'], ['a', 'b', 'c']) self.assertEqual(event_pair[1]['resultType'], 'StatEntry') - self.assertEqual(event_pair[1]['result'], { - 'pathspec': { - 'pathtype': 'OS', - 'path': '/中国', + self.assertEqual( + event_pair[1]['result'], + { + 'pathspec': { + 'pathtype': 'OS', + 'path': '/中国', + }, }, - }) + ) def testPopulatesBatchCorrectly(self): with test_lib.ConfigOverrider({ @@ -118,48 +133,61 @@ def testPopulatesBatchCorrectly(self): plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs(), responses=[ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS')), + pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS') + ), rdf_client.Process(pid=42), - ]) + ], + ) bulk_pairs = self._ParseEvents(mock_post) self.assertLen(bulk_pairs, 2) for event_pair in bulk_pairs: - self.assertEqual(event_pair[1]['client']['clientUrn'], - 'aff4:/C.1000000000000000') + self.assertEqual( + event_pair[1]['client']['clientUrn'], 'aff4:/C.1000000000000000' + ) self.assertEqual(bulk_pairs[0][1]['resultType'], 'StatEntry') - self.assertEqual(bulk_pairs[0][1]['result'], { - 'pathspec': { - 'pathtype': 'OS', - 'path': '/中国', + self.assertEqual( + bulk_pairs[0][1]['result'], + { + 'pathspec': { + 'pathtype': 'OS', + 'path': '/中国', + }, }, - }) + ) self.assertEqual(bulk_pairs[1][1]['resultType'], 'Process') - self.assertEqual(bulk_pairs[1][1]['result'], { - 'pid': 42, - }) + self.assertEqual( + bulk_pairs[1][1]['result'], + { + 'pid': 42, + }, + ) def testReadsConfigurationValuesCorrectly(self): with test_lib.ConfigOverrider({ 'Elasticsearch.url': 'http://a', 'Elasticsearch.token': 'b', 'Elasticsearch.verify_https': False, - 'Elasticsearch.index': 'e' + 'Elasticsearch.index': 'e', }): mock_post = self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) self.assertEqual(mock_post.call_args[KWARGS]['url'], 'http://a/_bulk') self.assertFalse(mock_post.call_args[KWARGS]['verify']) - self.assertEqual(mock_post.call_args[KWARGS]['headers']['Authorization'], - 'Basic b') + self.assertEqual( + mock_post.call_args[KWARGS]['headers']['Authorization'], 'Basic b' + ) - self.assertIn(mock_post.call_args[KWARGS]['headers']['Content-Type'], - ('application/json', 'application/x-ndjson')) + self.assertIn( + mock_post.call_args[KWARGS]['headers']['Content-Type'], + ('application/json', 'application/x-ndjson'), + ) bulk_pairs = self._ParseEvents(mock_post) self.assertEqual(bulk_pairs[0][0]['index']['_index'], 'e') @@ -168,21 +196,25 @@ def testFailsWhenUrlIsNotConfigured(self): with test_lib.ConfigOverrider({'Elasticsearch.token': 'b'}): with self.assertRaisesRegex( elasticsearch_plugin.ElasticsearchConfigurationError, - 'Elasticsearch.url'): + 'Elasticsearch.url', + ): self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) def testArgsOverrideConfiguration(self): with test_lib.ConfigOverrider({ 'Elasticsearch.url': 'http://a', 'Elasticsearch.token': 'b', - 'Elasticsearch.index': 'e' + 'Elasticsearch.index': 'e', }): mock_post = self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs( - index='f'), - responses=[rdf_client.Process(pid=42)]) + index='f' + ), + responses=[rdf_client.Process(pid=42)], + ) bulk_pairs = self._ParseEvents(mock_post) self.assertEqual(bulk_pairs[0][0]['index']['_index'], 'f') @@ -190,7 +222,8 @@ def testArgsOverrideConfiguration(self): def testRaisesForHttpError(self): post = mock.MagicMock() post.return_value.raise_for_status.side_effect = ( - requests.exceptions.HTTPError()) + requests.exceptions.HTTPError() + ) with test_lib.ConfigOverrider({ 'Elasticsearch.url': 'http://a', @@ -200,7 +233,8 @@ def testRaisesForHttpError(self): self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs(), responses=[rdf_client.Process(pid=42)], - patcher=mock.patch.object(requests, 'post', post)) + patcher=mock.patch.object(requests, 'post', post), + ) def testPostDataTerminatingNewline(self): with test_lib.ConfigOverrider({ @@ -209,7 +243,8 @@ def testPostDataTerminatingNewline(self): }): mock_post = self._CallPlugin( plugin_args=elasticsearch_plugin.ElasticsearchOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) self.assertEndsWith(mock_post.call_args[KWARGS]['data'], '\n') diff --git a/grr/server/grr_response_server/output_plugins/email_plugin.py b/grr/server/grr_response_server/output_plugins/email_plugin.py index 9a85cb4eb1..45ec240773 100644 --- a/grr/server/grr_response_server/output_plugins/email_plugin.py +++ b/grr/server/grr_response_server/output_plugins/email_plugin.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Email live output plugin.""" - import jinja2 from grr_response_core import config @@ -30,7 +29,8 @@ class EmailOutputPlugin(output_plugin.OutputPlugin): produces_output_streams = False subject_template = jinja2.Template( - "GRR got a new result in {{ source_urn }}.", autoescape=True) + "GRR got a new result in {{ source_urn }}.", autoescape=True + ) template = jinja2.Template( """

GRR got a new result in {{ source_urn }}.

@@ -49,8 +49,10 @@ class EmailOutputPlugin(output_plugin.OutputPlugin): autoescape=True, ) - too_many_mails_msg = ("

This hunt has now produced %d results so the " - "sending of emails will be disabled now.

") + too_many_mails_msg = ( + "

This hunt has now produced %d results so the " + "sending of emails will be disabled now.

" + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -76,12 +78,13 @@ def ProcessResponse(self, state, response): client_fragment_id = "/clients/%s" % client_id if emails_left == 0: - additional_message = (self.too_many_mails_msg % self.args.emails_limit) + additional_message = self.too_many_mails_msg % self.args.emails_limit else: additional_message = "" subject = self.__class__.subject_template.render( - source_urn=str(self.source_urn)) + source_urn=str(self.source_urn) + ) body = self.__class__.template.render( client_id=client_id, client_fragment_id=client_fragment_id, @@ -89,10 +92,12 @@ def ProcessResponse(self, state, response): source_urn=self.source_urn, additional_message=additional_message, signature=config.CONFIG["Email.signature"], - hostname=hostname) + hostname=hostname, + ) email_alerts.EMAIL_ALERTER.SendEmail( - self.args.email_address, "grr-noreply", subject, body, is_html=True) + self.args.email_address, "grr-noreply", subject, body, is_html=True + ) def ProcessResponses(self, state, responses): for response in responses: diff --git a/grr/server/grr_response_server/output_plugins/email_plugin_test.py b/grr/server/grr_response_server/output_plugins/email_plugin_test.py index 7abe47a3b3..695bb5b067 100644 --- a/grr/server/grr_response_server/output_plugins/email_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/email_plugin_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Tests for email output plugin.""" + from unittest import mock from absl import app @@ -25,13 +26,13 @@ def setUp(self): self.email_messages = [] self.email_address = "notify@%s" % config.CONFIG["Logging.domain"] - def ProcessResponses(self, - plugin_args=None, - responses=None, - process_responses_separately=False): + def ProcessResponses( + self, plugin_args=None, responses=None, process_responses_separately=False + ): plugin_cls = email_plugin.EmailOutputPlugin plugin, plugin_state = plugin_cls.CreatePluginAndDefaultState( - source_urn=self.results_urn, args=plugin_args) + source_urn=self.results_urn, args=plugin_args + ) messages = [] for response in responses: @@ -43,7 +44,8 @@ def ProcessResponses(self, def SendEmail(address, sender, title, message, **_): self.email_messages.append( - dict(address=address, sender=sender, title=title, message=message)) + dict(address=address, sender=sender, title=title, message=message) + ) with mock.patch.object(email_alerts.EMAIL_ALERTER, "SendEmail", SendEmail): if process_responses_separately: @@ -58,8 +60,10 @@ def SendEmail(address, sender, title, message, **_): def testEmailPluginSendsEmailPerEveyBatchOfResponses(self): self.ProcessResponses( plugin_args=email_plugin.EmailOutputPluginArgs( - email_address=self.email_address), - responses=[rdf_client.Process(pid=42)]) + email_address=self.email_address + ), + responses=[rdf_client.Process(pid=42)], + ) self.assertLen(self.email_messages, 1) @@ -73,9 +77,11 @@ def testEmailPluginStopsSendingEmailsAfterLimitIsReached(self): responses = [rdf_client.Process(pid=i) for i in range(11)] self.ProcessResponses( plugin_args=email_plugin.EmailOutputPluginArgs( - email_address=self.email_address, emails_limit=10), + email_address=self.email_address, emails_limit=10 + ), responses=responses, - process_responses_separately=True) + process_responses_separately=True, + ) self.assertLen(self.email_messages, 10) @@ -88,8 +94,10 @@ def testEmailPluginStopsSendingEmailsAfterLimitIsReached(self): for msg in self.email_messages[:10]: self.assertNotIn("sending of emails will be disabled now", msg) - self.assertIn("sending of emails will be disabled now", - self.email_messages[9]["message"]) + self.assertIn( + "sending of emails will be disabled now", + self.email_messages[9]["message"], + ) def main(argv): diff --git a/grr/server/grr_response_server/output_plugins/mig_bigquery_plugin.py b/grr/server/grr_response_server/output_plugins/mig_bigquery_plugin.py index 3073bd31af..cac1812b59 100644 --- a/grr/server/grr_response_server/output_plugins/mig_bigquery_plugin.py +++ b/grr/server/grr_response_server/output_plugins/mig_bigquery_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import output_plugin_pb2 from grr_response_server.output_plugins import bigquery_plugin diff --git a/grr/server/grr_response_server/output_plugins/mig_elasticsearch_plugin.py b/grr/server/grr_response_server/output_plugins/mig_elasticsearch_plugin.py index 27e23e23a0..a7aaba99a4 100644 --- a/grr/server/grr_response_server/output_plugins/mig_elasticsearch_plugin.py +++ b/grr/server/grr_response_server/output_plugins/mig_elasticsearch_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import output_plugin_pb2 from grr_response_server.output_plugins import elasticsearch_plugin diff --git a/grr/server/grr_response_server/output_plugins/mig_email_plugin.py b/grr/server/grr_response_server/output_plugins/mig_email_plugin.py index 62e08b97d0..00875bb5f8 100644 --- a/grr/server/grr_response_server/output_plugins/mig_email_plugin.py +++ b/grr/server/grr_response_server/output_plugins/mig_email_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import output_plugin_pb2 from grr_response_server.output_plugins import email_plugin diff --git a/grr/server/grr_response_server/output_plugins/mig_splunk_plugin.py b/grr/server/grr_response_server/output_plugins/mig_splunk_plugin.py index 4b37804797..ce361b23ce 100644 --- a/grr/server/grr_response_server/output_plugins/mig_splunk_plugin.py +++ b/grr/server/grr_response_server/output_plugins/mig_splunk_plugin.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Provides conversion functions to be used during RDFProtoStruct migration.""" + from grr_response_proto import output_plugin_pb2 from grr_response_server.output_plugins import splunk_plugin diff --git a/grr/server/grr_response_server/output_plugins/splunk_plugin.py b/grr/server/grr_response_server/output_plugins/splunk_plugin.py index 07bd54546a..6db5d4e5a8 100644 --- a/grr/server/grr_response_server/output_plugins/splunk_plugin.py +++ b/grr/server/grr_response_server/output_plugins/splunk_plugin.py @@ -7,11 +7,9 @@ The spec for HTTP Event Collector is taken from https://docs.splunk.com /Documentation/Splunk/8.0.1/Data/FormateventsforHTTPEventCollector """ + import json -from typing import Any -from typing import Dict -from typing import List -from typing import Text +from typing import Any, Dict, List from urllib import parse as urlparse import requests @@ -28,20 +26,21 @@ from grr_response_server.export_converters import base from grr_response_server.gui.api_plugins import flow as api_flow from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr_response_server.rdfvalues import mig_objects HTTP_EVENT_COLLECTOR_PATH = "services/collector/event" -JsonDict = Dict[Text, Any] +JsonDict = Dict[str, Any] class SplunkConfigurationError(Exception): """Error indicating a wrong or missing Splunk configuration.""" - pass class SplunkOutputPluginArgs(rdf_structs.RDFProtoStruct): """An RDF wrapper class for the arguments of SplunkOutputPlugin.""" + protobuf = output_plugin_pb2.SplunkOutputPluginArgs rdf_deps = [] @@ -73,14 +72,16 @@ def __init__(self, *args, **kwargs): raise SplunkConfigurationError( "Cannot start SplunkOutputPlugin, because Splunk.url is not " "configured. Set it to the base URL of your Splunk installation, " - "e.g. 'https://mysplunkserver.example.com:8088'.") + "e.g. 'https://mysplunkserver.example.com:8088'." + ) if not self._token: raise SplunkConfigurationError( "Cannot start SplunkOutputPlugin, because Splunk.token " "is not configured. You can get this authentication " "token when configuring a new HEC input in your Splunk " - "installation.") + "installation." + ) self._url = urlparse.urljoin(url, HTTP_EVENT_COLLECTOR_PATH) @@ -99,32 +100,38 @@ def ProcessResponses( events = [self._MakeEvent(response, client, flow) for response in responses] self._SendEvents(events) - def _GetClientId(self, responses: List[rdf_flow_objects.FlowResult]) -> Text: + def _GetClientId(self, responses: List[rdf_flow_objects.FlowResult]) -> str: client_ids = {msg.client_id for msg in responses} if len(client_ids) > 1: - raise AssertionError(( - "ProcessResponses received messages from different Clients {}, which " - "violates OutputPlugin constraints.").format(client_ids)) + raise AssertionError( + ( + "ProcessResponses received messages from different Clients {}," + " which violates OutputPlugin constraints." + ).format(client_ids) + ) return client_ids.pop() - def _GetFlowId(self, responses: List[rdf_flow_objects.FlowResult]) -> Text: + def _GetFlowId(self, responses: List[rdf_flow_objects.FlowResult]) -> str: flow_ids = {msg.flow_id for msg in responses} if len(flow_ids) > 1: raise AssertionError( - ("ProcessResponses received messages from different Flows {}, which " - "violates OutputPlugin constraints.").format(flow_ids)) + ( + "ProcessResponses received messages from different Flows {}," + " which violates OutputPlugin constraints." + ).format(flow_ids) + ) return flow_ids.pop() - def _GetClientMetadata(self, client_id: Text) -> base.ExportedMetadata: + def _GetClientMetadata(self, client_id: str) -> base.ExportedMetadata: info = data_store.REL_DB.ReadClientFullInfo(client_id) info = mig_objects.ToRDFClientFullInfo(info) metadata = export.GetMetadata(client_id, info) metadata.timestamp = None # timestamp is sent outside of metadata. return metadata - def _GetFlowMetadata(self, client_id: Text, - flow_id: Text) -> api_flow.ApiFlow: + def _GetFlowMetadata(self, client_id: str, flow_id: str) -> api_flow.ApiFlow: flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + flow_obj = mig_flow_objects.ToRDFFlow(flow_obj) return api_flow.ApiFlow().InitFromFlowObject(flow_obj) def _MakeEvent( @@ -167,5 +174,6 @@ def _SendEvents(self, events: List[JsonDict]) -> None: data = "\n\n".join(json.dumps(event) for event in events) response = requests.post( - url=self._url, verify=self._verify_https, data=data, headers=headers) + url=self._url, verify=self._verify_https, data=data, headers=headers + ) response.raise_for_status() diff --git a/grr/server/grr_response_server/output_plugins/splunk_plugin_test.py b/grr/server/grr_response_server/output_plugins/splunk_plugin_test.py index 1c7d559840..309c0522c8 100644 --- a/grr/server/grr_response_server/output_plugins/splunk_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/splunk_plugin_test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Tests for Splunk output plugin.""" + import json from unittest import mock @@ -13,6 +14,7 @@ from grr_response_server import data_store from grr_response_server.output_plugins import splunk_plugin from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects +from grr_response_server.rdfvalues import mig_flow_objects from grr.test_lib import flow_test_lib from grr.test_lib import test_lib @@ -28,16 +30,21 @@ def setUp(self): self.client_id = self.SetupClient(0) self.flow_id = '12345678' data_store.REL_DB.WriteFlowObject( - rdf_flow_objects.Flow( - flow_id=self.flow_id, - client_id=self.client_id, - flow_class_name='ClientFileFinder', - create_time=rdfvalue.RDFDatetime.Now(), - )) + mig_flow_objects.ToProtoFlow( + rdf_flow_objects.Flow( + flow_id=self.flow_id, + client_id=self.client_id, + flow_class_name='ClientFileFinder', + ) + ) + ) def _CallPlugin(self, plugin_args=None, responses=None, patcher=None): - source_id = rdf_client.ClientURN( - self.client_id).Add('Results').RelativeName('aff4:/') + source_id = ( + rdf_client.ClientURN(self.client_id) + .Add('Results') + .RelativeName('aff4:/') + ) messages = [] for response in responses: @@ -49,7 +56,8 @@ def _CallPlugin(self, plugin_args=None, responses=None, patcher=None): plugin_cls = splunk_plugin.SplunkOutputPlugin plugin, plugin_state = plugin_cls.CreatePluginAndDefaultState( - source_urn=source_id, args=plugin_args) + source_urn=source_id, args=plugin_args + ) if patcher is None: patcher = mock.patch.object(requests, 'post') @@ -73,11 +81,14 @@ def testPopulatesEventCorrectly(self): with test_lib.FakeTime(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(15)): mock_post = self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs( - index='idx', annotations=['a', 'b', 'c']), + index='idx', annotations=['a', 'b', 'c'] + ), responses=[ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS')) - ]) + pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS') + ) + ], + ) events = self._ParseEvents(mock_post) self.assertLen(events, 1) @@ -85,17 +96,21 @@ def testPopulatesEventCorrectly(self): self.assertEqual(events[0]['sourcetype'], 'grr_flow_result') self.assertEqual(events[0]['source'], 'grr') self.assertEqual(events[0]['time'], 15) - self.assertEqual(events[0]['event']['client']['clientUrn'], - 'aff4:/C.1000000000000000') + self.assertEqual( + events[0]['event']['client']['clientUrn'], 'aff4:/C.1000000000000000' + ) self.assertEqual(events[0]['event']['annotations'], ['a', 'b', 'c']) self.assertEqual(events[0]['event']['flow']['flowId'], '12345678') self.assertEqual(events[0]['event']['resultType'], 'StatEntry') - self.assertEqual(events[0]['event']['result'], { - 'pathspec': { - 'pathtype': 'OS', - 'path': '/中国', + self.assertEqual( + events[0]['event']['result'], + { + 'pathspec': { + 'pathtype': 'OS', + 'path': '/中国', + }, }, - }) + ) def testPopulatesBatchCorrectly(self): with test_lib.ConfigOverrider({ @@ -106,9 +121,11 @@ def testPopulatesBatchCorrectly(self): plugin_args=splunk_plugin.SplunkOutputPluginArgs(), responses=[ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS')), + pathspec=rdf_paths.PathSpec(path='/中国', pathtype='OS') + ), rdf_client.Process(pid=42), - ]) + ], + ) events = self._ParseEvents(mock_post) self.assertLen(events, 2) @@ -116,21 +133,28 @@ def testPopulatesBatchCorrectly(self): self.assertEqual(event['sourcetype'], 'grr_flow_result') self.assertEqual(event['source'], 'grr') self.assertEqual(event['host'], 'Host-0.example.com') - self.assertEqual(event['event']['client']['clientUrn'], - 'aff4:/C.1000000000000000') + self.assertEqual( + event['event']['client']['clientUrn'], 'aff4:/C.1000000000000000' + ) self.assertEqual(events[0]['event']['resultType'], 'StatEntry') - self.assertEqual(events[0]['event']['result'], { - 'pathspec': { - 'pathtype': 'OS', - 'path': '/中国', + self.assertEqual( + events[0]['event']['result'], + { + 'pathspec': { + 'pathtype': 'OS', + 'path': '/中国', + }, }, - }) + ) self.assertEqual(events[1]['event']['resultType'], 'Process') - self.assertEqual(events[1]['event']['result'], { - 'pid': 42, - }) + self.assertEqual( + events[1]['event']['result'], + { + 'pid': 42, + }, + ) def testReadsConfigurationValuesCorrectly(self): with test_lib.ConfigOverrider({ @@ -139,17 +163,20 @@ def testReadsConfigurationValuesCorrectly(self): 'Splunk.verify_https': False, 'Splunk.source': 'c', 'Splunk.sourcetype': 'd', - 'Splunk.index': 'e' + 'Splunk.index': 'e', }): mock_post = self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) - self.assertEqual(mock_post.call_args[KWARGS]['url'], - 'http://a/services/collector/event') + self.assertEqual( + mock_post.call_args[KWARGS]['url'], 'http://a/services/collector/event' + ) self.assertFalse(mock_post.call_args[KWARGS]['verify']) - self.assertEqual(mock_post.call_args[KWARGS]['headers']['Authorization'], - 'Splunk b') + self.assertEqual( + mock_post.call_args[KWARGS]['headers']['Authorization'], 'Splunk b' + ) events = self._ParseEvents(mock_post) self.assertEqual(events[0]['source'], 'c') @@ -158,29 +185,32 @@ def testReadsConfigurationValuesCorrectly(self): def testFailsWhenUrlIsNotConfigured(self): with test_lib.ConfigOverrider({'Splunk.token': 'b'}): - with self.assertRaisesRegex(splunk_plugin.SplunkConfigurationError, - 'Splunk.url'): + with self.assertRaisesRegex( + splunk_plugin.SplunkConfigurationError, 'Splunk.url' + ): self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) def testFailsWhenTokenIsNotConfigured(self): with test_lib.ConfigOverrider({'Splunk.url': 'a'}): - with self.assertRaisesRegex(splunk_plugin.SplunkConfigurationError, - 'Splunk.token'): + with self.assertRaisesRegex( + splunk_plugin.SplunkConfigurationError, 'Splunk.token' + ): self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs(), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) def testArgsOverrideConfiguration(self): - with test_lib.ConfigOverrider({ - 'Splunk.url': 'http://a', - 'Splunk.token': 'b', - 'Splunk.index': 'e' - }): + with test_lib.ConfigOverrider( + {'Splunk.url': 'http://a', 'Splunk.token': 'b', 'Splunk.index': 'e'} + ): mock_post = self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs(index='f'), - responses=[rdf_client.Process(pid=42)]) + responses=[rdf_client.Process(pid=42)], + ) events = self._ParseEvents(mock_post) self.assertEqual(events[0]['index'], 'f') @@ -188,7 +218,8 @@ def testArgsOverrideConfiguration(self): def testRaisesForHttpError(self): post = mock.MagicMock() post.return_value.raise_for_status.side_effect = ( - requests.exceptions.HTTPError()) + requests.exceptions.HTTPError() + ) with test_lib.ConfigOverrider({ 'Splunk.url': 'http://a', @@ -198,7 +229,8 @@ def testRaisesForHttpError(self): self._CallPlugin( plugin_args=splunk_plugin.SplunkOutputPluginArgs(), responses=[rdf_client.Process(pid=42)], - patcher=mock.patch.object(requests, 'post', post)) + patcher=mock.patch.object(requests, 'post', post), + ) def main(argv): diff --git a/grr/server/grr_response_server/output_plugins/sqlite_plugin.py b/grr/server/grr_response_server/output_plugins/sqlite_plugin.py index 6f69eb140f..92fd2a03bf 100644 --- a/grr/server/grr_response_server/output_plugins/sqlite_plugin.py +++ b/grr/server/grr_response_server/output_plugins/sqlite_plugin.py @@ -1,10 +1,11 @@ #!/usr/bin/env python """Plugin that exports results as SQLite db scripts.""" + import io import os +import sqlite3 import zipfile -import sqlite3 import yaml from grr_response_core.lib import rdfvalue @@ -14,10 +15,10 @@ from grr_response_server import instant_output_plugin -class Rdf2SqliteAdapter(object): +class Rdf2SqliteAdapter: """An adapter for converting RDF values to a SQLite-friendly form.""" - class Converter(object): + class Converter: def __init__(self, sqlite_type, convert_fn): self.sqlite_type = sqlite_type @@ -33,20 +34,17 @@ def __init__(self, sqlite_type, convert_fn): # Converters for fields that have a semantic type annotation in their # protobuf definition. SEMANTIC_CONVERTERS = { - rdfvalue.RDFString: - STR_CONVERTER, - rdfvalue.RDFBytes: - BYTES_CONVERTER, - rdfvalue.RDFInteger: - INT_CONVERTER, - bool: - INT_CONVERTER, # Sqlite does not have a bool type. - rdfvalue.RDFDatetime: - Converter("INTEGER", lambda x: x.AsMicrosecondsSinceEpoch()), - rdfvalue.RDFDatetimeSeconds: - Converter("INTEGER", lambda x: x.AsSecondsSinceEpoch() * 1000000), - rdfvalue.DurationSeconds: - Converter("INTEGER", lambda x: x.microseconds), + rdfvalue.RDFString: STR_CONVERTER, + rdfvalue.RDFBytes: BYTES_CONVERTER, + rdfvalue.RDFInteger: INT_CONVERTER, + bool: INT_CONVERTER, # Sqlite does not have a bool type. + rdfvalue.RDFDatetime: Converter( + "INTEGER", lambda x: x.AsMicrosecondsSinceEpoch() + ), + rdfvalue.RDFDatetimeSeconds: Converter( + "INTEGER", lambda x: x.AsSecondsSinceEpoch() * 1000000 + ), + rdfvalue.DurationSeconds: Converter("INTEGER", lambda x: x.microseconds), } # Converters for fields that do not have a semantic type annotation in their @@ -68,14 +66,17 @@ def __init__(self, sqlite_type, convert_fn): def GetConverter(type_info): if type_info.__class__ is rdf_structs.ProtoRDFValue: return Rdf2SqliteAdapter.SEMANTIC_CONVERTERS.get( - type_info.type, Rdf2SqliteAdapter.DEFAULT_CONVERTER) + type_info.type, Rdf2SqliteAdapter.DEFAULT_CONVERTER + ) else: return Rdf2SqliteAdapter.NON_SEMANTIC_CONVERTERS.get( - type_info.__class__, Rdf2SqliteAdapter.DEFAULT_CONVERTER) + type_info.__class__, Rdf2SqliteAdapter.DEFAULT_CONVERTER + ) class SqliteInstantOutputPlugin( - instant_output_plugin.InstantOutputPluginWithExportConversion): + instant_output_plugin.InstantOutputPluginWithExportConversion +): """Instant output plugin that converts results into SQLite db commands.""" plugin_name = "sqlite-zip" @@ -97,12 +98,14 @@ def path_prefix(self): def Start(self): self.archive_generator = utils.StreamingZipGenerator( - compression=zipfile.ZIP_DEFLATED) + compression=zipfile.ZIP_DEFLATED + ) self.export_counts = {} return [] - def ProcessSingleTypeExportedValues(self, original_value_type, - exported_values): + def ProcessSingleTypeExportedValues( + self, original_value_type, exported_values + ): first_value = next(exported_values, None) if not first_value: return @@ -110,10 +113,17 @@ def ProcessSingleTypeExportedValues(self, original_value_type, if not isinstance(first_value, rdf_structs.RDFProtoStruct): raise ValueError("The SQLite plugin only supports export-protos") yield self.archive_generator.WriteFileHeader( - "%s/%s_from_%s.sql" % (self.path_prefix, first_value.__class__.__name__, - original_value_type.__name__)) - table_name = "%s.from_%s" % (first_value.__class__.__name__, - original_value_type.__name__) + "%s/%s_from_%s.sql" + % ( + self.path_prefix, + first_value.__class__.__name__, + original_value_type.__name__, + ) + ) + table_name = "%s.from_%s" % ( + first_value.__class__.__name__, + original_value_type.__name__, + ) schema = self._GetSqliteSchema(first_value.__class__) # We will buffer the sql statements into an in-memory sql database before @@ -123,14 +133,15 @@ def ProcessSingleTypeExportedValues(self, original_value_type, db_cursor = db_connection.cursor() yield self.archive_generator.WriteFileChunk( - "BEGIN TRANSACTION;\n".encode("utf-8")) + "BEGIN TRANSACTION;\n".encode("utf-8") + ) with db_connection: buf = io.StringIO() - buf.write(u"CREATE TABLE \"%s\" (\n " % table_name) + buf.write('CREATE TABLE "%s" (\n ' % table_name) column_types = [(k, v.sqlite_type) for k, v in schema.items()] - buf.write(u",\n ".join([u"\"%s\" %s" % (k, v) for k, v in column_types])) - buf.write(u"\n);") + buf.write(",\n ".join(['"%s" %s' % (k, v) for k, v in column_types])) + buf.write("\n);") db_cursor.execute(buf.getvalue()) chunk = (buf.getvalue() + "\n").encode("utf-8") @@ -154,7 +165,8 @@ def ProcessSingleTypeExportedValues(self, original_value_type, yield self.archive_generator.WriteFileFooter() counts_for_original_type = self.export_counts.setdefault( - original_value_type.__name__, dict()) + original_value_type.__name__, dict() + ) counts_for_original_type[first_value.__class__.__name__] = counter def _GetSqliteSchema(self, proto_struct_class, prefix=""): @@ -164,7 +176,9 @@ def _GetSqliteSchema(self, proto_struct_class, prefix=""): if type_info.__class__ is rdf_structs.ProtoEmbedded: schema.update( self._GetSqliteSchema( - type_info.type, prefix="%s%s." % (prefix, type_info.name))) + type_info.type, prefix="%s%s." % (prefix, type_info.name) + ) + ) else: field_name = prefix + type_info.name schema[field_name] = Rdf2SqliteAdapter.GetConverter(type_info) @@ -173,10 +187,10 @@ def _GetSqliteSchema(self, proto_struct_class, prefix=""): def _InsertValueIntoDb(self, table_name, schema, value, db_cursor): sql_dict = self._ConvertToCanonicalSqlDict(schema, value.ToPrimitiveDict()) buf = io.StringIO() - buf.write(u"INSERT INTO \"%s\" (\n " % table_name) - buf.write(u",\n ".join(["\"%s\"" % k for k in sql_dict.keys()])) - buf.write(u"\n)") - buf.write(u"VALUES (%s);" % u",".join([u"?"] * len(sql_dict))) + buf.write('INSERT INTO "%s" (\n ' % table_name) + buf.write(",\n ".join(['"%s"' % k for k in sql_dict.keys()])) + buf.write("\n)") + buf.write("VALUES (%s);" % ",".join(["?"] * len(sql_dict))) db_cursor.execute(buf.getvalue(), list(sql_dict.values())) def _ConvertToCanonicalSqlDict(self, schema, raw_dict, prefix=""): @@ -186,7 +200,9 @@ def _ConvertToCanonicalSqlDict(self, schema, raw_dict, prefix=""): if isinstance(v, dict): flattened_dict.update( self._ConvertToCanonicalSqlDict( - schema, v, prefix="%s%s." % (prefix, k))) + schema, v, prefix="%s%s." % (prefix, k) + ) + ) else: field_name = prefix + k flattened_dict[field_name] = schema[field_name].convert_fn(v) @@ -195,15 +211,18 @@ def _ConvertToCanonicalSqlDict(self, schema, raw_dict, prefix=""): def _FlushAllRows(self, db_connection, table_name): """Copies rows from the given db into the output file then deletes them.""" for sql in db_connection.iterdump(): - if (sql.startswith("CREATE TABLE") or - sql.startswith("BEGIN TRANSACTION") or sql.startswith("COMMIT")): + if ( + sql.startswith("CREATE TABLE") + or sql.startswith("BEGIN TRANSACTION") + or sql.startswith("COMMIT") + ): # These statements only need to be written once. continue # The archive generator expects strings (not Unicode objects returned by # the pysqlite library). yield self.archive_generator.WriteFileChunk((sql + "\n").encode("utf-8")) with db_connection: - db_connection.cursor().execute("DELETE FROM \"%s\";" % table_name) + db_connection.cursor().execute('DELETE FROM "%s";' % table_name) def Finish(self): manifest = {"export_stats": self.export_counts} diff --git a/grr/server/grr_response_server/output_plugins/sqlite_plugin_test.py b/grr/server/grr_response_server/output_plugins/sqlite_plugin_test.py index 61db1aaed6..a0ab6cc0b9 100644 --- a/grr/server/grr_response_server/output_plugins/sqlite_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/sqlite_plugin_test.py @@ -3,10 +3,10 @@ import datetime import os +import sqlite3 import zipfile from absl import app -import sqlite3 import yaml from grr_response_core.lib import rdfvalue @@ -27,7 +27,8 @@ class TestEmbeddedStruct(rdf_structs.RDFProtoStruct): type_description = type_info.TypeDescriptorSet( rdf_structs.ProtoString(name="e_string_field", field_number=1), - rdf_structs.ProtoDouble(name="e_double_field", field_number=2)) + rdf_structs.ProtoDouble(name="e_double_field", field_number=2), + ) class SqliteTestStruct(rdf_structs.RDFProtoStruct): @@ -44,22 +45,27 @@ class SqliteTestStruct(rdf_structs.RDFProtoStruct): name="enum_field", field_number=7, enum_name="EnumField", - enum={ - "FIRST": 1, - "SECOND": 2 - }), rdf_structs.ProtoBoolean(name="bool_field", field_number=8), + enum={"FIRST": 1, "SECOND": 2}, + ), + rdf_structs.ProtoBoolean(name="bool_field", field_number=8), rdf_structs.ProtoRDFValue( - name="urn_field", field_number=9, rdf_type="RDFURN"), + name="urn_field", field_number=9, rdf_type="RDFURN" + ), rdf_structs.ProtoRDFValue( - name="time_field", field_number=10, rdf_type="RDFDatetime"), + name="time_field", field_number=10, rdf_type="RDFDatetime" + ), rdf_structs.ProtoRDFValue( name="time_field_seconds", field_number=11, - rdf_type="RDFDatetimeSeconds"), + rdf_type="RDFDatetimeSeconds", + ), rdf_structs.ProtoRDFValue( - name="duration_field", field_number=12, rdf_type="DurationSeconds"), + name="duration_field", field_number=12, rdf_type="DurationSeconds" + ), rdf_structs.ProtoEmbedded( - name="embedded_field", field_number=13, nested=TestEmbeddedStruct)) + name="embedded_field", field_number=13, nested=TestEmbeddedStruct + ), + ) class SqliteInstantOutputPluginTest(test_plugins.InstantOutputPluginTestBase): @@ -79,7 +85,9 @@ class SqliteInstantOutputPluginTest(test_plugins.InstantOutputPluginTestBase): st_size=0, st_atime=1493596800, # Midnight, 01.05.2017 UTC in seconds st_mtime=1493683200, # Midnight, 01.05.2017 UTC in seconds - st_ctime=1493683200) for i in range(10) + st_ctime=1493683200, + ) + for i in range(10) ] def setUp(self): @@ -98,7 +106,8 @@ def testColumnTypeInference(self): schema = self.plugin._GetSqliteSchema(SqliteTestStruct) column_types = {k: v.sqlite_type for k, v in schema.items()} self.assertEqual( - column_types, { + column_types, + { "string_field": "TEXT", "bytes_field": "BLOB", "uint_field": "INTEGER", @@ -112,8 +121,9 @@ def testColumnTypeInference(self): "time_field_seconds": "INTEGER", "duration_field": "INTEGER", "embedded_field.e_string_field": "TEXT", - "embedded_field.e_double_field": "REAL" - }) + "embedded_field.e_double_field": "REAL", + }, + ) def testConversionToCanonicalSqlDict(self): schema = self.plugin._GetSqliteSchema(SqliteTestStruct) @@ -128,14 +138,19 @@ def testConversionToCanonicalSqlDict(self): bool_field=True, urn_field=rdfvalue.RDFURN("www.test.com"), time_field=rdfvalue.RDFDatetime.FromDatetime( - datetime.datetime(2017, 5, 1)), + datetime.datetime(2017, 5, 1) + ), time_field_seconds=rdfvalue.RDFDatetimeSeconds.FromDatetime( - datetime.datetime(2017, 5, 2)), + datetime.datetime(2017, 5, 2) + ), duration_field=rdfvalue.Duration.From(123, rdfvalue.SECONDS), embedded_field=TestEmbeddedStruct( - e_string_field="e_string_value", e_double_field=0.789)) + e_string_field="e_string_value", e_double_field=0.789 + ), + ) sql_dict = self.plugin._ConvertToCanonicalSqlDict( - schema, test_struct.ToPrimitiveDict()) + schema, test_struct.ToPrimitiveDict() + ) self.assertEqual( sql_dict, { @@ -152,29 +167,29 @@ def testConversionToCanonicalSqlDict(self): "time_field_seconds": 1493683200000000, # Midnight, May 2 "duration_field": 123000000, "embedded_field.e_string_field": "e_string_value", - "embedded_field.e_double_field": 0.789 - }) + "embedded_field.e_double_field": 0.789, + }, + ) @export_test_lib.WithAllExportConverters def testExportedFilenamesAndManifestForValuesOfSameType(self): zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES}) + {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES} + ) self.assertEqual( set(zip_fd.namelist()), - {"%s/MANIFEST" % prefix, - "%s/ExportedFile_from_StatEntry.sql" % prefix}) + {"%s/MANIFEST" % prefix, "%s/ExportedFile_from_StatEntry.sql" % prefix}, + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) - self.assertEqual(parsed_manifest, - {"export_stats": { - "StatEntry": { - "ExportedFile": 10 - } - }}) + self.assertEqual( + parsed_manifest, {"export_stats": {"StatEntry": {"ExportedFile": 10}}} + ) @export_test_lib.WithAllExportConverters def testExportedTableStructureForValuesOfSameType(self): zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES}) + {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES} + ) sqlite_dump_path = "%s/ExportedFile_from_StatEntry.sql" % prefix sqlite_dump = zip_fd.read(sqlite_dump_path).decode("utf-8") @@ -202,7 +217,8 @@ def testExportedTableStructureForValuesOfSameType(self): @export_test_lib.WithAllExportConverters def testExportedRowsForValuesOfSameType(self): zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES}) + {rdf_client_fs.StatEntry: self.STAT_ENTRY_RESPONSES} + ) sqlite_dump_path = "%s/ExportedFile_from_StatEntry.sql" % prefix sqlite_dump = zip_fd.read(sqlite_dump_path).decode("utf-8") @@ -212,14 +228,28 @@ def testExportedRowsForValuesOfSameType(self): self.db_cursor.executescript(sqlite_dump) select_columns = [ - "metadata.client_urn", "metadata.source_urn", "urn", "st_mode", - "st_ino", "st_dev", "st_nlink", "st_uid", "st_gid", "st_size", - "st_atime", "st_mtime", "st_ctime", "st_blksize", "st_rdev", "symlink" + "metadata.client_urn", + "metadata.source_urn", + "urn", + "st_mode", + "st_ino", + "st_dev", + "st_nlink", + "st_uid", + "st_gid", + "st_size", + "st_atime", + "st_mtime", + "st_ctime", + "st_blksize", + "st_rdev", + "symlink", ] - escaped_column_names = ["\"%s\"" % c for c in select_columns] - self.db_cursor.execute("SELECT %s FROM " - "\"ExportedFile.from_StatEntry\";" % - ",".join(escaped_column_names)) + escaped_column_names = ['"%s"' % c for c in select_columns] + self.db_cursor.execute( + 'SELECT %s FROM "ExportedFile.from_StatEntry";' + % ",".join(escaped_column_names) + ) rows = self.db_cursor.fetchall() self.assertLen(rows, 10) for i, row in enumerate(rows): @@ -240,7 +270,7 @@ def testExportedRowsForValuesOfSameType(self): "st_ctime": 1493683200000000, "st_blksize": 0, "st_rdev": 0, - "symlink": "" + "symlink": "", } self.assertEqual(results, expected_results) @@ -249,38 +279,40 @@ def testExportedFilenamesAndManifestForValuesOfMultipleTypes(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS")) + pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS") + ) ], - rdf_client.Process: [rdf_client.Process(pid=42)] + rdf_client.Process: [rdf_client.Process(pid=42)], }) self.assertEqual( - set(zip_fd.namelist()), { + set(zip_fd.namelist()), + { "%s/MANIFEST" % prefix, "%s/ExportedFile_from_StatEntry.sql" % prefix, - "%s/ExportedProcess_from_Process.sql" % prefix - }) + "%s/ExportedProcess_from_Process.sql" % prefix, + }, + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) self.assertEqual( - parsed_manifest, { + parsed_manifest, + { "export_stats": { - "StatEntry": { - "ExportedFile": 1 - }, - "Process": { - "ExportedProcess": 1 - } + "StatEntry": {"ExportedFile": 1}, + "Process": {"ExportedProcess": 1}, } - }) + }, + ) @export_test_lib.WithAllExportConverters def testExportedRowsForValuesOfMultipleTypes(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS")) + pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS") + ) ], - rdf_client.Process: [rdf_client.Process(pid=42)] + rdf_client.Process: [rdf_client.Process(pid=42)], }) with self.db_connection: stat_entry_script_path = "%s/ExportedFile_from_StatEntry.sql" % prefix @@ -293,8 +325,9 @@ def testExportedRowsForValuesOfMultipleTypes(self): self.db_cursor.executescript(process_script) self.db_cursor.execute( - "SELECT \"metadata.client_urn\", \"metadata.source_urn\", urn " - "FROM \"ExportedFile.from_StatEntry\";") + 'SELECT "metadata.client_urn", "metadata.source_urn", urn ' + 'FROM "ExportedFile.from_StatEntry";' + ) stat_entry_results = self.db_cursor.fetchall() self.assertLen(stat_entry_results, 1) # Client URN @@ -302,12 +335,14 @@ def testExportedRowsForValuesOfMultipleTypes(self): # Source URN self.assertEqual(stat_entry_results[0][1], str(self.results_urn)) # URN - self.assertEqual(stat_entry_results[0][2], - "aff4:/%s/fs/os/foo/bar" % self.client_id) + self.assertEqual( + stat_entry_results[0][2], "aff4:/%s/fs/os/foo/bar" % self.client_id + ) self.db_cursor.execute( - "SELECT \"metadata.client_urn\", \"metadata.source_urn\", pid " - "FROM \"ExportedProcess.from_Process\";") + 'SELECT "metadata.client_urn", "metadata.source_urn", pid ' + 'FROM "ExportedProcess.from_Process";' + ) process_results = self.db_cursor.fetchall() self.assertLen(process_results, 1) # Client URN @@ -322,23 +357,28 @@ def testHandlingOfNonAsciiCharacters(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/中国新闻网新闻中", pathtype="OS")) + pathspec=rdf_paths.PathSpec( + path="/中国新闻网新闻中", pathtype="OS" + ) + ) ] }) self.assertEqual( set(zip_fd.namelist()), - {"%s/MANIFEST" % prefix, - "%s/ExportedFile_from_StatEntry.sql" % prefix}) + {"%s/MANIFEST" % prefix, "%s/ExportedFile_from_StatEntry.sql" % prefix}, + ) with self.db_connection: sqlite_dump_path = "%s/ExportedFile_from_StatEntry.sql" % prefix sqlite_dump = zip_fd.read(sqlite_dump_path).decode("utf-8") self.db_cursor.executescript(sqlite_dump) - self.db_cursor.execute("SELECT urn FROM \"ExportedFile.from_StatEntry\";") + self.db_cursor.execute('SELECT urn FROM "ExportedFile.from_StatEntry";') results = self.db_cursor.fetchall() self.assertLen(results, 1) - self.assertEqual(results[0][0], "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id) + self.assertEqual( + results[0][0], "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id + ) @export_test_lib.WithAllExportConverters def testHandlingOfMultipleRowBatches(self): @@ -348,21 +388,24 @@ def testHandlingOfMultipleRowBatches(self): for i in range(num_rows): responses.append( rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"))) + pathspec=rdf_paths.PathSpec(path="/foo/bar/%d" % i, pathtype="OS") + ) + ) zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: responses}) + {rdf_client_fs.StatEntry: responses} + ) with self.db_connection: sqlite_dump_path = "%s/ExportedFile_from_StatEntry.sql" % prefix sqlite_dump = zip_fd.read(sqlite_dump_path).decode("utf-8") self.db_cursor.executescript(sqlite_dump) - self.db_cursor.execute("SELECT urn FROM \"ExportedFile.from_StatEntry\";") + self.db_cursor.execute('SELECT urn FROM "ExportedFile.from_StatEntry";') results = self.db_cursor.fetchall() self.assertLen(results, num_rows) for i in range(num_rows): - self.assertEqual(results[i][0], - "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i)) + self.assertEqual( + results[i][0], "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i) + ) def main(argv): diff --git a/grr/server/grr_response_server/output_plugins/test_plugins.py b/grr/server/grr_response_server/output_plugins/test_plugins.py index 45889a09ff..b5009f5004 100644 --- a/grr/server/grr_response_server/output_plugins/test_plugins.py +++ b/grr/server/grr_response_server/output_plugins/test_plugins.py @@ -47,11 +47,13 @@ def ProcessValues(self, values_by_cls): messages = [] for value in values: messages.append( - rdf_flows.GrrMessage(source=self.client_id, payload=value)) + rdf_flows.GrrMessage(source=self.client_id, payload=value) + ) # pylint: disable=cell-var-from-loop chunks.extend( - list(self.plugin.ProcessValues(value_cls, lambda: messages))) + list(self.plugin.ProcessValues(value_cls, lambda: messages)) + ) # pylint: enable=cell-var-from-loop chunks.extend(list(self.plugin.Finish())) @@ -86,7 +88,8 @@ def Finish(self): class TestInstantOutputPluginWithExportConverstion( - instant_output_plugin.InstantOutputPluginWithExportConversion): + instant_output_plugin.InstantOutputPluginWithExportConversion +): """Test plugin with export conversion.""" def Start(self): diff --git a/grr/server/grr_response_server/output_plugins/yaml_plugin.py b/grr/server/grr_response_server/output_plugins/yaml_plugin.py index 4da506275c..a7c3d9bc4c 100644 --- a/grr/server/grr_response_server/output_plugins/yaml_plugin.py +++ b/grr/server/grr_response_server/output_plugins/yaml_plugin.py @@ -26,7 +26,8 @@ def _SerializeToYaml(value): class YamlInstantOutputPluginWithExportConversion( - instant_output_plugin.InstantOutputPluginWithExportConversion): + instant_output_plugin.InstantOutputPluginWithExportConversion +): """Instant output plugin that flattens results into YAML.""" plugin_name = "flattened-yaml-zip" @@ -48,20 +49,26 @@ def path_prefix(self): def Start(self): self.archive_generator = utils.StreamingZipGenerator( - compression=zipfile.ZIP_DEFLATED) + compression=zipfile.ZIP_DEFLATED + ) self.export_counts = {} return [] - def ProcessSingleTypeExportedValues(self, original_value_type, - exported_values): + def ProcessSingleTypeExportedValues( + self, original_value_type, exported_values + ): first_value = next(exported_values, None) if not first_value: return yield self.archive_generator.WriteFileHeader( - "%s/%s/from_%s.yaml" % (self.path_prefix, - first_value.__class__.__name__, - original_value_type.__name__)) + "%s/%s/from_%s.yaml" + % ( + self.path_prefix, + first_value.__class__.__name__, + original_value_type.__name__, + ) + ) serialized_value_bytes = _SerializeToYaml(first_value).encode("utf-8") yield self.archive_generator.WriteFileChunk(serialized_value_bytes) @@ -79,7 +86,8 @@ def ProcessSingleTypeExportedValues(self, original_value_type, yield self.archive_generator.WriteFileFooter() counts_for_original_type = self.export_counts.setdefault( - original_value_type.__name__, dict()) + original_value_type.__name__, dict() + ) counts_for_original_type[first_value.__class__.__name__] = counter def Finish(self): diff --git a/grr/server/grr_response_server/output_plugins/yaml_plugin_test.py b/grr/server/grr_response_server/output_plugins/yaml_plugin_test.py index 2f0538a26b..b9dff897d7 100644 --- a/grr/server/grr_response_server/output_plugins/yaml_plugin_test.py +++ b/grr/server/grr_response_server/output_plugins/yaml_plugin_test.py @@ -35,7 +35,8 @@ def testYamlPluginWithValuesOfSameType(self): responses.append( rdf_client_fs.StatEntry( pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"), + path="/foo/bar/%d" % i, pathtype="OS" + ), st_mode=33184, # octal = 100640 => u=rw,g=r,o= => -rw-r----- st_ino=1063090, st_dev=64512, @@ -45,37 +46,45 @@ def testYamlPluginWithValuesOfSameType(self): st_size=0, st_atime=1336469177, st_mtime=1336129892, - st_ctime=1336129892)) + st_ctime=1336129892, + ) + ) zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: responses}) + {rdf_client_fs.StatEntry: responses} + ) self.assertEqual( - set(zip_fd.namelist()), { + set(zip_fd.namelist()), + { "%s/MANIFEST" % prefix, - "%s/ExportedFile/from_StatEntry.yaml" % prefix - }) + "%s/ExportedFile/from_StatEntry.yaml" % prefix, + }, + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) - self.assertEqual(parsed_manifest, - {"export_stats": { - "StatEntry": { - "ExportedFile": 10 - } - }}) + self.assertEqual( + parsed_manifest, {"export_stats": {"StatEntry": {"ExportedFile": 10}}} + ) parsed_output = yaml.safe_load( - zip_fd.read("%s/ExportedFile/from_StatEntry.yaml" % prefix)) + zip_fd.read("%s/ExportedFile/from_StatEntry.yaml" % prefix) + ) self.assertLen(parsed_output, 10) for i in range(10): # Only the client_urn is filled in by the plugin. Doing lookups for # all the clients metadata is possible but expensive. It doesn't seem to # be worth it. - self.assertEqual(parsed_output[i]["metadata"]["client_urn"], - "aff4:/%s" % self.client_id) - self.assertEqual(parsed_output[i]["metadata"]["source_urn"], - str(self.results_urn)) - self.assertEqual(parsed_output[i]["urn"], - "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i)) + self.assertEqual( + parsed_output[i]["metadata"]["client_urn"], + "aff4:/%s" % self.client_id, + ) + self.assertEqual( + parsed_output[i]["metadata"]["source_urn"], str(self.results_urn) + ) + self.assertEqual( + parsed_output[i]["urn"], + "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i), + ) self.assertEqual(parsed_output[i]["st_mode"], "-rw-r-----") self.assertEqual(parsed_output[i]["st_ino"], "1063090") self.assertEqual(parsed_output[i]["st_dev"], "64512") @@ -95,46 +104,52 @@ def testYamlPluginWithValuesOfMultipleTypes(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS")) + pathspec=rdf_paths.PathSpec(path="/foo/bar", pathtype="OS") + ) ], - rdf_client.Process: [rdf_client.Process(pid=42)] + rdf_client.Process: [rdf_client.Process(pid=42)], }) self.assertEqual( - set(zip_fd.namelist()), { + set(zip_fd.namelist()), + { "%s/MANIFEST" % prefix, "%s/ExportedFile/from_StatEntry.yaml" % prefix, - "%s/ExportedProcess/from_Process.yaml" % prefix - }) + "%s/ExportedProcess/from_Process.yaml" % prefix, + }, + ) parsed_manifest = yaml.safe_load(zip_fd.read("%s/MANIFEST" % prefix)) self.assertEqual( - parsed_manifest, { + parsed_manifest, + { "export_stats": { - "StatEntry": { - "ExportedFile": 1 - }, - "Process": { - "ExportedProcess": 1 - } + "StatEntry": {"ExportedFile": 1}, + "Process": {"ExportedProcess": 1}, } - }) + }, + ) parsed_output = yaml.safe_load( - zip_fd.read("%s/ExportedFile/from_StatEntry.yaml" % prefix)) + zip_fd.read("%s/ExportedFile/from_StatEntry.yaml" % prefix) + ) self.assertLen(parsed_output, 1) # Only the client_urn is filled in by the plugin. Doing lookups for # all the clients metadata is possible but expensive. It doesn't seem to # be worth it. - self.assertEqual(parsed_output[0]["metadata"]["client_urn"], - "aff4:/%s" % self.client_id) - self.assertEqual(parsed_output[0]["metadata"]["source_urn"], - str(self.results_urn)) - self.assertEqual(parsed_output[0]["urn"], - "aff4:/%s/fs/os/foo/bar" % self.client_id) + self.assertEqual( + parsed_output[0]["metadata"]["client_urn"], "aff4:/%s" % self.client_id + ) + self.assertEqual( + parsed_output[0]["metadata"]["source_urn"], str(self.results_urn) + ) + self.assertEqual( + parsed_output[0]["urn"], "aff4:/%s/fs/os/foo/bar" % self.client_id + ) parsed_output = yaml.safe_load( - zip_fd.read("%s/ExportedProcess/from_Process.yaml" % prefix)) + zip_fd.read("%s/ExportedProcess/from_Process.yaml" % prefix) + ) self.assertLen(parsed_output, 1) self.assertEqual(parsed_output[0]["pid"], "42") @@ -143,21 +158,29 @@ def testYamlPluginWritesUnicodeValuesCorrectly(self): zip_fd, prefix = self.ProcessValuesToZip({ rdf_client_fs.StatEntry: [ rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec(path="/中国新闻网新闻中", pathtype="OS")) + pathspec=rdf_paths.PathSpec( + path="/中国新闻网新闻中", pathtype="OS" + ) + ) ] }) self.assertEqual( - set(zip_fd.namelist()), { + set(zip_fd.namelist()), + { "%s/MANIFEST" % prefix, - "%s/ExportedFile/from_StatEntry.yaml" % prefix - }) + "%s/ExportedFile/from_StatEntry.yaml" % prefix, + }, + ) parsed_output = yaml.safe_load( - zip_fd.open("%s/ExportedFile/from_StatEntry.yaml" % prefix)) + zip_fd.open("%s/ExportedFile/from_StatEntry.yaml" % prefix) + ) self.assertLen(parsed_output, 1) - self.assertEqual(parsed_output[0]["urn"], - "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id) + self.assertEqual( + parsed_output[0]["urn"], + "aff4:/%s/fs/os/中国新闻网新闻中" % self.client_id, + ) @export_test_lib.WithAllExportConverters def testYamlPluginWritesMoreThanOneBatchOfRowsCorrectly(self): @@ -167,17 +190,22 @@ def testYamlPluginWritesMoreThanOneBatchOfRowsCorrectly(self): for i in range(num_rows): responses.append( rdf_client_fs.StatEntry( - pathspec=rdf_paths.PathSpec( - path="/foo/bar/%d" % i, pathtype="OS"))) + pathspec=rdf_paths.PathSpec(path="/foo/bar/%d" % i, pathtype="OS") + ) + ) zip_fd, prefix = self.ProcessValuesToZip( - {rdf_client_fs.StatEntry: responses}) + {rdf_client_fs.StatEntry: responses} + ) parsed_output = yaml.safe_load( - zip_fd.open("%s/ExportedFile/from_StatEntry.yaml" % prefix)) + zip_fd.open("%s/ExportedFile/from_StatEntry.yaml" % prefix) + ) self.assertLen(parsed_output, num_rows) for i in range(num_rows): - self.assertEqual(parsed_output[i]["urn"], - "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i)) + self.assertEqual( + parsed_output[i]["urn"], + "aff4:/%s/fs/os/foo/bar/%d" % (self.client_id, i), + ) def main(argv): diff --git a/grr/server/grr_response_server/rdfvalues/flow_objects.py b/grr/server/grr_response_server/rdfvalues/flow_objects.py index 6e360a528f..a7ecd59263 100644 --- a/grr/server/grr_response_server/rdfvalues/flow_objects.py +++ b/grr/server/grr_response_server/rdfvalues/flow_objects.py @@ -30,45 +30,11 @@ def __init__(self, *args, **kwargs): self.next_response_id = 1 -class FlowMessage(object): - """Base class for all messages flows can receive.""" - - @property - def client_id(self) -> str: - return self.Get("client_id") - - @client_id.setter - def client_id(self, value: str) -> None: - self.Set("client_id", value) - - @property - def flow_id(self) -> str: - return self.Get("flow_id") - - @flow_id.setter - def flow_id(self, value: str) -> None: - self.Set("flow_id", value) - - @property - def request_id(self) -> int: - return self.Get("request_id") - - @request_id.setter - def request_id(self, value: int) -> None: - self.Set("request_id", value) - - @property - def response_id(self) -> int: - return self.Get("response_id") - - @response_id.setter - def response_id(self, value: int) -> None: - self.Set("response_id", value) - - -class FlowResponse(FlowMessage, rdf_structs.RDFProtoStruct): +class FlowResponse(rdf_structs.RDFProtoStruct): protobuf = flows_pb2.FlowResponse - rdf_deps = [] + rdf_deps = [ + rdfvalue.RDFDatetime, + ] def AsLegacyGrrMessage(self): return rdf_flows.GrrMessage( @@ -80,9 +46,11 @@ def AsLegacyGrrMessage(self): payload=self.payload) -class FlowIterator(FlowMessage, rdf_structs.RDFProtoStruct): +class FlowIterator(rdf_structs.RDFProtoStruct): protobuf = flows_pb2.FlowIterator - rdf_deps = [] + rdf_deps = [ + rdfvalue.RDFDatetime, + ] def AsLegacyGrrMessage(self): return rdf_flows.GrrMessage( @@ -93,13 +61,14 @@ def AsLegacyGrrMessage(self): timestamp=self.timestamp) -class FlowStatus(FlowMessage, rdf_structs.RDFProtoStruct): +class FlowStatus(rdf_structs.RDFProtoStruct): """The flow status object.""" protobuf = flows_pb2.FlowStatus rdf_deps = [ rdf_client_stats.CpuSeconds, rdfvalue.Duration, + rdfvalue.RDFDatetime, ] def AsLegacyGrrMessage(self): diff --git a/grr/server/grr_response_server/rdfvalues/hunt_objects.py b/grr/server/grr_response_server/rdfvalues/hunt_objects.py index 716ecfc158..e983a50605 100644 --- a/grr/server/grr_response_server/rdfvalues/hunt_objects.py +++ b/grr/server/grr_response_server/rdfvalues/hunt_objects.py @@ -133,8 +133,11 @@ def expired(self) -> bool: return False -def IsHuntSuitableForFlowProcessing(hunt_state): - return hunt_state in [Hunt.HuntState.PAUSED, Hunt.HuntState.STARTED] +def IsHuntSuitableForFlowProcessing(hunt_state: int) -> bool: + return hunt_state in [ + hunts_pb2.Hunt.HuntState.PAUSED, + hunts_pb2.Hunt.HuntState.STARTED, + ] class HuntMetadata(rdf_structs.RDFProtoStruct): diff --git a/grr/server/grr_response_server/rdfvalues/hunts.py b/grr/server/grr_response_server/rdfvalues/hunts.py index 27a27afc83..efd117db6d 100644 --- a/grr/server/grr_response_server/rdfvalues/hunts.py +++ b/grr/server/grr_response_server/rdfvalues/hunts.py @@ -120,8 +120,9 @@ def Validate(self): def GetFlowArgsClass(self): if self.flow_runner_args.flow_name: - flow_cls = registry.AFF4FlowRegistry.FlowClassByName( - self.flow_runner_args.flow_name) + flow_cls = registry.FlowRegistry.FlowClassByName( + self.flow_runner_args.flow_name + ) # The required protobuf for this class is in args_type. return flow_cls.args_type diff --git a/grr/server/grr_response_server/server_logging.py b/grr/server/grr_response_server/server_logging.py index d3a1094054..0f1309faa9 100644 --- a/grr/server/grr_response_server/server_logging.py +++ b/grr/server/grr_response_server/server_logging.py @@ -12,7 +12,7 @@ from grr_response_core import config from grr_response_core.stats import metrics from grr_response_server import data_store -from grr_response_server.rdfvalues import objects as rdf_objects +from grr_response_server.models import events try: # pylint: disable=g-import-not-at-top @@ -72,8 +72,7 @@ def LogHttpAdminUIAccess(self, request, response): logging.info(log_msg) if response.headers.get("X-No-Log") != "True": - entry = rdf_objects.APIAuditEntry.FromHttpRequestResponse( - request, response) + entry = events.APIAuditEntryFromHttpRequestResponse(request, response) data_store.REL_DB.WriteAPIAuditEntry(entry) def LogHttpFrontendAccess(self, request, source=None, message_count=None): diff --git a/grr/server/grr_response_server/server_logging_test.py b/grr/server/grr_response_server/server_logging_test.py index be65f04157..473a6bbaa2 100644 --- a/grr/server/grr_response_server/server_logging_test.py +++ b/grr/server/grr_response_server/server_logging_test.py @@ -12,8 +12,8 @@ from grr_response_proto import jobs_pb2 from grr_response_server import server_logging from grr_response_server.gui import api_call_context +from grr_response_server.gui import http_request from grr_response_server.gui import http_response -from grr_response_server.gui import wsgiapp from grr.test_lib import acl_test_lib from grr.test_lib import stats_test_lib from grr.test_lib import test_lib @@ -46,10 +46,10 @@ def testGetEventId(self): "Invalid event ID generated") def testLogHttpAdminUIAccess(self): - request = wsgiapp.HttpRequest({ + request = http_request.HttpRequest({ "wsgi.url_scheme": "http", "SERVER_NAME": "foo.bar", - "SERVER_PORT": "1234" + "SERVER_PORT": "1234", }) request.user = "testuser" diff --git a/grr/server/grr_response_server/server_stubs.py b/grr/server/grr_response_server/server_stubs.py index cb2a8a66c6..a65fbc7b3a 100644 --- a/grr/server/grr_response_server/server_stubs.py +++ b/grr/server/grr_response_server/server_stubs.py @@ -272,15 +272,6 @@ class Grep(ClientActionStub): out_rdfvalues = [rdf_client.BufferReference] -# from network.py -# Deprecated action, kept for outdated clients. -class Netstat(ClientActionStub): - """Gather open network connection stats.""" - - in_rdfvalue = None - out_rdfvalues = [rdf_client_network.NetworkConnection] - - class ListNetworkConnections(ClientActionStub): """Gather open network connection stats.""" diff --git a/grr/server/grr_response_server/signed_binary_utils_test.py b/grr/server/grr_response_server/signed_binary_utils_test.py index 78605d44fa..f7950955e2 100644 --- a/grr/server/grr_response_server/signed_binary_utils_test.py +++ b/grr/server/grr_response_server/signed_binary_utils_test.py @@ -64,14 +64,16 @@ def testWriteSignedBinary(self): test_urn) self.assertGreater(timestamp.AsMicrosecondsSinceEpoch(), 0) self.assertIsInstance(blobs_iter, collections.abc.Iterator) - # We expect blobs to have at most 3 contiguous bytes of data. - expected_blobs = [ - rdf_crypto.SignedBlob().Sign(b"\x00\x11\x22", self._private_key), - rdf_crypto.SignedBlob().Sign(b"\x33\x44\x55", self._private_key), - rdf_crypto.SignedBlob().Sign(b"\x66\x77\x88", self._private_key), - rdf_crypto.SignedBlob().Sign(b"\x99", self._private_key) - ] - self.assertCountEqual(list(blobs_iter), expected_blobs) + + blobs_list = list(blobs_iter) + blobs_list[0].Verify(self._public_key) + self.assertContainsSubset(blobs_list[0].data, binary_data) + blobs_list[1].Verify(self._public_key) + self.assertContainsSubset(blobs_list[1].data, binary_data) + blobs_list[2].Verify(self._public_key) + self.assertContainsSubset(blobs_list[2].data, binary_data) + blobs_list[3].Verify(self._public_key) + self.assertContainsSubset(blobs_list[3].data, binary_data) def testWriteSignedBinaryBlobs(self): test_urn = rdfvalue.RDFURN("aff4:/config/executables/foo") diff --git a/grr/server/grr_response_server/sinks/__init__.py b/grr/server/grr_response_server/sinks/__init__.py index 628477c2d6..bb98eba8a0 100644 --- a/grr/server/grr_response_server/sinks/__init__.py +++ b/grr/server/grr_response_server/sinks/__init__.py @@ -19,6 +19,7 @@ do that. This is where the startup sink comes into play and allows the agent to send information about its startup. """ + from typing import Mapping from grr_response_server.sinks import abstract diff --git a/grr/server/grr_response_server/sinks/abstract.py b/grr/server/grr_response_server/sinks/abstract.py index b8af7ac0a3..d20ad4308c 100644 --- a/grr/server/grr_response_server/sinks/abstract.py +++ b/grr/server/grr_response_server/sinks/abstract.py @@ -3,6 +3,7 @@ See documentation for the root `sinks` module for more details. """ + import abc from grr_response_proto import rrg_pb2 diff --git a/grr/server/grr_response_server/sinks/blob.py b/grr/server/grr_response_server/sinks/blob.py index 1324e9eaa2..21eda224ad 100644 --- a/grr/server/grr_response_server/sinks/blob.py +++ b/grr/server/grr_response_server/sinks/blob.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with the blob sink.""" + import logging from grr_response_server import data_store diff --git a/grr/server/grr_response_server/sinks/startup.py b/grr/server/grr_response_server/sinks/startup.py index c7dd3f5ef1..73dc6badd0 100644 --- a/grr/server/grr_response_server/sinks/startup.py +++ b/grr/server/grr_response_server/sinks/startup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A module with the startup sink.""" + from grr_response_server import data_store from grr_response_server.sinks import abstract from grr_response_proto import rrg_pb2 diff --git a/grr/server/grr_response_server/sinks/test_lib.py b/grr/server/grr_response_server/sinks/test_lib.py index 7c274e9a27..c40981b033 100644 --- a/grr/server/grr_response_server/sinks/test_lib.py +++ b/grr/server/grr_response_server/sinks/test_lib.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """A testing utilities to simplify test code working with sinks.""" + import collections from typing import Sequence diff --git a/grr/server/grr_response_server/throttle.py b/grr/server/grr_response_server/throttle.py index 45339efbd7..c12ef14e29 100644 --- a/grr/server/grr_response_server/throttle.py +++ b/grr/server/grr_response_server/throttle.py @@ -4,6 +4,7 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_server import data_store +from grr_response_server.rdfvalues import mig_flow_objects class Error(Exception): @@ -51,6 +52,7 @@ def _LoadFlows(self, client_id, min_create_time): client_id=client_id, min_create_time=min_create_time, include_child_flows=False) + flow_list = [mig_flow_objects.ToRDFFlow(flow) for flow in flow_list] for flow_obj in flow_list: yield flow_obj diff --git a/grr/server/grr_response_server/worker_lib.py b/grr/server/grr_response_server/worker_lib.py index ce23c9c1fd..0e16612922 100644 --- a/grr/server/grr_response_server/worker_lib.py +++ b/grr/server/grr_response_server/worker_lib.py @@ -7,9 +7,10 @@ from grr_response_core.lib import rdfvalue from grr_response_core.lib import registry -from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.util import collection from grr_response_core.stats import metrics +from grr_response_proto import flows_pb2 +from grr_response_proto import objects_pb2 from grr_response_server import data_store from grr_response_server import flow_base from grr_response_server import handler_registry @@ -18,7 +19,9 @@ # pylint: enable=unused-import from grr_response_server.databases import db from grr_response_server.rdfvalues import flow_objects as rdf_flow_objects -from grr_response_server.rdfvalues import objects as rdf_objects +from grr_response_server.rdfvalues import mig_flow_objects +from grr_response_server.rdfvalues import mig_objects + WELL_KNOWN_FLOW_REQUESTS = metrics.Counter( "well_known_flow_requests", fields=[("flow", str)]) @@ -33,12 +36,18 @@ class FlowHasNothingToProcessError(Error): def ProcessMessageHandlerRequests( - requests: Sequence[rdf_objects.MessageHandlerRequest]): + requests: Sequence[objects_pb2.MessageHandlerRequest], +) -> None: """Processes message handler requests.""" - logging.info("Leased message handler request ids: %s", - ",".join(str(r.request_id) for r in requests)) + logging.info( + "Leased message handler request ids: %s", + ",".join(str(r.request_id) for r in requests), + ) grouped_requests = collection.Group(requests, lambda r: r.handler_name) for handler_name, requests_for_handler in grouped_requests.items(): + requests_for_handler = [ + mig_objects.ToRDFMessageHandlerRequest(r) for r in requests_for_handler + ] handler_cls = handler_registry.handler_name_map.get(handler_name) if not handler_cls: logging.error("Unknown message handler: %s", handler_name) @@ -92,34 +101,43 @@ def Run(self) -> None: self.Shutdown() def _ReleaseProcessedFlow(self, flow_obj: rdf_flow_objects.Flow) -> bool: + """Release a processed flow if the processing deadline is not exceeded.""" rdf_flow = flow_obj.rdf_flow if rdf_flow.processing_deadline < rdfvalue.RDFDatetime.Now(): raise flow_base.FlowError( - "Lease expired for flow %s on %s (%s)." % - (rdf_flow.flow_id, rdf_flow.client_id, rdf_flow.processing_deadline)) - + "Lease expired for flow %s on %s (%s)." + % ( + rdf_flow.flow_id, + rdf_flow.client_id, + rdf_flow.processing_deadline, + ), + ) flow_obj.FlushQueuedMessages() - return data_store.REL_DB.ReleaseProcessedFlow(rdf_flow) + proto_flow = mig_flow_objects.ToProtoFlow(rdf_flow) + return data_store.REL_DB.ReleaseProcessedFlow(proto_flow) def ProcessFlow( - self, flow_processing_request: rdf_flows.FlowProcessingRequest) -> None: + self, flow_processing_request: flows_pb2.FlowProcessingRequest + ) -> None: """The callback for the flow processing queue.""" - client_id = flow_processing_request.client_id flow_id = flow_processing_request.flow_id data_store.REL_DB.AckFlowProcessingRequests([flow_processing_request]) try: - rdf_flow = data_store.REL_DB.LeaseFlowForProcessing( + flow = data_store.REL_DB.LeaseFlowForProcessing( client_id, flow_id, - processing_time=rdfvalue.Duration.From(6, rdfvalue.HOURS)) + processing_time=rdfvalue.Duration.From(6, rdfvalue.HOURS), + ) except db.ParentHuntIsNotRunningError: flow_base.TerminateFlow(client_id, flow_id, "Parent hunt stopped.") return + rdf_flow = mig_flow_objects.ToRDFFlow(flow) + first_request_to_process = rdf_flow.next_request_to_process logging.info("Processing Flow %s/%s/%d (%s).", client_id, flow_id, first_request_to_process, rdf_flow.flow_class_name) diff --git a/grr/test/grr_response_test/end_to_end_tests/tests/artifacts.py b/grr/test/grr_response_test/end_to_end_tests/tests/artifacts.py index 25b1a37987..60f1f06260 100644 --- a/grr/test/grr_response_test/end_to_end_tests/tests/artifacts.py +++ b/grr/test/grr_response_test/end_to_end_tests/tests/artifacts.py @@ -116,12 +116,7 @@ def testWinEnvVariableWinDir(self): def testWinUserShellFolder(self): results = self._CollectArtifact("WindowsUserShellFolders") - # Results should be of type User. Check that each user has - # a temp folder and at least one has an appdata folder. - for r in results: - self.assertTrue(r.payload.temp) - - self.assertNotEmpty([r for r in results if r.payload.appdata]) + self.assertNotEmpty(results) class TestWindowsRegistryCollector(test_base.EndToEndTest): @@ -141,26 +136,6 @@ def runTest(self): self.assertIn("namespace", statentry.pathspec.path.lower()) -class TestWindowsUninstallKeysCollection(test_base.EndToEndTest): - """Tests the WindowsUninstallKeys artifact collection.""" - - platforms = [ - test_base.EndToEndTest.Platform.WINDOWS, - ] - - def runTest(self): - args = self.grr_api.types.CreateFlowArgs("ArtifactCollectorFlow") - args.artifact_list.append("WindowsUninstallKeys") - f = self.RunFlowAndWait("ArtifactCollectorFlow", args=args) - - # The result should contain a single SoftwarePackages proto with - # multiple entries in its 'packages' attribute. - results = list(f.ListResults()) - self.assertLen(results, 1) - self.assertTrue(hasattr(results[0].payload, "packages")) - self.assertNotEmpty(results[0].payload.packages) - - class TestKnowledgeBaseInitializationFlow(test_base.EndToEndTest): """Test knowledge base initialization flow.""" diff --git a/grr/test/grr_response_test/test_data/bigquery/ExportedFile.jsonlines b/grr/test/grr_response_test/test_data/bigquery/ExportedFile.jsonlines index 00de464ce0..9ce241de48 100644 --- a/grr/test/grr_response_test/test_data/bigquery/ExportedFile.jsonlines +++ b/grr/test/grr_response_test/test_data/bigquery/ExportedFile.jsonlines @@ -1,10 +1,10 @@ -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "1", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "0", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/0"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "2", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "1", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/1"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "3", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "2", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/2"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "4", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "3", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/3"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "5", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "4", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/4"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "6", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "5", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/5"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "7", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "6", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/6"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "8", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "7", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/7"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "9", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "8", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/8"} -{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "10", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "9", "st_atime": "2012-05-08 09:26:17", "metadata": {"uname": "Linux--buster/sid", "os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/9"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "1", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "0", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/0"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "2", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "1", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/1"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "3", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "2", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/2"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "4", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "3", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/3"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "5", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "4", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/4"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "6", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "5", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/5"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "7", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "6", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/6"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "8", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "7", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/7"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "9", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "8", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/8"} +{"cert_hasher_name": "", "st_blocks": "0", "hash_sha256": "", "hash_sha1": "", "st_nlink": "10", "pecoff_hash_sha1": "", "st_gid": "5000", "symlink": "", "cert_program_url": "", "st_ctime": "2012-05-04 11:11:32", "cert_certificates": "", "st_blksize": "0", "st_dev": "64512", "st_mode": "-rw-r-----", "content_sha256": "", "st_rdev": "0", "st_btime": "2012-05-27 09:35:38", "st_mtime": "2012-05-04 11:11:32", "pecoff_hash_md5": "", "basename": "9", "st_atime": "2012-05-08 09:26:17", "metadata": {"os": "Linux", "os_version": "buster/sid", "timestamp": "2015-10-28 01:31:13", "hardware_info": {"system_product_name": "", "bios_vendor": "", "bios_release_date": "", "system_sku_number": "", "serial_number": "", "system_manufacturer": "System-Manufacturer-0", "bios_revision": "", "bios_rom_size": "", "system_uuid": "", "system_assettag": "", "bios_version": "Bios-Version-0", "system_family": ""}, "hostname": "Host-0.example.com", "os_release": "", "client_urn": "aff4:/C.1000000000000000", "user_labels": "", "mac_address": "aabbccddee00\nbbccddeeff00", "annotations": "", "kernel_version": "4.0.0", "cloud_instance_type": "UNSET","cloud_instance_id": "", "source_urn": "aff4:/C.1000000000000000/Results", "client_age": null, "system_labels": "", "deprecated_session_id": null, "usernames": "user1,user2", "labels": ""}, "cert_signing_id": "", "cert_countersignature_chain_head_issuer": "", "cert_program_name": "", "st_size": "0", "st_ino": "1063090", "cert_chain_head_issuer": "", "st_uid": "139592", "hash_md5": "", "urn": "aff4:/C.1000000000000000/fs/os/foo/bar/9"} diff --git a/grr/test/grr_response_test/test_data/bigquery/ExportedFile.schema b/grr/test/grr_response_test/test_data/bigquery/ExportedFile.schema index f5700f317c..b42993c2c4 100644 --- a/grr/test/grr_response_test/test_data/bigquery/ExportedFile.schema +++ b/grr/test/grr_response_test/test_data/bigquery/ExportedFile.schema @@ -21,11 +21,6 @@ "description": "Age of the client.", "name": "client_age" }, - { - "type": "STRING", - "description": "Uname string.", - "name": "uname" - }, { "type": "STRING", "description": "The OS release identifier e.g. 7, OSX, debian.", diff --git a/grr/test/grr_response_test/test_data/dummyconfig.yaml b/grr/test/grr_response_test/test_data/dummyconfig.yaml index 2195a8302c..cc79242514 100644 --- a/grr/test/grr_response_test/test_data/dummyconfig.yaml +++ b/grr/test/grr_response_test/test_data/dummyconfig.yaml @@ -13,4 +13,3 @@ Client.executable_signing_public_key: | g9+hgrcgt7JRECXmho+WFrKM1H6z1CNe6uEALMfogNl9Mm2jCgavl1wbV92pIKT1 pQIDAQAB -----END PUBLIC KEY----- -CA.certificate: "-----BEGIN CERTIFICATE" diff --git a/grr/test/grr_response_test/test_data/grr_test.yaml b/grr/test/grr_response_test/test_data/grr_test.yaml index 10d3ecdfc0..336b048389 100644 --- a/grr/test/grr_response_test/test_data/grr_test.yaml +++ b/grr/test/grr_response_test/test_data/grr_test.yaml @@ -3,23 +3,6 @@ # configuration. AdminUI.csrf_secret_key: TESTKEYTESTKEYTESTKEY -Client.private_key: | - -----BEGIN PRIVATE KEY----- - MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAMdgLNxyvDnQsuqp - sgoxcDksO0SkTbw4N07tU9LqASsXV3JzUsAg3jwUxuHlJ1aEIgdm+ATSc+Mi9yYu - dYZonACGAwx4fO9b2X36gFEOLmw8zZFjJRVqbT6ZxohUFmaReeLgVx56U24CiDqb - +hBjSItDX1oGQg8lCHrZOBQ1PFLrAgMBAAECgYBuur1vmdFsErqrlZ+qBZccSbKJ - Bu+db2NYVHrPkuMfOZF5oQ7+YcLLf4aNgyV76Vq03b01gGSYa5zz+a2P/p1UuRIC - DskgLApXXmP+WemAgAV22QyRQQYSUvSr5JIR46GW9JWBgs98f/6IUNUUne6and9M - k5GlHlX5z7TRZCw6IQJBAPfe5qWdUl1xB9QuHOcAnkYYaybFUIzWYSk0IJ3NiNmQ - MrpljWdNx0/dC2bSFjhTCNkf4sc8aSfZToi5Y7ze4bkCQQDN6h1ZmJkIZ+crqCj/ - sdiH+Ykba3+eOTo+oMoV+w9VcxZpdefV5S0NfXrAFUzt329WVH4MYV9fz/Rb+6Lc - VvvDAkEAj7+GNW+6T5R4fNXNTy5tm6sXoSF3KGY/bLzdWYbUIZBdyvmP+uQBfdBs - h1G5LysAi6LRSsg/F6wPvnz9WZBMiQJAdCIP/5Ii7Sy8olCrHtrNBpNkEoTkavZX - tS62CwOXuFe6UixfXrFsYWldq6vXwWj8wDHTDWR1h/IfHSmkxqSARQJBANEyt/lQ - Decnn0QintT3zNAV26lnf8vABdrnp/IaqgFhfjW8NlBEHLpcKY0Cow+qNiCh7pxj - jzITFeE6mjs3k1I= - -----END PRIVATE KEY----- Client.executable_signing_public_key: | -----BEGIN PUBLIC KEY----- @@ -32,186 +15,12 @@ Client.executable_signing_public_key: | pQIDAQAB -----END PUBLIC KEY----- -CA.certificate: | - -----BEGIN CERTIFICATE----- - MIIGCzCCA/OgAwIBAgIJAIayxnA7Bp+3MA0GCSqGSIb3DQEBBQUAMD4xCzAJBgNV - BAYTAlVTMQwwCgYDVQQIEwNDQUwxCzAJBgNVBAcTAlNGMRQwEgYDVQQDEwtHUlIg - VGVzdCBDQTAeFw0xMTA1MjcxMjE0MDlaFw0yMTA1MjQxMjE0MDlaMD4xCzAJBgNV - BAYTAlVTMQwwCgYDVQQIEwNDQUwxCzAJBgNVBAcTAlNGMRQwEgYDVQQDEwtHUlIg - VGVzdCBDQTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBANI1Xr3HZdkM - g8Eqa4BgnrlZbh01kLHkq+kUGlcoyuNns9BqWS2drd8ITU1Tk788Gu7uQPVMZV2t - nQlol/0IWpq5hdMBFOb6AnMs0L02nLKEOsdXXwm5E1MFePl67SPdB3lUgDUwEemp - P5nPYe2yFoWQQQdFWJ75Ky+NSmE6yy+bsqUFP2cAkpgvRTe1aXwVLFQdjXNgm02z - uG1TGoKc3dnlwe+fAOtuA8eD7dPARflCCh8yBNiIddTpV+oxsZ2wwn+QjvRgj+ZM - 8zxjZPALEPdFHGo3LFHO3IBA9/RF69BwlogCG0b1L9VUPlTThYWia9VN5u07NoyN - 9MGOR32CpIRG+DB4bpU3kGDZnl+RFxBMVgcMtr7/7cNvsQ0oSJ8nNyuc9muceylq - 8h1h2cXQwBpsqxAxuwuu55tR+oJtWhCfhB116ipsI2CglBhzENfX1PUv/argtlx8 - 0Ct5Pb/3DbtHIdolxNTAp6FfhvkDWLIHXGuZJosRcOQjnjYAEo8C5vs9f4fgvKJ0 - Ffh8aOMIiKwyi6VXdz5GJtGPZl5mUKT3XpFmk+BCHxty4hJORB8zusc0Yz31T2cQ - xwTdFUwbVW/sdkTtBG5KzcJ7aGcVqrjaFTkQ/e2xU4HP6hhE2u8lJhAkUzpKVxdf - 4VqPzV2koi7D5xpojoyL+5oYXh7rxGM1AgMBAAGjggEKMIIBBjAdBgNVHQ4EFgQU - O4+Xefeqvq3W6/eaPxaNv8IHpcswbgYDVR0jBGcwZYAUO4+Xefeqvq3W6/eaPxaN - v8IHpcuhQqRAMD4xCzAJBgNVBAYTAlVTMQwwCgYDVQQIEwNDQUwxCzAJBgNVBAcT - AlNGMRQwEgYDVQQDEwtHUlIgVGVzdCBDQYIJAIayxnA7Bp+3MA8GA1UdEwEB/wQF - MAMBAf8wEQYJYIZIAYb4QgEBBAQDAgEGMAkGA1UdEgQCMAAwKwYJYIZIAYb4QgEN - BB4WHFRpbnlDQSBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwCQYDVR0RBAIwADAOBgNV - HQ8BAf8EBAMCAQYwDQYJKoZIhvcNAQEFBQADggIBAACRLafixRV4JcwND0eOqZ+r - J8ma3LAa8apbWNLgAa9xJUTKEqofxCF9FmegYCWSTRUv43W7lDCIByuKl5Uwtyzh - DzOB2Z3+q1KWPGn7ao+wHfoS3b4uXOaGFHxpR2YSyLLhAFOS/HV4dM2hdHisaz9Z - Fz2aQRTq70iHlbUAoVY4Gw8zfN+JCLp93fz30dtRats5e9OPtf3WTcERHpzBI7qD - XjSexd/XxlZYFPVyN5dUTYCC8mAdsawrEv5U70fVcNfILCUY2wI+1XSARPSC94H7 - +WqZg6pVdyu12wkSexlwneSBa2nQKFLhAZOzXpi2Af2tUI31332knSP8ZUNuQ3un - 3qi9qXtcQVXjWkVYvkjfkZiymaGS6bRml5AC2G2vhaDi4PWml79gCHQcN0Lm9Epb - ObHvoRNuPU9YkbrVBwNzGHUfEdSN433OVLNp+9CAFcfYaJyMJiV4YAiutITQQkBM - 3zT4U/FDjnojGp6nZQl9pxpK6iq2l1cpo0ZcfQJ870CLnBjWMkvEa6Mp+7rMZUEB - yKIpQoCislf1ODyl0s037u2kip7iby5CyWDe2EUhcZxByE10s2pnBPsKsT0TdZbm - Cq6toF4BeLtlB2flxNLgGa63yuWRWqb6Cq7RbDlPlRXpaXAUnigQGYvmFl4M03i5 - ImKbVCFIXYW/vECT2R/v - -----END CERTIFICATE----- - -Frontend.certificate: | - -----BEGIN CERTIFICATE----- - MIIE4zCCAsugAwIBAgIBATANBgkqhkiG9w0BAQsFADA+MQswCQYDVQQGEwJVUzEM - MAoGA1UECAwDQ0FMMQswCQYDVQQHDAJTRjEUMBIGA1UEAwwLR1JSIFRlc3QgQ0Ew - HhcNMTYwNDEyMTU1MjUwWhcNMjYwNDExMTU1MjUwWjAaMRgwFgYDVQQDDA9HUlIg - VGVzdCBTZXJ2ZXIwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDBRcE3 - NZKgSvzx0ZCplZLCAn6f8D9BDI2f9kznteyojogw7qmUWYD4Nm/YWRNN2eI0+aG4 - v2aNZgDnKDgpgxvPmKA6IGgX1EF4ibb+/ZKDrlUWjIKAk2CXdXfiamgMIcFhOLYD - 85+gp5JAhJtR8wexiE8d0Z+GOTx0S4yjszGOu18A7djhC2NOITTt412hFCQW3MXS - b2hJrK0n3lwEB7V5D/bNy/E++cXnJq9Hh/ApC2TyjWFzbVdUbQm+ont1rv0LbtxB - A/cT/oRjOk2mSH2fht6s8EdX+jGUHaenjMqBcJDjMN2LDW4RRj36IWZUS6/M7Z7k - 0lx+e7viwZuhf7Zrb+mHrmwxnb80XyVGB628iu1hSXXnLKleYfdWhsFi+l8alpvW - su3QnUehVSFXk2yt1q/SfwRD6GWf+wfP4bms17hlIltnzDxvVM7BMQ13M5goBAf8 - 030iuxtc2ATxEdId8/CJtaeF8pL9/LsY7ypqH+7Yc9DlLfeBqjjT3NvBLpziyfSc - BCGvkfqkdUOIwtwE8VD7DrMiTbyE6/cYxrldC2bw41v9CE4tNu0H5kIcoLDdtcBJ - UO3sTC+cz9txJhLnzAqZRCEfe6eLNfKKEuHELxyT8oxOQae4YELQ/PZ3piBx0vgH - EArq51KNJBAMHyxZgokmuaUKdo3YTV4Fs9Xk8wIDAQABoxAwDjAMBgNVHRMBAf8E - AjAAMA0GCSqGSIb3DQEBCwUAA4ICAQAMZdYBVxVhke3Ntfx2e2VFjYCfZ2RcMiL8 - fOUipPdtwaJ0GGwPtcLz5S/o8lOKtDlltlJWklUgoqFtV4E/W89OQcuORQ3lqmSd - 3Is19g05IsbRz9O2hb9b/QKqi+D2L+bT7Mw4mG6KCs5458kHFs5EwsQrJP8H/Sql - mHXpR6HztjXwGyMr4N5RtKCswWoKFBWPksb0B7uvgGVohspDVhs3QkJrWUBehdXu - anFVJDHbaKyRAdYpujcR84VWBRgN28qQ2f34XpBORmOQOtyMmHEK4XWqXTkMj8RA - yLU3oE0DUPiQgx5lNuQLDWCmyhG6OoMOgfNnncS/FAxveBPITvf/fKOrwmO2DKua - r65CiHSmEfThRTfU+u8TGEkMw/YuKyeyUEiJMEGL8k1w+SSav10jtTDZC8vIEMEZ - A6Ii1swONwfK8kHdAih3guHeLb+gxBdXNwUDuc0mwrSPc/qeZD31QvUz1p9fKmEv - wnswB+k9PRAEff/1u5C/Nnel0rH+A+Nr934Po3PDdUfvpyGcNVyc4NHyNRE9Ddrj - Qj/JKF4l6lIhyUGxKh5HMr49ye3MERvVSq4ET5ecCrVSMdBxGvG9uU3sms2SrvQY - YcSRmLwfWm2Eh4sC8GjMn0Ol0+AJFqvwFMbbSYaWPivQlKuyGKaP0M4wFZvGJ1zW - 2nGnFAHAzw== - -----END CERTIFICATE----- - Frontend.bind_address: 127.0.0.1 Frontend.bind_port: 8080 HTTPServer Context: Logging.filename: "%(Logging.path)/grr-http-server.log" -PrivateKeys.ca_key: | - -----BEGIN RSA PRIVATE KEY----- - MIIJJwIBAAKCAgEA0jVevcdl2QyDwSprgGCeuVluHTWQseSr6RQaVyjK42ez0GpZ - LZ2t3whNTVOTvzwa7u5A9UxlXa2dCWiX/QhamrmF0wEU5voCcyzQvTacsoQ6x1df - CbkTUwV4+XrtI90HeVSANTAR6ak/mc9h7bIWhZBBB0VYnvkrL41KYTrLL5uypQU/ - ZwCSmC9FN7VpfBUsVB2Nc2CbTbO4bVMagpzd2eXB758A624Dx4Pt08BF+UIKHzIE - 2Ih11OlX6jGxnbDCf5CO9GCP5kzzPGNk8AsQ90UcajcsUc7cgED39EXr0HCWiAIb - RvUv1VQ+VNOFhaJr1U3m7Ts2jI30wY5HfYKkhEb4MHhulTeQYNmeX5EXEExWBwy2 - vv/tw2+xDShInyc3K5z2a5x7KWryHWHZxdDAGmyrEDG7C67nm1H6gm1aEJ+EHXXq - KmwjYKCUGHMQ19fU9S/9quC2XHzQK3k9v/cNu0ch2iXE1MCnoV+G+QNYsgdca5km - ixFw5COeNgASjwLm+z1/h+C8onQV+Hxo4wiIrDKLpVd3PkYm0Y9mXmZQpPdekWaT - 4EIfG3LiEk5EHzO6xzRjPfVPZxDHBN0VTBtVb+x2RO0EbkrNwntoZxWquNoVORD9 - 7bFTgc/qGETa7yUmECRTOkpXF1/hWo/NXaSiLsPnGmiOjIv7mhheHuvEYzUCAwEA - AQKCAgBwUcgfy42mHkPAUTRD0ly9WQW3YcnA5BjfX7h7Xfaa7+xqroicESa4h60W - ZlQJ2MnjZTccWwfGuF+yiUq9D1uqVPsmtes/R9NLS2T11VqBIJpvrUXA4j1rHP94 - /q/7e7zknbwrr1XC7oZnXyJKaeAS2fOFOQ6TUzw6Glrl/Q5Yj+8ysc0g8nNiEdAA - ZlTI0l9vSqMsRTB2olMnR0JhDASWy2eG7AUHxy8JynqnrJM3DoxuAsIIGVIsw8oP - /yGSysICe3GHLpl1SySk3c8vXBpipXD7aCOsSsYTWaOjyECqjZ5Bai69CYHXkT4F - AUjVEOZhgVCk7gDFtYxUZ/vXGplghMpkb5de9uTGptraMmypEjUon4bOlpd62CDI - wHva4bry6H+O4kp8pVHOV0XO7N1964rA2+N2/wT9AIAFIq5th2OAO9xZbX1sQTeL - 6a6CsNrWnd2KkpwOs+RQqSd/iOpb/BSbxMS7nsFBP7N4mW+MLAbUCYvdIT4ixuVp - GZvn8SvYtNcr4H0DgEXcg3JkDAgWOIfd664zfGy8hxUg8Ob0huKK0IgVRryptpjJ - fiEWcXQ6RAaaww9di2Iimu7dfh7u2llUyYNy9Gw6dGtHfCWOe1COkMG8QdkwIxyJ - RFH7eq33bioFC+F2ZPl0fD0R+LILKK/sCjvL37SEJP8AAMR4+QKCAQEA6VXgu0qM - mjAfQ7jPoky3vTb1Jj7CmhzU7gk/Uczlw5lDsCKf+qyryFYJJlq0/U4JHZ1qehOF - cG0JPh31gg5IbFYdTvFDnDHxwmDVyzty9bjffXdb1BpB045S81x79kxOQHKUO0bY - 4m9Pw9+QieziRokjhWlBEWYsmvQEIG44UqHSRMY/j1HRA2a2dlrU7FlfZW0FKblH - gA4dFac70SNTq8TeaFeSFt8eBZowKjRBX7x/6HKDKCVX5EaHDIN74EYw2XHB+DPd - 8YS8CCogDboIinBsQYscmufz5E3gNPwmZAvkon8YM34EnEaLAB7NZJYCh0ynMSpo - AosywHiH7xweZwKCAQEA5qBqmQr0WRj7XQKE03jWVm8R+6TmpV+UKnVPphzHmpQe - q50wuR3+QgyIMt5sO1q2SQKIS47/mGI9tAcV9QWv6mr1YZV0VKM2WllaxXekI3gg - zeElbAlqiU2CZOxFLynfQdZMPKEoxeYXB3Me+3WvAVN6fN2RQU0NcInpXD/0Kf3x - GXnOqKPTYBgogBcYOYr43zSKEzB3rtytuNMIv0HVxyA3DqYAmSgngv1JPehir4YB - /bJdekqOtTJMVEtlMohJdeJCv8SOemwq9Ean8Yw5DjE8cnTXdyR3iIh2k0otYtvl - 6PR3UOtGRVajA13tRWJV8fvz2euErIbZHQJuCPq4AwKCAQBErexZ9FVyRNvO+STU - ZrRmUzjRKwPojLf34GzszNyMOB5+R5LDG/PsIbbLvUMsk72HJABlMj3Cm7VuvS64 - OzACA07ZH5aA8Qpx5kLHcRYjUWkm6uzyf1AEzw2HaB9snYUi3xbWY8IO1CisRK+s - iTcI//Ceoh5u2p4iddSJHygg2lSjvZ6TtsdIswDd0Vp+vsefePleEJUFEiJpbzHi - Gv3TvzyfhbcQWFfj2kB3C656WWdkqeAE/wjhvgieHE1n9AEI37zyK4IWRrV5ybxp - jepZpUGYATRIPCHDf1CRB+7c38tKMRKUhXEh9nmPbYGTK7xOyrcjd5HpvcMQd+m6 - 7ZHdAoIBAEpJZczCQI7qgx0kkJBlnfTddhduuHSQYvOCMhO5tXnPOEnjKiyWwq2X - x89K4eYEhC7kR+6+swnsqx9wINx55n5F0aCLOZuBryJPIfP1Y8OhEEAAw8MXDWAI - vXWwvWBBxo44bvoglzeXs9dMd7Bb6fEtaIkL/ZvhK/ESGz8Bwq48BAtGtxCPJKkR - XwpTZhQy9ZNAIzGnLQYAQ10DbQ/eLvQjJljrk8nBq2iAGbV9Qzxyl/WWHJIlre4j - s357gq9SQwdbyFBpCdPZP8TLdZFSr5YoueXMSRMxhedOvZMYE6KCXn45MK25+zqe - e4e8G49761803wlU2bmQ/iJgA/2UdO0CggEAVK2NSNYQKm7sZ113LYA3YvBF46wU - kW/il7acKjE7Eo0EMLQ3YUYT1iuUI8LeupM0c1htjVdaC5QXeVLDZYNvlOjiaLo/ - R3ZyMxDzkHOnBGIy9e9xXtW3tpQSlOy8rpb3az2XrjwUMEv77PO5pG+rHZYLJGeI - 1iZnvNz85PRoA284B4Cxs8oin+mip4TKdp/2noF+vrWAG3HTOVxAqk3GCKUipXwI - 2lKE9xGFDbjr99noxkJKRDdVd8UTZg8ZPnr7VtDEutDIzDi7ikkBFYurGjTvSpF+ - rWUPICL/s+ez5PMfHkmkd+/uf3lym0BW/erSWrqzJvlA2MMlEpHx43kUVw== - -----END RSA PRIVATE KEY----- - -PrivateKeys.server_key: | - -----BEGIN RSA PRIVATE KEY----- - MIIJKAIBAAKCAgEAwUXBNzWSoEr88dGQqZWSwgJ+n/A/QQyNn/ZM57XsqI6IMO6p - lFmA+DZv2FkTTdniNPmhuL9mjWYA5yg4KYMbz5igOiBoF9RBeIm2/v2Sg65VFoyC - gJNgl3V34mpoDCHBYTi2A/OfoKeSQISbUfMHsYhPHdGfhjk8dEuMo7MxjrtfAO3Y - 4QtjTiE07eNdoRQkFtzF0m9oSaytJ95cBAe1eQ/2zcvxPvnF5yavR4fwKQtk8o1h - c21XVG0JvqJ7da79C27cQQP3E/6EYzpNpkh9n4berPBHV/oxlB2np4zKgXCQ4zDd - iw1uEUY9+iFmVEuvzO2e5NJcfnu74sGboX+2a2/ph65sMZ2/NF8lRgetvIrtYUl1 - 5yypXmH3VobBYvpfGpab1rLt0J1HoVUhV5Nsrdav0n8EQ+hln/sHz+G5rNe4ZSJb - Z8w8b1TOwTENdzOYKAQH/NN9IrsbXNgE8RHSHfPwibWnhfKS/fy7GO8qah/u2HPQ - 5S33gao409zbwS6c4sn0nAQhr5H6pHVDiMLcBPFQ+w6zIk28hOv3GMa5XQtm8ONb - /QhOLTbtB+ZCHKCw3bXASVDt7EwvnM/bcSYS58wKmUQhH3unizXyihLhxC8ck/KM - TkGnuGBC0Pz2d6YgcdL4BxAK6udSjSQQDB8sWYKJJrmlCnaN2E1eBbPV5PMCAwEA - AQKCAgAeAna929Ooj/w2kBOmQVNITJrcurEXqJtU+yl10QmuInODJYuvPTaJU+qJ - 7UrSC8LT9u7lgNKroesB+Xy+9VycH1bBr8Z57Ls9vCRt83GMgMU1exvIWxnkapjy - zxLYz2T3c5bPhkSC7YIIAo8bamEHb+LY/nOGo9x/MjvkLy7CutVFj6jdSKdiukU9 - qtAe8sGnyx/sTOAkkGtShXREK+5Bnfj0e3Y6EQ5pldghgzoHJX3HK6y1/4RP155r - u07wNvuTiuMoNTVoJVzpC2SDLT5URtzMfYKWZoLMPM6LdZD8CD53CF9d9/ffNsjo - zcbVz+q8JylE3mT+PrgtvWIs8WehgXv8Xvtvo9oDApBJ/0Zsm5Vo+mGrPSTn57M6 - 7z4gKBBVHvEzyExpZsu/H78ZUegI/Ma5JPKcBK0a+VHA5+FLiRCj/H7+wPyX2x2Y - fr2LB1OTFpGnqCsN/cdVPuctdmnAbuCuduGdpEY8/5UcmjltwdzwaVw2Djn4pKTi - 8HepsL5TZ0r6sAeW+uRtGicMDdGbBIRER+lQm/FQnNHDM7j02ERXYRC/6HnKEvDH - foL+Ja2Q4Yn3ys9+Su0JxG3j3wFFT7WAOL+SBarLJ+uFXJANI3vZHz47vRinfLlU - z+wp9X06xzhS8bg2XD77TbjNgRvxEinlezwDWHL3hxg2fExOAQKCAQEA+NbU+8bw - ID0udVfCFGCNFtwkfvLZ9Zw/3pzLyw96UzEIuSUMK9Lx5bjzSCNTNp0rQeYrY2ke - LvtJrAHLiUPx8xRGUUoJZlpFs1so0CYBwDDMxbkWaNvgYAmToBfBnJIRyi2VfQKD - CGT8GRSpQpko6yftqFYy2/4WvPDJlaR5MVb5kY5a2u3PlLQaZ/5Oy1hfRcrWE9xm - peswTqX82OVycbkO/qxTVx/5gC20/DEPLTin+2g4sdNqwH8pPJfmig0LnsIaJKt2 - SGQiUmL02MBGwV90qYtSjioP3d0VQctaLvspIU8ACuMk6BTXNDmt8S6km0mPB97F - 4hXfOHHCUGQwPQKCAQEAxtWR0XFdry1SejkS2rzJKfQkPispa8b0VZbaXnxjqspR - FAwOUFuh20BxYl7LlSDyXApPENhOMfDJRDznhkqTL5rvy1uNFxt/+kU3D9i483Ba - LuSuOD9K/SY8/HG0WUuMWTr4dWwkk/GZFrdwF+MqyKI7I7cWr8KAspOFt0BIlTjt - Gqdpib/GKnkOHN04owt5pEv1+01CxMAb6CAAa4JkiCPUEl5Uh3o2nR2nDtWQc/rc - y2EWif1Lv7szidA2j9Fsp8dKb/fOGZUZhmlpjVC+JQ9SICoqgq0KL3Vkjg/l4GyF - w7IHko9rUZ9VA3UURRls5pHqe8y+bhIjF78+CtIM7wKCAQBsAtxaGWUbqVLkLl6Y - 97vmQ1I4JHPFb7gticPP9Xz0ZWFS9CjPUPYc7+Xx1xuEpj4jkaQdt4AZhovY9MKD - Z5G7IH7RRCDGY8LDcntJtmWmJciMvqViiKZhKWcB463vp1u/dX/gzllQKH1g6Z36 - wX6IbEF1g0z6PtVh1+a4ZLcSWp1jt/Xp97XV50NbBoDvNQHNypiX8GBB/s9uJBCX - mSjPAjPnCaf3NnLXV3+qxCTBTUllED0juxAoVEny+kBghf8YP0qXxjFGhOh8+GUt - PHC5+RQrj8Ua3lkaxZ83euw/XlfhFGiBUU0wy/MJwUumV8etfAVwthsQ4suMtZxR - xttFAoIBABTrPUcqxS979jRzsr8eo49tZy6/PcFgEi67C0hrj9TVKkiQqCTeLx4x - Hny5+nM7HyR91SmxiDCK47HxMm6Xg/q7M0VS4Xov43wCMjPRmkvKY0KRvp6eUhZm - In5wvAe1AhQVNzSrZwHFplSUgg+RT9wB7XTpe1KMhRvEl4nbEofYkGGAgYMDkSbA - y7JPt/i3aVnWwA9rZn3qtETssP6enlMQaexwzjXersZC62ONJoB9QSOImGV2J7UJ - TffO1x60atkQB43WJXHdlOzmRDug5hBiF0LZDNXovKyXjjfABnBhGoWnQlKyEhlp - SPlvJO3MKf/sFB5oQRS1hcmmzjE3PjMCggEBALmcsjPJx4CCJVrwdg/1C/rm9tgS - 7GwAdyDvRRwoAiay9rdsi7z8HPu7FOqBEakCTSGcm1sxnYstvBJnxzXHbVfNKSg1 - bWfZ3v/TQT1DqWhchvh3F969kVXvtsvu9MzZfiG6dYdIE+9Xah93itbjYIJGqk6O - g8a1Sjjbw2I4rG6Fd/LkOVM52yblU9XIK2kNNHyR+gtf8XEouSODHMRQghSZThJV - BCsoAki3kAdM9LkAJWvymS8gobZri5nW8rMQioCVnIB3XXQziw3uT1a1MtqjMSdd - wkPliOT+lNsuHlTcnoW5O9P56XCNH68JXWYpVFTkeMKppNCNkyBqN5hbshE= - -----END RSA PRIVATE KEY----- - PrivateKeys.executable_signing_private_key: | -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAzbuwYnKTTleU2F4zu1gI/BolzR74470j6wK7QrQp5b5Qkmfe diff --git a/grr/test/grr_response_test/test_data/win_hello.exe b/grr/test/grr_response_test/test_data/win_hello.exe old mode 100644 new mode 100755 index 5f70435c52..eb53381d42 Binary files a/grr/test/grr_response_test/test_data/win_hello.exe and b/grr/test/grr_response_test/test_data/win_hello.exe differ diff --git a/grr/test/grr_response_test/test_utils.py b/grr/test/grr_response_test/test_utils.py deleted file mode 100644 index b38c627d29..0000000000 --- a/grr/test/grr_response_test/test_utils.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python -"""Helper functions and classes for use by tests.""" - -import yaml - -from grr_response_client import comms -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto - - -class PrivateKeyNotFoundException(Exception): - - def __init__(self): - super(PrivateKeyNotFoundException, - self).__init__("Private key not found in config file.") - - -def GetClientId(writeback_file): - """Given the path to a client's writeback file, returns its client id.""" - with open(writeback_file) as f: - parsed_yaml = yaml.safe_load(f.read()) or {} - serialized_pkey = parsed_yaml.get("Client.private_key", None) - if serialized_pkey is None: - raise PrivateKeyNotFoundException - pkey = rdf_crypto.RSAPrivateKey(serialized_pkey) - client_urn = comms.ClientCommunicator(private_key=pkey).common_name - return client_urn.Basename() diff --git a/grr/test_lib/action_mocks.py b/grr/test_lib/action_mocks.py index a3dd734c7d..fd902692da 100644 --- a/grr/test_lib/action_mocks.py +++ b/grr/test_lib/action_mocks.py @@ -23,13 +23,16 @@ from grr_response_core import config from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_core.lib.rdfvalues import client_action as rdf_client_action from grr_response_core.lib.rdfvalues import client_fs as rdf_client_fs from grr_response_core.lib.rdfvalues import client_network as rdf_client_network from grr_response_core.lib.rdfvalues import client_stats as rdf_client_stats from grr_response_core.lib.rdfvalues import cloud as rdf_cloud from grr_response_core.lib.rdfvalues import flows as rdf_flows +from grr_response_core.lib.rdfvalues import mig_client_action from grr_response_core.lib.rdfvalues import paths as rdf_paths from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import jobs_pb2 from grr_response_server import action_registry from grr_response_server import client_fixture from grr_response_server import server_stubs @@ -575,6 +578,61 @@ def EnumerateUsers( yield rdf_client.User(username="user2") yield rdf_client.User(username="user3") + def ExecuteCommand( + self, + args: rdf_client_action.ExecuteRequest, + ) -> Iterator[rdf_client_action.ExecuteResponse]: + """Returns fake replies for the `ExecuteCommand` action.""" + args = mig_client_action.ToProtoExecuteRequest(args) + + response = jobs_pb2.ExecuteResponse() + response.request.MergeFrom(args) + response.exit_status = 0 + + if args.cmd == "/usr/sbin/dmidecode": + # We only provide minimal output so that the parser does not fail. + response.stdout = """\ +BIOS Information + Vendor: Google + Version: Google + Release Date: 01/25/2024 + +System Information + Manufacturer: Google + Product Name: Google Compute Engine +""".encode("utf-8") + elif args.cmd == "/usr/sbin/system_profiler": + # We only provide minimal output so that the parser does not fail. + response.stdout = """\ + + + + + + _items + + + boot_rom_version + 10151.101.3 + chip_type + Apple M1 Pro + machine_model + MacBookPro18,3 + machine_name + MacBook Pro + model_number + Z15G000PCB/A + + + + + +""".encode("utf-8") + else: + raise RuntimeError(f"Unknown command: {args.cmd}") + + yield mig_client_action.ToRDFExecuteResponse(response) + def GetClientInfo(self, _): self.response_count += 1 return [ @@ -605,6 +663,17 @@ def WmiQuery(self, query): if query.query == u"SELECT * FROM Win32_LogicalDisk": self.response_count += 1 return client_fixture.WMI_SAMPLE + elif "FROM Win32_ComputerSystemProduct" in query.query: + self.response_count += 1 + return [ + rdf_protodict.Dict({ + "IdentifyingNumber": "2S42F1S3320HFN2179FV", + "Name": "42F1S3320H", + "Vendor": "LEVELHO", + "Version": "NumbBox Y1337", + "Caption": "Computer System Product", + }) + ] elif query.query.startswith("Select * " "from Win32_NetworkAdapterConfiguration"): self.response_count += 1 diff --git a/grr/test_lib/fixture_test_lib.py b/grr/test_lib/fixture_test_lib.py index be6b5bbff4..4919fa9b8c 100644 --- a/grr/test_lib/fixture_test_lib.py +++ b/grr/test_lib/fixture_test_lib.py @@ -105,10 +105,13 @@ def CreateClientObject(self, vfs_fixture): offset=0, size=len(content), blob_id=bytes(blob_id) ) hash_id = file_store.AddFileWithUnknownHash( - db.ClientPath.FromPathInfo(self.client_id, path_info), [blob_ref]) + db.ClientPath.FromPathInfo(self.client_id, path_info), [blob_ref] + ) path_info.hash_entry.num_bytes = len(content) path_info.hash_entry.sha256 = hash_id.AsBytes() if path_info is not None: data_store.REL_DB.WritePathInfos( - client_id=self.client_id, path_infos=[path_info]) + client_id=self.client_id, + path_infos=[mig_objects.ToProtoPathInfo(path_info)], + ) diff --git a/grr/test_lib/flow_test_lib.py b/grr/test_lib/flow_test_lib.py index bdf911265b..bc16bd64f1 100644 --- a/grr/test_lib/flow_test_lib.py +++ b/grr/test_lib/flow_test_lib.py @@ -4,7 +4,7 @@ import logging import re import sys -from typing import ContextManager, Iterable, List, Optional, Pattern, Text, Type, Union +from typing import Callable, ContextManager, Iterable, List, Optional, Pattern, Text, Type, Union from unittest import mock from grr_response_client import actions @@ -15,6 +15,7 @@ from grr_response_core.lib.rdfvalues import flows as rdf_flows from grr_response_core.lib.rdfvalues import structs as rdf_structs from grr_response_core.lib.util import precondition +from grr_response_proto import flows_pb2 from grr_response_proto import tests_pb2 from grr_response_server import action_registry from grr_response_server import data_store @@ -317,7 +318,7 @@ def _PushHandlerMessage(self, message): handler_cls().ProcessMessages([handler_request]) - def PushToStateQueue(self, message, **kw): + def PushToStateQueue(self, message: rdf_flows.GrrMessage, **kw): """Push given message to the state queue.""" # Assume the client is authorized message.auth_state = rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED @@ -331,8 +332,14 @@ def PushToStateQueue(self, message, **kw): self._PushHandlerMessage(message) return - data_store.REL_DB.WriteFlowResponses( - [rdf_flow_objects.FlowResponseForLegacyResponse(message)]) + message = rdf_flow_objects.FlowResponseForLegacyResponse(message) + if isinstance(message, rdf_flow_objects.FlowResponse): + message = mig_flow_objects.ToProtoFlowResponse(message) + if isinstance(message, rdf_flow_objects.FlowStatus): + message = mig_flow_objects.ToProtoFlowStatus(message) + if isinstance(message, rdf_flow_objects.FlowIterator): + message = mig_flow_objects.ToProtoFlowIterator(message) + data_store.REL_DB.WriteFlowResponses([message]) def Next(self): """Emulates execution of a single client action request. @@ -434,12 +441,18 @@ def StartAndRunFlow(flow_cls, flow_id = flow.StartFlow(flow_cls=flow_cls, client_id=client_id, **kwargs) if check_flow_errors: - rdf_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - if rdf_flow.flow_state == rdf_flow.FlowState.ERROR: + flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + if flow_obj.flow_state == flows_pb2.Flow.FlowState.ERROR: raise RuntimeError( "Flow %s on %s raised an error in state %s. \nError message: %s\n%s" - % (flow_id, client_id, rdf_flow.flow_state, rdf_flow.error_message, - rdf_flow.backtrace)) + % ( + flow_id, + client_id, + flow_obj.flow_state, + flow_obj.error_message, + flow_obj.backtrace, + ) + ) RunFlow( client_id, @@ -457,7 +470,12 @@ def __init__(self, *args, **kw): super().__init__(*args, **kw) self.processed_flows = [] - def ProcessFlow(self, flow_processing_request): + def ProcessFlow( + self, + flow_processing_request: Callable[ + [flows_pb2.FlowProcessingRequest], None + ], + ) -> None: key = (flow_processing_request.client_id, flow_processing_request.flow_id) self.processed_flows.append(key) super().ProcessFlow(flow_processing_request) @@ -496,20 +514,33 @@ def RunFlow(client_id, # Run the client and worker until nothing changes any more. while True: client_processed = client_mock.Next() + data_store.REL_DB.delegate.WaitUntilNoFlowsToProcess(timeout=10) worker_processed = test_worker.ResetProcessedFlows() all_processed_flows.update(worker_processed) - if client_processed == 0 and not worker_processed: + # Exit the loop if no client actions were processed, nothing was processed + # on the worker and there are no pending flow processing requests. + if ( + client_processed == 0 + and not worker_processed + and not data_store.REL_DB.ReadFlowProcessingRequests() + ): break if check_flow_errors: for client_id, flow_id in all_processed_flows: - rdf_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - if rdf_flow.flow_state != rdf_flow.FlowState.FINISHED: + flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + if flow_obj.flow_state != flows_pb2.Flow.FlowState.FINISHED: raise RuntimeError( - "Flow %s on %s completed in state %s (error message: %s %s)" % - (flow_id, client_id, rdf_flow.flow_state, rdf_flow.error_message, - rdf_flow.backtrace)) + "Flow %s on %s completed in state %s (error message: %s %s)" + % ( + flow_id, + client_id, + flow_obj.flow_state, + flow_obj.error_message, + flow_obj.backtrace, + ) + ) return flow_id finally: @@ -577,12 +608,14 @@ def FinishAllFlowsOnClient(client_id, **kwargs): def GetFlowState(client_id, flow_id): - rdf_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + proto_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + rdf_flow = mig_flow_objects.ToRDFFlow(proto_flow) return rdf_flow.persistent_data def GetFlowObj(client_id, flow_id): - rdf_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + proto_flow = data_store.REL_DB.ReadFlowObject(client_id, flow_id) + rdf_flow = mig_flow_objects.ToRDFFlow(proto_flow) return rdf_flow @@ -628,9 +661,9 @@ def FlowResultMetadataOverride( def MarkFlowAsFinished(client_id: str, flow_id: str) -> None: """Marks the given flow as finished without executing it.""" - flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - flow_obj.flow_state = flow_obj.FlowState.FINISHED - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.UpdateFlow( + client_id, flow_id, flow_state=flows_pb2.Flow.FlowState.FINISHED + ) def MarkFlowAsFailed(client_id: str, @@ -638,10 +671,14 @@ def MarkFlowAsFailed(client_id: str, error_message: Optional[str] = None) -> None: """Marks the given flow as finished without executing it.""" flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - flow_obj.flow_state = flow_obj.FlowState.ERROR - if error_message is not None: + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR + if not flow_obj.HasField("error_message") and error_message: flow_obj.error_message = error_message - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.UpdateFlow( + client_id, + flow_id, + flow_obj, + ) def ListAllFlows(client_id: str) -> List[rdf_flow_objects.Flow]: diff --git a/grr/test_lib/hunt_test_lib.py b/grr/test_lib/hunt_test_lib.py index 74742efa1d..1b5092aedf 100644 --- a/grr/test_lib/hunt_test_lib.py +++ b/grr/test_lib/hunt_test_lib.py @@ -352,7 +352,7 @@ def AddErrorToHunt(self, hunt_id, client_id, message, backtrace): flow_id = self._EnsureClientHasHunt(client_id, hunt_id) flow_obj = data_store.REL_DB.ReadFlowObject(client_id, flow_id) - flow_obj.flow_state = flow_obj.FlowState.ERROR + flow_obj.flow_state = flows_pb2.Flow.FlowState.ERROR flow_obj.error_message = message flow_obj.backtrace = backtrace data_store.REL_DB.UpdateFlow(client_id, flow_id, flow_obj=flow_obj) @@ -366,7 +366,10 @@ def GetHuntResults(self, hunt_id): Returns: List with hunt results payloads. """ - return data_store.REL_DB.ReadHuntResults(hunt_id, 0, sys.maxsize) + return [ + mig_flow_objects.ToRDFFlowResult(r) + for r in data_store.REL_DB.ReadHuntResults(hunt_id, 0, sys.maxsize) + ] class DummyHuntOutputPlugin(output_plugin.OutputPlugin): diff --git a/grr/test_lib/test_lib.py b/grr/test_lib/test_lib.py index bc501f20f9..3832a2a886 100644 --- a/grr/test_lib/test_lib.py +++ b/grr/test_lib/test_lib.py @@ -22,7 +22,6 @@ from grr_response_core.lib import utils from grr_response_core.lib.rdfvalues import client as rdf_client from grr_response_core.lib.rdfvalues import client_network as rdf_client_network -from grr_response_core.lib.rdfvalues import crypto as rdf_crypto from grr_response_core.lib.util import cache from grr_response_core.lib.util import precondition from grr_response_core.lib.util import temp @@ -168,7 +167,6 @@ def SetupClient( description=None, users=None, memory_size=None, - add_cert=True, ): """Prepares a test client mock to be used. @@ -188,14 +186,12 @@ def SetupClient( description: string users: list of rdf_client.User objects. memory_size: bytes - add_cert: boolean Returns: the client_id: string """ client = self._SetupTestClientObject( client_nr, - add_cert=add_cert, arch=arch, fqdn=fqdn, install_time=install_time, @@ -256,7 +252,6 @@ def _TestInterfaces(self, client_nr): def _SetupTestClientObject( self, client_nr, - add_cert=True, arch="x86_64", fqdn=None, install_time=None, @@ -304,14 +299,8 @@ def _SetupTestClientObject( client.memory_size = memory_size ping = ping or rdfvalue.RDFDatetime.Now() - if add_cert: - cert = self.ClientCertFromPrivateKey(config.CONFIG["Client.private_key"]) - else: - cert = None - data_store.REL_DB.WriteClientMetadata( - client_id, last_ping=ping, certificate=cert - ) + data_store.REL_DB.WriteClientMetadata(client_id, last_ping=ping) proto_client = mig_objects.ToProtoClientSnapshot(client) data_store.REL_DB.WriteClientSnapshot(proto_client) @@ -328,12 +317,6 @@ def AddClientLabel(self, client_id, owner, name): data_store.REL_DB.AddClientLabels(client_id, owner, [name]) client_index.ClientIndex().AddClientLabels(client_id, [name]) - def ClientCertFromPrivateKey(self, private_key): - common_name = rdf_client.ClientURN.FromPrivateKey(private_key) - csr = rdf_crypto.CertificateSigningRequest( - common_name=common_name, private_key=private_key) - return rdf_crypto.RDFX509Cert.ClientCertFromCSR(csr) - class ConfigOverrider(object): """A context to temporarily change config options.""" diff --git a/grr/test_lib/timeline_test_lib.py b/grr/test_lib/timeline_test_lib.py index d631c50a08..b5f7f0fa24 100644 --- a/grr/test_lib/timeline_test_lib.py +++ b/grr/test_lib/timeline_test_lib.py @@ -5,7 +5,6 @@ from typing import Sequence from typing import Text -from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import timeline as rdf_timeline from grr_response_server import data_store from grr_response_server.flows.general import timeline @@ -37,9 +36,8 @@ def WriteTimeline( flow_obj.flow_id = flow_id flow_obj.client_id = client_id flow_obj.flow_class_name = timeline.TimelineFlow.__name__ - flow_obj.create_time = rdfvalue.RDFDatetime.Now() flow_obj.parent_hunt_id = hunt_id - data_store.REL_DB.WriteFlowObject(flow_obj) + data_store.REL_DB.WriteFlowObject(mig_flow_objects.ToProtoFlow(flow_obj)) blobs = list(rdf_timeline.TimelineEntry.SerializeStream(iter(entries))) blob_ids = data_store.BLOBS.WriteBlobsWithUnknownHashes(blobs) diff --git a/grr/test_lib/vfs_test_lib.py b/grr/test_lib/vfs_test_lib.py index c929a0d394..724b6fb7de 100644 --- a/grr/test_lib/vfs_test_lib.py +++ b/grr/test_lib/vfs_test_lib.py @@ -569,7 +569,9 @@ def CreateFile(client_path, content=b""): path_info.hash_entry.sha256 = hash_id.AsBytes() path_info.stat_entry = stat_entry - data_store.REL_DB.WritePathInfos(client_path.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_path.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) def CreateDirectory(client_path): @@ -593,7 +595,9 @@ def CreateDirectory(client_path): path_info.stat_entry = stat_entry path_info.directory = True - data_store.REL_DB.WritePathInfos(client_path.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_path.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) def GenerateBlobRefs( @@ -653,7 +657,9 @@ def CreateFileWithBlobRefsAndData( path_type=client_path.path_type, components=client_path.components ) path_info.hash_entry.sha256 = hash_id.AsBytes() - data_store.REL_DB.WritePathInfos(client_path.client_id, [path_info]) + data_store.REL_DB.WritePathInfos( + client_path.client_id, [mig_objects.ToProtoPathInfo(path_info)] + ) class VfsTestCase(absltest.TestCase): diff --git a/travis/build_server_deb.sh b/travis/build_server_deb.sh deleted file mode 100755 index cf4c551e69..0000000000 --- a/travis/build_server_deb.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -set -ex - -function create_changelog() { - if [[ -f debian/changelog ]]; then - echo "Replacing debian/changelog with new changelog." - rm debian/changelog - fi - pyscript=" -import configparser -config = configparser.ConfigParser() -config.read('version.ini') -print('%s.%s.%s-%s' % ( - config.get('Version', 'major'), - config.get('Version', 'minor'), - config.get('Version', 'revision'), - config.get('Version', 'release'))) -" - deb_version="$(python3 -c "${pyscript}")" - debchange --create \ - --newversion "${deb_version}" \ - --package grr-server \ - --urgency low \ - --controlmaint \ - --distribution unstable \ - "Built by GitHub Actions at ${GITHUB_SHA}" -} - -# Sets environment variables to be used by debhelper. -function export_build_vars() { - # Note that versions for the packages listed here can differ. - export LOCAL_DEB_PYINDEX="${PWD}/local_pypi" - export API_SDIST="$(ls local_pypi | grep -e 'grr-api-client-.*\.zip')" - export CLIENT_BUILDER_SDIST="$(ls local_pypi | grep -e 'grr-response-client-builder.*\.zip')" - export TEMPLATES_SDIST="$(ls local_pypi | grep -e 'grr-response-templates-.*\.zip')" - export SERVER_SDIST="$(ls local_pypi | grep -e 'grr-response-server-.*\.zip')" -} - -create_changelog -export_build_vars -rm -f ../grr-server_*.tar.gz -rm -rf gcs_upload_dir -dpkg-buildpackage -us -uc -mkdir gcs_upload_dir && cp ../grr-server_* gcs_upload_dir diff --git a/travis/install_client_builder.sh b/travis/install_client_builder.sh index bd19319435..22695ef7f8 100755 --- a/travis/install_client_builder.sh +++ b/travis/install_client_builder.sh @@ -6,7 +6,8 @@ set -e source "${HOME}/INSTALL/bin/activate" -pip install --upgrade pip wheel six setuptools +pip install -r build_requirements.txt +pip install --upgrade six # Get around a Travis bug: https://github.com/travis-ci/travis-ci/issues/8315#issuecomment-327951718 unset _JAVA_OPTIONS @@ -18,14 +19,14 @@ unset _JAVA_OPTIONS # Proto package. cd grr/proto python setup.py sdist -pip install ./dist/grr-response-proto-*.tar.gz +pip install ./dist/grr_response_proto-*.tar.gz cd - # Base package, grr-response-core, depends on grr-response-proto. # Running sdist first since it accepts --no-sync-artifacts flag. cd grr/core python setup.py sdist --no-sync-artifacts -pip install ./dist/grr-response-core-*.tar.gz +pip install ./dist/grr_response_core-*.tar.gz cd - # Depends on grr-response-core. @@ -34,7 +35,7 @@ cd - # only gets copied during sdist step. cd grr/client python setup.py sdist -pip install ./dist/grr-response-client-*.tar.gz +pip install ./dist/grr_response_client-*.tar.gz cd - # Depends on grr-response-client. @@ -43,5 +44,5 @@ cd - # only gets copied during sdist step. cd grr/client_builder python setup.py sdist -pip install ./dist/grr-response-client-builder-*.tar.gz +pip install ./dist/grr_response_client_builder-*.tar.gz cd -