1 | module Goop.Types.BitsN
3 | import public Data.Bits
4 | import public Data.Nat
5 | import Derive.Prelude
7 | import Goop.Types.AsBytes
8 | import Goop.Util.Formatting
10 | %language ElabReflection
14 | record BitsN (n : Nat) where
18 | %runElab derive "BitsN" [Eq, Ord]
22 | Bytes n = BitsN (8 * n)
25 | Nibbles : Nat -> Type
26 | Nibbles n = BitsN (4 * n)
28 | mask : Nat -> Integer
29 | mask k = (1 `shiftL` k) - 1
32 | {n : Nat} -> Bits (BitsN n) where
34 | (.&.) x y = MkBitsN $
x.inner .&. y.inner
35 | (.|.) x y = MkBitsN $
x.inner .|. y.inner
36 | xor x y = MkBitsN $
x.inner `xor` y.inner
37 | shiftL x y = MkBitsN $
x.inner `shiftL` (finToNat y)
38 | shiftR x y = MkBitsN $
x.inner `shiftR` (finToNat y)
39 | bit n = MkBitsN $
bit (finToNat n)
40 | zeroBits = MkBitsN $
0
41 | oneBits = MkBitsN $
mask n
42 | testBit x i = testBit x.inner (finToNat i)
45 | {n : Nat} -> FiniteBits (BitsN n) where
48 | popCount x = popCount' x.inner 0
50 | popCount' : Integer -> (acc : Nat) -> Nat
51 | popCount' 0 acc = acc
52 | popCount' i acc = assert_total $
54 | then popCount' (i `shiftR` 1) (S acc)
55 | else popCount' (i `shiftR` 1) acc
58 | {n : Nat} -> Num (BitsN n) where
59 | (+) x y = MkBitsN $
(x.inner + y.inner) .&. mask n
60 | (*) x y = MkBitsN $
(x.inner * y.inner) .&. mask n
61 | fromInteger x = MkBitsN $
x .&. mask n
64 | {n : Nat} -> Integral (BitsN n) where
65 | div x y = MkBitsN $
x.inner `div` y.inner
66 | mod x y = MkBitsN $
x.inner `mod` y.inner
69 | {n : Nat} -> Interpolation (BitsN n) where
70 | interpolate x = paddedHex x
73 | {n : Nat} -> Show (BitsN n) where
74 | show x = show x.inner
78 | (++) : {n, m : Nat} -> BitsN n -> BitsN m -> BitsN (n + m)
79 | (++) x y = MkBitsN $
(x.inner `shiftL` m) .|. y.inner
82 | sumLengths : List Nat -> Nat
84 | sumLengths (x :: xs) = x + sumLengths xs
87 | concat : {lengths : List Nat} -> All BitsN lengths -> BitsN (sumLengths lengths)
88 | concat {lengths = []} [] = 0
89 | concat {lengths = (x :: [])} [y] =
90 | rewrite plusZeroRightNeutral x in y
91 | concat {lengths = (x :: xs)} (y :: ys) =
95 | LTE n 64 => Cast (BitsN n) Bits64 where
96 | cast x = cast x.inner
99 | [integerIntermediate]
100 | {n : Nat} -> Cast x Integer => Cast x (BitsN n) where
101 | cast x = MkBitsN $
cast x .&. mask n
104 | Cast (BitsN n) Integer where
108 | {n : Nat} -> Cast Integer (BitsN n) where
109 | cast x = fromInteger x
112 | Cast Bits8 (BitsN 8) where
113 | cast x = MkBitsN (cast x)
116 | Cast Bits16 (BitsN 16) where
117 | cast x = MkBitsN (cast x)
120 | Cast Bits32 (BitsN 32) where
121 | cast x = MkBitsN (cast x)
124 | Cast Bits64 (BitsN 64) where
125 | cast x = MkBitsN (cast x)
129 | split : {m : Nat} -> (n : Nat) -> BitsN (n + m) -> (BitsN n, BitsN m)
130 | split n (MkBitsN inner) =
131 | let fst = inner `shiftR` m
132 | snd = inner .&. mask m
133 | in (MkBitsN fst, MkBitsN snd)
137 | splits : (lengths : List Nat) -> BitsN (sumLengths lengths) -> All BitsN lengths
139 | splits (x :: xs) y =
140 | let (head, rest) = split x y
141 | in head :: splits xs rest
145 | toBytes : BitsN n -> List Bits8
146 | toBytes x = if x.inner == 0 then [0] else toBytes' x.inner
148 | toBytes' : Integer -> List Bits8
151 | let last = i `mod` 0x100
152 | rest = i `div` 0x100
153 | in cast last :: toBytes' (assert_smaller i rest)
157 | BytesFor : (n : Nat) -> Nat
159 | let bytes = n `div` 8
160 | in if n `mod` 8 == 0 then bytes else (S bytes)
162 | data BytesForV : Nat -> Type where
163 | Plus8 : {n : Nat} -> BytesForV (8 + n)
166 | bytesForV : (n : Nat) -> BytesForV n
167 | bytesForV (S (S (S (S (S (S (S (S n)))))))) = Plus8
172 | toVectBytes : {n : Nat} -> BitsN n -> Vect (BytesFor n) Bits8
173 | toVectBytes x = toVectBytes' (BytesFor n) x.inner
175 | toVectBytes' : (digits : Nat) -> Integer -> Vect digits Bits8
176 | toVectBytes' 0 i = []
177 | toVectBytes' (S k) i =
178 | let digit = i .&. 0xFF
179 | rest = i `shiftR` 8
180 | in cast digit :: toVectBytes' k rest
183 | {n : Nat} -> AsBytes (BitsN n) where
184 | asBytes x = asBytes' (BytesFor n) x.inner
186 | asBytes' : (digits : Nat) -> Integer -> LazyList Bits8
189 | let digit = i .&. 0xFF
190 | rest = i `shiftR` 8
191 | in cast digit :: asBytes' k rest
194 | widen : m `LTE` n => BitsN m -> BitsN n
195 | widen x = MkBitsN x.inner
198 | xorFold : {m : Nat} -> (n : Nat) -> m `LTE` n => BitsN (n + m) -> BitsN n
199 | xorFold @{lte} n x =
200 | let (x, y) = split n x
205 | max : {n : Nat} -> BitsN n
206 | max = MkBitsN (mask n)