Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Align tx match --hash format with the output of tx hash #1289

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 112 additions & 12 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Expand Up @@ -167,19 +167,21 @@ def execute(self, settings: CLISettings) -> None:
if self.as_hashes:
types = (BytesHasher, TextHasher, FileHasher)
signal_types = [s for s in signal_types if issubclass(s, types)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In reference to below comments) This list of signal types have already been narrowed down the valid set given the input.

It might be that we should consider moving --hashes to a toplevel argument (tx match photo|video|hash)` and allowing cross matching might make more sense, but now we've exceeded the scope of the pr.

if self.as_hashes and len(signal_types) > 1:
raise CommandError(
f"Error: '{self.content_type.get_name()}' supports more than one SignalType."
" for '--hashes' also use '--only-signal' to specify one of "
f"{[s.get_name() for s in signal_types]}",
2,
)

logging.info(
"Signal types that apply: %s",
", ".join(s.get_name() for s in signal_types) or "None!",
)

if self.as_hashes:
hashes_grouped_by_prefix: t.Dict[
t.Optional[t.Type[SignalType]], t.Set[str]
] = dict()
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
# Infer the signal types from the prefixes (None is used as key for hashes with no prefix)
for path in self.files:
_group_hashes_by_prefix(path, settings, hashes_grouped_by_prefix)
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
# Validate the SignalType and append the None prefixes to the correct SignalType
self.validate_hashes_signal_type(hashes_grouped_by_prefix, signal_types)
signal_types = [key for key in hashes_grouped_by_prefix.keys() if key]
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
indices: t.List[t.Tuple[t.Type[SignalType], SignalTypeIndex]] = []
for s_type in signal_types:
index = settings.index.load(s_type)
Expand All @@ -196,11 +198,14 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results = _match_hashes(
hashes_grouped_by_prefix[s_type], s_type, index
)
else:
results = _match_file(path, s_type, index)

for r in results:
# TODO Improve visualisation of a single multiple hash query
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
for collab, fetched_data in metadatas:
if not self.all and collab in seen:
Expand All @@ -215,6 +220,52 @@ def execute(self, settings: CLISettings) -> None:
fetched_data,
)

def validate_hashes_signal_type(
self,
hashes_grouped_by_prefix: t.Dict[t.Optional[t.Type[SignalType]], t.Set[str]],
signal_types: t.List[t.Type[SignalType]],
) -> None:
if (
len(hashes_grouped_by_prefix) > 2
and None in hashes_grouped_by_prefix.keys()
):
raise CommandError(
f"Error: Provided more than one SignalType and some hashes are missing a prefix",
2,
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
)
if self.only_signal:
if (
self.only_signal not in hashes_grouped_by_prefix.keys()
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
and None not in hashes_grouped_by_prefix.keys()
):
raise CommandError(
f"Error: SignalType '{self.only_signal} was provided, but inferred more from provided hashes."
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
f"Inferred signal types: {', '.join(s_type.get_name() for s_type in hashes_grouped_by_prefix.keys() if s_type)}"
)
if (
len(signal_types) > 1
and len(hashes_grouped_by_prefix) == 1
and None in hashes_grouped_by_prefix.keys()
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
):
raise CommandError(
f"Error: '{self.content_type.get_name()}' supports more than one SignalType"
"No prefix applied to the hashes, cannot infer correct SignalType"
)
# As well as the above validations, also need to combine the None prefixes into the correct SignalType
if None in hashes_grouped_by_prefix.keys():
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
values = set().union(*hashes_grouped_by_prefix.values())
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
keys = list(hashes_grouped_by_prefix.keys())
keys.remove(None)
# Based on the validations, we know that there will only be one key here or one defined in settings
hashes_grouped_by_prefix.clear()
if not self.only_signal:
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
key: t.Optional[t.Type[SignalType]] = signal_types[0]
if len(keys) > 0:
key = keys[0]
else:
key = self.only_signal
hashes_grouped_by_prefix[key] = values
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
Expand All @@ -225,23 +276,72 @@ def _match_file(
return index.query(s_type.hash_from_file(path))


def _group_hashes_by_prefix(
path: pathlib.Path,
settings: CLISettings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused - also, I prefer to avoid passing this around because it is somewhat of a god object

hashes_grouped_by_prefix: t.Dict[t.Optional[t.Type[SignalType]], t.Set[str]],
) -> None:
for line in path.read_text().splitlines():
line = line.strip()
if not line:
continue
components = line.split()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn over what to do when len(component) > 2 - this must mean that there are spaces in the hash, which some future type could allow, but it means our naive parsing here will fail oddly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree, and had a similar thought -- on re-evaluation I'm thinking for when len(component) > 1 it's probably better to assume that [0] is the prefix and concat [1:] into the hash.

The only issue here is that if there is a future hash with a space, without a prefix the parsing will fail still. I'll have a think on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your intuition is right - if there are hashes with a space, providing the prefix allows for a non-ambiguous parse. Check out lpartition or split(limit=2)

signal_type = None
if len(components) > 1:
# Assume it has a prefix
possible_type = components[0]
hash = components[1].strip()
try:
signal_type = settings.get_signal_type(possible_type)
hash = signal_type.validate_signal_str(hash)
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
except KeyError:
logging.exception("Signal type '%s' is invalid", possible_type)
raise CommandError(
f"Error attempting to infer Signal Type: '{possible_type}' is not a valid Signal Type.",
2,
)
except Exception as e:
logging.exception(
"%s failed verification on %s",
signal_type.get_name() if signal_type else "None!",
hash,
)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash} from {path} is not a valid hash for {signal_type.get_name() if signal_type else 'None!'}",
2,
)
else:
# Assume it doesn't have a prefix and is a raw hash
hash = components[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: What about len(component) > 2?

# We can't validate it this point as we have no context on which signal type
hashes = hashes_grouped_by_prefix.get(signal_type, set())
hashes.add(hash)
hashes_grouped_by_prefix[signal_type] = hashes
Comment on lines +335 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use setdefault again here to simplify (some folks also like defaultdict, but I have burned before!)

hashes_grouped_by_prefix.setdefault(signal_type, set()).add(hash)  # you are done



def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
hashes: t.Set[str],
s_type: t.Type[SignalType],
index: SignalTypeIndex,
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
for hash in path.read_text().splitlines():
for hash in hashes:
hash = hash.strip()
if not hash:
continue
try:
# Need to keep this final validation as we are yet to have validated the hashes without a prefix
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, forgot to remove this additional validation

hash = s_type.validate_signal_str(hash)
except Exception:
logging.exception("%s failed verification on %s", s_type.get_name(), hash)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
f"{hash_repr} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
Expand Down
Expand Up @@ -3,6 +3,7 @@
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal


class MatchCommandTest(ThreatExchangeCLIE2eTest):
Expand All @@ -29,5 +30,61 @@ def test_invalid_hash(self):
not_hash = "this is not an md5"
self.assert_cli_usage_error(
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
f"Error attempting to infer Signal Type: '{not_hash.split()[0]}' is not a valid Signal Type.",
)

def test_valid_hash_with_prefix(self):
hash = "pdq " + PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "photo", "--", hash), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_no_prefix_specific_signal_type(self):
hash = PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "-S", "pdq", "photo", "--", hash),
"pdq 16 (Sample Signals) INVESTIGATION_SEED",
)

def test_multiple_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_incorrect_valid_and_no_prefixes(self):
fakeprefix = "fakesignal"
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = fakeprefix + " " + PdqSignal.get_examples()[1]
hash3 = fakeprefix + " " + PdqSignal.get_examples()[2]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
self.assert_cli_usage_error(
("-H", "photo", fp.name),
f"Error attempting to infer Signal Type: '{fakeprefix}' is not a valid Signal Type.",
)

def test_prefix_and_no_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
hash3 = PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)