diff options
Diffstat (limited to 'src/Language/GraphQL/Validate/Rules.hs')
| -rw-r--r-- | src/Language/GraphQL/Validate/Rules.hs | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index a6bc43b..46a14b7 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -1502,15 +1502,6 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False hasNonNullVariableDefaultValue Nothing = False hasNonNullVariableDefaultValue _ = True - unwrapInType (In.NonNullScalarType nonNullType) = - Just $ In.NamedScalarType nonNullType - unwrapInType (In.NonNullEnumType nonNullType) = - Just $ In.NamedEnumType nonNullType - unwrapInType (In.NonNullInputObjectType nonNullType) = - Just $ In.NamedInputObjectType nonNullType - unwrapInType (In.NonNullListType nonNullType) = - Just $ In.ListType nonNullType - unwrapInType _ = Nothing makeError variableDefinition expectedType = let Full.VariableDefinition variableName variableType _ location' = variableDefinition @@ -1527,6 +1518,17 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case , locations = [location'] } +unwrapInType :: In.Type -> Maybe In.Type +unwrapInType (In.NonNullScalarType nonNullType) = + Just $ In.NamedScalarType nonNullType +unwrapInType (In.NonNullEnumType nonNullType) = + Just $ In.NamedEnumType nonNullType +unwrapInType (In.NonNullInputObjectType nonNullType) = + Just $ In.NamedInputObjectType nonNullType +unwrapInType (In.NonNullListType nonNullType) = + Just $ In.ListType nonNullType +unwrapInType _ = Nothing + -- | Literal values must be compatible with the type expected in the position -- they are found as per the coercion rules. -- @@ -1586,8 +1588,8 @@ valuesOfCorrectTypeRule = ValueRule go constGo | In.InputObjectType{} <- objectType , Full.ConstObject{} <- node = mempty check (In.ListBaseType listType) constValue@Full.Node{ .. } - -- Skip, lists are checked recursively by the validation traverser. - | Full.ConstList{} <- node = mempty + | Full.ConstList values <- node = + foldMap (checkNull listType) values | otherwise = check listType constValue check inputType Full.Node{ .. } = pure $ Error { message = concat @@ -1599,3 +1601,20 @@ valuesOfCorrectTypeRule = ValueRule go constGo ] , locations = [location] } + checkNull inputType constValue = + let checkResult = check inputType constValue + in case null checkResult of + True + | Just unwrappedType <- unwrapInType inputType + , Full.Node{ node = Full.ConstNull, .. } <- constValue -> + pure $ Error + { message = concat + [ "List of non-null values of type \"" + , show unwrappedType + , "\" cannot contain null values." + ] + , locations = [location] + } + | otherwise -> mempty + _ -> checkResult + |
