You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
3.2 KiB
C
89 lines
3.2 KiB
C
|
1 month ago
|
/*------------------------------------------------------------------------------
|
||
|
|
* Copyright (c) 2023 by Bai Bing (seread@163.com)
|
||
|
|
* S++ COPYING file for copying and redistribution conditions.
|
||
|
|
*
|
||
|
|
* Alians IT Studio.
|
||
|
|
*----------------------------------------------------------------------------*/
|
||
|
|
#pragma once
|
||
|
|
|
||
|
|
#include <string>
|
||
|
|
|
||
|
|
#include "core/Error.h"
|
||
|
|
#include "ASMatrix.h"
|
||
|
|
|
||
|
|
namespace ais
|
||
|
|
{
|
||
|
|
//============================================================================
|
||
|
|
// Method Description:
|
||
|
|
/// Append values to the end of an array.
|
||
|
|
///
|
||
|
|
/// @param array
|
||
|
|
/// @param appendValues
|
||
|
|
/// @param axis (Optional, default NONE): The axis along which values are appended.
|
||
|
|
/// If axis is not given, both array and appendValues
|
||
|
|
/// are flattened before use.
|
||
|
|
/// @return Matrix
|
||
|
|
///
|
||
|
|
template <typename dtype>
|
||
|
|
Matrix<dtype> append(const Matrix<dtype> &array, const Matrix<dtype> &appendValues, Axis axis = Axis::NONE)
|
||
|
|
{
|
||
|
|
switch (axis)
|
||
|
|
{
|
||
|
|
case Axis::NONE:
|
||
|
|
{
|
||
|
|
Matrix<dtype> returnArray(1, array.size() + appendValues.size());
|
||
|
|
std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin());
|
||
|
|
std::copy(std::execution::par_unseq, appendValues.cbegin(),
|
||
|
|
appendValues.cend(),
|
||
|
|
returnArray.begin() + array.size());
|
||
|
|
|
||
|
|
return returnArray;
|
||
|
|
}
|
||
|
|
case Axis::ROW:
|
||
|
|
{
|
||
|
|
const Shape inShape = array.shape();
|
||
|
|
const Shape appendShape = appendValues.shape();
|
||
|
|
if (inShape.cols != appendShape.cols)
|
||
|
|
{
|
||
|
|
THROW_INVALID_ARGUMENT(
|
||
|
|
"all the input array dimensions except for the concatenation axis must match exactly");
|
||
|
|
}
|
||
|
|
|
||
|
|
Matrix<dtype> returnArray(inShape.rows + appendShape.rows, inShape.cols);
|
||
|
|
std::copy(std::execution::par_unseq, array.cbegin(), array.cend(), returnArray.begin());
|
||
|
|
std::copy(std::execution::par_unseq, appendValues.cbegin(),
|
||
|
|
appendValues.cend(),
|
||
|
|
returnArray.begin() + array.size());
|
||
|
|
|
||
|
|
return returnArray;
|
||
|
|
}
|
||
|
|
case Axis::COLUMN:
|
||
|
|
{
|
||
|
|
const Shape inShape = array.shape();
|
||
|
|
const Shape appendShape = appendValues.shape();
|
||
|
|
if (inShape.rows != appendShape.rows)
|
||
|
|
{
|
||
|
|
THROW_INVALID_ARGUMENT(
|
||
|
|
"all the input array dimensions except for the concatenation axis must match exactly");
|
||
|
|
}
|
||
|
|
|
||
|
|
Matrix<dtype> returnArray(inShape.rows, inShape.cols + appendShape.cols);
|
||
|
|
for (uint32_t row = 0; row < returnArray.shape().rows; ++row)
|
||
|
|
{
|
||
|
|
std::copy(std::execution::par_unseq, array.cbegin(row), array.cend(row), returnArray.begin(row));
|
||
|
|
std::copy(std::execution::par_unseq, appendValues.cbegin(row),
|
||
|
|
appendValues.cend(row),
|
||
|
|
returnArray.begin(row) + inShape.cols);
|
||
|
|
}
|
||
|
|
|
||
|
|
return returnArray;
|
||
|
|
}
|
||
|
|
default:
|
||
|
|
{
|
||
|
|
THROW_INVALID_ARGUMENT("Unimplemented axis type.");
|
||
|
|
return {}; // get rid of compiler warning
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} // namespace ais
|