module Vulkan.Utils.ShaderQQ.Shaderc
  ( hlsl
  , comp
  , frag
  , geom
  , tesc
  , tese
  , vert
  , ShadercError
  , ShadercWarning
  , compileShaderQ
  , compileShader
  , processShadercMessages
  ) where

import           Control.Monad                  ( void )
import           Control.Monad.IO.Class
import           Data.ByteString                ( ByteString )
import qualified Data.ByteString               as BS
import qualified Data.ByteString.Lazy.Char8    as BSL
import           Data.FileEmbed
import           Data.Foldable                  ( asum )
import           Data.List.Extra
import           Language.Haskell.TH
import           Language.Haskell.TH.Quote
import           System.Exit
import           System.IO.Temp
import           System.Process.Typed
import           Text.ParserCombinators.ReadP
import           Vulkan.Utils.ShaderQQ.Interpolate

-- $setup
-- >>> :set -XQuasiQuotes

-- | 'hlsl' is a QuasiQuoter which produces HLSL source code with a @#line@
-- directive inserted so that error locations point to the correct location in
-- the Haskell source file. It also permits basic string interpolation.
--
-- - Interpolated variables are prefixed with @$@
-- - They can optionally be surrounded with braces like @${foo}@
-- - Interpolated variables are converted to strings with 'show'
-- - To escape a @$@ use @\\$@
--
-- It is intended to be used in concert with 'compileShaderQ' like so
--
-- @
-- myConstant = 3.141 -- Note that this will have to be in a different module
-- myFragmentShader = $(compileShaderQ "frag" [hlsl|
--   static const float myConstant = ${myConstant};
--   float main (){
--     return myConstant;
--   }
-- |])
-- @
--
-- An explicit example (@<interactive>@ is from doctest):
--
-- >>> let foo = 450 :: Int in [hlsl|const float foo = $foo|]
-- "#line 31 \"<interactive>\"\nconst float foo = 450"
--
-- Note that line number will be thrown off if any of the interpolated
-- variables contain newlines.
hlsl :: QuasiQuoter
hlsl :: QuasiQuoter
hlsl = (String -> QuasiQuoter
badQQ "hlsl")
  { quoteExp :: String -> Q Exp
quoteExp = \s :: String
s -> do
                 Loc
loc <- Q Loc
location
                 -- Insert the directive here, `compileShaderQ` will insert
                 -- another one, but it's before this one, so who cares.
                 let codeWithLineDirective :: String
codeWithLineDirective = String -> Loc -> String
insertLineDirective String
s Loc
loc
                 String -> Q Exp
interpExp String
codeWithLineDirective
  }

-- | QuasiQuoter for creating a compute shader.
--
-- Equivalent to calling @$(compileShaderQ "comp" [hlsl|...|])@ without
-- interpolation support.
comp :: QuasiQuoter
comp :: QuasiQuoter
comp = String -> QuasiQuoter
shaderQQ "comp"

-- | QuasiQuoter for creating a fragment shader.
--
-- Equivalent to calling @$(compileShaderQ "frag" [hlsl|...|])@ without
-- interpolation support.
frag :: QuasiQuoter
frag :: QuasiQuoter
frag = String -> QuasiQuoter
shaderQQ "frag"

-- | QuasiQuoter for creating a geometry shader.
--
-- Equivalent to calling @$(compileShaderQ "geom" [hlsl|...|])@ without
-- interpolation support.
geom :: QuasiQuoter
geom :: QuasiQuoter
geom = String -> QuasiQuoter
shaderQQ "geom"

-- | QuasiQuoter for creating a tessellation control shader.
--
-- Equivalent to calling @$(compileShaderQ "tesc" [hlsl|...|])@ without
-- interpolation support.
tesc :: QuasiQuoter
tesc :: QuasiQuoter
tesc = String -> QuasiQuoter
shaderQQ "tesc"

-- | QuasiQuoter for creating a tessellation evaluation shader.
--
-- Equivalent to calling @$(compileShaderQ "tese" [hlsl|...|])@ without
-- interpolation support.
tese :: QuasiQuoter
tese :: QuasiQuoter
tese = String -> QuasiQuoter
shaderQQ "tese"

-- | QuasiQuoter for creating a vertex shader.
--
-- Equivalent to calling @$(compileShaderQ "vert" [hlsl|...|])@ without
-- interpolation support.
vert :: QuasiQuoter
vert :: QuasiQuoter
vert = String -> QuasiQuoter
shaderQQ "vert"

shaderQQ :: String -> QuasiQuoter
shaderQQ :: String -> QuasiQuoter
shaderQQ stage :: String
stage = (String -> QuasiQuoter
badQQ String
stage) { quoteExp :: String -> Q Exp
quoteExp = String -> String -> Q Exp
compileShaderQ String
stage }

-- * Utilities

-- | Compile a HLSL shader to SPIR-V using glslc (from the shaderc project)
--
-- Messages are converted to GHC warnings or errors depending on compilation success.
compileShaderQ
  :: String
  -- ^ stage
  -> String
  -- ^ glsl or code
  -> Q Exp
  -- ^ Spir-V bytecode
compileShaderQ :: String -> String -> Q Exp
compileShaderQ stage :: String
stage code :: String
code = do
  Loc
loc                <- Q Loc
location
  (warnings :: [String]
warnings, result :: Either [String] ByteString
result) <- Maybe Loc
-> String -> String -> Q ([String], Either [String] ByteString)
forall (m :: * -> *).
MonadIO m =>
Maybe Loc
-> String -> String -> m ([String], Either [String] ByteString)
compileShader (Loc -> Maybe Loc
forall a. a -> Maybe a
Just Loc
loc) String
stage String
code
  case [String]
warnings of
    []    -> () -> Q ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    _some :: [String]
_some -> String -> Q ()
reportWarning (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
prepare [String]
warnings

  ByteString
bs <- case Either [String] ByteString
result of
    Left []     -> String -> Q ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "glslc failed with no errors"
    Left errors :: [String]
errors -> do
      String -> Q ()
reportError (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
prepare [String]
errors
      ByteString -> Q ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
forall a. Monoid a => a
mempty
    Right bs :: ByteString
bs -> ByteString -> Q ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs

  ByteString -> Q Exp
bsToExp ByteString
bs

 where
  prepare :: [String] -> String
prepare [singleLine :: String
singleLine] = String
singleLine
  prepare multiline :: [String]
multiline =
    String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate "\n" ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ "glslc:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: (String -> String) -> [String] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> String -> String
forall a. Monoid a => a -> a -> a
mappend "        ") [String]
multiline

type ShadercError = String
type ShadercWarning = String

-- | Compile a HLSL shader to spir-v using glslc
compileShader
  :: MonadIO m
  => Maybe Loc
  -- ^ Source location
  -> String
  -- ^ stage
  -> String
  -- ^ HLSL code
  -> m ([ShadercWarning], Either [ShadercError] ByteString)
  -- ^ Spir-V bytecode with warnings or errors
compileShader :: Maybe Loc
-> String -> String -> m ([String], Either [String] ByteString)
compileShader loc :: Maybe Loc
loc stage :: String
stage code :: String
code =
  IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ([String], Either [String] ByteString)
 -> m ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ String
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory "th-shader" ((String -> IO ([String], Either [String] ByteString))
 -> IO ([String], Either [String] ByteString))
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ \dir :: String
dir -> do
    let codeWithLineDirective :: String
codeWithLineDirective = String -> (Loc -> String) -> Maybe Loc -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
code (String -> Loc -> String
insertLineDirective String
code) Maybe Loc
loc
    let shader :: String
shader = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader.hlsl"
        spirv :: String
spirv  = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader.spv"
    String -> String -> IO ()
writeFile String
shader String
codeWithLineDirective

    (rc :: ExitCode
rc, out :: ByteString
out, err :: ByteString
err) <- ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall (m :: * -> *) stdin stdoutIgnored stderrIgnored.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderrIgnored
-> m (ExitCode, ByteString, ByteString)
readProcess (ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString))
-> ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ProcessConfig () () ()
proc
      "glslc"
      ["-fshader-stage=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
stage, "-x", "hlsl", String
shader, "-o", String
spirv]
    let (warnings :: [String]
warnings, errors :: [String]
errors) = ByteString -> ([String], [String])
processShadercMessages (ByteString
out ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
err)
    case ExitCode
rc of
      ExitSuccess -> do
        ByteString
bs <- String -> IO ByteString
BS.readFile String
spirv
        ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, ByteString -> Either [String] ByteString
forall a b. b -> Either a b
Right ByteString
bs)
      ExitFailure _rc :: Int
_rc -> ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, [String] -> Either [String] ByteString
forall a b. a -> Either a b
Left [String]
errors)

processShadercMessages :: BSL.ByteString -> ([ShadercWarning], [ShadercError])
processShadercMessages :: ByteString -> ([String], [String])
processShadercMessages = (String -> ([String], [String]))
-> [String] -> ([String], [String])
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap String -> ([String], [String])
parseMsg ([String] -> ([String], [String]))
-> (ByteString -> [String]) -> ByteString -> ([String], [String])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
lines (String -> [String])
-> (ByteString -> String) -> ByteString -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
BSL.unpack

-- >>> parseMsg "blah"
-- ([],[])
--
-- >>> parseMsg "blah"
-- ([],["blah"])
--
-- >>> parseMsg "foo:2: error: unknown var"
-- ([],["foo:2: unknown var"])
--
-- >>> parseMsg "foo:2: warning: unknown var"
-- (["foo:2: unknown var"],[])
--
-- >>> parseMsg "bar:2: error: 'a' : unknown variable"
-- ([],["bar:2: 'a' : unknown variable"])
--
-- >>> parseMsg "f:o: error: f:o:2: 'a' : unknown variable"
-- ([],["f:o:2: 'a' : unknown variable"])
--
-- >>> parseMsg "f:o: error: f:o:2: 'return' : type does not match, or is not convertible to, the function's return type"
-- ([],["f:o:2: 'return' : type does not match, or is not convertible to, the function's return type"])
--
-- >>> parseMsg "foo: foo(1): error at column 3, HLSL parsing failed."
-- ([],["foo:1: error at column 3, HLSL parsing failed."])
parseMsg :: String -> ([ShadercWarning], [ShadercError])
parseMsg :: String -> ([String], [String])
parseMsg = ReadP ([String], [String]) -> String -> ([String], [String])
forall p. Monoid p => ReadP p -> String -> p
runParser (ReadP ([String], [String]) -> String -> ([String], [String]))
-> ReadP ([String], [String]) -> String -> ([String], [String])
forall a b. (a -> b) -> a -> b
$ (ReadP ([String], [String])
 -> ReadP ([String], [String]) -> ReadP ([String], [String]))
-> [ReadP ([String], [String])] -> ReadP ([String], [String])
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1
  ReadP ([String], [String])
-> ReadP ([String], [String]) -> ReadP ([String], [String])
forall a. ReadP a -> ReadP a -> ReadP a
(<++)
  [ do
    String
f    <- ReadP String
filename
    Integer
line <- ReadP () -> ReadP () -> ReadP Integer -> ReadP Integer
forall open close a.
ReadP open -> ReadP close -> ReadP a -> ReadP a
between ReadP ()
colon ReadP ()
colon ReadP Integer
number
    ReadP ()
skipSpaces
    String -> ([String], [String])
t   <- ReadP (String -> ([String], [String]))
forall a. ReadP (a -> ([a], [a]))
msgType
    String
msg <- ReadP Char -> ReadP () -> ReadP String
forall a end. ReadP a -> ReadP end -> ReadP [a]
manyTill ReadP Char
get ReadP ()
eof
    ([String], [String]) -> ReadP ([String], [String])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([String], [String]) -> ReadP ([String], [String]))
-> ([String], [String]) -> ReadP ([String], [String])
forall a b. (a -> b) -> a -> b
$ (String -> ([String], [String]))
-> String -> Integer -> String -> ([String], [String])
forall a t. Show a => (String -> t) -> String -> a -> String -> t
formatMsg String -> ([String], [String])
t String
f Integer
line String
msg
  , do
    String
f <- ReadP String
filename
    ReadP ()
colon ReadP () -> ReadP () -> ReadP ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ReadP ()
skipSpaces
    String -> ([String], [String])
t    <- ReadP (String -> ([String], [String]))
forall a. ReadP (a -> ([a], [a]))
msgType
    String
_    <- String -> ReadP String
string String
f
    Integer
line <- ReadP Char -> ReadP Char -> ReadP Integer -> ReadP Integer
forall open close a.
ReadP open -> ReadP close -> ReadP a -> ReadP a
between (Char -> ReadP Char
char ':') (Char -> ReadP Char
char ':') ReadP Integer
number
    ReadP ()
skipSpaces
    String
msg <- ReadP Char -> ReadP () -> ReadP String
forall a end. ReadP a -> ReadP end -> ReadP [a]
manyTill ReadP Char
get ReadP ()
eof
    ([String], [String]) -> ReadP ([String], [String])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([String], [String]) -> ReadP ([String], [String]))
-> ([String], [String]) -> ReadP ([String], [String])
forall a b. (a -> b) -> a -> b
$ (String -> ([String], [String]))
-> String -> Integer -> String -> ([String], [String])
forall a t. Show a => (String -> t) -> String -> a -> String -> t
formatMsg String -> ([String], [String])
t String
f Integer
line String
msg
  , do
    String
f <- ReadP String
filename
    ReadP ()
colon ReadP () -> ReadP () -> ReadP ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ReadP ()
skipSpaces
    String
_    <- String -> ReadP String
string String
f
    Integer
line <- ReadP Char -> ReadP Char -> ReadP Integer -> ReadP Integer
forall open close a.
ReadP open -> ReadP close -> ReadP a -> ReadP a
between (Char -> ReadP Char
char '(') (Char -> ReadP Char
char ')') ReadP Integer
number
    ReadP ()
colon ReadP () -> ReadP () -> ReadP ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ReadP ()
skipSpaces
    let t :: a -> ([a], [a])
t x :: a
x = ([], [a
x])
    String
msg <- ReadP Char -> ReadP () -> ReadP String
forall a end. ReadP a -> ReadP end -> ReadP [a]
manyTill ReadP Char
get ReadP ()
eof
    ([String], [String]) -> ReadP ([String], [String])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([String], [String]) -> ReadP ([String], [String]))
-> ([String], [String]) -> ReadP ([String], [String])
forall a b. (a -> b) -> a -> b
$ (String -> ([String], [String]))
-> String -> Integer -> String -> ([String], [String])
forall a t. Show a => (String -> t) -> String -> a -> String -> t
formatMsg String -> ([String], [String])
forall a a. a -> ([a], [a])
t String
f Integer
line String
msg
  , do
    Integer
_ <- ReadP Integer
number
    ReadP ()
skipSpaces
    String
_ <- String -> ReadP String
string "errors generated"
    ReadP ()
eof
    ([String], [String]) -> ReadP ([String], [String])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [])
  , do
    -- Unknown format
    String
msg <- ReadP Char -> ReadP () -> ReadP String
forall a end. ReadP a -> ReadP end -> ReadP [a]
manyTill ReadP Char
get ReadP ()
eof
    ReadP ()
eof
    ([String], [String]) -> ReadP ([String], [String])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [String
msg])
  ]
 where
  formatMsg :: (String -> t) -> String -> a -> String -> t
formatMsg t :: String -> t
t f :: String
f line :: a
line msg :: String
msg = String -> t
t (String
f String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ":" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
line String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ": " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg)
  filename :: ReadP String
filename = ReadP Char -> ReadP String
forall a. ReadP a -> ReadP [a]
many1 ReadP Char
get
  number :: ReadP Integer
number   = ReadS Integer -> ReadP Integer
forall a. ReadS a -> ReadP a
readS_to_P (Read Integer => ReadS Integer
forall a. Read a => ReadS a
reads @Integer)
  colon :: ReadP ()
colon    = ReadP Char -> ReadP ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ReadP Char -> ReadP ()) -> ReadP Char -> ReadP ()
forall a b. (a -> b) -> a -> b
$ Char -> ReadP Char
char ':'
  msgType :: ReadP (a -> ([a], [a]))
msgType =
    [ReadP (a -> ([a], [a]))] -> ReadP (a -> ([a], [a]))
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
        [ (\x :: a
x -> ([], [a
x])) (a -> ([a], [a])) -> ReadP String -> ReadP (a -> ([a], [a]))
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ String -> ReadP String
string "error"
        , (\x :: a
x -> ([a
x], [])) (a -> ([a], [a])) -> ReadP String -> ReadP (a -> ([a], [a]))
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ String -> ReadP String
string "warning"
        ]
      ReadP (a -> ([a], [a])) -> ReadP () -> ReadP (a -> ([a], [a]))
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ReadP ()
colon
      ReadP (a -> ([a], [a])) -> ReadP () -> ReadP (a -> ([a], [a]))
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ReadP ()
skipSpaces

runParser :: Monoid p => ReadP p -> String -> p
runParser :: ReadP p -> String -> p
runParser p :: ReadP p
p s :: String
s = case ReadP p -> ReadS p
forall a. ReadP a -> ReadS a
readP_to_S ReadP p
p String
s of
  [(r :: p
r, "")] -> p
r
  _         -> p
forall a. Monoid a => a
mempty

-- Insert a #line directive with the specified location at the top of the file
insertLineDirective :: String -> Loc -> String
insertLineDirective :: String -> Loc -> String
insertLineDirective code :: String
code Loc {..} =
  let lineDirective :: String
lineDirective =
        "#line " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (CharPos -> Int
forall a b. (a, b) -> a
fst CharPos
loc_start) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> " \"" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
loc_filename String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "\""
  in  String
lineDirective String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
code

----------------------------------------------------------------
-- Utils
----------------------------------------------------------------

badQQ :: String -> QuasiQuoter
badQQ :: String -> QuasiQuoter
badQQ name :: String
name = (String -> Q Exp)
-> (String -> Q Pat)
-> (String -> Q Type)
-> (String -> Q [Dec])
-> QuasiQuoter
QuasiQuoter (String -> String -> Q Exp
forall a. String -> a
bad "expression")
                         (String -> String -> Q Pat
forall a. String -> a
bad "pattern")
                         (String -> String -> Q Type
forall a. String -> a
bad "type")
                         (String -> String -> Q [Dec]
forall a. String -> a
bad "declaration")
 where
  bad :: String -> a
  bad :: String -> a
bad context :: String
context =
    String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ "Can't use " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> " quote in a " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
context String -> String -> String
forall a. Semigroup a => a -> a -> a
<> " context"