/*------------------------------------------------------------------------------ * Copyright (c) 2023 by Bai Bing (seread@163.com) * S++ COPYING file for copying and redistribution conditions. * * Alians IT Studio. *----------------------------------------------------------------------------*/ #pragma once #include #include "core/Error.h" #include "ASMatrix.h" namespace ais { //============================================================================ // Method Description: /// Append values to the end of an array. /// /// @param array /// @param appendValues /// @param axis (Optional, default NONE): The axis along which values are appended. /// If axis is not given, both array and appendValues /// are flattened before use. /// @return Matrix /// template Matrix append(const Matrix &array, const Matrix &appendValues, Axis axis = Axis::NONE) { switch (axis) { case Axis::NONE: { Matrix returnArray(1, array.size() + appendValues.size()); std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin()); std::copy(std::execution::par_unseq, appendValues.cbegin(), appendValues.cend(), returnArray.begin() + array.size()); return returnArray; } case Axis::ROW: { const Shape inShape = array.shape(); const Shape appendShape = appendValues.shape(); if (inShape.cols != appendShape.cols) { THROW_INVALID_ARGUMENT( "all the input array dimensions except for the concatenation axis must match exactly"); } Matrix returnArray(inShape.rows + appendShape.rows, inShape.cols); std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin()); std::copy(std::execution::par_unseq, appendValues.cbegin(), appendValues.cend(), returnArray.begin() + array.size()); return returnArray; } case Axis::COLUMN: { const Shape inShape = array.shape(); const Shape appendShape = appendValues.shape(); if (inShape.rows != appendShape.rows) { THROW_INVALID_ARGUMENT( "all the input array dimensions except for the concatenation axis must match exactly"); } Matrix returnArray(inShape.rows, inShape.cols + appendShape.cols); for (uint32_t row = 0; row < returnArray.shape().rows; ++row) { std::copy(std::execution::par_unseq, array.cbegin(row), array.cend(row), returnArray.begin(row)); std::copy(std::execution::par_unseq, appendValues.cbegin(row), appendValues.cend(row), returnArray.begin(row) + inShape.cols); } return returnArray; } default: { THROW_INVALID_ARGUMENT("Unimplemented axis type."); return {}; // get rid of compiler warning } } } } // namespace ais