diff --git a/jwt.zig b/jwt.zig index 89031ec..a22125f 100644 --- a/jwt.zig +++ b/jwt.zig @@ -14,6 +14,14 @@ const Algorithm = enum { pub fn jsonStringify(value: Self, options: std.json.StringifyOptions, writer: anytype) @TypeOf(writer).Error!void { try std.json.stringify(std.meta.tagName(value), options, writer); } + + pub fn CryptoFn(self: Self) type { + return switch (self) { + .HS256 => std.crypto.auth.hmac.sha2.HmacSha256, + .HS384 => std.crypto.auth.hmac.sha2.HmacSha384, + .HS512 => std.crypto.auth.hmac.sha2.HmacSha512, + }; + } }; const JWTType = enum { @@ -63,11 +71,7 @@ pub fn encodeMessage(allocator: *std.mem.Allocator, comptime alg: Algorithm, mes jwt_text.items[protected_header_base64_len] = '.'; _ = base64url.Encoder.encode(message_base64, message); - const signature = switch (signatureOptions.alg) { - .HS256 => &generate_signature_hmac_sha256(signatureOptions.key, protected_header_base64, message_base64), - .HS384 => &generate_signature_hmac_sha384(signatureOptions.key, protected_header_base64, message_base64), - .HS512 => &generate_signature_hmac_sha512(signatureOptions.key, protected_header_base64, message_base64), - }; + const signature = &generate_signature(alg, signatureOptions.key, protected_header_base64, message_base64); const signature_base64_len = base64url.Encoder.calcSize(signature.len); try jwt_text.resize(message_base64_len + 1 + protected_header_base64_len + 1 + signature_base64_len); @@ -185,11 +189,7 @@ pub fn validateMessage(allocator: *std.mem.Allocator, comptime expectedAlg: Algo defer allocator.free(signature); try base64url.Decoder.decode(signature, signature_base64); - const gen_sig = switch (signatureOptions.alg) { - .HS256 => &generate_signature_hmac_sha256(signatureOptions.key, jose_base64, payload_base64), - .HS384 => &generate_signature_hmac_sha384(signatureOptions.key, jose_base64, payload_base64), - .HS512 => &generate_signature_hmac_sha512(signatureOptions.key, jose_base64, payload_base64), - }; + const gen_sig = &generate_signature(expectedAlg, signatureOptions.key, jose_base64, payload_base64); if (!std.mem.eql(u8, signature, gen_sig)) { return error.InvalidSignature; } @@ -225,40 +225,14 @@ pub fn validateMessage(allocator: *std.mem.Allocator, comptime expectedAlg: Algo return message; } -const HmacSha256 = std.crypto.auth.hmac.sha2.HmacSha256; -pub fn generate_signature_hmac_sha256(key: []const u8, protectedHeaderBase64: []const u8, payloadBase64: []const u8) [HmacSha256.mac_length]u8 { - var h = HmacSha256.init(key); +pub fn generate_signature(comptime algo: Algorithm, key: []const u8, protectedHeaderBase64: []const u8, payloadBase64: []const u8) [algo.CryptoFn().mac_length]u8 { + const T = algo.CryptoFn(); + var h = T.init(key); h.update(protectedHeaderBase64); h.update("."); h.update(payloadBase64); - var out: [HmacSha256.mac_length]u8 = undefined; - h.final(&out); - - return out; -} - -const HmacSha384 = std.crypto.auth.hmac.sha2.HmacSha384; -pub fn generate_signature_hmac_sha384(key: []const u8, protectedHeaderBase64: []const u8, payloadBase64: []const u8) [HmacSha384.mac_length]u8 { - var h = HmacSha384.init(key); - h.update(protectedHeaderBase64); - h.update("."); - h.update(payloadBase64); - - var out: [HmacSha384.mac_length]u8 = undefined; - h.final(&out); - - return out; -} - -const HmacSha512 = std.crypto.auth.hmac.sha2.HmacSha512; -pub fn generate_signature_hmac_sha512(key: []const u8, protectedHeaderBase64: []const u8, payloadBase64: []const u8) [HmacSha512.mac_length]u8 { - var h = HmacSha512.init(key); - h.update(protectedHeaderBase64); - h.update("."); - h.update(payloadBase64); - - var out: [HmacSha512.mac_length]u8 = undefined; + var out: [T.mac_length]u8 = undefined; h.final(&out); return out;