never executed always true always false
1 {-# LANGUAGE BangPatterns #-}
2
3 module Regex where
4
5 -- From https://github.com/Gabriel439/slides/blob/master/regex/regex.md
6
7 import Data.Array.Unboxed (Array, UArray, (!))
8 import Data.Bits ((.|.), (.&.))
9 import Data.ByteString (ByteString)
10 import Data.Foldable (foldl')
11 import Data.Word (Word8)
12 import Foreign (peek, plusPtr, withForeignPtr)
13
14 import qualified Data.Array.Base
15 import qualified Data.Array.Unboxed as Array
16 import qualified Data.Bits as Bits
17 import qualified Data.ByteString as ByteString
18 import qualified Data.ByteString.Internal as ByteString.Internal
19 import qualified GHC.Arr
20
21 integerShiftL :: Integer -> Int -> Integer
22 integerShiftL = Bits.unsafeShiftL
23 {-# INLINE integerShiftL #-}
24
25 sizeOfWord :: Int
26 sizeOfWord = Bits.finiteBitSize (0 :: Word)
27 {-# INLINE sizeOfWord #-}
28
29 integerFoldl' :: (b -> Int -> b) -> b -> Integer -> b
30 integerFoldl' f acc0 bits = go acc0 (Bits.popCount bits) 0
31 where
32 go !acc 0 _ = acc
33 go !acc n b =
34 if Bits.testBit bits b
35 then go (f acc b) (n - 1) (b + 1)
36 else go acc n (b + 1)
37 {-# INLINE integerFoldl' #-}
38
39 data Regex i = Regex
40 { _numberOfStates :: Int
41
42 -- Fast path, if the number of states is less than or equal to the number of
43 -- bits in a `Word`
44 , _startingStates :: Word
45 , _transitionFunction :: i -> Int -> Word
46 , _acceptingStates :: Word
47
48 -- Slow path, if the number of states is greater than the number of bits in
49 -- a `Word`
50 --
51 -- This is ~10x slower
52 , _startingStatesSlow :: Integer
53 , _transitionFunctionSlow :: i -> Int -> Integer
54 , _acceptingStatesSlow :: Integer
55 }
56
57 instance Num (Regex i) where
58 fromInteger n
59 | 0 == n = Regex 0 0 f 0 0 g 0
60 | 0 < n = Regex 1 1 f 1 1 g 1
61 | otherwise = error "fromInteger[Regex]: Negative numbers unsupported"
62 where
63 f _ _ = 0
64 g _ _ = 0
65 {-# INLINE fromInteger #-}
66
67 Regex nL asL fL bsL csL gL dsL + Regex nR asR fR bsR csR gR dsR =
68 Regex n as f bs cs g ds
69 where
70 n = nL + nR
71
72 as = Bits.unsafeShiftL asR nL .|. asL
73
74 f i j =
75 if j < nL
76 then fL i j
77 else Bits.unsafeShiftL (fR i (j - nL)) nL
78
79 bs = Bits.unsafeShiftL bsR nL .|. bsL
80
81 cs = integerShiftL csR nL .|. csL
82
83 g i j =
84 if j < nL
85 then gL i j
86 else integerShiftL (gR i (j - nL)) nL
87
88 ds = integerShiftL dsR nL .|. dsL
89 {-# INLINE (+) #-}
90
91 Regex nL asL fL bsL csL gL dsL * Regex nR asR fR bsR csR gR dsR =
92 asR' `seq` csR' `seq` Regex n as f bs cs g ds
93 where
94 n = nL + nR
95
96 asR' = Bits.unsafeShiftL asR nL
97
98 as =
99 if asL .&. bsL == 0
100 then asL
101 else asL .|. asR'
102
103 f i j =
104 if j < nL
105 then
106 if s .&. bsL == 0
107 then s
108 else s .|. asR'
109 else Bits.unsafeShiftL (fR i (j - nL)) nL
110 where
111 s = fL i j
112
113 bs = Bits.unsafeShiftL bsR nL
114
115 csR' = integerShiftL csR nL
116
117 cs =
118 if csL .&. dsL == 0
119 then csL
120 else csR' .|. csL
121
122 g i j =
123 if j < nL
124 then
125 if s .&. dsL == 0
126 then s
127 else s .|. csR'
128 else integerShiftL (gR i (j - nL)) nL
129 where
130 s = gL i j
131
132 ds = integerShiftL dsR nL
133 {-# INLINE (*) #-}
134
135 star :: Regex i -> Regex i
136 star (Regex n as f bs cs g ds) = Regex n as f' as cs g' cs
137 where
138 f' i j =
139 let s = f i j
140 in if s .&. bs == 0
141 then s
142 else s .|. as
143
144 g' i j =
145 let s = g i j
146 in if s .&. ds == 0
147 then s
148 else s .|. cs
149 {-# INLINE star #-}
150
151 plus :: Regex i -> Regex i
152 plus (Regex n as f bs cs g ds) = Regex n as f' bs cs g' ds
153 where
154 f' i j =
155 let s = f i j
156 in if s .&. bs == 0
157 then s
158 else s .|. as
159
160 g' i j =
161 let s = g i j
162 in if s .&. ds == 0
163 then s
164 else s .|. cs
165 {-# INLINE plus #-}
166
167 match :: Regex i -> [i] -> Bool
168 match (Regex n as f bs cs g ds) is
169 -- Fast path (Bit arithmetic on `Word`s)
170 | n <= sizeOfWord = bs .&. foldl' step as is /= 0
171 -- Slow path (Bit arithmetic on `Integer`s)
172 | otherwise = ds .&. foldl' step' cs is /= 0
173 where
174 step s0 i = go 0 s0
175 where
176 go !acc 0 = acc
177 go !acc s = go (acc .|. f i j) (Bits.clearBit s j)
178 where
179 j = Bits.countTrailingZeros s
180
181 step' s0 i = integerFoldl' (\acc j -> acc .|. g i j) 0 s0
182 {-# INLINE match #-}
183
184 satisfy :: (i -> Bool) -> Regex i
185 satisfy predicate = Regex 2 1 f 2 1 g 2
186 where
187 f c 0 | predicate c = 2
188 f _ _ = 0
189
190 g c 0 | predicate c = 2
191 g _ _ = 0
192 {-# INLINE satisfy #-}
193
194 once :: Eq i => i -> Regex i
195 once x = satisfy (== x)
196 {-# INLINE once #-}
197
198 dot :: Regex i
199 dot = satisfy (\_ -> True)
200 {-# INLINE dot #-}
201
202 chars :: Regex i
203 chars = Regex 1 1 f 1 1 g 1
204 where
205 f _ _ = 1
206 g _ _ = 1
207 {-# INLINE chars #-}
208
209 bytes :: ByteString -> Regex Word8
210 bytes w8s = Regex (n + 1) 1 f (Bits.unsafeShiftL 1 n) 1 g (integerShiftL 1 n)
211 where
212 n = fromIntegral (ByteString.length w8s)
213
214 f w8 i
215 | i == n =
216 0
217 | ByteString.index w8s (fromIntegral i) == w8 =
218 Bits.unsafeShiftL 1 (i + 1)
219 | otherwise =
220 0
221
222 g w8 i
223 | i == n =
224 0
225 | ByteString.index w8s (fromIntegral i) == w8 =
226 integerShiftL 1 (i + 1)
227 | otherwise =
228 0
229
230 matchBytes :: Regex Word8 -> ByteString -> Bool
231 matchBytes (Regex n as f bs cs g ds) (ByteString.Internal.PS fp off len)
232 | n <= sizeOfWord = do
233 ByteString.Internal.accursedUnutterablePerformIO
234 (withForeignPtr fp (\p ->
235 loop as (p `plusPtr` off) (p `plusPtr` (off+len)) ))
236 | otherwise = do
237 ByteString.Internal.accursedUnutterablePerformIO
238 (withForeignPtr fp (\p ->
239 loop' cs (p `plusPtr` off) (p `plusPtr` (off+len)) ))
240 where
241 loop 0 _ _ = return False
242 loop !z !p !q
243 | p == q = return (bs .&. z /= 0)
244 | otherwise = do
245 x <- peek p
246 loop (step z x) (p `plusPtr` 1) q
247
248 step :: Word -> Word8 -> Word
249 step !s0 i0 = go 0 s0
250 where
251 go :: Word -> Word -> Word
252 go !acc 0 = acc
253 go !acc s = go acc' s'
254 where
255 acc' = acc .|. m
256 m = Data.Array.Base.unsafeAt table ix
257 ix = GHC.Arr.unsafeIndex bounds (i0, j)
258 s' = s .&. Bits.complement (Bits.unsafeShiftL 1 j)
259 j = Bits.countTrailingZeros s
260
261 bounds :: ((Word8, Int), (Word8, Int))
262 bounds = ((0, 0), (255, n - 1))
263
264 table :: UArray (Word8, Int) Word
265 table =
266 Array.listArray bounds
267 [ f i j
268 | i <- [0..255]
269 , j <- [0..n-1]
270 ]
271
272 loop' 0 _ _ = return False
273 loop' !z !p !q
274 | p == q = return (ds .&. z /= 0)
275 | otherwise = do
276 x <- peek p
277 loop' (step' z x) (p `plusPtr` 1) q
278
279 step' :: Integer -> Word8 -> Integer
280 step' !s0 i0 = integerFoldl' (\acc j -> acc .|. table ! (i0, j)) 0 s0
281 where
282 table :: Array (Word8, Int) Integer
283 table =
284 Array.listArray ((0, 0), (255, n - 1))
285 [ g i j
286 | i <- [0..255]
287 , j <- [0..n-1]
288 ]
289
290 hasBytes :: Regex Word8 -> ByteString -> Bool
291 hasBytes (Regex n as f bs cs g ds) (ByteString.Internal.PS fp off len)
292 | n <= sizeOfWord = do
293 ByteString.Internal.accursedUnutterablePerformIO
294 (withForeignPtr fp (\p ->
295 loop as (p `plusPtr` off) (p `plusPtr` (off+len)) ))
296 | otherwise = do
297 ByteString.Internal.accursedUnutterablePerformIO
298 (withForeignPtr fp (\p ->
299 loop' cs (p `plusPtr` off) (p `plusPtr` (off+len)) ))
300 where
301 loop !z !p !q
302 | bs .&. z /= 0 = return True
303 | p == q = return False
304 | otherwise = do
305 x <- peek p
306 loop (step z x .|. as) (p `plusPtr` 1) q
307
308 step :: Word -> Word8 -> Word
309 step !s0 i0 = go 0 s0
310 where
311 go :: Word -> Word -> Word
312 go !acc 0 = acc
313 go !acc s = go acc' s'
314 where
315 acc' = acc .|. m
316 m = Data.Array.Base.unsafeAt table ix
317 ix = GHC.Arr.unsafeIndex bounds (i0, j)
318 s' = s .&. Bits.complement (Bits.unsafeShiftL 1 j)
319 j = Bits.countTrailingZeros s
320
321 bounds :: ((Word8, Int), (Word8, Int))
322 bounds = ((0, 0), (255, n - 1))
323
324 table :: UArray (Word8, Int) Word
325 table =
326 Array.listArray bounds
327 [ f i j
328 | i <- [0..255]
329 , j <- [0..n-1]
330 ]
331
332 loop' !z !p !q
333 | ds .&. z /= 0 = return True
334 | p == q = return False
335 | otherwise = do
336 x <- peek p
337 loop' (step' z x .|. cs) (p `plusPtr` 1) q
338
339 step' :: Integer -> Word8 -> Integer
340 step' !s0 i0 = integerFoldl' (\acc j -> acc .|. table ! (i0, j)) 0 s0
341 where
342 table :: Array (Word8, Int) Integer
343 table =
344 Array.listArray ((0, 0), (255, n - 1))
345 [ g i j
346 | i <- [0..255]
347 , j <- [0..n-1]
348 ]