You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
6.9 KiB
C++
231 lines
6.9 KiB
C++
/*------------------------------------------------------------------------------
|
|
* Copyright (c) 2023 by Bai Bing (seread@163.com)
|
|
* S++ COPYING file for copying and redistribution conditions.
|
|
*
|
|
* Alians IT Studio.
|
|
*----------------------------------------------------------------------------*/
|
|
#pragma once
|
|
|
|
#include "core/StaticAsserts.h"
|
|
#include "ASShape.h"
|
|
#include "ASMatrix.h"
|
|
|
|
namespace ais
|
|
{
|
|
//============================================================================
|
|
// Method Description:
|
|
/// An array with ones at and below the given diagonal and zeros elsewhere.
|
|
///
|
|
/// @param num: number of rows and cols
|
|
/// @param offset: (the sub-diagonal at and below which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> tril(uint32_t num, int32_t offset = 0)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
uint32_t rowStart = 0;
|
|
uint32_t colStart = 0;
|
|
if (offset > 0)
|
|
{
|
|
colStart = offset;
|
|
}
|
|
else
|
|
{
|
|
rowStart = offset * -1;
|
|
}
|
|
|
|
Matrix<dtype> returnArray(num);
|
|
returnArray.zeros();
|
|
for (uint32_t row = rowStart; row < num; ++row)
|
|
{
|
|
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
|
|
{
|
|
if (col == num)
|
|
{
|
|
break;
|
|
}
|
|
|
|
returnArray(row, col) = dtype{1};
|
|
}
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// An array with ones at and below the given diagonal and zeros elsewhere.
|
|
///
|
|
/// @param num: number of rows
|
|
/// @param inM: number of columns
|
|
/// @param offset: (the sub-diagonal at and below which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> tril(uint32_t num, uint32_t inM, int32_t offset = 0)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
uint32_t rowStart = 0;
|
|
uint32_t colStart = 0;
|
|
if (offset > 0)
|
|
{
|
|
colStart = offset;
|
|
}
|
|
else if (offset < 0)
|
|
{
|
|
rowStart = offset * -1;
|
|
}
|
|
|
|
Matrix<dtype> returnArray(num, inM);
|
|
returnArray.zeros();
|
|
for (uint32_t row = rowStart; row < num; ++row)
|
|
{
|
|
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
|
|
{
|
|
if (col == inM)
|
|
{
|
|
break;
|
|
}
|
|
|
|
returnArray(row, col) = dtype{1};
|
|
}
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
|
|
// forward declare
|
|
template <typename dtype>
|
|
Matrix<dtype> triu(uint32_t num, uint32_t inM, int32_t offset = 0);
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// Lower triangle of an array.
|
|
///
|
|
/// Return a copy of an array with elements above the k - th diagonal zeroed.
|
|
///
|
|
/// @param array: number of rows and cols
|
|
/// @param offset: (the sub-diagonal at and below which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> tril(const Matrix<dtype> &array, size_t offset = 0)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
const Shape inShape = array.shape();
|
|
auto outArray = array.copy();
|
|
outArray.put_mask(triu<bool>(inShape.rows, inShape.cols, offset + 1), 0);
|
|
return outArray;
|
|
}
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// An array with ones at and above the given diagonal and zeros elsewhere.
|
|
///
|
|
/// @param num: number of rows
|
|
/// @param inM: number of columns
|
|
/// @param offset: (the sub-diagonal at and above which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> triu(uint32_t num, uint32_t inM, int32_t offset)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
// because i'm stealing the lines of code from tril and reversing it, this is necessary
|
|
offset -= 1;
|
|
|
|
uint32_t rowStart = 0;
|
|
uint32_t colStart = 0;
|
|
if (offset > 0)
|
|
{
|
|
colStart = offset;
|
|
}
|
|
else if (offset < 0)
|
|
{
|
|
rowStart = offset * -1;
|
|
}
|
|
|
|
Matrix<dtype> returnArray(num, inM);
|
|
returnArray.ones();
|
|
for (uint32_t row = rowStart; row < num; ++row)
|
|
{
|
|
for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col)
|
|
{
|
|
if (col == inM)
|
|
{
|
|
break;
|
|
}
|
|
|
|
returnArray(row, col) = dtype{0};
|
|
}
|
|
}
|
|
|
|
return returnArray;
|
|
}
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// An array with ones at and above the given diagonal and zeros elsewhere.
|
|
///
|
|
/// @param num: number of rows and cols
|
|
/// @param offset: (the sub-diagonal at and above which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> triu(uint32_t num, int32_t offset = 0)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
return tril<dtype>(num, -offset).transpose();
|
|
}
|
|
|
|
//============================================================================
|
|
// Method Description:
|
|
/// Upper triangle of an array.
|
|
///
|
|
/// Return a copy of an array with elements below the k - th diagonal zeroed.
|
|
///
|
|
/// @param array: number of rows and cols
|
|
/// @param offset: (the sub-diagonal at and below which the array is filled.
|
|
/// k = 0 is the main diagonal, while k < 0 is below it,
|
|
/// and k > 0 is above. The default is 0.)
|
|
///
|
|
///
|
|
/// @return Matrix
|
|
///
|
|
template <typename dtype>
|
|
Matrix<dtype> triu(const Matrix<dtype> &array, size_t offset = 0)
|
|
{
|
|
STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype);
|
|
|
|
const Shape inShape = array.shape();
|
|
auto outArray = array.copy();
|
|
outArray.put_mask(tril<bool>(inShape.rows, inShape.cols, offset - 1), 0);
|
|
return outArray;
|
|
}
|
|
} // namespace ais
|