{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | MemoBytes is an abstraction for a data type that encodes its own serialization.
--   The idea is to use a newtype around a MemoBytes non-memoizing version.
--   For example:   newtype Foo = Foo (MemoBytes NonMemoizingFoo)
--   This way all the instances for Foo (Eq,Show,Ord,ToCBOR,FromCBOR,NoThunks,Generic)
--   can be derived for free.
module Data.MemoBytes
  ( MemoBytes (..),
    memoBytes,
    Mem,
    shorten,
    showMemo,
    printMemo,
    roundTripMemo,
  )
where

import Cardano.Binary
  ( Annotator (..),
    FromCBOR (fromCBOR),
    ToCBOR (toCBOR),
    encodePreEncoded,
    withSlice,
  )
import Codec.CBOR.Read (DeserialiseFailure, deserialiseFromBytes)
import Codec.CBOR.Write (toLazyByteString)
import Control.DeepSeq (NFData (..))
import Data.ByteString.Lazy (fromStrict, toStrict)
import qualified Data.ByteString.Lazy as Lazy
import Data.ByteString.Short (ShortByteString, fromShort, toShort)
import Data.Coders (Encode, encode, runE)
import Data.Typeable
import GHC.Generics (Generic)
import NoThunks.Class (AllowThunksIn (..), NoThunks (..))
import Prelude hiding (span)

-- ========================================================================

data MemoBytes t = Memo {MemoBytes t -> t
memotype :: !t, MemoBytes t -> ShortByteString
memobytes :: ShortByteString}
  deriving (Context -> MemoBytes t -> IO (Maybe ThunkInfo)
Proxy (MemoBytes t) -> String
(Context -> MemoBytes t -> IO (Maybe ThunkInfo))
-> (Context -> MemoBytes t -> IO (Maybe ThunkInfo))
-> (Proxy (MemoBytes t) -> String)
-> NoThunks (MemoBytes t)
forall t.
(Typeable t, NoThunks t) =>
Context -> MemoBytes t -> IO (Maybe ThunkInfo)
forall t. (Typeable t, NoThunks t) => Proxy (MemoBytes t) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
showTypeOf :: Proxy (MemoBytes t) -> String
$cshowTypeOf :: forall t. (Typeable t, NoThunks t) => Proxy (MemoBytes t) -> String
wNoThunks :: Context -> MemoBytes t -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall t.
(Typeable t, NoThunks t) =>
Context -> MemoBytes t -> IO (Maybe ThunkInfo)
noThunks :: Context -> MemoBytes t -> IO (Maybe ThunkInfo)
$cnoThunks :: forall t.
(Typeable t, NoThunks t) =>
Context -> MemoBytes t -> IO (Maybe ThunkInfo)
NoThunks) via AllowThunksIn '["memobytes"] (MemoBytes t)

deriving instance NFData t => NFData (MemoBytes t)

deriving instance Generic (MemoBytes t)

instance (Typeable t) => ToCBOR (MemoBytes t) where
  toCBOR :: MemoBytes t -> Encoding
toCBOR (Memo t
_ ShortByteString
bytes) = ByteString -> Encoding
encodePreEncoded (ShortByteString -> ByteString
fromShort ShortByteString
bytes)

instance (Typeable t, FromCBOR (Annotator t)) => FromCBOR (Annotator (MemoBytes t)) where
  fromCBOR :: Decoder s (Annotator (MemoBytes t))
fromCBOR = do
    (Annotator FullByteString -> t
getT, Annotator FullByteString -> LByteString
getBytes) <- Decoder s (Annotator t)
-> Decoder s (Annotator t, Annotator LByteString)
forall s a. Decoder s a -> Decoder s (a, Annotator LByteString)
withSlice Decoder s (Annotator t)
forall a s. FromCBOR a => Decoder s a
fromCBOR
    Annotator (MemoBytes t) -> Decoder s (Annotator (MemoBytes t))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((FullByteString -> MemoBytes t) -> Annotator (MemoBytes t)
forall a. (FullByteString -> a) -> Annotator a
Annotator (\FullByteString
fullbytes -> t -> ShortByteString -> MemoBytes t
forall t. t -> ShortByteString -> MemoBytes t
Memo (FullByteString -> t
getT FullByteString
fullbytes) (ByteString -> ShortByteString
toShort (LByteString -> ByteString
toStrict (FullByteString -> LByteString
getBytes FullByteString
fullbytes)))))

instance Eq (MemoBytes t) where (Memo t
_ ShortByteString
x) == :: MemoBytes t -> MemoBytes t -> Bool
== (Memo t
_ ShortByteString
y) = ShortByteString
x ShortByteString -> ShortByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ShortByteString
y

instance Show t => Show (MemoBytes t) where show :: MemoBytes t -> String
show (Memo t
y ShortByteString
_) = t -> String
forall a. Show a => a -> String
show t
y

instance Ord t => Ord (MemoBytes t) where compare :: MemoBytes t -> MemoBytes t -> Ordering
compare (Memo t
x ShortByteString
_) (Memo t
y ShortByteString
_) = t -> t -> Ordering
forall a. Ord a => a -> a -> Ordering
compare t
x t
y

shorten :: Lazy.ByteString -> ShortByteString
shorten :: LByteString -> ShortByteString
shorten LByteString
x = ByteString -> ShortByteString
toShort (LByteString -> ByteString
toStrict LByteString
x)

-- | Useful when deriving FromCBOR(Annotator T)
-- deriving via (Mem T) instance (Era era) => FromCBOR (Annotator T)
type Mem t = Annotator (MemoBytes t)

showMemo :: Show t => MemoBytes t -> String
showMemo :: MemoBytes t -> String
showMemo (Memo t
t ShortByteString
b) = String
"(Memo " String -> ShowS
forall a. [a] -> [a] -> [a]
++ t -> String
forall a. Show a => a -> String
show t
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShortByteString -> String
forall a. Show a => a -> String
show ShortByteString
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

printMemo :: Show t => MemoBytes t -> IO ()
printMemo :: MemoBytes t -> IO ()
printMemo MemoBytes t
x = String -> IO ()
putStrLn (MemoBytes t -> String
forall t. Show t => MemoBytes t -> String
showMemo MemoBytes t
x)

memoBytes :: Encode w t -> MemoBytes t
memoBytes :: Encode w t -> MemoBytes t
memoBytes Encode w t
t = t -> ShortByteString -> MemoBytes t
forall t. t -> ShortByteString -> MemoBytes t
Memo (Encode w t -> t
forall (w :: Wrapped) t. Encode w t -> t
runE Encode w t
t) (LByteString -> ShortByteString
shorten (Encoding -> LByteString
toLazyByteString (Encode w t -> Encoding
forall (w :: Wrapped) t. Encode w t -> Encoding
encode Encode w t
t)))

roundTripMemo :: (FromCBOR t) => MemoBytes t -> Either Codec.CBOR.Read.DeserialiseFailure (Lazy.ByteString, MemoBytes t)
roundTripMemo :: MemoBytes t -> Either DeserialiseFailure (LByteString, MemoBytes t)
roundTripMemo (Memo t
_t ShortByteString
bytes) =
  case (forall s. Decoder s t)
-> LByteString -> Either DeserialiseFailure (LByteString, t)
forall a.
(forall s. Decoder s a)
-> LByteString -> Either DeserialiseFailure (LByteString, a)
deserialiseFromBytes forall s. Decoder s t
forall a s. FromCBOR a => Decoder s a
fromCBOR (ByteString -> LByteString
fromStrict (ShortByteString -> ByteString
fromShort ShortByteString
bytes)) of
    Left DeserialiseFailure
err -> DeserialiseFailure
-> Either DeserialiseFailure (LByteString, MemoBytes t)
forall a b. a -> Either a b
Left DeserialiseFailure
err
    Right (LByteString
leftover, t
newt) -> (LByteString, MemoBytes t)
-> Either DeserialiseFailure (LByteString, MemoBytes t)
forall a b. b -> Either a b
Right (LByteString
leftover, t -> ShortByteString -> MemoBytes t
forall t. t -> ShortByteString -> MemoBytes t
Memo t
newt ShortByteString
bytes)