Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass grid search params to TensorFlowEstimator custom model #2030

Closed
waleedka opened this issue Apr 20, 2016 · 3 comments
Closed

Pass grid search params to TensorFlowEstimator custom model #2030

waleedka opened this issue Apr 20, 2016 · 3 comments
Assignees
Labels
type:feature Feature requests

Comments

@waleedka
Copy link
Contributor

waleedka commented Apr 20, 2016

GridSearchCV is a great way to test and optimize hyper-parameters automatically. I use it with TensorFlowEstimator to optimize learning_rate, batch_size, ...etc. It would be a great addition if I can also use it to customize other parameters in my custom model.

For example, say I have a custom model with a convnet and I want to optimize the stride value. This pseudo code explains what I'm trying to achieve.

I used a custom "params" input to the model function just as an example, not to imply that this is necessarily the right way to implement this feature.

# My custom model. 
# Feature request: New params dict with values filled by GridSearchCV
def cnn_model(X, Y, params):
  stride = params['stride']
  ... custom model definition here ...

# Create the Convnet classifier
cnn_classifier = learn.TensorFlowEstimator(model_fn=cnn_model)

# Grid search on different stride values.
parameters = {'stride': [1, 2, 3],}
grid_searcher = GridSearchCV(cnn_classifier, parameters)
grid_searcher.fit(X, Y)
@ilblackdragon
Copy link
Contributor

It's on our TODO list. Just trying to figure out how to do it nicely to have a general way to pass hyper-parameters into the models.

@suharshs
Copy link

@ilblackdragon Any update on this?

@ilblackdragon
Copy link
Contributor

Model function has params argument. TensorFlowEstimator is deprecated, please use Estimator that takes params argument. This should work now, please re-open if this doesn't.

@aselle aselle added type:feature Feature requests and removed enhancement labels Feb 9, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

5 participants