Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Align tx match --hash format with the output of tx hash #1289

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
139 changes: 127 additions & 12 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,22 @@ def execute(self, settings: CLISettings) -> None:
if self.as_hashes:
types = (BytesHasher, TextHasher, FileHasher)
signal_types = [s for s in signal_types if issubclass(s, types)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In reference to below comments) This list of signal types have already been narrowed down the valid set given the input.

It might be that we should consider moving --hashes to a toplevel argument (tx match photo|video|hash)` and allowing cross matching might make more sense, but now we've exceeded the scope of the pr.

if self.as_hashes and len(signal_types) > 1:
raise CommandError(
f"Error: '{self.content_type.get_name()}' supports more than one SignalType."
" for '--hashes' also use '--only-signal' to specify one of "
f"{[s.get_name() for s in signal_types]}",
2,
)

logging.info(
"Signal types that apply: %s",
", ".join(s.get_name() for s in signal_types) or "None!",
)

if self.as_hashes:
hashes_grouped_by_prefix: t.Dict[t.Optional[str], t.Set[str]] = {}
# Infer the signal types from the prefixes (None is used as key for hashes with no prefix)
for path in self.files:
_group_hashes_by_prefix(path, settings, hashes_grouped_by_prefix)
Sam-Freeman marked this conversation as resolved.
Show resolved Hide resolved
# Validate the SignalType and append the None prefixes to the correct SignalType
hashes_grouped_by_signal = self.validate(
settings, hashes_grouped_by_prefix, signal_types
)
signal_types = list(hashes_grouped_by_signal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: danger! I think this will allow you jump the tracks of the signal type validation done already above. Unless this was intentional, since I actually can't think of strong benefits of narrowing it when the user clearly is passing you hashes and we enforce uniqueness on all the names. However, this will come as a surprise to the reader of the code.

cat pdq faceface... | tx match video -  # Surprising?

^ Is the above an intentional behavior change?


indices: t.List[t.Tuple[t.Type[SignalType], SignalTypeIndex]] = []
for s_type in signal_types:
index = settings.index.load(s_type)
Expand All @@ -196,11 +199,14 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results = _match_hashes(
hashes_grouped_by_signal[s_type], s_type, index
)
else:
results = _match_file(path, s_type, index)

for r in results:
# TODO Improve output of a single multiple hash query
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
for collab, fetched_data in metadatas:
if not self.all and collab in seen:
Expand All @@ -215,6 +221,90 @@ def execute(self, settings: CLISettings) -> None:
fetched_data,
)

def validate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is still only called with --hashes, not in all calls will use it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, will update -- should we also be validating the non --hashes calls as well? I can update either in this pr or a follow up one.

self,
settings: CLISettings,
hashes_grouped_by_prefix: t.Dict[t.Optional[str], t.Set[str]],
signal_types: t.List[t.Type[SignalType]],
) -> t.Dict[t.Type[SignalType], t.Set[str]]:
"""
Takes the hashes grouped by optional string prefix, performs all required validations.
Command line arguments, SignalType, hashes, and ambiguous SignalTypes are validated.
Returns a dict of validated SignalType, and validated hashes of each SignalType.
Comment on lines +231 to +233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a docstring!

One detail (similar to above comment) is that this validates when the input is hashes

"""

# There is a signal without a prefix and it's ambiguous (fix: -S)
if None in hashes_grouped_by_prefix:
if len(hashes_grouped_by_prefix) > 2 or len(signal_types) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignorable: I think this is being overly specific about the input. This will throw if you do

cat 'pdq facefaceface...
facefaceface...' > file.txt
$ tx match -S pdq photo file.txt

but I think it's unlikely that you'll get a non-prefixed value here that will pass the signal type validation on L247

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh, this is a good catch. I see what you mean. Thanks!

# Ambiguous -- throw
raise CommandError(
f"Error: The SignalType is ambiguous for '{self.content_type.get_name()}'"
"For '--hashes' please use '-S' or '--only-signal' to specify one of"
f"{[s.get_name() for s in signal_types]}",
2,
)
# unambiguous -- set
hashes_grouped_by_prefix.setdefault(
signal_types[0].get_name(), set()
).update(hashes_grouped_by_prefix.pop(None))

# Validate that the SignalTypes and hashes are valid (fix: remove or update hash)
hashes_grouped_by_signal = self.validate_signal_type_and_hash(
settings, hashes_grouped_by_prefix
)

# There is a signal not valid for your arguments (fix: remove or change args)
if self.only_signal and (
self.only_signal not in hashes_grouped_by_signal
or len(hashes_grouped_by_signal) > 1
):
Comment on lines +257 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this block if you pass in the expected signal types into 252 (rather than all settings)

raise CommandError(
f"Error: SignalType '{self.only_signal} was specified, but attempting to query with additional/different SignalTypes"
f"Please remove or change your arguments, or remove the additional SignalTypes"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: f-string not needed here

)

return hashes_grouped_by_signal

def validate_signal_type_and_hash(
self,
settings: CLISettings,
hashes_grouped_by_prefix: t.Dict[t.Optional[str], t.Set[str]],
) -> t.Dict[t.Type[SignalType], t.Set[str]]:
hashes_grouped_by_signal: t.Dict[t.Type[SignalType], t.Set[str]] = {}
for possible_signal, hashes in hashes_grouped_by_prefix.items():
if not possible_signal:
# This shouldn't be hit as we filter out the None key before
continue
Comment on lines +275 to +277
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: Convert this to assert if you are sure

# Validate the signal
try:
signal_type = settings.get_signal_type(possible_signal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing you pass the settings god-object in for is the signal types. Why not pass in the signal types and use those? {st.get_name(): st for st in signal_types}

Another reason I don't like using the full set of signal types is that this will accept types that are not valid for the command (pdq for video for example).

except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This will catch CTRL+C interrupt and out of memory, which I don't think you intended. use except Exception to narrow down to the usual set.

logging.exception("Signal type %s is invalid.", possible_signal)
raise CommandError(
f"Error: '{possible_signal}' is not a valid Signal Type."
"Please remove or update this hash from your query",
2,
)
# Validate the hashes
for hash in hashes:
try:
hash = signal_type.validate_signal_str(hash.strip())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may actually change the string for some types (i.e. for some binary types we accept octal, hex, etc). You probably want to save this back.

except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This will catch CTRL+C interrupt and out of memory, which I don't think you intended. use except Exception to narrow down to the usual set.

logging.exception(
"%s failed verification on %s",
signal_type.get_name() if signal_type else "None!",
hash,
)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash_repr} is not a valid hash for {signal_type.get_name() if signal_type else 'None!'}",
2,
)
hashes_grouped_by_signal[signal_type] = hashes
return hashes_grouped_by_signal


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
Expand All @@ -225,23 +315,48 @@ def _match_file(
return index.query(s_type.hash_from_file(path))


def _group_hashes_by_prefix(
path: pathlib.Path,
settings: CLISettings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused - also, I prefer to avoid passing this around because it is somewhat of a god object

hashes_grouped_by_prefix: t.Dict[t.Optional[str], t.Set[str]],
) -> None:
for line in path.read_text().splitlines():
line = line.strip()
if not line:
continue
components = line.split()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn over what to do when len(component) > 2 - this must mean that there are spaces in the hash, which some future type could allow, but it means our naive parsing here will fail oddly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree, and had a similar thought -- on re-evaluation I'm thinking for when len(component) > 1 it's probably better to assume that [0] is the prefix and concat [1:] into the hash.

The only issue here is that if there is a future hash with a space, without a prefix the parsing will fail still. I'll have a think on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your intuition is right - if there are hashes with a space, providing the prefix allows for a non-ambiguous parse. Check out lpartition or split(limit=2)

signal_type = None
if len(components) == 2:
# Assume it has a prefix
signal_type, hash = components
else:
# Assume it doesn't have a prefix and is a raw hash
hash = components[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: What about len(component) > 2?

hashes = hashes_grouped_by_prefix.get(signal_type, set())
hashes.add(hash)
hashes_grouped_by_prefix[signal_type] = hashes
Comment on lines +335 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use setdefault again here to simplify (some folks also like defaultdict, but I have burned before!)

hashes_grouped_by_prefix.setdefault(signal_type, set()).add(hash)  # you are done



def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
hashes: t.Set[str],
s_type: t.Type[SignalType],
index: SignalTypeIndex,
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
for hash in path.read_text().splitlines():
for hash in hashes:
hash = hash.strip()
if not hash:
continue
try:
# Need to keep this final validation as we are yet to have validated the hashes without a prefix
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, forgot to remove this additional validation

hash = s_type.validate_signal_str(hash)
except Exception:
logging.exception("%s failed verification on %s", s_type.get_name(), hash)
hash_repr = repr(hash)
if len(hash_repr) > 50:
hash_repr = hash_repr[:47] + "..."
raise CommandError(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
f"{hash_repr} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from threatexchange.cli.tests.e2e_test_helper import ThreatExchangeCLIE2eTest
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal


class MatchCommandTest(ThreatExchangeCLIE2eTest):
Expand All @@ -29,5 +30,61 @@ def test_invalid_hash(self):
not_hash = "this is not an md5"
self.assert_cli_usage_error(
("-H", "video", "--", not_hash),
f"{not_hash!r} from .* is not a valid hash for video_md5",
f"'{not_hash.split()[0]}' is not a valid hash for video_md5",
)

def test_valid_hash_with_prefix(self):
hash = "pdq " + PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "photo", "--", hash), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_no_prefix_specific_signal_type(self):
hash = PdqSignal.get_examples()[0]
self.assert_cli_output(
("-H", "-S", "pdq", "photo", "--", hash),
"pdq 16 (Sample Signals) INVESTIGATION_SEED",
)

def test_multiple_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)

def test_incorrect_valid_and_no_prefixes(self):
fakeprefix = "fakesignal"
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = fakeprefix + " " + PdqSignal.get_examples()[1]
hash3 = fakeprefix + " " + PdqSignal.get_examples()[2]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
self.assert_cli_usage_error(
("-H", "photo", fp.name),
f"Error: '{fakeprefix}' is not a valid Signal Type.*",
)

def test_prefix_and_no_prefixes(self):
hash1 = "pdq " + PdqSignal.get_examples()[0]
hash2 = "pdq " + PdqSignal.get_examples()[1]
hash3 = PdqSignal.get_examples()[1]
with tempfile.NamedTemporaryFile("a+") as fp:
fp.write(hash1 + "\n")
fp.write(hash2 + "\n")
fp.write(hash3)
fp.seek(0)
# CLI is currently showing only one match for multiple hashes
# TODO Improve the handling of multiple hashes in one match query
self.assert_cli_output(
("-H", "photo", fp.name), "pdq 16 (Sample Signals) INVESTIGATION_SEED"
)