diff --git a/quicp2p.nim b/quicp2p.nim index f2df705..4976ba7 100644 --- a/quicp2p.nim +++ b/quicp2p.nim @@ -16,25 +16,43 @@ import picotls/openssl as ptls_openssl import strformat from nativesockets import SockAddr, Sockaddr_storage, SockLen, getHostByName -from openssl import DLLSSLName, EVP_PKEY from posix import IOVec from strutils import parseUInt +from openssl import + DLLSSLName, + EVP_PKEY, + SslPtr, + PX509, + PX509_STORE, + X509_STORE_new, + X509_STORE_free, + X509_STORE_add_cert, + PSTACK, + d2i_X509 + const serverCertChainPath = "./certs/server-certchain.pem" const serverKeyPath = "./certs/server-cert.key" const clientCertChainPath = "./certs/server-certchain.pem" const clientKeyPath = "./certs/server-cert.key" +const X509_V_FLAG_CHECK_SS_SIGNATURE = 0x00004000 + type QuicP2PContext = ref object sock: AsyncSocket streamOpen: quicly_stream_open_t nextCid: quicly_cid_plaintext_t - signCertificate: ptls_openssl_sign_certificate_t + signCertCb: ptls_openssl_sign_certificate_t + verifyCertsCb: ptls_verify_certificate_t tlsCtx: ptls_context_t quiclyCtx: quicly_context_t connections: seq[ptr quicly_conn_t] + PX509_STORE_CTX = SslPtr + + PX509_VERIFY_PARAM = SslPtr + proc PEM_read_PrivateKey(fp: File, x: ptr EVP_PKEY, cb: proc(buf: cstring, size: cint, rwflag: cint, u: pointer): cint {.cdecl.}, u: pointer): EVP_PKEY @@ -42,6 +60,31 @@ proc PEM_read_PrivateKey(fp: File, x: ptr EVP_PKEY, proc EVP_PKEY_free(key: EVP_PKEY) {.importc, dynlib: DLLSSLName, cdecl.} +proc X509_free(a: PX509) {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_STORE_CTX_new(): PX509_STORE_CTX + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_STORE_CTX_free(ctx: PX509_STORE_CTX) + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_STORE_CTX_init(ctx: PX509_STORE_CTX, store: PX509_STORE, x509: PX509, + chain: PSTACK): cint + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_STORE_CTX_get0_param(ctx: PX509_STORE_CTX): PX509_VERIFY_PARAM + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_VERIFY_PARAM_get_flags(param: PX509_VERIFY_PARAM): culong + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_VERIFY_PARAM_set_flags(param: PX509_VERIFY_PARAM, + flags: culong): int + {.importc, dynlib: DLLSSLName, cdecl.} + +proc X509_verify_cert(ctx: PX509_STORE_CTX): int + {.importc, dynlib: DLLSSLName, cdecl.} + proc getRelativeTimeout(ctx: QuicP2PContext): int32 = ## Obtain the absolute int64 timeout from quicly and convert it to the ## relative int32 timeout expected by poll. @@ -122,10 +165,46 @@ proc onClientStreamOpen(self: ptr quicly_stream_open_t, let msg = "hello server" discard quicly_streambuf_egress_write(stream, msg.cstring, msg.len().csize_t) +proc verifyCACertSignature(cert: PX509): bool = + let ctx = X509_STORE_CTX_new() + let store = X509_STORE_new() + discard X509_STORE_add_cert(store, cert) + discard X509_STORE_CTX_init(ctx, store, cert, nil) + let verifyParams = X509_STORE_CTX_get0_param(ctx) + let flags = X509_VERIFY_PARAM_get_flags(verifyParams) + discard X509_VERIFY_PARAM_set_flags(verifyParams, flags or X509_V_FLAG_CHECK_SS_SIGNATURE) + let verifyResult = X509_verify_cert(ctx) + result = verifyResult == 1 + X509_STORE_CTX_free(ctx) + X509_STORE_free(store) + +proc verifyCerts(self: ptr ptls_verify_certificate_t, tls: ptr ptls_t, + verify_sign: ptr VerifySignCb, verify_data: ptr pointer, + certs: ptr ptls_iovec_t, num_certs: csize_t): cint {.cdecl.} = + if num_certs != 2: + return PTLS_ALERT_UNKNOWN_CA + # parse the highest certificate and use it as CA certificate + var iovec = cast[ptr ptls_iovec_t](cast[ByteAddress](certs) + sizeof(ptls_iovec_t)) + var iovecBase = iovec.base + let caCert = d2i_X509(nil, cast[ptr ptr cuchar](addr iovecBase), + iovec.len.cint) + if caCert.isNil: + return PTLS_ALERT_BAD_CERTIFICATE + if not verifyCACertSignature(caCert): + return PTLS_ALERT_BAD_CERTIFICATE + let store = X509_STORE_new() + discard X509_STORE_add_cert(store, caCert) + var opensslVerifier: ptls_openssl_verify_certificate_t + discard ptls_openssl_init_verify_certificate(addr opensslVerifier, store) + result = opensslVerifier.super.cb(addr opensslVerifier.super, tls, + verify_sign, verify_data, certs, num_certs) + ptls_openssl_dispose_verify_certificate(addr opensslVerifier) + X509_STORE_free(store) + X509_free(caCert) + proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, streamOpenCb: typeof(quicly_stream_open_t.cb)): QuicP2PContext = - var tlsCtx = ptls_context_t(randomBytes: ptls_openssl_random_bytes, getTime: addr ptls_get_time, keyExchanges: ptls_openssl_key_exchanges, @@ -133,9 +212,11 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, quicly_amend_ptls_context(addr tlsCtx) result = QuicP2PContext(sock: sock, streamOpen: quicly_stream_open_t(cb: streamOpenCb), + verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts), tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context) result.quiclyCtx.tls = addr result.tlsCtx result.quiclyCtx.stream_open = addr result.streamOpen + result.tlsCtx.verify_certificate = addr result.verifyCertsCb if ptls_load_certificates(addr result.tlsCtx, certChainPath.cstring) != 0: raise newException(ValueError, &"cannot load certificate chain {certChainPath}") let pKeyFile = open(keyPath) @@ -143,9 +224,9 @@ proc initContext(sock: AsyncSocket, certChainPath: string, keyPath: string, pkeyFile.close() if privateKey == nil: raise newException(ValueError, &"cannot load private key {keyPath}") - discard ptls_openssl_init_sign_certificate(addr result.signCertificate, privateKey) + discard ptls_openssl_init_sign_certificate(addr result.signCertCb, privateKey) EVP_PKEY_free(privateKey) - result.tlsCtx.signCertificate = addr result.signCertificate.super + result.tlsCtx.sign_certificate = addr result.signCertCb.super proc sendPackets(ctx: QuicP2PContext) = if ctx.connections.len == 0: @@ -222,6 +303,7 @@ proc main() = sock.bindAddr(Port(portNumber)) ctx = initContext(sock, serverCertChainPath, serverKeyPath, onServerStreamOpen) + ctx.tlsCtx.require_client_authentication = 1 asyncCheck receive(ctx, true) of 2: