0 | ||| Arbitrary length Bits types
  1 | module Goop.Types.BitsN
  2 |
  3 | import public Data.Bits
  4 | import public Data.Nat
  5 | import Derive.Prelude
  6 |
  7 | import Goop.Types.AsBytes
  8 | import Goop.Util.Formatting
  9 |
 10 | %language ElabReflection
 11 | %default total
 12 |
 13 | export
 14 | record BitsN (n : Nat) where
 15 |   constructor MkBitsN
 16 |   inner : Integer
 17 | %name BitsN x,y,z
 18 | %runElab derive "BitsN" [Eq, Ord]
 19 |
 20 | public export
 21 | Bytes : Nat -> Type
 22 | Bytes n = BitsN (8 * n)
 23 |
 24 | public export
 25 | Nibbles : Nat -> Type
 26 | Nibbles n = BitsN (4 * n)
 27 |
 28 | mask : Nat -> Integer
 29 | mask k = (1 `shiftL` k) - 1
 30 |
 31 | export
 32 | {n : Nat} -> Bits (BitsN n) where
 33 |   Index = Fin n
 34 |   (.&.) x y = MkBitsN $ x.inner .&. y.inner
 35 |   (.|.) x y = MkBitsN $ x.inner .|. y.inner
 36 |   xor x y = MkBitsN $ x.inner `xor` y.inner
 37 |   shiftL x y = MkBitsN $ x.inner `shiftL` (finToNat y)
 38 |   shiftR x y = MkBitsN $ x.inner `shiftR` (finToNat y)
 39 |   bit n = MkBitsN $ bit (finToNat n)
 40 |   zeroBits = MkBitsN $ 0
 41 |   oneBits = MkBitsN $ mask n
 42 |   testBit x i = testBit x.inner (finToNat i)
 43 |
 44 | export
 45 | {n : Nat} -> FiniteBits (BitsN n) where
 46 |   bitSize = n
 47 |   bitsToIndex x = x
 48 |   popCount x = popCount' x.inner 0
 49 |     where
 50 |       popCount' : Integer -> (acc : Nat) -> Nat
 51 |       popCount' 0 acc = acc
 52 |       popCount' i acc = assert_total $
 53 |         if testBit i 0
 54 |         then popCount' (i `shiftR` 1) (S acc)
 55 |         else popCount' (i `shiftR` 1) acc
 56 |
 57 | export
 58 | {n : Nat} -> Num (BitsN n) where
 59 |   (+) x y = MkBitsN $ (x.inner + y.inner) .&. mask n
 60 |   (*) x y = MkBitsN $ (x.inner * y.inner) .&. mask n
 61 |   fromInteger x = MkBitsN $ x .&. mask n
 62 |
 63 | export
 64 | {n : Nat} -> Integral (BitsN n) where
 65 |   div x y = MkBitsN $ x.inner `div` y.inner
 66 |   mod x y = MkBitsN $ x.inner `mod` y.inner
 67 |
 68 | export
 69 | {n : Nat} -> Interpolation (BitsN n) where
 70 |   interpolate x = paddedHex x
 71 |
 72 | export
 73 | {n : Nat} -> Show (BitsN n) where
 74 |   show x = show x.inner
 75 |
 76 | export infixr 7 ++
 77 | export
 78 | (++) : {n, m : Nat} -> BitsN n -> BitsN m -> BitsN (n + m)
 79 | (++) x y = MkBitsN $ (x.inner `shiftL` m) .|. y.inner
 80 |
 81 | public export
 82 | sumLengths : List Nat -> Nat
 83 | sumLengths [] = 0
 84 | sumLengths (x :: xs) = x + sumLengths xs
 85 |
 86 | export
 87 | concat : {lengths : List Nat} -> All BitsN lengths -> BitsN (sumLengths lengths)
 88 | concat {lengths = []} [] = 0
 89 | concat {lengths = (x :: [])} [y] =
 90 |   rewrite plusZeroRightNeutral x in y
 91 | concat {lengths = (x :: xs)} (y :: ys) =
 92 |   y ++ concat ys
 93 |
 94 | export
 95 | LTE n 64 => Cast (BitsN n) Bits64 where
 96 |   cast x = cast x.inner
 97 |
 98 | export
 99 | [integerIntermediate]
100 | {n : Nat} -> Cast x Integer => Cast x (BitsN n) where
101 |   cast x = MkBitsN $ cast x .&. mask n
102 |
103 | export
104 | Cast (BitsN n) Integer where
105 |   cast x = x.inner
106 |
107 | export
108 | {n : Nat} -> Cast Integer (BitsN n) where
109 |   cast x = fromInteger x
110 |
111 | export
112 | Cast Bits8 (BitsN 8) where
113 |   cast x = MkBitsN (cast x)
114 |
115 | export
116 | Cast Bits16 (BitsN 16) where
117 |   cast x = MkBitsN (cast x)
118 |
119 | export
120 | Cast Bits32 (BitsN 32) where
121 |   cast x = MkBitsN (cast x)
122 |
123 | export
124 | Cast Bits64 (BitsN 64) where
125 |   cast x = MkBitsN (cast x)
126 |
127 | ||| Split off the first N bits of a number
128 | export
129 | split : {m : Nat} -> (n : Nat) -> BitsN (n + m) -> (BitsN n, BitsN m)
130 | split n (MkBitsN inner) =
131 |   let fst = inner `shiftR` m
132 |       snd = inner .&. mask m
133 |   in (MkBitsN fst, MkBitsN snd)
134 |
135 | ||| Split a number into a list of lengths
136 | export
137 | splits : (lengths : List Nat) -> BitsN (sumLengths lengths) -> All BitsN lengths
138 | splits [] y = []
139 | splits (x :: xs) y =
140 |   let (head, rest) = split x y
141 |   in head :: splits xs rest
142 |
143 | ||| Convert to a list of bytes, LSB first, not including leading zeros
144 | export
145 | toBytes : BitsN n -> List Bits8
146 | toBytes x = if x.inner == 0 then [0] else toBytes' x.inner
147 |   where
148 |   toBytes' : Integer -> List Bits8
149 |   toBytes' 0 = []
150 |   toBytes' i =
151 |     let last = i `mod` 0x100
152 |         rest = i `div` 0x100
153 |     in cast last :: toBytes' (assert_smaller i rest)
154 |
155 | ||| The number of bytes required to contain n bits
156 | public export
157 | BytesFor : (n : Nat) -> Nat
158 | BytesFor n =
159 |   let bytes = n `div` 8
160 |   in if n `mod` 8 == 0 then bytes else (S bytes)
161 |
162 | data BytesForV : Nat -> Type where
163 |   Plus8 : {n : Nat} -> BytesForV (8 + n)
164 |   Done : BytesForV n
165 |
166 | bytesForV : (n : Nat) -> BytesForV n
167 | bytesForV (S (S (S (S (S (S (S (S n)))))))) = Plus8
168 | bytesForV n = Done
169 |
170 | ||| Convert to a vector of bytes, LSB first
171 | export
172 | toVectBytes : {n : Nat} -> BitsN n -> Vect (BytesFor n) Bits8
173 | toVectBytes x = toVectBytes' (BytesFor n) x.inner
174 |   where
175 |     toVectBytes' : (digits : Nat) -> Integer -> Vect digits Bits8
176 |     toVectBytes' 0 i = []
177 |     toVectBytes' (S k) i =
178 |       let digit = i .&. 0xFF
179 |           rest = i `shiftR` 8
180 |       in cast digit :: toVectBytes' k rest
181 |
182 | export
183 | {n : Nat} -> AsBytes (BitsN n) where
184 |   asBytes x = asBytes' (BytesFor n) x.inner
185 |     where
186 |       asBytes' : (digits : Nat) -> Integer -> LazyList Bits8
187 |       asBytes' 0 i = []
188 |       asBytes' (S k) i =
189 |         let digit = i .&. 0xFF
190 |             rest = i `shiftR` 8
191 |         in cast digit :: asBytes' k rest
192 |
193 | export
194 | widen : m `LTE` n => BitsN m -> BitsN n
195 | widen x = MkBitsN x.inner
196 |
197 | export
198 | xorFold : {m : Nat} -> (n : Nat) -> m `LTE` n => BitsN (n + m) -> BitsN n
199 | xorFold @{lte} n x =
200 |   let (x, y) = split n x
201 |       y = widen @{lte} y
202 |   in x `xor` y
203 |
204 | export
205 | max : {n : Nat} -> BitsN n
206 | max = MkBitsN (mask n)
207 |