2016-12-28 6 views
1

Я просто пытаюсь построить двух гауссиан и найти точку пересечения. У меня есть следующий код. Это не означает точное пересечение, и я действительно не могу понять, почему. Это всего лишь чуть-чуть, но я работал через производное решение, если мы взяли журнал вычитаемых гауссианцев, и да, похоже, что это должно быть правильно. Может ли кто-нибудь помочь? Спасибо огромное!Пересечение между гауссовскими

import numpy as np 
import matplotlib.pyplot as plt 

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

# found online 
def solve_gasussians(m1, s1, m2, s2): 
    a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2) 
    b = m2/(s2**2) - m1/(s1**2) 
    c = m1**2 /(2*s1**2) - m2**2/(2.0*s2**2) - np.log(s2/s1) 
    return np.roots([a,b,c]) 

s1 = np.linspace(0, 10,300) 
s2 = np.linspace(0, 14, 300) 

solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0) 
print solved_val 
solved_val = solved_val[0] 
plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1') 
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2') 
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo') 
plt.legend() 
plt.show() 
+0

Можете ли вы указать нам на решение, которое вы нашли в Интернете, так что мы не нужно пытаться извлечь его для себя? –

+0

Я думаю, что решение, на которое они ссылаются, может быть этим [SO question] (http://stackoverflow.com/a/22579904/752843). Поэтому мы не можем полностью обвинить их в ужасном отсутствии комментариев в коде. – Richard

+0

@ Рихард, вот что я думаю тоже. –

ответ

0

У вас есть небольшая ошибка в plot_normal функции - вы упускаете квадратный корень в знаменателе. Правильный вариант:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

дает ожидаемый результат: enter image description here

И два замечания.

  1. Помните, что вы можете иметь 2 корня уравнения в целом (две точки пересечения), и это имеет место с параметрами, которые вы предоставили.
  2. Насколько я знаю np.roots дает приблизительный результат, но кот получить точный результат легко, переписав solve_gasussians функцию:

    def solve_gasussians(m1, s1, m2, s2): 
        # coefficients of quadratic equation ax^2 + bx + c = 0 
        a = (s1**2.0) - (s2**2.0) 
        b = 2 * (m1 * s2**2.0 - m2 * s1**2.0) 
        c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2) 
        x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        return x1, x2 
    
0

Я не знаю, где ошибка в вашем коде. Но я думаю, что нашел код, заимствованный у вас, и сделал часть необходимой вам корректировки.

import numpy as np 
import matplotlib.pyplot as plt 
from scipy.stats import norm 

def solve(m1,m2,std1,std2): 
    a = 1/(2*std1**2) - 1/(2*std2**2) 
    b = m2/(std2**2) - m1/(std1**2) 
    c = m1**2 /(2*std1**2) - m2**2/(2*std2**2) - np.log(std2/std1) 
    return np.roots([a,b,c]) 

m1 = 5 
std1 = 0.5 
m2 = 7 
std2 = 1 

result = solve(m1,m2,std1,std2) 

x = np.linspace(-5,9,10000) 
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x]) 
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x]) 
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o') 

plt.show() 

я предложу две части нежелательных советов, которые могли бы сделать жизнь проще для вас (так, как они делают для меня):

  • При адаптации кода попытаться сделать небольшие, постепенные изменения и убедитесь, что код все еще работает на каждом шаге.
  • Ищите существующие бесплатные библиотеки. В этом случае норма от scipy является хорошей заменой тому, что использовалось в исходном коде.
0

Ошибка здесь. Эта линия:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

Должно быть так:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

Вы забыли sqrt.

Было бы разумнее использовать уже существующий обычный PDF, если это доступно, например:

import scipy.stats 
def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

Это также можно решить для пересечения точно. This answer дает квадратичное уравнение для корней гауссовых пересечений. Использование максимумов для решения для x дает следующее выражение. Который, хотя и сложный, не полагается на итеративные методы и может автоматически генерироваться из более простых выражений.

def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

Собираем в целом дает:

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.stats 

def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x 
def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

s = np.linspace(0, 14,300) 
x = solve_gaussians(5.0,0.5,7.0,1.0) 

plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1') 
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2') 
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo') 
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo') 
plt.legend() 
plt.show() 

Отдает:

Intersection of Gaussians

 Смежные вопросы

  • Нет связанных вопросов^_^