5 changed files with 516 additions and 51 deletions
@ -0,0 +1,55 @@ |
|||
import strutils |
|||
from net import IpAddress, parseIpAddress, Port, `$` |
|||
|
|||
proc parseField(input: string, output: var string) = |
|||
output = input |
|||
|
|||
proc parseField[T: SomeUnsignedInt](input: string, output: var T) = |
|||
let parsed = parseUInt(input) |
|||
if parsed > T.high: |
|||
raise newException(ValueError, "Unsigned integer out of range") |
|||
output = parsed.T |
|||
|
|||
proc parseField(input: string, output: var IpAddress) = |
|||
output = parseIpAddress(input) |
|||
|
|||
proc parseField(input: string, output: var Port) = |
|||
var portNumber: uint16 |
|||
parseField(input, portNumber) |
|||
output = Port(portNumber) |
|||
|
|||
proc parseField[S, T](input: string, output: var array[S, T]) = |
|||
let parts = input.split(",", S.high) |
|||
if parts.len != S.high + 1: |
|||
raise newException(ValueError, "Array has wrong length") |
|||
for i in 0 .. S.high: |
|||
parseField(parts[i], output[i]) |
|||
|
|||
proc parseField[T](input: string, output: var seq[T]) = |
|||
let parts = input.split(",") |
|||
if parts.len < 1: |
|||
raise newException(ValueError, "Sequence is empty") |
|||
output = newSeq[T](parts.len) |
|||
for i in 0 .. parts.len - 1: |
|||
parseField(parts[i], output[i]) |
|||
|
|||
proc parseField[T: tuple | object](input: string, output: var T) = |
|||
var fieldCount = 0 |
|||
for _ in output.fields: |
|||
fieldCount = fieldCount + 1 |
|||
let args = input.parseArgs(fieldCount) |
|||
var i = 0 |
|||
for value in output.fields: |
|||
parseField(args[i], value) |
|||
i.inc |
|||
|
|||
proc parseArgs*(input: string, count: int, optionalCount = 0): seq[string] = |
|||
assert(optionalCount <= count) |
|||
result = input.split("|", count - 1) |
|||
if result.len < count: |
|||
if result.len < count - optionalCount: |
|||
raise newException(ValueError, "invalid message") |
|||
result.add(repeat("", count - result.len)) |
|||
|
|||
proc parseMessage*[T: tuple | object](input: string): T = |
|||
parseField(input, result) |
@ -0,0 +1,165 @@ |
|||
import algorithm |
|||
import net |
|||
import sequtils |
|||
import unittest |
|||
|
|||
const RandomPortCount = 1000 |
|||
|
|||
proc min(a, b: uint16): uint16 = |
|||
min(a.int32, b.int32).uint16 |
|||
|
|||
proc toUint16(p: Port): uint16 = uint16(p) |
|||
|
|||
proc toPort(u: uint16): Port = Port(u) |
|||
|
|||
proc addOffset(port: uint16, offset: uint16, minValue = 1024'u16, |
|||
maxValue = uint16.high): uint16 = |
|||
assert(port >= minValue) |
|||
assert(port <= maxValue) |
|||
let distanceToMaxValue = maxValue - port |
|||
if distanceToMaxValue < offset: |
|||
return minValue + offset - distanceToMaxValue - 1 |
|||
return port + offset |
|||
|
|||
proc subtractOffset(port: uint16, offset: uint16, minValue = 1024'u16, |
|||
maxValue = uint16.high): uint16 = |
|||
assert(port >= minValue) |
|||
assert(port <= maxValue) |
|||
let distanceToMinValue = port - minValue |
|||
if distanceToMinValue < offset: |
|||
return maxValue - offset + distanceToMinValue + 1 |
|||
return port - offset |
|||
|
|||
proc predictPortRange*(localPort: Port, probedPorts: seq[Port]): seq[Port] = |
|||
if probedPorts.len == 0: |
|||
# No probed ports, so our only guess can be that the NAT is a cone-type NAT |
|||
# and the port mapping preserves the local Port. |
|||
return @[localPort] |
|||
let localPortUint = localPort.uint16 |
|||
let probedPortsUint = probedPorts.map(toUint16) |
|||
if probedPorts.len == 1: |
|||
# Only one server was used for probing, so we cannot know if the NAT is |
|||
# symmetric or not. We are trying the probed port (assuming cone-type NAT) |
|||
# and the next port in a progressive sequence if applicable (assuming |
|||
# symmetric NAT with progressive port mapping). |
|||
result.add(probedPorts[0]) |
|||
if probedPortsUint[0] > localPortUint: |
|||
let offset = probedPortsUint[0] - localPortUint |
|||
result.add(Port(probedPortsUint[0].addOffset(offset))) |
|||
elif probedPortsUint[0] < localPortUint: |
|||
let offset = localPortUint - probedPortsUint[0] |
|||
result.add(Port(probedPortsUint[0].subtractOffset(offset))) |
|||
return |
|||
let deduplicatedPorts = probedPortsUint.deduplicate() |
|||
if deduplicatedPorts.len() == 1: |
|||
# It looks like the NAT is a cone-type NAT. |
|||
return deduplicatedPorts.map(toPort) |
|||
let probedPortsSorted = probedPortsUint.sorted() |
|||
let minPort = probedPortsSorted[probedPortsSorted.minIndex()] |
|||
let maxPort = probedPortsSorted[probedPortsSorted.maxIndex()] |
|||
var minDistance = uint16.high() |
|||
var maxDistance = uint16.low() |
|||
for i in 1 .. probedPortsSorted.len() - 1: |
|||
# FIXME: use rotated distance |
|||
let distance = probedPortsSorted[i] - probedPortsSorted[i - 1] |
|||
minDistance = min(minDistance, distance) |
|||
maxDistance = max(maxDistance, distance) |
|||
if maxDistance < 10: |
|||
if probedPortsUint.isSorted(Ascending): |
|||
# assume symmetric NAT with positive-progressive port mapping |
|||
if minDistance == maxDistance: |
|||
return @[Port(maxPort.addOffset(maxDistance))] |
|||
else: |
|||
for i in countup(0'u16, maxDistance): |
|||
result.add(Port(minPort.addOffset(i))) |
|||
return |
|||
if probedPortsUint.isSorted(Descending): |
|||
# assume symmetric NAT with negative-progressive port mapping |
|||
if minDistance == maxDistance: |
|||
return @[Port(minPort.subtractOffset(maxDistance))] |
|||
else: |
|||
for i in countup(0'u16, maxDistance): |
|||
result.add(Port(maxPort.subtractOffset(i))) |
|||
return |
|||
# assume symmetric NAT with random port mapping |
|||
let portRange = maxPort - minPort |
|||
let first = if portRange > RandomPortCount: |
|||
minPort |
|||
else: |
|||
let notCovered = RandomPortCount - portRange |
|||
max(minPort - notCovered shr 1, 1024) |
|||
let last = first + RandomPortCount |
|||
for i in first .. last: |
|||
result.add(Port(i)) |
|||
|
|||
suite "port prediction tests": |
|||
test "single port": |
|||
let predicted = predictPortRange(Port(1234), @[]) |
|||
check(predicted == @[Port(1234)]) |
|||
|
|||
test "single probe equal": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1234)]) |
|||
check(predicted == @[Port(1234)]) |
|||
|
|||
test "single probe positive-progressive": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1236)]) |
|||
check(predicted == @[Port(1236), Port(1238)]) |
|||
|
|||
test "single probe negative-progressive": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1232)]) |
|||
check(predicted == @[Port(1232), Port(1230)]) |
|||
|
|||
test "all equal": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1234), Port(1234)]) |
|||
check(predicted == @[Port(1234)]) |
|||
|
|||
test "positive-progressive, offset 1": |
|||
let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2035)]) |
|||
check(predicted == @[Port(2036)]) |
|||
|
|||
test "positive-progressive, offset 9": |
|||
let predicted = predictPortRange(Port(1234), @[Port(2034), Port(2043)]) |
|||
check(predicted == @[Port(2052)]) |
|||
|
|||
test "negative-progressive, offset 1": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1099)]) |
|||
check(predicted == @[Port(1098)]) |
|||
|
|||
test "negative-progressive, offset 9": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1100), Port(1091)]) |
|||
check(predicted == @[Port(1082)]) |
|||
|
|||
test "positive-progressive, 3 probed ports, low offset": |
|||
let predicted = predictPortRange(Port(1234), @[Port(2000), Port(2000), Port(2002)]) |
|||
check(predicted == @[Port(2000), Port(2001), Port(2002)]) |
|||
|
|||
test "negative-progressive, 3 probed ports, low offset": |
|||
let predicted = predictPortRange(Port(1234), @[Port(2002), Port(2000), Port(2000)]) |
|||
check(predicted == @[Port(2002), Port(2001), Port(2000)]) |
|||
|
|||
test "high port, positive-progressive, offset 1": |
|||
let predicted = predictPortRange(Port(1234), @[Port(65534), Port(65535)]) |
|||
check(predicted == @[Port(1024)]) |
|||
|
|||
test "high port, positive-progressive, offset 9": |
|||
let predicted = predictPortRange(Port(1234), @[Port(65520), Port(65529)]) |
|||
check(predicted == @[Port(1026)]) |
|||
|
|||
test "low port, negative-progressive, offset 1": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1025), Port(1024)]) |
|||
check(predicted == @[Port(65535)]) |
|||
|
|||
test "low port, negative-progressive, offset 9": |
|||
let predicted = predictPortRange(Port(1234), @[Port(1039), Port(1030)]) |
|||
check(predicted == @[Port(65533)]) |
|||
|
|||
test "random mapping, distance > RandomPortCount": |
|||
let predicted = predictPortRange(Port(1234), @[Port(3546), Port(7624)]) |
|||
check(predicted == toSeq(countup(3546'u16, 3546'u16 + RandomPortCount)).map(toPort)) |
|||
|
|||
test "random mapping, distance < RandomPortCount": |
|||
let centerPort = 30000'u16 |
|||
let minPort = centerPort - RandomPortCount.uint16 shr 1 + 1 |
|||
let maxPort = centerPort + RandomPortCount.uint16 shr 1 - 1 |
|||
let predicted = predictPortRange(Port(centerPort), @[Port(minPort), Port(maxPort)]) |
|||
check(predicted == toSeq(countup(minPort - 1, maxPort + 1)).map(toPort)) |
@ -0,0 +1,84 @@ |
|||
import asyncdispatch, asyncnet, net, port_prediction |
|||
|
|||
from nativesockets import SockAddr, SockAddr_storage, SockLen |
|||
from sequtils import any |
|||
|
|||
type |
|||
Attempt = object |
|||
## A hole punching attempt. |
|||
srcPort: Port |
|||
dstIp: IpAddress |
|||
dstPorts: seq[Port] |
|||
future: Future[Port] |
|||
|
|||
Puncher* = ref object |
|||
sock: AsyncSocket |
|||
attempts: seq[Attempt] |
|||
|
|||
PunchHoleError* = object of ValueError |
|||
|
|||
const Timeout = 3000 |
|||
|
|||
proc `==`(a, b: Attempt): bool = |
|||
## ``==`` for hole punching attempts. |
|||
## |
|||
## Two hole punching attempts are considered equal if their ``srcPort`` and |
|||
## ``dstIp`` are equal and their ``dstPorts`` overlap. |
|||
a.srcPort == b.srcPort and a.dstIp == b.dstIp and |
|||
a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) |
|||
|
|||
proc initPuncher*(sock: AsyncSocket): Puncher = |
|||
Puncher(sock: sock) |
|||
|
|||
proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, |
|||
peerProbedPorts: seq[Port], msg: string): Future[Port] {.async.} = |
|||
let punchFuture = newFuture[Port]("punch") |
|||
let predictedDstPorts = predictPortRange(peerPort, peerProbedPorts) |
|||
let (_, myPort) = puncher.sock.getLocalAddr() |
|||
let attempt = Attempt(srcPort: myPort, dstIp: peerIp, |
|||
dstPorts: predictedDstPorts, future: punchFuture) |
|||
if puncher.attempts.contains(attempt): |
|||
raise newException(PunchHoleError, |
|||
"hole punching for given parameters already active") |
|||
puncher.attempts.add(attempt) |
|||
var peerAddr: Sockaddr_storage |
|||
var peerSockLen: SockLen |
|||
try: |
|||
for dstPort in attempt.dstPorts: |
|||
toSockAddr(attempt.dstIp, dstPort, peerAddr, peerSockLen) |
|||
# TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) |
|||
await sendTo(puncher.sock.getFd().AsyncFD, msg.cstring, msg.len, |
|||
cast[ptr SockAddr](addr peerAddr), peerSockLen) |
|||
await punchFuture or sleepAsync(Timeout) |
|||
if punchFuture.finished(): |
|||
result = punchFuture.read() |
|||
else: |
|||
raise newException(PunchHoleError, "timeout") |
|||
except OSError as e: |
|||
raise newException(PunchHoleError, e.msg) |
|||
|
|||
proc initiate*(puncher: Puncher, peerIp: IpAddress, peerPort: Port, |
|||
peerProbedPorts: seq[Port]): Future[Port] = |
|||
punch(puncher, peerIp, peerPort, peerProbedPorts, "SYN") |
|||
|
|||
proc respond*(puncher: Puncher, peerIp: IpAddress, peerPort: Port, |
|||
peerProbedPorts: seq[Port]): Future[Port] = |
|||
punch(puncher, peerIp, peerPort, peerProbedPorts, "ACK") |
|||
|
|||
proc handleMsg*(puncher: Puncher, msg: string, peerIp: IpAddress, |
|||
peerPort: Port) = |
|||
## Handles an incoming UDP message which may complete the Futures returned by |
|||
## ``initiate`` and ``respond``. |
|||
let (_, myPort) = puncher.sock.getLocalAddr() |
|||
let query = Attempt(srcPort: myPort, dstIp: peerIp, dstPorts: @[peerPort]) |
|||
let i = puncher.attempts.find(query) |
|||
if i != -1: |
|||
puncher.attempts[i].future.complete(peerPort) |
|||
puncher.attempts.del(i) |
|||
|
|||
proc handleMsg*(puncher: Puncher, msg: string, |
|||
peerAddr: SockAddr | Sockaddr_storage, peerSockLen: SockLen) = |
|||
var peerIp: IpAddress |
|||
var peerPort: Port |
|||
fromSockAddr(peerAddr, peerSockLen, peerIp, peerPort) |
|||
handleMsg(puncher, msg, peerIp, peerPort) |
@ -0,0 +1,104 @@ |
|||
import asyncdispatch, asyncnet, message, net, tables, random, strformat |
|||
|
|||
type |
|||
Endpoint* = tuple[hostname: string, port: Port] |
|||
|
|||
ServerConnection* = ref object |
|||
sock: AsyncSocket |
|||
outMessages: TableRef[string, Future[string]] |
|||
peerNotifications*: FutureStream[string] |
|||
probedIp*: IpAddress |
|||
srcPort*: Port |
|||
probedSrcPorts*: seq[Port] |
|||
|
|||
ServerError* = object of ValueError |
|||
|
|||
OkGetPeerinfo* = object |
|||
ip*: IpAddress |
|||
localPort*: Port |
|||
probedPorts*: seq[Port] |
|||
|
|||
OkGetEndpoint* = object |
|||
ip*: IpAddress |
|||
port*: Port |
|||
|
|||
NotifyPeer* = object |
|||
sender*: string |
|||
recipient*: string |
|||
technique*: string |
|||
srcIp*: IpAddress |
|||
srcPort*: Port |
|||
probedSrcPorts*: seq[Port] |
|||
dstIp*: IpAddress |
|||
dstPort*: Port |
|||
probedDstPorts*: seq[Port] |
|||
extraArgs*: string |
|||
|
|||
proc getEndpoint(srcPort: Port, serverHostname: string, serverPort: Port): |
|||
Future[OkGetEndpoint] {.async.} = |
|||
let sock = newAsyncSocket() |
|||
var failCount = 0 |
|||
while true: |
|||
try: |
|||
sock.bindAddr(srcPort) |
|||
break |
|||
except OSError as e: |
|||
if failCount == 3: |
|||
raise e |
|||
failCount.inc |
|||
await sleepAsync(100) |
|||
await sock.connect(serverHostname, serverPort) |
|||
let id = rand(uint32) |
|||
await sock.send(&"get-endpoint|{id}\n") |
|||
let line = await sock.recvLine(maxLength = 400) |
|||
let args = line.parseArgs(3) |
|||
assert(args[0] == "ok") |
|||
assert(args[1] == $id) |
|||
result = parseMessage[OkGetEndpoint](args[2]) |
|||
let emptyLine = await sock.recvLine(maxLength = 400) |
|||
assert(emptyLine.len == 0) |
|||
sock.close() |
|||
|
|||
proc initServerConnection*(serverHostname: string, serverPort: Port, |
|||
srcPort: Port, probingServers: seq[Endpoint]): |
|||
Future[ServerConnection] {.async.} = |
|||
result.srcPort = srcPort |
|||
for s in probingServers: |
|||
let endpoint = await getEndpoint(srcPort, s.hostname, s.port) |
|||
# FIXME: what if we get get different IPs from different servers |
|||
result.probedIp = endpoint.ip |
|||
result.probedSrcPorts.add(endpoint.port) |
|||
result.sock = await asyncnet.dial(serverHostname, |
|||
serverPort) |
|||
result.outMessages = newTable[string, Future[string]]() |
|||
result.peerNotifications = newFutureStream[string]("initServerConnection") |
|||
|
|||
proc handleServerMessages*(conn: ServerConnection) {.async.} = |
|||
while true: |
|||
let line = await conn.sock.recvLine(maxLength = 400) |
|||
let args = line.parseArgs(3, 1) |
|||
case args[0]: |
|||
of "ok": |
|||
let future = conn.outMessages[args[1]] |
|||
conn.outMessages.del(args[1]) |
|||
future.complete(args[2]) |
|||
of "error": |
|||
let future = conn.outMessages[args[1]] |
|||
conn.outMessages.del(args[1]) |
|||
future.fail(newException(ServerError, args[2])) |
|||
of "notify-peer": |
|||
asyncCheck conn.peerNotifications.write(line.substr(args[0].len + 1)) |
|||
else: |
|||
raise newException(ValueError, "invalid server message") |
|||
|
|||
proc sendRequest*(connection: ServerConnection, command: string, |
|||
content: string): Future[string] = |
|||
result = newFuture[string]("sendRequest") |
|||
let id = $rand(uint32) |
|||
var request: string |
|||
if content.len != 0: |
|||
request = &"{command}|{id}|{content}\n" |
|||
else: |
|||
request = &"{command}|{id}\n" |
|||
asyncCheck connection.sock.send(request) |
|||
connection.outMessages[id] = result |
Loading…
Reference in new issue