store transport protocol in an Attempt; consider protocol when comparing attempts; puncher.getProtocol not needed anymore
This commit is contained in:
parent
5c6050faaf
commit
5bdb69f214
19
punchd.nim
19
punchd.nim
|
@ -31,10 +31,10 @@ proc sendToClient(unixSock: AsyncSocket, msg: string,
|
||||||
let unixFd = unixSock.getFd.AsyncFD
|
let unixFd = unixSock.getFd.AsyncFD
|
||||||
await unixFd.asyncSendMsg(msg, cmsgs)
|
await unixFd.asyncSendMsg(msg, cmsgs)
|
||||||
|
|
||||||
proc findAttemptsByLocalAddr(punchd: Punchd, srcIp: IpAddress,
|
proc findAttemptsByLocalAddr(punchd: Punchd, protocol: Protocol,
|
||||||
srcPort: Port): seq[Attempt] =
|
srcIp: IpAddress, srcPort: Port): seq[Attempt] =
|
||||||
proc matchesLocalAddr(a: Attempt): bool =
|
proc matchesLocalAddr(a: Attempt): bool =
|
||||||
a.srcIp == srcIp and a.srcPort == srcPort
|
a.protocol == protocol and a.srcIp == srcIp and a.srcPort == srcPort
|
||||||
punchd.attempts.filter(matchesLocalAddr)
|
punchd.attempts.filter(matchesLocalAddr)
|
||||||
|
|
||||||
proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port,
|
proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port,
|
||||||
|
@ -68,19 +68,20 @@ proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port,
|
||||||
else:
|
else:
|
||||||
let acceptFuture = punchd.attempts[i].acceptFuture.get()
|
let acceptFuture = punchd.attempts[i].acceptFuture.get()
|
||||||
acceptFuture.complete(peer)
|
acceptFuture.complete(peer)
|
||||||
let localAddrMatches = punchd.findAttemptsByLocalAddr(ip, port)
|
let localAddrMatches = punchd.findAttemptsByLocalAddr(protocol, ip, port)
|
||||||
if localAddrMatches.len() <= 1:
|
if localAddrMatches.len() <= 1:
|
||||||
break
|
break
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
proc addAttempt(punchd: Punchd, attempt: Attempt, puncher: Puncher) =
|
proc addAttempt(punchd: Punchd, attempt: Attempt) =
|
||||||
let localAddrMatches = punchd.findAttemptsByLocalAddr(attempt.srcIp,
|
let localAddrMatches = punchd.findAttemptsByLocalAddr(attempt.protocol,
|
||||||
|
attempt.srcIp,
|
||||||
attempt.srcPort)
|
attempt.srcPort)
|
||||||
punchd.attempts.add(attempt)
|
punchd.attempts.add(attempt)
|
||||||
if localAddrMatches.len() == 0:
|
if localAddrMatches.len() == 0:
|
||||||
if attempt.acceptFuture.isSome():
|
if attempt.acceptFuture.isSome():
|
||||||
asyncCheck punchd.acceptConnections(attempt.srcIp, attempt.srcPort,
|
asyncCheck punchd.acceptConnections(attempt.srcIp, attempt.srcPort,
|
||||||
puncher.getProtocol())
|
attempt.protocol)
|
||||||
elif localAddrMatches.contains(attempt):
|
elif localAddrMatches.contains(attempt):
|
||||||
raise newException(PunchHoleError,
|
raise newException(PunchHoleError,
|
||||||
"hole punching for given parameters already active")
|
"hole punching for given parameters already active")
|
||||||
|
@ -101,7 +102,7 @@ proc handleRequest(punchd: Punchd, line: string,
|
||||||
case args[0]:
|
case args[0]:
|
||||||
of "initiate":
|
of "initiate":
|
||||||
attempt = puncher.parseInitiateRequest(args[3])
|
attempt = puncher.parseInitiateRequest(args[3])
|
||||||
punchd.addAttempt(attempt, puncher)
|
punchd.addAttempt(attempt)
|
||||||
proc progress(extraArgs: string) {.async.} =
|
proc progress(extraArgs: string) {.async.} =
|
||||||
let msg = &"progress|{id}|{args[2]}|{args[3]}|{extraArgs}\n"
|
let msg = &"progress|{id}|{args[2]}|{args[3]}|{extraArgs}\n"
|
||||||
await sendToClient(unixSock, msg)
|
await sendToClient(unixSock, msg)
|
||||||
|
@ -110,7 +111,7 @@ proc handleRequest(punchd: Punchd, line: string,
|
||||||
|
|
||||||
of "respond":
|
of "respond":
|
||||||
attempt = puncher.parseRespondRequest(args[3])
|
attempt = puncher.parseRespondRequest(args[3])
|
||||||
punchd.addAttempt(attempt, puncher)
|
punchd.addAttempt(attempt)
|
||||||
sock = await puncher.respond(attempt)
|
sock = await puncher.respond(attempt)
|
||||||
punchd.removeAttempt(attempt)
|
punchd.removeAttempt(attempt)
|
||||||
|
|
||||||
|
|
14
puncher.nim
14
puncher.nim
|
@ -16,6 +16,7 @@ type
|
||||||
## protocol can be obtained by calling ``getProcotol``. The puncher expects
|
## protocol can be obtained by calling ``getProcotol``. The puncher expects
|
||||||
## the caller to complete the future when a connections from
|
## the caller to complete the future when a connections from
|
||||||
## ``dstIp``:``dstPort`` has been accepted.
|
## ``dstIp``:``dstPort`` has been accepted.
|
||||||
|
protocol*: Protocol
|
||||||
srcIp*: IpAddress
|
srcIp*: IpAddress
|
||||||
srcPort*: Port
|
srcPort*: Port
|
||||||
dstIp*: IpAddress
|
dstIp*: IpAddress
|
||||||
|
@ -40,10 +41,11 @@ const Timeout* = 3000
|
||||||
proc `==`*(a, b: Attempt): bool =
|
proc `==`*(a, b: Attempt): bool =
|
||||||
## ``==`` for hole punching attempts.
|
## ``==`` for hole punching attempts.
|
||||||
##
|
##
|
||||||
## Two hole punching attempts are considered equal if their ``srcIp``,
|
## Two hole punching attempts are considered equal if their ``protocol`` is
|
||||||
## ``srcPort`` and ``dstIp`` are equal and their ``dstPorts`` overlap.
|
## the same, ``srcIp``, ``srcPort`` and ``dstIp`` are equal and their
|
||||||
a.srcIp == b.srcIp and a.srcPort == b.srcPort and a.dstIp == b.dstIp and
|
## ``dstPorts`` overlap.
|
||||||
a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts)
|
a.protocol == b.protocol and a.srcIp == b.srcIp and a.srcPort == b.srcPort and
|
||||||
|
a.dstIp == b.dstIp and a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts)
|
||||||
|
|
||||||
method cleanup*(attempt: Attempt): Future[void] {.base, async.} =
|
method cleanup*(attempt: Attempt): Future[void] {.base, async.} =
|
||||||
## Cleans up when an attempt finished (either successful or not).
|
## Cleans up when an attempt finished (either successful or not).
|
||||||
|
@ -51,10 +53,6 @@ method cleanup*(attempt: Attempt): Future[void] {.base, async.} =
|
||||||
## Does nothing. Override for custom attempt types.
|
## Does nothing. Override for custom attempt types.
|
||||||
discard
|
discard
|
||||||
|
|
||||||
method getProtocol*(puncher: Puncher): Protocol {.base.} =
|
|
||||||
## Returns the transport protocol the puncher employs.
|
|
||||||
raise newException(CatchableError, "Method without implementation override")
|
|
||||||
|
|
||||||
method parseInitiateRequest*(puncher: Puncher, args: string): Attempt {.base.} =
|
method parseInitiateRequest*(puncher: Puncher, args: string): Attempt {.base.} =
|
||||||
## Creates a new hole punching attempt by parsing arguments of an ``initiate``
|
## Creates a new hole punching attempt by parsing arguments of an ``initiate``
|
||||||
## request.
|
## request.
|
||||||
|
|
|
@ -56,23 +56,22 @@ proc initTcpNutssPuncher*(): TcpNutssPuncher =
|
||||||
randomize()
|
randomize()
|
||||||
TcpNutssPuncher()
|
TcpNutssPuncher()
|
||||||
|
|
||||||
method getProtocol*(puncher: TcpNutssPuncher): Protocol =
|
|
||||||
IPPROTO_TCP
|
|
||||||
|
|
||||||
method parseInitiateRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
|
method parseInitiateRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
|
||||||
let parsed = parseMessage[InitiateRequest](args)
|
let parsed = parseMessage[InitiateRequest](args)
|
||||||
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
||||||
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
||||||
let acceptFuture = newFuture[AsyncSocket]("parseInitiateRequest")
|
let acceptFuture = newFuture[AsyncSocket]("parseInitiateRequest")
|
||||||
Attempt(srcIp: localIp, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
|
Attempt(protocol: IPPROTO_TCP, srcIp: localIp, srcPort: parsed.srcPorts[0],
|
||||||
dstPorts: predictedDstPorts, acceptFuture: some(acceptFuture))
|
dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
|
||||||
|
acceptFuture: some(acceptFuture))
|
||||||
|
|
||||||
method parseRespondRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
|
method parseRespondRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
|
||||||
let parsed = parseMessage[RespondRequest](args)
|
let parsed = parseMessage[RespondRequest](args)
|
||||||
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
||||||
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
||||||
Attempt(srcIp: localIp, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
|
Attempt(protocol: IPPROTO_TCP, srcIp: localIp, srcPort: parsed.srcPorts[0],
|
||||||
dstPorts: predictedDstPorts, acceptFuture: none(Future[AsyncSocket]))
|
dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
|
||||||
|
acceptFuture: none(Future[AsyncSocket]))
|
||||||
|
|
||||||
method initiate*(puncher: TcpNutssPuncher, attempt: Attempt,
|
method initiate*(puncher: TcpNutssPuncher, attempt: Attempt,
|
||||||
progress: PunchProgressCb): Future[AsyncSocket] {.async.} =
|
progress: PunchProgressCb): Future[AsyncSocket] {.async.} =
|
||||||
|
|
10
tcp_syni.nim
10
tcp_syni.nim
|
@ -137,8 +137,9 @@ method parseInitiateRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
|
||||||
let parsed = parseMessage[InitiateRequest](args)
|
let parsed = parseMessage[InitiateRequest](args)
|
||||||
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
||||||
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
||||||
TcpSyniInitiateAttempt(srcIp: localIp, srcPort: parsed.srcPorts[0],
|
TcpSyniInitiateAttempt(protocol: IPPROTO_TCP, srcIp: localIp,
|
||||||
dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
|
srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
|
||||||
|
dstPorts: predictedDstPorts,
|
||||||
acceptFuture: none(Future[AsyncSocket]))
|
acceptFuture: none(Future[AsyncSocket]))
|
||||||
|
|
||||||
method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
|
method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
|
||||||
|
@ -146,8 +147,9 @@ method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
|
||||||
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
let localIp = getPrimaryIPAddr(parsed.dstIp)
|
||||||
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
let predictedDstPorts = predictPortRange(parsed.dstPorts)
|
||||||
let acceptFuture = newFuture[AsyncSocket]("parseRespondRequest")
|
let acceptFuture = newFuture[AsyncSocket]("parseRespondRequest")
|
||||||
TcpSyniRespondAttempt(srcIp: localIp, srcPort: parsed.srcPorts[0],
|
TcpSyniRespondAttempt(protocol: IPPROTO_TCP, srcIp: localIp,
|
||||||
dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
|
srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
|
||||||
|
dstPorts: predictedDstPorts,
|
||||||
acceptFuture: some(acceptFuture),
|
acceptFuture: some(acceptFuture),
|
||||||
seqNums: parsed.seqNums)
|
seqNums: parsed.seqNums)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue