1import os
2import sys
3from typing import Any, Dict, List
4
5import astroid
6from astroid import nodes
7
8
9class HuggingFaceSmellDetector:
10 """A detector class that identifies common code smells in Hugging Face Transformers code.
11
12 This detector analyzes Python code that uses the Hugging Face Transformers library and identifies
13 potential issues and best practices violations related to model training, data processing,
14 and performance optimization.
15 """
16
17 def __init__(self):
18 self.smells: List[Dict[str, Any]] = []
19
20 def detect_smells(self, file_path: str) -> List[Dict[str, Any]]:
21 """Analyze a Python file for Hugging Face-related code smells.
22
23 Args:
24 file_path: Path to the Python file to analyze.
25
26 Returns:
27 List of dictionaries containing detected code smells and their details.
28 """
29 try:
30 with open(file_path, 'r') as file:
31 content = file.read()
32 module_name = os.path.splitext(os.path.basename(file_path))[0]
33 module = astroid.parse(content, module_name=module_name)
34
35 # Check if 'transformers' is imported
36 if self.is_framework_used(module, 'transformers'):
37 self.visit_module(module, file_path)
38 else:
39 print(
40 f"Skipping Hugging Face smell detection for {file_path}: 'transformers' not imported",
41 file=sys.stderr)
42 except astroid.exceptions.AstroidSyntaxError as e:
43 print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
44 except Exception as e:
45 print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
46 return self.smells
47
48 def is_framework_used(self, node: nodes.Module, framework: str) -> bool:
49 """Check if a specific framework is imported in the module.
50
51 Args:
52 node: AST node representing the module
53 framework: Name of the framework to check for
54
55 Returns:
56 True if the framework is imported, False otherwise
57 """
58 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
59 if isinstance(import_node, nodes.Import):
60 if any(name.split('.')[0] == framework for name, _ in import_node.names):
61 return True
62 elif isinstance(import_node, nodes.ImportFrom):
63 if import_node.modname.split('.')[0] == framework:
64 return True
65 return False
66
67 def visit_module(self, node: nodes.Module, file_path: str):
68 """Visit a module node and run all smell detection checks.
69
70 Args:
71 node: AST node representing the module
72 file_path: Path to the file being analyzed
73 """
74 self.check_model_versioning(node, file_path)
75 self.check_tokenizer_caching(node, file_path)
76 self.check_model_caching(node, file_path)
77 self.check_deterministic_tokenization(node, file_path)
78 self.check_efficient_data_loading(node, file_path)
79 self.check_distributed_training(node, file_path)
80 self.check_mixed_precision_training(node, file_path)
81 self.check_gradient_accumulation(node, file_path)
82 self.check_learning_rate_scheduling(node, file_path)
83 self.check_early_stopping(node, file_path)
84
85 def add_smell(self, smell: str, fix: str, benefits: str, strategies: str, node: nodes.NodeNG, file_path: str):
86 """Add a detected code smell to the results.
87
88 Args:
89 smell: Description of the code smell
90 fix: How to fix the issue
91 benefits: Benefits of fixing the issue
92 strategies: Specific strategies to implement the fix
93 node: AST node where the smell was detected
94 file_path: Path to the file containing the smell
95 """
96 self.smells.append({
97 "smell": smell,
98 "how_to_fix": fix,
99 "benefits": benefits,
100 "strategies": strategies,
101 "line_number": node.lineno,
102 "code_snippet": node.as_string(),
103 "file_path": file_path
104 })
105
106 def check_model_versioning(self, node: nodes.Module, file_path: str):
107 """Check if model versions are explicitly specified when loading pre-trained models.
108
109 Detects cases where models are loaded without version tags, which could lead to
110 reproducibility issues.
111
112 Args:
113 node: AST node representing the module
114 file_path: Path to the file being analyzed
115 """
116 for call in node.nodes_of_class(nodes.Call):
117 if ('from_pretrained' in call.func.as_string() and
118 ('AutoModel' in call.func.as_string() or
119 'PreTrainedModel' in call.func.as_string())):
120 if not any('@' in arg.as_string() for arg in call.args):
121 self.add_smell(
122 "Model versioning not specified",
123 "Specify model version when loading pre-trained models",
124 "Ensures consistency and reproducibility of results",
125 "Use model_name_or_path@revision when loading models (e.g., bert-base-uncased@v1)",
126 call,
127 file_path
128 )
129
130 def check_tokenizer_caching(self, node: nodes.Module, file_path: str):
131 """Check if tokenizer caching is enabled when loading tokenizers.
132
133 Detects cases where tokenizers are loaded without caching configuration, which
134 could lead to unnecessary re-downloads and slower loading times.
135
136 Args:
137 node: AST node representing the module
138 file_path: Path to the file being analyzed
139 """
140 for call in node.nodes_of_class(nodes.Call):
141 if ('from_pretrained' in call.func.as_string() and
142 ('AutoTokenizer' in call.func.as_string() or
143 'PreTrainedTokenizer' in call.func.as_string())):
144 if not any(keyword.arg in ['cache_dir', 'local_files_only']
145 for keyword in call.keywords):
146 self.add_smell(
147 "Tokenizer caching not used",
148 "Cache tokenizers to avoid re-downloading",
149 "Reduces loading time and network dependency",
150 "Use cache_dir parameter when loading tokenizers",
151 call,
152 file_path
153 )
154
155 def check_model_caching(self, node: nodes.Module, file_path: str):
156 """Check if model caching is enabled when loading models.
157
158 Detects cases where models are loaded without caching configuration, which
159 could lead to unnecessary re-downloads and slower loading times.
160
161 Args:
162 node: AST node representing the module
163 file_path: Path to the file being analyzed
164 """
165 for call in node.nodes_of_class(nodes.Call):
166 if ('from_pretrained' in call.func.as_string() and
167 ('AutoModel' in call.func.as_string() or
168 'PreTrainedModel' in call.func.as_string())):
169 if not any(keyword.arg in ['cache_dir', 'local_files_only']
170 for keyword in call.keywords):
171 self.add_smell(
172 "Model caching not used",
173 "Cache models to avoid re-downloading",
174 "Improves loading efficiency and reduces network dependency",
175 "Use cache_dir parameter when loading models",
176 call,
177 file_path
178 )
179
180 def check_deterministic_tokenization(self, node: nodes.Module, file_path: str):
181 """Check if tokenization parameters are explicitly specified.
182
183 Detects cases where tokenization settings are not explicitly defined,
184 which could lead to inconsistent preprocessing across runs.
185
186 Args:
187 node: AST node representing the module
188 file_path: Path to the file being analyzed
189 """
190 for call in node.nodes_of_class(nodes.Call):
191 if ('from_pretrained' in call.func.as_string() and
192 ('AutoTokenizer' in call.func.as_string() or
193 'PreTrainedTokenizer' in call.func.as_string())):
194 deterministic_params = [
195 'do_lower_case', 'strip_accents', 'truncation',
196 'padding', 'max_length', 'return_tensors'
197 ]
198 if not any(keyword.arg in deterministic_params for keyword in call.keywords):
199 self.add_smell(
200 "Deterministic tokenization settings not specified",
201 "Use consistent tokenization settings",
202 "Ensures reproducible pre-processing and consistent model inputs",
203 "Set tokenization parameters explicitly when loading tokenizers",
204 call,
205 file_path
206 )
207
208 def check_efficient_data_loading(self, node: nodes.Module, file_path: str):
209 """Check if efficient data loading techniques are being used.
210
211 Detects cases where standard data loading is used instead of optimized
212 methods like datasets library or DataLoader.
213
214 Args:
215 node: AST node representing the module
216 file_path: Path to the file being analyzed
217 """
218 datasets_imported = any('datasets' in import_node.names[0][0]
219 for import_node in node.nodes_of_class(nodes.ImportFrom))
220
221 efficient_patterns = [
222 'load_dataset',
223 'Dataset.from_',
224 'DataLoader',
225 'IterableDataset'
226 ]
227
228 has_efficient_loading = any(
229 pattern in call.func.as_string()
230 for call in node.nodes_of_class(nodes.Call)
231 for pattern in efficient_patterns
232 )
233
234 if not (datasets_imported or has_efficient_loading):
235 self.add_smell(
236 "Efficient data loading not detected",
237 "Use efficient data loading techniques",
238 "Enhances data processing speed and model training efficiency",
239 "Use datasets library for loading and processing data",
240 node,
241 file_path
242 )
243
244 def check_distributed_training(self, node: nodes.Module, file_path: str):
245 """Check if distributed training is configured when using training functionality.
246
247 Detects cases where training code is present but distributed training
248 settings are not configured.
249
250 Args:
251 node: AST node representing the module
252 file_path: Path to the file being analyzed
253 """
254 # Check if training-related imports exist
255 has_training_imports = any(
256 'Trainer' in import_node.names[0][0] or 'TrainingArguments' in import_node.names[0][0]
257 for import_node in node.nodes_of_class(nodes.ImportFrom)
258 )
259
260 if not has_training_imports:
261 return # Skip if no training-related imports
262
263 distributed_config = False
264 for assign in node.nodes_of_class(nodes.Assign):
265 if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
266 if any(keyword.arg in ['local_rank', 'n_gpu', 'distributed_training', 'tpu_num_cores']
267 for keyword in assign.value.keywords):
268 distributed_config = True
269 break
270
271 # Only report if TrainingArguments is used but without distributed config
272 if not distributed_config and self._has_training_arguments(node):
273 self.add_smell(
274 "Distributed training not configured",
275 "Utilize distributed training capabilities",
276 "Speeds up training and leverages multiple GPUs/TPUs",
277 "Configure Trainer with distributed settings using local_rank, n_gpu, or tpu_num_cores",
278 node,
279 file_path
280 )
281
282 def _has_training_arguments(self, node: nodes.Module) -> bool:
283 """Helper method to check if TrainingArguments is actually used in the code"""
284 return any(
285 isinstance(assign.targets[0], nodes.AssignName) and
286 assign.targets[0].name == 'TrainingArguments'
287 for assign in node.nodes_of_class(nodes.Assign)
288 )
289
290 def check_mixed_precision_training(self, node: nodes.Module, file_path: str):
291 """Check if mixed precision training is enabled.
292
293 Detects cases where training is performed without mixed precision settings,
294 which could lead to suboptimal performance and memory usage.
295
296 Args:
297 node: AST node representing the module
298 file_path: Path to the file being analyzed
299 """
300 if not self._has_training_arguments(node):
301 return # Skip if no TrainingArguments used
302
303 fp16_used = False
304 for assign in node.nodes_of_class(nodes.Assign):
305 if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
306 if any((keyword.arg == 'fp16' and keyword.value.value) or
307 (keyword.arg == 'bf16' and keyword.value.value) or
308 keyword.arg == 'half_precision_backend'
309 for keyword in assign.value.keywords):
310 fp16_used = True
311 break
312
313 if not fp16_used:
314 self.add_smell(
315 "Mixed precision training not enabled",
316 "Use mixed precision training to improve performance",
317 "Accelerates training and reduces memory usage",
318 "Enable mixed precision training using fp16=True or bf16=True in TrainingArguments",
319 node,
320 file_path
321 )
322
323 def check_gradient_accumulation(self, node: nodes.Module, file_path: str):
324 """Check if gradient accumulation is configured for training.
325
326 Detects cases where training is performed without gradient accumulation,
327 which could be beneficial for handling larger effective batch sizes.
328
329 Args:
330 node: AST node representing the module
331 file_path: Path to the file being analyzed
332 """
333 # Skip if no TrainingArguments used
334 if not self._has_training_arguments(node):
335 return
336
337 gradient_accumulation = False
338 for assign in node.nodes_of_class(nodes.Assign):
339 if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
340 if any(keyword.arg == 'gradient_accumulation_steps' and keyword.value.value > 1
341 for keyword in assign.value.keywords):
342 gradient_accumulation = True
343 break
344
345 # Only report if training configuration is present but gradient accumulation isn't
346 if not gradient_accumulation and self._has_training_code(node):
347 self.add_smell(
348 "Gradient accumulation not configured",
349 "Implement gradient accumulation for large batch sizes",
350 "Allows training with larger effective batch sizes and improves convergence",
351 "Set gradient_accumulation_steps in Trainer configuration",
352 node,
353 file_path
354 )
355
356 def check_learning_rate_scheduling(self, node: nodes.Module, file_path: str):
357 """Check if learning rate scheduling is configured.
358
359 Detects cases where training is performed without learning rate scheduling,
360 which could lead to suboptimal training dynamics.
361
362 Args:
363 node: AST node representing the module
364 file_path: Path to the file being analyzed
365 """
366 # Skip if no TrainingArguments used
367 if not self._has_training_arguments(node):
368 return
369
370 lr_scheduler_used = False
371 for assign in node.nodes_of_class(nodes.Assign):
372 if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
373 if any(keyword.arg in ['learning_rate_scheduler', 'lr_scheduler_type']
374 for keyword in assign.value.keywords):
375 lr_scheduler_used = True
376 break
377
378 # Only report if training configuration is present but lr scheduler isn't
379 if not lr_scheduler_used and self._has_training_code(node):
380 self.add_smell(
381 "Learning rate scheduler not detected",
382 "Use learning rate schedulers to dynamically adjust learning rate",
383 "Optimizes training process and enhances model performance",
384 "Configure lr_scheduler_type in TrainingArguments or use transformers built-in schedulers",
385 node,
386 file_path
387 )
388
389 def check_early_stopping(self, node: nodes.Module, file_path: str):
390 """Check if early stopping is implemented in training.
391
392 Detects cases where training code is present but early stopping
393 mechanisms are not configured, which could lead to overfitting.
394
395 Args:
396 node: AST node representing the module
397 file_path: Path to the file being analyzed
398 """
399 # Skip if no training-related code is present
400 if not self._has_training_code(node):
401 return
402
403 early_stopping_used = False
404 for call in node.nodes_of_class(nodes.Call):
405 if 'EarlyStoppingCallback' in call.func.as_string():
406 early_stopping_used = True
407 break
408
409 # Also check TrainingArguments for early_stopping_* parameters
410 for assign in node.nodes_of_class(nodes.Assign):
411 if isinstance(assign.targets[0], nodes.AssignName) and assign.targets[0].name == 'TrainingArguments':
412 if any(keyword.arg.startswith('early_stopping_') for keyword in assign.value.keywords):
413 early_stopping_used = True
414 break
415
416 if not early_stopping_used:
417 self.add_smell(
418 "Early stopping not implemented",
419 "Implement early stopping to avoid overfitting",
420 "Prevents overfitting and reduces unnecessary training time",
421 "Use EarlyStoppingCallback or configure early_stopping parameters in TrainingArguments",
422 node,
423 file_path
424 )
425
426 def _has_training_code(self, node: nodes.Module) -> bool:
427 """Helper method to check if the code contains training-related elements"""
428 training_indicators = [
429 'Trainer',
430 'TrainingArguments',
431 '.train(',
432 'optimizer',
433 'train_dataset',
434 'eval_dataset'
435 ]
436
437 # Check imports
438 has_training_imports = any(
439 any(indicator in name for name, _ in import_node.names)
440 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom))
441 for indicator in training_indicators
442 )
443
444 # Check function calls and assignments
445 has_training_usage = any(
446 any(indicator in node_item.as_string()
447 for indicator in training_indicators)
448 for node_item in node.nodes_of_class((nodes.Call, nodes.Assign))
449 )
450
451 return has_training_imports or has_training_usage
452
453 def generate_report(self) -> str:
454 """Generate a formatted report of all detected code smells.
455
456 Returns:
457 A string containing the formatted report with all detected smells
458 and their counts.
459 """
460 report = "Hugging Face Code Smell Report\n==============================\n\n"
461 smell_counts = {}
462 for i, smell in enumerate(self.smells, 1):
463 if smell['smell'] not in smell_counts:
464 smell_counts[smell['smell']] = 0
465 smell_counts[smell['smell']] += 1
466
467 report += f"{i}. Smell: {smell['smell']}\n"
468 report += f" File: {smell['file_path']}\n"
469
470 # Only show line number if it's not 0
471 if smell['line_number'] != 0:
472 report += f" Line: {smell['line_number']}\n"
473
474 # Only include code snippet if it's 3 lines or fewer
475 code_lines = smell['code_snippet'].strip().split('\n')
476 if len(code_lines) <= 3:
477 report += f" Code Snippet:\n{smell['code_snippet']}\n"
478
479 report += f" How to Fix: {smell['how_to_fix']}\n"
480 report += f" Benefits: {smell['benefits']}\n"
481 report += f" Strategies: {smell['strategies']}\n\n"
482
483 report += "Smell Counts:\n"
484 for smell, count in smell_counts.items():
485 report += f" {smell}: {count}\n"
486 report += f"\nTotal smells detected: {len(self.smells)}"
487 return report
488
489 def get_results(self) -> List[Dict[str, str]]:
490 """Get the detected smells in a simplified format.
491
492 Returns:
493 List of dictionaries containing smell details in a simplified format
494 suitable for integration with other tools.
495 """
496 return [
497 {
498 'framework': 'Hugging Face',
499 'name': smell['smell'],
500 'fix': smell['how_to_fix'],
501 'benefits': smell['benefits'],
502 'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
503 }
504 for smell in self.smells
505 ]