You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.6 KiB
C

1 month ago
/*------------------------------------------------------------------------------
* Copyright (c) 2023 by Bai Bing (seread@163.com)
* S++ COPYING file for copying and redistribution conditions.
*
* Alians IT Studio.
*----------------------------------------------------------------------------*/
#pragma once
#include <string>
#include "core/Error.h"
#include "ASMatrix.h"
namespace ais
{
//===========================================================================
// Method Description:
/// Compute the mean along the specified axis.
///
/// @param array
/// @param axis (Optional, default NONE)
///
/// @return Matrix
///
template <typename dtype>
Matrix<double> mean(const Matrix<dtype> &array, Axis axis = Axis::NONE)
{
STATIC_ASSERT_ARITHMETIC(dtype);
switch (axis)
{
case Axis::NONE:
{
auto sum = std::accumulate(array.cbegin(), array.cend(), 0.);
Matrix<double> returnArray = {sum /= static_cast<double>(array.size())};
return returnArray;
}
case Axis::COLUMN:
{
Matrix<double> returnArray(1, array.rows());
for (size_t row = 0; row < array.rows(); ++row)
{
auto sum = std::accumulate(array.cbegin(row), array.cend(row), 0.);
returnArray(0, row) = sum / static_cast<double>(array.columns());
}
return returnArray;
}
case Axis::ROW:
{
return mean(array.transpose(), Axis::COLUMN);
}
default:
{
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
return {};
}
}
}
//============================================================================
// Method Description:
/// Compute the mean along the specified axis.
///
/// @param array
/// @param axis (Optional, default NONE)
///
/// @return Matrix
///
template <typename dtype>
Matrix<std::complex<double>> mean(const Matrix<std::complex<dtype>> &array, Axis axis = Axis::NONE)
{
STATIC_ASSERT_ARITHMETIC(dtype);
switch (axis)
{
case Axis::NONE:
{
auto sum = std::accumulate(array.cbegin(), array.cend(), std::complex<double>(0.));
Matrix<std::complex<double>> returnArray = {sum /= std::complex<double>(array.size())};
return returnArray;
}
case Axis::COLUMN:
{
Matrix<std::complex<double>> returnArray(1, array.rows());
for (size_t row = 0; row < array.rows(); ++row)
{
auto sum = std::accumulate(array.cbegin(row), array.cend(row), std::complex<double>(0.));
returnArray(0, row) = sum / std::complex<double>(array.columns());
}
return returnArray;
}
case Axis::ROW:
{
Matrix<std::complex<double>> transposedArray = array.transpose();
Matrix<std::complex<double>> returnArray(1, transposedArray.rows());
for (size_t row = 0; row < transposedArray.rows(); ++row)
{
auto sum = std::accumulate(transposedArray.cbegin(row),
transposedArray.cend(row),
std::complex<double>(0.));
returnArray(0, row) = sum / std::complex<double>(transposedArray.columns());
}
return returnArray;
}
default:
{
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
return {}; // get rid of compiler warning
}
}
}
}