-
Notifications
You must be signed in to change notification settings - Fork 297
/
app.py
275 lines (229 loc) · 9.69 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright (c) Meta Platforms, Inc. and affiliates.
# NO IMPORTS ABOVE ME
# Import pdq first with its hash order warning squelched, it's before our time
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from threatexchange.signal_type.pdq import signal as _
## Resume regularly scheduled imports
import logging
import os
import datetime
import sys
import typing as t
import click
import flask
from flask.logging import default_handler
from flask_apscheduler import APScheduler
from threatexchange.exchanges import auth
from threatexchange.exchanges.signal_exchange_api import TSignalExchangeAPICls
from OpenMediaMatch.storage import interface
from OpenMediaMatch.storage.postgres.impl import DefaultOMMStore
from OpenMediaMatch.background_tasks import (
build_index,
fetcher,
development as dev_apscheduler,
)
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui
from OpenMediaMatch.utils import dev_utils
def _is_debug_mode():
"""Does it look like the app is being run in debug mode?"""
debug = os.environ.get("FLASK_DEBUG")
if not debug:
return os.environ.get("FLASK_ENV") == "development"
return debug.lower() not in ("0", "false", "no")
def _is_werkzeug_reloaded_process():
"""If in debug mode, are we in the reloaded process?"""
return os.environ.get("WERKZEUG_RUN_MAIN") == "true"
def _setup_task_logging(app_logger: logging.Logger):
"""Clownily replace module loggers with our own"""
fetcher.logger = app_logger.getChild("Fetcher")
build_index.logger = app_logger.getChild("Indexer")
def create_app() -> flask.Flask:
"""
Create and configure the Flask app
"""
# We like the flask logging format, so lets have it everywhere
root = logging.getLogger()
if not root.handlers:
root.addHandler(default_handler)
app = flask.Flask(__name__)
if "OMM_CONFIG" in os.environ:
app.config.from_envvar("OMM_CONFIG")
elif sys.argv[0].endswith("/flask"): # Default for flask CLI
# The devcontainer settings. If you are using the CLI outside
# the devcontainer and getting an error, just override the env
app.config.from_pyfile(
"/workspace/reference_omm_configs/development_omm_config.py"
)
else:
raise RuntimeError("No omm_config given - try populating OMM_CONFIG env")
# Override fields with environment variables
app.config.from_prefixed_env("OMM")
app.config.update(
SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"),
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)
logging_config = app.config.get("FLASK_LOGGING_CONFIG")
if logging_config:
logging.config.dictConfig(logging_config)
running_migrations = os.getenv("MIGRATION_COMMAND") == "1"
engine_logging = app.config.get("SQLALCHEMY_ENGINE_LOG_LEVEL")
if engine_logging is not None:
logging.getLogger("sqlalchemy.engine").setLevel(engine_logging)
if "STORAGE_IFACE_INSTANCE" not in app.config:
app.logger.warning("No storage class provided, using the default")
app.config["STORAGE_IFACE_INSTANCE"] = DefaultOMMStore()
storage = app.config["STORAGE_IFACE_INSTANCE"]
assert isinstance(
storage, interface.IUnifiedStore
), "STORAGE_IFACE_INSTANCE is not an instance of IUnifiedStore"
_setup_task_logging(app.logger)
scheduler: APScheduler | None = None
with app.app_context():
# We only run apscheduler in the "outer" reloader process, else we'll
# have multiple executions of the the scheduler in debug mode
if _is_werkzeug_reloaded_process() and not running_migrations:
now = datetime.datetime.now()
scheduler = dev_apscheduler.get_apscheduler()
scheduler.init_app(app)
tasks = []
if app.config.get("TASK_FETCHER", False):
tasks.append("Fetcher")
scheduler.add_job(
"Fetcher",
fetcher.apscheduler_fetch_all,
trigger="interval",
seconds=60 * 4,
start_date=now + datetime.timedelta(seconds=30),
)
if app.config.get("TASK_INDEXER", False):
tasks.append("Indexer")
scheduler.add_job(
"Indexer",
build_index.apscheduler_build_all_indices,
trigger="interval",
seconds=60,
start_date=now + datetime.timedelta(seconds=15),
)
app.logger.info("Started Apscheduler, initial tasks: %s", tasks)
scheduler.start()
storage.init_flask(app)
is_production = app.config.get("PRODUCTION", True)
# Register Flask blueprints for whichever server roles are enabled...
# URL prefixing facilitates easy Layer 7 routing :)
if (
not is_production
and app.config.get("ROLE_HASHER", False)
and app.config.get("ROLE_MATCHER", False)
):
app.register_blueprint(development.bp, url_prefix="/dev")
app.register_blueprint(ui.bp, url_prefix="/ui")
if app.config.get("ROLE_HASHER", False):
app.register_blueprint(hashing.bp, url_prefix="/h")
if app.config.get("ROLE_MATCHER", False):
app.register_blueprint(matching.bp, url_prefix="/m")
if app.config.get("TASK_INDEX_CACHE", False) and not running_migrations:
matching.initiate_index_cache(app, scheduler)
if app.config.get("ROLE_CURATOR", False):
app.register_blueprint(curation.bp, url_prefix="/c")
@app.route("/")
def home():
dst = "status" if is_production else "ui"
return flask.redirect(f"/{dst}")
@app.route("/status")
def status():
"""
Liveness/readiness check endpoint for your favourite Layer 7 load balancer
"""
if app.config.get("ROLE_MATCHER", False):
if matching.index_cache_is_stale():
return f"INDEX-STALE", 503
return "I-AM-ALIVE", 200
@app.route("/site-map")
def site_map():
# Use a set to avoid duplicates (e.g. same path, multiple methods)
routes = set()
for rule in app.url_map.iter_rules():
routes.add(rule.rule)
# Convert set to a list so we can sort it.
routes = list(routes)
routes.sort()
return routes
@app.cli.command("seed")
def seed_data() -> None:
"""Add sample data API connection"""
dev_utils.seed_sample()
@app.cli.command("big-seed")
@click.option("-b", "--banks", default=100, show_default=True)
@click.option("-s", "--seeds", default=10000, show_default=True)
def seed_enourmous(banks: int, seeds: int) -> None:
"""
Seed the database with a large number of banks and hashes
It will generate n banks and put n/m hashes on each bank
"""
dev_utils.seed_banks_random(banks, seeds)
@app.cli.command("fetch")
def fetch():
"""Run the 'background task' to fetch from 3p data and sync to local banks"""
app.logger.setLevel(logging.DEBUG)
storage = get_storage()
fetcher.fetch_all(
storage,
storage.get_signal_type_configs(),
)
@app.cli.command("build_indices")
def build_indices():
"""Run the 'background task' to rebuild indices from bank contents"""
app.logger.setLevel(logging.DEBUG)
storage = get_storage()
build_index.build_all_indices(storage, storage, storage)
@app.cli.command("auth")
@click.argument("api_name", callback=_get_api_cfg)
@click.option(
"--from-str",
help="attempt to use the private _from_str method to auth",
)
@click.option("--unset", is_flag=True, help="clear credentials")
def set_credentials(
api_name: interface.SignalExchangeAPIConfig, from_str: str | None, unset: bool
) -> None:
"""
Persist credentials for apis.
Using the lookup mechanisms built into threatexchange.exchange.auth
attempt to find credentials in the local environment.
The easiest way is usually via an environment variable.
Example, for fb_threatexchange:
TX_ACCESS_TOKEN='12345678|facefaceface' flask auth
"""
api_cfg = api_name # Can't rename arguments, so we rename variable :/
storage = get_storage()
api_cls = api_cfg.api_cls
cred_cls: auth.CredentialHelper = api_cls.get_credential_cls() # type: ignore
if unset:
api_cfg.credentials = None
else:
if from_str is not None:
creds = cred_cls._from_str(from_str)
if creds is None or not creds._are_valid():
raise click.UsageError("Invalid 'from-str'")
else:
try:
creds = cred_cls.get(api_cls)
except auth.SignalExchangeAPIMissingAuthException as e:
raise click.UsageError(e.pretty_str())
except auth.SignalExchangeAPIInvalidAuthException as e:
raise click.UsageError(e.message)
api_cfg.credentials = creds
storage.exchange_api_config_update(api_cfg)
return app
def _get_api_cfg(ctx: click.Context, param: click.Parameter, value: str):
storage = get_storage()
config = storage.exchange_apis_get_configs().get(value)
if config is None:
raise click.BadParameter("No such api")
api_cls = config.api_cls
if not issubclass(api_cls, auth.SignalExchangeWithAuth):
raise click.BadParameter("api doesn't take authentification")
return config