96 lines
3.1 KiB
Nim
96 lines
3.1 KiB
Nim
|
import asyncdispatch, strformat
|
||
|
from net import IpAddress, Port, `$`, toSockAddr
|
||
|
from nativesockets import SockAddr, Sockaddr_storage, SockLen
|
||
|
from sequtils import any
|
||
|
import asyncutils
|
||
|
import ip_packet
|
||
|
|
||
|
type
|
||
|
Attempt = tuple | object
|
||
|
|
||
|
Puncher*[T: Attempt] = ref object
|
||
|
attempts*: seq[T]
|
||
|
|
||
|
PunchHoleError* = object of ValueError
|
||
|
|
||
|
const Timeout* = 3000
|
||
|
|
||
|
proc findAttempt*(puncher: Puncher, srcIp: IpAddress, srcPort: Port,
|
||
|
dstIp: IpAddress, dstPorts: seq[Port]): int =
|
||
|
for (index, attempt) in puncher.attempts.pairs():
|
||
|
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
|
||
|
attempt.dstIp == dstIp and
|
||
|
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
|
||
|
return index
|
||
|
return -1
|
||
|
|
||
|
proc findAttemptsByLocalAddr*(puncher: Puncher[Attempt], address: IpAddress,
|
||
|
port: Port): seq[Attempt] =
|
||
|
for attempt in puncher.attempts:
|
||
|
if attempt.srcIp == address and attempt.srcPort == port:
|
||
|
result.add(attempt)
|
||
|
|
||
|
proc injectTcpPacket*(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
|
||
|
assert(ipPacket.protocol == tcp)
|
||
|
try:
|
||
|
let packet = serialize(ipPacket)
|
||
|
var sockaddr: Sockaddr_storage
|
||
|
var sockaddrLen: SockLen
|
||
|
toSockAddr(ipPacket.ipAddrDst, ipPacket.tcpPortDst, sockaddr, sockaddrLen)
|
||
|
await rawFd.sendTo(packet.cstring, packet.len,
|
||
|
cast[ptr SockAddr](addr sockaddr), sockaddrLen)
|
||
|
except OSError as e:
|
||
|
raise newException(PunchHoleError, e.msg)
|
||
|
|
||
|
proc predictPortRange*(dstPorts: seq[Port]): seq[Port] =
|
||
|
# TODO: do real port prediction
|
||
|
result = newSeq[Port](1)
|
||
|
let basePort = min(dstPorts[1].uint16,
|
||
|
uint16.high - (result.len - 1).uint16)
|
||
|
for i in 0 .. result.len - 1:
|
||
|
result[i] = Port(basePort + i.uint16)
|
||
|
|
||
|
proc makeFirewallRule(srcIp: IpAddress, srcPort: Port,
|
||
|
dstIp: IpAddress, dstPort: Port): string =
|
||
|
# FIXME: use & instead of fmt?
|
||
|
result = fmt"""-w \
|
||
|
-d {srcIp} \
|
||
|
-p icmp \
|
||
|
--icmp-type time-exceeded \
|
||
|
-m conntrack \
|
||
|
--ctstate RELATED \
|
||
|
--ctproto tcp \
|
||
|
--ctorigsrc {srcIp} \
|
||
|
--ctorigsrcport {srcPort.int} \
|
||
|
--ctorigdst {dstIp} \
|
||
|
--ctorigdstport {dstPort.int} \
|
||
|
-j DROP"""
|
||
|
|
||
|
proc iptablesInsert(chain: string, rule: string) {.async.} =
|
||
|
let firewall_cmd = fmt"iptables -I {chain} {rule}"
|
||
|
discard await asyncExecCmd(firewall_cmd)
|
||
|
|
||
|
proc iptablesDelete(chain: string, rule: string) {.async.} =
|
||
|
let firewall_cmd = fmt"iptables -D {chain} {rule}"
|
||
|
discard await asyncExecCmd(firewall_cmd)
|
||
|
|
||
|
proc addFirewallRules*(attempt: Attempt) {.async.} =
|
||
|
for dstPort in attempt.dstPorts:
|
||
|
let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort,
|
||
|
attempt.dstIp, dstPort)
|
||
|
try:
|
||
|
await iptablesInsert("INPUT", rule)
|
||
|
except OSError as e:
|
||
|
echo "cannot add firewall rule: ", e.msg
|
||
|
raise newException(PunchHoleError, e.msg)
|
||
|
|
||
|
proc deleteFirewallRules*(attempt: Attempt) {.async.} =
|
||
|
for dstPort in attempt.dstPorts:
|
||
|
let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort,
|
||
|
attempt.dstIp, dstPort)
|
||
|
try:
|
||
|
await iptablesDelete("INPUT", rule)
|
||
|
except OSError:
|
||
|
# At least we tried
|
||
|
discard
|