You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
4.3 KiB
C
125 lines
4.3 KiB
C
|
1 month ago
|
/*------------------------------------------------------------------------------
|
||
|
|
* Copyright (c) 2023 by Bai Bing (seread@163.com)
|
||
|
|
* S++ COPYING file for copying and redistribution conditions.
|
||
|
|
*
|
||
|
|
* Alians IT Studio.
|
||
|
|
*----------------------------------------------------------------------------*/
|
||
|
|
#pragma once
|
||
|
|
|
||
|
|
#include <algorithm>
|
||
|
|
#include <random>
|
||
|
|
#include <string>
|
||
|
|
|
||
|
|
#include "core/DataTypeInfo.h"
|
||
|
|
#include "core/Error.h"
|
||
|
|
#include "core/StaticAsserts.h"
|
||
|
|
#include "utils/EssentiallyEqual.h"
|
||
|
|
#include "ASMatrix.h"
|
||
|
|
#include "ASShape.h"
|
||
|
|
|
||
|
|
namespace ais
|
||
|
|
{
|
||
|
|
namespace random
|
||
|
|
{
|
||
|
|
static std::mt19937_64 generator_;
|
||
|
|
|
||
|
|
//============================================================================
|
||
|
|
// Method Description:
|
||
|
|
/// Seeds the random number generator
|
||
|
|
///
|
||
|
|
/// @param seed
|
||
|
|
///
|
||
|
|
inline void seed(int seed)
|
||
|
|
{
|
||
|
|
generator_.seed(seed);
|
||
|
|
}
|
||
|
|
//============================================================================
|
||
|
|
// Method Description:
|
||
|
|
/// Return random integer from low (inclusive) to high (exclusive),
|
||
|
|
/// with the given shape. If no high value is input then the range will
|
||
|
|
/// go from [0, low).
|
||
|
|
///
|
||
|
|
/// @param low
|
||
|
|
/// @param high default 0.
|
||
|
|
/// @return Matrix
|
||
|
|
///
|
||
|
|
template <typename dtype>
|
||
|
|
dtype rand(dtype low, dtype high = 0)
|
||
|
|
{
|
||
|
|
STATIC_ASSERT_ARITHMETIC(dtype);
|
||
|
|
|
||
|
|
if (utils::essentially_equal(low, high))
|
||
|
|
{
|
||
|
|
THROW_INVALID_ARGUMENT("input low value must be less than the input high value.");
|
||
|
|
}
|
||
|
|
else if (low > high)
|
||
|
|
{
|
||
|
|
std::swap(low, high);
|
||
|
|
}
|
||
|
|
|
||
|
|
// use constexpr to implement the correct comparison
|
||
|
|
if constexpr (ais::is_floating_point_v<dtype>)
|
||
|
|
{
|
||
|
|
std::uniform_real_distribution<dtype> dist(low, high - DataTypeInfo<dtype>::epsilon());
|
||
|
|
return dist(generator_);
|
||
|
|
}
|
||
|
|
else if constexpr (ais::is_integral_v<dtype>)
|
||
|
|
{
|
||
|
|
std::uniform_int_distribution<dtype> dist(low, high - 1);
|
||
|
|
return dist(generator_);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
//============================================================================
|
||
|
|
// Method Description:
|
||
|
|
/// Return random integers from low (inclusive) to high (exclusive),
|
||
|
|
/// with the given shape. If no high value is input then the range will
|
||
|
|
/// go from [0, low).
|
||
|
|
///
|
||
|
|
/// @param shape
|
||
|
|
/// @param low
|
||
|
|
/// @param high default 0.
|
||
|
|
/// @return Matrix
|
||
|
|
///
|
||
|
|
template <typename dtype>
|
||
|
|
Matrix<dtype> rand_matrix(const Shape &shape, dtype low, dtype high = 0)
|
||
|
|
{
|
||
|
|
|
||
|
|
STATIC_ASSERT_ARITHMETIC(dtype);
|
||
|
|
|
||
|
|
if (utils::essentially_equal(low, high))
|
||
|
|
{
|
||
|
|
THROW_INVALID_ARGUMENT("input low value must be less than the input high value.");
|
||
|
|
}
|
||
|
|
else if (low > high)
|
||
|
|
{
|
||
|
|
std::swap(low, high);
|
||
|
|
}
|
||
|
|
|
||
|
|
Matrix<dtype> returnArray(shape);
|
||
|
|
|
||
|
|
// use constexpr to implement the correct comparison
|
||
|
|
if constexpr (ais::is_integral_v<dtype>)
|
||
|
|
{
|
||
|
|
std::uniform_int_distribution<dtype> dist(low, dtype(high - 1));
|
||
|
|
std::for_each(std::execution::par_unseq,
|
||
|
|
returnArray.begin(),
|
||
|
|
returnArray.end(),
|
||
|
|
[&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void
|
||
|
|
{ value = dist(generator); });
|
||
|
|
}
|
||
|
|
else if constexpr (ais::is_floating_point_v<dtype>)
|
||
|
|
{
|
||
|
|
std::uniform_real_distribution<dtype> dist(low, high - DataTypeInfo<dtype>::epsilon());
|
||
|
|
std::for_each(std::execution::par_unseq,
|
||
|
|
returnArray.begin(),
|
||
|
|
returnArray.end(),
|
||
|
|
[&dist](dtype &value, std::mt19937_64 &generator = generator_) -> void
|
||
|
|
{ value = dist(generator_); });
|
||
|
|
}
|
||
|
|
|
||
|
|
return returnArray;
|
||
|
|
}
|
||
|
|
} // namespace random
|
||
|
|
} // namespace ais
|