/*------------------------------------------------------------------------------ * Copyright (c) 2023 by Bai Bing (seread@163.com) * S++ COPYING file for copying and redistribution conditions. * * Alians IT Studio. *----------------------------------------------------------------------------*/ #pragma once #include #include "core/Error.h" #include "ASMatrix.h" namespace ais { //=========================================================================== // Method Description: /// Compute the mean along the specified axis. /// /// @param array /// @param axis (Optional, default NONE) /// /// @return Matrix /// template Matrix mean(const Matrix &array, Axis axis = Axis::NONE) { STATIC_ASSERT_ARITHMETIC(dtype); switch (axis) { case Axis::NONE: { auto sum = std::accumulate(array.cbegin(), array.cend(), 0.); Matrix returnArray = {sum /= static_cast(array.size())}; return returnArray; } case Axis::COLUMN: { Matrix returnArray(1, array.rows()); for (size_t row = 0; row < array.rows(); ++row) { auto sum = std::accumulate(array.cbegin(row), array.cend(row), 0.); returnArray(0, row) = sum / static_cast(array.columns()); } return returnArray; } case Axis::ROW: { return mean(array.transpose(), Axis::COLUMN); } default: { THROW_INVALID_ARGUMENT("Unimplemented axis type."); return {}; } } } //============================================================================ // Method Description: /// Compute the mean along the specified axis. /// /// @param array /// @param axis (Optional, default NONE) /// /// @return Matrix /// template Matrix> mean(const Matrix> &array, Axis axis = Axis::NONE) { STATIC_ASSERT_ARITHMETIC(dtype); switch (axis) { case Axis::NONE: { auto sum = std::accumulate(array.cbegin(), array.cend(), std::complex(0.)); Matrix> returnArray = {sum /= std::complex(array.size())}; return returnArray; } case Axis::COLUMN: { Matrix> returnArray(1, array.rows()); for (size_t row = 0; row < array.rows(); ++row) { auto sum = std::accumulate(array.cbegin(row), array.cend(row), std::complex(0.)); returnArray(0, row) = sum / std::complex(array.columns()); } return returnArray; } case Axis::ROW: { Matrix> transposedArray = array.transpose(); Matrix> returnArray(1, transposedArray.rows()); for (size_t row = 0; row < transposedArray.rows(); ++row) { auto sum = std::accumulate(transposedArray.cbegin(row), transposedArray.cend(row), std::complex(0.)); returnArray(0, row) = sum / std::complex(transposedArray.columns()); } return returnArray; } default: { THROW_INVALID_ARGUMENT("Unimplemented axis type."); return {}; // get rid of compiler warning } } } }