1import os
2import sys
3from typing import Any, Dict, List
4
5import astroid
6from astroid import nodes
7
8
9class ML_SmellDetector:
10 """A detector for identifying common code smells in machine learning code.
11
12 This class analyzes Python source code to detect potential issues and anti-patterns
13 commonly found in machine learning projects, such as data leakage, improper feature
14 scaling, missing cross-validation, and more.
15 """
16
17 def __init__(self):
18 """Initialize the ML smell detector with empty collections for tracking analysis results."""
19 self.smells: List[Dict[str, Any]] = []
20 self.imports: Dict[str, Any] = {}
21 self.variables: Dict[str, Any] = {}
22 self.functions: Dict[str, Any] = {}
23 self.classes: Dict[str, Any] = {}
24
25 def add_smell(self, smell: str, node: nodes.NodeNG, file_path: str):
26 """Add a detected code smell to the collection.
27
28 Args:
29 smell: Description of the detected smell
30 node: The AST node where the smell was detected
31 file_path: Path to the file containing the smell
32 """
33 self.smells.append({
34 "smell": smell,
35 "line_number": node.lineno,
36 "code_snippet": node.as_string(),
37 "file_path": file_path
38 })
39
40 def detect_smells(self, file_path: str) -> List[Dict[str, Any]]:
41 """Analyze a Python file for ML-related code smells.
42
43 Args:
44 file_path: Path to the Python file to analyze
45
46 Returns:
47 List of dictionaries containing detected smell information
48 """
49 try:
50 with open(file_path, 'r') as file:
51 content = file.read()
52 module_name = os.path.splitext(os.path.basename(file_path))[0]
53 module = astroid.parse(content, module_name=module_name)
54
55 # Check if any ML-related packages are imported
56 ml_packages = ['pandas', 'numpy', 'sklearn', 'tensorflow', 'torch', 'transformers']
57 if any(self.is_package_used(module, package) for package in ml_packages):
58 self.visit_module(module, file_path)
59 else:
60 print(f"Skipping ML smell detection for {file_path}: No ML-related packages imported", file=sys.stderr)
61 except astroid.exceptions.AstroidSyntaxError as e:
62 print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
63 except Exception as e:
64 print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
65 return self.smells
66
67 def is_package_used(self, node: nodes.Module, package: str) -> bool:
68 """Check if a specific package is imported in the module.
69
70 Args:
71 node: AST node representing the module
72 package: Name of the package to check for
73
74 Returns:
75 True if the package is imported, False otherwise
76 """
77 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
78 if isinstance(import_node, nodes.Import):
79 if any(name.split('.')[0] == package for name, _ in import_node.names):
80 return True
81 elif isinstance(import_node, nodes.ImportFrom):
82 if import_node.modname.split('.')[0] == package:
83 return True
84 return False
85
86 def visit_module(self, node: nodes.Module, file_path: str):
87 """Run all smell detection checks on a module.
88
89 Args:
90 node: AST node representing the module
91 file_path: Path to the file being analyzed
92 """
93 self.check_imports(node, file_path)
94 self.check_data_leakage(node, file_path)
95 self.check_magic_numbers(node, file_path)
96 self.check_feature_scaling(node, file_path)
97 self.check_cross_validation(node, file_path)
98 self.check_imbalanced_dataset(node, file_path)
99 self.check_feature_selection(node, file_path)
100 self.check_metric_selection(node, file_path)
101 self.check_model_persistence(node, file_path)
102 self.check_reproducibility(node, file_path)
103 self.check_data_loading(node, file_path)
104 self.check_unused_features(node, file_path)
105 self.check_overfit_prone_practices(node, file_path)
106 self.check_error_handling(node, file_path)
107 self.check_hardcoded_filepaths(node, file_path)
108 self.check_documentation(node, file_path)
109
110 def check_imports(self, node: nodes.Module, file_path: str):
111 """Track imports used in the module for later analysis.
112
113 Args:
114 node: AST node representing the module
115 file_path: Path to the file being analyzed
116 """
117 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
118 if isinstance(import_node, nodes.Import):
119 for name, alias in import_node.names:
120 self.imports[alias or name] = name
121 elif isinstance(import_node, nodes.ImportFrom):
122 for name, alias in import_node.names:
123 self.imports[alias or name] = f"{import_node.modname}.{name}"
124
125 @staticmethod
126 def _is_inside_function(call_node):
127 """Return True if call_node is nested inside a FunctionDef."""
128 parent = call_node.parent
129 while parent:
130 if isinstance(parent, nodes.FunctionDef):
131 return True
132 if isinstance(parent, nodes.Module):
133 break
134 parent = parent.parent
135 return False
136
137 def _check_leakage_in_scope(self, scope_node, report_node, file_path):
138 """Check a single scope (module or function) for fit-before-split leakage."""
139 preprocessing_before_split = False
140 train_test_split_pos = -1
141 last_preprocessing_pos = -1
142
143 for call in scope_node.nodes_of_class(nodes.Call):
144 if call.func.as_string().endswith(('fit', 'fit_transform')):
145 preprocessing_before_split = True
146 last_preprocessing_pos = call.lineno
147 elif 'train_test_split' in call.func.as_string():
148 train_test_split_pos = call.lineno
149
150 if preprocessing_before_split and train_test_split_pos > 0:
151 if last_preprocessing_pos < train_test_split_pos:
152 self.add_smell(
153 "Potential data leakage: Preprocessing applied before train-test split",
154 report_node,
155 file_path)
156
157 def check_data_leakage(self, node: nodes.Module, file_path: str):
158 """Detect potential data leakage from preprocessing before train-test split.
159
160 Args:
161 node: AST node representing the module
162 file_path: Path to the file being analyzed
163 """
164 # Check inside each function/method (existing behaviour)
165 for func in node.nodes_of_class(nodes.FunctionDef):
166 self._check_leakage_in_scope(func, func, file_path)
167
168 # Also check module-level code that sits outside any function
169 preprocessing_before_split = False
170 train_test_split_pos = -1
171 last_preprocessing_pos = -1
172 for call in node.nodes_of_class(nodes.Call):
173 if self._is_inside_function(call):
174 continue
175 if call.func.as_string().endswith(('fit', 'fit_transform')):
176 preprocessing_before_split = True
177 last_preprocessing_pos = call.lineno
178 elif 'train_test_split' in call.func.as_string():
179 train_test_split_pos = call.lineno
180 if preprocessing_before_split and train_test_split_pos > 0:
181 if last_preprocessing_pos < train_test_split_pos:
182 self.add_smell(
183 "Potential data leakage: Preprocessing applied before train-test split "
184 "(module-level code)",
185 node,
186 file_path)
187
188 def check_magic_numbers(self, node: nodes.Module, file_path: str):
189 """Detect magic numbers in ML-related code.
190
191 Args:
192 node: AST node representing the module
193 file_path: Path to the file being analyzed
194 """
195 # Common acceptable values that shouldn't trigger warnings
196 acceptable_values = {0, 1, -1, 100, 0.5, 2} # Common ML-related constants
197 for assign in node.nodes_of_class(nodes.Assign):
198 if isinstance(assign.value, nodes.Const) and isinstance(assign.value.value, (int, float)):
199 # Skip if it's an acceptable value
200 if assign.value.value in acceptable_values:
201 continue
202 # Skip if the variable name suggests it's a legitimate constant
203 if any(assign.targets[0].as_string().lower().startswith(prefix) for prefix in
204 ['num_', 'size_', 'batch_', 'epoch', 'learning_rate', 'lr_', 'threshold_']):
205 continue
206 self.add_smell(f"Magic number detected: {assign.value.value}", assign, file_path)
207
208 def check_feature_scaling(self, node: nodes.Module, file_path: str):
209 """Detect inconsistent feature scaling methods across the codebase.
210
211 Looks for multiple different scaling methods (StandardScaler, MinMaxScaler, etc.)
212 being used, which could lead to inconsistent results.
213
214 Args:
215 node: AST node representing the module
216 file_path: Path to the file being analyzed
217 """
218 scaling_methods = ['StandardScaler', 'MinMaxScaler', 'RobustScaler']
219 scalers_used = set()
220
221 for call in node.nodes_of_class(nodes.Call):
222 if any(method in call.func.as_string() for method in scaling_methods):
223 scalers_used.add(call.func.as_string())
224
225 # Only raise warning if multiple different scaling methods are used
226 if len(scalers_used) > 1:
227 scalers = ', '.join(scalers_used)
228 self.add_smell(
229 f"Inconsistent scaling methods detected: {scalers}. Consider using the same scaler across the pipeline.",
230 call,
231 file_path)
232
233 def check_cross_validation(self, node: nodes.Module, file_path: str):
234 """Check if cross-validation is properly implemented in model training.
235
236 Detects if common cross-validation methods (KFold, cross_val_score, etc.)
237 are missing in model training code.
238
239 Args:
240 node: AST node representing the module
241 file_path: Path to the file being analyzed
242 """
243 cv_methods = [
244 'cross_val_score',
245 'KFold',
246 'cross_validate',
247 'GridSearchCV',
248 'RandomizedSearchCV',
249 'TimeSeriesSplit']
250 cv_detected = False
251 is_training_file = False
252
253 # Skip if this is likely not a main training file
254 if any(pattern in file_path.lower() for pattern in [
255 'test_', 'utils', 'helper', 'preprocessing', 'visualization',
256 'evaluate', 'predict', 'inference', 'deploy'
257 ]):
258 return
259
260 # Check imports to see if this is likely a training file
261 training_imports = {'sklearn.model_selection', 'sklearn.linear_model', 'sklearn.ensemble',
262 'tensorflow', 'torch', 'xgboost', 'lightgbm'}
263 has_training_imports = False
264 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
265 if isinstance(import_node, nodes.ImportFrom):
266 if any(imp in import_node.modname for imp in training_imports):
267 has_training_imports = True
268 break
269
270 if not has_training_imports:
271 return
272
273 for call in node.nodes_of_class(nodes.Call):
274 # Check if file contains model training code
275 if any(method in call.func.as_string() for method in ['fit', 'train', 'compile']):
276 # Skip if it's in a test method
277 if any(ancestor.name.startswith('test_') for ancestor in call.node_ancestors()
278 if isinstance(ancestor, nodes.FunctionDef)):
279 continue
280 is_training_file = True
281
282 # Check for CV methods
283 if any(method in call.func.as_string() for method in cv_methods):
284 cv_detected = True
285 break
286
287 # Also check for custom CV implementations
288 if ('split' in call.func.as_string() and
289 any(val in call.as_string() for val in ['fold', 'cv', 'validation'])):
290 cv_detected = True
291 break
292
293 # Only raise warning if it's a substantial training file (has multiple model-related calls)
294 model_related_calls = sum(1 for call in node.nodes_of_class(nodes.Call)
295 if any(term in call.func.as_string()
296 for term in ['fit', 'train', 'predict', 'score']))
297
298 if (is_training_file and not cv_detected and model_related_calls >= 2
299 and not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
300 self.add_smell(
301 "Cross-validation not detected in model training code. "
302 "Consider using cross-validation for more robust evaluation.",
303 node, file_path
304 )
305
306 def check_imbalanced_dataset(self, node: nodes.Module, file_path: str):
307 """Check if imbalanced dataset handling techniques are used in classification tasks.
308
309 Looks for common techniques like SMOTE, class weights, or stratification
310 when dealing with classification problems.
311
312 Args:
313 node: AST node representing the module
314 file_path: Path to the file being analyzed
315 """
316 # Skip if this is likely not a main training file
317 if any(pattern in file_path.lower() for pattern in [
318 'test_', 'utils', 'helper', 'visualization', 'evaluate',
319 'predict', 'inference', 'deploy', 'preprocess'
320 ]):
321 return
322
323 balance_methods = [
324 'SMOTE', 'class_weight', 'StratifiedKFold', 'RandomOverSampler',
325 'RandomUnderSampler', 'sample_weight', 'balanced_accuracy',
326 'WeightedRandomSampler', 'BalancedBaggingClassifier'
327 ]
328 imbalance_handling = False
329 is_classification = False
330 has_data_processing = False
331
332 # Check imports first
333 classification_imports = {
334 'sklearn.linear_model', 'sklearn.ensemble', 'sklearn.svm',
335 'sklearn.tree', 'sklearn.naive_bayes', 'xgboost', 'lightgbm'
336 }
337 has_classification_imports = False
338 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
339 if isinstance(import_node, nodes.ImportFrom):
340 if any(imp in import_node.modname for imp in classification_imports):
341 has_classification_imports = True
342 break
343
344 if not has_classification_imports:
345 return
346
347 # Check if this is a classification task
348 for call in node.nodes_of_class(nodes.Call):
349 if any(clf in call.func.as_string() for clf in [
350 'Classifier', 'LogisticRegression', 'SVC', 'RandomForestClassifier',
351 'GradientBoostingClassifier', 'XGBClassifier', 'LGBMClassifier',
352 'DecisionTreeClassifier', 'KNeighborsClassifier'
353 ]):
354 is_classification = True
355
356 # Check for data processing/analysis that might indicate class distribution checks
357 if any(term in call.func.as_string() for term in [
358 'value_counts', 'unique', 'hist', 'countplot', 'distribution',
359 'balance_ratio', 'class_distribution'
360 ]):
361 has_data_processing = True
362
363 # Check for imbalance handling methods
364 if any(method in call.func.as_string() for method in balance_methods):
365 imbalance_handling = True
366 break
367
368 # Check for custom handling in strings (e.g., parameter names)
369 if isinstance(call.func, nodes.Attribute):
370 if any(term in str(call.args) + str(call.keywords) for term in [
371 'weight', 'balanced', 'stratif'
372 ]):
373 imbalance_handling = True
374 break
375
376 # Count model-related calls to ensure it's a substantial training file
377 model_related_calls = sum(1 for call in node.nodes_of_class(nodes.Call)
378 if any(term in call.func.as_string()
379 for term in ['fit', 'train', 'predict', 'score']))
380
381 # Only raise warning if:
382 # 1. It's a classification task
383 # 2. No imbalance handling detected
384 # 3. Has multiple model-related calls
385 # 4. Has data processing (suggesting actual data analysis)
386 # 5. Not a quick example/demo
387 if (is_classification and not imbalance_handling and
388 model_related_calls >= 2 and has_data_processing and
389 not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
390
391 self.add_smell(
392 "No imbalanced dataset handling detected in classification task. "
393 "Consider techniques like SMOTE or class weights if dealing with imbalanced data.",
394 node, file_path
395 )
396
397 def check_feature_selection(self, node: nodes.Module, file_path: str):
398 """Detect feature selection practices and validate their implementation.
399
400 Ensures feature selection is performed with proper validation strategy
401 to avoid selection bias.
402
403 Args:
404 node: AST node representing the module
405 file_path: Path to the file being analyzed
406 """
407 # Skip if this is likely not a feature selection file
408 if any(pattern in file_path.lower() for pattern in [
409 'test_', 'utils', 'helper', 'visualization',
410 'predict', 'inference', 'deploy', 'evaluate'
411 ]):
412 return
413
414 feature_selection_methods = [
415 'SelectKBest', 'RFE', 'SelectFromModel', 'PCA', 'VarianceThreshold',
416 'mutual_info', 'chi2', 'f_classif', 'SelectPercentile',
417 'GenericUnivariateSelect', 'RFECV', 'SequentialFeatureSelector'
418 ]
419 validation_methods = [
420 'cross_val_score', 'train_test_split', 'GridSearchCV',
421 'RandomizedSearchCV', 'KFold', 'StratifiedKFold',
422 'TimeSeriesSplit', 'cross_validate'
423 ]
424 feature_selection = False
425 validation_detected = False
426 has_ml_imports = False
427
428 # Check for relevant imports first
429 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
430 if isinstance(import_node, nodes.ImportFrom):
431 if any(pkg in import_node.modname for pkg in [
432 'sklearn.feature_selection', 'sklearn.decomposition',
433 'sklearn.model_selection'
434 ]):
435 has_ml_imports = True
436 break
437
438 if not has_ml_imports:
439 return
440
441 for call in node.nodes_of_class(nodes.Call):
442 # Check for feature selection methods
443 if any(method in call.func.as_string() for method in feature_selection_methods):
444 feature_selection = True
445
446 # Check for validation methods
447 if any(method in call.func.as_string() for method in validation_methods):
448 validation_detected = True
449
450 # Check for custom validation in parameter names or strings
451 if isinstance(call.func, nodes.Attribute):
452 if any(term in str(call.args) + str(call.keywords) for term in
453 ['valid', 'test', 'split', 'cv', 'fold']):
454 validation_detected = True
455
456 # Count substantial ML operations
457 ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
458 if any(term in call.func.as_string()
459 for term in ['fit', 'transform', 'predict', 'score']))
460
461 if (feature_selection and not validation_detected and
462 ml_operations >= 2 and
463 not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
464 self.add_smell(
465 "Feature selection detected without clear validation strategy. "
466 "Ensure it's applied with proper validation to avoid selection bias.",
467 node, file_path
468 )
469
470 def check_metric_selection(self, node: nodes.Module, file_path: str):
471 """Validate the choice of evaluation metrics for ML models.
472
473 Ensures appropriate metrics are used for classification and regression tasks,
474 warning against using only basic metrics like accuracy.
475
476 Args:
477 node: AST node representing the module
478 file_path: Path to the file being analyzed
479 """
480 # Skip if this is likely not an evaluation file
481 if any(pattern in file_path.lower() for pattern in [
482 'test_', 'utils', 'helper', 'preprocess',
483 'data', 'feature', 'transform'
484 ]):
485 return
486
487 metrics = set()
488 is_classification = False
489 is_regression = False
490 has_ml_imports = False
491
492 # Check imports first
493 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
494 if isinstance(import_node, nodes.ImportFrom):
495 if any(pkg in import_node.modname for pkg in [
496 'sklearn.metrics', 'sklearn.model_selection',
497 'sklearn.linear_model', 'sklearn.ensemble'
498 ]):
499 has_ml_imports = True
500 break
501
502 if not has_ml_imports:
503 return
504
505 # Determine if it's classification or regression
506 for call in node.nodes_of_class(nodes.Call):
507 if any(clf in call.func.as_string() for clf in [
508 'Classifier', 'LogisticRegression', 'SVC', 'RandomForestClassifier',
509 'GradientBoostingClassifier', 'XGBClassifier', 'LGBMClassifier'
510 ]):
511 is_classification = True
512 elif any(reg in call.func.as_string() for reg in [
513 'Regressor', 'LinearRegression', 'SVR', 'RandomForestRegressor',
514 'GradientBoostingRegressor', 'XGBRegressor', 'LGBMRegressor'
515 ]):
516 is_regression = True
517
518 # Collect metrics
519 if any(metric in call.func.as_string() for metric in [
520 'accuracy_score', 'precision_score', 'recall_score', 'f1_score',
521 'mean_squared_error', 'r2_score', 'mean_absolute_error',
522 'roc_auc_score', 'average_precision_score', 'confusion_matrix',
523 'classification_report', 'explained_variance_score'
524 ]):
525 metrics.add(call.func.as_string())
526
527 # Check for custom metric implementations
528 if isinstance(call.func, nodes.Attribute):
529 if any(term in str(call.args) + str(call.keywords) for term in [
530 'metric', 'score', 'evaluation', 'performance'
531 ]):
532 metrics.add('custom_metric')
533
534 # Count substantial ML operations
535 ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
536 if any(term in call.func.as_string()
537 for term in ['fit', 'predict', 'score', 'evaluate']))
538
539 # Only raise warnings for specific cases with substantial ML usage
540 if ml_operations >= 2 and not any(term in file_path.lower() for term in ['quick', 'example', 'demo']):
541 if is_classification and 'accuracy_score' in metrics and len(metrics) == 1:
542 self.add_smell(
543 "Only accuracy metric detected for classification. "
544 "Consider adding precision, recall, or F1-score for a more comprehensive evaluation.",
545 node, file_path
546 )
547 elif is_regression and 'mean_squared_error' in metrics and len(metrics) == 1:
548 self.add_smell(
549 "Only MSE detected for regression. "
550 "Consider adding R2 score or MAE for a more comprehensive evaluation.",
551 node, file_path
552 )
553
554 def check_model_persistence(self, node: nodes.Module, file_path: str):
555 """Check model saving practices and associated preprocessing steps.
556
557 Ensures models are saved with their preprocessing steps and proper versioning
558 for reproducibility.
559
560 Args:
561 node: AST node representing the module
562 file_path: Path to the file being analyzed
563 """
564 # Skip if this is likely not a model saving file
565 if any(pattern in file_path.lower() for pattern in [
566 'test_', 'utils', 'helper', 'data', 'preprocess',
567 'explore', 'analyze', 'visualize'
568 ]):
569 return
570
571 model_save = False
572 preprocessing_save = False
573 version_control = False
574 has_ml_imports = False
575
576 # Check for relevant imports first
577 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
578 if isinstance(import_node, nodes.ImportFrom):
579 if any(pkg in import_node.modname for pkg in [
580 'sklearn', 'tensorflow', 'torch', 'joblib', 'pickle'
581 ]):
582 has_ml_imports = True
583 break
584
585 if not has_ml_imports:
586 return
587
588 # Check for model training/fitting first
589 has_model_training = False
590 for call in node.nodes_of_class(nodes.Call):
591 if any(term in call.func.as_string() for term in ['fit', 'train']):
592 has_model_training = True
593 break
594
595 if not has_model_training:
596 return
597
598 for call in node.nodes_of_class(nodes.Call):
599 # Check for model saving operations
600 if any(save in call.func.as_string() for save in [
601 'save', 'dump', 'to_pickle', 'save_model', 'save_weights',
602 'torch.save', 'joblib.dump'
603 ]):
604 model_save = True
605
606 # Check for preprocessing steps being saved
607 if any(prep in call.as_string().lower() for prep in [
608 'scaler', 'encoder', 'preprocessor', 'pipeline',
609 'transform', 'processor', 'tokenizer'
610 ]):
611 preprocessing_save = True
612
613 # Check for version control
614 if any(ver in call.as_string().lower() for ver in [
615 'version', 'v1', 'v2', 'v3', 'timestamp', 'date',
616 '_v', '.v', 'release'
617 ]):
618 version_control = True
619
620 # Check for version control in variable names
621 if isinstance(call.func, nodes.Attribute):
622 if any(ver in str(call.args) + str(call.keywords) for ver in [
623 'version', 'timestamp', 'date', 'release'
624 ]):
625 version_control = True
626
627 # Only raise warnings if we have substantial model operations
628 ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
629 if any(term in call.func.as_string()
630 for term in ['fit', 'train', 'predict', 'transform']))
631
632 if ml_operations >= 2 and not any(term in file_path.lower() for term in ['quick', 'example', 'demo']):
633 if model_save and not preprocessing_save:
634 self.add_smell(
635 "Model saving detected without preprocessing steps. "
636 "Remember to save preprocessing steps for proper model deployment.",
637 node, file_path
638 )
639 elif model_save and not version_control:
640 self.add_smell(
641 "Model saving detected without clear versioning. "
642 "Consider adding version control for model artifacts.",
643 node, file_path
644 )
645
646 def check_reproducibility(self, node: nodes.Module, file_path: str):
647 """Check if random seeds are properly set for reproducibility.
648
649 Ensures random seeds are set for all relevant libraries (numpy, random,
650 framework-specific) in ML operations.
651
652 Args:
653 node: AST node representing the module
654 file_path: Path to the file being analyzed
655 """
656 # Skip if this is likely not a training file
657 if any(pattern in file_path.lower() for pattern in [
658 'test_', 'utils', 'helper', 'data', 'preprocess',
659 'explore', 'analyze', 'visualize', 'inference'
660 ]):
661 return
662
663 seed_methods = {
664 'random_state', 'seed', 'torch.manual_seed', 'tf.random.set_seed',
665 'np.random.seed', 'random.seed', 'cuda.manual_seed',
666 'cuda.manual_seed_all', 'tensorflow.random.set_seed'
667 }
668 seeds_set = set()
669 has_ml_operations = False
670 has_ml_imports = False
671
672 # Check imports first
673 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
674 if isinstance(import_node, nodes.ImportFrom):
675 if any(pkg in import_node.modname for pkg in [
676 'sklearn', 'tensorflow', 'torch', 'numpy', 'random'
677 ]):
678 has_ml_imports = True
679 break
680
681 if not has_ml_imports:
682 return
683
684 for call in node.nodes_of_class(nodes.Call):
685 # Check if file contains substantial ML operations
686 if any(op in call.func.as_string() for op in [
687 'fit', 'train', 'predict', 'transform', 'split',
688 'sample', 'shuffle', 'random'
689 ]):
690 has_ml_operations = True
691
692 # Check for seed setting
693 for seed_method in seed_methods:
694 if seed_method in call.as_string():
695 seeds_set.add(seed_method)
696
697 # Check for seed parameters in ML operations
698 if isinstance(call.func, nodes.Attribute):
699 if any(seed in str(call.args) + str(call.keywords) for seed in [
700 'random_state', 'seed', 'deterministic'
701 ]):
702 seeds_set.add('parameter_seed')
703
704 # Count substantial ML operations
705 ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
706 if any(term in call.func.as_string()
707 for term in ['fit', 'train', 'predict', 'transform']))
708
709 if (ml_operations >= 2 and has_ml_operations and
710 not any(term in file_path.lower() for term in ['quick', 'example', 'demo'])):
711 if not seeds_set:
712 self.add_smell(
713 "No random seed setting detected in ML operations. "
714 "Consider setting seeds for reproducibility.",
715 node, file_path
716 )
717 elif len(seeds_set) < 2 and any(framework in self.imports for framework in ['tensorflow', 'torch']):
718 self.add_smell(
719 "Incomplete seed setting detected. Remember to set seeds for all "
720 "relevant libraries (numpy, random, framework-specific).",
721 node, file_path
722 )
723
724 def check_data_loading(self, node: nodes.Module, file_path: str):
725 """Analyze data loading practices for potential issues.
726
727 Checks for proper handling of large datasets, including batch processing
728 and file size checks.
729
730 Args:
731 node: AST node representing the module
732 file_path: Path to the file being analyzed
733 """
734 # Skip if this is likely not a data loading file
735 if any(pattern in file_path.lower() for pattern in [
736 'test_', 'utils', 'helper', 'model', 'train',
737 'evaluate', 'predict', 'inference'
738 ]):
739 return
740
741 data_loading_methods = {
742 'read_csv', 'load_data', 'read_excel', 'read_parquet',
743 'read_json', 'read_sql', 'load_dataset'
744 }
745 batch_processing_methods = {
746 'batch_size', 'generator', 'DataLoader', 'dataset',
747 'chunk', 'iterator', 'yield', 'flow_from_directory'
748 }
749 memory_handling_methods = {
750 'dask', 'vaex', 'datatable', 'memory_limit',
751 'low_memory', 'nrows', 'usecols'
752 }
753
754 file_size_check = False
755 batch_processing = False
756 memory_handling = False
757 has_data_imports = False
758
759 # Check imports first
760 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
761 if isinstance(import_node, nodes.ImportFrom):
762 if any(pkg in import_node.modname for pkg in [
763 'pandas', 'dask', 'vaex', 'torch.utils.data',
764 'tensorflow.data', 'datasets'
765 ]):
766 has_data_imports = True
767 break
768
769 if not has_data_imports:
770 return
771
772 for call in node.nodes_of_class(nodes.Call):
773 if any(method in call.func.as_string() for method in data_loading_methods):
774 # Check for file size checks
775 if any(check in call.as_string() for check in [
776 'os.path.getsize', 'file_size', 'stat', 'memory_usage'
777 ]):
778 file_size_check = True
779
780 # Check for batch processing
781 if any(method in call.as_string() for method in batch_processing_methods):
782 batch_processing = True
783
784 # Check for memory handling
785 if any(method in call.as_string() for method in memory_handling_methods):
786 memory_handling = True
787
788 # Check for parameters indicating memory consideration
789 if isinstance(call.func, nodes.Attribute):
790 if any(param in str(call.args) + str(call.keywords) for param in [
791 'chunksize', 'batch_size', 'iterator', 'memory',
792 'nrows', 'usecols', 'dtype'
793 ]):
794 memory_handling = True
795
796 # Only raise warning if we have substantial data loading operations
797 data_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
798 if any(term in call.func.as_string()
799 for term in ['read_', 'load_', 'open']))
800
801 if (data_operations >= 2 and
802 not any(term in file_path.lower() for term in ['quick', 'example', 'demo']) and
803 not (file_size_check or batch_processing or memory_handling)):
804 self.add_smell(
805 "Data loading detected without size checks or batch processing. "
806 "Consider using generators or batch processing for large datasets.",
807 node, file_path
808 )
809
810 def check_unused_features(self, node: nodes.Module, file_path: str):
811 """Detect potentially unused features or variables in ML code.
812
813 Identifies variables that are defined but not used, excluding common
814 variable names and special prefixes.
815
816 Args:
817 node: AST node representing the module
818 file_path: Path to the file being analyzed
819 """
820 # Skip if this is likely not a main code file
821 if any(pattern in file_path.lower() for pattern in [
822 'test_', 'conftest', 'setup', '__init__',
823 'utils', 'helper', 'config', 'constants'
824 ]):
825 return
826
827 features = set()
828 used_features = set()
829 # Expanded list of common variables to ignore
830 common_vars = {
831 'self', 'i', 'j', 'k', 'x', 'y', 'X', 'y', 'df', 'data',
832 'model', 'clf', 'reg', 'pred', 'proba', 'score',
833 'train', 'test', 'val', 'valid', 'result', 'output',
834 'input', 'params', 'args', 'kwargs', 'config', 'options'
835 }
836
837 # Skip prefixes for variables that are commonly used in specific ways
838 skip_prefixes = {
839 '_', 'temp_', 'tmp_', 'test_', 'debug_', 'log_',
840 'cache_', 'old_', 'new_', 'raw_', 'processed_'
841 }
842
843 # Skip suffixes that indicate special usage
844 skip_suffixes = {
845 '_id', '_idx', '_index', '_key', '_val', '_list',
846 '_dict', '_map', '_set', '_array', '_df', '_series'
847 }
848
849 has_ml_imports = False
850 # Check for ML-related imports
851 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
852 if isinstance(import_node, nodes.ImportFrom):
853 if any(pkg in import_node.modname for pkg in [
854 'sklearn', 'tensorflow', 'torch', 'pandas', 'numpy'
855 ]):
856 has_ml_imports = True
857 break
858
859 if not has_ml_imports:
860 return
861
862 for assign in node.nodes_of_class(nodes.Assign):
863 if isinstance(assign.targets[0], nodes.Name):
864 name = assign.targets[0].name
865 # Skip if name matches any exclusion criteria
866 if (name not in common_vars and
867 not any(name.startswith(prefix) for prefix in skip_prefixes) and
868 not any(name.endswith(suffix) for suffix in skip_suffixes)):
869 features.add(name)
870
871 # Collect used features from various contexts
872 for name in node.nodes_of_class(nodes.Name):
873 used_features.add(name.name)
874
875 # Check for usage in attributes
876 for attr in node.nodes_of_class(nodes.Attribute):
877 if isinstance(attr.expr, nodes.Name):
878 used_features.add(attr.expr.name)
879
880 unused = features - used_features - common_vars
881 if unused:
882 # Additional checks for usage in various contexts
883 for node_type in [nodes.Const, nodes.Dict, nodes.List, nodes.Set]:
884 for item in node.nodes_of_class(node_type):
885 if isinstance(item.value, str):
886 unused = unused - {feat for feat in unused if feat in item.value}
887
888 # Check for usage in f-strings
889 for string in node.nodes_of_class(nodes.JoinedStr):
890 unused = unused - {feat for feat in unused if feat in string.as_string()}
891
892 # Only report if we have substantial ML operations
893 ml_operations = sum(1 for call in node.nodes_of_class(nodes.Call)
894 if any(term in call.func.as_string()
895 for term in ['fit', 'predict', 'transform']))
896
897 if unused and ml_operations >= 2:
898 self.add_smell(
899 f"Potentially unused features detected: {', '.join(unused)}. "
900 "Verify if these are actually needed.",
901 node, file_path
902 )
903
904 def check_overfit_prone_practices(self, node: nodes.Module, file_path: str):
905 """Detect practices that might lead to overfitting.
906
907 Checks feature engineering functions for proper train/test separation
908 to avoid data leakage.
909
910 Args:
911 node: AST node representing the module
912 file_path: Path to the file being analyzed
913 """
914 # Skip if this is likely not a feature engineering file
915 if any(pattern in file_path.lower() for pattern in [
916 'test_', 'utils', 'helper', 'config', 'constants',
917 'visualization', 'plotting', 'display', 'logging'
918 ]):
919 return
920
921 has_ml_imports = False
922 # Check for ML-related imports
923 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
924 if isinstance(import_node, nodes.ImportFrom):
925 if any(pkg in import_node.modname for pkg in [
926 'sklearn', 'pandas', 'numpy', 'tensorflow', 'torch'
927 ]):
928 has_ml_imports = True
929 break
930
931 if not has_ml_imports:
932 return
933
934 for func in node.nodes_of_class(nodes.FunctionDef):
935 # Check if it's a feature engineering function
936 if any(term in func.name.lower() for term in [
937 'feature', 'transform', 'process', 'engineer', 'prepare'
938 ]):
939 # Skip if it's clearly safe
940 if any(safe_term in func.name.lower() for safe_term in [
941 'train', 'fit', 'training_only', 'train_set',
942 'single', 'one', 'individual', 'row'
943 ]):
944 continue
945
946 # Check function body for proper data handling
947 has_safe_handling = False
948 for call in func.nodes_of_class(nodes.Call):
949 if any(term in call.as_string().lower() for term in [
950 'train_test_split', 'validation', 'train_data', 'test_data',
951 'fit_transform', 'transform', 'partial_fit', 'single_transform'
952 ]):
953 has_safe_handling = True
954 break
955
956 # Check for parameter names indicating proper handling
957 for arg in func.args.args:
958 if any(term in arg.name.lower() for term in [
959 'train', 'test', 'valid', 'single', 'row'
960 ]):
961 has_safe_handling = True
962 break
963
964 # Only warn if we have substantial data operations
965 data_operations = sum(1 for call in func.nodes_of_class(nodes.Call)
966 if any(term in call.func.as_string()
967 for term in ['fit', 'transform', 'process']))
968
969 if not has_safe_handling and data_operations >= 2:
970 self.add_smell(
971 "Feature engineering function detected without clear train/test separation. "
972 "Ensure it's not applied to the entire dataset to avoid data leakage.",
973 func, file_path
974 )
975
976 def check_error_handling(self, node: nodes.Module, file_path: str):
977 """Check for proper error handling in critical ML operations.
978
979 Ensures try-except blocks or validation checks are present for data
980 operations and model-related tasks.
981
982 Args:
983 node: AST node representing the module
984 file_path: Path to the file being analyzed
985 """
986 # Skip if this is likely not a main code file
987 if any(pattern in file_path.lower() for pattern in [
988 'test_', 'conftest', 'setup', '__init__',
989 'utils', 'helper', 'config', 'constants',
990 'visualization', 'plotting', 'display'
991 ]):
992 return
993
994 has_data_operations = False
995 has_error_handling = False
996 critical_operations = [
997 'read_csv', 'load_data', 'open', 'fit', 'predict',
998 'transform', 'save', 'dump', 'to_pickle', 'load_model',
999 'read_excel', 'read_json', 'read_sql'
1000 ]
1001
1002 has_ml_imports = False
1003 # Check for relevant imports
1004 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
1005 if isinstance(import_node, nodes.ImportFrom):
1006 if any(pkg in import_node.modname for pkg in [
1007 'pandas', 'sklearn', 'tensorflow', 'torch',
1008 'pickle', 'joblib'
1009 ]):
1010 has_ml_imports = True
1011 break
1012
1013 if not has_ml_imports:
1014 return
1015
1016 # Check for critical operations
1017 critical_op_count = 0
1018 for call in node.nodes_of_class(nodes.Call):
1019 if any(op in call.func.as_string() for op in critical_operations):
1020 has_data_operations = True
1021 critical_op_count += 1
1022
1023 if has_data_operations:
1024 # Check for different types of error handling
1025 for block in node.nodes_of_class((nodes.Try, nodes.ExceptHandler)):
1026 has_error_handling = True
1027 break
1028
1029 # Check for validation checks
1030 for if_block in node.nodes_of_class(nodes.If):
1031 if any(check in if_block.as_string().lower() for check in [
1032 'isinstance', 'isfile', 'exists', 'shape', 'empty', 'null',
1033 'none', 'isnull', 'isna', 'type', 'hasattr', 'in',
1034 'validate', 'check', 'verify'
1035 ]):
1036 has_error_handling = True
1037 break
1038
1039 # Check for assertion statements
1040 for assert_node in node.nodes_of_class(nodes.Assert):
1041 has_error_handling = True
1042 break
1043
1044 if not has_error_handling and critical_op_count >= 2:
1045 self.add_smell(
1046 "No error handling detected in data processing. "
1047 "Consider adding try-except blocks or validation checks for robustness.",
1048 node, file_path
1049 )
1050
1051 def check_hardcoded_filepaths(self, node: nodes.Module, file_path: str):
1052 """Detect hardcoded file paths in the code.
1053
1054 Identifies hardcoded paths that should be moved to configuration files
1055 or environment variables, excluding common development paths.
1056
1057 Args:
1058 node: AST node representing the module
1059 file_path: Path to the file being analyzed
1060 """
1061 # Common acceptable patterns
1062 acceptable_patterns = [
1063 './test/', './tests/',
1064 '../test/', '../tests/',
1065 'fixtures/', 'data/test/',
1066 '__pycache__', '.git/',
1067 'venv/', 'env/'
1068 ]
1069
1070 config_vars = set()
1071 # First collect any config/environment variables
1072 for assign in node.nodes_of_class(nodes.Assign):
1073 if isinstance(assign.targets[0], nodes.Name):
1074 if any(config_term in assign.targets[0].name.lower() for config_term in
1075 ['path', 'dir', 'folder', 'file', 'config']):
1076 config_vars.add(assign.targets[0].name)
1077
1078 for string in node.nodes_of_class(nodes.Const):
1079 if isinstance(string.value, str) and ('/' in string.value or '\\' in string.value):
1080 # Skip if it's a test/common development path
1081 if any(pattern in string.value for pattern in acceptable_patterns):
1082 continue
1083
1084 # Skip if it's used in a config/path variable assignment
1085 if any(config_var in string.scope().locals for config_var in config_vars):
1086 continue
1087
1088 self.add_smell(
1089 f"Hardcoded file path detected: {string.value}. "
1090 "Consider using configuration files or environment variables.",
1091 string, file_path
1092 )
1093
1094 def check_documentation(self, node: nodes.Module, file_path: str):
1095 """Check for proper documentation in ML-related code.
1096
1097 Ensures functions and classes have appropriate docstrings, especially
1098 for those with parameters or return values.
1099
1100 Args:
1101 node: AST node representing the module
1102 file_path: Path to the file being analyzed
1103 """
1104 # Skip certain types of files/functions
1105 skip_patterns = [
1106 'test_', 'fixture', 'conftest',
1107 'setup', 'init', 'main',
1108 'helper', 'util'
1109 ]
1110
1111 for func in node.nodes_of_class(nodes.FunctionDef):
1112 # Skip if it's a simple function (few lines)
1113 if len(list(func.get_children())) <= 3:
1114 continue
1115
1116 # Skip if it's a test or utility function
1117 if any(pattern in func.name.lower() for pattern in skip_patterns):
1118 continue
1119
1120 if not isinstance(func.doc_node, nodes.Const):
1121 # Only warn for functions with parameters or return values
1122 if func.args.args or 'return' in func.as_string():
1123 self.add_smell(
1124 f"Missing docstring for function: {func.name}. "
1125 "Consider adding documentation for parameters and return values.",
1126 func, file_path
1127 )
1128
1129 for cls in node.nodes_of_class(nodes.ClassDef):
1130 # Skip test classes
1131 if any(pattern in cls.name.lower() for pattern in skip_patterns):
1132 continue
1133
1134 if not isinstance(cls.doc_node, nodes.Const):
1135 # Check if class has public methods
1136 has_public_methods = any(
1137 not method.name.startswith('_')
1138 for method in cls.mymethods()
1139 )
1140 if has_public_methods:
1141 self.add_smell(
1142 f"Missing docstring for class: {cls.name}. "
1143 "Consider adding class-level documentation.",
1144 cls, file_path
1145 )
1146
1147 def generate_report(self) -> str:
1148 """Generate a human-readable report of all detected smells.
1149
1150 Returns:
1151 Formatted string containing the analysis report
1152 """
1153 report = "General ML Code Smell Report\n============================\n\n"
1154 smell_counts = {}
1155 for i, smell in enumerate(self.smells, 1):
1156 if smell['smell'] not in smell_counts:
1157 smell_counts[smell['smell']] = 0
1158 smell_counts[smell['smell']] += 1
1159
1160 report += f"{i}. Smell: {smell['smell']}\n"
1161 report += f" File: {smell['file_path']}\n"
1162
1163 # Only show line number if it's not 0
1164 if smell['line_number'] != 0:
1165 report += f" Line: {smell['line_number']}\n"
1166
1167 # Only include code snippet if it's 3 lines or fewer
1168 code_lines = smell['code_snippet'].strip().split('\n')
1169 if len(code_lines) <= 3:
1170 report += f" Code Snippet:\n{smell['code_snippet']}\n"
1171 report += "\n"
1172
1173 report += "Smell Counts:\n"
1174 for smell, count in smell_counts.items():
1175 report += f" {smell}: {count}\n"
1176 report += f"\nTotal smells detected: {len(self.smells)}"
1177 return report
1178
1179 def get_results(self) -> List[Dict[str, str]]:
1180 """Get the analysis results in a structured format.
1181
1182 Returns:
1183 List of dictionaries containing smell details and locations
1184 """
1185 return [
1186 {
1187 'framework': 'General ML',
1188 'name': smell['smell'],
1189 'fix': "Not specified",
1190 'benefits': "Not specified",
1191 'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
1192 }
1193 for smell in self.smells
1194 ]