2017-02-09 13 views
0

Я узнал о нейронных сетях, и я нашел следующий код, который хорошо работает с набором данных диафрагмы, достигающим 99% точности.Как повысить точность нейронной сети для случайного массива X, где y является модулем (X, 2)?

Однако, когда я пытаюсь запустить ту же сеть на фиктивном наборе данных, случайно сгенерированном с использованием get_modulus_data(), где есть корреляция между входом и выходом следующим образом, он не работает так хорошо. Кто-нибудь сможет пролить свет, почему нейронная сеть борется с этим типом обучения?

# Implementation of a simple MLP network with one hidden layer. Tested on the iris data set. 
# Requires: numpy, sklearn, tensorflow 

# NOTE: In order to make the code simple, we rewrite x * W_1 + b_1 = x' * W_1' 
# where x' = [x | 1] and W_1' is the matrix W_1 appended with a new row with elements b_1's. 
# Similarly, for h * W_2 + b_2 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from sklearn.cross_validation import train_test_split 

RANDOM_SEED = 42 
tf.set_random_seed(RANDOM_SEED) 


def init_weights(shape): 
    """ Weight initialization """ 
    weights = tf.random_normal(shape, stddev=0.1) 
    return tf.Variable(weights) 

def forwardprop(X, w_1, w_2, w_3): 
    """ 
    Forward-propagation. 
    IMPORTANT: yhat is not softmax since TensorFlow's softmax_cross_entropy_with_logits() does that internally. 
    """ 
    h_1 = tf.nn.elu(tf.matmul(X, w_1)) # The \sigma function 
    h_1 = tf.nn.dropout(h_1,0.5) 
    h_2 = tf.nn.elu(tf.matmul(h_1, w_2)) # The \sigma function 
    h_2 = tf.nn.dropout(h_2,0.5) 
    yhat = tf.matmul(h_2, w_3) # The \varphi function 
    return yhat 

def get_iris_data(): 
    """ Read the iris data set and split them into training and test sets """ 
    iris = datasets.load_iris() 
    data = iris["data"] 
    target = iris["target"] 

    # Prepend the column of 1s for bias 
    N, M = data.shape 
    all_X = np.ones((N, M + 1)) 
    all_X[:, 1:] = data 

    # Convert into one-hot vectors 
    num_labels = len(np.unique(target)) 
    all_Y = np.eye(num_labels)[target] # One liner trick! 
    return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED) 

def get_modulus_data(): 

    all_X = np.round(np.random.random((150,5))) 
    target = np.int_(all_X.sum(axis=1)%3) 

    # Convert into one-hot vectors 
    num_labels = len(np.unique(target)) 
    all_Y = np.eye(num_labels)[target] # One liner trick! 

    return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED) 

def main(): 
    # train_X, test_X, train_y, test_y = get_iris_data() 
    train_X, test_X, train_y, test_y = get_modulus_data() 

    # Layer's sizes 
    x_size = train_X.shape[1] # Number of input nodes: 4 features and 1 bias 
    h_size = 256    # Number of hidden nodes 
    y_size = train_y.shape[1] # Number of outcomes (3 iris flowers) 

    # Symbols 
    X = tf.placeholder("float", shape=[None, x_size]) 
    y = tf.placeholder("float", shape=[None, y_size]) 

    # Weight initializations 
    w_1 = init_weights((x_size, h_size)) 
    w_2 = init_weights((h_size, h_size//2)) 
    w_3 = init_weights((h_size//2, y_size)) 

    # Forward propagation 
    yhat = forwardprop(X, w_1, w_2, w_3) 
    predict = tf.argmax(yhat, dimension=1) 

    # Backward propagation 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(yhat, y)) 
    updates = tf.train.GradientDescentOptimizer(0.01).minimize(cost) 

    # Run SGD 
    sess = tf.Session() 
    init = tf.global_variables_initializer() 
    sess.run(init) 

    for epoch in range(300): 
     # Train with each example 
     for i in range(len(train_X)-10): 
      sess.run(updates, feed_dict={X: train_X[i: i + 10], y: train_y[i: i + 10]}) 

     train_accuracy = np.mean(np.argmax(train_y, axis=1) == sess.run(predict, feed_dict={X: train_X, y: train_y})) 
     test_accuracy = np.mean(np.argmax(test_y, axis=1) == sess.run(predict, feed_dict={X: test_X, y: test_y})) 

     print("Epoch = %d, train accuracy = %.2f%%, test accuracy = %.2f%%" 
       % (epoch + 1, 100. * train_accuracy, 100. * test_accuracy)) 

    sess.close() 

if __name__ == '__main__': 
    main() 

Вот выход из программы:

Epoch = 1, train accuracy = 44.00%, test accuracy = 50.00% 
Epoch = 2, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 3, train accuracy = 50.00%, test accuracy = 46.00% 
Epoch = 4, train accuracy = 44.00%, test accuracy = 40.00% 
Epoch = 5, train accuracy = 46.00%, test accuracy = 48.00% 
Epoch = 6, train accuracy = 45.00%, test accuracy = 38.00% 
Epoch = 7, train accuracy = 45.00%, test accuracy = 44.00% 
Epoch = 8, train accuracy = 53.00%, test accuracy = 40.00% 
Epoch = 9, train accuracy = 48.00%, test accuracy = 46.00% 
Epoch = 10, train accuracy = 44.00%, test accuracy = 38.00% 
Epoch = 11, train accuracy = 48.00%, test accuracy = 48.00% 
Epoch = 12, train accuracy = 44.00%, test accuracy = 32.00% 
Epoch = 13, train accuracy = 42.00%, test accuracy = 48.00% 
Epoch = 14, train accuracy = 47.00%, test accuracy = 46.00% 
Epoch = 15, train accuracy = 50.00%, test accuracy = 44.00% 
Epoch = 16, train accuracy = 48.00%, test accuracy = 44.00% 
Epoch = 17, train accuracy = 49.00%, test accuracy = 40.00% 
Epoch = 18, train accuracy = 48.00%, test accuracy = 36.00% 
Epoch = 19, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 20, train accuracy = 47.00%, test accuracy = 32.00% 
Epoch = 21, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 22, train accuracy = 43.00%, test accuracy = 38.00% 
Epoch = 23, train accuracy = 48.00%, test accuracy = 46.00% 
Epoch = 24, train accuracy = 50.00%, test accuracy = 44.00% 
Epoch = 25, train accuracy = 51.00%, test accuracy = 34.00% 
Epoch = 26, train accuracy = 52.00%, test accuracy = 36.00% 
Epoch = 27, train accuracy = 54.00%, test accuracy = 42.00% 
Epoch = 28, train accuracy = 50.00%, test accuracy = 34.00% 
Epoch = 29, train accuracy = 47.00%, test accuracy = 38.00% 
Epoch = 30, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 31, train accuracy = 52.00%, test accuracy = 42.00% 
Epoch = 32, train accuracy = 51.00%, test accuracy = 34.00% 
Epoch = 33, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 34, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 35, train accuracy = 50.00%, test accuracy = 46.00% 
Epoch = 36, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 37, train accuracy = 50.00%, test accuracy = 44.00% 
Epoch = 38, train accuracy = 46.00%, test accuracy = 38.00% 
Epoch = 39, train accuracy = 46.00%, test accuracy = 46.00% 
Epoch = 40, train accuracy = 50.00%, test accuracy = 40.00% 
Epoch = 41, train accuracy = 53.00%, test accuracy = 38.00% 
Epoch = 42, train accuracy = 50.00%, test accuracy = 42.00% 
Epoch = 43, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 44, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 45, train accuracy = 51.00%, test accuracy = 40.00% 
Epoch = 46, train accuracy = 50.00%, test accuracy = 46.00% 
Epoch = 47, train accuracy = 49.00%, test accuracy = 46.00% 
Epoch = 48, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 49, train accuracy = 52.00%, test accuracy = 46.00% 
Epoch = 50, train accuracy = 47.00%, test accuracy = 52.00% 
Epoch = 51, train accuracy = 44.00%, test accuracy = 40.00% 
Epoch = 52, train accuracy = 51.00%, test accuracy = 44.00% 
Epoch = 53, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 54, train accuracy = 49.00%, test accuracy = 38.00% 
Epoch = 55, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 56, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 57, train accuracy = 50.00%, test accuracy = 38.00% 
Epoch = 58, train accuracy = 48.00%, test accuracy = 44.00% 
Epoch = 59, train accuracy = 51.00%, test accuracy = 42.00% 
Epoch = 60, train accuracy = 48.00%, test accuracy = 34.00% 
Epoch = 61, train accuracy = 47.00%, test accuracy = 42.00% 
Epoch = 62, train accuracy = 48.00%, test accuracy = 42.00% 
Epoch = 63, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 64, train accuracy = 53.00%, test accuracy = 48.00% 
Epoch = 65, train accuracy = 51.00%, test accuracy = 42.00% 
Epoch = 66, train accuracy = 48.00%, test accuracy = 36.00% 
Epoch = 67, train accuracy = 49.00%, test accuracy = 46.00% 
Epoch = 68, train accuracy = 52.00%, test accuracy = 42.00% 
Epoch = 69, train accuracy = 50.00%, test accuracy = 38.00% 
Epoch = 70, train accuracy = 49.00%, test accuracy = 42.00% 
Epoch = 71, train accuracy = 50.00%, test accuracy = 44.00% 
Epoch = 72, train accuracy = 50.00%, test accuracy = 38.00% 
Epoch = 73, train accuracy = 48.00%, test accuracy = 46.00% 
Epoch = 74, train accuracy = 52.00%, test accuracy = 48.00% 
Epoch = 75, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 76, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 77, train accuracy = 51.00%, test accuracy = 42.00% 
Epoch = 78, train accuracy = 45.00%, test accuracy = 40.00% 
Epoch = 79, train accuracy = 46.00%, test accuracy = 38.00% 
Epoch = 80, train accuracy = 51.00%, test accuracy = 42.00% 
Epoch = 81, train accuracy = 47.00%, test accuracy = 42.00% 
Epoch = 82, train accuracy = 53.00%, test accuracy = 44.00% 
Epoch = 83, train accuracy = 49.00%, test accuracy = 38.00% 
Epoch = 84, train accuracy = 49.00%, test accuracy = 38.00% 
Epoch = 85, train accuracy = 52.00%, test accuracy = 30.00% 
Epoch = 86, train accuracy = 49.00%, test accuracy = 36.00% 
Epoch = 87, train accuracy = 48.00%, test accuracy = 44.00% 
Epoch = 88, train accuracy = 46.00%, test accuracy = 40.00% 
Epoch = 89, train accuracy = 48.00%, test accuracy = 44.00% 
Epoch = 90, train accuracy = 50.00%, test accuracy = 34.00% 
Epoch = 91, train accuracy = 53.00%, test accuracy = 32.00% 
Epoch = 92, train accuracy = 51.00%, test accuracy = 40.00% 
Epoch = 93, train accuracy = 43.00%, test accuracy = 44.00% 
Epoch = 94, train accuracy = 48.00%, test accuracy = 40.00% 
Epoch = 95, train accuracy = 50.00%, test accuracy = 44.00% 
Epoch = 96, train accuracy = 48.00%, test accuracy = 38.00% 
Epoch = 97, train accuracy = 50.00%, test accuracy = 50.00% 
Epoch = 98, train accuracy = 47.00%, test accuracy = 46.00% 
Epoch = 99, train accuracy = 52.00%, test accuracy = 40.00% 
Epoch = 100, train accuracy = 50.00%, test accuracy = 36.00% 

Для сравнения, вот выход из набора данных радужки:

Epoch = 1, train accuracy = 71.00%, test accuracy = 68.00% 
Epoch = 2, train accuracy = 81.00%, test accuracy = 80.00% 
Epoch = 3, train accuracy = 80.00%, test accuracy = 90.00% 
Epoch = 4, train accuracy = 87.00%, test accuracy = 92.00% 
Epoch = 5, train accuracy = 91.00%, test accuracy = 86.00% 
Epoch = 6, train accuracy = 90.00%, test accuracy = 94.00% 
Epoch = 7, train accuracy = 96.00%, test accuracy = 92.00% 
Epoch = 8, train accuracy = 89.00%, test accuracy = 86.00% 
Epoch = 9, train accuracy = 92.00%, test accuracy = 92.00% 
Epoch = 10, train accuracy = 92.00%, test accuracy = 94.00% 
Epoch = 11, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 12, train accuracy = 94.00%, test accuracy = 96.00% 
Epoch = 13, train accuracy = 93.00%, test accuracy = 94.00% 
Epoch = 14, train accuracy = 95.00%, test accuracy = 90.00% 
Epoch = 15, train accuracy = 94.00%, test accuracy = 94.00% 
Epoch = 16, train accuracy = 98.00%, test accuracy = 96.00% 
Epoch = 17, train accuracy = 93.00%, test accuracy = 94.00% 
Epoch = 18, train accuracy = 92.00%, test accuracy = 98.00% 
Epoch = 19, train accuracy = 94.00%, test accuracy = 100.00% 
Epoch = 20, train accuracy = 94.00%, test accuracy = 96.00% 
Epoch = 21, train accuracy = 96.00%, test accuracy = 96.00% 
Epoch = 22, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 23, train accuracy = 95.00%, test accuracy = 92.00% 
Epoch = 24, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 25, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 26, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 27, train accuracy = 96.00%, test accuracy = 94.00% 
Epoch = 28, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 29, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 30, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 31, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 32, train accuracy = 96.00%, test accuracy = 100.00% 
Epoch = 33, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 34, train accuracy = 95.00%, test accuracy = 96.00% 
Epoch = 35, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 36, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 37, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 38, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 39, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 40, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 41, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 42, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 43, train accuracy = 98.00%, test accuracy = 96.00% 
Epoch = 44, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 45, train accuracy = 94.00%, test accuracy = 100.00% 
Epoch = 46, train accuracy = 98.00%, test accuracy = 98.00% 
Epoch = 47, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 48, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 49, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 50, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 51, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 52, train accuracy = 97.00%, test accuracy = 94.00% 
Epoch = 53, train accuracy = 96.00%, test accuracy = 100.00% 
Epoch = 54, train accuracy = 93.00%, test accuracy = 98.00% 
Epoch = 55, train accuracy = 94.00%, test accuracy = 98.00% 
Epoch = 56, train accuracy = 96.00%, test accuracy = 96.00% 
Epoch = 57, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 58, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 59, train accuracy = 96.00%, test accuracy = 94.00% 
Epoch = 60, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 61, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 62, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 63, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 64, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 65, train accuracy = 96.00%, test accuracy = 100.00% 
Epoch = 66, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 67, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 68, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 69, train accuracy = 93.00%, test accuracy = 96.00% 
Epoch = 70, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 71, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 72, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 73, train accuracy = 98.00%, test accuracy = 98.00% 
Epoch = 74, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 75, train accuracy = 99.00%, test accuracy = 98.00% 
Epoch = 76, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 77, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 78, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 79, train accuracy = 98.00%, test accuracy = 96.00% 
Epoch = 80, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 81, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 82, train accuracy = 96.00%, test accuracy = 96.00% 
Epoch = 83, train accuracy = 95.00%, test accuracy = 98.00% 
Epoch = 84, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 85, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 86, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 87, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 88, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 89, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 90, train accuracy = 98.00%, test accuracy = 98.00% 
Epoch = 91, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 92, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 93, train accuracy = 97.00%, test accuracy = 100.00% 
Epoch = 94, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 95, train accuracy = 97.00%, test accuracy = 96.00% 
Epoch = 96, train accuracy = 96.00%, test accuracy = 100.00% 
Epoch = 97, train accuracy = 96.00%, test accuracy = 98.00% 
Epoch = 98, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 99, train accuracy = 97.00%, test accuracy = 98.00% 
Epoch = 100, train accuracy = 96.00%, test accuracy = 98.00% 
+0

Сеть, которая работает с одним набором данных, не означает, что она будет работать с другим набором данных. Вам необходимо настроить учебную способность модели (путем добавления или удаления слоев). –

+0

Спасибо, я увеличил количество слоев, и теперь получаю гораздо лучшие результаты. Я бы дал вам баллы за ваш комментарий, но я не могу. – brtknr

+0

Я обращусь к ответу. –

ответ

0

сеть, которая работает с одним набором данных Безразлично» t означает, что будет работать с другим набором данных. Каждая сеть спроектирована с определенной учебной способностью (приблизительно с количеством параметров в сети) для соответствия заданной задаче или набору данных.

Если образовательная способность в модели слишком низкая, тогда модель будет недостаточно, и это то, что вы ищете. Слишком большой учебный потенциал и модель будут перекуплены.

Вам необходимо настроить учебную способность модели (путем добавления или удаления слоев). В этом случае кажется, что вы недофинансируете, так что добавление слоев поможет.