Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load model from resources #249

Open
charvey2718 opened this issue Aug 19, 2023 · 0 comments
Open

Load model from resources #249

charvey2718 opened this issue Aug 19, 2023 · 0 comments

Comments

@charvey2718
Copy link

Many thanks for this extremely helpful library.

Currently models are loaded from a file using the model constructor which calls readGraph.
Sometimes it is helpful to load a model from resources. This is an issue that I have solved for frozen graphs, and I'm posting it here as a suggestion, in case you want to add it to the master.

(I'm not that familiar with GitHub, so excuse me not branching the master and pushing, or whatever the terms are!)

I added to model.h a new constructor which takes a pointer to a std::vector of uchar as its only parameter. This then provides the arguments bufferModel->data() and bufferModel->size() to TF_NewBufferFromString instead of readGraph(filename) as in the existing version.

inline model::model(const std::vector<uchar>* bufferModel)
{
	this->status = {TF_NewStatus(), &TF_DeleteStatus};
	this->graph = {TF_NewGraph(), TF_DeleteGraph};
	
	// Create the session.
	std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>
		session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
	
	auto session_deleter = [this](TF_Session* sess) {
		TF_DeleteSession(sess, this->status.get());
		status_check(this->status.get());
	};
	
	this->session = {TF_NewSession(this->graph.get(),
			session_options.get(),
			this->status.get()),
		session_deleter};
	status_check(this->status.get());
	
	// Import the graph definition
	TF_Buffer* def = TF_NewBufferFromString(bufferModel->data(), bufferModel->size());
	if (def == nullptr)
	{
		throw std::runtime_error("Failed to import graph def from file");
	}
	
	std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
			TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
	TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), this->status.get());
	TF_DeleteBuffer(def);
	
	status_check(this->status.get());
}

I then load the PB model from resources as std::vector<uchar> using a LoadModel function in my own project's source code, and pass that to the new model constructor. My project happens to be using wxWidgets, and so this is conveniently done as follows. I include this here only in case this might help someone in future. It's not itself a suggestion for cppflow.

void LoadModel(wxString resName, std::vector<uchar>& model)
{
	HRSRC hrsrc = FindResource(wxGetInstance(), resName, RT_RCDATA);
	if(hrsrc == NULL) return;

	HGLOBAL hglobal = LoadResource(wxGetInstance(), hrsrc);
	if(hglobal == NULL) return;
	
	void *data = LockResource(hglobal);
	if(data == NULL) return;

	DWORD datalen = SizeofResource(wxGetInstance(), hrsrc);
	if(datalen < 1) return;
	
	uchar *charBuf = (uchar*)data;
	model = std::vector<uchar>(charBuf, charBuf + datalen);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant