1import os
2import sys
3from typing import Any, Dict, List
4
5import astroid
6from astroid import nodes
7
8
9class FrameworkSpecificSmellDetector:
10 """A detector for identifying code smells specific to ML frameworks like
11 Pandas, NumPy, Scikit-learn, PyTorch and TensorFlow.
12
13 This class analyzes Python code to detect common anti-patterns and suboptimal implementations
14 when using popular machine learning frameworks. It provides suggestions for improvements
15 and best practices.
16 """
17
18 def __init__(self):
19 self.smells: List[Dict[str, Any]] = []
20 self.framework_smells = self.get_smells()
21
22 def detect_smells(self, file_path: str) -> List[Dict[str, str]]:
23 """Analyze a Python file for framework-specific code smells.
24
25 Args:
26 file_path: Path to the Python file to analyze
27
28 Returns:
29 List of dictionaries containing detected code smells with details like:
30 - framework: The ML framework the smell relates to
31 - name: Name of the code smell
32 - how_to_fix: Instructions for fixing the issue
33 - benefits: Benefits of fixing the issue
34 - line_number: Line where the smell was detected
35 - code_snippet: The problematic code
36 - file_path: Path to the file containing the smell
37 """
38 try:
39 with open(file_path, 'r') as file:
40 content = file.read()
41 module_name = os.path.splitext(os.path.basename(file_path))[0]
42 module = astroid.parse(content, module_name=module_name)
43
44 if not module:
45 print(f"Error: Could not parse module for {file_path}", file=sys.stderr)
46 return self.smells
47
48 frameworks_used = self.get_frameworks_used(module)
49 if frameworks_used is None:
50 print(f"Error: Could not determine frameworks for {file_path}", file=sys.stderr)
51 return self.smells
52
53 if frameworks_used:
54 self.visit_module(module, file_path, frameworks_used)
55 else:
56 print(
57 f"Skipping framework-specific smell detection for {file_path}: No relevant frameworks imported",
58 file=sys.stderr)
59 except astroid.exceptions.AstroidSyntaxError as e:
60 print(f"Error parsing {file_path}: {str(e)}", file=sys.stderr)
61 except Exception as e:
62 print(f"Unexpected error while processing {file_path}: {str(e)}", file=sys.stderr)
63 return self.smells
64
65 def get_frameworks_used(self, node: nodes.Module) -> List[str]:
66 """Detect which ML frameworks are imported in the code.
67
68 Args:
69 node: AST node representing the Python module
70
71 Returns:
72 List of framework names found in imports (e.g. ['Pandas', 'NumPy'])
73 """
74 if not node:
75 return []
76
77 frameworks = []
78 framework_imports = {
79 'pandas': 'Pandas',
80 'numpy': 'NumPy',
81 'sklearn': 'ScikitLearn',
82 'tensorflow': 'TensorFlow',
83 'torch': 'PyTorch'
84 }
85
86 try:
87 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
88 if isinstance(import_node, nodes.Import):
89 for name, _ in import_node.names:
90 root = name.split('.')[0]
91 if root in framework_imports and framework_imports[root] not in frameworks:
92 frameworks.append(framework_imports[root])
93 elif isinstance(import_node, nodes.ImportFrom):
94 root = import_node.modname.split('.')[0]
95 if root in framework_imports and framework_imports[root] not in frameworks:
96 frameworks.append(framework_imports[root])
97 except Exception as e:
98 print(f"Error in get_frameworks_used: {str(e)}", file=sys.stderr)
99 return []
100
101 return frameworks
102
103 def visit_module(self, node: nodes.Module, file_path: str, frameworks_used: List[str]):
104 """Visit a Python module to detect framework-specific smells.
105
106 Args:
107 node: AST node representing the Python module
108 file_path: Path to the file being analyzed
109 frameworks_used: List of frameworks detected in the code
110 """
111 if not node or not frameworks_used:
112 return
113
114 try:
115 self.check_imports(node, file_path)
116 for framework in frameworks_used:
117 if framework:
118 detector_method = getattr(self, f"detect_{framework.lower()}_smells", None)
119 if detector_method:
120 detector_method(node, file_path)
121 except Exception as e:
122 print(f"Error in visit_module for {file_path}: {str(e)}", file=sys.stderr)
123
124 def add_smell(self, framework: str, smell_name: str, node: nodes.NodeNG, file_path: str):
125 """Add a detected code smell to the results.
126
127 Args:
128 framework: Name of the framework (e.g. 'Pandas', 'NumPy')
129 smell_name: Name of the detected smell
130 node: AST node where the smell was found
131 file_path: Path to the file containing the smell
132 """
133 smell = next((s for s in self.framework_smells[framework] if s['name'] == smell_name), None)
134 if smell:
135 self.smells.append({
136 "framework": framework,
137 "name": smell['name'],
138 "how_to_fix": smell['how_to_fix'],
139 "benefits": smell['benefits'],
140 "strategies": smell['strategies'],
141 "line_number": node.lineno,
142 "code_snippet": node.as_string(),
143 "file_path": file_path
144 })
145
146 # Pandas Detection Methods
147 def detect_pandas_smells(self, node: nodes.Module, file_path: str):
148 """Detect Pandas-specific code smells like:
149 - Unnecessary iteration instead of vectorization
150 - Chain indexing
151 - Missing merge parameters
152 - Inplace operations
153 - Inefficient values usage
154 - Missing dtype specifications
155 - Suboptimal column selection
156 - DataFrame modifications in loops
157
158 Args:
159 node: AST node representing the Python module
160 file_path: Path to the file being analyzed
161 """
162 self.detect_iterrows_usage(node, file_path)
163 self.detect_chain_indexing(node, file_path)
164 self.detect_merge_parameters(node, file_path)
165 self.detect_inplace_operations(node, file_path)
166 self.detect_values_usage(node, file_path)
167 self.detect_dtype_specification(node, file_path)
168 self.detect_column_selection(node, file_path)
169 self.detect_dataframe_modification(node, file_path)
170
171 def detect_iterrows_usage(self, node: nodes.Module, file_path: str):
172 # Only flag iterrows if better alternatives could be used
173 vectorizable_operations = ['sum', 'mean', 'max', 'min', 'apply', 'map']
174 for call in node.nodes_of_class(nodes.Call):
175 if 'iterrows' in call.func.as_string():
176 # Check if the iterrows is used within a loop
177 parent = call.parent
178 while parent and not isinstance(parent, nodes.For):
179 parent = parent.parent
180 if parent and any(op in parent.as_string() for op in vectorizable_operations):
181 self.add_smell('Pandas', 'Unnecessary Iteration', call, file_path)
182
183 def detect_chain_indexing(self, node: nodes.Module, file_path: str):
184 for subscript in node.nodes_of_class(nodes.Subscript):
185 if isinstance(subscript.value, nodes.Subscript):
186 # Only flag if it's not within a with pd.option_context block
187 parent = subscript.parent
188 while parent and not isinstance(parent, nodes.With):
189 parent = parent.parent
190 if not parent or 'option_context' not in parent.as_string():
191 # Check if it's an assignment (more dangerous) vs just access
192 if isinstance(subscript.parent, nodes.Assign):
193 self.add_smell('Pandas', 'Chain Indexing', subscript, file_path)
194
195 def detect_merge_parameters(self, node: nodes.Module, file_path: str):
196 for call in node.nodes_of_class(nodes.Call):
197 if ('merge' in call.func.as_string() and
198 isinstance(call.func, nodes.Attribute) and
199 # Check if it's actually a pandas merge
200 ('pd' in call.func.expr.as_string() or 'pandas' in call.func.expr.as_string()) and
201 not any(kw.arg in ['how', 'on', 'validate'] for kw in call.keywords)):
202 self.add_smell('Pandas', 'Merge Parameter Checker', call, file_path)
203
204 def detect_inplace_operations(self, node: nodes.Module, file_path: str):
205 inplace_operations = ['sort_values', 'fillna', 'drop', 'replace', 'rename']
206 for call in node.nodes_of_class(nodes.Call):
207 if ('inplace=True' in call.as_string() and
208 any(op in call.func.as_string() for op in inplace_operations)):
209 # Check if the result is used later
210 parent = call.parent
211 while parent and not isinstance(parent, nodes.Module):
212 if isinstance(parent, nodes.Assign):
213 break
214 parent = parent.parent
215 if not isinstance(parent, nodes.Assign):
216 self.add_smell('Pandas', 'InPlace Checker', call, file_path)
217
218 def detect_values_usage(self, node: nodes.Module, file_path: str):
219 for call in node.nodes_of_class(nodes.Call):
220 if '.values' in call.as_string():
221 # Check if it's used in a context where to_numpy() would be better
222 numpy_contexts = ['np.', 'array', 'asarray', 'reshape', 'transpose']
223 if any(ctx in call.as_string() for ctx in numpy_contexts):
224 self.add_smell('Pandas', 'DataFrame Conversion Checker', call, file_path)
225
226 def detect_dtype_specification(self, node: nodes.Module, file_path: str):
227 for call in node.nodes_of_class(nodes.Call):
228 if ('read_csv' in call.func.as_string() and
229 # Check if it's actually pandas read_csv
230 ('pd' in call.func.as_string() or 'pandas' in call.func.as_string()) and
231 not any(kw.arg in ['dtype', 'parse_dates'] for kw in call.keywords)):
232 # Check if the DataFrame is used in operations sensitive to dtypes
233 parent = call.parent
234 dtype_sensitive_ops = ['groupby', 'merge', 'join', 'sort_values', 'arithmetic_operations']
235 while parent and not isinstance(parent, nodes.Module):
236 if any(op in parent.as_string() for op in dtype_sensitive_ops):
237 self.add_smell('Pandas', 'Datatype Checker', call, file_path)
238 break
239 parent = parent.parent
240
241 def detect_column_selection(self, node: nodes.Module, file_path: str):
242 # Only check if there are actual DataFrame operations
243 df_operations = False
244 column_selections = False
245
246 for subscript in node.nodes_of_class(nodes.Subscript):
247 if isinstance(subscript.value, nodes.Name):
248 # Look for DataFrame operations
249 if any(op in node.as_string() for op in ['DataFrame', 'pd.', 'pandas']):
250 df_operations = True
251 if '[[' in subscript.as_string():
252 column_selections = True
253 break
254
255 if df_operations and not column_selections:
256 self.add_smell('Pandas', 'Column Selection Checker', node, file_path)
257
258 def detect_dataframe_modification(self, node: nodes.Module, file_path: str):
259 for assign in node.nodes_of_class(nodes.Assign):
260 if (isinstance(assign.targets[0], nodes.Subscript) and
261 isinstance(assign.targets[0].value, nodes.Name)):
262 # Check if it's within a loop and modifying a DataFrame
263 parent = assign.parent
264 while parent and not isinstance(parent, nodes.For):
265 parent = parent.parent
266
267 # Check if it's actually a DataFrame modification
268 df_indicators = ['DataFrame', 'pd.', 'pandas']
269 if (parent and
270 any(indicator in assign.as_string() for indicator in df_indicators) and
271 not any(safe_op in assign.as_string() for safe_op in ['loc', 'iloc', 'at', 'iat'])):
272 self.add_smell('Pandas', 'DataFrame Iteration Modification', assign, file_path)
273
274 # NumPy Detection Methods
275 def detect_numpy_smells(self, node: nodes.Module, file_path: str):
276 """Detect NumPy-specific code smells like:
277 - NaN equality comparisons
278 - Missing random seeds
279 - Inefficient array creation
280 - Non-vectorized operations
281 - Dtype inconsistencies
282 - Broadcasting issues
283 - Copy/view confusion
284 - Missing axis specifications
285
286 Args:
287 node: AST node representing the Python module
288 file_path: Path to the file being analyzed
289 """
290 self.detect_nan_equality(node, file_path)
291 self.detect_random_seed(node, file_path)
292 self.detect_array_creation(node, file_path)
293 self.detect_inefficient_operations(node, file_path)
294 self.detect_dtype_consistency(node, file_path)
295 self.detect_broadcasting_issues(node, file_path)
296 self.detect_copy_view_issues(node, file_path)
297 self.detect_axis_specification(node, file_path)
298
299 def detect_nan_equality(self, node: nodes.Module, file_path: str):
300 for compare in node.nodes_of_class(nodes.Compare):
301 if any('np.nan' in getattr(op[1], 'as_string', lambda: '')() for op in compare.ops):
302 self.add_smell('NumPy', 'NaN Equality Checker', compare, file_path)
303
304 def detect_random_seed(self, node: nodes.Module, file_path: str):
305 # List of numpy random operations to check for
306 random_operations = [
307 'np.random.rand',
308 'np.random.randn',
309 'np.random.randint',
310 'np.random.choice',
311 'np.random.shuffle',
312 'np.random.permutation',
313 'np.random.normal',
314 'np.random.uniform'
315 ]
316
317 # Find all random operation calls
318 random_op_calls = [
319 call for call in node.nodes_of_class(nodes.Call)
320 if any(op in getattr(call.func, 'as_string', lambda: '')() for op in random_operations)
321 ]
322
323 # Find all random seed calls
324 random_seed_calls = [
325 call for call in node.nodes_of_class(nodes.Call)
326 if 'np.random.seed' in getattr(call.func, 'as_string', lambda: '')()
327 ]
328
329 # Only trigger smell if random operations are used without seed
330 if random_op_calls and not random_seed_calls:
331 self.add_smell('NumPy', 'Randomness Control Checker', node, file_path)
332
333 def detect_array_creation(self, node: nodes.Module, file_path: str):
334 """Detect inefficient array creation patterns"""
335 for call in node.nodes_of_class(nodes.Call):
336 # Check for list to array conversion without dtype
337 if ('np.array' in call.func.as_string() and
338 not any(kw.arg == 'dtype' for kw in call.keywords)):
339 self.add_smell('NumPy', 'Array Creation Efficiency', call, file_path)
340
341 # Check for zeros/ones/empty without dtype
342 if any(func in call.func.as_string() for func in ['np.zeros', 'np.ones', 'np.empty']) and \
343 not any(kw.arg == 'dtype' for kw in call.keywords):
344 self.add_smell('NumPy', 'Array Creation Efficiency', call, file_path)
345
346 def detect_inefficient_operations(self, node: nodes.Module, file_path: str):
347 """Detect inefficient numerical operations"""
348 for loop in node.nodes_of_class(nodes.For):
349 # Check for element-wise operations in loops
350 if any(op in loop.as_string() for op in ['np.sum', 'np.mean', 'np.max', 'np.min']):
351 self.add_smell('NumPy', 'Inefficient Operations', loop, file_path)
352
353 # Check for inefficient concatenation
354 for call in node.nodes_of_class(nodes.Call):
355 if 'np.concatenate' in call.func.as_string():
356 parent = call.parent
357 while parent and not isinstance(parent, nodes.For):
358 parent = parent.parent
359 if parent: # If concatenate is inside a loop
360 self.add_smell('NumPy', 'Inefficient Operations', call, file_path)
361
362 def detect_dtype_consistency(self, node: nodes.Module, file_path: str):
363 """Detect potential dtype inconsistency issues"""
364 for binop in node.nodes_of_class(nodes.BinOp):
365 # Check for mixed integer and float operations
366 if (isinstance(binop.left, nodes.Call) and isinstance(binop.right, nodes.Call) and
367 'np.' in binop.left.func.as_string() and 'np.' in binop.right.func.as_string()):
368 if ('int' in binop.left.func.as_string() and 'float' in binop.right.func.as_string()) or \
369 ('float' in binop.left.func.as_string() and 'int' in binop.right.func.as_string()):
370 self.add_smell('NumPy', 'Dtype Consistency', binop, file_path)
371
372 def detect_broadcasting_issues(self, node: nodes.Module, file_path: str):
373 """Detect potential broadcasting issues"""
374 for binop in node.nodes_of_class(nodes.BinOp):
375 if isinstance(binop.left, nodes.Call) and isinstance(binop.right, nodes.Call):
376 # Check for operations between arrays that might have broadcasting issues
377 if ('reshape' in binop.left.as_string() or 'reshape' in binop.right.as_string() or
378 'transpose' in binop.left.as_string() or 'transpose' in binop.right.as_string()):
379 self.add_smell('NumPy', 'Broadcasting Risk', binop, file_path)
380
381 def detect_copy_view_issues(self, node: nodes.Module, file_path: str):
382 """Detect potential copy/view confusion in NumPy array operations"""
383 for assign in node.nodes_of_class(nodes.Assign):
384 # Only check assignments involving array slicing
385 if isinstance(assign.value, nodes.Subscript):
386 # Skip if it's not a NumPy array operation
387 if not any(np_indicator in assign.as_string()
388 for np_indicator in ['np.', 'numpy.']):
389 continue
390
391 # Check if the slice is being modified later
392 is_modified = False
393 parent = assign.parent
394 while parent and not isinstance(parent, nodes.Module):
395 if isinstance(parent, (nodes.Assign, nodes.AugAssign)):
396 # Look for modifications to the assigned variable
397 target_name = assign.targets[0].as_string()
398 if target_name in parent.as_string():
399 is_modified = True
400 break
401 parent = parent.parent
402
403 # Only flag if:
404 # 1. The slice is modified later
405 # 2. No explicit copy is made
406 # 3. Not using advanced indexing (which creates copies)
407 if (is_modified and
408 not any(method in assign.as_string() for method in ['.copy()', 'np.copy']) and
409 not any(idx in assign.value.as_string() for idx in ['[[', 'bool', 'mask'])):
410 self.add_smell('NumPy', 'Copy-View Confusion', assign, file_path)
411
412 def detect_axis_specification(self, node: nodes.Module, file_path: str):
413 """Detect missing axis specifications in array operations"""
414 axis_operations = ['sum', 'mean', 'max', 'min', 'argmax', 'argmin', 'any', 'all']
415 for call in node.nodes_of_class(nodes.Call):
416 if (any(op in call.func.as_string() for op in axis_operations) and
417 'np.' in call.func.as_string() and
418 not any(kw.arg == 'axis' for kw in call.keywords) and
419 len(call.args) < 2): # axis can also be specified as positional argument
420 self.add_smell('NumPy', 'Missing Axis Specification', call, file_path)
421
422 # Scikit-learn Detection Methods
423 def detect_scikitlearn_smells(self, node: nodes.Module, file_path: str):
424 """Detect Scikit-learn specific code smells like:
425 - Missing data scaling
426 - Not using pipelines
427 - Missing cross validation
428 - Missing random state
429 - Missing verbose mode
430 - Using only threshold-dependent metrics
431 - Missing unit tests
432 - Data leakage risks
433 - Missing exception handling
434
435 Args:
436 node: AST node representing the Python module
437 file_path: Path to the file being analyzed
438 """
439 self.detect_scaling_usage(node, file_path)
440 self.detect_pipeline_usage(node, file_path)
441 self.detect_cross_validation(node, file_path)
442 self.detect_random_state(node, file_path)
443 self.detect_verbose_mode(node, file_path)
444 self.detect_threshold_metrics(node, file_path)
445 self.detect_unit_tests(node, file_path)
446 self.detect_data_leakage(node, file_path)
447 self.detect_exception_handling(node, file_path)
448
449 def detect_scaling_usage(self, node: nodes.Module, file_path: str):
450 scaling_sensitive_estimators = [
451 'SVM', 'SVR', 'PCA', 'KMeans', 'NeuralNetwork', 'LogisticRegression'
452 ]
453 scaling_methods = ['StandardScaler', 'MinMaxScaler', 'RobustScaler']
454
455 # Check if scaling-sensitive estimators are used
456 has_sensitive_estimator = any(
457 estimator in call.func.as_string()
458 for call in node.nodes_of_class(nodes.Call)
459 for estimator in scaling_sensitive_estimators
460 )
461
462 # Check if any scaling is applied
463 has_scaling = any(
464 method in call.func.as_string()
465 for call in node.nodes_of_class(nodes.Call)
466 for method in scaling_methods
467 )
468
469 # Only report if using sensitive estimators without scaling
470 if has_sensitive_estimator and not has_scaling:
471 self.add_smell('ScikitLearn', 'Scaler Missing Checker', node, file_path)
472
473 def detect_pipeline_usage(self, node: nodes.Module, file_path: str):
474 # Check if there are multiple preprocessing or model fitting steps
475 preprocessing_steps = [
476 'StandardScaler', 'MinMaxScaler', 'PCA', 'SelectKBest',
477 'PolynomialFeatures', 'OneHotEncoder', 'LabelEncoder'
478 ]
479 model_steps = [
480 'fit', 'predict', 'transform', 'fit_transform'
481 ]
482
483 has_preprocessing = any(
484 step in call.func.as_string()
485 for call in node.nodes_of_class(nodes.Call)
486 for step in preprocessing_steps
487 )
488 has_model_steps = any(
489 step in call.func.as_string()
490 for call in node.nodes_of_class(nodes.Call)
491 for step in model_steps
492 )
493
494 # Only suggest Pipeline if multiple steps are present
495 if has_preprocessing and has_model_steps and not any(
496 'Pipeline' in call.func.as_string()
497 for call in node.nodes_of_class(nodes.Call)
498 ):
499 self.add_smell('ScikitLearn', 'Pipeline Checker', node, file_path)
500
501 def detect_cross_validation(self, node: nodes.Module, file_path: str):
502 # Check if model training is performed
503 training_indicators = ['fit', 'train']
504 has_training = any(
505 indicator in call.func.as_string()
506 for call in node.nodes_of_class(nodes.Call)
507 for indicator in training_indicators
508 )
509
510 # Check for any cross-validation technique
511 cv_methods = [
512 'cross_val_score', 'KFold', 'StratifiedKFold',
513 'cross_validate', 'GridSearchCV', 'RandomizedSearchCV'
514 ]
515 has_cv = any(
516 method in call.func.as_string()
517 for call in node.nodes_of_class(nodes.Call)
518 for method in cv_methods
519 )
520
521 # Only suggest cross-validation for model training scenarios
522 if has_training and not has_cv:
523 self.add_smell('ScikitLearn', 'Cross Validation Checker', node, file_path)
524
525 def detect_random_state(self, node: nodes.Module, file_path: str):
526 # List of methods that accept random_state
527 random_state_methods = [
528 'train_test_split', 'KFold', 'RandomForest', 'KMeans',
529 'PCA', 'shuffle', 'random_state'
530 ]
531
532 # Check if any random-dependent operations are used
533 random_dependent_calls = [
534 call for call in node.nodes_of_class(nodes.Call)
535 if any(method in call.func.as_string() for method in random_state_methods)
536 ]
537
538 # Check if random_state is set for these calls
539 for call in random_dependent_calls:
540 if not any(
541 kw.arg == 'random_state' for kw in call.keywords
542 ):
543 self.add_smell('ScikitLearn', 'Randomness Control Checker', call, file_path)
544
545 def detect_verbose_mode(self, node: nodes.Module, file_path: str):
546 # Only check for verbose in time-consuming operations
547 time_consuming_ops = [
548 'GridSearchCV', 'RandomizedSearchCV', 'fit',
549 'RandomForest', 'GradientBoosting'
550 ]
551
552 for call in node.nodes_of_class(nodes.Call):
553 if (any(op in call.func.as_string() for op in time_consuming_ops) and
554 not any(kw.arg == 'verbose' for kw in call.keywords)):
555 self.add_smell('ScikitLearn', 'Verbose Mode Checker', call, file_path)
556
557 def detect_threshold_metrics(self, node: nodes.Module, file_path: str):
558 # Only check for classification-related tasks
559 classification_indicators = [
560 'classifier', 'predict_proba', 'accuracy_score',
561 'precision_score', 'recall_score', 'f1_score',
562 'classification_report'
563 ]
564
565 # Check if it's a classification task
566 is_classification = any(
567 indicator in node.as_string()
568 for indicator in classification_indicators
569 )
570
571 # Check for threshold-independent metrics
572 threshold_independent_metrics = [
573 'roc_auc_score', 'average_precision_score',
574 'precision_recall_curve', 'roc_curve'
575 ]
576 has_threshold_metrics = any(
577 metric in call.func.as_string()
578 for call in node.nodes_of_class(nodes.Call)
579 for metric in threshold_independent_metrics
580 )
581
582 # Only suggest threshold-independent metrics for classification tasks
583 if is_classification and not has_threshold_metrics:
584 self.add_smell('ScikitLearn', 'Dependent Threshold Checker', node, file_path)
585
586 def detect_unit_tests(self, node: nodes.Module, file_path: str):
587 # Check if this is a source file (not already a test file)
588 is_test_file = ('test' in file_path.lower() or
589 node.as_string().lower().startswith('test'))
590
591 # Check for model training or evaluation code
592 ml_operations = [
593 'fit', 'predict', 'transform', 'score',
594 'cross_val_score', 'GridSearchCV'
595 ]
596 has_ml_operations = any(
597 op in call.func.as_string()
598 for call in node.nodes_of_class(nodes.Call)
599 for op in ml_operations
600 )
601
602 # Check for testing frameworks
603 test_frameworks = [
604 'unittest', 'pytest', 'nose',
605 'TestCase', '@test', '@pytest'
606 ]
607 has_tests = any(
608 framework in node.as_string()
609 for framework in test_frameworks
610 )
611
612 # Only suggest adding tests for non-test ML files without tests
613 if not is_test_file and has_ml_operations and not has_tests:
614 self.add_smell('ScikitLearn', 'Unit Testing Checker', node, file_path)
615
616 def detect_data_leakage(self, node: nodes.Module, file_path: str):
617 # Check for data preprocessing operations
618 preprocessing_ops = [
619 'fit_transform', 'transform', 'StandardScaler',
620 'MinMaxScaler', 'PCA', 'feature_selection'
621 ]
622 has_preprocessing = any(
623 op in call.func.as_string()
624 for call in node.nodes_of_class(nodes.Call)
625 for op in preprocessing_ops
626 )
627
628 # Check for model training
629 has_model_training = any(
630 'fit' in call.func.as_string()
631 for call in node.nodes_of_class(nodes.Call)
632 )
633
634 # Check for proper train-test splitting
635 splitting_methods = [
636 'train_test_split', 'KFold', 'StratifiedKFold',
637 'GroupKFold', 'TimeSeriesSplit'
638 ]
639 has_proper_split = any(
640 method in call.func.as_string()
641 for call in node.nodes_of_class(nodes.Call)
642 for method in splitting_methods
643 )
644
645 # Only report if preprocessing and training without proper splitting
646 if (has_preprocessing and has_model_training and not has_proper_split):
647 self.add_smell('ScikitLearn', 'Data Leakage Checker', node, file_path)
648
649 def detect_exception_handling(self, node: nodes.Module, file_path: str):
650 # Check for risky operations that should have exception handling
651 risky_operations = [
652 'fit', 'predict', 'transform', 'inverse_transform',
653 'cross_val_score', 'GridSearchCV', 'load_', 'dump_',
654 'pickle', 'joblib'
655 ]
656
657 # Find all risky operation calls
658 risky_calls = [
659 call for call in node.nodes_of_class(nodes.Call)
660 if any(op in call.func.as_string() for op in risky_operations)
661 ]
662
663 # Check if these calls are within try-except blocks
664 for call in risky_calls:
665 parent = call.parent
666 within_try_block = False
667
668 while parent and not isinstance(parent, nodes.Module):
669 if isinstance(parent, nodes.Try):
670 within_try_block = True
671 break
672 parent = parent.parent
673
674 # Only report if risky operations are not in try-except blocks
675 if risky_calls and not within_try_block:
676 self.add_smell('ScikitLearn', 'Exception Handling Checker', call, file_path)
677
678 def check_imports(self, node: nodes.Module, file_path: str):
679 for import_node in node.nodes_of_class((nodes.Import, nodes.ImportFrom)):
680 if isinstance(import_node, nodes.Import):
681 for name, alias in import_node.names:
682 if name == 'numpy' and alias != 'np':
683 self.add_smell('General', 'Import Checker', import_node, file_path)
684 elif name == 'pandas' and alias != 'pd':
685 self.add_smell('General', 'Import Checker', import_node, file_path)
686 elif isinstance(import_node, nodes.ImportFrom):
687 if import_node.modname == 'numpy' and any(alias != 'np' for _, alias in import_node.names):
688 self.add_smell('General', 'Import Checker', import_node, file_path)
689 elif import_node.modname == 'pandas' and any(alias != 'pd' for _, alias in import_node.names):
690 self.add_smell('General', 'Import Checker', import_node, file_path)
691
692 # PyTorch Detection Methods
693 def detect_pytorch_smells(self, node: nodes.Module, file_path: str):
694 """Detect PyTorch-specific code smells like:
695 - Missing random seeds
696 - Not using deterministic mode
697 - DataLoader randomness issues
698 - Missing masks for numerical operations
699 - Direct forward() calls
700 - Missing gradient zeroing
701 - Missing batch normalization
702 - Missing dropout
703 - Missing data augmentation
704 - Missing learning rate scheduling
705 - Missing logging
706 - Missing model evaluation mode
707
708 Args:
709 node: AST node representing the Python module
710 file_path: Path to the file being analyzed
711 """
712 self.detect_pytorch_random_seed(node, file_path)
713 self.detect_pytorch_deterministic(node, file_path)
714 self.detect_pytorch_dataloader_random(node, file_path)
715 self.detect_pytorch_mask(node, file_path)
716 self.detect_pytorch_forward(node, file_path)
717 self.detect_pytorch_grad_zero(node, file_path)
718 self.detect_pytorch_batch_norm(node, file_path)
719 self.detect_pytorch_dropout(node, file_path)
720 self.detect_pytorch_augmentation(node, file_path)
721 self.detect_pytorch_lr_scheduler(node, file_path)
722 self.detect_pytorch_logging(node, file_path)
723 self.detect_pytorch_eval_mode(node, file_path)
724
725 def detect_pytorch_random_seed(self, node: nodes.Module, file_path: str):
726 # Check if there are any random operations that need seeding
727 random_operations = [
728 'torch.rand', 'torch.randn', 'torch.randint',
729 'torch.randperm', 'torch.bernoulli', 'torch.normal',
730 'torch.dropout', 'torch.nn.Dropout'
731 ]
732
733 has_random_ops = any(
734 op in call.func.as_string()
735 for call in node.nodes_of_class(nodes.Call)
736 for op in random_operations
737 )
738
739 has_manual_seed = any(
740 'torch.manual_seed' in call.func.as_string() or
741 'torch.cuda.manual_seed' in call.func.as_string() or
742 'torch.cuda.manual_seed_all' in call.func.as_string()
743 for call in node.nodes_of_class(nodes.Call)
744 )
745
746 # Only report if random operations are used without seeding
747 if has_random_ops and not has_manual_seed:
748 self.add_smell('PyTorch', 'Randomness Control Checker', node, file_path)
749
750 def detect_pytorch_deterministic(self, node: nodes.Module, file_path: str):
751 # Check for operations that benefit from deterministic algorithms
752 deterministic_sensitive_ops = [
753 'torch.nn.Conv', 'torch.nn.LSTM', 'torch.nn.GRU',
754 'torch.backends.cudnn', 'torch.cuda', 'DataLoader'
755 ]
756
757 has_sensitive_ops = any(
758 op in node.as_string()
759 for op in deterministic_sensitive_ops
760 )
761
762 has_deterministic_setting = any(
763 'torch.use_deterministic_algorithms' in call.func.as_string() or
764 'torch.backends.cudnn.deterministic' in call.as_string()
765 for call in node.nodes_of_class(nodes.Call)
766 )
767
768 # Only report if sensitive operations are used without deterministic setting
769 if has_sensitive_ops and not has_deterministic_setting:
770 self.add_smell('PyTorch', 'Deterministic Algorithm Usage Checker', node, file_path)
771
772 def detect_pytorch_dataloader_random(self, node: nodes.Module, file_path: str):
773 for call in node.nodes_of_class(nodes.Call):
774 if 'DataLoader' in call.func.as_string():
775 # Check if it's actually a PyTorch DataLoader
776 is_pytorch_dataloader = any(
777 'torch' in imp.modname
778 for imp in node.nodes_of_class(nodes.ImportFrom)
779 )
780
781 # Check if shuffling is enabled
782 has_shuffle = any(
783 (kw.arg == 'shuffle' and getattr(kw.value, 'value', False) is True)
784 for kw in call.keywords
785 )
786
787 # Check for proper random state control
788 has_random_control = any(
789 (kw.arg in ['worker_init_fn', 'generator'])
790 for kw in call.keywords
791 )
792
793 # Only report if it's a PyTorch DataLoader with shuffling but no random control
794 if is_pytorch_dataloader and has_shuffle and not has_random_control:
795 self.add_smell('PyTorch', 'Randomness Control Checker (PyTorch-Dataloader)', call, file_path)
796
797 def detect_pytorch_mask(self, node: nodes.Module, file_path: str):
798 for call in node.nodes_of_class(nodes.Call):
799 if 'torch.log' in call.func.as_string():
800 # Check if the input might contain zeros or negative values
801 input_arg = call.args[0] if call.args else None
802 if input_arg:
803 # Look for potential risky operations before the log
804 risky_ops = ['zeros', 'randn', 'rand', 'sub', 'subtract']
805 has_risky_input = any(
806 op in input_arg.as_string()
807 for op in risky_ops
808 )
809
810 # Check if masking is applied
811 has_mask = (
812 len(call.args) > 1 or
813 any('clamp' in node.as_string() or
814 'mask' in node.as_string() or
815 'where' in node.as_string())
816 )
817
818 # Only report if there's a risk of invalid input without masking
819 if has_risky_input and not has_mask:
820 self.add_smell('PyTorch', 'Mask Missing Checker', call, file_path)
821
822 def detect_pytorch_forward(self, node: nodes.Module, file_path: str):
823 for call in node.nodes_of_class(nodes.Call):
824 if 'forward' in call.func.as_string():
825 # Check if it's a direct forward call on a neural network
826 is_nn_forward = (
827 isinstance(call.func, nodes.Attribute) and
828 'forward' in call.func.attrname and
829 any(parent_cls.bases and
830 'nn.Module' in parent_cls.bases[0].as_string()
831 for parent_cls in node.nodes_of_class(nodes.ClassDef))
832 )
833
834 # Check if it's called directly instead of using __call__
835 is_direct_call = (
836 'net.forward' in call.func.as_string() or
837 'model.forward' in call.func.as_string()
838 )
839
840 # Only report if it's a direct forward call on an nn.Module
841 if is_nn_forward and is_direct_call:
842 self.add_smell('PyTorch', 'Net Forward Checker', call, file_path)
843
844 def detect_pytorch_grad_zero(self, node: nodes.Module, file_path: str):
845 # Check if there's actual training happening
846 has_training_loop = any(
847 'loss.backward' in call.func.as_string() or
848 'backward' in call.func.as_string()
849 for call in node.nodes_of_class(nodes.Call)
850 )
851
852 has_optimizer = any(
853 'optim.' in call.func.as_string()
854 for call in node.nodes_of_class(nodes.Call)
855 )
856
857 has_grad_zero = any(
858 'zero_grad' in call.func.as_string() or
859 'set_to_none' in call.func.as_string() # Alternative method
860 for call in node.nodes_of_class(nodes.Call)
861 )
862
863 # Only report if there's training without gradient zeroing
864 if has_training_loop and has_optimizer and not has_grad_zero:
865 self.add_smell('PyTorch', 'Gradient Clear Checker', node, file_path)
866
867 def detect_pytorch_batch_norm(self, node: nodes.Module, file_path: str):
868 # Check if it's a CNN or deep network that could benefit from BatchNorm
869 conv_layers = [
870 'Conv1d', 'Conv2d', 'Conv3d',
871 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'
872 ]
873
874 has_conv_layers = any(
875 layer in call.func.as_string()
876 for call in node.nodes_of_class(nodes.Call)
877 for layer in conv_layers
878 )
879
880 has_deep_structure = len([
881 call for call in node.nodes_of_class(nodes.Call)
882 if any(layer in call.func.as_string() for layer in ['Linear', 'Conv'])
883 ]) > 2
884
885 has_batch_norm = any(
886 'BatchNorm' in call.func.as_string()
887 for call in node.nodes_of_class(nodes.Call)
888 )
889
890 # Only suggest BatchNorm for appropriate architectures
891 if (has_conv_layers or has_deep_structure) and not has_batch_norm:
892 self.add_smell('PyTorch', 'Batch Normalisation Checker', node, file_path)
893
894 def detect_pytorch_dropout(self, node: nodes.Module, file_path: str):
895 # Check if the model is complex enough to benefit from dropout
896 has_multiple_layers = len([
897 call for call in node.nodes_of_class(nodes.Call)
898 if any(layer in call.func.as_string()
899 for layer in ['Linear', 'Conv', 'LSTM', 'GRU'])
900 ]) > 2
901
902 has_training_code = any(
903 'train' in call.func.as_string() or
904 'backward' in call.func.as_string()
905 for call in node.nodes_of_class(nodes.Call)
906 )
907
908 has_dropout = any(
909 'Dropout' in call.func.as_string()
910 for call in node.nodes_of_class(nodes.Call)
911 )
912
913 # Only suggest dropout for complex models during training
914 if has_multiple_layers and has_training_code and not has_dropout:
915 self.add_smell('PyTorch', 'Dropout Usage Checker', node, file_path)
916
917 def detect_pytorch_augmentation(self, node: nodes.Module, file_path: str):
918 # Check if it's a computer vision task
919 vision_indicators = [
920 'ImageFolder', 'Dataset', 'DataLoader',
921 'Conv2d', 'Conv3d', 'ResNet', 'VGG',
922 'image', 'img', 'PIL'
923 ]
924
925 is_vision_task = any(
926 indicator in node.as_string()
927 for indicator in vision_indicators
928 )
929
930 has_training = any(
931 'train' in call.func.as_string() or
932 'fit' in call.func.as_string()
933 for call in node.nodes_of_class(nodes.Call)
934 )
935
936 has_augmentation = any(
937 'transforms' in call.func.as_string() or
938 'augment' in call.func.as_string().lower()
939 for call in node.nodes_of_class(nodes.Call)
940 )
941
942 # Only suggest augmentation for vision tasks during training
943 if is_vision_task and has_training and not has_augmentation:
944 self.add_smell('PyTorch', 'Data Augmentation Checker', node, file_path)
945
946 def detect_pytorch_lr_scheduler(self, node: nodes.Module, file_path: str):
947 # Check if there's a training loop with enough epochs
948 has_epochs = False
949 for node in node.nodes_of_class(nodes.For):
950 if 'epoch' in node.as_string().lower():
951 try:
952 # Try to determine the number of epochs
953 if hasattr(node.iter, 'args') and len(node.iter.args) > 0:
954 epoch_num = int(node.iter.args[0].value)
955 has_epochs = epoch_num > 10 # Suggest scheduler for longer training
956 except (AttributeError, ValueError):
957 continue
958
959 has_optimizer = any(
960 'optim.' in call.func.as_string()
961 for call in node.nodes_of_class(nodes.Call)
962 )
963
964 has_scheduler = any(
965 'lr_scheduler' in call.func.as_string() or
966 'LRScheduler' in call.func.as_string()
967 for call in node.nodes_of_class(nodes.Call)
968 )
969
970 # Only suggest scheduler for long training processes
971 if has_epochs and has_optimizer and not has_scheduler:
972 self.add_smell('PyTorch', 'Learning Rate Scheduler Checker', node, file_path)
973
974 def detect_pytorch_logging(self, node: nodes.Module, file_path: str):
975 # Check if there's actual training to log
976 has_training_loop = any(
977 'train' in call.func.as_string() or
978 'backward' in call.func.as_string()
979 for call in node.nodes_of_class(nodes.Call)
980 )
981
982 has_metrics = any(
983 metric in node.as_string().lower()
984 for metric in ['loss', 'accuracy', 'score', 'metric']
985 )
986
987 has_logging = any(
988 logger in node.as_string()
989 for logger in [
990 'tensorboard', 'SummaryWriter', 'wandb',
991 'MLflow', 'Neptune', 'logger'
992 ]
993 )
994
995 # Only suggest logging for training with metrics
996 if has_training_loop and has_metrics and not has_logging:
997 self.add_smell('PyTorch', 'Logging Checker', node, file_path)
998
999 def detect_pytorch_eval_mode(self, node: nodes.Module, file_path: str):
1000 # Check if there's validation/testing happening
1001 evaluation_indicators = [
1002 'val_loader', 'test_loader', 'validate',
1003 'evaluation', 'testing', 'predict'
1004 ]
1005
1006 has_evaluation = any(
1007 indicator in node.as_string().lower()
1008 for indicator in evaluation_indicators
1009 )
1010
1011 has_model_usage = any(
1012 'forward' in call.func.as_string() or
1013 'model(' in call.as_string()
1014 for call in node.nodes_of_class(nodes.Call)
1015 )
1016
1017 has_eval_mode = any(
1018 'eval' in call.func.as_string() or
1019 'train(False)' in call.func.as_string()
1020 for call in node.nodes_of_class(nodes.Call)
1021 )
1022
1023 # Only suggest eval mode when doing validation/testing
1024 if has_evaluation and has_model_usage and not has_eval_mode:
1025 self.add_smell('PyTorch', 'Model Evaluation Checker', node, file_path)
1026
1027 # TensorFlow Detection Methods
1028 def detect_tensorflow_smells(self, node: nodes.Module, file_path: str):
1029 """Detect TensorFlow-specific code smells like:
1030 - Missing random seeds
1031 - Missing early stopping
1032 - Missing checkpointing
1033 - Memory leaks
1034 - Missing masks
1035 - Python lists instead of TensorArrays
1036 - Missing threshold-independent metrics
1037 - Missing logging
1038 - Missing batch normalization
1039 - Missing dropout
1040 - Missing data augmentation
1041 - Missing learning rate scheduling
1042 - Missing model evaluation
1043
1044 Args:
1045 node: AST node representing the Python module
1046 file_path: Path to the file being analyzed
1047 """
1048 self.detect_tf_random_seed(node, file_path)
1049 self.detect_tf_early_stopping(node, file_path)
1050 self.detect_tf_checkpointing(node, file_path)
1051 self.detect_tf_memory_release(node, file_path)
1052 self.detect_tf_mask(node, file_path)
1053 self.detect_tf_tensor_array(node, file_path)
1054 self.detect_tf_metrics(node, file_path)
1055 self.detect_tf_logging(node, file_path)
1056 self.detect_tf_batch_norm(node, file_path)
1057 self.detect_tf_dropout(node, file_path)
1058 self.detect_tf_augmentation(node, file_path)
1059 self.detect_tf_lr_scheduler(node, file_path)
1060 self.detect_tf_model_evaluation(node, file_path)
1061
1062 def detect_tf_random_seed(self, node: nodes.Module, file_path: str):
1063 # Check for operations that need random seed control
1064 random_ops = [
1065 'random.normal', 'random.uniform', 'random.shuffle',
1066 'dropout', 'RandomRotation', 'RandomFlip', 'RandomZoom'
1067 ]
1068
1069 has_random_ops = any(
1070 op in call.func.as_string()
1071 for call in node.nodes_of_class(nodes.Call)
1072 for op in random_ops
1073 )
1074
1075 has_seed_set = any(
1076 'random.set_seed' in call.func.as_string() or
1077 'set_random_seed' in call.func.as_string()
1078 for call in node.nodes_of_class(nodes.Call)
1079 )
1080
1081 # Only report if random operations are used without seed
1082 if has_random_ops and not has_seed_set:
1083 self.add_smell('TensorFlow', 'Randomness Control Checker', node, file_path)
1084
1085 def detect_tf_early_stopping(self, node: nodes.Module, file_path: str):
1086 # Check if there's actual model training
1087 has_training = any(
1088 'model.fit' in call.func.as_string() or
1089 'fit(' in call.func.as_string()
1090 for call in node.nodes_of_class(nodes.Call)
1091 )
1092
1093 # Check for training loop with multiple epochs
1094 has_multiple_epochs = any(
1095 'epochs' in kw.arg and getattr(kw.value, 'value', 1) > 1
1096 for call in node.nodes_of_class(nodes.Call)
1097 for kw in call.keywords
1098 )
1099
1100 has_early_stopping = any(
1101 'EarlyStopping' in call.func.as_string()
1102 for call in node.nodes_of_class(nodes.Call)
1103 )
1104
1105 # Only suggest early stopping for actual training with multiple epochs
1106 if has_training and has_multiple_epochs and not has_early_stopping:
1107 self.add_smell('TensorFlow', 'Early Stopping Checker', node, file_path)
1108
1109 def detect_tf_checkpointing(self, node: nodes.Module, file_path: str):
1110 # Check for model training and complexity
1111 has_training = any(
1112 'model.fit' in call.func.as_string()
1113 for call in node.nodes_of_class(nodes.Call)
1114 )
1115
1116 has_complex_model = any(
1117 layer in node.as_string()
1118 for layer in ['Dense', 'Conv', 'LSTM', 'GRU']
1119 )
1120
1121 has_checkpointing = any(
1122 'ModelCheckpoint' in call.func.as_string() or
1123 'save_weights' in call.func.as_string()
1124 for call in node.nodes_of_class(nodes.Call)
1125 )
1126
1127 # Only suggest checkpointing for complex models during training
1128 if has_training and has_complex_model and not has_checkpointing:
1129 self.add_smell('TensorFlow', 'Checkpointing Checker', node, file_path)
1130
1131 def detect_tf_memory_release(self, node: nodes.Module, file_path: str):
1132 # Check for memory-intensive operations
1133 memory_intensive_ops = [
1134 'model.fit', 'predict', 'evaluate',
1135 'Conv', 'LSTM', 'GRU', 'Attention'
1136 ]
1137
1138 has_intensive_ops = any(
1139 op in call.func.as_string()
1140 for call in node.nodes_of_class(nodes.Call)
1141 for op in memory_intensive_ops
1142 )
1143
1144 has_memory_release = any(
1145 'clear_session' in call.func.as_string() or
1146 'reset_states' in call.func.as_string()
1147 for call in node.nodes_of_class(nodes.Call)
1148 )
1149
1150 # Only suggest memory release for memory-intensive operations
1151 if has_intensive_ops and not has_memory_release:
1152 self.add_smell('TensorFlow', 'Memory Release Checker', node, file_path)
1153
1154 def detect_tf_mask(self, node: nodes.Module, file_path: str):
1155 for call in node.nodes_of_class(nodes.Call):
1156 if hasattr(call.func, 'as_string') and 'tf.math.log' in call.func.as_string():
1157 # Check for potential zero/negative inputs
1158 input_arg = call.args[0] if call.args else None
1159 if input_arg and hasattr(input_arg, 'as_string'): # Add null check
1160 risky_ops = ['zeros', 'random', 'subtract', 'sub']
1161 has_risky_input = any(
1162 op in input_arg.as_string()
1163 for op in risky_ops
1164 )
1165
1166 has_mask = (
1167 len(call.args) > 1 or
1168 any(mask_op in node.as_string()
1169 for mask_op in ['where', 'clip', 'maximum'])
1170 )
1171
1172 # Only report if there's risk of invalid input without masking
1173 if has_risky_input and not has_mask:
1174 self.add_smell('TensorFlow', 'Mask Missing Checker', call, file_path)
1175
1176 def detect_tf_tensor_array(self, node: nodes.Module, file_path: str):
1177 # Check for dynamic sequence operations
1178 dynamic_ops = [
1179 'RNN', 'LSTM', 'GRU', 'while_loop',
1180 'map_fn', 'scan', 'dynamic'
1181 ]
1182
1183 has_dynamic_ops = any(
1184 op in node.as_string()
1185 for op in dynamic_ops
1186 )
1187
1188 using_python_list = any(
1189 'append' in call.func.as_string() or
1190 'extend' in call.func.as_string()
1191 for call in node.nodes_of_class(nodes.Call)
1192 )
1193
1194 has_tensor_array = any(
1195 'TensorArray' in call.func.as_string()
1196 for call in node.nodes_of_class(nodes.Call)
1197 )
1198
1199 # Only suggest TensorArray for dynamic operations using Python lists
1200 if has_dynamic_ops and using_python_list and not has_tensor_array:
1201 self.add_smell('TensorFlow', 'Tensor Array Checker', node, file_path)
1202
1203 def detect_tf_metrics(self, node: nodes.Module, file_path: str):
1204 # Check if it's a classification task
1205 classification_indicators = [
1206 'Binary', 'Categorical', 'sparse_categorical',
1207 'accuracy', 'precision', 'recall', 'f1'
1208 ]
1209
1210 is_classification = any(
1211 indicator in node.as_string()
1212 for indicator in classification_indicators
1213 )
1214
1215 has_basic_metrics = any(
1216 metric in call.func.as_string()
1217 for call in node.nodes_of_class(nodes.Call)
1218 for metric in ['accuracy', 'precision', 'recall']
1219 )
1220
1221 has_threshold_independent = any(
1222 metric in call.func.as_string()
1223 for call in node.nodes_of_class(nodes.Call)
1224 for metric in ['AUC', 'AUROCScore', 'PrecisionRecallCurve']
1225 )
1226
1227 # Only suggest threshold-independent metrics for classification tasks
1228 if is_classification and has_basic_metrics and not has_threshold_independent:
1229 self.add_smell('TensorFlow', 'Dependent Threshold Checker', node, file_path)
1230
1231 def detect_tf_logging(self, node: nodes.Module, file_path: str):
1232 # Check if there's actual training to log
1233 has_training_loop = any(
1234 'model.fit' in call.func.as_string() or
1235 'train' in call.func.as_string()
1236 for call in node.nodes_of_class(nodes.Call)
1237 )
1238
1239 has_metrics = any(
1240 metric in node.as_string().lower()
1241 for metric in ['loss', 'accuracy', 'score', 'metric']
1242 )
1243
1244 has_logging = any(
1245 logger in node.as_string()
1246 for logger in [
1247 'tf.summary', 'TensorBoard', 'CSVLogger',
1248 'WandbCallback', 'MLflow'
1249 ]
1250 )
1251
1252 # Only suggest logging for training with metrics
1253 if has_training_loop and has_metrics and not has_logging:
1254 self.add_smell('TensorFlow', 'Logging Checker', node, file_path)
1255
1256 def detect_tf_batch_norm(self, node: nodes.Module, file_path: str):
1257 # Check if it's a deep network that could benefit from BatchNorm
1258 conv_layers = [
1259 'Conv1D', 'Conv2D', 'Conv3D',
1260 'Dense', 'SeparableConv'
1261 ]
1262
1263 has_conv_layers = any(
1264 layer in call.func.as_string()
1265 for call in node.nodes_of_class(nodes.Call)
1266 for layer in conv_layers
1267 )
1268
1269 # Check if model is deep enough
1270 layer_count = len([
1271 call for call in node.nodes_of_class(nodes.Call)
1272 if any(layer in call.func.as_string()
1273 for layer in conv_layers)
1274 ])
1275
1276 has_batch_norm = any(
1277 'BatchNormalization' in call.func.as_string()
1278 for call in node.nodes_of_class(nodes.Call)
1279 )
1280
1281 # Only suggest BatchNorm for deep networks with conv layers
1282 if has_conv_layers and layer_count > 2 and not has_batch_norm:
1283 self.add_smell('TensorFlow', 'Batch Normalisation Checker', node, file_path)
1284
1285 def detect_tf_dropout(self, node: nodes.Module, file_path: str):
1286 # Check if model is complex enough to need dropout
1287 deep_layers = ['Dense', 'Conv', 'LSTM', 'GRU']
1288
1289 layer_count = len([
1290 call for call in node.nodes_of_class(nodes.Call)
1291 if any(layer in call.func.as_string()
1292 for layer in deep_layers)
1293 ])
1294
1295 has_training = any(
1296 'model.fit' in call.func.as_string() or
1297 'training=True' in call.as_string()
1298 for call in node.nodes_of_class(nodes.Call)
1299 )
1300
1301 has_dropout = any(
1302 'Dropout' in call.func.as_string()
1303 for call in node.nodes_of_class(nodes.Call)
1304 )
1305
1306 # Only suggest dropout for complex models during training
1307 if layer_count > 2 and has_training and not has_dropout:
1308 self.add_smell('TensorFlow', 'Dropout Usage Checker', node, file_path)
1309
1310 def detect_tf_augmentation(self, node: nodes.Module, file_path: str):
1311 # Check if it's a computer vision task
1312 vision_indicators = [
1313 'image', 'img', 'Conv2D', 'Conv3D',
1314 'ImageDataGenerator', 'load_img'
1315 ]
1316
1317 is_vision_task = any(
1318 indicator in node.as_string()
1319 for indicator in vision_indicators
1320 )
1321
1322 has_training = any(
1323 'model.fit' in call.func.as_string() or
1324 'train' in call.func.as_string().lower()
1325 for call in node.nodes_of_class(nodes.Call)
1326 )
1327
1328 has_augmentation = any(
1329 aug in call.func.as_string()
1330 for call in node.nodes_of_class(nodes.Call)
1331 for aug in ['ImageDataGenerator', 'RandomFlip', 'RandomRotation', 'RandomZoom']
1332 )
1333
1334 # Only suggest augmentation for vision tasks during training
1335 if is_vision_task and has_training and not has_augmentation:
1336 self.add_smell('TensorFlow', 'Data Augmentation Checker', node, file_path)
1337
1338 def detect_tf_lr_scheduler(self, node: nodes.Module, file_path: str):
1339 # Check if there's a long training process
1340 has_epochs = False
1341 for node in node.nodes_of_class(nodes.For):
1342 if 'epoch' in node.as_string().lower():
1343 try:
1344 if hasattr(node.iter, 'args') and len(node.iter.args) > 0:
1345 epoch_num = int(node.iter.args[0].value)
1346 has_epochs = epoch_num > 5 # Suggest scheduler for longer training
1347 except (AttributeError, ValueError):
1348 continue
1349
1350 has_optimizer = any(
1351 'optimizer' in call.func.as_string().lower()
1352 for call in node.nodes_of_class(nodes.Call)
1353 )
1354
1355 has_scheduler = any(
1356 'LearningRateScheduler' in call.func.as_string() or
1357 'schedules' in call.func.as_string() or
1358 'ReduceLROnPlateau' in call.func.as_string()
1359 for call in node.nodes_of_class(nodes.Call)
1360 )
1361
1362 # Only suggest scheduler for long training processes
1363 if has_epochs and has_optimizer and not has_scheduler:
1364 self.add_smell('TensorFlow', 'Learning Rate Scheduler Checker', node, file_path)
1365
1366 def detect_tf_model_evaluation(self, node: nodes.Module, file_path: str):
1367 # Check if there's a model to evaluate
1368 has_model = any(
1369 'model' in call.func.as_string() or
1370 'Sequential' in call.func.as_string() or
1371 'Model(' in call.func.as_string()
1372 for call in node.nodes_of_class(nodes.Call)
1373 )
1374
1375 has_test_data = any(
1376 data in node.as_string().lower()
1377 for data in ['test', 'val', 'valid', 'evaluation']
1378 )
1379
1380 has_evaluation = any(
1381 'evaluate' in call.func.as_string() or
1382 'predict' in call.func.as_string()
1383 for call in node.nodes_of_class(nodes.Call)
1384 )
1385
1386 # Only suggest evaluation when there's a model and test data
1387 if has_model and has_test_data and not has_evaluation:
1388 self.add_smell('TensorFlow', 'Model Evaluation Checker', node, file_path)
1389
1390 def generate_report(self) -> str:
1391 """Generate a detailed text report of all detected code smells.
1392
1393 Returns:
1394 A formatted string containing the analysis report with:
1395 - Framework-specific smell counts
1396 - Details for each smell including location and fixes
1397 - Total number of smells detected
1398 """
1399 report = "Framework-Specific Code Smell Report\n====================================\n\n"
1400 smell_counts = {}
1401 for smell in self.smells:
1402 framework = smell['framework']
1403 if framework not in smell_counts:
1404 smell_counts[framework] = {}
1405 if smell['name'] not in smell_counts[framework]:
1406 smell_counts[framework][smell['name']] = 0
1407 smell_counts[framework][smell['name']] += 1
1408
1409 report += f"Framework: {framework}\n"
1410 report += f"Smell: {smell['name']}\n"
1411 report += f"File: {smell['file_path']}\n"
1412
1413 if smell['line_number'] != 0:
1414 report += f"Line: {smell['line_number']}\n"
1415
1416 code_lines = smell['code_snippet'].strip().split('\n')
1417 if len(code_lines) <= 3:
1418 report += f"Code Snippet:\n{smell['code_snippet']}\n"
1419
1420 report += f"How to Fix: {smell['how_to_fix']}\n"
1421 report += f"Benefits: {smell['benefits']}\n"
1422 report += f"Strategies: {smell['strategies']}\n\n"
1423
1424 report += "Smell Counts:\n"
1425 for framework, counts in smell_counts.items():
1426 report += f"{framework}:\n"
1427 for smell, count in counts.items():
1428 report += f" {smell}: {count}\n"
1429 report += f"\nTotal smells detected: {len(self.smells)}"
1430 return report
1431
1432 def get_results(self) -> List[Dict[str, str]]:
1433 """Get a simplified list of detected code smells.
1434
1435 Returns:
1436 List of dictionaries containing:
1437 - framework: The ML framework
1438 - name: Name of the smell
1439 - fix: How to fix it
1440 - benefits: Benefits of fixing
1441 - location: Where it was found
1442 """
1443 return [
1444 {
1445 'framework': smell['framework'],
1446 'name': smell['name'],
1447 'fix': smell['how_to_fix'],
1448 'benefits': smell['benefits'],
1449 'location': f"Line {smell['line_number']}" if smell['line_number'] != 0 else ""
1450 }
1451 for smell in self.smells
1452 ]
1453
1454 def get_smells(self) -> Dict[str, List[Dict[str, str]]]:
1455 return {"General": [{"name": "Import Checker",
1456 "how_to_fix": "Use standard naming conventions for imported modules.",
1457 "benefits": "Improves code readability and maintainability.",
1458 "strategies": "Follow standard naming conventions (e.g., import numpy as np, import pandas as pd)."}],
1459 "Pandas": [{"name": "Unnecessary Iteration",
1460 "how_to_fix": "Use vectorized operations instead of loops.",
1461 "benefits": "Enhances performance and reduces execution time.",
1462 "strategies": "Replace loops with Pandas vectorized functions (e.g., apply, map, vectorized arithmetic operations)."},
1463 {"name": "DataFrame Iteration Modification",
1464 "how_to_fix": "Avoid modifying DataFrame during iteration.",
1465 "benefits": "Prevents unexpected behaviour and potential data corruption.",
1466 "strategies": "Use temporary variables or vectorized operations for modifications."},
1467 {"name": "Chain Indexing",
1468 "how_to_fix": "Use single indexing or .loc[], .iloc[] methods.",
1469 "benefits": "Enhances code readability and prevents performance issues.",
1470 "strategies": "Use .loc[] or .iloc[] for DataFrame indexing instead of chained indexing."},
1471 {"name": "Datatype Checker",
1472 "how_to_fix": "Set data types explicitly when importing data.",
1473 "benefits": "Ensures correct data format and reduces memory usage.",
1474 "strategies": "Use dtype parameter in Pandas read functions (e.g., pd.read_csv)."},
1475 {"name": "Column Selection Checker",
1476 "how_to_fix": "Select necessary columns after importing DataFrame.",
1477 "benefits": "Clarifies data usage and improves performance.",
1478 "strategies": "Use column selection methods (e.g., df[['col1', 'col2']])."},
1479 {"name": "Merge Parameter Checker",
1480 "how_to_fix": "Specify how, on, and validate parameters in merge operations.",
1481 "benefits": "Ensures accurate data merging and prevents data loss.",
1482 "strategies": "Use appropriate parameters in pd.merge function."},
1483 {"name": "InPlace Checker",
1484 "how_to_fix": "Assign operations to a new DataFrame variable.",
1485 "benefits": "Prevents data loss and improves code clarity.",
1486 "strategies": "Assign operation results to a new variable instead of using inplace=True."},
1487 {"name": "DataFrame Conversion Checker",
1488 "how_to_fix": "Use .to_numpy() instead of .values.",
1489 "benefits": "Ensures future compatibility and avoids unexpected behaviour.",
1490 "strategies": "Replace .values with .to_numpy() in Pandas DataFrame conversion."}],
1491 "NumPy": [{"name": "NaN Equality Checker",
1492 "how_to_fix": "Use np.isnan() to check for NaN values.",
1493 "benefits": "Ensures accurate data handling and avoids logical errors.",
1494 "strategies": "Replace == np.nan with np.isnan()."},
1495 {"name": "Randomness Control Checker",
1496 "how_to_fix": "Use np.random.seed() for reproducibility.",
1497 "benefits": "Enables reproducible results and debugging.",
1498 "strategies": "Set random seed using np.random.seed(seed_value)."},
1499 {"name": "Array Creation Efficiency",
1500 "how_to_fix": "Specify dtype when creating arrays.",
1501 "benefits": "Improves memory efficiency and avoids unnecessary conversions.",
1502 "strategies": "Use dtype parameter in np.array, np.zeros, np.ones, np.empty functions."},
1503 {"name": "Inefficient Operations",
1504 "how_to_fix": "Optimize numerical operations.",
1505 "benefits": "Improves performance and reduces execution time.",
1506 "strategies": "Use vectorized functions instead of loops for element-wise operations."},
1507 {"name": "Dtype Consistency",
1508 "how_to_fix": "Ensure consistent data types in operations.",
1509 "benefits": "Improves accuracy and avoids logical errors.",
1510 "strategies": "Use consistent data types in np.sum, np.mean, np.max, np.min functions."},
1511 {"name": "Broadcasting Risk",
1512 "how_to_fix": "Specify axis in operations between arrays.",
1513 "benefits": "Improves performance and avoids broadcasting issues.",
1514 "strategies": "Use axis parameter in np.reshape, np.transpose functions."},
1515 {"name": "Copy-View Confusion",
1516 "how_to_fix": "Explicitly copy arrays before slicing.",
1517 "benefits": "Improves performance and avoids unexpected behaviour.",
1518 "strategies": "Use .copy() method or np.copy function before slicing."},
1519 {"name": "Missing Axis Specification",
1520 "how_to_fix": "Specify axis in array operations.",
1521 "benefits": "Improves performance and avoids unexpected behaviour.",
1522 "strategies": "Use axis parameter in np.sum, np.mean, np.max, np.min, np.argmax, np.argmin, np.any, np.all functions."}],
1523 "ScikitLearn": [{"name": "Scaler Missing Checker",
1524 "how_to_fix": "Apply scaling before scaling-sensitive operations.",
1525 "benefits": "Improves model performance and accuracy.",
1526 "strategies": "Use StandardScaler, MinMaxScaler, etc., before applying PCA, SVM, etc."},
1527 {"name": "Pipeline Checker",
1528 "how_to_fix": "Use Pipelines for all scikit-learn estimators.",
1529 "benefits": "Prevents data leakage and ensures correct model evaluation.",
1530 "strategies": "Implement Pipeline from sklearn.pipeline for preprocessing and model fitting."},
1531 {"name": "Cross Validation Checker",
1532 "how_to_fix": "Implement cross-validation techniques for robust model evaluation.",
1533 "benefits": "Enhances model performance and reduces overfitting.",
1534 "strategies": "Use cross_val_score, KFold, or other cross-validation methods."},
1535 {"name": "Randomness Control Checker",
1536 "how_to_fix": "Set random_state in estimators for reproducibility.",
1537 "benefits": "Ensures reproducible results and consistent model behaviour.",
1538 "strategies": "Set random_state to a fixed value in scikit-learn estimators."},
1539 {"name": "Verbose Mode Checker",
1540 "how_to_fix": "Enable verbose mode for long training processes.",
1541 "benefits": "Provides better insights into model training progress.",
1542 "strategies": "Set verbose=True in scikit-learn estimators with long training times."},
1543 {"name": "Dependent Threshold Checker",
1544 "how_to_fix": "Use threshold-independent metrics alongside threshold-dependent ones.",
1545 "benefits": "Provides a comprehensive evaluation of model performance.",
1546 "strategies": "Include metrics like ROC AUC score alongside accuracy or F1-score."},
1547 {"name": "Unit Testing Checker",
1548 "how_to_fix": "Write unit tests for data processing and model components.",
1549 "benefits": "Ensures code reliability and prevents bugs.",
1550 "strategies": "Use unittest or pytest to write and run tests for individual components."},
1551 {"name": "Data Leakage Checker",
1552 "how_to_fix": "Split data before any preprocessing or feature engineering.",
1553 "benefits": "Prevents data leakage and ensures valid model evaluation.",
1554 "strategies": "Use train_test_split before any data preprocessing steps."},
1555 {"name": "Exception Handling Checker",
1556 "how_to_fix": "Implement proper exception handling for model operations.",
1557 "benefits": "Improves code robustness and error reporting.",
1558 "strategies": "Use try-except blocks to handle potential exceptions in model training and prediction."}],
1559 "PyTorch": [{"name": "Randomness Control Checker",
1560 "how_to_fix": "Set random seed using torch.manual_seed() for reproducibility.",
1561 "benefits": "Ensures reproducible results across different runs.",
1562 "strategies": "Add torch.manual_seed(seed_value) at the start of your script."},
1563 {"name": "Deterministic Algorithm Usage Checker",
1564 "how_to_fix": "Enable deterministic algorithms using torch.use_deterministic_algorithms(True).",
1565 "benefits": "Ensures consistent results across different hardware and runs.",
1566 "strategies": "Set torch.use_deterministic_algorithms(True) and handle any required adjustments."},
1567 {"name": "Randomness Control Checker (PyTorch-Dataloader)",
1568 "how_to_fix": "Set worker_init_fn and generator in DataLoader for reproducible data loading.",
1569 "benefits": "Ensures consistent data loading across different runs.",
1570 "strategies": "Configure worker_init_fn and generator parameters in DataLoader initialization."},
1571 {"name": "Mask Missing Checker",
1572 "how_to_fix": "Use appropriate masking when dealing with log operations.",
1573 "benefits": "Prevents numerical errors and improves model stability.",
1574 "strategies": "Apply masks before log operations to handle invalid values."},
1575 {"name": "Net Forward Checker",
1576 "how_to_fix": "Use model(input) instead of model.forward(input).",
1577 "benefits": "Follows PyTorch best practices and handles hooks properly.",
1578 "strategies": "Replace direct forward() calls with the recommended calling syntax."},
1579 {"name": "Gradient Clear Checker",
1580 "how_to_fix": "Clear gradients before each backward pass using optimizer.zero_grad().",
1581 "benefits": "Prevents gradient accumulation and ensures correct updates.",
1582 "strategies": "Add optimizer.zero_grad() before loss.backward() in training loop."},
1583 {"name": "Batch Normalisation Checker",
1584 "how_to_fix": "Include BatchNorm layers in your model architecture.",
1585 "benefits": "Improves training stability and model convergence.",
1586 "strategies": "Add torch.nn.BatchNorm layers after convolutional or linear layers."},
1587 {"name": "Dropout Usage Checker",
1588 "how_to_fix": "Include Dropout layers for regularization.",
1589 "benefits": "Reduces overfitting and improves model generalization.",
1590 "strategies": "Add torch.nn.Dropout layers in your model architecture."},
1591 {"name": "Data Augmentation Checker",
1592 "how_to_fix": "Implement data augmentation using torchvision.transforms.",
1593 "benefits": "Improves model robustness and reduces overfitting.",
1594 "strategies": "Use torchvision.transforms to apply various augmentation techniques."},
1595 {"name": "Learning Rate Scheduler Checker",
1596 "how_to_fix": "Implement learning rate scheduling.",
1597 "benefits": "Improves training convergence and final model performance.",
1598 "strategies": "Use torch.optim.lr_scheduler for dynamic learning rate adjustment."},
1599 {"name": "Logging Checker",
1600 "how_to_fix": "Implement proper logging using tensorboard or similar tools.",
1601 "benefits": "Enables better experiment tracking and debugging.",
1602 "strategies": "Use tensorboardX or torch.utils.tensorboard for logging."},
1603 {"name": "Model Evaluation Checker",
1604 "how_to_fix": "Set model to evaluation mode using model.eval().",
1605 "benefits": "Ensures correct behavior of layers like BatchNorm and Dropout during inference.",
1606 "strategies": "Call model.eval() before validation/testing and model.train() before training."}],
1607 "TensorFlow": [{"name": "Randomness Control Checker",
1608 "how_to_fix": "Set random seed using tf.random.set_seed().",
1609 "benefits": "Ensures reproducible results across different runs.",
1610 "strategies": "Add tf.random.set_seed(seed_value) at the start of your script."},
1611 {"name": "Early Stopping Checker",
1612 "how_to_fix": "Implement early stopping using tf.keras.callbacks.EarlyStopping.",
1613 "benefits": "Prevents overfitting and reduces unnecessary training time.",
1614 "strategies": "Add EarlyStopping callback to model.fit()."},
1615 {"name": "Checkpointing Checker",
1616 "how_to_fix": "Implement model checkpointing using tf.keras.callbacks.ModelCheckpoint.",
1617 "benefits": "Enables model recovery and saves best performing models.",
1618 "strategies": "Add ModelCheckpoint callback to model.fit()."},
1619 {"name": "Memory Release Checker",
1620 "how_to_fix": "Clear session after model creation/training using tf.keras.backend.clear_session().",
1621 "benefits": "Prevents memory leaks and reduces resource usage.",
1622 "strategies": "Call clear_session() after completing major operations."},
1623 {"name": "Mask Missing Checker",
1624 "how_to_fix": "Use appropriate masking for log operations.",
1625 "benefits": "Prevents numerical errors and improves model stability.",
1626 "strategies": "Apply masks before log operations to handle invalid values."},
1627 {"name": "Tensor Array Checker",
1628 "how_to_fix": "Use tf.TensorArray for dynamic tensor operations.",
1629 "benefits": "Improves memory efficiency for dynamic computations.",
1630 "strategies": "Replace Python lists with tf.TensorArray for growing tensors."},
1631 {"name": "Dependent Threshold Checker",
1632 "how_to_fix": "Use threshold-independent metrics like AUC.",
1633 "benefits": "Provides more robust model evaluation.",
1634 "strategies": "Include tf.keras.metrics.AUC in your model metrics."},
1635 {"name": "Logging Checker",
1636 "how_to_fix": "Implement TensorBoard logging.",
1637 "benefits": "Enables better experiment tracking and visualization.",
1638 "strategies": "Use tf.summary or TensorBoard callbacks for logging."},
1639 {"name": "Batch Normalisation Checker",
1640 "how_to_fix": "Include BatchNormalization layers in your model.",
1641 "benefits": "Improves training stability and model convergence.",
1642 "strategies": "Add tf.keras.layers.BatchNormalization layers to your model."},
1643 {"name": "Dropout Usage Checker",
1644 "how_to_fix": "Include Dropout layers for regularization.",
1645 "benefits": "Reduces overfitting and improves generalization.",
1646 "strategies": "Add tf.keras.layers.Dropout layers to your model."},
1647 {"name": "Data Augmentation Checker",
1648 "how_to_fix": "Implement data augmentation using ImageDataGenerator.",
1649 "benefits": "Improves model robustness and reduces overfitting.",
1650 "strategies": "Use tf.keras.preprocessing.image.ImageDataGenerator for augmentation."},
1651 {"name": "Learning Rate Scheduler Checker",
1652 "how_to_fix": "Implement learning rate scheduling.",
1653 "benefits": "Improves training convergence and final model performance.",
1654 "strategies": "Use LearningRateScheduler callback or custom scheduling."},
1655 {"name": "Model Evaluation Checker",
1656 "how_to_fix": "Evaluate model performance using model.evaluate().",
1657 "benefits": "Provides standardized model evaluation.",
1658 "strategies": "Use model.evaluate() on validation/test data."},
1659 {"name": "Unit Testing Checker",
1660 "how_to_fix": "Implement unit tests using tf.test.TestCase.",
1661 "benefits": "Ensures code reliability and prevents regressions.",
1662 "strategies": "Create test classes inheriting from tf.test.TestCase."},
1663 {"name": "Exception Handling Checker",
1664 "how_to_fix": "Implement proper exception handling.",
1665 "benefits": "Improves code robustness and error reporting.",
1666 "strategies": "Use try-except blocks to handle TensorFlow-specific exceptions."}]}