punchd/ip_packet.nim

190 lines
7.1 KiB
Nim

from nativesockets import ntohs, ntohl, htons, htonl
from net import IpAddress, IpAddressFamily, Port, `$`
from posix import InAddr, inet_ntoa
from random import randomize, rand
import unittest
type
Ether_header {.importc: "struct ether_header", pure, final,
header: "<netinet/if_ether.h>".} = object
ether_dhost: array[6, uint8]
ether_shost: array[6, uint8]
ether_type: cushort
Ip {.importc: "struct ip", pure, final,
header: "<netinet/ip.h>".} = object
when cpuEndian == littleEndian:
ip_hl {.bitsize:4.}: uint # header length
ip_v {.bitsize:4.}: uint # version
else:
ip_v {.bitsize:4.}: uint # version
ip_hl {.bitsize:4.}: uint # header length
ip_tos: uint8 # type of service
ip_len: cushort # total length
ip_id: cushort # identification
ip_off: cushort # fragment offset field
ip_ttl: uint8 # time to live
ip_p: cuchar # protocol
ip_sum: cushort # checksum
ip_src: InAddr # source address
ip_dst: InAddr # destination address
Tcphdr {.importc: "struct tcphdr", pure, final,
header: "<netinet/tcp.h>".} = object
th_sport: uint16 # source port
th_dport: uint16 # destination port
th_seq: uint32 # sequence number
th_ack: uint32 # ackkowledgment number
when cpuEndian == littleEndian:
th_x2 {.bitsize:4.}: uint8 # (unused)
th_off {.bitsize:4.}: uint8 # data offset
else:
th_off {.bitsize:4.}: uint8 # (unused)
th_x2 {.bitsize:4.}: uint8 # data offset
th_flags: uint8 # flags
th_win: uint16 # window
th_sum: uint16 # checksum
th_urp: uint16 # urgent pointer
TransportProtocol* = enum
tcp
other
TcpFlag* {.size: sizeof(uint8).} = enum
FIN
SYN
RST
PSH
ACK
URG
IpPacket* = object
ipAddrSrc*: IpAddress
ipAddrDst*: IpAddress
ipTTL*: uint8
case protocol*: TransportProtocol
of tcp:
tcpPortSrc*: Port
tcpPortDst*: Port
tcpSeqNumber*: uint32
tcpAckNumber*: uint32
tcpFlags*: set[TcpFlag]
tcpWindowSize*: uint16
else:
discard
var
ETHERTYPE_IP {.importc: "ETHERTYPE_IP", header: "<netinet/if_ether.h>".}: cushort
IPPROTO_TCP {.importc: "IPPROTO_TCP", header: "<netinet/in.h>".}: cint
proc parseEthernetPacket*(input: string): IpPacket =
let etherHeader = cast[ptr Ether_header](input.cstring)
if ntohs(etherHeader.ether_type) == ETHERTYPE_IP:
let ipHeader = cast[ptr Ip](cast[int](input.cstring) + sizeof(Ether_header))
let ipSrcScalar = ipHeader.ip_src.s_addr
let ipDstScalar = ipHeader.ip_dst.s_addr
let ipSrc = IpAddress(family: Ipv4,
address_v4: cast[array[4, uint8]](ipSrcScalar))
let ipDst = IpAddress(family: Ipv4,
address_v4: cast[array[4, uint8]](ipDstScalar))
if ipHeader.ip_p.int == IPPROTO_TCP:
let tcpHeader = cast[ptr Tcphdr](cast[int](ipHeader) + ipHeader.ip_hl.int * 4)
result = IpPacket(protocol: tcp,
ipAddrSrc: ipSrc,
ipAddrDst: ipDst,
ipTTL: ipHeader.ip_ttl,
tcpPortSrc: Port(ntohs(tcpHeader.th_sport)),
tcpPortDst: Port(ntohs(tcpHeader.th_dport)),
tcpSeqNumber: ntohl(tcpHeader.th_seq),
tcpAckNumber: ntohl(tcpHeader.th_ack),
tcpFlags: cast[set[TcpFlag]](tcpHeader.th_flags),
tcpWindowSize: ntohs(tcpHeader.th_win))
else:
result = IpPacket(protocol: other)
else:
result = IpPacket(protocol: other)
proc tcpChecksum(buffer: string): uint16 =
let ip = cast[ptr Ip](buffer.cstring)
var tcp = cast[ptr uint16](cast[ByteAddress](ip) + ip.ip_hl.int * 4)
var tcpLen = ntohs(ip.ip_len) - ip.ip_hl.uint16 * 4
var checksum = 0.uint32
# Add source IP address
checksum = checksum + (ntohl(ip.ip_src.s_addr) shr 16 and 0xFFFF).uint16
checksum = checksum + (ntohl(ip.ip_src.s_addr) and 0xFFFF).uint16
# Add dest IP address
checksum = checksum + (ntohl(ip.ip_dst.s_addr) shr 16 and 0xFFFF).uint16
checksum = checksum + (ntohl(ip.ip_dst.s_addr) and 0xFFFF).uint16
# Add protocol and reserved
checksum = checksum + ip.ip_p.uint16
# Add length of the IP payload
checksum = checksum + tcpLen
# Add IP payload
while tcpLen > 1:
checksum = checksum + ntohs(tcp[])
tcp = cast[ptr uint16](cast[ByteAddress](tcp) + sizeof(uint16))
tcpLen = tcpLen - sizeof(uint16).uint16
# Add leftover byte
if tcpLen != 0:
checksum = checksum + cast[ptr uint8](tcp)[]
# Fold 16-bit segments of the 32-bit checksum and add carry bit. Result is the
# one's complement.
checksum = (checksum and 0xFFFF).uint16 + (checksum shr 16).uint16
checksum = checksum + checksum shr 16
result = not checksum.uint16
proc serialize*(packet: IpPacket): string =
randomize()
case packet.protocol
of tcp:
result = newString(sizeof(Ip) + sizeof(Tcphdr))
zeroMem(result.cstring, result.len)
let srcIp = InAddr(s_addr: cast[uint32](packet.ipAddrSrc.address_v4))
let dstIp = InAddr(s_addr: cast[uint32](packet.ipAddrDst.address_v4))
let ipHeader = cast[ptr Ip](addr result[0])
ipHeader.ip_hl = 5
ipHeader.ip_v = 4
ipHeader.ip_tos = 0
ipHeader.ip_len = htons(sizeof(Ip).uint16 + sizeof(TcpHdr).uint16)
ipHeader.ip_id = rand(cushort)
ipHeader.ip_off = 0
ipHeader.ip_ttl = packet.ipTTL
ipHeader.ip_p = 6.cuchar
ipHeader.ip_src = srcIp
ipHeader.ip_dst = dstIp
let tcpHeader = cast[ptr Tcphdr](addr result[sizeof(Ip)])
tcpHeader.th_sport = htons(packet.tcpPortSrc.uint16)
tcpHeader.th_dport = htons(packet.tcpPortDst.uint16)
tcpHeader.th_seq = htonl(packet.tcpSeqNumber)
tcpHeader.th_ack = htonl(packet.tcpAckNumber)
tcpHeader.th_off = 5
tcpHeader.th_flags = cast[uint8](packet.tcpFlags)
tcpHeader.th_win = htons(packet.tcpWindowSize)
tcpHeader.th_urp = 0
tcpHeader.th_sum = htons(tcpChecksum(result))
else:
raise newException(ValueError, "protocol not supported")
suite "ip_packet tests":
setup:
var
ipHeader = Ip(ip_hl: 5, ip_v: 4, ip_tos: 0, ip_len: htons(40),
ip_id: htons(54321), ip_off: 0, ip_ttl: 64, ip_p: 6.cuchar,
ip_src: InAddr(s_addr: htonl(0x5abb2bd9.uint32)),
ip_dst: InAddr(s_addr: htonl(0x0a000069.uint32)))
tcpHeader = Tcphdr(th_sport: htons(4321), th_dport: htons(1234),
th_seq: htonl(345364), th_ack: 0, th_off: 5,
th_flags: cast[uint8]({SYN}), th_win: htons(1452 * 10),
th_urp: 0)
test "tcpChecksum":
var buffer = newString(sizeof(Ip) + sizeof(Tcphdr))
zeroMem(addr buffer[0], buffer.len)
copyMem(addr buffer[0], addr ipHeader, sizeof(Ip))
copyMem(addr buffer[sizeof(Ip)], addr tcpHeader, sizeof(Tcphdr))
check(tcpChecksum(buffer) == 35681)