Hugging Face Smell Detector

Source Code

  1import os
  2import sys
  3from typing import Any, Dict, List
  4
  5import astroid
  6from astroid import nodes
  7
  8
  9class HuggingFaceSmellDetector:
 10    """A detector class that identifies common code smells in Hugging Face Transformers code.
 11
 12    This detector analyzes Python code that uses the Hugging Face Transformers library and identifies
 13    potential issues and best practices violations related to model training, data processing,
 14    and performance optimization.
 15    """
 16
 17    def __init__(self):
 18        self.smells: List[Dict[str, Any]] = []
 19
 20    def detect_smells(self, file_path: str) -> List[Dict[str, Any]]:
 21        """Analyze a Python file for Hugging Face-related code smells.
 22
 23        Args:
 24            file_path: Path to the Python file to analyze.
 25
 26        Returns:
 27            List of dictionaries containing detected code smells and their details.
 28        """
 29        try:
 30            with open(file_path, 'r') as file:
 31                content = file.read()
 32            module_name = os.path.splitext(os.path.basename(file_path))[0]
 33            module = astroid.parse(content, module_name=module_name)
 34
 35            # Check if 'transformers' is imported
 36            if self.is_framework_used(module, 'transformers'):
 37                self.visit_module(module, file_path)
 38            else:
 39                print(
 40                    f"Skipping Hugging Face smell detection for {file_path}: 'transformers' not imported",
 41                    file=sys.stderr)
 42        except astroid.exceptions.AstroidSyntaxError as e:
 43            print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
 44        except Exception as e:
 45            print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
 46        return self.smells
 47
 48    def is_framework_used(self, node: nodes.Module, framework: str) -> bool:
 49        """Check if a specific framework is imported in the module.
 50
 51        Args:
 52            node: AST node representing the module
 53            framework: Name of the framework to check for
 54
 55        Returns:
 56            True if the framework is imported, False otherwise
 57        """
 58        for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
 59            if isinstance(import_node, nodes.Import):
 60                if any(name.split('.')[0] == framework for name, _ in import_node.names):
 61                    return True
 62            elif isinstance(import_node, nodes.ImportFrom):
 63                if import_node.modname.split('.')[0] == framework:
 64                    return True
 65        return False
 66
 67    def visit_module(self, node: nodes.Module, file_path: str):
 68        """Visit a module node and run all smell detection checks.
 69
 70        Args:
 71            node: AST node representing the module
 72            file_path: Path to the file being analyzed
 73        """
 74        self.check_model_versioning(node, file_path)
 75        self.check_tokenizer_caching(node, file_path)
 76        self.check_model_caching(node, file_path)
 77        self.check_deterministic_tokenization(node, file_path)
 78        self.check_efficient_data_loading(node, file_path)
 79        self.check_distributed_training(node, file_path)
 80        self.check_mixed_precision_training(node, file_path)
 81        self.check_gradient_accumulation(node, file_path)
 82        self.check_learning_rate_scheduling(node, file_path)
 83        self.check_early_stopping(node, file_path)
 84
 85    def add_smell(self, smell: str, fix: str, benefits: str, strategies: str, node: nodes.NodeNG, file_path: str):
 86        """Add a detected code smell to the results.
 87
 88        Args:
 89            smell: Description of the code smell
 90            fix: How to fix the issue
 91            benefits: Benefits of fixing the issue
 92            strategies: Specific strategies to implement the fix
 93            node: AST node where the smell was detected
 94            file_path: Path to the file containing the smell
 95        """
 96        self.smells.append({
 97            "smell": smell,
 98            "how_to_fix": fix,
 99            "benefits": benefits,
100            "strategies": strategies,
101            "line_number": node.lineno,
102            "code_snippet": node.as_string(),
103            "file_path": file_path
104        })
105
106    def check_model_versioning(self, node: nodes.Module, file_path: str):
107        """Check if model versions are explicitly specified when loading pre-trained models.
108
109        Detects cases where models are loaded without version tags, which could lead to
110        reproducibility issues.
111
112        Args:
113            node: AST node representing the module
114            file_path: Path to the file being analyzed
115        """
116        for call in node.nodes_of_class(nodes.Call):
117            if ('from_pretrained' in call.func.as_string() and
118                ('AutoModel' in call.func.as_string() or
119                 'PreTrainedModel' in call.func.as_string())):
120                if not any('@' in arg.as_string() for arg in call.args):
121                    self.add_smell(
122                        "Model versioning not specified",
123                        "Specify model version when loading pre-trained models",
124                        "Ensures consistency and reproducibility of results",
125                        "Use model_name_or_path@revision when loading models (e.g., bert-base-uncased@v1)",
126                        call,
127                        file_path
128                    )
129
130    def check_tokenizer_caching(self, node: nodes.Module, file_path: str):
131        """Check if tokenizer caching is enabled when loading tokenizers.
132
133        Detects cases where tokenizers are loaded without caching configuration, which
134        could lead to unnecessary re-downloads and slower loading times.
135
136        Args:
137            node: AST node representing the module
138            file_path: Path to the file being analyzed
139        """
140        for call in node.nodes_of_class(nodes.Call):
141            if ('from_pretrained' in call.func.as_string() and
142                ('AutoTokenizer' in call.func.as_string() or
143                 'PreTrainedTokenizer' in call.func.as_string())):
144                if not any(keyword.arg in ['cache_dir', 'local_files_only']
145                           for keyword in call.keywords):
146                    self.add_smell(
147                        "Tokenizer caching not used",
148                        "Cache tokenizers to avoid re-downloading",
149                        "Reduces loading time and network dependency",
150                        "Use cache_dir parameter when loading tokenizers",
151                        call,
152                        file_path
153                    )
154
155    def check_model_caching(self, node: nodes.Module, file_path: str):
156        """Check if model caching is enabled when loading models.
157
158        Detects cases where models are loaded without caching configuration, which
159        could lead to unnecessary re-downloads and slower loading times.
160
161        Args:
162            node: AST node representing the module
163            file_path: Path to the file being analyzed
164        """
165        for call in node.nodes_of_class(nodes.Call):
166            if ('from_pretrained' in call.func.as_string() and
167                ('AutoModel' in call.func.as_string() or
168                 'PreTrainedModel' in call.func.as_string())):
169                if not any(keyword.arg in ['cache_dir', 'local_files_only']
170                           for keyword in call.keywords):
171                    self.add_smell(
172                        "Model caching not used",
173                        "Cache models to avoid re-downloading",
174                        "Improves loading efficiency and reduces network dependency",
175                        "Use cache_dir parameter when loading models",
176                        call,
177                        file_path
178                    )
179
180    def check_deterministic_tokenization(self, node: nodes.Module, file_path: str):
181        """Check if tokenization parameters are explicitly specified.
182
183        Detects cases where tokenization settings are not explicitly defined,
184        which could lead to inconsistent preprocessing across runs.
185
186        Args:
187            node: AST node representing the module
188            file_path: Path to the file being analyzed
189        """
190        for call in node.nodes_of_class(nodes.Call):
191            if ('from_pretrained' in call.func.as_string() and
192                ('AutoTokenizer' in call.func.as_string() or
193                 'PreTrainedTokenizer' in call.func.as_string())):
194                deterministic_params = [
195                    'do_lower_case', 'strip_accents', 'truncation',
196                    'padding', 'max_length', 'return_tensors'
197                ]
198                if not any(keyword.arg in deterministic_params for keyword in call.keywords):
199                    self.add_smell(
200                        "Deterministic tokenization settings not specified",
201                        "Use consistent tokenization settings",
202                        "Ensures reproducible pre-processing and consistent model inputs",
203                        "Set tokenization parameters explicitly when loading tokenizers",
204                        call,
205                        file_path
206                    )
207
208    def check_efficient_data_loading(self, node: nodes.Module, file_path: str):
209        """Check if efficient data loading techniques are being used.
210
211        Detects cases where standard data loading is used instead of optimized
212        methods like datasets library or DataLoader.
213
214        Args:
215            node: AST node representing the module
216            file_path: Path to the file being analyzed
217        """
218        datasets_imported = any('datasets' in import_node.names[0][0]
219                                for import_node in node.nodes_of_class(nodes.ImportFrom))
220
221        efficient_patterns = [
222            'load_dataset',
223            'Dataset.from_',
224            'DataLoader',
225            'IterableDataset'
226        ]
227
228        has_efficient_loading = any(
229            pattern in call.func.as_string()
230            for call in node.nodes_of_class(nodes.Call)
231            for pattern in efficient_patterns
232        )
233
234        if not (datasets_imported or has_efficient_loading):
235            self.add_smell(
236                "Efficient data loading not detected",
237                "Use efficient data loading techniques",
238                "Enhances data processing speed and model training efficiency",
239                "Use datasets library for loading and processing data",
240                node,
241                file_path
242            )
243
244    def check_distributed_training(self, node: nodes.Module, file_path: str):
245        """Check if distributed training is configured when using training functionality.
246
247        Detects cases where training code is present but distributed training
248        settings are not configured.
249
250        Args:
251            node: AST node representing the module
252            file_path: Path to the file being analyzed
253        """
254        # Check if training-related imports exist
255        has_training_imports = any(
256            'Trainer' in import_node.names[0][0] or 'TrainingArguments' in import_node.names[0][0]
257            for import_node in node.nodes_of_class(nodes.ImportFrom)
258        )
259
260        if not has_training_imports:
261            return  # Skip if no training-related imports
262
263        distributed_config = False
264        for assign in node.nodes_of_class(nodes.Assign):
265            if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
266                if any(keyword.arg in ['local_rank', 'n_gpu', 'distributed_training', 'tpu_num_cores']
267                       for keyword in assign.value.keywords):
268                    distributed_config = True
269                    break
270
271        # Only report if TrainingArguments is used but without distributed config
272        if not distributed_config and self._has_training_arguments(node):
273            self.add_smell(
274                "Distributed training not configured",
275                "Utilize distributed training capabilities",
276                "Speeds up training and leverages multiple GPUs/TPUs",
277                "Configure Trainer with distributed settings using local_rank, n_gpu, or tpu_num_cores",
278                node,
279                file_path
280            )
281
282    def _has_training_arguments(self, node: nodes.Module) -> bool:
283        """Helper method to check if TrainingArguments is actually used in the code"""
284        return any(
285            isinstance(assign.targets[0], nodes.AssignName) and
286            assign.targets[0].name == 'TrainingArguments'
287            for assign in node.nodes_of_class(nodes.Assign)
288        )
289
290    def check_mixed_precision_training(self, node: nodes.Module, file_path: str):
291        """Check if mixed precision training is enabled.
292
293        Detects cases where training is performed without mixed precision settings,
294        which could lead to suboptimal performance and memory usage.
295
296        Args:
297            node: AST node representing the module
298            file_path: Path to the file being analyzed
299        """
300        if not self._has_training_arguments(node):
301            return  # Skip if no TrainingArguments used
302
303        fp16_used = False
304        for assign in node.nodes_of_class(nodes.Assign):
305            if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
306                if any((keyword.arg == 'fp16' and keyword.value.value) or
307                       (keyword.arg == 'bf16' and keyword.value.value) or
308                       keyword.arg == 'half_precision_backend'
309                       for keyword in assign.value.keywords):
310                    fp16_used = True
311                    break
312
313        if not fp16_used:
314            self.add_smell(
315                "Mixed precision training not enabled",
316                "Use mixed precision training to improve performance",
317                "Accelerates training and reduces memory usage",
318                "Enable mixed precision training using fp16=True or bf16=True in TrainingArguments",
319                node,
320                file_path
321            )
322
323    def check_gradient_accumulation(self, node: nodes.Module, file_path: str):
324        """Check if gradient accumulation is configured for training.
325
326        Detects cases where training is performed without gradient accumulation,
327        which could be beneficial for handling larger effective batch sizes.
328
329        Args:
330            node: AST node representing the module
331            file_path: Path to the file being analyzed
332        """
333        # Skip if no TrainingArguments used
334        if not self._has_training_arguments(node):
335            return
336
337        gradient_accumulation = False
338        for assign in node.nodes_of_class(nodes.Assign):
339            if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
340                if any(keyword.arg == 'gradient_accumulation_steps' and keyword.value.value > 1
341                       for keyword in assign.value.keywords):
342                    gradient_accumulation = True
343                    break
344
345        # Only report if training configuration is present but gradient accumulation isn't
346        if not gradient_accumulation and self._has_training_code(node):
347            self.add_smell(
348                "Gradient accumulation not configured",
349                "Implement gradient accumulation for large batch sizes",
350                "Allows training with larger effective batch sizes and improves convergence",
351                "Set gradient_accumulation_steps in Trainer configuration",
352                node,
353                file_path
354            )
355
356    def check_learning_rate_scheduling(self, node: nodes.Module, file_path: str):
357        """Check if learning rate scheduling is configured.
358
359        Detects cases where training is performed without learning rate scheduling,
360        which could lead to suboptimal training dynamics.
361
362        Args:
363            node: AST node representing the module
364            file_path: Path to the file being analyzed
365        """
366        # Skip if no TrainingArguments used
367        if not self._has_training_arguments(node):
368            return
369
370        lr_scheduler_used = False
371        for assign in node.nodes_of_class(nodes.Assign):
372            if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
373                if any(keyword.arg in ['learning_rate_scheduler', 'lr_scheduler_type']
374                       for keyword in assign.value.keywords):
375                    lr_scheduler_used = True
376                    break
377
378        # Only report if training configuration is present but lr scheduler isn't
379        if not lr_scheduler_used and self._has_training_code(node):
380            self.add_smell(
381                "Learning rate scheduler not detected",
382                "Use learning rate schedulers to dynamically adjust learning rate",
383                "Optimizes training process and enhances model performance",
384                "Configure lr_scheduler_type in TrainingArguments or use transformers built-in schedulers",
385                node,
386                file_path
387            )
388
389    def check_early_stopping(self, node: nodes.Module, file_path: str):
390        """Check if early stopping is implemented in training.
391
392        Detects cases where training code is present but early stopping
393        mechanisms are not configured, which could lead to overfitting.
394
395        Args:
396            node: AST node representing the module
397            file_path: Path to the file being analyzed
398        """
399        # Skip if no training-related code is present
400        if not self._has_training_code(node):
401            return
402
403        early_stopping_used = False
404        for call in node.nodes_of_class(nodes.Call):
405            if 'EarlyStoppingCallback' in call.func.as_string():
406                early_stopping_used = True
407                break
408
409        # Also check TrainingArguments for early_stopping_* parameters
410        for assign in node.nodes_of_class(nodes.Assign):
411            if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
412                if any(keyword.arg.startswith('early_stopping_') for keyword in assign.value.keywords):
413                    early_stopping_used = True
414                    break
415
416        if not early_stopping_used:
417            self.add_smell(
418                "Early stopping not implemented",
419                "Implement early stopping to avoid overfitting",
420                "Prevents overfitting and reduces unnecessary training time",
421                "Use EarlyStoppingCallback or configure early_stopping parameters in TrainingArguments",
422                node,
423                file_path
424            )
425
426    def _has_training_code(self, node: nodes.Module) -> bool:
427        """Helper method to check if the code contains training-related elements"""
428        training_indicators = [
429            'Trainer',
430            'TrainingArguments',
431            '.train(',
432            'optimizer',
433            'train_dataset',
434            'eval_dataset'
435        ]
436
437        # Check imports
438        has_training_imports = any(
439            any(indicator in name for name, _ in import_node.names)
440            for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom))
441            for indicator in training_indicators
442        )
443
444        # Check function calls and assignments
445        has_training_usage = any(
446            any(indicator in node_item.as_string()
447                for indicator in training_indicators)
448            for node_item in node.nodes_of_class((nodes.Call, nodes.Assign))
449        )
450
451        return has_training_imports or has_training_usage
452
453    def generate_report(self) -> str:
454        """Generate a formatted report of all detected code smells.
455
456        Returns:
457            A string containing the formatted report with all detected smells
458            and their counts.
459        """
460        report = "Hugging Face Code Smell Report\n==============================\n\n"
461        smell_counts = {}
462        for i, smell in enumerate(self.smells, 1):
463            if smell['smell'] not in smell_counts:
464                smell_counts[smell['smell']] = 0
465            smell_counts[smell['smell']] += 1
466
467            report += f"{i}. Smell: {smell['smell']}\n"
468            report += f"   File: {smell['file_path']}\n"
469
470            # Only show line number if it's not 0
471            if smell['line_number'] != 0:
472                report += f"   Line: {smell['line_number']}\n"
473
474            # Only include code snippet if it's 3 lines or fewer
475            code_lines = smell['code_snippet'].strip().split('\n')
476            if len(code_lines) <= 3:
477                report += f"   Code Snippet:\n{smell['code_snippet']}\n"
478
479            report += f"   How to Fix: {smell['how_to_fix']}\n"
480            report += f"   Benefits: {smell['benefits']}\n"
481            report += f"   Strategies: {smell['strategies']}\n\n"
482
483        report += "Smell Counts:\n"
484        for smell, count in smell_counts.items():
485            report += f"  {smell}: {count}\n"
486        report += f"\nTotal smells detected: {len(self.smells)}"
487        return report
488
489    def get_results(self) -> List[Dict[str, str]]:
490        """Get the detected smells in a simplified format.
491
492        Returns:
493            List of dictionaries containing smell details in a simplified format
494            suitable for integration with other tools.
495        """
496        return [
497            {
498                'framework': 'Hugging Face',
499                'name': smell['smell'],
500                'fix': smell['how_to_fix'],
501                'benefits': smell['benefits'],
502                'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
503            }
504            for smell in self.smells
505        ]