summaryrefslogtreecommitdiff
path: root/src/Language/GraphQL/Validate/Rules.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/GraphQL/Validate/Rules.hs')
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs41
1 files changed, 30 insertions, 11 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index a6bc43b..46a14b7 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -1502,15 +1502,6 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case
hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False
hasNonNullVariableDefaultValue Nothing = False
hasNonNullVariableDefaultValue _ = True
- unwrapInType (In.NonNullScalarType nonNullType) =
- Just $ In.NamedScalarType nonNullType
- unwrapInType (In.NonNullEnumType nonNullType) =
- Just $ In.NamedEnumType nonNullType
- unwrapInType (In.NonNullInputObjectType nonNullType) =
- Just $ In.NamedInputObjectType nonNullType
- unwrapInType (In.NonNullListType nonNullType) =
- Just $ In.ListType nonNullType
- unwrapInType _ = Nothing
makeError variableDefinition expectedType =
let Full.VariableDefinition variableName variableType _ location' =
variableDefinition
@@ -1527,6 +1518,17 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case
, locations = [location']
}
+unwrapInType :: In.Type -> Maybe In.Type
+unwrapInType (In.NonNullScalarType nonNullType) =
+ Just $ In.NamedScalarType nonNullType
+unwrapInType (In.NonNullEnumType nonNullType) =
+ Just $ In.NamedEnumType nonNullType
+unwrapInType (In.NonNullInputObjectType nonNullType) =
+ Just $ In.NamedInputObjectType nonNullType
+unwrapInType (In.NonNullListType nonNullType) =
+ Just $ In.ListType nonNullType
+unwrapInType _ = Nothing
+
-- | Literal values must be compatible with the type expected in the position
-- they are found as per the coercion rules.
--
@@ -1586,8 +1588,8 @@ valuesOfCorrectTypeRule = ValueRule go constGo
| In.InputObjectType{} <- objectType
, Full.ConstObject{} <- node = mempty
check (In.ListBaseType listType) constValue@Full.Node{ .. }
- -- Skip, lists are checked recursively by the validation traverser.
- | Full.ConstList{} <- node = mempty
+ | Full.ConstList values <- node =
+ foldMap (checkNull listType) values
| otherwise = check listType constValue
check inputType Full.Node{ .. } = pure $ Error
{ message = concat
@@ -1599,3 +1601,20 @@ valuesOfCorrectTypeRule = ValueRule go constGo
]
, locations = [location]
}
+ checkNull inputType constValue =
+ let checkResult = check inputType constValue
+ in case null checkResult of
+ True
+ | Just unwrappedType <- unwrapInType inputType
+ , Full.Node{ node = Full.ConstNull, .. } <- constValue ->
+ pure $ Error
+ { message = concat
+ [ "List of non-null values of type \""
+ , show unwrappedType
+ , "\" cannot contain null values."
+ ]
+ , locations = [location]
+ }
+ | otherwise -> mempty
+ _ -> checkResult
+