From a7128eaf34d2dfa79b9fc3a109bac25f866ba073 Mon Sep 17 00:00:00 2001 From: Louis Pearson Date: Tue, 16 Jan 2024 00:57:31 -0700 Subject: [PATCH] feat: recieve file descriptors Feels super hacky, but I'm not sure there is a better way to handle it. The crux of the problem is that wayland does not seem to make any guarantees about the location of file descriptors. --- examples/01_client_connect.zig | 13 +++++-- src/main.zig | 64 +++++++++++++++++++++++++++------- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/examples/01_client_connect.zig b/examples/01_client_connect.zig index 544d788..f61c52f 100644 --- a/examples/01_client_connect.zig +++ b/examples/01_client_connect.zig @@ -128,7 +128,7 @@ pub fn main() !void { std.debug.print("<- wl_seat.name = {s}\n", .{name.name}); }, } - } else if (header.object_id == 1) { + } else if (header.object_id == DISPLAY_ID) { const event = try wayland.deserialize(wayland.core.Display.Event, header, body); switch (event) { .@"error" => |err| std.debug.print("<- error({}): {} {s}\n", .{ err.object_id, err.code, err.message }), @@ -304,12 +304,19 @@ pub fn main() !void { } else if (header.object_id == wl_keyboard_id) { const event = try wayland.deserialize(wayland.core.Keyboard.Event, header, body); switch (event) { - // .keymap => |keymap| {}, + .keymap => |keymap| { + const fd = conn.fd_queue.orderedRemove(0); + std.debug.print("keymap format={}, size={}, fd={}\n", .{ + keymap.format, + keymap.size, + fd, + }); + }, else => { std.debug.print("<- wl_keyboard@{}\n", .{event}); }, } - } else if (header.object_id == 1) { + } else if (header.object_id == DISPLAY_ID) { const event = try wayland.deserialize(wayland.core.Display.Event, header, body); switch (event) { .@"error" => |err| std.debug.print("<- error({}): {} {s}\n", .{ err.object_id, err.code, err.message }), diff --git a/src/main.zig b/src/main.zig index 63eae7c..0b42321 100644 --- a/src/main.zig +++ b/src/main.zig @@ -113,7 +113,7 @@ pub fn deserializeArguments(comptime Signature: type, buffer: []const u32) !Sign var result: Signature = undefined; var pos: usize = 0; inline for (std.meta.fields(Signature)) |field| { - if (field.type == types.Fd) continue; // Must be handled + if (field.type == types.Fd) continue; switch (@typeInfo(field.type)) { .Int => |int_info| switch (int_info.signedness) { .signed => @field(result, field.name) = try readInt(buffer, &pos), @@ -458,6 +458,7 @@ pub const Conn = struct { allocator: std.mem.Allocator, send_buffer: []u32, recv_buffer: []u32, + fd_queue: std.ArrayListUnmanaged(std.os.fd_t), socket: std.net.Stream, pub fn init(alloc: std.mem.Allocator, display_path: []const u8) !Conn { @@ -467,6 +468,7 @@ pub const Conn = struct { .allocator = alloc, .send_buffer = send_buffer, .recv_buffer = recv_buffer, + .fd_queue = .{}, .socket = try std.net.connectUnixSocket(display_path), }; } @@ -474,6 +476,7 @@ pub const Conn = struct { pub fn deinit(conn: *Conn) void { conn.allocator.free(conn.send_buffer); conn.allocator.free(conn.recv_buffer); + conn.fd_queue.deinit(conn.allocator); conn.socket.close(); } @@ -526,21 +529,58 @@ pub const Conn = struct { pub const Message = struct { Header, []const u32 }; pub fn recv(conn: *Conn) !Message { - // TODO: recvmesg and read fds + // TODO: make this less messy + // Read header + @memset(conn.recv_buffer, 0); + var iov: [1]std.os.iovec = .{.{ + .iov_base = std.mem.sliceAsBytes(conn.recv_buffer).ptr, + .iov_len = @sizeOf(Header), + }}; + var control_msg: cmsg(std.os.fd_t) = undefined; + const control_bytes = std.mem.asBytes(&control_msg); + var socket_msg = std.os.msghdr{ + .name = null, + .namelen = 0, + .iov = &iov, + .iovlen = iov.len, + .control = control_bytes.ptr, + .controllen = @intCast(control_bytes.len), + .flags = 0, + }; + + const size = std.os.linux.recvmsg(conn.socket.handle, &socket_msg, 0); + + if (size < @sizeOf(Header)) return error.SocketClosed; + var header: Header = undefined; - const header_bytes_read = try conn.socket.readAll(std.mem.asBytes(&header)); - if (header_bytes_read < @sizeOf(Header)) { - return error.SocketClosed; + @memcpy(std.mem.asBytes(&header), iov[0].iov_base[0..@sizeOf(Header)]); + + if (socket_msg.controllen != 0) { + try conn.fd_queue.append(conn.allocator, control_msg.data); } - const msg_size = (header.size_and_opcode.size - @sizeOf(Header)) / @sizeOf(u32); - if (msg_size > conn.recv_buffer.len) { - var new_size = conn.recv_buffer.len * 2; - while (new_size < msg_size) new_size *= 2; - conn.recv_buffer = try conn.allocator.realloc(conn.recv_buffer, new_size); + // Read body + const body_size = (header.size_and_opcode.size - @sizeOf(Header)) / @sizeOf(u32); + + iov[0] = .{ + .iov_base = std.mem.sliceAsBytes(conn.recv_buffer).ptr, + .iov_len = body_size * @sizeOf(u32), + }; + socket_msg = std.os.msghdr{ + .name = null, + .namelen = 0, + .iov = &iov, + .iovlen = iov.len, + .control = control_bytes.ptr, + .controllen = @intCast(control_bytes.len), + .flags = 0, + }; + const size2 = std.os.linux.recvmsg(conn.socket.handle, &socket_msg, 0); + const message = conn.recv_buffer[0 .. size2 / @sizeOf(u32)]; + + if (socket_msg.controllen != 0) { + try conn.fd_queue.append(conn.allocator, control_msg.data); } - const bytes_read = try conn.socket.readAll(std.mem.sliceAsBytes(conn.recv_buffer[0..msg_size])); - const message = conn.recv_buffer[0 .. bytes_read / @sizeOf(u32)]; return .{ header, message }; }