2016-12-31 9 views
0

Я не знаком с машинным обучением, и я работаю на простом примере, как практика. Я генерирую пары чисел и присваиваю метку «1», если оба значения четные или нечетные, а метка «0», если одна четная, а другая - нечетная.Использование Keras для прогнозирования того, имеют ли два числа одинаковые «странности» с использованием вложения, я нахожусь на правильном пути?

Моя модель иногда приближается к ~ 75%, но я уверен, что это просто запоминание, какие пары чисел приводят к 1 и которые приводят к 0. Я хочу, чтобы модель узнала, что существуют две категории чисел, например, зная, что [1, 4] = 0 и [1, 6] = 0, поэтому 4 и 6 находятся в одной и той же категории.

Есть ли я на правильном пути? Это даже разумная проблема для решения проблемы с ML?

Вот мой код:

num_examples = 500000 
input_dim = 300 

# Generate the randomized training data 
data = (input_dim * numpy.random.random((num_examples, 2))).astype(int) 
labels = [] 

# Generate the correct labels for the training data 
for example in data: 
    if example[0] % 2 == example[1] % 2: 
     labels.append([1, 0]) 
    else: 
     labels.append([0, 1]) 


model = Sequential() 
model.add(Embedding(input_dim, 70, input_length=2)) 
model.add(Flatten()) 
model.add(Dense(20)) 
model.add(Dense(2, activation='sigmoid')) 

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
model.fit(data, labels, nb_epoch=900, batch_size=10000) 
+0

Я запустил ваш код, и я получаю в основном 50-процентную точность. Вы уверены, что видите 75%? – gobrewers14

+0

Да, примерно в половине случаев, когда он никогда не достигает 50%, а в других случаях он достигает уровня 75%. – TheN33k

ответ

0

Эта модель не будет иметь возможности обобщать невидимые числа, так как вложение слой назначает для каждого номера функции вектора. если число не находится в каком-либо учебном примере, оно никогда не будет обновляться и будет сохранять свое первоначальное случайное значение. Я думаю, что обеспечение чисел в двоичной форме (например, поставка «9» в виде 8 входов значения 00001001) поможет. во всяком случае, это грязный хак, на мой взгляд, так как теперь все сети должны учиться определять, когда последний бит каждого входа равен.