summaryrefslogtreecommitdiff
path: root/src/Language/GraphQL/Execute/Subscribe.hs
blob: 5d8d29490e3d35a3f12ba86e76842b7a8f1ff992 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
{- This Source Code Form is subject to the terms of the Mozilla Public License,
   v. 2.0. If a copy of the MPL was not distributed with this file, You can
   obtain one at https://mozilla.org/MPL/2.0/. -}

{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-}
module Language.GraphQL.Execute.Subscribe
    ( subscribe
    ) where

import Conduit
import Control.Arrow (left)
import Control.Monad.Catch (Exception(..), MonadCatch(..))
import Control.Monad.Trans.Reader (ReaderT(..), runReaderT)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import qualified Data.List.NonEmpty as NonEmpty
import Data.Sequence (Seq(..))
import qualified Language.GraphQL.AST as Full
import Language.GraphQL.Execute.Coerce
import Language.GraphQL.Execute.Execution
import Language.GraphQL.Execute.Internal
import qualified Language.GraphQL.Execute.OrderedMap as OrderedMap
import qualified Language.GraphQL.Execute.Transform as Transform
import Language.GraphQL.Error
    ( Error(..)
    , ResolverException
    , Response
    , ResponseEventStream
    , runCollectErrs
    )
import qualified Language.GraphQL.Type.Definition as Definition
import qualified Language.GraphQL.Type as Type
import qualified Language.GraphQL.Type.Out as Out
import Language.GraphQL.Type.Schema

subscribe :: (MonadCatch m, Serialize a)
    => HashMap Full.Name (Type m)
    -> Out.ObjectType m
    -> Full.Location
    -> Seq (Transform.Selection m)
    -> m (Either Error (ResponseEventStream m a))
subscribe types' objectType objectLocation fields = do
    sourceStream <-
        createSourceEventStream types' objectType objectLocation fields
    let traverser =
            mapSourceToResponseEvent types' objectType objectLocation fields
    traverse traverser sourceStream

mapSourceToResponseEvent :: (MonadCatch m, Serialize a)
    => HashMap Full.Name (Type m)
    -> Out.ObjectType m
    -> Full.Location
    -> Seq (Transform.Selection m)
    -> Out.SourceEventStream m
    -> m (ResponseEventStream m a)
mapSourceToResponseEvent types' subscriptionType objectLocation fields sourceStream
    = pure
    $ sourceStream
    .| mapMC (executeSubscriptionEvent types' subscriptionType objectLocation fields)

createSourceEventStream :: MonadCatch m
    => HashMap Full.Name (Type m)
    -> Out.ObjectType m
    -> Full.Location
    -> Seq (Transform.Selection m)
    -> m (Either Error (Out.SourceEventStream m))
createSourceEventStream _types subscriptionType objectLocation fields
    | [fieldGroup] <- OrderedMap.elems groupedFieldSet
    , Transform.Field _ fieldName arguments' _ errorLocation <- NonEmpty.head fieldGroup
    , Out.ObjectType _ _ _ fieldTypes <- subscriptionType
    , resolverT <- fieldTypes HashMap.! fieldName
    , Out.EventStreamResolver fieldDefinition _ resolver <- resolverT
    , Out.Field _ _fieldType argumentDefinitions <- fieldDefinition =
        case coerceArgumentValues argumentDefinitions arguments' of
            Left _ -> pure
                $ Left
                $ Error "Argument coercion failed." [errorLocation] []
            Right  argumentValues -> left (singleError [errorLocation])
                <$> resolveFieldEventStream Type.Null argumentValues resolver
    | otherwise = pure
        $ Left
        $ Error "Subscription contains more than one field." [objectLocation] []
  where
    groupedFieldSet = collectFields subscriptionType fields

resolveFieldEventStream :: MonadCatch m
    => Type.Value
    -> Type.Subs
    -> Out.Subscribe m
    -> m (Either String (Out.SourceEventStream m))
resolveFieldEventStream result args resolver =
    catch (Right <$> runReaderT resolver context) handleEventStreamError
  where
    handleEventStreamError :: MonadCatch m
        => ResolverException
        -> m (Either String (Out.SourceEventStream m))
    handleEventStreamError = pure . Left . displayException
    context = Type.Context
        { Type.arguments = Type.Arguments args
        , Type.values = result
        }

executeSubscriptionEvent :: (MonadCatch m, Serialize a)
    => HashMap Full.Name (Type m)
    -> Out.ObjectType m
    -> Full.Location
    -> Seq (Transform.Selection m)
    -> Definition.Value
    -> m (Response a)
executeSubscriptionEvent types' objectType objectLocation fields initialValue
    = runCollectErrs types'
    $ executeSelectionSet initialValue objectType objectLocation fields