У меня есть настройка и обучение (точно так же, как в уроке this), и теперь я хочу сохранить его, чтобы повторно использовать нейронную сеть в следующий раз, когда мне нужно классифицировать некоторые данные. Модель имеет load
и save
методы, которые необходимо сохранить и восстановить в файле. Но есть ли способ сохранить (и позже - загрузить) модель в базе данных? (в моем случае это CassandraDB).Как сохранить модель Spark MLlib в базе данных?
0
A
ответ
1
Хорошо, я нашел ответ сам. Не уверен, что это лучшее решение, но оно отлично работает для меня.
MultilayerPerceptronClassificationModel
(и, насколько я вижу, каждая модель MLlib
упаковки) реализует Serializable
интерфейс. Поэтому он может быть сериализован/десериализован как ByteArray
.
Давайте создадим таблицу для хранения модели в Cassandra БД:
CREATE TABLE models (
uid TEXT,
name TEXT,
model BLOB,
PRIMARY KEY (uid)
);
Теперь мы можем записать модель к БД:
def saveModel(model: MultilayerPerceptronClassificationModel) = {
val baos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(baos)
oos.writeObject(model)
oos.flush()
oos.close()
sc.parallelize(Seq((model.uid, "my-neural-network-model", baos.toByteArray)))
.saveToCassandra("mykeyspace", "models", SomeColumns("uid", "name", "model"))
}
и читать модель обратно:
def loadModel(): MultilayerPerceptronClassificationModel = {
sc.cassandraTable("mykeyspace", "models").map { r =>
val bis = new ByteArrayInputStream(r.getBytes("model").array())
val ois = new ObjectInputStream(bis)
ois.readObject.asInstanceOf[MultilayerPerceptronClassificationModel]
}.first()
}