You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
4.3 KiB
C

1 month ago
/*------------------------------------------------------------------------------
* Copyright (c) 2023 by Bai Bing (seread@163.com)
* S++ COPYING file for copying and redistribution conditions.
*
* Alians IT Studio.
*----------------------------------------------------------------------------*/
#pragma once
#include <algorithm>
#include <random>
#include <string>
#include "core/DataTypeInfo.h"
#include "core/Error.h"
#include "core/StaticAsserts.h"
#include "utils/EssentiallyEqual.h"
#include "ASMatrix.h"
#include "ASShape.h"
namespace ais
{
namespace random
{
static std::mt19937_64 generator_;
//============================================================================
// Method Description:
/// Seeds the random number generator
///
/// @param seed
///
inline void seed(int seed)
{
generator_.seed(seed);
}
//============================================================================
// Method Description:
/// Return random integer from low (inclusive) to high (exclusive),
/// with the given shape. If no high value is input then the range will
/// go from [0, low).
///
/// @param low
/// @param high default 0.
/// @return Matrix
///
template <typename dtype>
dtype rand(dtype low, dtype high = 0)
{
STATIC_ASSERT_ARITHMETIC(dtype);
if (utils::essentially_equal(low, high))
{
THROW_INVALID_ARGUMENT("input low value must be less than the input high value.");
}
else if (low > high)
{
std::swap(low, high);
}
// use constexpr to implement the correct comparison
if constexpr (ais::is_floating_point_v<dtype>)
{
std::uniform_real_distribution<dtype> dist(low, high - DataTypeInfo<dtype>::epsilon());
return dist(generator_);
}
else if constexpr (ais::is_integral_v<dtype>)
{
std::uniform_int_distribution<dtype> dist(low, high - 1);
return dist(generator_);
}
}
//============================================================================
// Method Description:
/// Return random integers from low (inclusive) to high (exclusive),
/// with the given shape. If no high value is input then the range will
/// go from [0, low).
///
/// @param shape
/// @param low
/// @param high default 0.
/// @return Matrix
///
template <typename dtype>
Matrix<dtype> rand_matrix(const Shape &shape, dtype low, dtype high = 0)
{
STATIC_ASSERT_ARITHMETIC(dtype);
if (utils::essentially_equal(low, high))
{
THROW_INVALID_ARGUMENT("input low value must be less than the input high value.");
}
else if (low > high)
{
std::swap(low, high);
}
Matrix<dtype> returnArray(shape);
// use constexpr to implement the correct comparison
if constexpr (ais::is_integral_v<dtype>)
{
std::uniform_int_distribution<dtype> dist(low, dtype(high - 1));
std::for_each(std::execution::par_unseq,
returnArray.begin(),
returnArray.end(),
[&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void
{ value = dist(generator); });
}
else if constexpr (ais::is_floating_point_v<dtype>)
{
std::uniform_real_distribution<dtype> dist(low, high - DataTypeInfo<dtype>::epsilon());
std::for_each(std::execution::par_unseq,
returnArray.begin(),
returnArray.end(),
[&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void
{ value = dist(generator_); });
}
return returnArray;
}
} // namespace random
} // namespace ais