То, что создает ваша заставка, называется «Checkpoint V2» и было введено в TF 0.12.
У меня это работает очень хорошо (хотя документы на C++-странице ужасны, поэтому мне потребовался день, чтобы решить). Некоторые люди предлагают converting all variables to constants или freezing the graph, но ни один из них на самом деле не нужен.
Python часть (экономия)
with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
Если вы создаете Saver
с tf.trainable_variables()
, вы можете сэкономить некоторую головную боль и место для хранения. Но, возможно, некоторым более сложным моделям нужны все данные для сохранения, затем удалите этот аргумент до Saver
, просто убедитесь, что вы создаете Saver
после, ваш график создан. Также очень важно дать всем переменным/слоям уникальные имена, иначе вы можете запускать разные задачи.
C++ часть (умозаключение)
Обратите внимание, что checkpointPath
не путь к какой-либо из существующих файлов, только их общий префикс. Если вы по ошибке поместите туда путь к файлу .index
, TF не скажет вам, что это было неправильно, но он умрет во время вывода из-за неинициализированных переменных.
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
Для полноты, вот эквивалент Python:
Умозаключение в Python
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)
Спасибо, @Ian. Я также нашел это: https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc#.ay0m5hj9k – mhaghighat
Хорошая находка. Кажется, что они делают то, что мы пытаемся сделать, но только с Python, а не с C++. В настоящее время я просматриваю эту проблему: https://github.com/tensorflow/tensorflow/issues/615 Что породил этот вопрос: http://stackoverflow.com/questions/35508866/tensorflow-different-ways-to-export -and-run-graph-in-c –
Вы, ребята, наконец-то добились успеха? Я также борюсь с этим, я пробовал много разных методов, но большинство из них не могут сохранить значения переменных, другие сбой или дать мне постоянный вывод в C++ ... http://stackoverflow.com/questions/43515671/tensorflow-freeze-graph-does-not-store-the-variable-values – gdelab