Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
Jaime Collado
/
textflow
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
1
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit
0bd3c574
authored
Oct 17, 2023
by
Estrella Vallecillo
Browse files
Options
_('Browse Files')
Download
Email Patches
Plain Diff
perplexity Analyzer
parent
8b6a0f7b
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
0 deletions
textflow/PerplexityAnalyzer.py
textflow/PerplexityAnalyzer.py
0 → 100644
View file @
0bd3c574
from
typing
import
Optional
from
textflow.Analyzer
import
Analyzer
import
re
import
torch
import
tqdm
from
transformers
import
GPT2LMHeadModel
,
GPT2TokenizerFast
class
PerplexityAnalyzer
(
Analyzer
):
"""
A class that provides methods to analyze the perplexity of the text of a sequence.
Attributes:
language: the languague of text to analyze the emojis.
textEmojis: the text with the words of the emojis instead of emojis.
freqEmoji: a dictionary with the different emojis and their frequence.
nEmoji: the number of emojis that appear in the text to analyze.
"""
def
__init__
(
self
,
model
=
GPT2LMHeadModel
.
from_pretrained
(
"PlanTL-GOB-ES/gpt2-large-bne"
),
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
"PlanTL-GOB-ES/gpt2-large-bne"
),
device
=
"cuda"
):
"""
Create an emoji analyzer.
Args:
language: the languague of text to analyze the emojis.
"""
self
.
device
=
device
self
.
model
=
model
.
to
(
device
)
self
.
tokenizer
=
tokenizer
def
analyze
(
self
,
sequence
,
tag
,
levelOfAnalyzer
,
levelOfResult
:
Optional
[
str
]
=
""
):
"""
Analyze a sequence with a emoji function.
Args:
sequence: the Sequence we want to analyze.
tag: the label to store the analysis result.
levelOfAnalyzer: the path of the sequence level to analyze inside of the result.
levelOfResult: the path of the sequence level to store the result.
"""
super
()
.
analyze
(
self
.
perplexityResult
,
sequence
,
tag
,
levelOfAnalyzer
,
levelOfResult
,
True
)
def
perplexityResult
(
self
,
arrayText
):
"""
Function that analyzes the emojis of a list of texts.
Args:
arrayText: list that contains the texts that we want to analyze
Returns:
A float list. Each float value of the list represent the result
of the perplexity of the corresponding text.
"""
resultsList
=
[]
for
t
in
arrayText
:
resultsList
.
append
(
self
.
perplexity
(
t
))
return
resultsList
def
perplexity
(
self
,
text
):
"""
Function that calculate the perplexity of a text
Args:
text: the text that we want to analyze
"""
encodings
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
)
max_length
=
self
.
model
.
config
.
n_positions
stride
=
512
seq_len
=
encodings
.
input_ids
.
size
(
1
)
nlls
=
[]
prev_end_loc
=
0
for
begin_loc
in
tqdm
(
range
(
0
,
seq_len
,
stride
)):
end_loc
=
min
(
begin_loc
+
max_length
,
seq_len
)
trg_len
=
end_loc
-
prev_end_loc
# may be different from stride on last loop
input_ids
=
encodings
.
input_ids
[:,
begin_loc
:
end_loc
]
.
to
(
self
.
device
)
target_ids
=
input_ids
.
clone
()
target_ids
[:,
:
-
trg_len
]
=
-
100
with
torch
.
no_grad
():
outputs
=
self
.
model
(
input_ids
,
labels
=
target_ids
)
# loss is calculated using CrossEntropyLoss which averages over valid labels
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
# to the left by 1.
neg_log_likelihood
=
outputs
.
loss
nlls
.
append
(
neg_log_likelihood
)
prev_end_loc
=
end_loc
if
end_loc
==
seq_len
:
break
ppl
=
torch
.
exp
(
torch
.
stack
(
nlls
)
.
mean
())
return
ppl
.
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment