/*------------------------------------------------------------------------------ * Copyright (c) 2023 by Bai Bing (seread@163.com) * S++ COPYING file for copying and redistribution conditions. * * Alians IT Studio. *----------------------------------------------------------------------------*/ #pragma once #include #include #include #include "core/DataTypeInfo.h" #include "core/Error.h" #include "core/StaticAsserts.h" #include "utils/EssentiallyEqual.h" #include "ASMatrix.h" #include "ASShape.h" namespace ais { namespace random { static std::mt19937_64 generator_; //============================================================================ // Method Description: /// Seeds the random number generator /// /// @param seed /// inline void seed(int seed) { generator_.seed(seed); } //============================================================================ // Method Description: /// Return random integer from low (inclusive) to high (exclusive), /// with the given shape. If no high value is input then the range will /// go from [0, low). /// /// @param low /// @param high default 0. /// @return Matrix /// template dtype rand(dtype low, dtype high = 0) { STATIC_ASSERT_ARITHMETIC(dtype); if (utils::essentially_equal(low, high)) { THROW_INVALID_ARGUMENT("input low value must be less than the input high value."); } else if (low > high) { std::swap(low, high); } // use constexpr to implement the correct comparison if constexpr (ais::is_floating_point_v) { std::uniform_real_distribution dist(low, high - DataTypeInfo::epsilon()); return dist(generator_); } else if constexpr (ais::is_integral_v) { std::uniform_int_distribution dist(low, high - 1); return dist(generator_); } } //============================================================================ // Method Description: /// Return random integers from low (inclusive) to high (exclusive), /// with the given shape. If no high value is input then the range will /// go from [0, low). /// /// @param shape /// @param low /// @param high default 0. /// @return Matrix /// template Matrix rand_matrix(const Shape &shape, dtype low, dtype high = 0) { STATIC_ASSERT_ARITHMETIC(dtype); if (utils::essentially_equal(low, high)) { THROW_INVALID_ARGUMENT("input low value must be less than the input high value."); } else if (low > high) { std::swap(low, high); } Matrix returnArray(shape); // use constexpr to implement the correct comparison if constexpr (ais::is_integral_v) { std::uniform_int_distribution dist(low, dtype(high - 1)); std::for_each(std::execution::par_unseq, returnArray.begin(), returnArray.end(), [&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void { value = dist(generator); }); } else if constexpr (ais::is_floating_point_v) { std::uniform_real_distribution dist(low, high - DataTypeInfo::epsilon()); std::for_each(std::execution::par_unseq, returnArray.begin(), returnArray.end(), [&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void { value = dist(generator_); }); } return returnArray; } } // namespace random } // namespace ais