Skip to content

Commit

Permalink
feat(bigquery): Add support for ML model export
Browse files Browse the repository at this point in the history
* Add model support to Project#extract and #extract_job
* Add ExtractJob#model?
* Add ExtractJob#ml_tf_saved_model?
* Add ExtractJob#ml_xgboost_booster?
* Add Model#extract and #extract_job

closes: googleapis#7061
  • Loading branch information
quartzmo committed Aug 12, 2020
1 parent 4fc030b commit eabcbfe
Show file tree
Hide file tree
Showing 14 changed files with 802 additions and 102 deletions.
54 changes: 36 additions & 18 deletions google-cloud-bigquery/acceptance/bigquery/bigquery_test.rb
Expand Up @@ -246,27 +246,45 @@
_(downloaded_file.size).must_be :>, 0
end
end
focus
it "extracts a model to a GCS url with extract" do
job = dataset.query_job model_sql
job.wait_until_done!
_(job).wont_be :failed?

# can get the model
model = dataset.model model_id
_(model).must_be_kind_of Google::Cloud::Bigquery::Model
it "extracts a model to a GCS url with extract_job" do
model = nil
begin
query_job = dataset.query_job model_sql
query_job.wait_until_done!
_(query_job).wont_be :failed?

Tempfile.open "temp_extract_model" do |tmp|
dest_file_name = random_file_destination_name_model
extract_url = "gs://#{bucket.name}/#{dest_file_name}"
result = bigquery.extract model, extract_url do |j|
j.location = "US"
end
_(result).must_equal true
model = dataset.model model_id
_(model).must_be_kind_of Google::Cloud::Bigquery::Model

extract_file = bucket.file dest_file_name
downloaded_file = extract_file.download tmp.path
_(downloaded_file.size).must_be :>, 0
Tempfile.open "temp_extract_model" do |tmp|
extract_url = "gs://#{bucket.name}/#{model_id}"

# sut
extract_job = bigquery.extract_job model, extract_url

extract_job.wait_until_done!
_(extract_job).wont_be :failed?
_(extract_job.ml_tf_saved_model?).must_equal true
_(extract_job.ml_xgboost_booster?).must_equal false
_(extract_job.model?).must_equal true
_(extract_job.table?).must_equal false

source = extract_job.source
_(source).must_be_kind_of Google::Cloud::Bigquery::Model
_(source.model_id).must_equal model_id

extract_files = bucket.files prefix: model_id
_(extract_files).wont_be :nil?
_(extract_files).wont_be :empty?
extract_file = extract_files.find { |f| f.name == "#{model_id}/saved_model.pb" }
_(extract_file).wont_be :nil?
downloaded_file = extract_file.download tmp.path
_(downloaded_file.size).must_be :>, 0
end
ensure
# cleanup
model.delete if model
end
end

Expand Down
78 changes: 75 additions & 3 deletions google-cloud-bigquery/acceptance/bigquery/model_test.rb
Expand Up @@ -41,9 +41,9 @@
end

it "can create, list, read, update, and delete a model" do
job = dataset.query_job model_sql
job.wait_until_done!
_(job).wont_be :failed?
query_job = dataset.query_job model_sql
query_job.wait_until_done!
_(query_job).wont_be :failed?

# can find the model in the list of models
_(dataset.models.all.map(&:model_id)).must_include model_id
Expand All @@ -64,4 +64,76 @@

_(dataset.model(model_id)).must_be_nil
end

it "extracts itself to a GCS url with extract" do
model = nil
begin
query_job = dataset.query_job model_sql
query_job.wait_until_done!
_(query_job).wont_be :failed?

model = dataset.model model_id
_(model).must_be_kind_of Google::Cloud::Bigquery::Model

Tempfile.open "temp_extract_model" do |tmp|
extract_url = "gs://#{bucket.name}/#{model_id}"

# sut
result = model.extract extract_url
_(result).must_equal true

extract_files = bucket.files prefix: model_id
_(extract_files).wont_be :nil?
_(extract_files).wont_be :empty?
extract_file = extract_files.find { |f| f.name == "#{model_id}/saved_model.pb" }
_(extract_file).wont_be :nil?
downloaded_file = extract_file.download tmp.path
_(downloaded_file.size).must_be :>, 0
end
ensure
# cleanup
model.delete if model
end
end

it "extracts itself to a GCS url with extract_job" do
model = nil
begin
query_job = dataset.query_job model_sql
query_job.wait_until_done!
_(query_job).wont_be :failed?

model = dataset.model model_id
_(model).must_be_kind_of Google::Cloud::Bigquery::Model

Tempfile.open "temp_extract_model" do |tmp|
extract_url = "gs://#{bucket.name}/#{model_id}"

# sut
extract_job = model.extract_job extract_url

extract_job.wait_until_done!
_(extract_job).wont_be :failed?
_(extract_job.ml_tf_saved_model?).must_equal true
_(extract_job.ml_xgboost_booster?).must_equal false
_(extract_job.model?).must_equal true
_(extract_job.table?).must_equal false

source = extract_job.source
_(source).must_be_kind_of Google::Cloud::Bigquery::Model
_(source.model_id).must_equal model_id

extract_files = bucket.files prefix: model_id
_(extract_files).wont_be :nil?
_(extract_files).wont_be :empty?
extract_file = extract_files.find { |f| f.name == "#{model_id}/saved_model.pb" }
_(extract_file).wont_be :nil?
downloaded_file = extract_file.download tmp.path
_(downloaded_file.size).must_be :>, 0
end
ensure
# cleanup
model.delete if model
end
end
end
4 changes: 0 additions & 4 deletions google-cloud-bigquery/acceptance/bigquery_helper.rb
Expand Up @@ -125,10 +125,6 @@ def random_file_destination_name
"kitten-test-data-#{SecureRandom.hex}.json"
end

def random_file_destination_name_model
"my-test-extract-model-#{SecureRandom.hex}"
end

def assert_data data
assert_equal Google::Cloud::Bigquery::Data, data.class
refute_nil data.kind
Expand Down
4 changes: 3 additions & 1 deletion google-cloud-bigquery/lib/google/cloud/bigquery/convert.rb
Expand Up @@ -318,7 +318,9 @@ def self.source_format format
"parquet" => "PARQUET",
"datastore" => "DATASTORE_BACKUP",
"backup" => "DATASTORE_BACKUP",
"datastore_backup" => "DATASTORE_BACKUP"
"datastore_backup" => "DATASTORE_BACKUP",
"ml_tf_saved_model" => "ML_TF_SAVED_MODEL",
"ml_xgboost_booster" => "ML_XGBOOST_BOOSTER"
}[format.to_s.downcase]
return val unless val.nil?
format
Expand Down
95 changes: 77 additions & 18 deletions google-cloud-bigquery/lib/google/cloud/bigquery/extract_job.rb
Expand Up @@ -20,8 +20,8 @@ module Bigquery
# # ExtractJob
#
# A {Job} subclass representing an export operation that may be performed
# on a {Table}. A ExtractJob instance is created when you call
# {Table#extract_job}.
# on a {Table} or {Model}. A ExtractJob instance is returned when you call
# {Project#extract_job}, {Table#extract_job} or {Model#extract_job}.
#
# @see https://cloud.google.com/bigquery/docs/exporting-data
# Exporting Data From BigQuery
Expand Down Expand Up @@ -49,15 +49,36 @@ def destinations
end

##
# The table from which the data is exported. This is the table upon
# which {Table#extract_job} was called.
# The table or model which is exported.
#
# @return [Table] A table instance.
# @return [Table, Model, nil] A table or model instance, or `nil`.
#
def source
table = @gapi.configuration.extract.source_table
return nil unless table
retrieve_table table.project_id, table.dataset_id, table.table_id
if (table = @gapi.configuration.extract.source_table)
retrieve_table table.project_id, table.dataset_id, table.table_id
elsif (model = @gapi.configuration.extract.source_model)
retrieve_model model.project_id, model.dataset_id, model.model_id
end
end

##
# Whether the source of the export job is a table. See {#source}.
#
# @return [Boolean] `true` when the source is a table, `false`
# otherwise.
#
def table?
!@gapi.configuration.extract.source_table.nil?
end

##
# Whether the source of the export job is a model. See {#source}.
#
# @return [Boolean] `true` when the source is a model, `false`
# otherwise.
#
def model?
!@gapi.configuration.extract.source_model.nil?
end

##
Expand All @@ -72,7 +93,7 @@ def compression?
end

##
# Checks if the destination format for the data is [newline-delimited
# Checks if the destination format for the table data is [newline-delimited
# JSON](http://jsonlines.org/). The default is `false`.
#
# @return [Boolean] `true` when `NEWLINE_DELIMITED_JSON`, `false`
Expand All @@ -84,20 +105,20 @@ def json?
end

##
# Checks if the destination format for the data is CSV. Tables with
# Checks if the destination format for the table data is CSV. Tables with
# nested or repeated fields cannot be exported as CSV. The default is
# `true`.
# `true` for tables.
#
# @return [Boolean] `true` when `CSV`, `false` otherwise.
#
def csv?
val = @gapi.configuration.extract.destination_format
return true if val.nil?
return true if table? && val.nil?
val == "CSV"
end

##
# Checks if the destination format for the data is
# Checks if the destination format for the table data is
# [Avro](http://avro.apache.org/). The default is `false`.
#
# @return [Boolean] `true` when `AVRO`, `false` otherwise.
Expand All @@ -107,6 +128,29 @@ def avro?
val == "AVRO"
end

##
# Checks if the destination format for the model is TensorFlow SavedModel.
# The default is `true` for models.
#
# @return [Boolean] `true` when `ML_TF_SAVED_MODEL`, `false` otherwise.
#
def ml_tf_saved_model?
val = @gapi.configuration.extract.destination_format
return true if model? && val.nil?
val == "ML_TF_SAVED_MODEL"
end

##
# Checks if the destination format for the model is XGBoost. The default
# is `false`.
#
# @return [Boolean] `true` when `ML_XGBOOST_BOOSTER`, `false` otherwise.
#
def ml_xgboost_booster?
val = @gapi.configuration.extract.destination_format
val == "ML_XGBOOST_BOOSTER"
end

##
# The character or symbol the operation uses to delimit fields in the
# exported data. The default is a comma (,).
Expand Down Expand Up @@ -182,19 +226,24 @@ def initialize gapi
#
# @return [Google::Cloud::Bigquery::ExtractJob::Updater] A job
# configuration object for setting query options.
def self.from_options service, table, storage_files, options
def self.from_options service, source, storage_files, options
job_ref = service.job_ref_from options[:job_id], options[:prefix]
storage_urls = Array(storage_files).map do |url|
url.respond_to?(:to_gs_url) ? url.to_gs_url : url
end
options[:format] ||= Convert.derive_source_format storage_urls.first
extract_config = Google::Apis::BigqueryV2::JobConfigurationExtract.new(
destination_uris: Array(storage_urls)
)
if source.is_a? Google::Apis::BigqueryV2::TableReference
extract_config.source_table = source
elsif source.is_a? Google::Apis::BigqueryV2::ModelReference
extract_config.source_model = source
end
job = Google::Apis::BigqueryV2::Job.new(
job_reference: job_ref,
configuration: Google::Apis::BigqueryV2::JobConfiguration.new(
extract: Google::Apis::BigqueryV2::JobConfigurationExtract.new(
destination_uris: Array(storage_urls),
source_table: table
),
extract: extract_config,
dry_run: options[:dryrun]
)
)
Expand Down Expand Up @@ -362,6 +411,16 @@ def to_gapi
@gapi
end
end

protected

def retrieve_model project_id, dataset_id, model_id
ensure_service!
gapi = service.get_project_model project_id, dataset_id, model_id
Model.from_gapi_json gapi, service
rescue Google::Cloud::NotFoundError
nil
end
end
end
end
Expand Down

0 comments on commit eabcbfe

Please sign in to comment.