2016-12-02 3 views
0

У меня есть настройка и обучение (точно так же, как в уроке this), и теперь я хочу сохранить его, чтобы повторно использовать нейронную сеть в следующий раз, когда мне нужно классифицировать некоторые данные. Модель имеет load и save методы, которые необходимо сохранить и восстановить в файле. Но есть ли способ сохранить (и позже - загрузить) модель в базе данных? (в моем случае это CassandraDB).Как сохранить модель Spark MLlib в базе данных?

ответ

1

Хорошо, я нашел ответ сам. Не уверен, что это лучшее решение, но оно отлично работает для меня.

MultilayerPerceptronClassificationModel (и, насколько я вижу, каждая модель MLlib упаковки) реализует Serializable интерфейс. Поэтому он может быть сериализован/десериализован как ByteArray.

Давайте создадим таблицу для хранения модели в Cassandra БД:

CREATE TABLE models (
    uid TEXT, 
    name TEXT, 
    model BLOB, 

    PRIMARY KEY (uid) 
); 

Теперь мы можем записать модель к БД:

def saveModel(model: MultilayerPerceptronClassificationModel) = { 
    val baos = new ByteArrayOutputStream() 
    val oos = new ObjectOutputStream(baos) 

    oos.writeObject(model) 
    oos.flush() 
    oos.close() 

    sc.parallelize(Seq((model.uid, "my-neural-network-model", baos.toByteArray))) 
    .saveToCassandra("mykeyspace", "models", SomeColumns("uid", "name", "model")) 
} 

и читать модель обратно:

def loadModel(): MultilayerPerceptronClassificationModel = { 
    sc.cassandraTable("mykeyspace", "models").map { r => 
    val bis = new ByteArrayInputStream(r.getBytes("model").array()) 
    val ois = new ObjectInputStream(bis) 

    ois.readObject.asInstanceOf[MultilayerPerceptronClassificationModel] 
    }.first() 
}