You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
6.9 KiB
C

1 month ago
/*------------------------------------------------------------------------------
* Copyright (c) 2023 by Bai Bing (seread@163.com)
* S++ COPYING file for copying and redistribution conditions.
*
* Alians IT Studio.
*----------------------------------------------------------------------------*/
#pragma once
#include "core/StaticAsserts.h"
#include "ASShape.h"
#include "ASMatrix.h"
namespace ais
{
//============================================================================
// Method Description:
/// An array with ones at and below the given diagonal and zeros elsewhere.
///
/// @param num: number of rows and cols
/// @param offset: (the sub-diagonal at and below which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> tril(uint32_t num, int32_t offset = 0)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
uint32_t rowStart = 0;
uint32_t colStart = 0;
if (offset > 0)
{
colStart = offset;
}
else
{
rowStart = offset * -1;
}
Matrix<dtype> returnArray(num);
returnArray.zeros();
for (uint32_t row = rowStart; row < num; ++row)
{
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
{
if (col == num)
{
break;
}
returnArray(row, col) = dtype{1};
}
}
return returnArray;
}
//============================================================================
// Method Description:
/// An array with ones at and below the given diagonal and zeros elsewhere.
///
/// @param num: number of rows
/// @param inM: number of columns
/// @param offset: (the sub-diagonal at and below which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> tril(uint32_t num, uint32_t inM, int32_t offset = 0)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
uint32_t rowStart = 0;
uint32_t colStart = 0;
if (offset > 0)
{
colStart = offset;
}
else if (offset < 0)
{
rowStart = offset * -1;
}
Matrix<dtype> returnArray(num, inM);
returnArray.zeros();
for (uint32_t row = rowStart; row < num; ++row)
{
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
{
if (col == inM)
{
break;
}
returnArray(row, col) = dtype{1};
}
}
return returnArray;
}
// forward declare
template <typename dtype>
Matrix<dtype> triu(uint32_t num, uint32_t inM, int32_t offset = 0);
//============================================================================
// Method Description:
/// Lower triangle of an array.
///
/// Return a copy of an array with elements above the k - th diagonal zeroed.
///
/// @param array: number of rows and cols
/// @param offset: (the sub-diagonal at and below which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> tril(const Matrix<dtype> &array, size_t offset = 0)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
const Shape inShape = array.shape();
auto outArray = array.copy();
outArray.put_mask(triu<bool>(inShape.rows, inShape.cols, offset + 1), 0);
return outArray;
}
//============================================================================
// Method Description:
/// An array with ones at and above the given diagonal and zeros elsewhere.
///
/// @param num: number of rows
/// @param inM: number of columns
/// @param offset: (the sub-diagonal at and above which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> triu(uint32_t num, uint32_t inM, int32_t offset)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
// because i'm stealing the lines of code from tril and reversing it, this is necessary
offset -= 1;
uint32_t rowStart = 0;
uint32_t colStart = 0;
if (offset > 0)
{
colStart = offset;
}
else if (offset < 0)
{
rowStart = offset * -1;
}
Matrix<dtype> returnArray(num, inM);
returnArray.ones();
for (uint32_t row = rowStart; row < num; ++row)
{
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
{
if (col == inM)
{
break;
}
returnArray(row, col) = dtype{0};
}
}
return returnArray;
}
//============================================================================
// Method Description:
/// An array with ones at and above the given diagonal and zeros elsewhere.
///
/// @param num: number of rows and cols
/// @param offset: (the sub-diagonal at and above which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> triu(uint32_t num, int32_t offset = 0)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
return tril<dtype>(num, -offset).transpose();
}
//============================================================================
// Method Description:
/// Upper triangle of an array.
///
/// Return a copy of an array with elements below the k - th diagonal zeroed.
///
/// @param array: number of rows and cols
/// @param offset: (the sub-diagonal at and below which the array is filled.
/// k = 0 is the main diagonal, while k < 0 is below it,
/// and k > 0 is above. The default is 0.)
///
///
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> triu(const Matrix<dtype> &array, size_t offset = 0)
{
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
const Shape inShape = array.shape();
auto outArray = array.copy();
outArray.put_mask(tril<bool>(inShape.rows, inShape.cols, offset - 1), 0);
return outArray;
}
} // namespace ais