|
|
@@ -167,81 +167,81 @@ class SpecialVocab:
|
|
|
tokenizer_config['bos_token'] = special_bos = special_cls
|
|
|
if not special_eos and special_sep and tokenizer_config:
|
|
|
tokenizer_config['eos_token'] = special_eos = special_sep
|
|
|
- post_processor = tokenizer.get('post_processor', {})
|
|
|
- for processor in post_processor.get('processors', [post_processor]):
|
|
|
- if processor.get('type') == 'RobertaProcessing':
|
|
|
- self.add_special_token['bos'] = True
|
|
|
- self.add_special_token['eos'] = True
|
|
|
- self.add_special_token['sep'] = True
|
|
|
- if not special_cls and tokenizer_config:
|
|
|
- special_cls = processor.get('cls', [special_bos])[0]
|
|
|
- tokenizer_config['cls_token'] = special_cls
|
|
|
- if not special_sep and tokenizer_config:
|
|
|
- special_sep = processor.get('sep', [special_eos])[0]
|
|
|
- tokenizer_config['sep_token'] = special_sep
|
|
|
- continue
|
|
|
- # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
|
|
- # Only works with simple templates, **will** get it wrong on unusual sequences
|
|
|
- if processor.get('type') == 'TemplateProcessing':
|
|
|
- tmpl_single = processor.get('single', [])
|
|
|
- tmpl_pair = processor.get('pair', [])
|
|
|
- special_first = None
|
|
|
- special_last = None
|
|
|
- if len(tmpl_single) > 1:
|
|
|
- if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
|
|
- if not tokenizer_config:
|
|
|
- special_bos = special_first
|
|
|
- self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
|
|
- if special_first not in (special_bos, special_cls):
|
|
|
- logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
|
|
- if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
|
|
- if not tokenizer_config:
|
|
|
- special_eos = special_last
|
|
|
- elif special_last != special_eos:
|
|
|
- if 'eot' not in self.special_token_types:
|
|
|
- self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
|
|
- tokenizer_config['eot_token'] = special_eos
|
|
|
- elif 'eom' not in self.special_token_types:
|
|
|
- self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
|
|
- tokenizer_config['eom_token'] = special_eos
|
|
|
- else:
|
|
|
- logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
|
|
- tokenizer_config['eos_token'] = special_eos = special_last
|
|
|
- self.add_special_token['eos'] = True if special_last == special_eos else False
|
|
|
- if special_last != special_eos:
|
|
|
- logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
|
|
- if tmpl_pair:
|
|
|
- seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
|
|
- seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
|
|
- if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
|
|
- logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
|
|
- if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
|
|
- tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
|
|
- tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
|
|
- if tmpl_a != 'A' or tmpl_b != 'B':
|
|
|
- logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
|
|
- # A [sep] [eos] B
|
|
|
- if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
|
|
- add_sep = False
|
|
|
- if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
|
|
- if special_entry in (special_sep, special_eos) and not special_last:
|
|
|
- add_sep = True
|
|
|
- if special_entry not in (special_sep, special_eos):
|
|
|
- logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
|
|
- else:
|
|
|
- logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
|
|
- if len(tmpl_pair) == 2:
|
|
|
- if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
|
|
- if special_entry in (special_sep, special_eos):
|
|
|
+ if post_processor := tokenizer.get('post_processor'):
|
|
|
+ for processor in post_processor.get('processors', [post_processor]):
|
|
|
+ if processor.get('type') == 'RobertaProcessing':
|
|
|
+ self.add_special_token['bos'] = True
|
|
|
+ self.add_special_token['eos'] = True
|
|
|
+ self.add_special_token['sep'] = True
|
|
|
+ if not special_cls and tokenizer_config:
|
|
|
+ special_cls = processor.get('cls', [special_bos])[0]
|
|
|
+ tokenizer_config['cls_token'] = special_cls
|
|
|
+ if not special_sep and tokenizer_config:
|
|
|
+ special_sep = processor.get('sep', [special_eos])[0]
|
|
|
+ tokenizer_config['sep_token'] = special_sep
|
|
|
+ continue
|
|
|
+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
|
|
+ # Only works with simple templates, **will** get it wrong on unusual sequences
|
|
|
+ if processor.get('type') == 'TemplateProcessing':
|
|
|
+ tmpl_single = processor.get('single', [])
|
|
|
+ tmpl_pair = processor.get('pair', [])
|
|
|
+ special_first = None
|
|
|
+ special_last = None
|
|
|
+ if len(tmpl_single) > 1:
|
|
|
+ if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
|
|
+ if not tokenizer_config:
|
|
|
+ special_bos = special_first
|
|
|
+ self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
|
|
+ if special_first not in (special_bos, special_cls):
|
|
|
+ logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
|
|
+ if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
|
|
+ if not tokenizer_config:
|
|
|
+ special_eos = special_last
|
|
|
+ elif special_last != special_eos:
|
|
|
+ if 'eot' not in self.special_token_types:
|
|
|
+ self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
|
|
+ tokenizer_config['eot_token'] = special_eos
|
|
|
+ elif 'eom' not in self.special_token_types:
|
|
|
+ self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
|
|
+ tokenizer_config['eom_token'] = special_eos
|
|
|
+ else:
|
|
|
+ logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
|
|
+ tokenizer_config['eos_token'] = special_eos = special_last
|
|
|
+ self.add_special_token['eos'] = True if special_last == special_eos else False
|
|
|
+ if special_last != special_eos:
|
|
|
+ logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
|
|
+ if tmpl_pair:
|
|
|
+ seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
|
|
+ seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
|
|
+ if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
|
|
+ logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
|
|
+ if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
|
|
+ tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
|
|
+ tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
|
|
+ if tmpl_a != 'A' or tmpl_b != 'B':
|
|
|
+ logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
|
|
+ # A [sep] [eos] B
|
|
|
+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
|
|
+ add_sep = False
|
|
|
+ if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
|
|
+ if special_entry in (special_sep, special_eos) and not special_last:
|
|
|
add_sep = True
|
|
|
if special_entry not in (special_sep, special_eos):
|
|
|
- logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
|
|
+ logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
|
|
else:
|
|
|
- logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
|
|
- self.add_special_token['sep'] = add_sep
|
|
|
- if add_sep and not special_sep and tokenizer_config:
|
|
|
- tokenizer_config['sep_token'] = special_eos
|
|
|
- continue
|
|
|
+ logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
|
|
+ if len(tmpl_pair) == 2:
|
|
|
+ if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
|
|
+ if special_entry in (special_sep, special_eos):
|
|
|
+ add_sep = True
|
|
|
+ if special_entry not in (special_sep, special_eos):
|
|
|
+ logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
|
|
+ else:
|
|
|
+ logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
|
|
+ self.add_special_token['sep'] = add_sep
|
|
|
+ if add_sep and not special_sep and tokenizer_config:
|
|
|
+ tokenizer_config['sep_token'] = special_eos
|
|
|
+ continue
|
|
|
if not tokenizer_config:
|
|
|
return True
|
|
|
chat_template_alt = None
|