#include "mbedtls/ssl.h" #include "mbedtls/entropy.h" #include "mbedtls/net_sockets.h" #include "mbedtls/ctr_drbg.h" #include "mbedsock.h" #include #include #include struct mbedsock_ctx *mbedsock_ctx_new_ex(const char *capath) { struct mbedsock_ctx *ctx = malloc(sizeof(struct mbedsock_ctx)); mbedsock_ctx_new(ctx, capath); return ctx; } struct mbedsock *mbedsock_new_ex(struct mbedsock_ctx *ctx) { struct mbedsock *sock = malloc(sizeof(struct mbedsock)); mbedsock_new(ctx, sock); return sock; } int mbedsock_ctx_new(struct mbedsock_ctx *ctx, const char *capath) { int ret = 1; mbedtls_x509_crt_init(&ctx->chain); mbedtls_ctr_drbg_init(&ctx->ctr_drbg); mbedtls_entropy_init(&ctx->entropy); if ((ret = mbedtls_ctr_drbg_seed(&ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, (unsigned char *) SSL_PERS, SSL_PERS_LEN)) != 0) return ret; if((ret = mbedtls_x509_crt_parse_path(&ctx->chain, capath)) < 0 ) return ret; return 0; } int mbedsock_new(struct mbedsock_ctx *ctx, struct mbedsock *sock) { int ret = 1; mbedtls_net_init(&sock->server_fd); mbedtls_ssl_init(&sock->ssl); mbedtls_ssl_config_init(&sock->conf); mbedtls_ssl_conf_authmode(&sock->conf, MBEDTLS_SSL_VERIFY_REQUIRED); mbedtls_ssl_conf_ca_chain(&sock->conf, &ctx->chain, NULL); mbedtls_ssl_conf_rng(&sock->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); if ((ret = mbedtls_ssl_setup(&sock->ssl, &sock->conf)) != 0) return ret; mbedtls_ssl_set_bio(&sock->ssl, &sock->server_fd, mbedtls_net_send, mbedtls_net_recv, NULL); if ((ret = mbedtls_ssl_config_defaults(&sock->conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) return ret; sock->secure = false; return 0; } void mbedsock_free(struct mbedsock *sock) { mbedtls_net_free(&sock->server_fd); mbedtls_ssl_free(&sock->ssl); mbedtls_ssl_config_free(&sock->conf); } void mbedsock_ctx_free(struct mbedsock_ctx *ctx) { mbedtls_x509_crt_free(&ctx->chain); mbedtls_ctr_drbg_free(&ctx->ctr_drbg); mbedtls_entropy_free(&ctx->entropy); } void mbedsock_free_ex(struct mbedsock *sock) { mbedsock_free(sock); free(sock); } void mbedsock_ctx_free_ex(struct mbedsock_ctx *ctx) { mbedsock_ctx_free(ctx); free(ctx); } int mbedsock_do_handshake(struct mbedsock *sock, const char *alpn, const char *sni) { int ret = 1; // Set ALPN, if desired if (alpn != NULL) { const char *alpn_list[2]; alpn_list[0] = alpn; alpn_list[1] = NULL; if ((ret = mbedtls_ssl_conf_alpn_protocols(&sock->conf, alpn_list)) != 0) { return ret; } } // Set SNI, if desired if (sni != NULL) { if ((ret = mbedtls_ssl_set_hostname(&sock->ssl, sni)) != 0) { return ret; } } while ((ret = mbedtls_ssl_handshake(&sock->ssl)) != 0) { if( ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) return ret; } // Verify the certificates if ((ret = mbedtls_ssl_get_verify_result(&sock->ssl)) != 0) { return ret; } sock->secure = true; return 0; } int mbedsock_connect_secure(struct mbedsock *sock, const char *host, const char *port, const char *alpn, const char *sni) { int ret = 1; if ((ret = mbedtls_net_connect(&sock->server_fd, host, port, MBEDTLS_NET_PROTO_TCP)) != 0) return ret; if ((ret = mbedsock_do_handshake(sock, alpn, sni))) return ret; return 0; } int mbedsock_connect(struct mbedsock *sock, const char *host, const char *port) { return mbedtls_net_connect(&sock->server_fd, host, port, MBEDTLS_NET_PROTO_TCP); } int mbedsock_write(struct mbedsock *sock, const unsigned char *data, int len) { int ret = 1; if (sock->secure) { while ((ret = mbedtls_ssl_write(&sock->ssl, data, len)) <= 0) { if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) return -1; } } else { if ((ret = mbedtls_net_send(&sock->server_fd, data, len)) <= 0) return -1; } return ret; } int mbedsock_read(struct mbedsock *sock, unsigned char *buf, int len) { int ret = 1; memset(buf, 0, len); if (sock->secure) { do { ret = mbedtls_ssl_read(&sock->ssl, buf, len); if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) continue; // TODO: Notify if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) break; if (ret < 0) return -1; return ret; } while (true); } else { ret = mbedtls_net_recv(&sock->server_fd, buf, len); if (ret < 0) return -1; return ret; } return 0; } bool mbedsock_is_secure(struct mbedsock *sock) { return sock->secure; }