Skip to content

Commit

Permalink
Fix type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
sosthene-nitrokey committed Feb 29, 2024
1 parent ddfdd3f commit b0b5d38
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 38 deletions.
55 changes: 35 additions & 20 deletions pynitrokey/cli/nk3/piv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def piv() -> None:
)
def admin_auth(admin_key: str) -> None:
try:
admin_key: bytes = bytearray.fromhex(admin_key)
admin_key_bytes = bytearray.fromhex(admin_key)
except ValueError:
local_critical(
"Key is expected to be an hexadecimal string",
support_hint=False,
)

device = PivApp()
device.authenticate_admin(admin_key)
device.authenticate_admin(admin_key_bytes)
local_print("Authenticated successfully")


Expand All @@ -52,15 +52,15 @@ def admin_auth(admin_key: str) -> None:
)
def init(admin_key: str) -> None:
try:
admin_key: bytes = bytearray.fromhex(admin_key)
admin_key_bytes = bytearray.fromhex(admin_key)
except ValueError:
local_critical(
"Key is expected to be an hexadecimal string",
support_hint=False,
)

device = PivApp()
device.authenticate_admin(admin_key)
device.authenticate_admin(admin_key_bytes)
guid = device.init()
local_print("Device intialized successfully")
local_print(f"GUID: {guid.hex().upper()}")
Expand All @@ -83,9 +83,11 @@ def info() -> None:
if not printed_head:
local_print("Keys:")
printed_head = True
cert = cryptography.x509.load_der_x509_certificate(cert)
parsed_cert = cryptography.x509.load_der_x509_certificate(cert)
local_print(f" {key}")
local_print(f" algorithm: {cert.signature_algorithm_oid._name}")
local_print(
f" algorithm: {parsed_cert.signature_algorithm_oid._name}"
)
if not printed_head:
local_print("No certificate found")
pass
Expand All @@ -103,17 +105,17 @@ def info() -> None:
)
def change_admin_key(current_admin_key: str, new_admin_key: str) -> None:
try:
current_admin_key: bytes = bytearray.fromhex(current_admin_key)
new_admin_key: bytes = bytearray.fromhex(new_admin_key)
current_admin_key_bytes = bytearray.fromhex(current_admin_key)
new_admin_key_bytes = bytearray.fromhex(new_admin_key)
except ValueError:
local_critical(
"Key is expected to be an hexadecimal string",
support_hint=False,
)

device = PivApp()
device.authenticate_admin(current_admin_key)
device.set_admin_key(new_admin_key)
device.authenticate_admin(current_admin_key_bytes)
device.set_admin_key(new_admin_key_bytes)
local_print("Changed key successfully")


Expand Down Expand Up @@ -168,7 +170,7 @@ def change_puk(current_puk: str, new_puk: str) -> None:
prompt="Enter the new PIN",
hide_input=True,
)
def reset_retry_counter(puk: str, new_pin: str):
def reset_retry_counter(puk: str, new_pin: str) -> None:
device = PivApp()
device.reset_retry_counter(puk, new_pin)
local_print("Unlocked PIN successfully")
Expand Down Expand Up @@ -294,7 +296,7 @@ def generate_key(
out_file: str,
) -> None:
try:
admin_key: bytes = bytearray.fromhex(admin_key)
admin_key_bytes = bytearray.fromhex(admin_key)
except ValueError:
local_critical(
"Key is expected to be an hexadecimal string",
Expand All @@ -304,7 +306,7 @@ def generate_key(
key_ref = int(key_hex, 16)

device = PivApp()
device.authenticate_admin(admin_key)
device.authenticate_admin(admin_key_bytes)
device.login(pin)

if algo == "rsa2048":
Expand All @@ -326,9 +328,13 @@ def generate_key(
data = Tlv.parse(find_by_id(0x7F49, data), recursive=False)

if algo == "nistp256":
key = find_by_id(0x86, data)[1:]
public_x = int.from_bytes(key[:32], byteorder="big", signed=False)
public_y = int.from_bytes(key[32:], byteorder="big", signed=False)
key_data = find_by_id(0x86, data)
if key_data is None:
local_critical("Device did not send public key data")
return
key_data = key_data[1:]
public_x = int.from_bytes(key_data[:32], byteorder="big", signed=False)
public_y = int.from_bytes(key_data[32:], byteorder="big", signed=False)
public_numbers = ec.EllipticCurvePublicNumbers(
public_x,
public_y,
Expand All @@ -340,8 +346,14 @@ def generate_key(
serialization.PublicFormat.SubjectPublicKeyInfo,
)
elif algo == "rsa2048":
modulus = int.from_bytes(find_by_id(0x81, data), byteorder="big", signed=False)
exponent = int.from_bytes(find_by_id(0x82, data), byteorder="big", signed=False)
modulus_data = find_by_id(0x81, data)
exponent_data = find_by_id(0x82, data)
if modulus_data is None or exponent_data is None:
local_critical("Device did not send public key data")
return

modulus = int.from_bytes(modulus_data, byteorder="big", signed=False)
exponent = int.from_bytes(exponent_data, byteorder="big", signed=False)
public_numbers = rsa.RSAPublicNumbers(exponent, modulus)
public_key = public_numbers.public_key()
public_key_der = public_key.public_bytes(
Expand All @@ -353,6 +365,9 @@ def generate_key(

public_key_info = PublicKeyInfo.load(public_key_der, strict=True)

if domain_component is None:
domain_component = []

if subject_name is None:
rdns = []
else:
Expand Down Expand Up @@ -552,15 +567,15 @@ def generate_key(
)
def write_certificate(admin_key: str, format: str, key: str, path: str) -> None:
try:
admin_key: bytes = bytearray.fromhex(admin_key)
admin_key_bytes: bytes = bytearray.fromhex(admin_key)
except ValueError:
local_critical(
"Key is expected to be an hexadecimal string",
support_hint=False,
)

device = PivApp()
device.authenticate_admin(admin_key)
device.authenticate_admin(admin_key_bytes)

with click.open_file(path, mode="rb") as f:
cert_bytes = f.read()
Expand Down
59 changes: 41 additions & 18 deletions pynitrokey/nk3/piv_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union

from ber_tlv.tlv import Tlv
from cryptography.hazmat.primitives import hashes
Expand All @@ -19,6 +19,7 @@ def find_by_id(tag: int, data: Sequence[tuple[int, bytes]]) -> Optional[bytes]:
for t, b in data:
if t == tag:
return b
return None


# size is in bytes
Expand All @@ -41,7 +42,7 @@ class StatusError(Exception):
def __init__(self, value: int):
self.value = value

def __str__(self):
def __str__(self) -> str:
return f"{hex(self.value)}"


Expand Down Expand Up @@ -179,7 +180,9 @@ def _send_receive_inner(self, data: bytes, log_info: str = "") -> bytes:
def authenticate_admin(self, admin_key: bytes) -> None:

if len(admin_key) == 24:
algorithm = algorithms.TripleDES(admin_key)
algorithm: Union[
algorithms.TripleDES, algorithms.AES128, algorithms.AES256
] = algorithms.TripleDES(admin_key)
# algo = "tdes"
algo_byte = 0x03
expected_len = 8
Expand Down Expand Up @@ -209,6 +212,10 @@ def authenticate_admin(self, admin_key: bytes) -> None:
),
)

if challenge is None:
local_critical("Failed to get authentication challenge from the device")
return

# challenge = decoded.first_by_id(0x7C).data.first_by_id(0x80).data
if len(challenge) != expected_len:
local_critical("Got unexpected authentication challenge length")
Expand Down Expand Up @@ -263,32 +270,32 @@ def encode_pin(self, pin: str) -> bytes:
body += bytes([0xFF for i in range(8 - len(body))])
return body

def login(self, pin: str):
def login(self, pin: str) -> None:
body = self.encode_pin(pin)
self.send_receive(0x20, 0x00, 0x80, body)

def change_pin(self, old_pin: str, new_pin: str):
def change_pin(self, old_pin: str, new_pin: str) -> None:
body = self.encode_pin(old_pin) + self.encode_pin(new_pin)
self.send_receive(0x24, 0, 0x80, body)

def change_puk(self, old_puk: str, new_puk: str):
old_puk = old_puk.encode("utf-8")
new_puk = new_puk.encode("utf-8")
if len(old_puk) != 8 or len(new_puk) != 8:
def change_puk(self, old_puk: str, new_puk: str) -> None:
old_puk_bytes = old_puk.encode("utf-8")
new_puk_bytes = new_puk.encode("utf-8")
if len(old_puk_bytes) != 8 or len(new_puk) != 8:
local_critical("PUK must be 8 bytes long", support_hint=False)
body = old_puk + new_puk
body = old_puk_bytes + new_puk_bytes
self.send_receive(0x24, 0, 0x81, body)

def reset_retry_counter(self, puk, new_pin):
puk = puk.encode("utf-8")
def reset_retry_counter(self, puk: str, new_pin: str) -> None:
puk_bytes = puk.encode("utf-8")

if len(puk) != 8:
if len(puk_bytes) != 8:
local_critical("PUK must be 8 bytes long", support_hint=False)

body = puk + self.encode_pin(new_pin)
body = puk_bytes + self.encode_pin(new_pin)
self.send_receive(0x2C, 0, 0x80, body)

def factory_reset(self):
def factory_reset(self) -> None:
self.send_receive(0xFB, 0, 0)

def sign_p256(self, data: bytes, key: int) -> bytes:
Expand All @@ -305,13 +312,22 @@ def sign_rsa2048(self, data: bytes, key: int) -> bytes:
def raw_sign(self, payload: bytes, key: int, algo: int) -> bytes:
body = Tlv.build({0x7C: {0x81: payload, 0x82: b""}})
result = self.send_receive(0x87, algo, key, body)
return find_by_id(

signature = find_by_id(
0x82,
Tlv.parse(
find_by_id(0x7C, Tlv.parse(result, recursive=False)), recursive=False
),
)

if signature is None:
local_critical("Failed to get signature from device")
# Satisfy the type checker.
# local_critical raises always raises an error
return b""

return signature

def init(self) -> bytes:
# Template for card capabilities with nothing but a random ID
template_begin = bytearray.fromhex("f015a000000116")
Expand Down Expand Up @@ -353,7 +369,14 @@ def guid(self) -> bytes:
payload = Tlv.build({0x5C: bytes(bytearray.fromhex("5FC102"))})
chuid = self.send_receive(0xCB, 0x3F, 0xFF, payload)

return find_by_id(0x34, Tlv.parse(find_by_id(0x53, Tlv.parse(chuid))))
chuid_data = find_by_id(0x34, Tlv.parse(find_by_id(0x53, Tlv.parse(chuid))))
if chuid_data is None:
local_critical("Failed to get chuid from device")
# Satisfy the type checker.
# local_critical raises always raises an error
return b""

return chuid_data

def cert(self, container_id: bytes) -> Optional[bytes]:
payload = Tlv.build({0x5C: container_id})
Expand Down Expand Up @@ -381,4 +404,4 @@ def cert(self, container_id: bytes) -> Optional[bytes]:
if e.value == 0x6A82:
return None
else:
raise ValueError(f"{e.value.hex()}, Received error")
raise ValueError(f"{hex(e.value)}, Received error")

0 comments on commit b0b5d38

Please sign in to comment.