summaryrefslogtreecommitdiff
path: root/src/Language/GraphQL/Validate/Rules.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/GraphQL/Validate/Rules.hs')
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs32
1 files changed, 15 insertions, 17 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index 2d7adba..e60d39d 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -50,14 +50,15 @@ import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
import Data.Bifunctor (first)
-import Data.Foldable (find, fold, foldl', toList)
+import Data.Foldable (Foldable(..), find)
import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
-import Data.List (groupBy, sortBy, sortOn)
+import Data.List (sortBy)
import Data.Maybe (fromMaybe, isJust, isNothing, mapMaybe)
import Data.List.NonEmpty (NonEmpty(..))
+import qualified Data.List.NonEmpty as NonEmpty
import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>))
import qualified Data.Sequence as Seq
@@ -253,14 +254,16 @@ findDuplicates :: (Full.Definition -> [Full.Location] -> [Full.Location])
-> Full.Location
-> String
-> RuleT m
-findDuplicates filterByName thisLocation errorMessage = do
- ast' <- asks ast
- let locations' = foldr filterByName [] ast'
- if length locations' > 1 && head locations' == thisLocation
- then pure $ error' locations'
- else lift mempty
+findDuplicates filterByName thisLocation errorMessage =
+ asks ast >>= go . foldr filterByName []
where
- error' locations' = Error
+ go locations' =
+ case locations' of
+ headLocation : otherLocations -- length locations' > 1
+ | not $ null otherLocations
+ , headLocation == thisLocation -> pure $ makeError locations'
+ _ -> lift mempty
+ makeError locations' = Error
{ message = errorMessage
, locations = locations'
}
@@ -536,11 +539,6 @@ uniqueDirectiveNamesRule = DirectivesRule
extract (Full.Directive directiveName _ location') =
(directiveName, location')
-groupSorted :: forall a. (a -> Text) -> [a] -> [[a]]
-groupSorted getName = groupBy equalByName . sortOn getName
- where
- equalByName lhs rhs = getName lhs == getName rhs
-
filterDuplicates :: forall a
. (a -> (Text, Full.Location))
-> String
@@ -549,12 +547,12 @@ filterDuplicates :: forall a
filterDuplicates extract nodeType = Seq.fromList
. fmap makeError
. filter ((> 1) . length)
- . groupSorted getName
+ . NonEmpty.groupAllWith getName
where
getName = fst . extract
makeError directives' = Error
- { message = makeMessage $ head directives'
- , locations = snd . extract <$> directives'
+ { message = makeMessage $ NonEmpty.head directives'
+ , locations = snd . extract <$> toList directives'
}
makeMessage directive = concat
[ "There can be only one "