kms client/server: replace unix domain socket file with socketpair after connecting (fixes issue of .gsr-kms-socket files remaining in $HOME)

This commit is contained in:
dec05eba 2023-11-12 10:55:02 +01:00
parent 1ac862d155
commit 290db495ff
5 changed files with 175 additions and 48 deletions

3
TODO
View File

@ -106,4 +106,5 @@ Make it possible to select which /dev/dri/card* to use, but that requires opengl
Support I915_FORMAT_MOD_Y_TILED_CCS (and other power saving modifiers, see https://trac.ffmpeg.org/ticket/8542). The only fix may be to use desktop portal for recording. This issue doesn't appear on x11 since these modifiers are not used by xorg server. Support I915_FORMAT_MOD_Y_TILED_CCS (and other power saving modifiers, see https://trac.ffmpeg.org/ticket/8542). The only fix may be to use desktop portal for recording. This issue doesn't appear on x11 since these modifiers are not used by xorg server.
Test if p2 state can be worked around by using pure nvenc api and overwriting cuInit/cuCtxCreate* to not do anything. Cuda might be loaded when using nvenc but it might not be used, with certain record options? (such as h264 p5). Test if p2 state can be worked around by using pure nvenc api and overwriting cuInit/cuCtxCreate* to not do anything. Cuda might be loaded when using nvenc but it might not be used, with certain record options? (such as h264 p5).
nvenc uses cuda when using b frames and rgb->yuv conversion, so convert the image ourselves instead.

View File

@ -12,6 +12,12 @@
#include <sys/wait.h> #include <sys/wait.h>
#include <sys/capability.h> #include <sys/capability.h>
#define GSR_SOCKET_PAIR_LOCAL 0
#define GSR_SOCKET_PAIR_REMOTE 1
static void cleanup_initial_socket(gsr_kms_client *self, bool kill_server);
static int gsr_kms_client_replace_connection(gsr_kms_client *self);
static bool generate_random_characters(char *buffer, int buffer_size, const char *alphabet, size_t alphabet_size) { static bool generate_random_characters(char *buffer, int buffer_size, const char *alphabet, size_t alphabet_size) {
int fd = open("/dev/urandom", O_RDONLY); int fd = open("/dev/urandom", O_RDONLY);
if(fd == -1) { if(fd == -1) {
@ -34,6 +40,14 @@ static bool generate_random_characters(char *buffer, int buffer_size, const char
return true; return true;
} }
static void close_fds(gsr_kms_response *response) {
for(int i = 0; i < response->num_fds; ++i) {
if(response->fds[i].fd > 0)
close(response->fds[i].fd);
response->fds[i].fd = 0;
}
}
static int send_msg_to_server(int server_fd, gsr_kms_request *request) { static int send_msg_to_server(int server_fd, gsr_kms_request *request) {
struct iovec iov; struct iovec iov;
iov.iov_base = request; iov.iov_base = request;
@ -72,9 +86,7 @@ static int recv_msg_from_server(int server_fd, gsr_kms_response *response) {
response->fds[i].fd = fds[i]; response->fds[i].fd = fds[i];
} }
} else { } else {
for(int i = 0; i < response->num_fds; ++i) { close_fds(response);
response->fds[i].fd = 0;
}
} }
} }
@ -142,13 +154,15 @@ static bool find_program_in_path(const char *program_name, char *filepath, int f
int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) { int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
int result = -1; int result = -1;
self->kms_server_pid = -1; self->kms_server_pid = -1;
self->socket_fd = -1; self->initial_socket_fd = -1;
self->client_fd = -1; self->initial_client_fd = -1;
self->socket_path[0] = '\0'; self->initial_socket_path[0] = '\0';
self->socket_pair[0] = -1;
self->socket_pair[1] = -1;
struct sockaddr_un local_addr = {0}; struct sockaddr_un local_addr = {0};
struct sockaddr_un remote_addr = {0}; struct sockaddr_un remote_addr = {0};
if(!create_socket_path(self->socket_path, sizeof(self->socket_path))) { if(!create_socket_path(self->initial_socket_path, sizeof(self->initial_socket_path))) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to create path to kms socket\n"); fprintf(stderr, "gsr error: gsr_kms_client_init: failed to create path to kms socket\n");
return -1; return -1;
} }
@ -185,20 +199,25 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
} }
} }
self->socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); if(socketpair(AF_UNIX, SOCK_STREAM, 0, self->socket_pair) == -1) {
if(self->socket_fd == -1) { fprintf(stderr, "gsr error: gsr_kms_client_init: socketpair failed, error: %s\n", strerror(errno));
goto err;
}
self->initial_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if(self->initial_socket_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: socket failed, error: %s\n", strerror(errno)); fprintf(stderr, "gsr error: gsr_kms_client_init: socket failed, error: %s\n", strerror(errno));
goto err; goto err;
} }
local_addr.sun_family = AF_UNIX; local_addr.sun_family = AF_UNIX;
strncpy_safe(local_addr.sun_path, self->socket_path, sizeof(local_addr.sun_path)); strncpy_safe(local_addr.sun_path, self->initial_socket_path, sizeof(local_addr.sun_path));
if(bind(self->socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path)) == -1) { if(bind(self->initial_socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path)) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to bind socket, error: %s\n", strerror(errno)); fprintf(stderr, "gsr error: gsr_kms_client_init: failed to bind socket, error: %s\n", strerror(errno));
goto err; goto err;
} }
if(listen(self->socket_fd, 1) == -1) { if(listen(self->initial_socket_fd, 1) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to listen on socket, error: %s\n", strerror(errno)); fprintf(stderr, "gsr error: gsr_kms_client_init: failed to listen on socket, error: %s\n", strerror(errno));
goto err; goto err;
} }
@ -209,13 +228,13 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
goto err; goto err;
} else if(pid == 0) { /* child */ } else if(pid == 0) { /* child */
if(inside_flatpak) { if(inside_flatpak) {
const char *args[] = { "flatpak-spawn", "--host", "pkexec", "flatpak", "run", "--command=gsr-kms-server", "com.dec05eba.gpu_screen_recorder", self->socket_path, card_path, NULL }; const char *args[] = { "flatpak-spawn", "--host", "pkexec", "flatpak", "run", "--command=gsr-kms-server", "com.dec05eba.gpu_screen_recorder", self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args); execvp(args[0], (char *const*)args);
} else if(has_perm) { } else if(has_perm) {
const char *args[] = { server_filepath, self->socket_path, card_path, NULL }; const char *args[] = { server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args); execvp(args[0], (char *const*)args);
} else { } else {
const char *args[] = { "pkexec", server_filepath, self->socket_path, card_path, NULL }; const char *args[] = { "pkexec", server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args); execvp(args[0], (char *const*)args);
} }
fprintf(stderr, "gsr error: gsr_kms_client_init: execvp failed, error: %s\n", strerror(errno)); fprintf(stderr, "gsr error: gsr_kms_client_init: execvp failed, error: %s\n", strerror(errno));
@ -229,16 +248,16 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
struct timeval tv; struct timeval tv;
fd_set rfds; fd_set rfds;
FD_ZERO(&rfds); FD_ZERO(&rfds);
FD_SET(self->socket_fd, &rfds); FD_SET(self->initial_socket_fd, &rfds);
tv.tv_sec = 0; tv.tv_sec = 0;
tv.tv_usec = 100 * 1000; // 100 ms tv.tv_usec = 100 * 1000; // 100 ms
int select_res = select(1 + self->socket_fd, &rfds, NULL, NULL, &tv); int select_res = select(1 + self->initial_socket_fd, &rfds, NULL, NULL, &tv);
if(select_res > 0) { if(select_res > 0) {
socklen_t sock_len = 0; socklen_t sock_len = 0;
self->client_fd = accept(self->socket_fd, (struct sockaddr*)&remote_addr, &sock_len); self->initial_client_fd = accept(self->initial_socket_fd, (struct sockaddr*)&remote_addr, &sock_len);
if(self->client_fd == -1) { if(self->initial_client_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: accept failed on socket, error: %s\n", strerror(errno)); fprintf(stderr, "gsr error: gsr_kms_client_init: accept failed on socket, error: %s\n", strerror(errno));
goto err; goto err;
} }
@ -260,6 +279,13 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
} }
fprintf(stderr, "gsr info: gsr_kms_client_init: server connected\n"); fprintf(stderr, "gsr info: gsr_kms_client_init: server connected\n");
fprintf(stderr, "gsr info: replacing file-backed unix domain socket with socketpair\n");
if(gsr_kms_client_replace_connection(self) != 0)
goto err;
cleanup_initial_socket(self, false);
fprintf(stderr, "gsr info: using socketpair\n");
return 0; return 0;
err: err:
@ -267,47 +293,104 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
return result; return result;
} }
void gsr_kms_client_deinit(gsr_kms_client *self) { void cleanup_initial_socket(gsr_kms_client *self, bool kill_server) {
if(self->client_fd != -1) { if(self->initial_client_fd != -1) {
close(self->client_fd); close(self->initial_client_fd);
self->client_fd = -1; self->initial_client_fd = -1;
} }
if(self->socket_fd != -1) { if(self->initial_socket_fd != -1) {
close(self->socket_fd); close(self->initial_socket_fd);
self->socket_fd = -1; self->initial_socket_fd = -1;
} }
if(self->kms_server_pid != -1) { if(kill_server && self->kms_server_pid != -1) {
kill(self->kms_server_pid, SIGINT); kill(self->kms_server_pid, SIGINT);
int status; int status;
waitpid(self->kms_server_pid, &status, 0); waitpid(self->kms_server_pid, &status, 0);
self->kms_server_pid = -1; self->kms_server_pid = -1;
} }
if(self->socket_path[0] != '\0') { if(self->initial_socket_path[0] != '\0') {
remove(self->socket_path); remove(self->initial_socket_path);
self->socket_path[0] = '\0'; self->initial_socket_path[0] = '\0';
} }
} }
int gsr_kms_client_get_kms(gsr_kms_client *self, gsr_kms_response *response) { void gsr_kms_client_deinit(gsr_kms_client *self) {
response->result = KMS_RESULT_FAILED_TO_SEND; cleanup_initial_socket(self, true);
strcpy(response->err_msg, "failed to send");
for(int i = 0; i < 2; ++i) {
if(self->socket_pair[i] > 0) {
close(self->socket_pair[i]);
self->socket_pair[i] = -1;
}
}
}
int gsr_kms_client_replace_connection(gsr_kms_client *self) {
gsr_kms_response response;
response.version = 0;
response.result = KMS_RESULT_FAILED_TO_SEND;
response.err_msg[0] = '\0';
gsr_kms_request request; gsr_kms_request request;
request.type = KMS_REQUEST_TYPE_GET_KMS; request.version = GSR_KMS_PROTOCOL_VERSION;
if(send_msg_to_server(self->client_fd, &request) == -1) { request.type = KMS_REQUEST_TYPE_REPLACE_CONNECTION;
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to send request message to server\n"); request.new_connection_fd = self->socket_pair[GSR_SOCKET_PAIR_REMOTE];
if(send_msg_to_server(self->initial_client_fd, &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to send request message to server\n");
return -1; return -1;
} }
const int recv_res = recv_msg_from_server(self->client_fd, response); const int recv_res = recv_msg_from_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &response);
if(recv_res == 0) { if(recv_res == 0) {
fprintf(stderr, "gsr warning: gsr_kms_client_get_kms: kms server shut down\n"); fprintf(stderr, "gsr warning: gsr_kms_client_replace_connection: kms server shut down\n");
return -1; return -1;
} else if(recv_res == -1) { } else if(recv_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to receive response\n"); fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to receive response\n");
return -1;
}
if(response.version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: expected gsr-kms-server protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, response.version);
/*close_fds(response);*/
return -1;
}
return 0;
}
int gsr_kms_client_get_kms(gsr_kms_client *self, gsr_kms_response *response) {
response->version = 0;
response->result = KMS_RESULT_FAILED_TO_SEND;
response->err_msg[0] = '\0';
gsr_kms_request request;
request.version = GSR_KMS_PROTOCOL_VERSION;
request.type = KMS_REQUEST_TYPE_GET_KMS;
request.new_connection_fd = 0;
if(send_msg_to_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to send request message to server\n");
strcpy(response->err_msg, "failed to send");
return -1;
}
const int recv_res = recv_msg_from_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], response);
if(recv_res == 0) {
fprintf(stderr, "gsr warning: gsr_kms_client_get_kms: kms server shut down\n");
strcpy(response->err_msg, "failed to receive");
return -1;
} else if(recv_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to receive response\n");
strcpy(response->err_msg, "failed to receive");
return -1;
}
if(response->version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: expected gsr-kms-server protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, response->version);
/*close_fds(response);*/
strcpy(response->err_msg, "mismatching protocol version");
return -1; return -1;
} }

View File

@ -7,9 +7,10 @@
typedef struct { typedef struct {
pid_t kms_server_pid; pid_t kms_server_pid;
int socket_fd; int initial_socket_fd;
int client_fd; int initial_client_fd;
char socket_path[PATH_MAX]; char initial_socket_path[PATH_MAX];
int socket_pair[2];
} gsr_kms_client; } gsr_kms_client;
/* |card_path| should be a path to card, for example /dev/dri/card0 */ /* |card_path| should be a path to card, for example /dev/dri/card0 */

View File

@ -4,9 +4,11 @@
#include <stdint.h> #include <stdint.h>
#include <stdbool.h> #include <stdbool.h>
#define GSR_KMS_PROTOCOL_VERSION 1
#define GSR_KMS_MAX_PLANES 32 #define GSR_KMS_MAX_PLANES 32
typedef enum { typedef enum {
KMS_REQUEST_TYPE_REPLACE_CONNECTION,
KMS_REQUEST_TYPE_GET_KMS KMS_REQUEST_TYPE_GET_KMS
} gsr_kms_request_type; } gsr_kms_request_type;
@ -19,7 +21,9 @@ typedef enum {
} gsr_kms_result; } gsr_kms_result;
typedef struct { typedef struct {
int type; /* gsr_kms_request_type */ uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
int type; /* gsr_kms_request_type */
int new_connection_fd;
} gsr_kms_request; } gsr_kms_request;
typedef struct { typedef struct {
@ -40,7 +44,8 @@ typedef struct {
} gsr_kms_response_fd; } gsr_kms_response_fd;
typedef struct { typedef struct {
int result; /* gsr_kms_result */ uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
int result; /* gsr_kms_result */
char err_msg[128]; char err_msg[128];
gsr_kms_response_fd fds[GSR_KMS_MAX_PLANES]; gsr_kms_response_fd fds[GSR_KMS_MAX_PLANES];
int num_fds; int num_fds;

View File

@ -408,6 +408,8 @@ int main(int argc, char **argv) {
for(;;) { for(;;) {
gsr_kms_request request; gsr_kms_request request;
request.version = 0;
request.type = -1;
struct iovec iov; struct iovec iov;
iov.iov_base = &request; iov.iov_base = &request;
iov.iov_len = sizeof(request); iov.iov_len = sizeof(request);
@ -431,9 +433,42 @@ int main(int argc, char **argv) {
continue; continue;
} }
if(request.version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "kms server error: expected gpu screen recorder protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, request.version);
/*
if(request.new_connection_fd > 0)
close(request.new_connection_fd);
*/
continue;
}
switch(request.type) { switch(request.type) {
case KMS_REQUEST_TYPE_REPLACE_CONNECTION: {
gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.num_fds = 0;
if(request.new_connection_fd > 0) {
if(socket_fd > 0)
close(socket_fd);
socket_fd = request.new_connection_fd;
response.result = KMS_RESULT_OK;
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_REPLACE_CONNECTION request\n");
} else {
response.result = KMS_RESULT_INVALID_REQUEST;
snprintf(response.err_msg, sizeof(response.err_msg), "received invalid connection fd");
fprintf(stderr, "kms server error: %s\n", response.err_msg);
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client request\n");
}
break;
}
case KMS_REQUEST_TYPE_GET_KMS: { case KMS_REQUEST_TYPE_GET_KMS: {
gsr_kms_response response; gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.num_fds = 0; response.num_fds = 0;
if(kms_get_fb(&drm, &response, &c2crtc_map) == 0) { if(kms_get_fb(&drm, &response, &c2crtc_map) == 0) {
@ -452,13 +487,15 @@ int main(int argc, char **argv) {
} }
default: { default: {
gsr_kms_response response; gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.result = KMS_RESULT_INVALID_REQUEST; response.result = KMS_RESULT_INVALID_REQUEST;
response.num_fds = 0;
snprintf(response.err_msg, sizeof(response.err_msg), "invalid request type %d, expected %d (%s)", request.type, KMS_REQUEST_TYPE_GET_KMS, "KMS_REQUEST_TYPE_GET_KMS"); snprintf(response.err_msg, sizeof(response.err_msg), "invalid request type %d, expected %d (%s)", request.type, KMS_REQUEST_TYPE_GET_KMS, "KMS_REQUEST_TYPE_GET_KMS");
fprintf(stderr, "kms server error: %s\n", response.err_msg); fprintf(stderr, "kms server error: %s\n", response.err_msg);
if(send_msg_to_client(socket_fd, &response) == -1) { if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client request\n"); fprintf(stderr, "kms server error: failed to respond to client request\n");
break;
}
break; break;
} }
} }