diff --git a/puncher.nim b/puncher.nim index 3621391..bba33b6 100644 --- a/puncher.nim +++ b/puncher.nim @@ -11,13 +11,15 @@ from sequtils import any type Attempt* = object ## A hole punching attempt. - srcPort*: Port + srcPorts*: seq[Port] dstIp*: IpAddress dstPorts*: seq[Port] - future*: Future[Port] + future*: Future[(AsyncSocket, Port)] Puncher* = ref object - sock: AsyncSocket + socks: seq[AsyncSocket] + srcPorts: seq[Port] + natProps: NatProperties attempts: seq[Attempt] PunchHoleError* = object of ValueError @@ -30,23 +32,34 @@ const Timeout = 3000 proc `==`(a, b: Attempt): bool = ## ``==`` for hole punching attempts. ## - ## Two hole punching attempts are considered equal if their ``srcPort`` and - ## ``dstIp`` are equal and their ``dstPorts`` overlap. - a.srcPort == b.srcPort and a.dstIp == b.dstIp and + ## Two hole punching attempts are considered equal if their ``dstIp`` is + ## equal, their ``srcPorts`` overlap and their ``dstPorts`` overlap. + a.dstIp == b.dstIp and + a.srcPorts.any(proc (p: Port): bool = p in b.srcPorts) and a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) -proc initPuncher*(sock: AsyncSocket): Puncher = - Puncher(sock: sock) +proc initPuncher*(socks: seq[AsyncSocket], probedSrcPorts: seq[Port]): Puncher = + assert(socks.len > 0) + var srcPorts = newSeq[Port](socks.len) + for i in 0 .. socks.len - 1: + let (_, srcPort) = socks[i].getLocalAddr() + srcPorts[i] = srcPort + let natProps = getNatProperties(srcPorts[0], probedSrcPorts) + Puncher(socks: socks, srcPorts: srcPorts, natProps: natProps) proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, peerProbedPorts: seq[Port], lowTTL: bool, msg: string): Future[Attempt] {.async.} = - let punchFuture = newFuture[Port]("punch") + let punchFuture = newFuture[(AsyncSocket, Port)]("punch") let natProps = getNatProperties(peerPort, peerProbedPorts) let predictedDstPorts = predictPortRange(natProps) - let (_, myPort) = puncher.sock.getLocalAddr() - result = Attempt(srcPort: myPort, dstIp: peerIp, dstPorts: predictedDstPorts, - future: punchFuture) + result = Attempt(srcPorts: @[puncher.srcPorts[0]], dstIp: peerIp, + dstPorts: predictedDstPorts, future: punchFuture) + if puncher.natProps.natType == SymmetricRandom and puncher.srcPorts.len > 1: + # our NAT is of the evil symmetric type with random port allocation. We are + # trying to help the other peer by punching more holes using all our + # sockets. + result.srcPorts.add(puncher.srcPorts[1 .. ^1]) if puncher.attempts.contains(result): raise newException(PunchHoleError, "hole punching for given parameters already active") @@ -56,16 +69,19 @@ proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, var peerSockLen: SockLen try: var defaultTTL: int - if lowTTL: - defaultTTL = puncher.sock.getFd.getSockOptInt(IPPROTO_IP, IP_TTL) - puncher.sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) - for dstPort in result.dstPorts: - toSockAddr(result.dstIp, dstPort, peerAddr, peerSockLen) - # TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) - await sendTo(puncher.sock.getFd().AsyncFD, msg.cstring, msg.len, - cast[ptr SockAddr](addr peerAddr), peerSockLen) - if lowTTL: - puncher.sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, defaultTTL) + for i in 0 .. result.srcPorts.len - 1: + let sock = puncher.socks[i] + if lowTTL: + defaultTTL = sock.getFd.getSockOptInt(IPPROTO_IP, IP_TTL) + sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) + + for dstPort in result.dstPorts: + toSockAddr(result.dstIp, dstPort, peerAddr, peerSockLen) + # TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) + await sendTo(sock.getFd.AsyncFD, msg.cstring, msg.len, + cast[ptr SockAddr](addr peerAddr), peerSockLen) + if lowTTL: + sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, defaultTTL) except OSError as e: raise newException(PunchHoleError, e.msg) @@ -77,37 +93,37 @@ proc respond*(puncher: Puncher, peerIp: IpAddress, peerPort: Port, peerProbedPorts: seq[Port]): Future[Attempt] = punch(puncher, peerIp, peerPort, peerProbedPorts, false, "ACK") -proc finalize*(attempt: Attempt): Future[Port] {.async.} = +proc finalize*(attempt: Attempt): Future[(AsyncSocket, Port)] {.async.} = await attempt.future or sleepAsync(Timeout) if attempt.future.finished: result = attempt.future.read() else: raise newException(PunchHoleError, "timeout") -proc handleMsg*(puncher: Puncher, msg: string, peerIp: IpAddress, - peerPort: Port) = +proc handleMsg*(puncher: Puncher, msg: string, sock: AsyncSocket, + peerIp: IpAddress, peerPort: Port) = ## Handles an incoming UDP message which may complete the Futures returned by ## ``initiate`` and ``respond``. if msg == "SYN": # We received a SYN packet. We ignore it because we expected it to be # filtered by our NAT. return - let (_, myPort) = puncher.sock.getLocalAddr() - let query = Attempt(srcPort: myPort, dstIp: peerIp, dstPorts: @[peerPort]) + let query = Attempt(srcPorts: puncher.srcPorts, dstIp: peerIp, + dstPorts: @[peerPort]) let i = puncher.attempts.find(query) if i != -1: if msg == "ACK": echo &"handling ACK message from {peerIp}:{peerPort}" else: echo &"handling QUIC message from {peerIp}:{peerPort}" - puncher.attempts[i].future.complete(peerPort) + puncher.attempts[i].future.complete((sock, peerPort)) puncher.attempts.del(i) else: echo &"received unexpected packet from {peerIp}:{peerPort}" -proc handleMsg*(puncher: Puncher, msg: string, +proc handleMsg*(puncher: Puncher, msg: string, sock: AsyncSocket, peerAddr: SockAddr | Sockaddr_storage, peerSockLen: SockLen) = var peerIp: IpAddress var peerPort: Port fromSockAddr(peerAddr, peerSockLen, peerIp, peerPort) - handleMsg(puncher, msg, peerIp, peerPort) + handleMsg(puncher, msg, sock, peerIp, peerPort) diff --git a/quicp2p.nim b/quicp2p.nim index f7c0a0c..93ae39f 100644 --- a/quicp2p.nim +++ b/quicp2p.nim @@ -25,6 +25,7 @@ import strutils from nativesockets import SockAddr, Sockaddr_storage, SockLen, getHostByName from posix import IOVec +from sequtils import filter from strutils import parseUInt from openssl import @@ -42,11 +43,12 @@ from openssl import type Connection = ref object conn: ptr quicly_conn_t + sock: AsyncSocket certs: seq[Certificate] - peerId: string + expectedPeerId: string QuicP2PContext = ref object - sock: AsyncSocket + socks: seq[AsyncSocket] puncher: Puncher streamOpen: quicly_stream_open_t nextCid: quicly_cid_plaintext_t @@ -66,7 +68,7 @@ const rendezvousServers: seq[tuple[hostname: string, port: Port]] = @[ ("ulrich.earth", Port(5320)) ] -proc getRelativeTimeout(ctx: QuicP2PContext): int32 = +proc relativeTimeout(ctx: QuicP2PContext): int32 = ## Obtain the absolute int64 timeout from quicly and convert it to the ## relative int32 timeout expected by poll. result = 0 @@ -80,7 +82,11 @@ proc getRelativeTimeout(ctx: QuicP2PContext): int32 = let delta = nextTimeout - now result = min(delta, int32.high).int32 -proc getPeerId(ctx: QuicP2PContext): string = +proc srcPort(ctx: QuicP2PContext): Port = + let (_, myPort) = ctx.socks[0].getLocalAddr() + myPort + +proc peerId(ctx: QuicP2PContext): string = assert(ctx.tlsCtx.certificates.count == 2) let firstCertAddr = cast[ByteAddress](ctx.tlsCtx.certificates.list) let secondCertIovec = cast[ptr ptls_iovec_t](firstCertAddr + sizeof(ptls_iovec_t)) @@ -88,7 +94,7 @@ proc getPeerId(ctx: QuicP2PContext): string = copyMem(caCert.cstring, secondCertIovec.base, secondCertIovec.len) result = caCert.getPublicKey().encode(pad = false) -proc getPeerId(conn: Connection): string = +proc peerId(conn: Connection): string = assert(conn.certs.len() == 2) result = conn.certs[1].getPublicKey().encode(pad = false) @@ -108,7 +114,7 @@ proc onServerReceive(stream: ptr quicly_stream_t, offset: csize_t, src: pointer, var msg = newString(input.len) copyMem(addr msg[0], input.base, input.len) let conn = cast[Connection](quicly_get_data(stream.conn)[]) - echo &"client {conn.getPeerId()} sends \"{msg}\"" + echo &"client {conn.peerId} sends \"{msg}\"" if quicly_sendstate_is_open(addr stream.sendstate) != 0 and input.len > 0: discard quicly_streambuf_egress_write(stream, input.base, input.len) if quicly_recvstate_transfer_complete(addr stream.recvstate) != 0: @@ -123,7 +129,7 @@ proc onClientReceive(stream: ptr quicly_stream_t, offset: csize_t, let msg = newString(input.len) copyMem(msg.cstring, input.base, input.len) let conn = cast[Connection](quicly_get_data(stream.conn)[]) - echo &"server {conn.getPeerId()} replies \"{msg}\"" + echo &"server {conn.peerId} replies \"{msg}\"" if quicly_recvstate_transfer_complete(addr stream.recvstate) != 0: discard quicly_close(stream.conn, 0, "") quicly_streambuf_ingress_shift(stream, input.len) @@ -208,7 +214,7 @@ proc verifyCerts(self: ptr ptls_verify_certificate_t, tls: ptr ptls_t, X509_STORE_free(store) X509_free(caCert) -proc initContext(sock: AsyncSocket, puncher: Puncher, certChainPath: string, +proc initContext(socks: seq[AsyncSocket], certChainPath: string, keyPath: string, streamOpenCb: typeof(quicly_stream_open_t.cb)): QuicP2PContext = @@ -217,7 +223,7 @@ proc initContext(sock: AsyncSocket, puncher: Puncher, certChainPath: string, keyExchanges: ptls_openssl_key_exchanges, cipherSuites: ptls_openssl_cipher_suites) quicly_amend_ptls_context(addr tlsCtx) - result = QuicP2PContext(sock: sock, puncher: puncher, + result = QuicP2PContext(socks: socks, streamOpen: quicly_stream_open_t(cb: streamOpenCb), verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts), tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context) @@ -236,10 +242,11 @@ proc initContext(sock: AsyncSocket, puncher: Puncher, certChainPath: string, result.tlsCtx.sign_certificate = addr result.signCertCb.super proc addConnection(ctx: QuicP2PContext, connPtr: ptr quicly_conn_t, - peerId: string) = + sock: AsyncSocket, expectedPeerId: string) = assert(not connPtr.isNil) let data = quicly_get_data(connPtr) - var conn = Connection(conn: connPtr, peerId: peerId) + var conn = Connection(conn: connPtr, sock: sock, + expectedPeerId: expectedPeerId) data[] = addr conn[] ctx.connections.add(conn) @@ -267,7 +274,7 @@ proc sendPackets(ctx: QuicP2PContext) = for j in 0 .. dgramCount - 1: var sockLen = quicly_get_socklen(addr dstAddr.sa) # TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) - asyncCheck sendTo(ctx.sock.getFd().AsyncFD, dgrams[j].iov_base, + asyncCheck sendTo(conns[i].sock.getFd().AsyncFD, dgrams[j].iov_base, dgrams[j].iov_len.int, addr dstAddr.sa, sockLen) of QUICLY_ERROR_FREE_CONNECTION: echo "deleting connection" @@ -276,7 +283,8 @@ proc sendPackets(ctx: QuicP2PContext) = raise newException(ValueError, &"quicly_send returned {sendResult}") proc initiateQuicConnection(ctx: QuicP2PContext, peerId: string, - peerIp: IpAddress, peerPort: Port) = + sock: AsyncSocket, peerIp: IpAddress, + peerPort: Port) = var conn: ptr quicly_conn_t var peerAddr: SockAddr_storage var peerSockLen: SockLen @@ -290,9 +298,10 @@ proc initiateQuicConnection(ctx: QuicP2PContext, peerId: string, return var stream: ptr quicly_stream_t discard quicly_open_stream(conn, addr stream, 0) - ctx.addConnection(conn, peerId) + ctx.addConnection(conn, sock, peerId) proc handleMsg(ctx: QuicP2PContext, msg: string, peerId: string, + puncher: Puncher, sock: AsyncSocket, peerAddr: ptr Sockaddr_storage, peerSockLen: SockLen) = var offset: csize_t = 0 while offset < msg.len().csize_t: @@ -304,7 +313,7 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerId: string, echo "unable to decode packet" return var conn: ptr quicly_conn_t = nil - for c in ctx.connections: + for c in ctx.connections.filter(proc(c: Connection): bool = c.sock == sock): if quicly_is_destination(c.conn, nil, peerAddr, addr decoded) != 0: conn = c.conn break @@ -313,44 +322,47 @@ proc handleMsg(ctx: QuicP2PContext, msg: string, peerId: string, else: # The puncher needs to be informed about this message because it may # be the peer's response to our respond call. - ctx.puncher.handleMsg(msg, peerAddr[], peerSockLen) + puncher.handleMsg(msg, sock, peerAddr[], peerSockLen) if peerId.len == 0: let acceptResult = quicly_accept(addr conn, addr ctx.quiclyCtx, nil, peerAddr, addr decoded, nil, addr ctx.nextCid, nil) if acceptResult == 0: - ctx.addConnection(conn, peerId) + ctx.addConnection(conn, sock, peerId) -proc receive(ctx: QuicP2PContext, peerId: string) {.async.} = +proc receive(ctx: QuicP2PContext, puncher: Puncher, sock: AsyncSocket, + peerId: string) {.async.} = while true: # TODO: replace asyncdispatch.recvFromInto with asyncnet.recvFrom (Nim 1.4 required) var msg = newString(BufferSize) var peerAddr: Sockaddr_storage var peerAddrLen = SockLen(sizeof(peerAddr)) - let msgLen = await recvFromInto(ctx.sock.getFd().AsyncFD, msg.cstring, - msg.len, cast[ptr SockAddr](addr peerAddr), + let msgLen = await recvFromInto(sock.getFd().AsyncFD, msg.cstring, msg.len, + cast[ptr SockAddr](addr peerAddr), addr peerAddrLen) msg.setLen(msgLen) if msg.len > 0: - handleMsg(ctx, msg, peerId, addr peerAddr, peerAddrLen) + handleMsg(ctx, msg, peerId, puncher, sock, addr peerAddr, peerAddrLen) -proc handleNotification(ctx: QuicP2PContext, notification: NotifyPeer) - {.async.} = - let attempt = await ctx.puncher.respond(notification.srcIp, notification.srcPort, - notification.probedsrcPorts) +proc handleNotification(puncher: Puncher, notification: NotifyPeer) {.async.} = + let attempt = await puncher.respond(notification.srcIp, notification.srcPort, + notification.probedsrcPorts) discard await attempt.finalize() -proc runApp(ctx: QuicP2PContext, srcPort: Port, peerId: string) {.async.} = +proc runApp(ctx: QuicP2PContext, peerId: string) {.async.} = let serverConn = await initServerConnection(rendezvousServers[0].hostname, rendezvousServers[0].port, - srcPort, rendezvousServers) + ctx.srcPort, rendezvousServers) + let puncher = initPuncher(ctx.socks, serverConn.probedSrcPorts) + asyncCheck handleServerMessages(serverConn) - asyncCheck receive(ctx, peerId) + for sock in ctx.socks: + asyncCheck receive(ctx, puncher, sock, peerId) if peerId.len == 0: # We are the responder let probedPorts = serverConn.probedSrcPorts.join(",") - let req = &"{ctx.getPeerId()}|{serverConn.probedIp}|{srcPort}|{probedPorts}" + let req = &"{ctx.peerId}|{serverConn.probedIp}|{ctx.srcPort}|{probedPorts}" discard await serverConn.sendRequest("register", req) while true: let (hasData, data) = await serverConn.peerNotifications.read() @@ -359,7 +371,7 @@ proc runApp(ctx: QuicP2PContext, srcPort: Port, peerId: string) {.async.} = try: let msg = parseMessage[NotifyPeer](data) # FIXME: check if we want to receive messages from the sender - asyncCheck handleNotification(ctx, msg) + asyncCheck handleNotification(puncher, msg) except ValueError as e: echo e.msg discard @@ -370,42 +382,49 @@ proc runApp(ctx: QuicP2PContext, srcPort: Port, peerId: string) {.async.} = let peerInfo = parseMessage[OkGetPeerInfo](serverResponse) let myProbedPorts = serverConn.probedSrcPorts.join(",") let peerProbedPorts = peerInfo.probedPorts.join(",") - let req = &"{ctx.getPeerId()}|{peerId}|{serverConn.probedIp}|{srcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}" - let attempt = await ctx.puncher.initiate(peerInfo.ip, peerInfo.localPort, - peerInfo.probedPorts) + let req = &"{ctx.peerId}|{peerId}|{serverConn.probedIp}|{ctx.srcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}" + let attempt = await puncher.initiate(peerInfo.ip, peerInfo.localPort, + peerInfo.probedPorts) discard await serverConn.sendRequest("notify-peer", req) - let peerPort = await attempt.finalize() - initiateQuicConnection(ctx, peerId, peerInfo.ip, peerPort) + let (sock, peerPort) = await attempt.finalize() + initiateQuicConnection(ctx, peerId, sock, peerInfo.ip, peerPort) proc main() = var ctx: QuicP2PContext - let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP, - buffered = false) + var socks = newSeq[AsyncSocket]() randomize() - let srcPort = rand(Port(1024) .. Port.high) - sock.bindAddr(srcPort) - let puncher = initPuncher(sock) + for i in 0 .. 4: + # FIXME: close socks + let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP, + buffered = false) + let srcPort = rand(Port(1024) .. Port.high) + # FIXME: Once we start using UDP for endpoint probing (currently done in + # initServerConnction) we either have to + # - finish the probing before we bind the srcPort here + # - pass the primary socket to initServerConnection + sock.bindAddr(srcPort) + socks.add(sock) case paramCount(): of 0: - ctx = initContext(sock, puncher, serverCertChainPath, serverKeyPath, + ctx = initContext(socks, serverCertChainPath, serverKeyPath, onServerStreamOpen) ctx.tlsCtx.require_client_authentication = 1 - asyncCheck runApp(ctx, srcPort, "") + asyncCheck runApp(ctx, "") of 1: let peerId = paramStr(1) - ctx = initContext(sock, puncher, clientCertChainPath, clientKeyPath, + ctx = initContext(socks, clientCertChainPath, clientKeyPath, onClientStreamOpen) - asyncCheck runApp(ctx, srcPort, peerId) + asyncCheck runApp(ctx, peerId) else: usage() quit(1) - echo "My peer ID is ", ctx.getPeerId() + echo "My peer ID is ", ctx.peerId while true: - let nextTimeout = ctx.getRelativeTimeout() + let nextTimeout = ctx.relativeTimeout poll(nextTimeout) ctx.sendPackets()