|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING |
|
|
|
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available |
|
|
|
|
|
_import_structure = { |
|
"configuration_utils": ["GenerationConfig"], |
|
"streamers": ["TextIteratorStreamer", "TextStreamer"], |
|
} |
|
|
|
try: |
|
if not is_torch_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
_import_structure["beam_constraints"] = [ |
|
"Constraint", |
|
"ConstraintListState", |
|
"DisjunctiveConstraint", |
|
"PhrasalConstraint", |
|
] |
|
_import_structure["beam_search"] = [ |
|
"BeamHypotheses", |
|
"BeamScorer", |
|
"BeamSearchScorer", |
|
"ConstrainedBeamSearchScorer", |
|
] |
|
_import_structure["logits_process"] = [ |
|
"AlternatingCodebooksLogitsProcessor", |
|
"ClassifierFreeGuidanceLogitsProcessor", |
|
"EncoderNoRepeatNGramLogitsProcessor", |
|
"EncoderRepetitionPenaltyLogitsProcessor", |
|
"EpsilonLogitsWarper", |
|
"EtaLogitsWarper", |
|
"ExponentialDecayLengthPenalty", |
|
"ForcedBOSTokenLogitsProcessor", |
|
"ForcedEOSTokenLogitsProcessor", |
|
"ForceTokensLogitsProcessor", |
|
"HammingDiversityLogitsProcessor", |
|
"InfNanRemoveLogitsProcessor", |
|
"LogitNormalization", |
|
"LogitsProcessor", |
|
"LogitsProcessorList", |
|
"LogitsWarper", |
|
"MinLengthLogitsProcessor", |
|
"MinNewTokensLengthLogitsProcessor", |
|
"NoBadWordsLogitsProcessor", |
|
"NoRepeatNGramLogitsProcessor", |
|
"PrefixConstrainedLogitsProcessor", |
|
"RepetitionPenaltyLogitsProcessor", |
|
"SequenceBiasLogitsProcessor", |
|
"SuppressTokensLogitsProcessor", |
|
"SuppressTokensAtBeginLogitsProcessor", |
|
"TemperatureLogitsWarper", |
|
"TopKLogitsWarper", |
|
"TopPLogitsWarper", |
|
"TypicalLogitsWarper", |
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor", |
|
"WhisperTimeStampLogitsProcessor", |
|
] |
|
_import_structure["stopping_criteria"] = [ |
|
"MaxNewTokensCriteria", |
|
"MaxLengthCriteria", |
|
"MaxTimeCriteria", |
|
"StoppingCriteria", |
|
"StoppingCriteriaList", |
|
"validate_stopping_criteria", |
|
] |
|
_import_structure["utils"] = [ |
|
"GenerationMixin", |
|
"top_k_top_p_filtering", |
|
"GreedySearchEncoderDecoderOutput", |
|
"GreedySearchDecoderOnlyOutput", |
|
"SampleEncoderDecoderOutput", |
|
"SampleDecoderOnlyOutput", |
|
"BeamSearchEncoderDecoderOutput", |
|
"BeamSearchDecoderOnlyOutput", |
|
"BeamSampleEncoderDecoderOutput", |
|
"BeamSampleDecoderOnlyOutput", |
|
"ContrastiveSearchEncoderDecoderOutput", |
|
"ContrastiveSearchDecoderOnlyOutput", |
|
] |
|
|
|
try: |
|
if not is_tf_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
_import_structure["tf_logits_process"] = [ |
|
"TFForcedBOSTokenLogitsProcessor", |
|
"TFForcedEOSTokenLogitsProcessor", |
|
"TFForceTokensLogitsProcessor", |
|
"TFLogitsProcessor", |
|
"TFLogitsProcessorList", |
|
"TFLogitsWarper", |
|
"TFMinLengthLogitsProcessor", |
|
"TFNoBadWordsLogitsProcessor", |
|
"TFNoRepeatNGramLogitsProcessor", |
|
"TFRepetitionPenaltyLogitsProcessor", |
|
"TFSuppressTokensAtBeginLogitsProcessor", |
|
"TFSuppressTokensLogitsProcessor", |
|
"TFTemperatureLogitsWarper", |
|
"TFTopKLogitsWarper", |
|
"TFTopPLogitsWarper", |
|
] |
|
_import_structure["tf_utils"] = [ |
|
"TFGenerationMixin", |
|
"tf_top_k_top_p_filtering", |
|
"TFGreedySearchDecoderOnlyOutput", |
|
"TFGreedySearchEncoderDecoderOutput", |
|
"TFSampleEncoderDecoderOutput", |
|
"TFSampleDecoderOnlyOutput", |
|
"TFBeamSearchEncoderDecoderOutput", |
|
"TFBeamSearchDecoderOnlyOutput", |
|
"TFBeamSampleEncoderDecoderOutput", |
|
"TFBeamSampleDecoderOnlyOutput", |
|
"TFContrastiveSearchEncoderDecoderOutput", |
|
"TFContrastiveSearchDecoderOnlyOutput", |
|
] |
|
|
|
try: |
|
if not is_flax_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
_import_structure["flax_logits_process"] = [ |
|
"FlaxForcedBOSTokenLogitsProcessor", |
|
"FlaxForcedEOSTokenLogitsProcessor", |
|
"FlaxForceTokensLogitsProcessor", |
|
"FlaxLogitsProcessor", |
|
"FlaxLogitsProcessorList", |
|
"FlaxLogitsWarper", |
|
"FlaxMinLengthLogitsProcessor", |
|
"FlaxSuppressTokensAtBeginLogitsProcessor", |
|
"FlaxSuppressTokensLogitsProcessor", |
|
"FlaxTemperatureLogitsWarper", |
|
"FlaxTopKLogitsWarper", |
|
"FlaxTopPLogitsWarper", |
|
"FlaxWhisperTimeStampLogitsProcessor", |
|
] |
|
_import_structure["flax_utils"] = [ |
|
"FlaxGenerationMixin", |
|
"FlaxGreedySearchOutput", |
|
"FlaxSampleOutput", |
|
"FlaxBeamSearchOutput", |
|
] |
|
|
|
if TYPE_CHECKING: |
|
from .configuration_utils import GenerationConfig |
|
from .streamers import TextIteratorStreamer, TextStreamer |
|
|
|
try: |
|
if not is_torch_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint |
|
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer |
|
from .logits_process import ( |
|
AlternatingCodebooksLogitsProcessor, |
|
ClassifierFreeGuidanceLogitsProcessor, |
|
EncoderNoRepeatNGramLogitsProcessor, |
|
EncoderRepetitionPenaltyLogitsProcessor, |
|
EpsilonLogitsWarper, |
|
EtaLogitsWarper, |
|
ExponentialDecayLengthPenalty, |
|
ForcedBOSTokenLogitsProcessor, |
|
ForcedEOSTokenLogitsProcessor, |
|
ForceTokensLogitsProcessor, |
|
HammingDiversityLogitsProcessor, |
|
InfNanRemoveLogitsProcessor, |
|
LogitNormalization, |
|
LogitsProcessor, |
|
LogitsProcessorList, |
|
LogitsWarper, |
|
MinLengthLogitsProcessor, |
|
MinNewTokensLengthLogitsProcessor, |
|
NoBadWordsLogitsProcessor, |
|
NoRepeatNGramLogitsProcessor, |
|
PrefixConstrainedLogitsProcessor, |
|
RepetitionPenaltyLogitsProcessor, |
|
SequenceBiasLogitsProcessor, |
|
SuppressTokensAtBeginLogitsProcessor, |
|
SuppressTokensLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
TypicalLogitsWarper, |
|
UnbatchedClassifierFreeGuidanceLogitsProcessor, |
|
WhisperTimeStampLogitsProcessor, |
|
) |
|
from .stopping_criteria import ( |
|
MaxLengthCriteria, |
|
MaxNewTokensCriteria, |
|
MaxTimeCriteria, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
validate_stopping_criteria, |
|
) |
|
from .utils import ( |
|
BeamSampleDecoderOnlyOutput, |
|
BeamSampleEncoderDecoderOutput, |
|
BeamSearchDecoderOnlyOutput, |
|
BeamSearchEncoderDecoderOutput, |
|
ContrastiveSearchDecoderOnlyOutput, |
|
ContrastiveSearchEncoderDecoderOutput, |
|
GenerationMixin, |
|
GreedySearchDecoderOnlyOutput, |
|
GreedySearchEncoderDecoderOutput, |
|
SampleDecoderOnlyOutput, |
|
SampleEncoderDecoderOutput, |
|
top_k_top_p_filtering, |
|
) |
|
|
|
try: |
|
if not is_tf_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
from .tf_logits_process import ( |
|
TFForcedBOSTokenLogitsProcessor, |
|
TFForcedEOSTokenLogitsProcessor, |
|
TFForceTokensLogitsProcessor, |
|
TFLogitsProcessor, |
|
TFLogitsProcessorList, |
|
TFLogitsWarper, |
|
TFMinLengthLogitsProcessor, |
|
TFNoBadWordsLogitsProcessor, |
|
TFNoRepeatNGramLogitsProcessor, |
|
TFRepetitionPenaltyLogitsProcessor, |
|
TFSuppressTokensAtBeginLogitsProcessor, |
|
TFSuppressTokensLogitsProcessor, |
|
TFTemperatureLogitsWarper, |
|
TFTopKLogitsWarper, |
|
TFTopPLogitsWarper, |
|
) |
|
from .tf_utils import ( |
|
TFBeamSampleDecoderOnlyOutput, |
|
TFBeamSampleEncoderDecoderOutput, |
|
TFBeamSearchDecoderOnlyOutput, |
|
TFBeamSearchEncoderDecoderOutput, |
|
TFContrastiveSearchDecoderOnlyOutput, |
|
TFContrastiveSearchEncoderDecoderOutput, |
|
TFGenerationMixin, |
|
TFGreedySearchDecoderOnlyOutput, |
|
TFGreedySearchEncoderDecoderOutput, |
|
TFSampleDecoderOnlyOutput, |
|
TFSampleEncoderDecoderOutput, |
|
tf_top_k_top_p_filtering, |
|
) |
|
|
|
try: |
|
if not is_flax_available(): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
pass |
|
else: |
|
from .flax_logits_process import ( |
|
FlaxForcedBOSTokenLogitsProcessor, |
|
FlaxForcedEOSTokenLogitsProcessor, |
|
FlaxForceTokensLogitsProcessor, |
|
FlaxLogitsProcessor, |
|
FlaxLogitsProcessorList, |
|
FlaxLogitsWarper, |
|
FlaxMinLengthLogitsProcessor, |
|
FlaxSuppressTokensAtBeginLogitsProcessor, |
|
FlaxSuppressTokensLogitsProcessor, |
|
FlaxTemperatureLogitsWarper, |
|
FlaxTopKLogitsWarper, |
|
FlaxTopPLogitsWarper, |
|
FlaxWhisperTimeStampLogitsProcessor, |
|
) |
|
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput |
|
else: |
|
import sys |
|
|
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
|
|