Source code for bittensor.utils.registration
# The MIT License (MIT)
# Copyright © 2024 Opentensor Foundation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import functools
import os
from typing import TYPE_CHECKING
import numpy
from bittensor.utils.btlogging import logging
[docs]
def use_torch() -> bool:
"""Force the use of torch over numpy for certain operations."""
return True if os.getenv("USE_TORCH") == "1" else False
[docs]
def legacy_torch_api_compat(func):
"""
Convert function operating on numpy Input&Output to legacy torch Input&Output API if `use_torch()` is True.
Args:
func (function): Function with numpy Input/Output to be decorated.
Returns:
decorated (function): Decorated function.
"""
@functools.wraps(func)
def decorated(*args, **kwargs):
if use_torch():
# if argument is a Torch tensor, convert it to numpy
args = [
arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
for arg in args
]
kwargs = {
key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
for key, value in kwargs.items()
}
ret = func(*args, **kwargs)
if use_torch():
# if return value is a numpy array, convert it to Torch tensor
if isinstance(ret, numpy.ndarray):
ret = torch.from_numpy(ret)
return ret
return decorated
[docs]
@functools.cache
def _get_real_torch():
try:
import torch as _real_torch
except ImportError:
_real_torch = None
return _real_torch
[docs]
def log_no_torch_error():
logging.error(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)
[docs]
class LazyLoadedTorch:
"""A lazy-loading proxy for the torch module."""
[docs]
def __bool__(self):
return bool(_get_real_torch())
[docs]
def __getattr__(self, name):
if real_torch := _get_real_torch():
return getattr(real_torch, name)
else:
log_no_torch_error()
raise ImportError("torch not installed")
if TYPE_CHECKING:
import torch
else:
torch = LazyLoadedTorch()