You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
3.6 KiB
C++
117 lines
3.6 KiB
C++
/*------------------------------------------------------------------------------
|
|
* Copyright (c) 2023 by Bai Bing (seread@163.com)
|
|
* S++ COPYING file for copying and redistribution conditions.
|
|
*
|
|
* Alians IT Studio.
|
|
*----------------------------------------------------------------------------*/
|
|
#pragma once
|
|
|
|
#include <string>
|
|
|
|
#include "core/Error.h"
|
|
#include "ASMatrix.h"
|
|
|
|
namespace ais
|
|
{
|
|
//===========================================================================
|
|
// Method Description:
|
|
/// Compute the mean along the specified axis.
|
|
///
|
|
/// @param array
|
|
/// @param axis (Optional, default NONE)
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<double> mean(const Matrix<dtype> &array, Axis axis = Axis::NONE)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC(dtype);
|
|
|
|
switch (axis)
|
|
{
|
|
case Axis::NONE:
|
|
{
|
|
auto sum = std::accumulate(array.cbegin(), array.cend(), 0.);
|
|
Matrix<double> returnArray = {sum /= static_cast<double>(array.size())};
|
|
|
|
return returnArray;
|
|
}
|
|
case Axis::COLUMN:
|
|
{
|
|
Matrix<double> returnArray(1, array.rows());
|
|
for (size_t row = 0; row < array.rows(); ++row)
|
|
{
|
|
auto sum = std::accumulate(array.cbegin(row), array.cend(row), 0.);
|
|
returnArray(0, row) = sum / static_cast<double>(array.columns());
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
case Axis::ROW:
|
|
{
|
|
return mean(array.transpose(), Axis::COLUMN);
|
|
}
|
|
default:
|
|
{
|
|
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
|
|
return {};
|
|
}
|
|
}
|
|
}
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// Compute the mean along the specified axis.
|
|
///
|
|
/// @param array
|
|
/// @param axis (Optional, default NONE)
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<std::complex<double>> mean(const Matrix<std::complex<dtype>> &array, Axis axis = Axis::NONE)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC(dtype);
|
|
|
|
switch (axis)
|
|
{
|
|
case Axis::NONE:
|
|
{
|
|
auto sum = std::accumulate(array.cbegin(), array.cend(), std::complex<double>(0.));
|
|
Matrix<std::complex<double>> returnArray = {sum /= std::complex<double>(array.size())};
|
|
|
|
return returnArray;
|
|
}
|
|
case Axis::COLUMN:
|
|
{
|
|
Matrix<std::complex<double>> returnArray(1, array.rows());
|
|
for (size_t row = 0; row < array.rows(); ++row)
|
|
{
|
|
auto sum = std::accumulate(array.cbegin(row), array.cend(row), std::complex<double>(0.));
|
|
returnArray(0, row) = sum / std::complex<double>(array.columns());
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
case Axis::ROW:
|
|
{
|
|
Matrix<std::complex<double>> transposedArray = array.transpose();
|
|
Matrix<std::complex<double>> returnArray(1, transposedArray.rows());
|
|
for (size_t row = 0; row < transposedArray.rows(); ++row)
|
|
{
|
|
auto sum = std::accumulate(transposedArray.cbegin(row),
|
|
transposedArray.cend(row),
|
|
std::complex<double>(0.));
|
|
returnArray(0, row) = sum / std::complex<double>(transposedArray.columns());
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
default:
|
|
{
|
|
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
|
|
return {}; // get rid of compiler warning
|
|
}
|
|
}
|
|
}
|
|
} |