diff --git a/jwt.zig b/jwt.zig index fa5040e..89031ec 100644 --- a/jwt.zig +++ b/jwt.zig @@ -22,24 +22,23 @@ const JWTType = enum { }; pub const SignatureOptions = struct { - alg: Algorithm, key: []const u8, kid: ?[]const u8 = null, }; -pub fn encode(allocator: *std.mem.Allocator, payload: anytype, signatureOptions: SignatureOptions) ![]const u8 { +pub fn encode(allocator: *std.mem.Allocator, comptime alg: Algorithm, payload: anytype, signatureOptions: SignatureOptions) ![]const u8 { var payload_json = std.ArrayList(u8).init(allocator); defer payload_json.deinit(); try std.json.stringify(payload, .{}, payload_json.writer()); - return try encodeMessage(allocator, payload_json.items, signatureOptions); + return try encodeMessage(allocator, alg, payload_json.items, signatureOptions); } -pub fn encodeMessage(allocator: *std.mem.Allocator, message: []const u8, signatureOptions: SignatureOptions) ![]const u8 { +pub fn encodeMessage(allocator: *std.mem.Allocator, comptime alg: Algorithm, message: []const u8, signatureOptions: SignatureOptions) ![]const u8 { var protected_header = std.json.ObjectMap.init(allocator); defer protected_header.deinit(); - try protected_header.put("alg", .{ .String = std.meta.tagName(signatureOptions.alg) }); + try protected_header.put("alg", .{ .String = std.meta.tagName(alg) }); try protected_header.put("typ", .{ .String = "JWT" }); if (signatureOptions.kid) |kid| { try protected_header.put("kid", .{ .String = kid }); @@ -80,8 +79,8 @@ pub fn encodeMessage(allocator: *std.mem.Allocator, message: []const u8, signatu return jwt_text.toOwnedSlice(); } -pub fn validate(comptime P: type, allocator: *std.mem.Allocator, tokenText: []const u8, signatureOptions: SignatureOptions) !P { - const message = try validateMessage(allocator, tokenText, signatureOptions); +pub fn validate(comptime P: type, allocator: *std.mem.Allocator, comptime alg: Algorithm, tokenText: []const u8, signatureOptions: SignatureOptions) !P { + const message = try validateMessage(allocator, alg, tokenText, signatureOptions); defer allocator.free(message); // 10. Verify that the resulting octet sequence is a UTF-8-encoded @@ -94,7 +93,7 @@ pub fn validateFree(comptime P: type, allocator: *std.mem.Allocator, value: P) v std.json.parseFree(P, value, .{ .allocator = allocator }); } -pub fn validateMessage(allocator: *std.mem.Allocator, tokenText: []const u8, signatureOptions: SignatureOptions) ![]const u8 { +pub fn validateMessage(allocator: *std.mem.Allocator, comptime expectedAlg: Algorithm, tokenText: []const u8, signatureOptions: SignatureOptions) ![]const u8 { // 1. Verify that the JWT contains at least one period ('.') // character. // 2. Let the Encoded JOSE Header be the portion of the JWT before the @@ -138,7 +137,7 @@ pub fn validateMessage(allocator: *std.mem.Allocator, tokenText: []const u8, sig const alg = std.meta.stringToEnum(Algorithm, alg_val.String) orelse return error.InvalidAlgorithm; // Make sure that the algorithm matches: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - if (alg != signatureOptions.alg) return error.InvalidAlgorithm; + if (alg != expectedAlg) return error.InvalidAlgorithm; // TODO: Determine if "jku"/"jwk" need to be parsed and validated @@ -320,8 +319,8 @@ test "validate jws based tokens" { } test "generate and then validate jws token" { - try test_generate_then_validate(.{ .alg = .HS256, .key = "a jws hmac sha-256 test key" }); - try test_generate_then_validate(.{ .alg = .HS384, .key = "a jws hmac sha-384 test key" }); + try test_generate_then_validate(.HS256, .{ .key = "a jws hmac sha-256 test key" }); + try test_generate_then_validate(.HS384, .{ .key = "a jws hmac sha-384 test key" }); } const TestPayload = struct { @@ -330,12 +329,12 @@ const TestPayload = struct { iat: i64, }; -fn test_generate(algorithm: Algorithm, payload: TestPayload, expected: []const u8, key_base64: []const u8) !void { +fn test_generate(comptime algorithm: Algorithm, payload: TestPayload, expected: []const u8, key_base64: []const u8) !void { var key = try std.testing.allocator.alloc(u8, try base64url.Decoder.calcSizeForSlice(key_base64)); defer std.testing.allocator.free(key); try base64url.Decoder.decode(key, key_base64); - const token = try encode(std.testing.allocator, payload, .{ .alg = algorithm, .key = key }); + const token = try encode(std.testing.allocator, algorithm, payload, .{ .key = key }); defer std.testing.allocator.free(token); try std.testing.expectEqualSlices(u8, expected, token); @@ -347,12 +346,12 @@ const TestValidatePayload = struct { @"http://example.com/is_root": bool, }; -fn test_validate(algorithm: Algorithm, expected: TestValidatePayload, token: []const u8, key_base64: []const u8) !void { +fn test_validate(comptime algorithm: Algorithm, expected: TestValidatePayload, token: []const u8, key_base64: []const u8) !void { var key = try std.testing.allocator.alloc(u8, try base64url.Decoder.calcSizeForSlice(key_base64)); defer std.testing.allocator.free(key); try base64url.Decoder.decode(key, key_base64); - var claims = try validate(TestValidatePayload, std.testing.allocator, token, .{ .alg = algorithm, .key = key }); + var claims = try validate(TestValidatePayload, std.testing.allocator, algorithm, token, .{ .key = key }); defer validateFree(TestValidatePayload, std.testing.allocator, claims); try std.testing.expectEqualSlices(u8, expected.iss, claims.iss); @@ -360,7 +359,7 @@ fn test_validate(algorithm: Algorithm, expected: TestValidatePayload, token: []c try std.testing.expectEqual(expected.@"http://example.com/is_root", claims.@"http://example.com/is_root"); } -fn test_generate_then_validate(signatureOptions: SignatureOptions) !void { +fn test_generate_then_validate(comptime alg: Algorithm, signatureOptions: SignatureOptions) !void { const Payload = struct { sub: []const u8, name: []const u8, @@ -372,10 +371,10 @@ fn test_generate_then_validate(signatureOptions: SignatureOptions) !void { .iat = 1516239022, }; - const token = try encode(std.testing.allocator, payload, signatureOptions); + const token = try encode(std.testing.allocator, alg, payload, signatureOptions); defer std.testing.allocator.free(token); - var decoded = try validate(Payload, std.testing.allocator, token, signatureOptions); + var decoded = try validate(Payload, std.testing.allocator, alg, token, signatureOptions); defer validateFree(Payload, std.testing.allocator, decoded); try std.testing.expectEqualSlices(u8, payload.sub, decoded.sub);