never executed always true always false
    1 {-# LANGUAGE BangPatterns #-}
    2 
    3 module Regex where
    4 
    5 -- From https://github.com/Gabriel439/slides/blob/master/regex/regex.md
    6 
    7 import Data.Array.Unboxed (Array, UArray, (!))
    8 import Data.Bits ((.|.), (.&.))
    9 import Data.ByteString (ByteString)
   10 import Data.Foldable (foldl')
   11 import Data.Word (Word8)
   12 import Foreign (peek, plusPtr, withForeignPtr)
   13 
   14 import qualified Data.Array.Base
   15 import qualified Data.Array.Unboxed       as Array
   16 import qualified Data.Bits                as Bits
   17 import qualified Data.ByteString          as ByteString
   18 import qualified Data.ByteString.Internal as ByteString.Internal
   19 import qualified GHC.Arr
   20 
   21 integerShiftL :: Integer -> Int -> Integer
   22 integerShiftL = Bits.unsafeShiftL
   23 {-# INLINE integerShiftL #-}
   24 
   25 sizeOfWord :: Int
   26 sizeOfWord = Bits.finiteBitSize (0 :: Word)
   27 {-# INLINE sizeOfWord #-}
   28 
   29 integerFoldl' :: (b -> Int -> b) -> b -> Integer -> b
   30 integerFoldl' f acc0 bits = go acc0 (Bits.popCount bits) 0
   31   where
   32     go !acc 0 _ = acc
   33     go !acc n b =
   34         if Bits.testBit bits b
   35         then go (f acc b) (n - 1) (b + 1)
   36         else go    acc     n      (b + 1)
   37 {-# INLINE integerFoldl' #-}
   38 
   39 data Regex i = Regex
   40     { _numberOfStates         :: Int
   41 
   42     -- Fast path, if the number of states is less than or equal to the number of
   43     -- bits in a `Word`
   44     , _startingStates         :: Word
   45     , _transitionFunction     :: i -> Int -> Word
   46     , _acceptingStates        :: Word
   47 
   48     -- Slow path, if the number of states is greater than the number of bits in
   49     -- a `Word`
   50     --
   51     -- This is ~10x slower
   52     , _startingStatesSlow     :: Integer
   53     , _transitionFunctionSlow :: i -> Int -> Integer
   54     , _acceptingStatesSlow    :: Integer
   55     }
   56 
   57 instance Num (Regex i) where
   58     fromInteger n
   59         | 0 == n    = Regex 0 0 f 0 0 g 0
   60         | 0 <  n    = Regex 1 1 f 1 1 g 1
   61         | otherwise = error "fromInteger[Regex]: Negative numbers unsupported"
   62       where
   63         f _ _ = 0
   64         g _ _ = 0
   65     {-# INLINE fromInteger #-}
   66 
   67     Regex nL asL fL bsL csL gL dsL + Regex nR asR fR bsR csR gR dsR =
   68         Regex n as f bs cs g ds
   69       where
   70         n  = nL + nR
   71 
   72         as = Bits.unsafeShiftL asR nL .|. asL
   73 
   74         f i j =
   75             if j < nL
   76             then fL i j
   77             else Bits.unsafeShiftL (fR i (j - nL)) nL
   78 
   79         bs = Bits.unsafeShiftL bsR nL .|. bsL
   80 
   81         cs = integerShiftL csR nL .|. csL
   82 
   83         g i j =
   84             if j < nL
   85             then gL i j
   86             else integerShiftL (gR i (j - nL)) nL
   87 
   88         ds = integerShiftL dsR nL .|. dsL
   89     {-# INLINE (+) #-}
   90 
   91     Regex nL asL fL bsL csL gL dsL * Regex nR asR fR bsR csR gR dsR =
   92         asR' `seq` csR' `seq` Regex n as f bs cs g ds
   93       where
   94         n = nL + nR
   95 
   96         asR' = Bits.unsafeShiftL asR nL
   97 
   98         as =
   99             if asL .&. bsL == 0
  100             then asL
  101             else asL .|. asR'
  102 
  103         f i j =
  104             if j < nL
  105             then
  106                 if s .&. bsL == 0
  107                 then s
  108                 else s .|. asR'
  109             else Bits.unsafeShiftL (fR i (j - nL)) nL
  110           where
  111             s = fL i j
  112 
  113         bs = Bits.unsafeShiftL bsR nL
  114 
  115         csR' = integerShiftL csR nL
  116 
  117         cs =
  118             if csL .&. dsL == 0
  119             then csL
  120             else csR' .|. csL
  121 
  122         g i j =
  123             if j < nL
  124             then
  125                 if s .&. dsL == 0
  126                 then s
  127                 else s .|. csR'
  128             else integerShiftL (gR i (j - nL)) nL
  129           where
  130             s = gL i j
  131 
  132         ds = integerShiftL dsR nL
  133     {-# INLINE (*) #-}
  134 
  135 star :: Regex i -> Regex i
  136 star (Regex n as f bs cs g ds) = Regex n as f' as cs g' cs
  137   where
  138     f' i j =
  139         let s = f i j
  140         in  if s .&. bs == 0
  141             then s
  142             else s .|. as
  143 
  144     g' i j =
  145         let s = g i j
  146         in  if s .&. ds == 0
  147             then s
  148             else s .|. cs
  149 {-# INLINE star #-}
  150 
  151 plus :: Regex i -> Regex i
  152 plus (Regex n as f bs cs g ds) = Regex n as f' bs cs g' ds
  153   where
  154     f' i j =
  155         let s = f i j
  156         in  if s .&. bs == 0
  157             then s
  158             else s .|. as
  159 
  160     g' i j =
  161         let s = g i j
  162         in  if s .&. ds == 0
  163             then s
  164             else s .|. cs
  165 {-# INLINE plus #-}
  166 
  167 match :: Regex i -> [i] -> Bool
  168 match (Regex n as f bs cs g ds) is
  169     -- Fast path (Bit arithmetic on `Word`s)
  170     | n <= sizeOfWord = bs .&. foldl' step  as is /= 0
  171     -- Slow path (Bit arithmetic on `Integer`s)
  172     | otherwise       = ds .&. foldl' step' cs is /= 0
  173   where
  174     step s0 i = go 0 s0
  175       where
  176         go !acc 0 = acc
  177         go !acc s = go (acc .|. f i j) (Bits.clearBit s j)
  178           where
  179             j = Bits.countTrailingZeros s
  180 
  181     step' s0 i = integerFoldl' (\acc j -> acc .|. g i j) 0 s0
  182 {-# INLINE match #-}
  183 
  184 satisfy :: (i -> Bool) -> Regex i
  185 satisfy predicate = Regex 2 1 f 2 1 g 2
  186   where
  187     f c 0 | predicate c = 2
  188     f _ _               = 0
  189 
  190     g c 0 | predicate c = 2
  191     g _ _               = 0
  192 {-# INLINE satisfy #-}
  193 
  194 once :: Eq i => i -> Regex i
  195 once x = satisfy (== x)
  196 {-# INLINE once #-}
  197 
  198 dot :: Regex i
  199 dot = satisfy (\_ -> True)
  200 {-# INLINE dot #-}
  201 
  202 chars :: Regex i
  203 chars = Regex 1 1 f 1 1 g 1
  204   where
  205     f _ _ = 1
  206     g _ _ = 1
  207 {-# INLINE chars #-}
  208 
  209 bytes :: ByteString -> Regex Word8
  210 bytes w8s = Regex (n + 1) 1 f (Bits.unsafeShiftL 1 n) 1 g (integerShiftL 1 n)
  211   where
  212     n = fromIntegral (ByteString.length w8s)
  213 
  214     f w8 i
  215         | i == n                                      =
  216             0
  217         | ByteString.index w8s (fromIntegral i) == w8 =
  218             Bits.unsafeShiftL 1 (i + 1)
  219         | otherwise                                   =
  220             0
  221 
  222     g w8 i
  223         | i == n                                      =
  224             0
  225         | ByteString.index w8s (fromIntegral i) == w8 =
  226             integerShiftL 1 (i + 1)
  227         | otherwise                                   =
  228             0
  229 
  230 matchBytes :: Regex Word8 -> ByteString -> Bool
  231 matchBytes (Regex n as f bs cs g ds) (ByteString.Internal.PS fp off len)
  232     | n <= sizeOfWord = do
  233         ByteString.Internal.accursedUnutterablePerformIO
  234             (withForeignPtr fp (\p ->
  235                 loop as (p `plusPtr` off) (p `plusPtr` (off+len)) ))
  236     | otherwise       = do
  237         ByteString.Internal.accursedUnutterablePerformIO
  238             (withForeignPtr fp (\p ->
  239                 loop' cs (p `plusPtr` off) (p `plusPtr` (off+len)) ))
  240   where
  241     loop  0  _  _   = return False
  242     loop !z !p !q
  243         | p == q    = return (bs .&. z /= 0)
  244         | otherwise = do
  245             x <- peek p
  246             loop (step z x) (p `plusPtr` 1) q
  247 
  248     step :: Word -> Word8 -> Word
  249     step !s0 i0 = go 0 s0
  250       where
  251         go :: Word -> Word -> Word
  252         go !acc 0 = acc
  253         go !acc s = go acc' s'
  254           where
  255             acc' = acc .|. m
  256             m    = Data.Array.Base.unsafeAt table ix
  257             ix   = GHC.Arr.unsafeIndex bounds (i0, j)
  258             s'   = s .&. Bits.complement (Bits.unsafeShiftL 1 j)
  259             j    = Bits.countTrailingZeros s
  260 
  261         bounds :: ((Word8, Int), (Word8, Int))
  262         bounds = ((0, 0), (255, n - 1))
  263 
  264         table :: UArray (Word8, Int) Word
  265         table =
  266             Array.listArray bounds
  267                 [ f i j
  268                 | i <- [0..255]
  269                 , j <- [0..n-1]
  270                 ]
  271 
  272     loop'  0  _  _  = return False
  273     loop' !z !p !q
  274         | p == q    = return (ds .&. z /= 0)
  275         | otherwise = do
  276             x <- peek p
  277             loop' (step' z x) (p `plusPtr` 1) q
  278 
  279     step' :: Integer -> Word8 -> Integer
  280     step' !s0 i0 = integerFoldl' (\acc j -> acc .|. table ! (i0, j)) 0 s0
  281       where
  282         table :: Array (Word8, Int) Integer
  283         table =
  284             Array.listArray ((0, 0), (255, n - 1))
  285                 [ g i j
  286                 | i <- [0..255]
  287                 , j <- [0..n-1]
  288                 ]
  289 
  290 hasBytes :: Regex Word8 -> ByteString -> Bool
  291 hasBytes (Regex n as f bs cs g ds) (ByteString.Internal.PS fp off len)
  292     | n <= sizeOfWord = do
  293         ByteString.Internal.accursedUnutterablePerformIO
  294             (withForeignPtr fp (\p ->
  295                 loop as (p `plusPtr` off) (p `plusPtr` (off+len)) ))
  296     | otherwise       = do
  297         ByteString.Internal.accursedUnutterablePerformIO
  298             (withForeignPtr fp (\p ->
  299                 loop' cs (p `plusPtr` off) (p `plusPtr` (off+len)) ))
  300   where
  301     loop !z !p !q
  302         | bs .&. z /= 0 = return True
  303         | p == q        = return False
  304         | otherwise     = do
  305             x <- peek p
  306             loop (step z x .|. as) (p `plusPtr` 1) q
  307 
  308     step :: Word -> Word8 -> Word
  309     step !s0 i0 = go 0 s0
  310       where
  311         go :: Word -> Word -> Word
  312         go !acc 0 = acc
  313         go !acc s = go acc' s'
  314           where
  315             acc' = acc .|. m
  316             m    = Data.Array.Base.unsafeAt table ix
  317             ix   = GHC.Arr.unsafeIndex bounds (i0, j)
  318             s'   = s .&. Bits.complement (Bits.unsafeShiftL 1 j)
  319             j    = Bits.countTrailingZeros s
  320 
  321         bounds :: ((Word8, Int), (Word8, Int))
  322         bounds = ((0, 0), (255, n - 1))
  323 
  324         table :: UArray (Word8, Int) Word
  325         table =
  326             Array.listArray bounds
  327                 [ f i j
  328                 | i <- [0..255]
  329                 , j <- [0..n-1]
  330                 ]
  331 
  332     loop' !z !p !q
  333         | ds .&. z /= 0 = return True
  334         | p == q        = return False
  335         | otherwise     = do
  336             x <- peek p
  337             loop' (step' z x .|. cs) (p `plusPtr` 1) q
  338 
  339     step' :: Integer -> Word8 -> Integer
  340     step' !s0 i0 = integerFoldl' (\acc j -> acc .|. table ! (i0, j)) 0 s0
  341       where
  342         table :: Array (Word8, Int) Integer
  343         table =
  344             Array.listArray ((0, 0), (255, n - 1))
  345                 [ g i j
  346                 | i <- [0..255]
  347                 , j <- [0..n-1]
  348                 ]