kms client/server: replace unix domain socket file with socketpair after connecting (fixes issue of .gsr-kms-socket files remaining in $HOME)

This commit is contained in:
dec05eba 2023-11-12 10:55:02 +01:00
parent 1ac862d155
commit 290db495ff
5 changed files with 175 additions and 48 deletions

3
TODO
View File

@ -106,4 +106,5 @@ Make it possible to select which /dev/dri/card* to use, but that requires opengl
Support I915_FORMAT_MOD_Y_TILED_CCS (and other power saving modifiers, see https://trac.ffmpeg.org/ticket/8542). The only fix may be to use desktop portal for recording. This issue doesn't appear on x11 since these modifiers are not used by xorg server.
Test if p2 state can be worked around by using pure nvenc api and overwriting cuInit/cuCtxCreate* to not do anything. Cuda might be loaded when using nvenc but it might not be used, with certain record options? (such as h264 p5).
Test if p2 state can be worked around by using pure nvenc api and overwriting cuInit/cuCtxCreate* to not do anything. Cuda might be loaded when using nvenc but it might not be used, with certain record options? (such as h264 p5).
nvenc uses cuda when using b frames and rgb->yuv conversion, so convert the image ourselves instead.

View File

@ -12,6 +12,12 @@
#include <sys/wait.h>
#include <sys/capability.h>
#define GSR_SOCKET_PAIR_LOCAL 0
#define GSR_SOCKET_PAIR_REMOTE 1
static void cleanup_initial_socket(gsr_kms_client *self, bool kill_server);
static int gsr_kms_client_replace_connection(gsr_kms_client *self);
static bool generate_random_characters(char *buffer, int buffer_size, const char *alphabet, size_t alphabet_size) {
int fd = open("/dev/urandom", O_RDONLY);
if(fd == -1) {
@ -34,6 +40,14 @@ static bool generate_random_characters(char *buffer, int buffer_size, const char
return true;
}
static void close_fds(gsr_kms_response *response) {
for(int i = 0; i < response->num_fds; ++i) {
if(response->fds[i].fd > 0)
close(response->fds[i].fd);
response->fds[i].fd = 0;
}
}
static int send_msg_to_server(int server_fd, gsr_kms_request *request) {
struct iovec iov;
iov.iov_base = request;
@ -72,9 +86,7 @@ static int recv_msg_from_server(int server_fd, gsr_kms_response *response) {
response->fds[i].fd = fds[i];
}
} else {
for(int i = 0; i < response->num_fds; ++i) {
response->fds[i].fd = 0;
}
close_fds(response);
}
}
@ -142,13 +154,15 @@ static bool find_program_in_path(const char *program_name, char *filepath, int f
int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
int result = -1;
self->kms_server_pid = -1;
self->socket_fd = -1;
self->client_fd = -1;
self->socket_path[0] = '\0';
self->initial_socket_fd = -1;
self->initial_client_fd = -1;
self->initial_socket_path[0] = '\0';
self->socket_pair[0] = -1;
self->socket_pair[1] = -1;
struct sockaddr_un local_addr = {0};
struct sockaddr_un remote_addr = {0};
if(!create_socket_path(self->socket_path, sizeof(self->socket_path))) {
if(!create_socket_path(self->initial_socket_path, sizeof(self->initial_socket_path))) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to create path to kms socket\n");
return -1;
}
@ -185,20 +199,25 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
}
}
self->socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if(self->socket_fd == -1) {
if(socketpair(AF_UNIX, SOCK_STREAM, 0, self->socket_pair) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: socketpair failed, error: %s\n", strerror(errno));
goto err;
}
self->initial_socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if(self->initial_socket_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: socket failed, error: %s\n", strerror(errno));
goto err;
}
local_addr.sun_family = AF_UNIX;
strncpy_safe(local_addr.sun_path, self->socket_path, sizeof(local_addr.sun_path));
if(bind(self->socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path)) == -1) {
strncpy_safe(local_addr.sun_path, self->initial_socket_path, sizeof(local_addr.sun_path));
if(bind(self->initial_socket_fd, (struct sockaddr*)&local_addr, sizeof(local_addr.sun_family) + strlen(local_addr.sun_path)) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to bind socket, error: %s\n", strerror(errno));
goto err;
}
if(listen(self->socket_fd, 1) == -1) {
if(listen(self->initial_socket_fd, 1) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: failed to listen on socket, error: %s\n", strerror(errno));
goto err;
}
@ -209,13 +228,13 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
goto err;
} else if(pid == 0) { /* child */
if(inside_flatpak) {
const char *args[] = { "flatpak-spawn", "--host", "pkexec", "flatpak", "run", "--command=gsr-kms-server", "com.dec05eba.gpu_screen_recorder", self->socket_path, card_path, NULL };
const char *args[] = { "flatpak-spawn", "--host", "pkexec", "flatpak", "run", "--command=gsr-kms-server", "com.dec05eba.gpu_screen_recorder", self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args);
} else if(has_perm) {
const char *args[] = { server_filepath, self->socket_path, card_path, NULL };
const char *args[] = { server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args);
} else {
const char *args[] = { "pkexec", server_filepath, self->socket_path, card_path, NULL };
const char *args[] = { "pkexec", server_filepath, self->initial_socket_path, card_path, NULL };
execvp(args[0], (char *const*)args);
}
fprintf(stderr, "gsr error: gsr_kms_client_init: execvp failed, error: %s\n", strerror(errno));
@ -229,16 +248,16 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
struct timeval tv;
fd_set rfds;
FD_ZERO(&rfds);
FD_SET(self->socket_fd, &rfds);
FD_SET(self->initial_socket_fd, &rfds);
tv.tv_sec = 0;
tv.tv_usec = 100 * 1000; // 100 ms
int select_res = select(1 + self->socket_fd, &rfds, NULL, NULL, &tv);
int select_res = select(1 + self->initial_socket_fd, &rfds, NULL, NULL, &tv);
if(select_res > 0) {
socklen_t sock_len = 0;
self->client_fd = accept(self->socket_fd, (struct sockaddr*)&remote_addr, &sock_len);
if(self->client_fd == -1) {
self->initial_client_fd = accept(self->initial_socket_fd, (struct sockaddr*)&remote_addr, &sock_len);
if(self->initial_client_fd == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_init: accept failed on socket, error: %s\n", strerror(errno));
goto err;
}
@ -260,6 +279,13 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
}
fprintf(stderr, "gsr info: gsr_kms_client_init: server connected\n");
fprintf(stderr, "gsr info: replacing file-backed unix domain socket with socketpair\n");
if(gsr_kms_client_replace_connection(self) != 0)
goto err;
cleanup_initial_socket(self, false);
fprintf(stderr, "gsr info: using socketpair\n");
return 0;
err:
@ -267,47 +293,104 @@ int gsr_kms_client_init(gsr_kms_client *self, const char *card_path) {
return result;
}
void gsr_kms_client_deinit(gsr_kms_client *self) {
if(self->client_fd != -1) {
close(self->client_fd);
self->client_fd = -1;
void cleanup_initial_socket(gsr_kms_client *self, bool kill_server) {
if(self->initial_client_fd != -1) {
close(self->initial_client_fd);
self->initial_client_fd = -1;
}
if(self->socket_fd != -1) {
close(self->socket_fd);
self->socket_fd = -1;
if(self->initial_socket_fd != -1) {
close(self->initial_socket_fd);
self->initial_socket_fd = -1;
}
if(self->kms_server_pid != -1) {
if(kill_server && self->kms_server_pid != -1) {
kill(self->kms_server_pid, SIGINT);
int status;
waitpid(self->kms_server_pid, &status, 0);
self->kms_server_pid = -1;
}
if(self->socket_path[0] != '\0') {
remove(self->socket_path);
self->socket_path[0] = '\0';
if(self->initial_socket_path[0] != '\0') {
remove(self->initial_socket_path);
self->initial_socket_path[0] = '\0';
}
}
int gsr_kms_client_get_kms(gsr_kms_client *self, gsr_kms_response *response) {
response->result = KMS_RESULT_FAILED_TO_SEND;
strcpy(response->err_msg, "failed to send");
void gsr_kms_client_deinit(gsr_kms_client *self) {
cleanup_initial_socket(self, true);
for(int i = 0; i < 2; ++i) {
if(self->socket_pair[i] > 0) {
close(self->socket_pair[i]);
self->socket_pair[i] = -1;
}
}
}
int gsr_kms_client_replace_connection(gsr_kms_client *self) {
gsr_kms_response response;
response.version = 0;
response.result = KMS_RESULT_FAILED_TO_SEND;
response.err_msg[0] = '\0';
gsr_kms_request request;
request.type = KMS_REQUEST_TYPE_GET_KMS;
if(send_msg_to_server(self->client_fd, &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to send request message to server\n");
request.version = GSR_KMS_PROTOCOL_VERSION;
request.type = KMS_REQUEST_TYPE_REPLACE_CONNECTION;
request.new_connection_fd = self->socket_pair[GSR_SOCKET_PAIR_REMOTE];
if(send_msg_to_server(self->initial_client_fd, &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to send request message to server\n");
return -1;
}
const int recv_res = recv_msg_from_server(self->client_fd, response);
const int recv_res = recv_msg_from_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &response);
if(recv_res == 0) {
fprintf(stderr, "gsr warning: gsr_kms_client_get_kms: kms server shut down\n");
fprintf(stderr, "gsr warning: gsr_kms_client_replace_connection: kms server shut down\n");
return -1;
} else if(recv_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to receive response\n");
fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: failed to receive response\n");
return -1;
}
if(response.version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "gsr error: gsr_kms_client_replace_connection: expected gsr-kms-server protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, response.version);
/*close_fds(response);*/
return -1;
}
return 0;
}
int gsr_kms_client_get_kms(gsr_kms_client *self, gsr_kms_response *response) {
response->version = 0;
response->result = KMS_RESULT_FAILED_TO_SEND;
response->err_msg[0] = '\0';
gsr_kms_request request;
request.version = GSR_KMS_PROTOCOL_VERSION;
request.type = KMS_REQUEST_TYPE_GET_KMS;
request.new_connection_fd = 0;
if(send_msg_to_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], &request) == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to send request message to server\n");
strcpy(response->err_msg, "failed to send");
return -1;
}
const int recv_res = recv_msg_from_server(self->socket_pair[GSR_SOCKET_PAIR_LOCAL], response);
if(recv_res == 0) {
fprintf(stderr, "gsr warning: gsr_kms_client_get_kms: kms server shut down\n");
strcpy(response->err_msg, "failed to receive");
return -1;
} else if(recv_res == -1) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: failed to receive response\n");
strcpy(response->err_msg, "failed to receive");
return -1;
}
if(response->version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "gsr error: gsr_kms_client_get_kms: expected gsr-kms-server protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, response->version);
/*close_fds(response);*/
strcpy(response->err_msg, "mismatching protocol version");
return -1;
}

View File

@ -7,9 +7,10 @@
typedef struct {
pid_t kms_server_pid;
int socket_fd;
int client_fd;
char socket_path[PATH_MAX];
int initial_socket_fd;
int initial_client_fd;
char initial_socket_path[PATH_MAX];
int socket_pair[2];
} gsr_kms_client;
/* |card_path| should be a path to card, for example /dev/dri/card0 */

View File

@ -4,9 +4,11 @@
#include <stdint.h>
#include <stdbool.h>
#define GSR_KMS_PROTOCOL_VERSION 1
#define GSR_KMS_MAX_PLANES 32
typedef enum {
KMS_REQUEST_TYPE_REPLACE_CONNECTION,
KMS_REQUEST_TYPE_GET_KMS
} gsr_kms_request_type;
@ -19,7 +21,9 @@ typedef enum {
} gsr_kms_result;
typedef struct {
int type; /* gsr_kms_request_type */
uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
int type; /* gsr_kms_request_type */
int new_connection_fd;
} gsr_kms_request;
typedef struct {
@ -40,7 +44,8 @@ typedef struct {
} gsr_kms_response_fd;
typedef struct {
int result; /* gsr_kms_result */
uint32_t version; /* GSR_KMS_PROTOCOL_VERSION */
int result; /* gsr_kms_result */
char err_msg[128];
gsr_kms_response_fd fds[GSR_KMS_MAX_PLANES];
int num_fds;

View File

@ -408,6 +408,8 @@ int main(int argc, char **argv) {
for(;;) {
gsr_kms_request request;
request.version = 0;
request.type = -1;
struct iovec iov;
iov.iov_base = &request;
iov.iov_len = sizeof(request);
@ -431,9 +433,42 @@ int main(int argc, char **argv) {
continue;
}
if(request.version != GSR_KMS_PROTOCOL_VERSION) {
fprintf(stderr, "kms server error: expected gpu screen recorder protocol version to be %u, but it's %u\n", GSR_KMS_PROTOCOL_VERSION, request.version);
/*
if(request.new_connection_fd > 0)
close(request.new_connection_fd);
*/
continue;
}
switch(request.type) {
case KMS_REQUEST_TYPE_REPLACE_CONNECTION: {
gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.num_fds = 0;
if(request.new_connection_fd > 0) {
if(socket_fd > 0)
close(socket_fd);
socket_fd = request.new_connection_fd;
response.result = KMS_RESULT_OK;
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client KMS_REQUEST_TYPE_REPLACE_CONNECTION request\n");
} else {
response.result = KMS_RESULT_INVALID_REQUEST;
snprintf(response.err_msg, sizeof(response.err_msg), "received invalid connection fd");
fprintf(stderr, "kms server error: %s\n", response.err_msg);
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client request\n");
}
break;
}
case KMS_REQUEST_TYPE_GET_KMS: {
gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.num_fds = 0;
if(kms_get_fb(&drm, &response, &c2crtc_map) == 0) {
@ -452,13 +487,15 @@ int main(int argc, char **argv) {
}
default: {
gsr_kms_response response;
response.version = GSR_KMS_PROTOCOL_VERSION;
response.result = KMS_RESULT_INVALID_REQUEST;
response.num_fds = 0;
snprintf(response.err_msg, sizeof(response.err_msg), "invalid request type %d, expected %d (%s)", request.type, KMS_REQUEST_TYPE_GET_KMS, "KMS_REQUEST_TYPE_GET_KMS");
fprintf(stderr, "kms server error: %s\n", response.err_msg);
if(send_msg_to_client(socket_fd, &response) == -1) {
if(send_msg_to_client(socket_fd, &response) == -1)
fprintf(stderr, "kms server error: failed to respond to client request\n");
break;
}
break;
}
}