You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.1 KiB
C

1 month ago
/*------------------------------------------------------------------------------
* Copyright (c) 2023 by Bai Bing (seread@163.com)
* S++ COPYING file for copying and redistribution conditions.
*
* Alians IT Studio.
*----------------------------------------------------------------------------*/
#pragma once
#include <initializer_list>
#include <string>
#include "core/Error.h"
#include "ASShape.h"
#include "ASMatrix.h"
namespace ais
{
//============================================================================
// Method Description:
/// Stack arrays in sequence vertically (row wise).
///
/// @param arrayList: {list} of arrays to stack
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> row_stack(const std::initializer_list<Matrix<dtype>> &arrayList)
{
// first loop through to calculate the final size of the array
Shape finalShape;
for (auto &Matrix : arrayList)
{
if (finalShape.is_null())
{
finalShape = Matrix.shape();
}
else if (Matrix.shape().cols != finalShape.cols)
{
THROW_INVALID_ARGUMENT("input arrays must have the same number of columns.");
}
else
{
finalShape.rows += Matrix.shape().rows;
}
}
// now that we know the final size, contruct the output array
Matrix<dtype> returnArray(finalShape);
size_t rowStart = 0;
for (auto &Matrix : arrayList)
{
const Shape theShape = Matrix.shape();
for (size_t row = 0; row < theShape.rows; ++row)
{
for (size_t col = 0; col < theShape.cols; ++col)
{
returnArray(rowStart + row, col) = Matrix(row, col);
}
}
rowStart += theShape.rows;
}
return returnArray;
}
//============================================================================
// Method Description:
/// Stack 1-D arrays as columns into a 2-D array.
///
/// @param arrayList: {list} of arrays to stack
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> column_stack(const std::initializer_list<Matrix<dtype>> &arrayList)
{
// first loop through to calculate the final size of the array
Shape finalShape;
for (auto &Matrix : arrayList)
{
if (finalShape.is_null())
{
finalShape = Matrix.shape();
}
else if (Matrix.shape().rows != finalShape.rows)
{
THROW_INVALID_ARGUMENT("input arrays must have the same number of rows.");
}
else
{
finalShape.cols += Matrix.shape().cols;
}
}
// now that we know the final size, contruct the output array
Matrix<dtype> returnArray(finalShape);
size_t colStart = 0;
for (auto &Matrix : arrayList)
{
const Shape theShape = Matrix.shape();
for (size_t row = 0; row < theShape.rows; ++row)
{
for (size_t col = 0; col < theShape.cols; ++col)
{
returnArray(row, colStart + col) = Matrix(row, col);
}
}
colStart += theShape.cols;
}
return returnArray;
}
//============================================================================
// Method Description:
/// Compute the variance along the specified axis.
///
/// @param arrayList: {list} of arrays to stack
/// @param axis: axis to stack the input Matrixes
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> stack(std::initializer_list<Matrix<dtype>> arrayList, Axis axis = Axis::NONE)
{
switch (axis)
{
case Axis::ROW:
{
return row_stack(arrayList);
}
case Axis::COLUMN:
case Axis::NONE:
{
return column_stack(arrayList);
}
default:
{
THROW_INVALID_ARGUMENT("axis must be either ROW or COL.");
return {}; // getting rid of compiler warning
}
}
}
//============================================================================
// Method Description:
/// Compute the variance along the specified axis.
///
/// @param arrayList: {list} of arrays to stack
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> hstack(std::initializer_list<Matrix<dtype>> arrayList)
{
return column_stack(arrayList);
}
//============================================================================
// Method Description:
/// Compute the variance along the specified axis.
///
/// @param arrayList: {list} of arrays to stack
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> vstack(std::initializer_list<Matrix<dtype>> arrayList)
{
return row_stack(arrayList);
}
} // namespace ais