fix(3p/nix/hash): provide a Status-returning constructor

Additionally, add IsValidBase16() to restore the behavior of rejecting invalid base16, which absl's HexStringToBytes does not do.

Change-Id: I777a36f5dc787aa54a2aa316d6728f68da129768
Reviewed-on: https://cl.tvl.fyi/c/depot/+/1484
Tested-by: BuildkiteCI
Reviewed-by: tazjin <mail@tazj.in>
This commit is contained in:
Kane York 2020-07-27 16:50:56 -07:00 committed by kanepyork
parent 976a36c2e4
commit 31f9ee58d0
5 changed files with 123 additions and 52 deletions

View file

@ -48,6 +48,7 @@ target_sources(nixutil
target_link_libraries(nixutil
absl::strings
absl::statusor
glog
BZip2::BZip2
LibLZMA::LibLZMA

View file

@ -4,6 +4,7 @@
#include <iostream>
#include <absl/strings/escaping.h>
#include <absl/strings/str_format.h>
#include <fcntl.h>
#include <openssl/md5.h>
#include <openssl/sha.h>
@ -75,8 +76,18 @@ static std::string printHash16(const Hash& hash) {
return std::string(buf, hash.hashSize * 2);
}
bool Hash::IsValidBase16(absl::string_view s) {
for (char c : s) {
if ('0' <= c && c <= '9') continue;
if ('a' <= c && c <= 'f') continue;
if ('A' <= c && c <= 'F') continue;
return false;
}
return true;
}
// omitted: E O U T
const std::string base32Chars = "0123456789abcdfghijklmnpqrsvwxyz";
constexpr char base32Chars[] = "0123456789abcdfghijklmnpqrsvwxyz";
constexpr signed char kUnBase32[] = {
-1, -1, -1, -1, -1, -1, -1, -1, /* unprintables */
@ -167,6 +178,15 @@ std::string Hash::to_string(Base base, bool includeType) const {
}
Hash::Hash(const std::string& s, HashType type) : type(type) {
absl::StatusOr<Hash> result = deserialize(s, type);
if (result.ok()) {
*this = *result;
} else {
throw BadHash(result.status().message());
}
}
absl::StatusOr<Hash> Hash::deserialize(const std::string& s, HashType type) {
size_t pos = 0;
bool isSRI = false;
@ -176,90 +196,88 @@ Hash::Hash(const std::string& s, HashType type) : type(type) {
if (sep != std::string::npos) {
isSRI = true;
} else if (type == htUnknown) {
throw BadHash("hash '%s' does not include a type", s);
return absl::InvalidArgumentError(
absl::StrCat("hash string '", s, " does not include a type"));
}
}
HashType parsedType = type;
if (sep != std::string::npos) {
std::string hts = std::string(s, 0, sep);
this->type = parseHashType(hts);
if (this->type == htUnknown) {
throw BadHash("unknown hash type '%s'", hts);
}
if (type != htUnknown && type != this->type) {
throw BadHash("hash '%s' should have type '%s'", s, printHashType(type));
parsedType = parseHashType(hts);
if (parsedType != type) {
return absl::InvalidArgumentError(
absl::StrCat("hash '", s, "' should have type '", printHashType(type),
"', found '", printHashType(parsedType), "'"));
}
pos = sep + 1;
}
init();
Hash dest(parsedType);
size_t size = s.size() - pos;
absl::string_view sv(s.data() + pos, size);
if (!isSRI && size == base16Len()) {
auto parseHexDigit = [&](char c) {
if (c >= '0' && c <= '9') {
return c - '0';
if (!isSRI && size == dest.base16Len()) {
std::string bytes;
if (!IsValidBase16(sv)) {
return absl::InvalidArgumentError(
absl::StrCat("invalid base-16 hash: bad character in '", s, "'"));
}
if (c >= 'A' && c <= 'F') {
return c - 'A' + 10;
}
if (c >= 'a' && c <= 'f') {
return c - 'a' + 10;
}
throw BadHash("invalid base-16 hash '%s'", s);
};
for (unsigned int i = 0; i < hashSize; i++) {
hash[i] = parseHexDigit(s[pos + i * 2]) << 4 |
parseHexDigit(s[pos + i * 2 + 1]);
bytes = absl::HexStringToBytes(sv);
if (bytes.size() != dest.hashSize) {
return absl::InvalidArgumentError(
absl::StrCat("hash '", s, "' has wrong length for base16 ",
printHashType(dest.type)));
}
memcpy(dest.hash, bytes.data(), dest.hashSize);
}
else if (!isSRI && size == base32Len()) {
else if (!isSRI && size == dest.base32Len()) {
for (unsigned int n = 0; n < size; ++n) {
char c = s[pos + size - n - 1];
unsigned char digit = 0;
for (digit = 0; digit < base32Chars.size(); ++digit) { /* !!! slow */
if (base32Chars[digit] == c) {
break;
}
}
if (digit >= 32) {
throw BadHash("invalid base-32 hash '%s'", s);
char c = sv[size - n - 1];
// range: -1, 0..31
signed char digit = kUnBase32[static_cast<unsigned char>(c)];
if (digit < 0) {
return absl::InvalidArgumentError(
absl::StrCat("invalid base-32 hash: bad character ",
absl::CEscape(absl::string_view(&c, 1))));
}
unsigned int b = n * 5;
unsigned int i = b / 8;
unsigned int j = b % 8;
hash[i] |= digit << j;
dest.hash[i] |= digit << j;
if (i < hashSize - 1) {
hash[i + 1] |= digit >> (8 - j);
if (i < dest.hashSize - 1) {
dest.hash[i + 1] |= digit >> (8 - j);
} else {
if ((digit >> (8 - j)) != 0) {
throw BadHash("invalid base-32 hash '%s'", s);
return absl::InvalidArgumentError(
absl::StrCat("invalid base-32 hash '", s, "'"));
}
}
}
}
else if (isSRI || size == base64Len()) {
std::string d;
if (!absl::Base64Unescape(std::string(s, pos), &d)) {
// TODO(grfn): replace this with StatusOr
throw Error("Invalid Base64");
else if (isSRI || size == dest.base64Len()) {
std::string decoded;
if (!absl::Base64Unescape(sv, &decoded)) {
return absl::InvalidArgumentError("invalid base-64 hash");
}
if (d.size() != hashSize) {
throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", s);
if (decoded.size() != dest.hashSize) {
return absl::InvalidArgumentError(
absl::StrCat("hash '", s, "' has wrong length for base64 ",
printHashType(dest.type)));
}
assert(hashSize);
memcpy(hash, d.data(), hashSize);
memcpy(dest.hash, decoded.data(), dest.hashSize);
}
else {
throw BadHash("hash '%s' has wrong length for hash type '%s'", s,
printHashType(type));
return absl::InvalidArgumentError(absl::StrCat(
"hash '", s, "' has wrong length for ", printHashType(dest.type)));
}
return dest;
}
union Ctx {

View file

@ -1,5 +1,7 @@
#pragma once
#include <absl/status/statusor.h>
#include "libutil/serialise.hh"
#include "libutil/types.hh"
@ -36,6 +38,10 @@ struct Hash {
string. */
Hash(const std::string& s, HashType type = htUnknown);
/* Status-returning version of above constructor */
static absl::StatusOr<Hash> deserialize(const std::string& s,
HashType type = htUnknown);
void init();
/* Check whether a hash is set. */
@ -64,6 +70,10 @@ struct Hash {
(e.g. "sha256:"). */
std::string to_string(Base base = Base32, bool includeType = true) const;
/* Returns whether the passed string contains entirely valid base16
characters. */
static bool IsValidBase16(absl::string_view s);
/* Returns whether the passed string contains entirely valid base32
characters. */
static bool IsValidBase32(absl::string_view s);

View file

@ -44,6 +44,8 @@ struct FormatOrString {
inline std::string fmt(const std::string& s) { return s; }
inline std::string fmt(std::string_view s) { return std::string(s); }
inline std::string fmt(const char* s) { return s; }
inline std::string fmt(const FormatOrString& fs) { return fs.s; }

View file

@ -1,12 +1,16 @@
#include "libutil/hash.hh"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
class HashTest : public ::testing::Test {};
using testing::EndsWith;
using testing::HasSubstr;
namespace nix {
TEST(HASH_TEST, SHA256) {
TEST(HashTest, SHA256) {
auto hash = hashString(HashType::htSHA256, "foo");
ASSERT_EQ(hash.base64Len(), 44);
ASSERT_EQ(hash.base32Len(), 52);
@ -40,4 +44,40 @@ TEST(HashTest, SHA256Decode) {
ASSERT_EQ(hash, *base64);
}
TEST(HashTest, SHA256DecodeFail) {
EXPECT_THAT(
Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56==",
HashType::htSHA256)
.status()
.message(),
HasSubstr("wrong length"));
EXPECT_THAT(
Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56,=",
HashType::htSHA256)
.status()
.message(),
HasSubstr("invalid base-64"));
EXPECT_THAT(Hash::deserialize(
"sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i",
HashType::htSHA256)
.status()
.message(),
HasSubstr("wrong length"));
absl::StatusOr<Hash> badB32Char = Hash::deserialize(
"sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i,",
HashType::htSHA256);
EXPECT_THAT(badB32Char.status().message(), HasSubstr("invalid base-32"));
EXPECT_THAT(badB32Char.status().message(), EndsWith(","));
EXPECT_THAT(
Hash::deserialize(
"sha256:"
"2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7 ",
HashType::htSHA256)
.status()
.message(),
HasSubstr("invalid base-16"));
}
} // namespace nix