diff --git a/jwt.zig b/jwt.zig index a3eb3cf..1c2c28a 100644 --- a/jwt.zig +++ b/jwt.zig @@ -83,18 +83,14 @@ pub fn encodeMessage(allocator: std.mem.Allocator, comptime alg: Algorithm, mess return jwt_text.toOwnedSlice(); } -pub fn validate(comptime P: type, allocator: std.mem.Allocator, comptime alg: Algorithm, tokenText: []const u8, signatureOptions: SignatureOptions) !P { +pub fn validate(comptime P: type, allocator: std.mem.Allocator, comptime alg: Algorithm, tokenText: []const u8, signatureOptions: SignatureOptions) !std.json.Parsed(P) { const message = try validateMessage(allocator, alg, tokenText, signatureOptions); defer allocator.free(message); // 10. Verify that the resulting octet sequence is a UTF-8-encoded // representation of a completely valid JSON object conforming to // RFC 7159 [RFC7159]; let the JWT Claims Set be this JSON object. - return std.json.parseFromSlice(P, allocator, message, .{}); -} - -pub fn validateFree(comptime P: type, allocator: std.mem.Allocator, value: P) void { - std.json.parseFree(P, allocator, value); + return std.json.parseFromSlice(P, allocator, message, .{ .allocate = .alloc_always }); } pub fn validateMessage(allocator: std.mem.Allocator, comptime expectedAlg: Algorithm, tokenText: []const u8, signatureOptions: SignatureOptions) ![]const u8 { @@ -118,13 +114,10 @@ pub fn validateMessage(allocator: std.mem.Allocator, comptime expectedAlg: Algor // TODO: Make sure the JSON parser confirms everything above - var parser = std.json.Parser.init(allocator, .alloc_always); - defer parser.deinit(); - var cty_opt = @as(?[]const u8, null); defer if (cty_opt) |cty| allocator.free(cty); - var jwt_tree = try parser.parse(jose_json); + var jwt_tree = try std.json.parseFromSlice(std.json.Value, allocator, jose_json, .{}); defer jwt_tree.deinit(); // 5. Verify that the resulting JOSE Header includes only parameters @@ -132,7 +125,7 @@ pub fn validateMessage(allocator: std.mem.Allocator, comptime expectedAlg: Algor // supported or that are specified as being ignored when not // understood. - var jwt_root = jwt_tree.root; + var jwt_root = jwt_tree.value; if (jwt_root != .object) return error.InvalidFormat; { @@ -325,8 +318,9 @@ fn test_validate(comptime algorithm: Algorithm, expected: TestValidatePayload, t defer std.testing.allocator.free(key); try base64url.Decoder.decode(key, key_base64); - var claims = try validate(TestValidatePayload, std.testing.allocator, algorithm, token, .{ .key = key }); - defer validateFree(TestValidatePayload, std.testing.allocator, claims); + var claims_p = try validate(TestValidatePayload, std.testing.allocator, algorithm, token, .{ .key = key }); + defer claims_p.deinit(); + const claims = claims_p.value; try std.testing.expectEqualSlices(u8, expected.iss, claims.iss); try std.testing.expectEqual(expected.exp, claims.exp); @@ -348,8 +342,9 @@ fn test_generate_then_validate(comptime alg: Algorithm, signatureOptions: Signat const token = try encode(std.testing.allocator, alg, payload, signatureOptions); defer std.testing.allocator.free(token); - var decoded = try validate(Payload, std.testing.allocator, alg, token, signatureOptions); - defer validateFree(Payload, std.testing.allocator, decoded); + var decoded_p = try validate(Payload, std.testing.allocator, alg, token, signatureOptions); + defer decoded_p.value; + const decoded = decoded_p.value; try std.testing.expectEqualSlices(u8, payload.sub, decoded.sub); try std.testing.expectEqualSlices(u8, payload.name, decoded.name);