2016-01-13 1 views
4

я сделал следующую реализацию медианы в C++ и и использовал его в R через Rcpp:Почему стандартная медианная функция R настолько медленнее, чем простая альтернатива C++?

// [[Rcpp::export]] 
double median2(std::vector<double> x){ 
    double median; 
    size_t size = x.size(); 
    sort(x.begin(), x.end()); 
    if (size % 2 == 0){ 
     median = (x[size/2 - 1] + x[size/2])/2.0; 
    } 
    else { 
     median = x[size/2]; 
    } 
    return median; 
} 

Если я впоследствии сравнить производительность с стандартным встроенным в R срединной функции, я получаю следующие результаты через microbenchmark

> x = rnorm(100) 
> microbenchmark(median(x),median2(x)) 
Unit: microseconds 
     expr min  lq  mean median  uq  max neval 
    median(x) 25.469 26.990 34.96888 28.130 29.081 518.126 100 
median2(x) 1.140 1.521 2.47486 1.901 2.281 47.897 100 

Почему стандартная медианная функция настолько медленнее? Это не то, что я ожидал бы ...

+3

Для начала рассмотрим все, что на самом деле выполняет 'median.default', а затем попробуйте проверить что-то более справедливое. – joran

+0

Итак, я думаю, что это из-за всего, но на самом деле вычисление медианы не требует времени. – Ruben

+3

Как в стороне, сортировка вектора излишне. Вы не заботитесь о упорядочении первых элементов n/2 - вам просто интересно, что такое n/2-й элемент. Алгоритм 'std :: nth_element' будет делать это быстрее, чем сортировка. Он может быть легко и эффективно реализован с использованием рекурсивного медиана медиана 5 и раздела с альтернативным алгоритмом короткой длины, если вы хотите его в r. Во-вторых, используйте явный 'std :: sort' на итераторах' std :: vector' (нет гарантии, что они определены в 'namespace std', на который опирается ваш код). – Yakk

ответ

11

Как отметил @joran, ваш код очень специализирован и, вообще говоря, менее обобщенные функции, алгоритмы и т. Д. Часто более эффективны. Посмотрите на median.default:

median.default 
# function (x, na.rm = FALSE) 
# { 
# if (is.factor(x) || is.data.frame(x)) 
#  stop("need numeric data") 
# if (length(names(x))) 
#  names(x) <- NULL 
# if (na.rm) 
#  x <- x[!is.na(x)] 
# else if (any(is.na(x))) 
#  return(x[FALSE][NA]) 
# n <- length(x) 
# if (n == 0L) 
#  return(x[FALSE][NA]) 
# half <- (n + 1L)%/%2L 
# if (n%%2L == 1L) 
#  sort(x, partial = half)[half] 
# else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L]) 
# } 

Есть несколько операций в месте размещения возможности пропущенных значений, и это, безусловно, влияет на общее время выполнения функции. Поскольку ваша функция не повторить это поведение может устранить кучу вычислений, но, следовательно, не будет предоставлять один и тот же результат для векторов с отсутствующими значениями:

median(c(1, 2, NA)) 
#[1] NA 

median2(c(1, 2, NA)) 
#[1] 2 

несколько других факторов, которые, вероятно, не имеют столько от эффекта, как при обработке NA с, но стоит отметить:

  • median, наряду с несколькими функциями, которые она использует, являются S3 дженерики, так что есть небольшой количество времени, затраченного на отправку метода
  • median будет работать только с целыми и числовыми векторами; он также будет обрабатывать Date, POSIXt, и, вероятно, куча других классов, и правильно сохранять атрибуты:

median(Sys.Date() + 0:4) 
#[1] "2016-01-15" 

median(Sys.time() + (0:4) * 3600 * 24) 
#[1] "2016-01-15 11:14:31 EST" 

Edit: Следует отметить, что функция ниже будет вызвать сортировку исходного вектора с NumericVector s являются прокси-объектами. Если вы хотите этого избежать, вы можете ввести Rcpp::clone входной вектор и работать с клоном или использовать свою оригинальную подпись (с std::vector<double>), которая неявно требует копирования при преобразовании от SEXP до std::vector.

Также обратите внимание, что вы можете сбрить немного больше времени с помощью NumericVector вместо std::vector<double>:

#include <Rcpp.h> 

// [[Rcpp::export]] 
double cpp_med(Rcpp::NumericVector x){ 
    std::size_t size = x.size(); 
    std::sort(x.begin(), x.end()); 
    if (size % 2 == 0) return (x[size/2 - 1] + x[size/2])/2.0; 
    return x[size/2]; 
} 

microbenchmark::microbenchmark(
    median(x), 
    median2(x), 
    cpp_med(x), 
    times = 200L 
) 
# Unit: microseconds 
#  expr min  lq  mean median  uq  max neval 
# median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810 200 
# median2(x) 6.474 7.9665 13.90126 11.0570 14.844 151.817 200 
# cpp_med(x) 5.737 7.4285 11.25318 9.0270 13.405 52.184 200 

Yakk воспитал большую точку в комментариях выше - также разработанный Джерри Коффином - о неэффективности делать полный сорт. Вот переписан с использованием std::nth_element, протестированные на гораздо больший вектор: [. Это более длительный комментарий, чем ответ на вопрос, который вы на самом деле спросил]

#include <Rcpp.h> 

// [[Rcpp::export]] 
double cpp_med2(Rcpp::NumericVector xx) { 
    Rcpp::NumericVector x = Rcpp::clone(xx); 
    std::size_t n = x.size()/2; 
    std::nth_element(x.begin(), x.begin() + n, x.end()); 

    if (x.size() % 2) return x[n]; 
    return (x[n] + *std::max_element(x.begin(), x.begin() + n))/2.; 
} 

set.seed(123) 
xx <- rnorm(10e5) 

all.equal(cpp_med2(xx), median(xx)) 
all.equal(median2(xx), median(xx)) 

microbenchmark::microbenchmark(
    cpp_med2(xx), median2(xx), 
    median(xx), times = 200L 
) 
# Unit: milliseconds 
#   expr  min  lq  mean median  uq  max neval 
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161 33.92103 200 
# median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301 200 
# median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939 200 
+4

Мне было бы любопытно, что произойдет, если вы использовали только последние четыре строки 'median.default' и заменили' mean() 'на' .Internal (среднее (среднее значение))) '. Я бы предположил, что это будет очень близко к 'median2', может быть, даже быстрее. – joran

+2

... поэтому, проверив, что это определенно не так быстро, как 'median2', но он намного ближе. – joran

+1

@joran Это, вероятно, стоит проверить на большом векторе; когда я сравнивал «median.default» с двумя версиями C++ на векторе «rnorm (1e5)», тайминги были намного ближе. – nrussell

0

Я не уверен, какую «стандартную» реализацию вы бы имели в виду.

В любом случае: если бы он был одним из них, он был бы частью стандартной библиотеки, конечно же, не был бы позволен изменять порядок элементов в векторе (как это делает ваша реализация), поэтому ему определенно придется работать над копия.

Для создания этой копии потребуется время и процессор (и значительная память), что повлияет на время выполнения.

+2

Код C++ также создает копию, поэтому время копирования должно быть примерно одинаковым. – NathanOliver

+1

Он передает вектор по * значению *, а не по ссылке const. –

+0

Я имею в виду медианную функцию пакета статистики (стандартный пакет). Спасибо, что заметили, что я изменил переменную x, я этого не заметил. edit: это пропуск по значению, поэтому сделана копия – Ruben

2

Даже ваш код может быть открыт для значительного улучшения. В частности, вы сортируете весь ввод, даже если вам нужен только один или два элемента.

Вы можете изменить это значение с O (n log n) на O (n), используя std::nth_element вместо std::sort. В случае четного количества элементов вы обычно хотите использовать std::nth_element, чтобы найти элемент непосредственно перед посещением, затем используйте std::min_element, чтобы найти сразу следующий элемент - но std::nth_element также разделяет входные элементы, поэтому std::min_element имеет только для запуска над элементами выше середины после nth_element, а не для всего массива ввода. То есть, после того, как nth_element, вы получите ситуацию, как это:

enter image description here

Сложность std::nth_element является «линейная в среднем», и (конечно) std::min_element является линейным, так что общая сложность линейна ,

Итак, для простого случая (нечетное число элементов), вы получите что-то вроде:

auto pos = x.begin() + x.size()/2; 

std::nth_element(x.begin(), pos, x.end()); 
return *pos; 

... и для более сложного случая (четное число элементов):

std::nth_element(x.begin(), pos, x.end()); 
auto pos2 = std::min_element(pos+1, x.end()); 
return (*pos + *pos2)/2.0;