From f49c0463d4252d2544e6a4a35e853491677017eb Mon Sep 17 00:00:00 2001 From: SEK1RO Date: Thu, 19 Sep 2024 07:54:46 +0300 Subject: [PATCH] fix(baseN): sizeEnc/Dec: zeros counting --- include/base/baseN.hpp | 2 +- src/base58.cpp | 2 +- src/baseN.cpp | 16 +++++++++++----- test/test-baseN.cpp | 18 ++++++++++++++++++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/include/base/baseN.hpp b/include/base/baseN.hpp index e877de1..4dc058f 100644 --- a/include/base/baseN.hpp +++ b/include/base/baseN.hpp @@ -11,7 +11,7 @@ namespace baseN bool isValid(std::string_view str, const int8_t *map) noexcept; uint64_t sizeEncoded(std::span data, uint8_t base); - uint64_t sizeDecoded(std::string_view str, uint8_t base) noexcept; + uint64_t sizeDecoded(std::string_view str, uint8_t base, const char* digits) noexcept; void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size, uint8_t base, const char *digits); std::string encode(std::span data, uint8_t base, const char *digits) noexcept; diff --git a/src/base58.cpp b/src/base58.cpp index a22836a..0c7afb2 100644 --- a/src/base58.cpp +++ b/src/base58.cpp @@ -41,7 +41,7 @@ namespace base58 } uint64_t sizeDecoded(std::string_view str) noexcept { - return baseN::sizeDecoded(str, 58); + return baseN::sizeDecoded(str, 58, digits); } void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size) noexcept { diff --git a/src/baseN.cpp b/src/baseN.cpp index 02f3e65..5c2337b 100644 --- a/src/baseN.cpp +++ b/src/baseN.cpp @@ -20,15 +20,21 @@ namespace baseN } uint64_t sizeEncoded(std::span data, uint8_t base) { - if (data.size() > std::numeric_limits::max() / log256) + std::span dv(std::find_if(data.begin(), data.end(), [](uint8_t item) + { return item != 0; }), + data.end()); + if (dv.size() > std::numeric_limits::max() / log256) { throw std::overflow_error("baseN::sizeEncoded: overflow"); } - return data.size() * log256 / std::log(base) + 1; + return dv.size() * log256 / std::log(base) + 1 + (data.size() - dv.size()); } - uint64_t sizeDecoded(std::string_view str, uint8_t base) noexcept + uint64_t sizeDecoded(std::string_view str, uint8_t base, const char *digits) noexcept { - return str.size() * std::log(base) / log256 + 1; + std::string_view sv(std::find_if(str.begin(), str.end(), [digits](uint8_t ch) + { return ch != digits[0]; }), + str.end()); + return sv.size() * std::log(base) / log256 + 1 + (str.size() - sv.size()); } void encode(const uint8_t *data, uint64_t data_size, char *str, uint64_t str_size, uint8_t base, const char *digits) { @@ -119,7 +125,7 @@ namespace baseN } std::vector decode(std::string_view str, uint8_t base, const char *digits, const int8_t *map) noexcept { - std::vector data(baseN::sizeDecoded(str, base)); + std::vector data(baseN::sizeDecoded(str, base, digits)); baseN::decode(str.data(), str.size(), data.data(), data.size(), base, digits, map); data.erase(data.begin(), std::find_if(data.begin(), data.end(), [](uint8_t item) { return item != 0; })); diff --git a/test/test-baseN.cpp b/test/test-baseN.cpp index abd7638..8ddb77b 100644 --- a/test/test-baseN.cpp +++ b/test/test-baseN.cpp @@ -16,6 +16,24 @@ TEST(baseN, isValid) for (auto it : tests) EXPECT_EQ(it.first, isValid(it.second, base58::map)); } +TEST(baseN, sizeEncoded) +{ + std::vector> tests = { + {6, "12341234"}, + {5, "00000000"}, + }; + for (auto it : tests) + EXPECT_EQ(it.first, sizeEncoded(hex::decode(it.second), 58)); +} +TEST(baseN, sizeDecoded) +{ + std::vector> tests = { + {3, "qwer"}, + {5, "1111"}, + }; + for (auto it : tests) + EXPECT_EQ(it.first, sizeDecoded(it.second, 58, base58::digits)); +} std::vector> tests = { {"", ""}, {"Ky", "044c"},