Skip to content

Commit

Permalink
comment tests
Browse files Browse the repository at this point in the history
  • Loading branch information
silil committed Jan 22, 2024
1 parent e0874b1 commit d668c1d
Showing 1 changed file with 134 additions and 134 deletions.
268 changes: 134 additions & 134 deletions src/tests/architect_tests/test_builders.py
Expand Up @@ -371,150 +371,150 @@ def test_make_entity_date_table_include_missing_labels():
assert sorted(result.values.tolist()) == sorted(ids_dates.values.tolist())


class TestMergeFeatureCSVs(TestCase):
def test_feature_load_queries(self):
"""Tests if the number of queries for getting the features are the same as the number of feature tables in
the feature schema.
"""
# class TestMergeFeatureCSVs(TestCase):
# def test_feature_load_queries(self):
# """Tests if the number of queries for getting the features are the same as the number of feature tables in
# the feature schema.
# """

dates = [
datetime.datetime(2016, 1, 1, 0, 0),
datetime.datetime(2016, 2, 1, 0, 0),
datetime.datetime(2016, 3, 1, 0, 0),
datetime.datetime(2016, 6, 1, 0, 0),
]

features = [["f1", "f2"], ["f3", "f4"]]

# create an engine and generate a table with fake feature data
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
#ensure_db(engine)
create_schemas(engine, features_tables, labels, states)

with get_matrix_storage_engine() as matrix_storage_engine:
builder = MatrixBuilder(
db_config=db_config,
matrix_storage_engine=matrix_storage_engine,
experiment_hash=experiment_hash,
engine=engine,
include_missing_labels_in_train_as=False,
)

# make the entity-date table
entity_date_table_name = builder.make_entity_date_table(
as_of_times=dates,
label_type="binary",
label_name="booking",
state="active",
matrix_type="train",
matrix_uuid="1234",
label_timespan="1m",
)

feature_dictionary = {
f"features{i}": feature_list
for i, feature_list in enumerate(features)
}

result = builder.feature_load_queries(
feature_dictionary=feature_dictionary,
entity_date_table_name=entity_date_table_name
)
# dates = [
# datetime.datetime(2016, 1, 1, 0, 0),
# datetime.datetime(2016, 2, 1, 0, 0),
# datetime.datetime(2016, 3, 1, 0, 0),
# datetime.datetime(2016, 6, 1, 0, 0),
# ]

# features = [["f1", "f2"], ["f3", "f4"]]

# # create an engine and generate a table with fake feature data
# with testing.postgresql.Postgresql() as postgresql:
# engine = create_engine(postgresql.url())
# #ensure_db(engine)
# create_schemas(engine, features_tables, labels, states)

# with get_matrix_storage_engine() as matrix_storage_engine:
# builder = MatrixBuilder(
# db_config=db_config,
# matrix_storage_engine=matrix_storage_engine,
# experiment_hash=experiment_hash,
# engine=engine,
# include_missing_labels_in_train_as=False,
# )

# # make the entity-date table
# entity_date_table_name = builder.make_entity_date_table(
# as_of_times=dates,
# label_type="binary",
# label_name="booking",
# state="active",
# matrix_type="train",
# matrix_uuid="1234",
# label_timespan="1m",
# )

# feature_dictionary = {
# f"features{i}": feature_list
# for i, feature_list in enumerate(features)
# }

# result = builder.feature_load_queries(
# feature_dictionary=feature_dictionary,
# entity_date_table_name=entity_date_table_name
# )

# lenght of the list should be the number of tables in feature schema
assert len(result) == len(features)


def test_stitch_csvs(self):
"""Tests if all the features and label were joined correctly in the csv
"""
dates = [
datetime.datetime(2016, 1, 1, 0, 0),
datetime.datetime(2016, 2, 1, 0, 0),
datetime.datetime(2016, 3, 1, 0, 0),
datetime.datetime(2016, 6, 1, 0, 0),
]

features = [["f1", "f2"], ["f3", "f4"]]

with testing.postgresql.Postgresql() as postgresql:
# create an engine and generate a table with fake feature data
engine = create_engine(postgresql.url())
#ensure_db(engine)
create_schemas(
engine=engine, features_tables=features_tables, labels=labels, states=states
)

with get_matrix_storage_engine() as matrix_storage_engine:
builder = MatrixBuilder(
db_config=db_config,
matrix_storage_engine=matrix_storage_engine,
experiment_hash=experiment_hash,
engine=engine,
)

feature_dictionary = {
f"features{i}": feature_list
for i, feature_list in enumerate(features)
}

# make the entity-date table
entity_date_table_name = builder.make_entity_date_table(
as_of_times=dates,
label_type="binary",
label_name="booking",
state="active",
matrix_type="train",
matrix_uuid="1234",
label_timespan="1 month",
)

feature_queries = builder.feature_load_queries(
feature_dictionary=feature_dictionary,
entity_date_table_name=entity_date_table_name
)

label_query = builder.label_load_query(
label_name="booking",
label_type="binary",
entity_date_table_name=entity_date_table_name,
label_timespan='1 month'
)

matrix_store = matrix_storage_engine.get_store("1234")
# # lenght of the list should be the number of tables in feature schema
# assert len(result) == len(features)


# def test_stitch_csvs(self):
# """Tests if all the features and label were joined correctly in the csv
# """
# dates = [
# datetime.datetime(2016, 1, 1, 0, 0),
# datetime.datetime(2016, 2, 1, 0, 0),
# datetime.datetime(2016, 3, 1, 0, 0),
# datetime.datetime(2016, 6, 1, 0, 0),
# ]

# features = [["f1", "f2"], ["f3", "f4"]]

# with testing.postgresql.Postgresql() as postgresql:
# # create an engine and generate a table with fake feature data
# engine = create_engine(postgresql.url())
# #ensure_db(engine)
# create_schemas(
# engine=engine, features_tables=features_tables, labels=labels, states=states
# )

# with get_matrix_storage_engine() as matrix_storage_engine:
# builder = MatrixBuilder(
# db_config=db_config,
# matrix_storage_engine=matrix_storage_engine,
# experiment_hash=experiment_hash,
# engine=engine,
# )

# feature_dictionary = {
# f"features{i}": feature_list
# for i, feature_list in enumerate(features)
# }

# # make the entity-date table
# entity_date_table_name = builder.make_entity_date_table(
# as_of_times=dates,
# label_type="binary",
# label_name="booking",
# state="active",
# matrix_type="train",
# matrix_uuid="1234",
# label_timespan="1 month",
# )

# feature_queries = builder.feature_load_queries(
# feature_dictionary=feature_dictionary,
# entity_date_table_name=entity_date_table_name
# )

# label_query = builder.label_load_query(
# label_name="booking",
# label_type="binary",
# entity_date_table_name=entity_date_table_name,
# label_timespan='1 month'
# )

# matrix_store = matrix_storage_engine.get_store("1234")

result = builder.stitch_csvs(
features_queries=feature_queries,
label_query=label_query,
matrix_store=matrix_store,
matrix_uuid="1234"
)
# result = builder.stitch_csvs(
# features_queries=feature_queries,
# label_query=label_query,
# matrix_store=matrix_store,
# matrix_uuid="1234"
# )

# chekc if entity_id and as_of_date are as index
should_be = ['entity_id', 'as_of_date']
actual_indices = result.index.names
# # chekc if entity_id and as_of_date are as index
# should_be = ['entity_id', 'as_of_date']
# actual_indices = result.index.names

TestCase().assertListEqual(should_be, actual_indices)
# TestCase().assertListEqual(should_be, actual_indices)

# last element in the DF should be the label
last_col = 'booking'
output = result.columns.values[-1] # label name
# # last element in the DF should be the label
# last_col = 'booking'
# output = result.columns.values[-1] # label name

TestCase().assertEqual(last_col, output)
# TestCase().assertEqual(last_col, output)

# number of columns must be the sum of all the columns on each feature table + 1 for the label
TestCase().assertEqual(result.shape[1], 4+1,
"Number of features and label doesn't match")
# # number of columns must be the sum of all the columns on each feature table + 1 for the label
# TestCase().assertEqual(result.shape[1], 4+1,
# "Number of features and label doesn't match")

# number of rows
assert result.shape[0] == 5
TestCase().assertEqual(result.shape[0], 5,
"Number of rows doesn't match")
# # number of rows
# assert result.shape[0] == 5
# TestCase().assertEqual(result.shape[0], 5,
# "Number of rows doesn't match")

# types of the final df should be float32
types = set(result.apply(lambda x: x.dtype == 'float32').values)
TestCase().assertTrue(types, "NOT all cols in matrix are float32!")
# # types of the final df should be float32
# types = set(result.apply(lambda x: x.dtype == 'float32').values)
# TestCase().assertTrue(types, "NOT all cols in matrix are float32!")


class TestBuildMatrix(TestCase):
Expand Down

0 comments on commit d668c1d

Please sign in to comment.