{-# LINE 1 "src/Hookup/OpenSSL.hsc" #-}
{-# Language CApiFFI #-}
{-|
Module      : Hookup.OpenSSL
Description : Hack into the internals of OpenSSL to add missing functionality
Copyright   : (c) Eric Mertens, 2016
License     : ISC
Maintainer  : emertens@gmail.com
-}






{-# LINE 17 "src/Hookup/OpenSSL.hsc" #-}

module Hookup.OpenSSL (installVerification, getPubKeyDer) where

import           Control.Monad (unless)
import           Foreign.C (CString(..), CSize(..), CUInt(..), CInt(..), withCStringLen)
import           Foreign.Ptr (Ptr, castPtr, nullPtr)
import           Foreign.Marshal (with)
import           OpenSSL.Session (SSLContext, SSLContext_, withContext)
import           OpenSSL.X509 (withX509Ptr, X509, X509_)
import           Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B

------------------------------------------------------------------------
-- Bindings to hostname verification interface
------------------------------------------------------------------------

data X509_VERIFY_PARAM_
data {-# CTYPE "openssl/ssl.h" "X509_PUBKEY" #-} X509_PUBKEY_
data {-# CTYPE "openssl/ssl.h" "X509" #-} X509__

-- X509_VERIFY_PARAM *SSL_CTX_get0_param(SSL_CTX *ctx);
foreign import ccall unsafe "SSL_CTX_get0_param"
  sslGet0Param ::
    Ptr SSLContext_ {- ^ ctx -} ->
    IO (Ptr X509_VERIFY_PARAM_)

-- void X509_VERIFY_PARAM_set_hostflags(X509_VERIFY_PARAM *param, unsigned int flags);
foreign import ccall unsafe "X509_VERIFY_PARAM_set_hostflags"
  x509VerifyParamSetHostflags ::
    Ptr X509_VERIFY_PARAM_ {- ^ param -} ->
    CUInt                  {- ^ flags -} ->
    IO ()

-- int X509_VERIFY_PARAM_set1_host(X509_VERIFY_PARAM *param, const char *name, size_t namelen);
foreign import ccall unsafe "X509_VERIFY_PARAM_set1_host"
  x509VerifyParamSet1Host ::
    Ptr X509_VERIFY_PARAM_ {- ^ param                -} ->
    CString                {- ^ name                 -} ->
    CSize                  {- ^ namelen              -} ->
    IO CInt                {- ^ 1 success, 0 failure -}

-- X509_PUBKEY *X509_get_X509_PUBKEY(X509 *x);
foreign import capi unsafe "openssl/x509.h X509_get_X509_PUBKEY"
  x509getX509Pubkey ::
    Ptr X509__ -> IO (Ptr X509_PUBKEY_)

-- int i2d_X509_PUBKEY(X509_PUBKEY *p, unsigned char **ppout);
foreign import ccall unsafe "i2d_X509_PUBKEY"
  i2dX509Pubkey ::
    Ptr X509_PUBKEY_ ->
    Ptr CString ->
    IO CInt

getPubKeyDer :: X509 -> IO ByteString
getPubKeyDer :: X509 -> IO ByteString
getPubKeyDer x509 :: X509
x509 =
  X509 -> (Ptr X509_ -> IO ByteString) -> IO ByteString
forall a. X509 -> (Ptr X509_ -> IO a) -> IO a
withX509Ptr X509
x509 ((Ptr X509_ -> IO ByteString) -> IO ByteString)
-> (Ptr X509_ -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \x509ptr :: Ptr X509_
x509ptr ->
  do Ptr X509_PUBKEY_
pubkey <- Ptr X509__ -> IO (Ptr X509_PUBKEY_)
x509getX509Pubkey (Ptr X509_ -> Ptr X509__
forall a b. Ptr a -> Ptr b
castPtr Ptr X509_
x509ptr)
     Int
len    <- CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr X509_PUBKEY_ -> Ptr CString -> IO CInt
i2dX509Pubkey Ptr X509_PUBKEY_
pubkey Ptr CString
forall a. Ptr a
nullPtr
     Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create Int
len ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \bsPtr :: Ptr Word8
bsPtr ->
        CString -> (Ptr CString -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (Ptr Word8 -> CString
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
bsPtr) ((Ptr CString -> IO ()) -> IO ())
-> (Ptr CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ptrPtr :: Ptr CString
ptrPtr ->
           () () -> IO CInt -> IO ()
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Ptr X509_PUBKEY_ -> Ptr CString -> IO CInt
i2dX509Pubkey Ptr X509_PUBKEY_
pubkey Ptr CString
ptrPtr


-- | Add hostname checking to the certificate verification step.
-- Partial wildcards matching is disabled.
installVerification :: SSLContext -> String {- ^ hostname -} -> IO ()
installVerification :: SSLContext -> String -> IO ()
installVerification ctx :: SSLContext
ctx host :: String
host =
  SSLContext -> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a. SSLContext -> (Ptr SSLContext_ -> IO a) -> IO a
withContext SSLContext
ctx     ((Ptr SSLContext_ -> IO ()) -> IO ())
-> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ctxPtr :: Ptr SSLContext_
ctxPtr ->
  String -> (CStringLen -> IO ()) -> IO ()
forall a. String -> (CStringLen -> IO a) -> IO a
withCStringLen String
host ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(ptr :: CString
ptr,len :: Int
len) ->
    do Ptr X509_VERIFY_PARAM_
param <- Ptr SSLContext_ -> IO (Ptr X509_VERIFY_PARAM_)
sslGet0Param Ptr SSLContext_
ctxPtr
       Ptr X509_VERIFY_PARAM_ -> CUInt -> IO ()
x509VerifyParamSetHostflags Ptr X509_VERIFY_PARAM_
param
         (4)
{-# LINE 89 "src/Hookup/OpenSSL.hsc" #-}
       success <- x509VerifyParamSet1Host param ptr (fromIntegral len)
       Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (CInt
success CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== 1) (String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "Unable to set verification host")