diff --git a/include/base/base64.hpp b/include/base/base64.hpp index 2b90a4f..6d61c44 100644 --- a/include/base/base64.hpp +++ b/include/base/base64.hpp @@ -10,6 +10,9 @@ namespace base64 bool isValid(const char *str, uint64_t str_size) noexcept; bool isValid(std::string_view str) noexcept; + uint64_t sizeEncoded(std::span data); + uint64_t sizeDecoded(std::string_view str_size) noexcept; + void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size); std::string encode(std::span data) noexcept; diff --git a/include/base/hex.hpp b/include/base/hex.hpp index d7dcb2f..404104b 100644 --- a/include/base/hex.hpp +++ b/include/base/hex.hpp @@ -10,6 +10,9 @@ namespace hex bool isValid(const char *str, uint64_t str_size) noexcept; bool isValid(std::string_view str) noexcept; + uint64_t sizeEncoded(std::span data); + uint64_t sizeDecoded(std::string_view str_size) noexcept; + void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size); std::string encode(std::span data) noexcept; diff --git a/src/base64.cpp b/src/base64.cpp index 1845277..ba10e3e 100644 --- a/src/base64.cpp +++ b/src/base64.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -44,9 +45,23 @@ namespace base64 } return baseN::isValid(sv, b64map); } + uint64_t sizeEncoded(std::span data) + { + uint64_t str_size = data.size() / 3; + if (str_size > std::numeric_limits::max() / 4) + { + throw std::overflow_error("base64::sizeEncoded: overflow"); + } + str_size = str_size * 4 + (data.size() % 3 ? 4 : 0); + return str_size; + } + // uint64_t sizeDecoded(std::string_view str) noexcept + // { + + // } void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size) { - if (str_size < data_size / 3 * 4 + (data_size % 3 ? 4 : 0)) + if (str_size < base64::sizeEncoded(std::span(data, data_size))) { throw std::logic_error("base64::encode: not enough allocated length"); } @@ -81,7 +96,7 @@ namespace base64 } std::string encode(std::span data) noexcept { - std::string str(data.size() / 3 * 4 + (data.size() % 3 ? 4 : 0), ' '); + std::string str(base64::sizeEncoded(data), ' '); base64::encode(data.data(), data.size(), str.data(), str.size()); return str; } diff --git a/src/hex.cpp b/src/hex.cpp index 011d92d..73c65d3 100644 --- a/src/hex.cpp +++ b/src/hex.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -35,9 +36,21 @@ namespace hex { return baseN::isValid(str, hexmap); } + uint64_t sizeEncoded(std::span data) + { + if (data.size() > std::numeric_limits::max() / 2) + { + throw std::overflow_error("hex::sizeEncoded: overflow"); + } + return data.size() * 2; + } + uint64_t sizeDecoded(std::string_view str) noexcept + { + return str.size() / 2; + } void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size) { - if (str_size < data_size * 2) + if (str_size < hex::sizeEncoded(std::span(data, data_size))) { throw std::logic_error("hex::encode: not enough allocated length"); } @@ -49,7 +62,7 @@ namespace hex } std::string encode(std::span data) noexcept { - std::string str(data.size() * 2, ' '); + std::string str(hex::sizeEncoded(data), ' '); hex::encode(data.data(), data.size(), str.data(), str.size()); return str; } @@ -59,7 +72,7 @@ namespace hex { throw std::logic_error("hex::decode: isn't hex"); } - if (data_size < str_size / 2) + if (data_size < hex::sizeDecoded(std::string_view(str, str_size))) { throw std::logic_error("hex::decode: not enough allocated length"); } @@ -74,7 +87,7 @@ namespace hex } std::vector decode(std::string_view str) noexcept { - std::vector data(str.size() / 2); + std::vector data(hex::sizeDecoded(str)); hex::decode(str.data(), str.size(), data.data(), data.size()); return data; }