Skip to content

Commit

Permalink
Source Dataframes in Client (Hosted) (#770)
Browse files Browse the repository at this point in the history
Co-authored-by: sdreyer <sterling@featureform.com>
  • Loading branch information
epps and sdreyer committed May 1, 2023
1 parent 285d719 commit d027124
Show file tree
Hide file tree
Showing 19 changed files with 544 additions and 61 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ test_pandas:
test_offline: gen_grpc ## Run offline tests. Run with `make test_offline provider=(memory | postgres | snowflake | redshift | spark )`
@echo "These tests require a .env file. Please Check .env-template for possible variables"
-mkdir coverage
go test -v -parallel 1000 -timeout 60m -coverpkg=./... -coverprofile coverage/cover.out.tmp ./provider --tags=offline --provider=$(provider)
go test -v -parallel 1000 -timeout 60m -coverpkg=./... -coverprofile coverage/cover.out.tmp ./provider --tags=offline,filepath --provider=$(provider)

test_offline_spark: gen_grpc ## Run spark tests.
@echo "These tests require a .env file. Please Check .env-template for possible variables"
Expand Down Expand Up @@ -390,6 +390,7 @@ test_e2e: update_python ## Runs End-to-End tests on minikube
pytest -m 'hosted' client/tests/test_getting_model.py
pytest -m 'hosted' client/tests/test_updating_provider.py
pytest -m 'hosted' client/tests/test_class_api.py
pytest -m 'hosted' client/tests/test_source_dataframe.py
# pytest -m 'hosted' client/tests/test_search.py

echo "Starting end to end tests"
Expand Down
29 changes: 29 additions & 0 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,35 @@ func (serv *OnlineServer) TrainingData(req *srv.TrainingDataRequest, stream srv.
}
}

func (serv *OnlineServer) SourceData(req *srv.SourceDataRequest, stream srv.Feature_SourceDataServer) error {
serv.Logger.Infow("Serving Source Data", "id", req.Id.String())
if req.Limit == 0 {
return fmt.Errorf("limit must be greater than 0")
}
client, err := serv.client.SourceData(context.Background(), req)
if err != nil {
return fmt.Errorf("could not serve source data: %w", err)
}
for {
row, err := client.Recv()
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("receive error: %w", err)
}
if err := stream.Send(row); err != nil {
serv.Logger.Errorf("failed to write to source data stream: %w", err)
return fmt.Errorf("source send row: %w", err)
}
}
}

func (serv *OnlineServer) SourceColumns(ctx context.Context, req *srv.SourceColumnRequest) (*srv.SourceDataColumns, error) {
serv.Logger.Infow("Serving Source Columns", "id", req.Id.String())
return serv.client.SourceColumns(ctx, req)
}

func (serv *ApiServer) Serve() error {
if serv.grpcServer != nil {
return fmt.Errorf("server already running")
Expand Down
5 changes: 4 additions & 1 deletion client/src/featureform/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SubscriptableTransformation,
)
from .serving import ServingClient
from .constants import NO_RECORD_LIMIT


class Client(ResourceClient, ServingClient):
Expand Down Expand Up @@ -48,13 +49,15 @@ def dataframe(
self,
source: Union[SourceRegistrar, LocalSource, SubscriptableTransformation, str],
variant="default",
limit=NO_RECORD_LIMIT,
):
"""
Compute a dataframe from a registered source or transformation
Args:
source (Union[SourceRegistrar, LocalSource, SubscriptableTransformation, str]): The source or transformation to compute the dataframe from
variant (str): The source variant; defaults to "default" and is ignored if source argument is not a string
limit (int): The maximum number of records to return; defaults to NO_RECORD_LIMIT
**Example:**
```py title="definitions.py"
Expand All @@ -72,4 +75,4 @@ def dataframe(
raise ValueError(
f"source must be of type SourceRegistrar, LocalSource, SubscriptableTransformation or str, not {type(source)}"
)
return self.impl.get_source_as_df(name, variant)
return self.impl._get_source_as_df(name, variant, limit)
3 changes: 3 additions & 0 deletions client/src/featureform/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# For fetching sources as dataframes, the default limit is set to -1
# to denote that all records should be fetched.
NO_RECORD_LIMIT = -1
36 changes: 29 additions & 7 deletions client/src/featureform/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .tls import insecure_channel, secure_channel
from .version import check_up_to_date
from .enums import FileFormat
from .constants import NO_RECORD_LIMIT


def check_feature_type(features):
Expand Down Expand Up @@ -190,11 +191,26 @@ def features(

return feature_values

def get_source_as_df(self, name, variant):
warnings.warn(
"Computing dataframes on sources in hosted mode is not yet supported."
)
return pd.DataFrame()
def _get_source_as_df(self, name, variant, limit):
columns = self._get_source_columns(name, variant)
data = self._get_source_data(name, variant, limit)
return pd.DataFrame(data=data, columns=columns)

def _get_source_data(self, name, variant, limit):
id = serving_pb2.SourceID(name=name, version=variant)
req = serving_pb2.SourceDataRequest(id=id, limit=limit)
resp = self._stub.SourceData(req)
data = []
for rows in resp:
row = [getattr(r, r.WhichOneof("value")) for r in rows.rows]
data.append(row)
return data

def _get_source_columns(self, name, variant):
id = serving_pb2.SourceID(name=name, version=variant)
req = serving_pb2.SourceDataRequest(id=id)
resp = self._stub.SourceColumns(req)
return resp.columns


class LocalClientImpl:
Expand Down Expand Up @@ -653,8 +669,14 @@ def _register_model(
self.db.insert("models", name, type)
self.db.insert(look_up_table, name, association_name, association_variant)

def get_source_as_df(self, name, variant):
return self.get_input_df(name, variant)
def _get_source_as_df(self, name, variant, limit):
if limit == 0:
raise ValueError("limit must be greater than 0")
df = self.get_input_df(name, variant)
if limit != NO_RECORD_LIMIT:
return df[:limit]
else:
return df


class Stream:
Expand Down
11 changes: 1 addition & 10 deletions client/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,9 @@ def del_rw(action, name, exc):

@pytest.fixture(scope="module")
def hosted_sql_provider_and_source():
def get_hosted(custom_marks):
def get_hosted(custom_marks, file_format=FileFormat.CSV.value):
ff.register_user("test_user").make_default_owner()

postgres_host = (
"host.docker.internal"
if "docker" in custom_marks
else "quickstart-postgres"
)
redis_host = (
"host.docker.internal" if "docker" in custom_marks else "quickstart-redis"
)

provider = ff.register_postgres(
name="postgres-quickstart",
# The host name for postgres is different between Docker and Minikube
Expand Down
27 changes: 23 additions & 4 deletions client/tests/test_source_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
True,
marks=pytest.mark.local,
),
pytest.param(
"hosted_sql_provider_and_source", False, False, marks=pytest.mark.hosted
),
pytest.param(
"hosted_sql_provider_and_source", False, True, marks=pytest.mark.docker
),
],
)
def test_dataframe_for_name_variant_args(
Expand All @@ -28,9 +34,14 @@ def test_dataframe_for_name_variant_args(
transformation = arrange_transformation(provider, is_local)

client = ff.Client(local=is_local, insecure=is_insecure)
client.apply(asynchronous=True)
# If we're running in a hosted context, `apply` needs to be synchronous
# to ensure resources are ready to test.
client.apply(asynchronous=is_local)

source_df = client.dataframe(source.name, source.variant)
if is_local:
source_df = client.dataframe(source.name, source.variant)
else:
source_df = client.dataframe(*source.name_variant())
transformation_df = client.dataframe(*transformation.name_variant())

assert isinstance(source_df, pd.DataFrame) and isinstance(
Expand All @@ -50,6 +61,12 @@ def test_dataframe_for_name_variant_args(
True,
marks=pytest.mark.local,
),
pytest.param(
"hosted_sql_provider_and_source", False, False, marks=pytest.mark.hosted
),
pytest.param(
"hosted_sql_provider_and_source", False, True, marks=pytest.mark.docker
),
],
)
def test_dataframe_for_source_args(provider_source_fxt, is_local, is_insecure, request):
Expand All @@ -63,7 +80,9 @@ def test_dataframe_for_source_args(provider_source_fxt, is_local, is_insecure, r
transformation = arrange_transformation(provider, is_local)

client = ff.Client(local=is_local, insecure=is_insecure)
client.apply(asynchronous=True)
# If we're running in a hosted context, `apply` needs to be synchronous
# to ensure resources are ready to test.
client.apply(asynchronous=is_local)

source_df = client.dataframe(source)
transformation_df = client.dataframe(transformation)
Expand Down Expand Up @@ -131,6 +150,6 @@ def average_user_transaction(transactions):

@provider.sql_transformation(variant="quickstart")
def average_user_transaction():
return "SELECT customerid as user_id, avg(transactionamount) as avg_transaction_amt from {{transactions.quickstart}} GROUP BY user_id"
return "SELECT customerid as user_id, avg(transactionamount) AS avg_transaction_amt FROM {{transactions.quickstart}} GROUP BY user_id"

return average_user_transaction
4 changes: 2 additions & 2 deletions coordinator/Dockerfile.old
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RUN echo "PATH=${PATH}" > "${ENV}"

RUN curl https://pyenv.run | bash

# Install Python versions
## Install Python versions
ARG TESTING
RUN if [ "$TESTING" = "True" ]; then \
pyenv install 3.7.16 && pyenv global 3.7.16 && pyenv exec pip install --upgrade pip && pyenv exec pip install -r /app/provider/scripts/spark/requirements.txt ; \
Expand All @@ -76,7 +76,7 @@ RUN if [ "$TESTING" = "True" ]; then \
ENV SPARK_SCRIPT_PATH="/app/provider/scripts/spark/offline_store_spark_runner.py"
ENV PYTHON_INIT_PATH="/app/provider/scripts/spark/python_packages.sh"

# Download Shaded Jar
## Download Shaded Jar
RUN wget https://repo1.maven.org/maven2/com/google/cloud/bigdataoss/gcs-connector/hadoop2-2.2.11/gcs-connector-hadoop2-2.2.11-shaded.jar -P /app/provider/scripts/spark/jars/

EXPOSE 8080
Expand Down
24 changes: 24 additions & 0 deletions proto/serving.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ package featureform.serving.proto;
service Feature {
rpc TrainingData(TrainingDataRequest) returns (stream TrainingDataRow) {}
rpc FeatureServe(FeatureServeRequest) returns (FeatureRow) {}
rpc SourceData(SourceDataRequest) returns (stream SourceDataRow) {}
rpc SourceColumns(SourceColumnRequest) returns (SourceDataColumns) {}
}

message Model {
Expand Down Expand Up @@ -64,3 +66,25 @@ message Value {
bytes on_demand_function = 8;
}
}

message SourceID {
string name = 1;
string version = 2;
}

message SourceDataRequest {
SourceID id = 1;
int64 limit = 2;
}

message SourceColumnRequest {
SourceID id = 1;
}

message SourceDataRow {
repeated Value rows = 1;
}

message SourceDataColumns {
repeated string columns = 1;
}
8 changes: 6 additions & 2 deletions provider/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ func (pt *bqPrimaryTable) GetName() string {

func (pt *bqPrimaryTable) IterateSegment(n int64) (GenericTableIterator, error) {
tableName := pt.query.getTableName(pt.name)
query := fmt.Sprintf("SELECT * FROM `%s` LIMIT %d", tableName, n)
var query string
if n == -1 {
query = fmt.Sprintf("SELECT * FROM `%s`", tableName)
} else {
query = fmt.Sprintf("SELECT * FROM `%s` LIMIT %d", tableName, n)
}
bqQ := pt.client.Query(query)
it, err := bqQ.Read(pt.query.getContext())

if err != nil {
return nil, err
}
Expand Down
62 changes: 62 additions & 0 deletions provider/filepath.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package provider
import (
"fmt"
"net/url"
"regexp"
"strings"

pc "github.com/featureform/provider/provider_config"
Expand Down Expand Up @@ -36,6 +37,27 @@ func NewFilepath(storeType pc.FileStoreType, bucket string, prefix string, path
}
}

func NewEmptyFilepath(storeType pc.FileStoreType) (Filepath, error) {
switch storeType {
case S3:
return &S3Filepath{}, nil
case Azure:
return &AzureFilepath{}, nil
case GCS:
return &GCSFilepath{}, nil
case Memory:
return nil, fmt.Errorf("currently unsupported file store type '%s'", storeType)
case FileSystem:
return nil, fmt.Errorf("currently unsupported file store type '%s'", storeType)
case pc.DB:
return nil, fmt.Errorf("currently unsupported file store type '%s'", storeType)
case HDFS:
return nil, fmt.Errorf("currently unsupported file store type '%s'", storeType)
default:
return nil, fmt.Errorf("unknown store type '%s'", storeType)
}
}

type filePath struct {
bucket string
prefix string
Expand Down Expand Up @@ -99,3 +121,43 @@ func (s3 *S3Filepath) FullPathWithBucket() string {

return fmt.Sprintf("s3://%s%s/%s", s3.bucket, prefix, s3.path)
}

type AzureFilepath struct {
storageAccount string
filePath
}

func (azure *AzureFilepath) FullPathWithBucket() string {
return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net/%s", azure.filePath.bucket, azure.storageAccount, azure.filePath.path)
}

func (azure *AzureFilepath) ParseFullPath(fullPath string) error {
abfssRegex := regexp.MustCompile(`abfss://(.+?)@(.+?)\.dfs.core.windows.net/(.+)`)
matches := abfssRegex.FindStringSubmatch(fullPath)

// If the regex matches all parts of the ABFS path, then we can parse the
// bucket, storage account, and path components. Otherwise, we can just set
// the path for standard Azure Blob Storage paths.
if len(matches) == 4 {
azure.filePath.bucket = matches[1]
azure.storageAccount = matches[2]
azure.filePath.path = matches[3]
} else {
azure.filePath.path = fullPath
}

return nil
}

type GCSFilepath struct {
filePath
}

func (gcs *GCSFilepath) FullPathWithBucket() string {
prefix := ""
if gcs.prefix != "" {
prefix = fmt.Sprintf("/%s", gcs.prefix)
}

return fmt.Sprintf("gs://%s%s/%s", gcs.bucket, prefix, gcs.path)
}

0 comments on commit d027124

Please sign in to comment.