summaryrefslogtreecommitdiff
path: root/src/Language
diff options
context:
space:
mode:
authorEugen Wissner <belka@caraus.de>2020-10-04 18:51:21 +0200
committerEugen Wissner <belka@caraus.de>2020-10-05 14:51:21 +0200
commita91bc7f2d218ea2df308d3968587b60351625150 (patch)
tree3c3170437b0c903e2c63540c028c1aaa4ff35c17 /src/Language
parentd5f518fe827d3d279d6c37740820f296689539e4 (diff)
downloadgraphql-a91bc7f2d218ea2df308d3968587b60351625150.tar.gz
Validate required input fields
Diffstat (limited to 'src/Language')
-rw-r--r--src/Language/GraphQL/AST/Document.hs14
-rw-r--r--src/Language/GraphQL/AST/Encoder.hs11
-rw-r--r--src/Language/GraphQL/AST/Parser.hs6
-rw-r--r--src/Language/GraphQL/Execute/Transform.hs12
-rw-r--r--src/Language/GraphQL/Validate.hs47
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs65
-rw-r--r--src/Language/GraphQL/Validate/Validation.hs2
7 files changed, 110 insertions, 47 deletions
diff --git a/src/Language/GraphQL/AST/Document.hs b/src/Language/GraphQL/AST/Document.hs
index 489a242..b30271c 100644
--- a/src/Language/GraphQL/AST/Document.hs
+++ b/src/Language/GraphQL/AST/Document.hs
@@ -1,5 +1,7 @@
+{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Safe #-}
-- | This module defines an abstract syntax tree for the @GraphQL@ language. It
@@ -75,10 +77,13 @@ instance Ord Location where
-- | Contains some tree node with a location.
data Node a = Node
- { value :: a
+ { node :: a
, location :: Location
} deriving (Eq, Show)
+instance Functor Node where
+ fmap f Node{..} = Node (f node) location
+
-- ** Document
-- | GraphQL document.
@@ -241,8 +246,11 @@ data ConstValue
-- | Key-value pair.
--
-- A list of 'ObjectField's represents a GraphQL object type.
-data ObjectField a = ObjectField Name a Location
- deriving (Eq, Show)
+data ObjectField a = ObjectField
+ { name :: Name
+ , value :: Node a
+ , location :: Location
+ } deriving (Eq, Show)
-- ** Variables
diff --git a/src/Language/GraphQL/AST/Encoder.hs b/src/Language/GraphQL/AST/Encoder.hs
index 51a801e..9b3eea3 100644
--- a/src/Language/GraphQL/AST/Encoder.hs
+++ b/src/Language/GraphQL/AST/Encoder.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Safe #-}
-- | This module defines a minifier and a printer for the @GraphQL@ language.
@@ -101,7 +102,7 @@ variableDefinition formatter variableDefinition' =
in variable variableName
<> eitherFormat formatter ": " ":"
<> type' variableType
- <> maybe mempty (defaultValue formatter) (Full.value <$> defaultValue')
+ <> maybe mempty (defaultValue formatter) (Full.node <$> defaultValue')
defaultValue :: Formatter -> Full.ConstValue -> Lazy.Text
defaultValue formatter val
@@ -164,7 +165,7 @@ argument :: Formatter -> Full.Argument -> Lazy.Text
argument formatter (Full.Argument name value' _)
= Lazy.Text.fromStrict name
<> colon formatter
- <> value formatter (Full.value value')
+ <> value formatter (Full.node value')
-- * Fragments
@@ -222,8 +223,8 @@ fromConstValue (Full.ConstEnum x) = Full.Enum x
fromConstValue (Full.ConstList x) = Full.List $ fromConstValue <$> x
fromConstValue (Full.ConstObject x) = Full.Object $ fromConstObjectField <$> x
where
- fromConstObjectField (Full.ObjectField key value' location) =
- Full.ObjectField key (fromConstValue value') location
+ fromConstObjectField Full.ObjectField{value = value', ..} =
+ Full.ObjectField name (fromConstValue <$> value') location
booleanValue :: Bool -> Lazy.Text
booleanValue True = "true"
@@ -292,7 +293,7 @@ objectValue formatter = intercalate $ objectField formatter
. fmap f
objectField :: Formatter -> Full.ObjectField Full.Value -> Lazy.Text
-objectField formatter (Full.ObjectField name value' _) =
+objectField formatter (Full.ObjectField name (Full.Node value' _) _) =
Lazy.Text.fromStrict name <> colon formatter <> value formatter value'
-- | Converts a 'Type' a type into a string.
diff --git a/src/Language/GraphQL/AST/Parser.hs b/src/Language/GraphQL/AST/Parser.hs
index 2695e6f..46c8fa3 100644
--- a/src/Language/GraphQL/AST/Parser.hs
+++ b/src/Language/GraphQL/AST/Parser.hs
@@ -461,7 +461,7 @@ value = Full.Variable <$> variable
<|> Full.String <$> stringValue
<|> Full.Enum <$> try enumValue
<|> Full.List <$> brackets (some value)
- <|> Full.Object <$> braces (some $ objectField value)
+ <|> Full.Object <$> braces (some $ objectField $ valueNode value)
<?> "Value"
constValue :: Parser Full.ConstValue
@@ -472,7 +472,7 @@ constValue = Full.ConstFloat <$> try float
<|> Full.ConstString <$> stringValue
<|> Full.ConstEnum <$> try enumValue
<|> Full.ConstList <$> brackets (some constValue)
- <|> Full.ConstObject <$> braces (some $ objectField constValue)
+ <|> Full.ConstObject <$> braces (some $ objectField $ valueNode constValue)
<?> "Value"
booleanValue :: Parser Bool
@@ -493,7 +493,7 @@ stringValue = blockString <|> string <?> "StringValue"
nullValue :: Parser Text
nullValue = symbol "null" <?> "NullValue"
-objectField :: Parser a -> Parser (Full.ObjectField a)
+objectField :: forall a. Parser (Full.Node a) -> Parser (Full.ObjectField a)
objectField valueParser = label "ObjectField" $ do
location <- getLocation
fieldName <- name
diff --git a/src/Language/GraphQL/Execute/Transform.hs b/src/Language/GraphQL/Execute/Transform.hs
index 80e7a83..a8a2ae2 100644
--- a/src/Language/GraphQL/Execute/Transform.hs
+++ b/src/Language/GraphQL/Execute/Transform.hs
@@ -153,7 +153,7 @@ coerceVariableValues types operationDefinition variableValues =
forEach variableDefinition coercedValues = do
let Full.VariableDefinition variableName variableTypeName defaultValue _ =
variableDefinition
- let defaultValue' = constValue . Full.value <$> defaultValue
+ let defaultValue' = constValue . Full.node <$> defaultValue
variableType <- lookupInputType variableTypeName types
Coerce.matchFieldValues
@@ -178,7 +178,8 @@ constValue (Full.ConstList l) = Type.List $ constValue <$> l
constValue (Full.ConstObject o) =
Type.Object $ HashMap.fromList $ constObjectField <$> o
where
- constObjectField (Full.ObjectField key value' _) = (key, constValue value')
+ constObjectField Full.ObjectField{value = value', ..} =
+ (name, constValue $ Full.node value')
-- | Rewrites the original syntax tree into an intermediate representation used
-- for query execution.
@@ -384,7 +385,8 @@ value (Full.List list) = Type.List <$> traverse value list
value (Full.Object object) =
Type.Object . HashMap.fromList <$> traverse objectField object
where
- objectField (Full.ObjectField name value' _) = (name,) <$> value value'
+ objectField Full.ObjectField{value = value', ..} =
+ (name,) <$> value (Full.node value')
input :: forall m. Full.Value -> State (Replacement m) (Maybe Input)
input (Full.Variable name) =
@@ -400,8 +402,8 @@ input (Full.Object object) = do
objectFields <- foldM objectField HashMap.empty object
pure $ pure $ Object objectFields
where
- objectField resultMap (Full.ObjectField name value' _) =
- inputField resultMap name value'
+ objectField resultMap Full.ObjectField{value = value', ..} =
+ inputField resultMap name $ Full.node value'
inputField :: forall m
. HashMap Full.Name Input
diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs
index 5acb26a..d904e8c 100644
--- a/src/Language/GraphQL/Validate.hs
+++ b/src/Language/GraphQL/Validate.hs
@@ -3,6 +3,7 @@
obtain one at https://mozilla.org/MPL/2.0/. -}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -349,25 +350,25 @@ variableDefinition :: forall m
variableDefinition context rule (Full.VariableDefinition _ typeName value' _)
| Just defaultValue' <- value'
, variableType <- lookupInputType typeName $ Validation.types context =
- constValue rule variableType $ Full.value defaultValue'
+ constValue rule variableType defaultValue'
variableDefinition _ _ _ = mempty
constValue :: forall m
. Validation.Rule m
-> Maybe In.Type
- -> Full.ConstValue
+ -> Full.Node Full.ConstValue
-> Seq (Validation.RuleT m)
constValue (Validation.ValueRule _ rule) valueType = go valueType
where
- go inputObjectType value'@(Full.ConstObject fields)
+ go inputObjectType value'@(Full.Node (Full.ConstObject fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value'
- go listType value'@(Full.ConstList values)
- = foldMap (go $ valueTypeFromList listType) (Seq.fromList values)
+ go listType value'@(Full.Node (Full.ConstList values) location')
+ = embedListLocation go listType values location'
|> rule listType value'
go anotherValue value' = pure $ rule anotherValue value'
- forEach inputObjectType (Full.ObjectField fieldName fieldValue _) =
- go (valueTypeByName fieldName inputObjectType) fieldValue
+ forEach inputObjectType Full.ObjectField{value = value', ..} =
+ go (valueTypeByName name inputObjectType) value'
constValue _ _ = const mempty
inputFieldType :: In.InputField -> In.Type
@@ -379,10 +380,6 @@ valueTypeByName fieldName (Just( In.InputObjectBaseType inputObjectType)) =
in inputFieldType <$> HashMap.lookup fieldName fieldTypes
valueTypeByName _ _ = Nothing
-valueTypeFromList :: Maybe In.Type -> Maybe In.Type
-valueTypeFromList (Just (In.ListBaseType listType)) = Just listType
-valueTypeFromList _ = Nothing
-
fragmentDefinition :: forall m
. Validation.Rule m
-> Validation m
@@ -468,26 +465,40 @@ argument :: forall m
-> Full.Argument
-> Seq (Validation.RuleT m)
argument rule argumentType (Full.Argument _ value' _) =
- value rule (valueType <$> argumentType) $ Full.value value'
+ value rule (valueType <$> argumentType) value'
where
valueType (In.Argument _ valueType' _) = valueType'
+-- valueTypeFromList :: Maybe In.Type -> Maybe In.Type
+embedListLocation :: forall a m
+ . (Maybe In.Type -> Full.Node a -> Seq m)
+ -> Maybe In.Type
+ -> [a]
+ -> Full.Location
+ -> Seq m
+embedListLocation go listType values location'
+ = foldMap (go $ valueTypeFromList listType)
+ $ flip Full.Node location' <$> Seq.fromList values
+ where
+ valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType
+ valueTypeFromList _ = Nothing
+
value :: forall m
. Validation.Rule m
-> Maybe In.Type
- -> Full.Value
+ -> Full.Node Full.Value
-> Seq (Validation.RuleT m)
value (Validation.ValueRule rule _) valueType = go valueType
where
- go inputObjectType value'@(Full.Object fields)
+ go inputObjectType value'@(Full.Node (Full.Object fields) _)
= foldMap (forEach inputObjectType) (Seq.fromList fields)
|> rule inputObjectType value'
- go listType value'@(Full.List values)
- = foldMap (go $ valueTypeFromList listType) (Seq.fromList values)
+ go listType value'@(Full.Node (Full.List values) location')
+ = embedListLocation go listType values location'
|> rule listType value'
go anotherValue value' = pure $ rule anotherValue value'
- forEach inputObjectType (Full.ObjectField fieldName fieldValue _) =
- go (valueTypeByName fieldName inputObjectType) fieldValue
+ forEach inputObjectType Full.ObjectField{value = value', ..} =
+ go (valueTypeByName name inputObjectType) value'
value _ _ = const mempty
inlineFragment :: forall m
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index a5754c6..11f4482 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules
, noUndefinedVariablesRule
, noUnusedFragmentsRule
, noUnusedVariablesRule
+ , providedRequiredInputFieldsRule
, providedRequiredArgumentsRule
, scalarLeafsRule
, singleFieldSubscriptionsRule
@@ -91,6 +92,7 @@ specifiedRules =
-- Values
, knownInputFieldNamesRule
, uniqueInputFieldNamesRule
+ , providedRequiredInputFieldsRule
-- Directives.
, knownDirectiveNamesRule
, directivesInValidLocationsRule
@@ -646,7 +648,7 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
findDirectiveVariables (Directive _ arguments _) = mapArguments arguments
mapArguments = Seq.fromList . mapMaybe findArgumentVariables
mapDirectives = foldMap findDirectiveVariables
- findArgumentVariables (Argument _ Node{ value = Variable value', ..} _) =
+ findArgumentVariables (Argument _ Node{ node = Variable value', ..} _) =
Just (value', [location])
findArgumentVariables _ = Nothing
makeError operationName (variableName, locations') = Error
@@ -680,12 +682,12 @@ uniqueInputFieldNamesRule :: forall m. Rule m
uniqueInputFieldNamesRule =
ValueRule (const $ lift . go) (const $ lift . constGo)
where
- go (Object fields) = filterFieldDuplicates fields
+ go (Node (Object fields) _) = filterFieldDuplicates fields
go _ = mempty
filterFieldDuplicates fields =
filterDuplicates getFieldName "input field" fields
getFieldName (ObjectField fieldName _ location') = (fieldName, location')
- constGo (ConstObject fields) = filterFieldDuplicates fields
+ constGo (Node (ConstObject fields) _) = filterFieldDuplicates fields
constGo _ = mempty
-- | The target field of a field selection must be defined on the scoped type of
@@ -848,11 +850,11 @@ knownDirectiveNamesRule = DirectivesRule $ const $ \directives' -> do
knownInputFieldNamesRule :: Rule m
knownInputFieldNamesRule = ValueRule go constGo
where
- go (Just valueType) (Object inputFields)
+ go (Just valueType) (Node (Object inputFields) _)
| In.InputObjectBaseType objectType <- valueType =
lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields
go _ _ = lift mempty
- constGo (Just valueType) (ConstObject inputFields)
+ constGo (Just valueType) (Node (ConstObject inputFields) _)
| In.InputObjectBaseType objectType <- valueType =
lift $ Seq.fromList $ mapMaybe (forEach objectType) inputFields
constGo _ _ = lift mempty
@@ -915,13 +917,6 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule
let forEach = go (directiveMessage directiveName) arguments location'
in lift $ HashMap.foldrWithKey forEach Seq.empty definitions
_ -> lift mempty
- inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) =
- typeName
- inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) =
- typeName
- inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) =
- typeName
- inputTypeName (In.ListBaseType listType) = inputTypeName listType
go makeMessage arguments location' argumentName argumentType errors
| In.Argument _ type' optionalValue <- argumentType
, In.isNonNullType type'
@@ -956,3 +951,49 @@ providedRequiredArgumentsRule = ArgumentsRule fieldRule directiveRule
, Text.unpack typeName
, "\" is required, but it was not provided."
]
+
+inputTypeName :: In.Type -> Text
+inputTypeName (In.ScalarBaseType (Definition.ScalarType typeName _)) = typeName
+inputTypeName (In.EnumBaseType (Definition.EnumType typeName _ _)) = typeName
+inputTypeName (In.InputObjectBaseType (In.InputObjectType typeName _ _)) =
+ typeName
+inputTypeName (In.ListBaseType listType) = inputTypeName listType
+
+-- | Input object fields may be required. Much like a field may have required
+-- arguments, an input object may have required fields. An input field is
+-- required if it has a non‐null type and does not have a default value.
+-- Otherwise, the input object field is optional.
+providedRequiredInputFieldsRule :: Rule m
+providedRequiredInputFieldsRule = ValueRule go constGo
+ where
+ go (Just valueType) (Node (Object inputFields) location')
+ | In.InputObjectBaseType objectType <- valueType
+ , In.InputObjectType objectTypeName _ fieldDefinitions <- objectType
+ = lift
+ $ Seq.fromList
+ $ HashMap.elems
+ $ flip HashMap.mapMaybeWithKey fieldDefinitions
+ $ forEach inputFields objectTypeName location'
+ go _ _ = lift mempty
+ constGo _ _ = lift mempty
+ forEach inputFields typeName location' definitionName fieldDefinition
+ | In.InputField _ inputType optionalValue <- fieldDefinition
+ , In.isNonNullType inputType
+ , isNothing optionalValue
+ , isNothingOrNull $ find (lookupField definitionName) inputFields =
+ Just $ makeError definitionName typeName location'
+ | otherwise = Nothing
+ isNothingOrNull (Just (ObjectField _ (Node Null _) _)) = True
+ isNothingOrNull x = isNothing x
+ lookupField needle (ObjectField fieldName _ _) = needle == fieldName
+ makeError fieldName typeName location' = Error
+ { message = errorMessage fieldName typeName
+ , locations = [location']
+ }
+ errorMessage fieldName typeName = concat
+ [ "Input field \""
+ , Text.unpack fieldName
+ , "\" of type \""
+ , Text.unpack typeName
+ , "\" is required, but it was not provided."
+ ]
diff --git a/src/Language/GraphQL/Validate/Validation.hs b/src/Language/GraphQL/Validate/Validation.hs
index 32a454e..7ffab10 100644
--- a/src/Language/GraphQL/Validate/Validation.hs
+++ b/src/Language/GraphQL/Validate/Validation.hs
@@ -48,7 +48,7 @@ data Rule m
| ArgumentsRule (Maybe (Out.Type m) -> Field -> RuleT m) (Directive -> RuleT m)
| DirectivesRule (DirectiveLocation -> [Directive] -> RuleT m)
| VariablesRule ([VariableDefinition] -> RuleT m)
- | ValueRule (Maybe In.Type -> Value -> RuleT m) (Maybe In.Type -> ConstValue -> RuleT m)
+ | ValueRule (Maybe In.Type -> Node Value -> RuleT m) (Maybe In.Type -> Node ConstValue -> RuleT m)
-- | Monad transformer used by the rules.
type RuleT m = ReaderT (Validation m) Seq Error