You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
3.2 KiB
C

1 month ago
/*------------------------------------------------------------------------------
* Copyright (c) 2023 by Bai Bing (seread@163.com)
* S++ COPYING file for copying and redistribution conditions.
*
* Alians IT Studio.
*----------------------------------------------------------------------------*/
#pragma once
#include <string>
#include "core/Error.h"
#include "ASMatrix.h"
namespace ais
{
//============================================================================
// Method Description:
/// Append values to the end of an array.
///
/// @param array
/// @param appendValues
/// @param axis (Optional, default NONE): The axis along which values are appended.
/// If axis is not given, both array and appendValues
/// are flattened before use.
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> append(const Matrix<dtype> &array, const Matrix<dtype> &appendValues, Axis axis = Axis::NONE)
{
switch (axis)
{
case Axis::NONE:
{
Matrix<dtype> returnArray(1, array.size() + appendValues.size());
std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin());
std::copy(std::execution::par_unseq, appendValues.cbegin(),
appendValues.cend(),
returnArray.begin() + array.size());
return returnArray;
}
case Axis::ROW:
{
const Shape inShape = array.shape();
const Shape appendShape = appendValues.shape();
if (inShape.cols != appendShape.cols)
{
THROW_INVALID_ARGUMENT(
"all the input array dimensions except for the concatenation axis must match exactly");
}
Matrix<dtype> returnArray(inShape.rows + appendShape.rows, inShape.cols);
std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin());
std::copy(std::execution::par_unseq, appendValues.cbegin(),
appendValues.cend(),
returnArray.begin() + array.size());
return returnArray;
}
case Axis::COLUMN:
{
const Shape inShape = array.shape();
const Shape appendShape = appendValues.shape();
if (inShape.rows != appendShape.rows)
{
THROW_INVALID_ARGUMENT(
"all the input array dimensions except for the concatenation axis must match exactly");
}
Matrix<dtype> returnArray(inShape.rows, inShape.cols + appendShape.cols);
for (uint32_t row = 0; row < returnArray.shape().rows; ++row)
{
std::copy(std::execution::par_unseq, array.cbegin(row), array.cend(row), returnArray.begin(row));
std::copy(std::execution::par_unseq, appendValues.cbegin(row),
appendValues.cend(row),
returnArray.begin(row) + inShape.cols);
}
return returnArray;
}
default:
{
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
return {}; // get rid of compiler warning
}
}
}
} // namespace ais