/*------------------------------------------------------------------------------ * Copyright (c) 2023 by Bai Bing (seread@163.com) * S++ COPYING file for copying and redistribution conditions. * * Alians IT Studio. *----------------------------------------------------------------------------*/ #pragma once #include "core/StaticAsserts.h" #include "ASShape.h" #include "ASMatrix.h" namespace ais { //============================================================================ // Method Description: /// An array with ones at and below the given diagonal and zeros elsewhere. /// /// @param num: number of rows and cols /// @param offset: (the sub-diagonal at and below which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// @return Matrix /// template Matrix tril(uint32_t num, int32_t offset = 0) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); uint32_t rowStart = 0; uint32_t colStart = 0; if (offset > 0) { colStart = offset; } else { rowStart = offset * -1; } Matrix returnArray(num); returnArray.zeros(); for (uint32_t row = rowStart; row < num; ++row) { for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col) { if (col == num) { break; } returnArray(row, col) = dtype{1}; } } return returnArray; } //============================================================================ // Method Description: /// An array with ones at and below the given diagonal and zeros elsewhere. /// /// @param num: number of rows /// @param inM: number of columns /// @param offset: (the sub-diagonal at and below which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// /// @return Matrix /// template Matrix tril(uint32_t num, uint32_t inM, int32_t offset = 0) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); uint32_t rowStart = 0; uint32_t colStart = 0; if (offset > 0) { colStart = offset; } else if (offset < 0) { rowStart = offset * -1; } Matrix returnArray(num, inM); returnArray.zeros(); for (uint32_t row = rowStart; row < num; ++row) { for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col) { if (col == inM) { break; } returnArray(row, col) = dtype{1}; } } return returnArray; } // forward declare template Matrix triu(uint32_t num, uint32_t inM, int32_t offset = 0); //============================================================================ // Method Description: /// Lower triangle of an array. /// /// Return a copy of an array with elements above the k - th diagonal zeroed. /// /// @param array: number of rows and cols /// @param offset: (the sub-diagonal at and below which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// /// @return Matrix /// template Matrix tril(const Matrix &array, size_t offset = 0) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); const Shape inShape = array.shape(); auto outArray = array.copy(); outArray.put_mask(triu(inShape.rows, inShape.cols, offset + 1), 0); return outArray; } //============================================================================ // Method Description: /// An array with ones at and above the given diagonal and zeros elsewhere. /// /// @param num: number of rows /// @param inM: number of columns /// @param offset: (the sub-diagonal at and above which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// /// @return Matrix /// template Matrix triu(uint32_t num, uint32_t inM, int32_t offset) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); // because i'm stealing the lines of code from tril and reversing it, this is necessary offset -= 1; uint32_t rowStart = 0; uint32_t colStart = 0; if (offset > 0) { colStart = offset; } else if (offset < 0) { rowStart = offset * -1; } Matrix returnArray(num, inM); returnArray.ones(); for (uint32_t row = rowStart; row < num; ++row) { for (uint32_t col = 0; col < row + colStart + 1 - rowStart; ++col) { if (col == inM) { break; } returnArray(row, col) = dtype{0}; } } return returnArray; } //============================================================================ // Method Description: /// An array with ones at and above the given diagonal and zeros elsewhere. /// /// @param num: number of rows and cols /// @param offset: (the sub-diagonal at and above which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// /// @return Matrix /// template Matrix triu(uint32_t num, int32_t offset = 0) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); return tril(num, -offset).transpose(); } //============================================================================ // Method Description: /// Upper triangle of an array. /// /// Return a copy of an array with elements below the k - th diagonal zeroed. /// /// @param array: number of rows and cols /// @param offset: (the sub-diagonal at and below which the array is filled. /// k = 0 is the main diagonal, while k < 0 is below it, /// and k > 0 is above. The default is 0.) /// /// /// @return Matrix /// template Matrix triu(const Matrix &array, size_t offset = 0) { STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype); const Shape inShape = array.shape(); auto outArray = array.copy(); outArray.put_mask(tril(inShape.rows, inShape.cols, offset - 1), 0); return outArray; } } // namespace ais