Я узнал о нейронных сетях, и я нашел следующий код, который хорошо работает с набором данных диафрагмы, достигающим 99% точности.Как повысить точность нейронной сети для случайного массива X, где y является модулем (X, 2)?
Однако, когда я пытаюсь запустить ту же сеть на фиктивном наборе данных, случайно сгенерированном с использованием get_modulus_data()
, где есть корреляция между входом и выходом следующим образом, он не работает так хорошо. Кто-нибудь сможет пролить свет, почему нейронная сеть борется с этим типом обучения?
# Implementation of a simple MLP network with one hidden layer. Tested on the iris data set.
# Requires: numpy, sklearn, tensorflow
# NOTE: In order to make the code simple, we rewrite x * W_1 + b_1 = x' * W_1'
# where x' = [x | 1] and W_1' is the matrix W_1 appended with a new row with elements b_1's.
# Similarly, for h * W_2 + b_2
import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
RANDOM_SEED = 42
tf.set_random_seed(RANDOM_SEED)
def init_weights(shape):
""" Weight initialization """
weights = tf.random_normal(shape, stddev=0.1)
return tf.Variable(weights)
def forwardprop(X, w_1, w_2, w_3):
"""
Forward-propagation.
IMPORTANT: yhat is not softmax since TensorFlow's softmax_cross_entropy_with_logits() does that internally.
"""
h_1 = tf.nn.elu(tf.matmul(X, w_1)) # The \sigma function
h_1 = tf.nn.dropout(h_1,0.5)
h_2 = tf.nn.elu(tf.matmul(h_1, w_2)) # The \sigma function
h_2 = tf.nn.dropout(h_2,0.5)
yhat = tf.matmul(h_2, w_3) # The \varphi function
return yhat
def get_iris_data():
""" Read the iris data set and split them into training and test sets """
iris = datasets.load_iris()
data = iris["data"]
target = iris["target"]
# Prepend the column of 1s for bias
N, M = data.shape
all_X = np.ones((N, M + 1))
all_X[:, 1:] = data
# Convert into one-hot vectors
num_labels = len(np.unique(target))
all_Y = np.eye(num_labels)[target] # One liner trick!
return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED)
def get_modulus_data():
all_X = np.round(np.random.random((150,5)))
target = np.int_(all_X.sum(axis=1)%3)
# Convert into one-hot vectors
num_labels = len(np.unique(target))
all_Y = np.eye(num_labels)[target] # One liner trick!
return train_test_split(all_X, all_Y, test_size=0.33, random_state=RANDOM_SEED)
def main():
# train_X, test_X, train_y, test_y = get_iris_data()
train_X, test_X, train_y, test_y = get_modulus_data()
# Layer's sizes
x_size = train_X.shape[1] # Number of input nodes: 4 features and 1 bias
h_size = 256 # Number of hidden nodes
y_size = train_y.shape[1] # Number of outcomes (3 iris flowers)
# Symbols
X = tf.placeholder("float", shape=[None, x_size])
y = tf.placeholder("float", shape=[None, y_size])
# Weight initializations
w_1 = init_weights((x_size, h_size))
w_2 = init_weights((h_size, h_size//2))
w_3 = init_weights((h_size//2, y_size))
# Forward propagation
yhat = forwardprop(X, w_1, w_2, w_3)
predict = tf.argmax(yhat, dimension=1)
# Backward propagation
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(yhat, y))
updates = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
# Run SGD
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for epoch in range(300):
# Train with each example
for i in range(len(train_X)-10):
sess.run(updates, feed_dict={X: train_X[i: i + 10], y: train_y[i: i + 10]})
train_accuracy = np.mean(np.argmax(train_y, axis=1) == sess.run(predict, feed_dict={X: train_X, y: train_y}))
test_accuracy = np.mean(np.argmax(test_y, axis=1) == sess.run(predict, feed_dict={X: test_X, y: test_y}))
print("Epoch = %d, train accuracy = %.2f%%, test accuracy = %.2f%%"
% (epoch + 1, 100. * train_accuracy, 100. * test_accuracy))
sess.close()
if __name__ == '__main__':
main()
Вот выход из программы:
Epoch = 1, train accuracy = 44.00%, test accuracy = 50.00%
Epoch = 2, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 3, train accuracy = 50.00%, test accuracy = 46.00%
Epoch = 4, train accuracy = 44.00%, test accuracy = 40.00%
Epoch = 5, train accuracy = 46.00%, test accuracy = 48.00%
Epoch = 6, train accuracy = 45.00%, test accuracy = 38.00%
Epoch = 7, train accuracy = 45.00%, test accuracy = 44.00%
Epoch = 8, train accuracy = 53.00%, test accuracy = 40.00%
Epoch = 9, train accuracy = 48.00%, test accuracy = 46.00%
Epoch = 10, train accuracy = 44.00%, test accuracy = 38.00%
Epoch = 11, train accuracy = 48.00%, test accuracy = 48.00%
Epoch = 12, train accuracy = 44.00%, test accuracy = 32.00%
Epoch = 13, train accuracy = 42.00%, test accuracy = 48.00%
Epoch = 14, train accuracy = 47.00%, test accuracy = 46.00%
Epoch = 15, train accuracy = 50.00%, test accuracy = 44.00%
Epoch = 16, train accuracy = 48.00%, test accuracy = 44.00%
Epoch = 17, train accuracy = 49.00%, test accuracy = 40.00%
Epoch = 18, train accuracy = 48.00%, test accuracy = 36.00%
Epoch = 19, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 20, train accuracy = 47.00%, test accuracy = 32.00%
Epoch = 21, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 22, train accuracy = 43.00%, test accuracy = 38.00%
Epoch = 23, train accuracy = 48.00%, test accuracy = 46.00%
Epoch = 24, train accuracy = 50.00%, test accuracy = 44.00%
Epoch = 25, train accuracy = 51.00%, test accuracy = 34.00%
Epoch = 26, train accuracy = 52.00%, test accuracy = 36.00%
Epoch = 27, train accuracy = 54.00%, test accuracy = 42.00%
Epoch = 28, train accuracy = 50.00%, test accuracy = 34.00%
Epoch = 29, train accuracy = 47.00%, test accuracy = 38.00%
Epoch = 30, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 31, train accuracy = 52.00%, test accuracy = 42.00%
Epoch = 32, train accuracy = 51.00%, test accuracy = 34.00%
Epoch = 33, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 34, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 35, train accuracy = 50.00%, test accuracy = 46.00%
Epoch = 36, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 37, train accuracy = 50.00%, test accuracy = 44.00%
Epoch = 38, train accuracy = 46.00%, test accuracy = 38.00%
Epoch = 39, train accuracy = 46.00%, test accuracy = 46.00%
Epoch = 40, train accuracy = 50.00%, test accuracy = 40.00%
Epoch = 41, train accuracy = 53.00%, test accuracy = 38.00%
Epoch = 42, train accuracy = 50.00%, test accuracy = 42.00%
Epoch = 43, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 44, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 45, train accuracy = 51.00%, test accuracy = 40.00%
Epoch = 46, train accuracy = 50.00%, test accuracy = 46.00%
Epoch = 47, train accuracy = 49.00%, test accuracy = 46.00%
Epoch = 48, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 49, train accuracy = 52.00%, test accuracy = 46.00%
Epoch = 50, train accuracy = 47.00%, test accuracy = 52.00%
Epoch = 51, train accuracy = 44.00%, test accuracy = 40.00%
Epoch = 52, train accuracy = 51.00%, test accuracy = 44.00%
Epoch = 53, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 54, train accuracy = 49.00%, test accuracy = 38.00%
Epoch = 55, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 56, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 57, train accuracy = 50.00%, test accuracy = 38.00%
Epoch = 58, train accuracy = 48.00%, test accuracy = 44.00%
Epoch = 59, train accuracy = 51.00%, test accuracy = 42.00%
Epoch = 60, train accuracy = 48.00%, test accuracy = 34.00%
Epoch = 61, train accuracy = 47.00%, test accuracy = 42.00%
Epoch = 62, train accuracy = 48.00%, test accuracy = 42.00%
Epoch = 63, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 64, train accuracy = 53.00%, test accuracy = 48.00%
Epoch = 65, train accuracy = 51.00%, test accuracy = 42.00%
Epoch = 66, train accuracy = 48.00%, test accuracy = 36.00%
Epoch = 67, train accuracy = 49.00%, test accuracy = 46.00%
Epoch = 68, train accuracy = 52.00%, test accuracy = 42.00%
Epoch = 69, train accuracy = 50.00%, test accuracy = 38.00%
Epoch = 70, train accuracy = 49.00%, test accuracy = 42.00%
Epoch = 71, train accuracy = 50.00%, test accuracy = 44.00%
Epoch = 72, train accuracy = 50.00%, test accuracy = 38.00%
Epoch = 73, train accuracy = 48.00%, test accuracy = 46.00%
Epoch = 74, train accuracy = 52.00%, test accuracy = 48.00%
Epoch = 75, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 76, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 77, train accuracy = 51.00%, test accuracy = 42.00%
Epoch = 78, train accuracy = 45.00%, test accuracy = 40.00%
Epoch = 79, train accuracy = 46.00%, test accuracy = 38.00%
Epoch = 80, train accuracy = 51.00%, test accuracy = 42.00%
Epoch = 81, train accuracy = 47.00%, test accuracy = 42.00%
Epoch = 82, train accuracy = 53.00%, test accuracy = 44.00%
Epoch = 83, train accuracy = 49.00%, test accuracy = 38.00%
Epoch = 84, train accuracy = 49.00%, test accuracy = 38.00%
Epoch = 85, train accuracy = 52.00%, test accuracy = 30.00%
Epoch = 86, train accuracy = 49.00%, test accuracy = 36.00%
Epoch = 87, train accuracy = 48.00%, test accuracy = 44.00%
Epoch = 88, train accuracy = 46.00%, test accuracy = 40.00%
Epoch = 89, train accuracy = 48.00%, test accuracy = 44.00%
Epoch = 90, train accuracy = 50.00%, test accuracy = 34.00%
Epoch = 91, train accuracy = 53.00%, test accuracy = 32.00%
Epoch = 92, train accuracy = 51.00%, test accuracy = 40.00%
Epoch = 93, train accuracy = 43.00%, test accuracy = 44.00%
Epoch = 94, train accuracy = 48.00%, test accuracy = 40.00%
Epoch = 95, train accuracy = 50.00%, test accuracy = 44.00%
Epoch = 96, train accuracy = 48.00%, test accuracy = 38.00%
Epoch = 97, train accuracy = 50.00%, test accuracy = 50.00%
Epoch = 98, train accuracy = 47.00%, test accuracy = 46.00%
Epoch = 99, train accuracy = 52.00%, test accuracy = 40.00%
Epoch = 100, train accuracy = 50.00%, test accuracy = 36.00%
Для сравнения, вот выход из набора данных радужки:
Epoch = 1, train accuracy = 71.00%, test accuracy = 68.00%
Epoch = 2, train accuracy = 81.00%, test accuracy = 80.00%
Epoch = 3, train accuracy = 80.00%, test accuracy = 90.00%
Epoch = 4, train accuracy = 87.00%, test accuracy = 92.00%
Epoch = 5, train accuracy = 91.00%, test accuracy = 86.00%
Epoch = 6, train accuracy = 90.00%, test accuracy = 94.00%
Epoch = 7, train accuracy = 96.00%, test accuracy = 92.00%
Epoch = 8, train accuracy = 89.00%, test accuracy = 86.00%
Epoch = 9, train accuracy = 92.00%, test accuracy = 92.00%
Epoch = 10, train accuracy = 92.00%, test accuracy = 94.00%
Epoch = 11, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 12, train accuracy = 94.00%, test accuracy = 96.00%
Epoch = 13, train accuracy = 93.00%, test accuracy = 94.00%
Epoch = 14, train accuracy = 95.00%, test accuracy = 90.00%
Epoch = 15, train accuracy = 94.00%, test accuracy = 94.00%
Epoch = 16, train accuracy = 98.00%, test accuracy = 96.00%
Epoch = 17, train accuracy = 93.00%, test accuracy = 94.00%
Epoch = 18, train accuracy = 92.00%, test accuracy = 98.00%
Epoch = 19, train accuracy = 94.00%, test accuracy = 100.00%
Epoch = 20, train accuracy = 94.00%, test accuracy = 96.00%
Epoch = 21, train accuracy = 96.00%, test accuracy = 96.00%
Epoch = 22, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 23, train accuracy = 95.00%, test accuracy = 92.00%
Epoch = 24, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 25, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 26, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 27, train accuracy = 96.00%, test accuracy = 94.00%
Epoch = 28, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 29, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 30, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 31, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 32, train accuracy = 96.00%, test accuracy = 100.00%
Epoch = 33, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 34, train accuracy = 95.00%, test accuracy = 96.00%
Epoch = 35, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 36, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 37, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 38, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 39, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 40, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 41, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 42, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 43, train accuracy = 98.00%, test accuracy = 96.00%
Epoch = 44, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 45, train accuracy = 94.00%, test accuracy = 100.00%
Epoch = 46, train accuracy = 98.00%, test accuracy = 98.00%
Epoch = 47, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 48, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 49, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 50, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 51, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 52, train accuracy = 97.00%, test accuracy = 94.00%
Epoch = 53, train accuracy = 96.00%, test accuracy = 100.00%
Epoch = 54, train accuracy = 93.00%, test accuracy = 98.00%
Epoch = 55, train accuracy = 94.00%, test accuracy = 98.00%
Epoch = 56, train accuracy = 96.00%, test accuracy = 96.00%
Epoch = 57, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 58, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 59, train accuracy = 96.00%, test accuracy = 94.00%
Epoch = 60, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 61, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 62, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 63, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 64, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 65, train accuracy = 96.00%, test accuracy = 100.00%
Epoch = 66, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 67, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 68, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 69, train accuracy = 93.00%, test accuracy = 96.00%
Epoch = 70, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 71, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 72, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 73, train accuracy = 98.00%, test accuracy = 98.00%
Epoch = 74, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 75, train accuracy = 99.00%, test accuracy = 98.00%
Epoch = 76, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 77, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 78, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 79, train accuracy = 98.00%, test accuracy = 96.00%
Epoch = 80, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 81, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 82, train accuracy = 96.00%, test accuracy = 96.00%
Epoch = 83, train accuracy = 95.00%, test accuracy = 98.00%
Epoch = 84, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 85, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 86, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 87, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 88, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 89, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 90, train accuracy = 98.00%, test accuracy = 98.00%
Epoch = 91, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 92, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 93, train accuracy = 97.00%, test accuracy = 100.00%
Epoch = 94, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 95, train accuracy = 97.00%, test accuracy = 96.00%
Epoch = 96, train accuracy = 96.00%, test accuracy = 100.00%
Epoch = 97, train accuracy = 96.00%, test accuracy = 98.00%
Epoch = 98, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 99, train accuracy = 97.00%, test accuracy = 98.00%
Epoch = 100, train accuracy = 96.00%, test accuracy = 98.00%
Сеть, которая работает с одним набором данных, не означает, что она будет работать с другим набором данных. Вам необходимо настроить учебную способность модели (путем добавления или удаления слоев). –
Спасибо, я увеличил количество слоев, и теперь получаю гораздо лучшие результаты. Я бы дал вам баллы за ваш комментарий, но я не могу. – brtknr
Я обращусь к ответу. –