Skip to content

Commit

Permalink
A number of changes:
Browse files Browse the repository at this point in the history
* Bumped the version to 3.4.3.1
* ListParsedResults is now exposed in the API client.
* Proper CPU accounting in the sandboxing code.
* More work on artifact collection (UIv2).
* Approval process fixes (UIv2).
  • Loading branch information
mbushkov committed May 10, 2021
1 parent 118a633 commit 6567156
Show file tree
Hide file tree
Showing 32 changed files with 732 additions and 105 deletions.
6 changes: 6 additions & 0 deletions api_client/python/grr_api_client/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def ListResults(self) -> utils.ItemsIterator[FlowResult]:
items = self._context.SendIteratorRequest("ListFlowResults", args)
return utils.MapItemsIterator(lambda data: FlowResult(data=data), items)

def ListParsedResults(self) -> utils.ItemsIterator[FlowResult]:
args = flow_pb2.ApiListParsedFlowResultsArgs(
client_id=self.client_id, flow_id=self.flow_id)
items = self._context.SendIteratorRequest("ListParsedFlowResults", args)
return utils.MapItemsIterator(lambda data: FlowResult(data=data), items)

def GetExportedResultsArchive(self, plugin_name) -> utils.BinaryChunkIterator:
args = flow_pb2.ApiGetExportedFlowResultsArgs(
client_id=self.client_id, flow_id=self.flow_id, plugin_name=plugin_name)
Expand Down
48 changes: 35 additions & 13 deletions grr/client/grr_response_client/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import logging
import pdb
import traceback
from typing import NamedTuple

from absl import flags

import psutil

from grr_response_client import client_utils
from grr_response_client.unprivileged import communication
from grr_response_core import config
from grr_response_core.lib import rdfvalue
from grr_response_core.lib import utils
Expand All @@ -40,6 +42,35 @@ class RuntimeExceededError(Error):
"""Exceeded the maximum allowed runtime."""


class _CpuUsed(NamedTuple):
cpu_time: float
sys_time: float


class _CpuTimes:
"""Accounting of used CPU time."""

def __init__(self):
self.proc = psutil.Process()
self.cpu_start = self.proc.cpu_times()
self.unprivileged_cpu_start = communication.TotalServerCpuTime()
self.unprivileged_sys_start = communication.TotalServerSysTime()

@property
def cpu_used(self) -> _CpuUsed:
end = self.proc.cpu_times()
unprivileged_cpu_end = communication.TotalServerCpuTime()
unprivileged_sys_end = communication.TotalServerSysTime()
return _CpuUsed((end.user - self.cpu_start.user + unprivileged_cpu_end -
self.unprivileged_cpu_start),
(end.system - self.cpu_start.system + unprivileged_sys_end -
self.unprivileged_sys_start))

@property
def total_cpu_used(self) -> float:
return sum(self.cpu_used)


class ActionPlugin(object):
"""Baseclass for plugins.
Expand Down Expand Up @@ -93,8 +124,7 @@ def __init__(self, grr_worker=None):
self._last_gc_run = rdfvalue.RDFDatetime.Now()
self._gc_frequency = rdfvalue.Duration.From(
config.CONFIG["Client.gc_frequency"], rdfvalue.SECONDS)
self.proc = psutil.Process()
self.cpu_start = self.proc.cpu_times()
self.cpu_times = _CpuTimes()
self.cpu_limit = rdf_flows.GrrMessage().cpu_limit
self.start_time = None
self.runtime_limit = None
Expand Down Expand Up @@ -139,7 +169,7 @@ def Execute(self, message):
raise RuntimeError("Message for %s was not Authenticated." %
self.message.name)

self.cpu_start = self.proc.cpu_times()
self.cpu_times = _CpuTimes()
self.cpu_limit = self.message.cpu_limit

if getattr(flags.FLAGS, "debug_client_actions", False):
Expand All @@ -153,9 +183,7 @@ def Execute(self, message):

# Ensure we always add CPU usage even if an exception occurred.
finally:
used = self.proc.cpu_times()
self.cpu_used = (used.user - self.cpu_start.user,
used.system - self.cpu_start.system)
self.cpu_used = self.cpu_times.cpu_used
self.status.runtime_us = rdfvalue.RDFDatetime.Now() - self.start_time

except NetworkBytesExceededError as e:
Expand Down Expand Up @@ -306,13 +334,7 @@ def Progress(self):

self.grr_worker.Heartbeat()

user_start = self.cpu_start.user
system_start = self.cpu_start.system
cpu_times = self.proc.cpu_times()
user_end = cpu_times.user
system_end = cpu_times.system

used_cpu = user_end - user_start + system_end - system_start
used_cpu = self.cpu_times.total_cpu_used

if used_cpu > self.cpu_limit:
raise CPUExceededError("Action exceeded cpu limit.")
Expand Down
91 changes: 65 additions & 26 deletions grr/client/grr_response_client/client_actions/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import unicode_literals

import collections
import contextlib
import os
import platform
import stat
Expand All @@ -18,6 +19,7 @@
from grr_response_client import actions
from grr_response_client import client_utils
from grr_response_client.client_actions import standard
from grr_response_client.unprivileged import communication
from grr_response_core.lib import rdfvalue
from grr_response_core.lib import utils
from grr_response_core.lib.rdfvalues import client as rdf_client
Expand Down Expand Up @@ -160,16 +162,57 @@ def testDoesNotRaiseForZeroRuntimeLimit(self):
self.assertEqual(action.SendReply.call_count, 1)
self.assertEqual(action.SendReply.call_args[0][0].status, "OK")

def testCPULimit(self):
received_messages = []

class MockWorker(object):

def Heartbeat(self):
pass
def testCPUAccounting(self):
with contextlib.ExitStack() as stack:
server_cpu_time = 1.0
server_sys_time = 1.1
stack.enter_context(
mock.patch.object(communication, "TotalServerCpuTime",
lambda: server_cpu_time))
stack.enter_context(
mock.patch.object(communication, "TotalServerSysTime",
lambda: server_sys_time))

process_cpu_time = 1.2
process_sys_time = 1.3

class FakeProcess(object):

def __init__(self, pid=None):
pass

def cpu_times(self): # pylint: disable=invalid-name
return collections.namedtuple("pcputimes",
["user", "system"])(process_cpu_time,
process_sys_time)

stack.enter_context(mock.patch.object(psutil, "Process", FakeProcess))

class _ProgressAction(ProgressAction):

def Run(self, *args):
super().Run(*args)
nonlocal server_cpu_time, server_sys_time
server_cpu_time = 42.0
server_sys_time = 43.0
nonlocal process_cpu_time, process_sys_time
process_cpu_time = 10.0
process_sys_time = 11.0

message = rdf_flows.GrrMessage(name="ProgressAction", runtime_limit_us=0)
worker = mock.MagicMock()
action = _ProgressAction(worker)
action.SendReply = mock.MagicMock()
action.Execute(message)
self.assertEqual(action.SendReply.call_count, 1)
self.assertAlmostEqual(
action.SendReply.call_args[0][0].cpu_time_used.user_cpu_time,
42.0 - 1.0 + 10.0 - 1.2)
self.assertAlmostEqual(
action.SendReply.call_args[0][0].cpu_time_used.system_cpu_time,
43.0 - 1.1 + 11.0 - 1.3)

def SendClientAlert(self, msg):
received_messages.append(msg)
def testCPULimit(self):

class FakeProcess(object):

Expand All @@ -181,26 +224,22 @@ def __init__(self, unused_pid=None):
def cpu_times(self): # pylint: disable=g-bad-name
return self.pcputimes(*self.times.pop(0))

results = []

def MockSendReply(unused_self, reply=None, **kwargs):
results.append(reply or rdf_client.LogMessage(**kwargs))

message = rdf_flows.GrrMessage(name="ProgressAction", cpu_limit=3600)

action_cls = ProgressAction
with utils.MultiStubber((psutil, "Process", FakeProcess),
(action_cls, "SendReply", MockSendReply)):

action = action_cls(grr_worker=MockWorker())
with mock.patch.object(psutil, "Process", FakeProcess):
worker = mock.MagicMock()
action = ProgressAction(grr_worker=worker)
message = rdf_flows.GrrMessage(name="ProgressAction", cpu_limit=3600)
action.Execute(message)

self.assertIn("Action exceeded cpu limit.", results[0].error_message)
self.assertIn("CPUExceededError", results[0].error_message)
self.assertEqual("CPU_LIMIT_EXCEEDED", results[0].status)
self.assertEqual(worker.SendReply.call_count, 1)
reply = worker.SendReply.call_args[0][0]

self.assertIn("Action exceeded cpu limit.", reply.error_message)
self.assertIn("CPUExceededError", reply.error_message)
self.assertEqual("CPU_LIMIT_EXCEEDED", reply.status)

self.assertLen(received_messages, 1)
self.assertEqual(received_messages[0], "Cpu limit exceeded.")
self.assertEqual(worker.SendClientAlert.call_count, 1)
self.assertEqual(worker.SendClientAlert.call_args[0][0],
"Cpu limit exceeded.")

@unittest.skipIf(platform.system() == "Windows",
"os.statvfs is not available on Windows")
Expand Down
51 changes: 50 additions & 1 deletion grr/client/grr_response_client/unprivileged/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import platform
import struct
import subprocess
from typing import NamedTuple, Callable, Optional, List, BinaryIO
from typing import NamedTuple, Callable, Optional, List, BinaryIO, Set
import psutil


class Transport(abc.ABC):
Expand Down Expand Up @@ -236,6 +237,11 @@ class SubprocessServer(Server):
channel.
"""

_past_instances_total_cpu_time = 0.0
_past_instances_total_sys_time = 0.0

_started_instances: Set["SubprocessServer"] = set()

def __init__(self,
args_factory: ArgsFactory,
extra_file_descriptors: Optional[List[FileDescriptor]] = None):
Expand Down Expand Up @@ -291,7 +297,15 @@ def Start(self) -> None:
pass_fds=[input_r_fd, output_w_fd] + extra_fds,
)

SubprocessServer._started_instances.add(self)

def Stop(self) -> None:

if self in self._started_instances:
SubprocessServer._started_instances.remove(self)
SubprocessServer._past_instances_total_cpu_time += self.cpu_time
SubprocessServer._past_instances_total_sys_time += self.sys_time

if self._process is not None:
self._process.kill()
self._process.wait()
Expand All @@ -306,6 +320,33 @@ def Connect(self) -> Connection:
transport = PipeTransport(self._output_r, self._input_w)
return Connection(transport)

@classmethod
def TotalCpuTime(cls) -> float:
return SubprocessServer._past_instances_total_cpu_time + sum(
[instance.cpu_time for instance in cls._started_instances])

@classmethod
def TotalSysTime(cls) -> float:
return SubprocessServer._past_instances_total_sys_time + sum(
[instance.sys_time for instance in cls._started_instances])

@property
def cpu_time(self) -> float:
return self._psutil_process.cpu_times().user

@property
def sys_time(self) -> float:
return self._psutil_process.cpu_times().system

@property
def _psutil_process(self) -> psutil.Process:
if self._process_win is not None:
return psutil.Process(pid=self._process_win.pid)
elif self._process is not None:
return psutil.Process(pid=self._process.pid)
else:
raise ValueError("Can't determine process.")


def _EnterSandbox(user: str, group: str) -> None:
if platform.system() == "Linux" or platform.system() == "Darwin":
Expand Down Expand Up @@ -335,3 +376,11 @@ def Main(channel: Channel, connection_handler: ConnectionHandler, user: str,
transport = PipeTransport(pipe_input, pipe_output)
connection = Connection(transport)
connection_handler(connection)


def TotalServerCpuTime() -> float:
return SubprocessServer.TotalCpuTime()


def TotalServerSysTime() -> float:
return SubprocessServer.TotalSysTime()

0 comments on commit 6567156

Please sign in to comment.