-
Notifications
You must be signed in to change notification settings - Fork 104
Open
Description
Hello 👋
I am having issues with using mup.get_coord_data because some of my modules return dataclass objects. Currently only, dict, list, tuple and tensors are supported. It would be great, and fairly easy, to also support dataclasses.
I think that the only code to modify would be
Lines 129 to 148 in 1981497
| def get_stat(d, x, fdict): | |
| if isinstance(x, (tuple, list)): | |
| for i, _x in enumerate(x): | |
| _d = copy(d) | |
| _d['module'] += f'[{i}]' | |
| get_stat(_d, _x, fdict) | |
| elif isinstance(x, dict): | |
| for name, _x in x.items(): | |
| _d = copy(d) | |
| _d['module'] += f'[{name}]' | |
| get_stat(_d, _x, fdict) | |
| elif isinstance(x, torch.Tensor): | |
| _d = copy(d) | |
| for fname, f in fdict.items(): | |
| _d[fname] = f(x).item() | |
| records.append(_d) | |
| elif x is None: | |
| pass | |
| else: | |
| raise NotImplementedError(f'Unexpected output type: {type(x)}') |
I can do a PR for that.
Metadata
Metadata
Assignees
Labels
No labels